fix(napcat): harden async handlers and action errors

maintainer edit: track background handler tasks, surface failed OneBot actions, reject image redirects, and add focused unit coverage for group routing and edge cases.
This commit is contained in:
chengyongru 2026-06-02 00:52:54 +08:00 committed by Xubin Ren
parent 0c3063b78c
commit d5692bf94c
2 changed files with 219 additions and 9 deletions

View File

@ -12,11 +12,12 @@ import uuid
from collections import deque
from pathlib import Path
from typing import Annotated, Any, Literal
import aiohttp
from loguru import logger
from pydantic import Field
from websockets.asyncio.client import ClientConnection, connect as ws_connect
from websockets.asyncio.client import ClientConnection
from websockets.asyncio.client import connect as ws_connect
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
@ -75,6 +76,7 @@ class NapcatChannel(BaseChannel):
self._pending: dict[str, asyncio.Future[dict[str, Any]]] = {}
self._processed_ids: deque[int] = deque(maxlen=2000)
self._bot_outbound_ids: deque[int] = deque(maxlen=2000)
self._background_tasks: set[asyncio.Task[None]] = set()
# ------------------------------------------------------------------
# Lifecycle
@ -161,6 +163,12 @@ class NapcatChannel(BaseChannel):
pass
self._http = None
self._fail_pending(RuntimeError("napcat: stopped"))
tasks = list(self._background_tasks)
for task in tasks:
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
self._background_tasks.clear()
def _fail_pending(self, err: BaseException) -> None:
for fut in self._pending.values():
@ -198,9 +206,24 @@ class NapcatChannel(BaseChannel):
post_type = payload.get("post_type")
if post_type == "message":
asyncio.create_task(self._on_message(payload))
self._create_background_task(self._on_message(payload), "message")
elif post_type == "notice":
asyncio.create_task(self._on_notice(payload)) # avoid deadlock
self._create_background_task(self._on_notice(payload), "notice")
def _create_background_task(self, coro: Any, kind: str) -> None:
task = asyncio.create_task(coro)
self._background_tasks.add(task)
def _done(done: asyncio.Task[None]) -> None:
self._background_tasks.discard(done)
try:
done.result()
except asyncio.CancelledError:
pass
except Exception as e:
logger.warning("napcat: {} handler failed: {}", kind, e)
task.add_done_callback(_done)
# ------------------------------------------------------------------
# Inbound: messages
@ -245,8 +268,6 @@ class NapcatChannel(BaseChannel):
return
chat_id = f"group:{group_id}"
label = nickname or str(user_id)
content = f"{label}: {text}"
content = self._format_group_content(
text=text,
nickname=nickname,
@ -367,7 +388,14 @@ class NapcatChannel(BaseChannel):
if group_id is None or user_id is None:
return
nickname = await self._lookup_member_name(int(group_id), int(user_id))
try:
group_id_int = int(group_id)
user_id_int = int(user_id)
except (TypeError, ValueError):
logger.warning("napcat: invalid group_increase ids group_id={} user_id={}", group_id, user_id)
return
nickname = await self._lookup_member_name(group_id_int, user_id_int)
# Note: this routes through is_allowed(). For group bots set
# `allow_from: ["*"]` (or include the joining user's id) for welcomes
@ -467,7 +495,14 @@ class NapcatChannel(BaseChannel):
await self._ws.send(
json.dumps({"action": action, "params": params, "echo": echo}, ensure_ascii=False)
)
return await asyncio.wait_for(fut, timeout=timeout)
resp = await asyncio.wait_for(fut, timeout=timeout)
status = resp.get("status")
retcode = resp.get("retcode")
if (status and status != "ok") or (retcode not in (None, 0)):
raise RuntimeError(
f"napcat: action {action} failed status={status!r} retcode={retcode!r}"
)
return resp
finally:
self._pending.pop(echo, None)
@ -503,7 +538,10 @@ class NapcatChannel(BaseChannel):
pass
try:
async with self._http.get(url, allow_redirects=True) as resp:
async with self._http.get(url, allow_redirects=False) as resp:
if 300 <= resp.status < 400:
logger.warning("napcat: image download redirect rejected url={}", url)
return None
if resp.status >= 400:
logger.warning("napcat: image download status={} url={}", resp.status, url)
return None

View File

@ -0,0 +1,172 @@
import asyncio
import pytest
from nanobot.bus.queue import MessageBus
from nanobot.channels.napcat import NapcatChannel, NapcatConfig
class _FakeWs:
def __init__(self) -> None:
self.sent: list[str] = []
async def send(self, payload: str) -> None:
self.sent.append(payload)
async def close(self) -> None:
pass
class _FakeContent:
def __init__(self, chunks: list[bytes]) -> None:
self._chunks = chunks
async def iter_chunked(self, _size: int):
for chunk in self._chunks:
yield chunk
class _FakeResponse:
def __init__(self, status: int, chunks: list[bytes] | None = None) -> None:
self.status = status
self.content = _FakeContent(chunks or [])
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb) -> None:
return None
class _FakeHttp:
def __init__(self, response: _FakeResponse) -> None:
self.response = response
self.calls: list[dict] = []
def get(self, url: str, **kwargs):
self.calls.append({"url": url, "kwargs": kwargs})
return self.response
def _channel(config: NapcatConfig | None = None) -> NapcatChannel:
return NapcatChannel(config or NapcatConfig(allow_from=["*"]), MessageBus())
@pytest.mark.asyncio
async def test_group_message_requires_mention_by_default() -> None:
channel = _channel(NapcatConfig(allow_from=["user1"], group_policy="mention"))
channel._self_id = 42
await channel._on_message(
{
"message_id": 1,
"message_type": "group",
"group_id": 100,
"user_id": "user1",
"sender": {"nickname": "Alice"},
"message": [{"type": "text", "data": {"text": "hello"}}],
}
)
assert channel.bus.inbound_size == 0
@pytest.mark.asyncio
async def test_group_mention_routes_with_sender_label() -> None:
channel = _channel(NapcatConfig(allow_from=["user1"], group_policy="mention"))
channel._self_id = 42
await channel._on_message(
{
"message_id": 1,
"message_type": "group",
"group_id": 100,
"user_id": "user1",
"sender": {"card": "Alice"},
"message": [
{"type": "at", "data": {"qq": "42"}},
{"type": "text", "data": {"text": "hello"}},
],
}
)
msg = await channel.bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "group:100"
assert msg.content == "Alice: hello"
assert msg.metadata["message_id"] == 1
@pytest.mark.asyncio
async def test_call_action_raises_on_onebot_failure_and_clears_pending() -> None:
channel = _channel()
channel._ws = _FakeWs()
task = asyncio.create_task(channel._call_action("send_msg", {"message": []}))
while not channel._pending:
await asyncio.sleep(0)
fut = next(iter(channel._pending.values()))
fut.set_result({"status": "failed", "retcode": 1400, "wording": "bad request"})
with pytest.raises(RuntimeError, match="action send_msg failed"):
await task
assert channel._pending == {}
@pytest.mark.asyncio
async def test_notice_with_invalid_ids_is_ignored(monkeypatch) -> None:
channel = _channel()
async def fail_lookup(*_args, **_kwargs):
raise AssertionError("lookup should not be called for invalid ids")
monkeypatch.setattr(channel, "_lookup_member_name", fail_lookup)
await channel._on_notice(
{
"notice_type": "group_increase",
"group_id": "not-an-int",
"user_id": "user1",
}
)
assert channel.bus.inbound_size == 0
@pytest.mark.asyncio
async def test_download_image_rejects_redirects(tmp_path, monkeypatch) -> None:
channel = _channel()
channel._media_root = tmp_path
channel._http = _FakeHttp(_FakeResponse(status=302))
monkeypatch.setattr(
"nanobot.channels.napcat.validate_url_target",
lambda _url: (True, ""),
)
result = await channel._download_image({"url": "https://example.com/a.png", "file": "a.png"})
assert result is None
assert channel._http.calls == [
{"url": "https://example.com/a.png", "kwargs": {"allow_redirects": False}}
]
assert list(tmp_path.iterdir()) == []
@pytest.mark.asyncio
async def test_dispatch_tracks_and_discards_background_tasks() -> None:
channel = _channel()
seen = asyncio.Event()
async def fake_on_message(_payload):
seen.set()
channel._on_message = fake_on_message
await channel._dispatch_frame(
'{"post_type":"message","message_type":"private","user_id":"user1","message":"hi"}'
)
assert len(channel._background_tasks) == 1
await asyncio.wait_for(seen.wait(), timeout=1)
await asyncio.sleep(0)
assert channel._background_tasks == set()