diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 74d3a4736..4341f21d1 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -15,6 +15,7 @@ import hashlib import json import mimetypes import os +import random import re import time import uuid @@ -102,18 +103,23 @@ MAX_QR_REFRESH_COUNT = 3 TYPING_STATUS_TYPING = 1 TYPING_STATUS_CANCEL = 2 TYPING_TICKET_TTL_S = 24 * 60 * 60 +TYPING_KEEPALIVE_INTERVAL_S = 5 +CONFIG_CACHE_INITIAL_RETRY_S = 2 +CONFIG_CACHE_MAX_RETRY_S = 60 * 60 # Default long-poll timeout; overridden by server via longpolling_timeout_ms. DEFAULT_LONG_POLL_TIMEOUT_S = 35 -# Media-type codes for getuploadurl (1=image, 2=video, 3=file) +# Media-type codes for getuploadurl (1=image, 2=video, 3=file, 4=voice) UPLOAD_MEDIA_IMAGE = 1 UPLOAD_MEDIA_VIDEO = 2 UPLOAD_MEDIA_FILE = 3 +UPLOAD_MEDIA_VOICE = 4 # File extensions considered as images / videos for outbound media _IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".ico", ".svg"} _VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv"} +_VOICE_EXTS = {".mp3", ".wav", ".amr", ".silk", ".ogg", ".m4a", ".aac", ".flac"} def _has_downloadable_media_locator(media: dict[str, Any] | None) -> bool: @@ -167,7 +173,7 @@ class WeixinChannel(BaseChannel): self._poll_task: asyncio.Task | None = None self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S self._session_pause_until: float = 0.0 - self._typing_tickets: dict[str, tuple[str, float]] = {} + self._typing_tickets: dict[str, dict[str, Any]] = {} # ------------------------------------------------------------------ # State persistence @@ -339,7 +345,16 @@ class WeixinChannel(BaseChannel): params={"qrcode": qrcode_id}, auth=False, ) - except httpx.TimeoutException: + except Exception as e: + if self._is_retryable_qr_poll_error(e): + logger.warning("QR polling temporary error, will retry: {}", e) + await asyncio.sleep(1) + continue + raise + + if not isinstance(status_data, dict): + logger.warning("QR polling got non-object response, continue waiting") + await asyncio.sleep(1) continue status = status_data.get("status", "") @@ -408,6 +423,16 @@ class WeixinChannel(BaseChannel): return False + @staticmethod + def _is_retryable_qr_poll_error(err: Exception) -> bool: + if isinstance(err, httpx.TimeoutException | httpx.TransportError): + return True + if isinstance(err, httpx.HTTPStatusError): + status_code = err.response.status_code if err.response is not None else 0 + if status_code >= 500: + return True + return False + @staticmethod def _print_qr_code(url: str) -> None: try: @@ -858,13 +883,11 @@ class WeixinChannel(BaseChannel): # ------------------------------------------------------------------ async def _get_typing_ticket(self, user_id: str, context_token: str = "") -> str: - """Get typing ticket for a user with simple per-user TTL cache.""" + """Get typing ticket with per-user refresh + failure backoff cache.""" now = time.time() - cached = self._typing_tickets.get(user_id) - if cached: - ticket, expires_at = cached - if ticket and now < expires_at: - return ticket + entry = self._typing_tickets.get(user_id) + if entry and now < float(entry.get("next_fetch_at", 0)): + return str(entry.get("ticket", "") or "") body: dict[str, Any] = { "ilink_user_id": user_id, @@ -874,9 +897,27 @@ class WeixinChannel(BaseChannel): data = await self._api_post("ilink/bot/getconfig", body) if data.get("ret", 0) == 0: ticket = str(data.get("typing_ticket", "") or "") - if ticket: - self._typing_tickets[user_id] = (ticket, now + TYPING_TICKET_TTL_S) - return ticket + self._typing_tickets[user_id] = { + "ticket": ticket, + "ever_succeeded": True, + "next_fetch_at": now + (random.random() * TYPING_TICKET_TTL_S), + "retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S, + } + return ticket + + prev_delay = float(entry.get("retry_delay_s", CONFIG_CACHE_INITIAL_RETRY_S)) if entry else CONFIG_CACHE_INITIAL_RETRY_S + next_delay = min(prev_delay * 2, CONFIG_CACHE_MAX_RETRY_S) + if entry: + entry["next_fetch_at"] = now + next_delay + entry["retry_delay_s"] = next_delay + return str(entry.get("ticket", "") or "") + + self._typing_tickets[user_id] = { + "ticket": "", + "ever_succeeded": False, + "next_fetch_at": now + CONFIG_CACHE_INITIAL_RETRY_S, + "retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S, + } return "" async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None: @@ -891,6 +932,16 @@ class WeixinChannel(BaseChannel): } await self._api_post("ilink/bot/sendtyping", body) + async def _typing_keepalive_loop(self, user_id: str, typing_ticket: str, stop_event: asyncio.Event) -> None: + while not stop_event.is_set(): + await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S) + if stop_event.is_set(): + break + try: + await self._send_typing(user_id, typing_ticket, TYPING_STATUS_TYPING) + except Exception as e: + logger.debug("WeChat sendtyping(keepalive) failed for {}: {}", user_id, e) + async def send(self, msg: OutboundMessage) -> None: if not self._client or not self._token: logger.warning("WeChat client not initialized or not authenticated") @@ -923,6 +974,13 @@ class WeixinChannel(BaseChannel): except Exception as e: logger.debug("WeChat sendtyping(start) failed for {}: {}", msg.chat_id, e) + typing_keepalive_stop = asyncio.Event() + typing_keepalive_task: asyncio.Task | None = None + if typing_ticket: + typing_keepalive_task = asyncio.create_task( + self._typing_keepalive_loop(msg.chat_id, typing_ticket, typing_keepalive_stop) + ) + try: # --- Send media files first (following Telegram channel pattern) --- for media_path in (msg.media or []): @@ -947,6 +1005,14 @@ class WeixinChannel(BaseChannel): logger.error("Error sending WeChat message: {}", e) raise finally: + if typing_keepalive_task: + typing_keepalive_stop.set() + typing_keepalive_task.cancel() + try: + await typing_keepalive_task + except asyncio.CancelledError: + pass + if typing_ticket: try: await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL) @@ -1025,6 +1091,10 @@ class WeixinChannel(BaseChannel): upload_type = UPLOAD_MEDIA_VIDEO item_type = ITEM_VIDEO item_key = "video_item" + elif ext in _VOICE_EXTS: + upload_type = UPLOAD_MEDIA_VOICE + item_type = ITEM_VOICE + item_key = "voice_item" else: upload_type = UPLOAD_MEDIA_FILE item_type = ITEM_FILE diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 7701ad597..c4e5cf552 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from unittest.mock import AsyncMock import pytest +import httpx import nanobot.channels.weixin as weixin_mod from nanobot.bus.queue import MessageBus @@ -595,6 +596,158 @@ async def test_send_media_falls_back_to_upload_param_url(tmp_path) -> None: assert "&filekey=" in cdn_url +@pytest.mark.asyncio +async def test_send_media_voice_file_uses_voice_item_and_voice_upload_type(tmp_path) -> None: + channel, _bus = _make_channel() + + media_file = tmp_path / "voice.mp3" + media_file.write_bytes(b"voice-bytes") + + cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "voice-dl-param"})) + channel._client = SimpleNamespace(post=cdn_post) + channel._api_post = AsyncMock( + side_effect=[ + {"upload_full_url": "https://upload-full.example.test/voice?foo=bar"}, + {"ret": 0}, + ] + ) + + await channel._send_media_file("wx-user", str(media_file), "ctx-voice") + + getupload_body = channel._api_post.await_args_list[0].args[1] + assert getupload_body["media_type"] == 4 + + sendmessage_body = channel._api_post.await_args_list[1].args[1] + item = sendmessage_body["msg"]["item_list"][0] + assert item["type"] == 3 + assert "voice_item" in item + assert "file_item" not in item + assert item["voice_item"]["media"]["encrypt_query_param"] == "voice-dl-param" + + +@pytest.mark.asyncio +async def test_send_typing_uses_keepalive_until_send_finishes() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-typing-loop" + async def _api_post_side_effect(endpoint: str, _body: dict | None = None, *, auth: bool = True): + if endpoint == "ilink/bot/getconfig": + return {"ret": 0, "typing_ticket": "ticket-keepalive"} + return {"ret": 0} + + channel._api_post = AsyncMock(side_effect=_api_post_side_effect) + + async def _slow_send_text(*_args, **_kwargs) -> None: + await asyncio.sleep(0.03) + + channel._send_text = AsyncMock(side_effect=_slow_send_text) + + old_interval = weixin_mod.TYPING_KEEPALIVE_INTERVAL_S + weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = 0.01 + try: + await channel.send( + type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() + ) + finally: + weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = old_interval + + status_calls = [ + c.args[1]["status"] + for c in channel._api_post.await_args_list + if c.args and c.args[0] == "ilink/bot/sendtyping" + ] + assert status_calls.count(1) >= 2 + assert status_calls[-1] == 2 + + +@pytest.mark.asyncio +async def test_get_typing_ticket_failure_uses_backoff_and_cached_ticket(monkeypatch) -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + + now = {"value": 1000.0} + monkeypatch.setattr(weixin_mod.time, "time", lambda: now["value"]) + monkeypatch.setattr(weixin_mod.random, "random", lambda: 0.5) + + channel._api_post = AsyncMock(return_value={"ret": 0, "typing_ticket": "ticket-ok"}) + first = await channel._get_typing_ticket("wx-user", "ctx-1") + assert first == "ticket-ok" + + # force refresh window reached + now["value"] = now["value"] + (12 * 60 * 60) + 1 + channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "temporary failure"}) + + # On refresh failure, should still return cached ticket and apply backoff. + second = await channel._get_typing_ticket("wx-user", "ctx-2") + assert second == "ticket-ok" + assert channel._api_post.await_count == 1 + + # Before backoff expiry, no extra fetch should happen. + now["value"] += 1 + third = await channel._get_typing_ticket("wx-user", "ctx-3") + assert third == "ticket-ok" + assert channel._api_post.await_count == 1 + + +@pytest.mark.asyncio +async def test_qr_login_treats_temporary_connect_error_as_wait_and_recovers() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._save_state = lambda: None + channel._print_qr_code = lambda url: None + channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1")) + + request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status") + channel._api_get_with_base = AsyncMock( + side_effect=[ + httpx.ConnectError("temporary network", request=request), + { + "status": "confirmed", + "bot_token": "token-net-ok", + "ilink_bot_id": "bot-id", + "baseurl": "https://example.test", + "ilink_user_id": "wx-user", + }, + ] + ) + + ok = await channel._qr_login() + + assert ok is True + assert channel._token == "token-net-ok" + + +@pytest.mark.asyncio +async def test_qr_login_treats_5xx_gateway_response_error_as_wait_and_recovers() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._save_state = lambda: None + channel._print_qr_code = lambda url: None + channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1")) + + request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status") + response = httpx.Response(status_code=524, request=request) + channel._api_get_with_base = AsyncMock( + side_effect=[ + httpx.HTTPStatusError("gateway timeout", request=request, response=response), + { + "status": "confirmed", + "bot_token": "token-5xx-ok", + "ilink_bot_id": "bot-id", + "baseurl": "https://example.test", + "ilink_user_id": "wx-user", + }, + ] + ) + + ok = await channel._qr_login() + + assert ok is True + assert channel._token == "token-5xx-ok" + + def test_decrypt_aes_ecb_strips_valid_pkcs7_padding() -> None: key_b64 = "MDEyMzQ1Njc4OWFiY2RlZg==" # base64("0123456789abcdef") plaintext = b"hello-weixin-padding"