From d5692bf94c39af6894ccd71d371af4e438f5f0d9 Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Tue, 2 Jun 2026 00:52:54 +0800 Subject: [PATCH] fix(napcat): harden async handlers and action errors maintainer edit: track background handler tasks, surface failed OneBot actions, reject image redirects, and add focused unit coverage for group routing and edge cases. --- nanobot/channels/napcat.py | 56 +++++++-- tests/channels/test_napcat_channel.py | 172 ++++++++++++++++++++++++++ 2 files changed, 219 insertions(+), 9 deletions(-) create mode 100644 tests/channels/test_napcat_channel.py diff --git a/nanobot/channels/napcat.py b/nanobot/channels/napcat.py index 64c0ac8dd..c0d961f01 100644 --- a/nanobot/channels/napcat.py +++ b/nanobot/channels/napcat.py @@ -12,11 +12,12 @@ import uuid from collections import deque from pathlib import Path from typing import Annotated, Any, Literal + import aiohttp from loguru import logger from pydantic import Field - -from websockets.asyncio.client import ClientConnection, connect as ws_connect +from websockets.asyncio.client import ClientConnection +from websockets.asyncio.client import connect as ws_connect from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus @@ -75,6 +76,7 @@ class NapcatChannel(BaseChannel): self._pending: dict[str, asyncio.Future[dict[str, Any]]] = {} self._processed_ids: deque[int] = deque(maxlen=2000) self._bot_outbound_ids: deque[int] = deque(maxlen=2000) + self._background_tasks: set[asyncio.Task[None]] = set() # ------------------------------------------------------------------ # Lifecycle @@ -161,6 +163,12 @@ class NapcatChannel(BaseChannel): pass self._http = None self._fail_pending(RuntimeError("napcat: stopped")) + tasks = list(self._background_tasks) + for task in tasks: + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + self._background_tasks.clear() def _fail_pending(self, err: BaseException) -> None: for fut in self._pending.values(): @@ -198,9 +206,24 @@ class NapcatChannel(BaseChannel): post_type = payload.get("post_type") if post_type == "message": - asyncio.create_task(self._on_message(payload)) + self._create_background_task(self._on_message(payload), "message") elif post_type == "notice": - asyncio.create_task(self._on_notice(payload)) # avoid deadlock + self._create_background_task(self._on_notice(payload), "notice") + + def _create_background_task(self, coro: Any, kind: str) -> None: + task = asyncio.create_task(coro) + self._background_tasks.add(task) + + def _done(done: asyncio.Task[None]) -> None: + self._background_tasks.discard(done) + try: + done.result() + except asyncio.CancelledError: + pass + except Exception as e: + logger.warning("napcat: {} handler failed: {}", kind, e) + + task.add_done_callback(_done) # ------------------------------------------------------------------ # Inbound: messages @@ -245,8 +268,6 @@ class NapcatChannel(BaseChannel): return chat_id = f"group:{group_id}" - label = nickname or str(user_id) - content = f"{label}: {text}" content = self._format_group_content( text=text, nickname=nickname, @@ -367,7 +388,14 @@ class NapcatChannel(BaseChannel): if group_id is None or user_id is None: return - nickname = await self._lookup_member_name(int(group_id), int(user_id)) + try: + group_id_int = int(group_id) + user_id_int = int(user_id) + except (TypeError, ValueError): + logger.warning("napcat: invalid group_increase ids group_id={} user_id={}", group_id, user_id) + return + + nickname = await self._lookup_member_name(group_id_int, user_id_int) # Note: this routes through is_allowed(). For group bots set # `allow_from: ["*"]` (or include the joining user's id) for welcomes @@ -467,7 +495,14 @@ class NapcatChannel(BaseChannel): await self._ws.send( json.dumps({"action": action, "params": params, "echo": echo}, ensure_ascii=False) ) - return await asyncio.wait_for(fut, timeout=timeout) + resp = await asyncio.wait_for(fut, timeout=timeout) + status = resp.get("status") + retcode = resp.get("retcode") + if (status and status != "ok") or (retcode not in (None, 0)): + raise RuntimeError( + f"napcat: action {action} failed status={status!r} retcode={retcode!r}" + ) + return resp finally: self._pending.pop(echo, None) @@ -503,7 +538,10 @@ class NapcatChannel(BaseChannel): pass try: - async with self._http.get(url, allow_redirects=True) as resp: + async with self._http.get(url, allow_redirects=False) as resp: + if 300 <= resp.status < 400: + logger.warning("napcat: image download redirect rejected url={}", url) + return None if resp.status >= 400: logger.warning("napcat: image download status={} url={}", resp.status, url) return None diff --git a/tests/channels/test_napcat_channel.py b/tests/channels/test_napcat_channel.py new file mode 100644 index 000000000..7ebc917b2 --- /dev/null +++ b/tests/channels/test_napcat_channel.py @@ -0,0 +1,172 @@ +import asyncio + +import pytest + +from nanobot.bus.queue import MessageBus +from nanobot.channels.napcat import NapcatChannel, NapcatConfig + + +class _FakeWs: + def __init__(self) -> None: + self.sent: list[str] = [] + + async def send(self, payload: str) -> None: + self.sent.append(payload) + + async def close(self) -> None: + pass + + +class _FakeContent: + def __init__(self, chunks: list[bytes]) -> None: + self._chunks = chunks + + async def iter_chunked(self, _size: int): + for chunk in self._chunks: + yield chunk + + +class _FakeResponse: + def __init__(self, status: int, chunks: list[bytes] | None = None) -> None: + self.status = status + self.content = _FakeContent(chunks or []) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + +class _FakeHttp: + def __init__(self, response: _FakeResponse) -> None: + self.response = response + self.calls: list[dict] = [] + + def get(self, url: str, **kwargs): + self.calls.append({"url": url, "kwargs": kwargs}) + return self.response + + +def _channel(config: NapcatConfig | None = None) -> NapcatChannel: + return NapcatChannel(config or NapcatConfig(allow_from=["*"]), MessageBus()) + + +@pytest.mark.asyncio +async def test_group_message_requires_mention_by_default() -> None: + channel = _channel(NapcatConfig(allow_from=["user1"], group_policy="mention")) + channel._self_id = 42 + + await channel._on_message( + { + "message_id": 1, + "message_type": "group", + "group_id": 100, + "user_id": "user1", + "sender": {"nickname": "Alice"}, + "message": [{"type": "text", "data": {"text": "hello"}}], + } + ) + + assert channel.bus.inbound_size == 0 + + +@pytest.mark.asyncio +async def test_group_mention_routes_with_sender_label() -> None: + channel = _channel(NapcatConfig(allow_from=["user1"], group_policy="mention")) + channel._self_id = 42 + + await channel._on_message( + { + "message_id": 1, + "message_type": "group", + "group_id": 100, + "user_id": "user1", + "sender": {"card": "Alice"}, + "message": [ + {"type": "at", "data": {"qq": "42"}}, + {"type": "text", "data": {"text": "hello"}}, + ], + } + ) + + msg = await channel.bus.consume_inbound() + assert msg.sender_id == "user1" + assert msg.chat_id == "group:100" + assert msg.content == "Alice: hello" + assert msg.metadata["message_id"] == 1 + + +@pytest.mark.asyncio +async def test_call_action_raises_on_onebot_failure_and_clears_pending() -> None: + channel = _channel() + channel._ws = _FakeWs() + + task = asyncio.create_task(channel._call_action("send_msg", {"message": []})) + while not channel._pending: + await asyncio.sleep(0) + fut = next(iter(channel._pending.values())) + fut.set_result({"status": "failed", "retcode": 1400, "wording": "bad request"}) + + with pytest.raises(RuntimeError, match="action send_msg failed"): + await task + assert channel._pending == {} + + +@pytest.mark.asyncio +async def test_notice_with_invalid_ids_is_ignored(monkeypatch) -> None: + channel = _channel() + + async def fail_lookup(*_args, **_kwargs): + raise AssertionError("lookup should not be called for invalid ids") + + monkeypatch.setattr(channel, "_lookup_member_name", fail_lookup) + + await channel._on_notice( + { + "notice_type": "group_increase", + "group_id": "not-an-int", + "user_id": "user1", + } + ) + + assert channel.bus.inbound_size == 0 + + +@pytest.mark.asyncio +async def test_download_image_rejects_redirects(tmp_path, monkeypatch) -> None: + channel = _channel() + channel._media_root = tmp_path + channel._http = _FakeHttp(_FakeResponse(status=302)) + monkeypatch.setattr( + "nanobot.channels.napcat.validate_url_target", + lambda _url: (True, ""), + ) + + result = await channel._download_image({"url": "https://example.com/a.png", "file": "a.png"}) + + assert result is None + assert channel._http.calls == [ + {"url": "https://example.com/a.png", "kwargs": {"allow_redirects": False}} + ] + assert list(tmp_path.iterdir()) == [] + + +@pytest.mark.asyncio +async def test_dispatch_tracks_and_discards_background_tasks() -> None: + channel = _channel() + seen = asyncio.Event() + + async def fake_on_message(_payload): + seen.set() + + channel._on_message = fake_on_message + + await channel._dispatch_frame( + '{"post_type":"message","message_type":"private","user_id":"user1","message":"hi"}' + ) + + assert len(channel._background_tasks) == 1 + await asyncio.wait_for(seen.wait(), timeout=1) + await asyncio.sleep(0) + assert channel._background_tasks == set()