From 1f5492ea9e33d431852b967b058d2c48d40ef8fb Mon Sep 17 00:00:00 2001 From: xcosmosbox <2162381070@qq.com> Date: Tue, 24 Mar 2026 14:52:13 +0800 Subject: [PATCH] fix(WeiXin): persist _context_tokens with account.json to restore conversations after restart --- nanobot/channels/weixin.py | 11 ++++++ tests/channels/test_weixin_channel.py | 56 ++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index e572d68a2..115cca7ff 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -147,6 +147,15 @@ class WeixinChannel(BaseChannel): data = json.loads(state_file.read_text()) self._token = data.get("token", "") self._get_updates_buf = data.get("get_updates_buf", "") + context_tokens = data.get("context_tokens", {}) + if isinstance(context_tokens, dict): + self._context_tokens = { + str(user_id): str(token) + for user_id, token in context_tokens.items() + if str(user_id).strip() and str(token).strip() + } + else: + self._context_tokens = {} base_url = data.get("base_url", "") if base_url: self.config.base_url = base_url @@ -161,6 +170,7 @@ class WeixinChannel(BaseChannel): data = { "token": self._token, "get_updates_buf": self._get_updates_buf, + "context_tokens": self._context_tokens, "base_url": self.config.base_url, } state_file.write_text(json.dumps(data, ensure_ascii=False)) @@ -502,6 +512,7 @@ class WeixinChannel(BaseChannel): ctx_token = msg.get("context_token", "") if ctx_token: self._context_tokens[from_user_id] = ctx_token + self._save_state() # Parse item_list (WeixinMessage.item_list — types.ts:161) item_list: list[dict] = msg.get("item_list") or [] diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 0a01b72c7..36e56315b 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -1,4 +1,6 @@ import asyncio +import json +import tempfile from types import SimpleNamespace from unittest.mock import AsyncMock @@ -17,7 +19,11 @@ from nanobot.channels.weixin import ( def _make_channel() -> tuple[WeixinChannel, MessageBus]: bus = MessageBus() channel = WeixinChannel( - WeixinConfig(enabled=True, allow_from=["*"]), + WeixinConfig( + enabled=True, + allow_from=["*"], + state_dir=tempfile.mkdtemp(prefix="nanobot-weixin-test-"), + ), bus, ) return channel, bus @@ -37,6 +43,30 @@ def test_make_headers_includes_route_tag_when_configured() -> None: assert headers["SKRouteTag"] == "123" +def test_save_and_load_state_persists_context_tokens(tmp_path) -> None: + bus = MessageBus() + channel = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), + bus, + ) + channel._token = "token" + channel._get_updates_buf = "cursor" + channel._context_tokens = {"wx-user": "ctx-1"} + + channel._save_state() + + saved = json.loads((tmp_path / "account.json").read_text()) + assert saved["context_tokens"] == {"wx-user": "ctx-1"} + + restored = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), + bus, + ) + + assert restored._load_state() is True + assert restored._context_tokens == {"wx-user": "ctx-1"} + + @pytest.mark.asyncio async def test_process_message_deduplicates_inbound_ids() -> None: channel, bus = _make_channel() @@ -86,6 +116,30 @@ async def test_process_message_caches_context_token_and_send_uses_it() -> None: channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2") +@pytest.mark.asyncio +async def test_process_message_persists_context_token_to_state_file(tmp_path) -> None: + bus = MessageBus() + channel = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), + bus, + ) + + await channel._process_message( + { + "message_type": 1, + "message_id": "m2b", + "from_user_id": "wx-user", + "context_token": "ctx-2b", + "item_list": [ + {"type": ITEM_TEXT, "text_item": {"text": "ping"}}, + ], + } + ) + + saved = json.loads((tmp_path / "account.json").read_text()) + assert saved["context_tokens"] == {"wx-user": "ctx-2b"} + + @pytest.mark.asyncio async def test_process_message_extracts_media_and_preserves_paths() -> None: channel, bus = _make_channel()