import asyncio import json import tempfile from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock import pytest import nanobot.channels.weixin as weixin_mod from nanobot.bus.queue import MessageBus from nanobot.channels.weixin import ( ITEM_IMAGE, ITEM_TEXT, MESSAGE_TYPE_BOT, WEIXIN_CHANNEL_VERSION, _decrypt_aes_ecb, _encrypt_aes_ecb, WeixinChannel, WeixinConfig, ) def _make_channel() -> tuple[WeixinChannel, MessageBus]: bus = MessageBus() channel = WeixinChannel( WeixinConfig( enabled=True, allow_from=["*"], state_dir=tempfile.mkdtemp(prefix="nanobot-weixin-test-"), ), bus, ) return channel, bus def test_make_headers_includes_route_tag_when_configured() -> None: bus = MessageBus() channel = WeixinChannel( WeixinConfig(enabled=True, allow_from=["*"], route_tag=123), bus, ) channel._token = "token" headers = channel._make_headers() assert headers["Authorization"] == "Bearer token" assert headers["SKRouteTag"] == "123" assert headers["iLink-App-Id"] == "bot" assert headers["iLink-App-ClientVersion"] == str((2 << 16) | (1 << 8) | 1) def test_channel_version_matches_reference_plugin_version() -> None: pkg = json.loads(Path("package/package.json").read_text()) assert WEIXIN_CHANNEL_VERSION == pkg["version"] def test_save_and_load_state_persists_context_tokens(tmp_path) -> None: bus = MessageBus() channel = WeixinChannel( WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), bus, ) channel._token = "token" channel._get_updates_buf = "cursor" channel._context_tokens = {"wx-user": "ctx-1"} channel._save_state() saved = json.loads((tmp_path / "account.json").read_text()) assert saved["context_tokens"] == {"wx-user": "ctx-1"} restored = WeixinChannel( WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), bus, ) assert restored._load_state() is True assert restored._context_tokens == {"wx-user": "ctx-1"} @pytest.mark.asyncio async def test_process_message_deduplicates_inbound_ids() -> None: channel, bus = _make_channel() msg = { "message_type": 1, "message_id": "m1", "from_user_id": "wx-user", "context_token": "ctx-1", "item_list": [ {"type": ITEM_TEXT, "text_item": {"text": "hello"}}, ], } await channel._process_message(msg) first = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0) await channel._process_message(msg) assert first.sender_id == "wx-user" assert first.chat_id == "wx-user" assert first.content == "hello" assert bus.inbound_size == 0 @pytest.mark.asyncio async def test_process_message_caches_context_token_and_send_uses_it() -> None: channel, _bus = _make_channel() channel._client = object() channel._token = "token" channel._send_text = AsyncMock() await channel._process_message( { "message_type": 1, "message_id": "m2", "from_user_id": "wx-user", "context_token": "ctx-2", "item_list": [ {"type": ITEM_TEXT, "text_item": {"text": "ping"}}, ], } ) await channel.send( type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() ) channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2") @pytest.mark.asyncio async def test_process_message_persists_context_token_to_state_file(tmp_path) -> None: bus = MessageBus() channel = WeixinChannel( WeixinConfig(enabled=True, allow_from=["*"], state_dir=str(tmp_path)), bus, ) await channel._process_message( { "message_type": 1, "message_id": "m2b", "from_user_id": "wx-user", "context_token": "ctx-2b", "item_list": [ {"type": ITEM_TEXT, "text_item": {"text": "ping"}}, ], } ) saved = json.loads((tmp_path / "account.json").read_text()) assert saved["context_tokens"] == {"wx-user": "ctx-2b"} @pytest.mark.asyncio async def test_process_message_extracts_media_and_preserves_paths() -> None: channel, bus = _make_channel() channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg") await channel._process_message( { "message_type": 1, "message_id": "m3", "from_user_id": "wx-user", "context_token": "ctx-3", "item_list": [ {"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "x"}}}, ], } ) inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0) assert "[image]" in inbound.content assert "/tmp/test.jpg" in inbound.content assert inbound.media == ["/tmp/test.jpg"] @pytest.mark.asyncio async def test_send_without_context_token_does_not_send_text() -> None: channel, _bus = _make_channel() channel._client = object() channel._token = "token" channel._send_text = AsyncMock() await channel.send( type("Msg", (), {"chat_id": "unknown-user", "content": "pong", "media": [], "metadata": {}})() ) channel._send_text.assert_not_awaited() @pytest.mark.asyncio async def test_send_does_not_send_when_session_is_paused() -> None: channel, _bus = _make_channel() channel._client = object() channel._token = "token" channel._context_tokens["wx-user"] = "ctx-2" channel._pause_session(60) channel._send_text = AsyncMock() await channel.send( type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() ) channel._send_text.assert_not_awaited() @pytest.mark.asyncio async def test_poll_once_pauses_session_on_expired_errcode() -> None: channel, _bus = _make_channel() channel._client = SimpleNamespace(timeout=None) channel._token = "token" channel._api_post = AsyncMock(return_value={"ret": 0, "errcode": -14, "errmsg": "expired"}) await channel._poll_once() assert channel._session_pause_remaining_s() > 0 @pytest.mark.asyncio async def test_qr_login_refreshes_expired_qr_and_then_succeeds() -> None: channel, _bus = _make_channel() channel._running = True channel._save_state = lambda: None channel._print_qr_code = lambda url: None channel._api_get = AsyncMock( side_effect=[ {"qrcode": "qr-1", "qrcode_img_content": "url-1"}, {"status": "expired"}, {"qrcode": "qr-2", "qrcode_img_content": "url-2"}, { "status": "confirmed", "bot_token": "token-2", "ilink_bot_id": "bot-2", "baseurl": "https://example.test", "ilink_user_id": "wx-user", }, ] ) ok = await channel._qr_login() assert ok is True assert channel._token == "token-2" assert channel.config.base_url == "https://example.test" @pytest.mark.asyncio async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None: channel, _bus = _make_channel() channel._running = True channel._print_qr_code = lambda url: None channel._api_get = AsyncMock( side_effect=[ {"qrcode": "qr-1", "qrcode_img_content": "url-1"}, {"status": "expired"}, {"qrcode": "qr-2", "qrcode_img_content": "url-2"}, {"status": "expired"}, {"qrcode": "qr-3", "qrcode_img_content": "url-3"}, {"status": "expired"}, {"qrcode": "qr-4", "qrcode_img_content": "url-4"}, {"status": "expired"}, ] ) ok = await channel._qr_login() assert ok is False @pytest.mark.asyncio async def test_process_message_skips_bot_messages() -> None: channel, bus = _make_channel() await channel._process_message( { "message_type": MESSAGE_TYPE_BOT, "message_id": "m4", "from_user_id": "wx-user", "item_list": [ {"type": ITEM_TEXT, "text_item": {"text": "hello"}}, ], } ) assert bus.inbound_size == 0 class _DummyHttpResponse: def __init__(self, *, headers: dict[str, str] | None = None, status_code: int = 200) -> None: self.headers = headers or {} self.status_code = status_code def raise_for_status(self) -> None: return None @pytest.mark.asyncio async def test_send_media_uses_upload_full_url_when_present(tmp_path) -> None: channel, _bus = _make_channel() media_file = tmp_path / "photo.jpg" media_file.write_bytes(b"hello-weixin") cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "dl-param"})) channel._client = SimpleNamespace(post=cdn_post) channel._api_post = AsyncMock( side_effect=[ { "upload_full_url": "https://upload-full.example.test/path?foo=bar", "upload_param": "should-not-be-used", }, {"ret": 0}, ] ) await channel._send_media_file("wx-user", str(media_file), "ctx-1") # first POST call is CDN upload cdn_url = cdn_post.await_args_list[0].args[0] assert cdn_url == "https://upload-full.example.test/path?foo=bar" @pytest.mark.asyncio async def test_send_media_falls_back_to_upload_param_url(tmp_path) -> None: channel, _bus = _make_channel() media_file = tmp_path / "photo.jpg" media_file.write_bytes(b"hello-weixin") cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "dl-param"})) channel._client = SimpleNamespace(post=cdn_post) channel._api_post = AsyncMock( side_effect=[ {"upload_param": "enc-need-fallback"}, {"ret": 0}, ] ) await channel._send_media_file("wx-user", str(media_file), "ctx-1") cdn_url = cdn_post.await_args_list[0].args[0] assert cdn_url.startswith(f"{channel.config.cdn_base_url}/upload?encrypted_query_param=enc-need-fallback") assert "&filekey=" in cdn_url def test_decrypt_aes_ecb_strips_valid_pkcs7_padding() -> None: key_b64 = "MDEyMzQ1Njc4OWFiY2RlZg==" # base64("0123456789abcdef") plaintext = b"hello-weixin-padding" ciphertext = _encrypt_aes_ecb(plaintext, key_b64) decrypted = _decrypt_aes_ecb(ciphertext, key_b64) assert decrypted == plaintext class _DummyDownloadResponse: def __init__(self, content: bytes, status_code: int = 200) -> None: self.content = content self.status_code = status_code def raise_for_status(self) -> None: return None @pytest.mark.asyncio async def test_download_media_item_uses_full_url_when_present(tmp_path) -> None: channel, _bus = _make_channel() weixin_mod.get_media_dir = lambda _name: tmp_path full_url = "https://cdn.example.test/download/full" channel._client = SimpleNamespace( get=AsyncMock(return_value=_DummyDownloadResponse(content=b"raw-image-bytes")) ) item = { "media": { "full_url": full_url, "encrypt_query_param": "enc-fallback-should-not-be-used", }, } saved_path = await channel._download_media_item(item, "image") assert saved_path is not None assert Path(saved_path).read_bytes() == b"raw-image-bytes" channel._client.get.assert_awaited_once_with(full_url) @pytest.mark.asyncio async def test_download_media_item_falls_back_to_encrypt_query_param(tmp_path) -> None: channel, _bus = _make_channel() weixin_mod.get_media_dir = lambda _name: tmp_path channel._client = SimpleNamespace( get=AsyncMock(return_value=_DummyDownloadResponse(content=b"fallback-bytes")) ) item = {"media": {"encrypt_query_param": "enc-fallback"}} saved_path = await channel._download_media_item(item, "image") assert saved_path is not None assert Path(saved_path).read_bytes() == b"fallback-bytes" called_url = channel._client.get.await_args_list[0].args[0] assert called_url.startswith(f"{channel.config.cdn_base_url}/download?encrypted_query_param=enc-fallback")