mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-09 11:15:55 +00:00
fix(WeiXin): resolve polling issues in WeiXin plugin
- Prevent repeated retries on expired sessions in the polling thread - Stop sending messages to invalid agent sessions to eliminate noise logs and unnecessary requests
This commit is contained in:
parent
3a9d6ea536
commit
9c872c3458
@ -57,6 +57,7 @@ BASE_INFO: dict[str, str] = {"channel_version": "1.0.2"}
|
|||||||
|
|
||||||
# Session-expired error code
|
# Session-expired error code
|
||||||
ERRCODE_SESSION_EXPIRED = -14
|
ERRCODE_SESSION_EXPIRED = -14
|
||||||
|
SESSION_PAUSE_DURATION_S = 60 * 60
|
||||||
|
|
||||||
# Retry constants (matching the reference plugin's monitor.ts)
|
# Retry constants (matching the reference plugin's monitor.ts)
|
||||||
MAX_CONSECUTIVE_FAILURES = 3
|
MAX_CONSECUTIVE_FAILURES = 3
|
||||||
@ -120,6 +121,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
self._token: str = ""
|
self._token: str = ""
|
||||||
self._poll_task: asyncio.Task | None = None
|
self._poll_task: asyncio.Task | None = None
|
||||||
self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S
|
self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S
|
||||||
|
self._session_pause_until: float = 0.0
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# State persistence
|
# State persistence
|
||||||
@ -395,7 +397,34 @@ class WeixinChannel(BaseChannel):
|
|||||||
# Polling (matches monitor.ts monitorWeixinProvider)
|
# Polling (matches monitor.ts monitorWeixinProvider)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _pause_session(self, duration_s: int = SESSION_PAUSE_DURATION_S) -> None:
|
||||||
|
self._session_pause_until = time.time() + duration_s
|
||||||
|
|
||||||
|
def _session_pause_remaining_s(self) -> int:
|
||||||
|
remaining = int(self._session_pause_until - time.time())
|
||||||
|
if remaining <= 0:
|
||||||
|
self._session_pause_until = 0.0
|
||||||
|
return 0
|
||||||
|
return remaining
|
||||||
|
|
||||||
|
def _assert_session_active(self) -> None:
|
||||||
|
remaining = self._session_pause_remaining_s()
|
||||||
|
if remaining > 0:
|
||||||
|
remaining_min = max((remaining + 59) // 60, 1)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"WeChat session paused, {remaining_min} min remaining (errcode {ERRCODE_SESSION_EXPIRED})"
|
||||||
|
)
|
||||||
|
|
||||||
async def _poll_once(self) -> None:
|
async def _poll_once(self) -> None:
|
||||||
|
remaining = self._session_pause_remaining_s()
|
||||||
|
if remaining > 0:
|
||||||
|
logger.warning(
|
||||||
|
"WeChat session paused, waiting {} min before next poll.",
|
||||||
|
max((remaining + 59) // 60, 1),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(remaining)
|
||||||
|
return
|
||||||
|
|
||||||
body: dict[str, Any] = {
|
body: dict[str, Any] = {
|
||||||
"get_updates_buf": self._get_updates_buf,
|
"get_updates_buf": self._get_updates_buf,
|
||||||
"base_info": BASE_INFO,
|
"base_info": BASE_INFO,
|
||||||
@ -414,11 +443,13 @@ class WeixinChannel(BaseChannel):
|
|||||||
|
|
||||||
if is_error:
|
if is_error:
|
||||||
if errcode == ERRCODE_SESSION_EXPIRED or ret == ERRCODE_SESSION_EXPIRED:
|
if errcode == ERRCODE_SESSION_EXPIRED or ret == ERRCODE_SESSION_EXPIRED:
|
||||||
|
self._pause_session()
|
||||||
|
remaining = self._session_pause_remaining_s()
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"WeChat session expired (errcode {}). Pausing 60 min.",
|
"WeChat session expired (errcode {}). Pausing {} min.",
|
||||||
errcode,
|
errcode,
|
||||||
|
max((remaining + 59) // 60, 1),
|
||||||
)
|
)
|
||||||
await asyncio.sleep(3600)
|
|
||||||
return
|
return
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"getUpdates failed: ret={ret} errcode={errcode} errmsg={data.get('errmsg', '')}"
|
f"getUpdates failed: ret={ret} errcode={errcode} errmsg={data.get('errmsg', '')}"
|
||||||
@ -654,6 +685,11 @@ class WeixinChannel(BaseChannel):
|
|||||||
if not self._client or not self._token:
|
if not self._client or not self._token:
|
||||||
logger.warning("WeChat client not initialized or not authenticated")
|
logger.warning("WeChat client not initialized or not authenticated")
|
||||||
return
|
return
|
||||||
|
try:
|
||||||
|
self._assert_session_active()
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.warning("WeChat send blocked: {}", e)
|
||||||
|
return
|
||||||
|
|
||||||
content = msg.content.strip()
|
content = msg.content.strip()
|
||||||
ctx_token = self._context_tokens.get(msg.chat_id, "")
|
ctx_token = self._context_tokens.get(msg.chat_id, "")
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -123,6 +124,34 @@ async def test_send_without_context_token_does_not_send_text() -> None:
|
|||||||
channel._send_text.assert_not_awaited()
|
channel._send_text.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_does_not_send_when_session_is_paused() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-2"
|
||||||
|
channel._pause_session(60)
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||||
|
)
|
||||||
|
|
||||||
|
channel._send_text.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_once_pauses_session_on_expired_errcode() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = SimpleNamespace(timeout=None)
|
||||||
|
channel._token = "token"
|
||||||
|
channel._api_post = AsyncMock(return_value={"ret": 0, "errcode": -14, "errmsg": "expired"})
|
||||||
|
|
||||||
|
await channel._poll_once()
|
||||||
|
|
||||||
|
assert channel._session_pause_remaining_s() > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_message_skips_bot_messages() -> None:
|
async def test_process_message_skips_bot_messages() -> None:
|
||||||
channel, bus = _make_channel()
|
channel, bus = _make_channel()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user