diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index a8a4a636d..e572d68a2 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -57,6 +57,7 @@ BASE_INFO: dict[str, str] = {"channel_version": "1.0.2"} # Session-expired error code ERRCODE_SESSION_EXPIRED = -14 +SESSION_PAUSE_DURATION_S = 60 * 60 # Retry constants (matching the reference plugin's monitor.ts) MAX_CONSECUTIVE_FAILURES = 3 @@ -120,6 +121,7 @@ class WeixinChannel(BaseChannel): self._token: str = "" self._poll_task: asyncio.Task | None = None self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S + self._session_pause_until: float = 0.0 # ------------------------------------------------------------------ # State persistence @@ -395,7 +397,34 @@ class WeixinChannel(BaseChannel): # Polling (matches monitor.ts monitorWeixinProvider) # ------------------------------------------------------------------ + def _pause_session(self, duration_s: int = SESSION_PAUSE_DURATION_S) -> None: + self._session_pause_until = time.time() + duration_s + + def _session_pause_remaining_s(self) -> int: + remaining = int(self._session_pause_until - time.time()) + if remaining <= 0: + self._session_pause_until = 0.0 + return 0 + return remaining + + def _assert_session_active(self) -> None: + remaining = self._session_pause_remaining_s() + if remaining > 0: + remaining_min = max((remaining + 59) // 60, 1) + raise RuntimeError( + f"WeChat session paused, {remaining_min} min remaining (errcode {ERRCODE_SESSION_EXPIRED})" + ) + async def _poll_once(self) -> None: + remaining = self._session_pause_remaining_s() + if remaining > 0: + logger.warning( + "WeChat session paused, waiting {} min before next poll.", + max((remaining + 59) // 60, 1), + ) + await asyncio.sleep(remaining) + return + body: dict[str, Any] = { "get_updates_buf": self._get_updates_buf, "base_info": BASE_INFO, @@ -414,11 +443,13 @@ class WeixinChannel(BaseChannel): if is_error: if errcode == ERRCODE_SESSION_EXPIRED or ret == ERRCODE_SESSION_EXPIRED: + self._pause_session() + remaining = self._session_pause_remaining_s() logger.warning( - "WeChat session expired (errcode {}). Pausing 60 min.", + "WeChat session expired (errcode {}). Pausing {} min.", errcode, + max((remaining + 59) // 60, 1), ) - await asyncio.sleep(3600) return raise RuntimeError( f"getUpdates failed: ret={ret} errcode={errcode} errmsg={data.get('errmsg', '')}" @@ -654,6 +685,11 @@ class WeixinChannel(BaseChannel): if not self._client or not self._token: logger.warning("WeChat client not initialized or not authenticated") return + try: + self._assert_session_active() + except RuntimeError as e: + logger.warning("WeChat send blocked: {}", e) + return content = msg.content.strip() ctx_token = self._context_tokens.get(msg.chat_id, "") diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 6107d117b..0a01b72c7 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -1,4 +1,5 @@ import asyncio +from types import SimpleNamespace from unittest.mock import AsyncMock import pytest @@ -123,6 +124,34 @@ async def test_send_without_context_token_does_not_send_text() -> None: channel._send_text.assert_not_awaited() +@pytest.mark.asyncio +async def test_send_does_not_send_when_session_is_paused() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-2" + channel._pause_session(60) + channel._send_text = AsyncMock() + + await channel.send( + type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() + ) + + channel._send_text.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_poll_once_pauses_session_on_expired_errcode() -> None: + channel, _bus = _make_channel() + channel._client = SimpleNamespace(timeout=None) + channel._token = "token" + channel._api_post = AsyncMock(return_value={"ret": 0, "errcode": -14, "errmsg": "expired"}) + + await channel._poll_once() + + assert channel._session_pause_remaining_s() > 0 + + @pytest.mark.asyncio async def test_process_message_skips_bot_messages() -> None: channel, bus = _make_channel()