mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 06:14:02 +00:00
fix(napcat): harden async handlers and action errors
maintainer edit: track background handler tasks, surface failed OneBot actions, reject image redirects, and add focused unit coverage for group routing and edge cases.
This commit is contained in:
parent
0c3063b78c
commit
d5692bf94c
@ -12,11 +12,12 @@ import uuid
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
|
||||
from websockets.asyncio.client import ClientConnection, connect as ws_connect
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
from websockets.asyncio.client import connect as ws_connect
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@ -75,6 +76,7 @@ class NapcatChannel(BaseChannel):
|
||||
self._pending: dict[str, asyncio.Future[dict[str, Any]]] = {}
|
||||
self._processed_ids: deque[int] = deque(maxlen=2000)
|
||||
self._bot_outbound_ids: deque[int] = deque(maxlen=2000)
|
||||
self._background_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
@ -161,6 +163,12 @@ class NapcatChannel(BaseChannel):
|
||||
pass
|
||||
self._http = None
|
||||
self._fail_pending(RuntimeError("napcat: stopped"))
|
||||
tasks = list(self._background_tasks)
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
|
||||
def _fail_pending(self, err: BaseException) -> None:
|
||||
for fut in self._pending.values():
|
||||
@ -198,9 +206,24 @@ class NapcatChannel(BaseChannel):
|
||||
|
||||
post_type = payload.get("post_type")
|
||||
if post_type == "message":
|
||||
asyncio.create_task(self._on_message(payload))
|
||||
self._create_background_task(self._on_message(payload), "message")
|
||||
elif post_type == "notice":
|
||||
asyncio.create_task(self._on_notice(payload)) # avoid deadlock
|
||||
self._create_background_task(self._on_notice(payload), "notice")
|
||||
|
||||
def _create_background_task(self, coro: Any, kind: str) -> None:
|
||||
task = asyncio.create_task(coro)
|
||||
self._background_tasks.add(task)
|
||||
|
||||
def _done(done: asyncio.Task[None]) -> None:
|
||||
self._background_tasks.discard(done)
|
||||
try:
|
||||
done.result()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning("napcat: {} handler failed: {}", kind, e)
|
||||
|
||||
task.add_done_callback(_done)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound: messages
|
||||
@ -245,8 +268,6 @@ class NapcatChannel(BaseChannel):
|
||||
return
|
||||
|
||||
chat_id = f"group:{group_id}"
|
||||
label = nickname or str(user_id)
|
||||
content = f"{label}: {text}"
|
||||
content = self._format_group_content(
|
||||
text=text,
|
||||
nickname=nickname,
|
||||
@ -367,7 +388,14 @@ class NapcatChannel(BaseChannel):
|
||||
if group_id is None or user_id is None:
|
||||
return
|
||||
|
||||
nickname = await self._lookup_member_name(int(group_id), int(user_id))
|
||||
try:
|
||||
group_id_int = int(group_id)
|
||||
user_id_int = int(user_id)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("napcat: invalid group_increase ids group_id={} user_id={}", group_id, user_id)
|
||||
return
|
||||
|
||||
nickname = await self._lookup_member_name(group_id_int, user_id_int)
|
||||
|
||||
# Note: this routes through is_allowed(). For group bots set
|
||||
# `allow_from: ["*"]` (or include the joining user's id) for welcomes
|
||||
@ -467,7 +495,14 @@ class NapcatChannel(BaseChannel):
|
||||
await self._ws.send(
|
||||
json.dumps({"action": action, "params": params, "echo": echo}, ensure_ascii=False)
|
||||
)
|
||||
return await asyncio.wait_for(fut, timeout=timeout)
|
||||
resp = await asyncio.wait_for(fut, timeout=timeout)
|
||||
status = resp.get("status")
|
||||
retcode = resp.get("retcode")
|
||||
if (status and status != "ok") or (retcode not in (None, 0)):
|
||||
raise RuntimeError(
|
||||
f"napcat: action {action} failed status={status!r} retcode={retcode!r}"
|
||||
)
|
||||
return resp
|
||||
finally:
|
||||
self._pending.pop(echo, None)
|
||||
|
||||
@ -503,7 +538,10 @@ class NapcatChannel(BaseChannel):
|
||||
pass
|
||||
|
||||
try:
|
||||
async with self._http.get(url, allow_redirects=True) as resp:
|
||||
async with self._http.get(url, allow_redirects=False) as resp:
|
||||
if 300 <= resp.status < 400:
|
||||
logger.warning("napcat: image download redirect rejected url={}", url)
|
||||
return None
|
||||
if resp.status >= 400:
|
||||
logger.warning("napcat: image download status={} url={}", resp.status, url)
|
||||
return None
|
||||
|
||||
172
tests/channels/test_napcat_channel.py
Normal file
172
tests/channels/test_napcat_channel.py
Normal file
@ -0,0 +1,172 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.napcat import NapcatChannel, NapcatConfig
|
||||
|
||||
|
||||
class _FakeWs:
|
||||
def __init__(self) -> None:
|
||||
self.sent: list[str] = []
|
||||
|
||||
async def send(self, payload: str) -> None:
|
||||
self.sent.append(payload)
|
||||
|
||||
async def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _FakeContent:
|
||||
def __init__(self, chunks: list[bytes]) -> None:
|
||||
self._chunks = chunks
|
||||
|
||||
async def iter_chunked(self, _size: int):
|
||||
for chunk in self._chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
def __init__(self, status: int, chunks: list[bytes] | None = None) -> None:
|
||||
self.status = status
|
||||
self.content = _FakeContent(chunks or [])
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeHttp:
|
||||
def __init__(self, response: _FakeResponse) -> None:
|
||||
self.response = response
|
||||
self.calls: list[dict] = []
|
||||
|
||||
def get(self, url: str, **kwargs):
|
||||
self.calls.append({"url": url, "kwargs": kwargs})
|
||||
return self.response
|
||||
|
||||
|
||||
def _channel(config: NapcatConfig | None = None) -> NapcatChannel:
|
||||
return NapcatChannel(config or NapcatConfig(allow_from=["*"]), MessageBus())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_requires_mention_by_default() -> None:
|
||||
channel = _channel(NapcatConfig(allow_from=["user1"], group_policy="mention"))
|
||||
channel._self_id = 42
|
||||
|
||||
await channel._on_message(
|
||||
{
|
||||
"message_id": 1,
|
||||
"message_type": "group",
|
||||
"group_id": 100,
|
||||
"user_id": "user1",
|
||||
"sender": {"nickname": "Alice"},
|
||||
"message": [{"type": "text", "data": {"text": "hello"}}],
|
||||
}
|
||||
)
|
||||
|
||||
assert channel.bus.inbound_size == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_mention_routes_with_sender_label() -> None:
|
||||
channel = _channel(NapcatConfig(allow_from=["user1"], group_policy="mention"))
|
||||
channel._self_id = 42
|
||||
|
||||
await channel._on_message(
|
||||
{
|
||||
"message_id": 1,
|
||||
"message_type": "group",
|
||||
"group_id": 100,
|
||||
"user_id": "user1",
|
||||
"sender": {"card": "Alice"},
|
||||
"message": [
|
||||
{"type": "at", "data": {"qq": "42"}},
|
||||
{"type": "text", "data": {"text": "hello"}},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "group:100"
|
||||
assert msg.content == "Alice: hello"
|
||||
assert msg.metadata["message_id"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_action_raises_on_onebot_failure_and_clears_pending() -> None:
|
||||
channel = _channel()
|
||||
channel._ws = _FakeWs()
|
||||
|
||||
task = asyncio.create_task(channel._call_action("send_msg", {"message": []}))
|
||||
while not channel._pending:
|
||||
await asyncio.sleep(0)
|
||||
fut = next(iter(channel._pending.values()))
|
||||
fut.set_result({"status": "failed", "retcode": 1400, "wording": "bad request"})
|
||||
|
||||
with pytest.raises(RuntimeError, match="action send_msg failed"):
|
||||
await task
|
||||
assert channel._pending == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notice_with_invalid_ids_is_ignored(monkeypatch) -> None:
|
||||
channel = _channel()
|
||||
|
||||
async def fail_lookup(*_args, **_kwargs):
|
||||
raise AssertionError("lookup should not be called for invalid ids")
|
||||
|
||||
monkeypatch.setattr(channel, "_lookup_member_name", fail_lookup)
|
||||
|
||||
await channel._on_notice(
|
||||
{
|
||||
"notice_type": "group_increase",
|
||||
"group_id": "not-an-int",
|
||||
"user_id": "user1",
|
||||
}
|
||||
)
|
||||
|
||||
assert channel.bus.inbound_size == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_image_rejects_redirects(tmp_path, monkeypatch) -> None:
|
||||
channel = _channel()
|
||||
channel._media_root = tmp_path
|
||||
channel._http = _FakeHttp(_FakeResponse(status=302))
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.napcat.validate_url_target",
|
||||
lambda _url: (True, ""),
|
||||
)
|
||||
|
||||
result = await channel._download_image({"url": "https://example.com/a.png", "file": "a.png"})
|
||||
|
||||
assert result is None
|
||||
assert channel._http.calls == [
|
||||
{"url": "https://example.com/a.png", "kwargs": {"allow_redirects": False}}
|
||||
]
|
||||
assert list(tmp_path.iterdir()) == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_tracks_and_discards_background_tasks() -> None:
|
||||
channel = _channel()
|
||||
seen = asyncio.Event()
|
||||
|
||||
async def fake_on_message(_payload):
|
||||
seen.set()
|
||||
|
||||
channel._on_message = fake_on_message
|
||||
|
||||
await channel._dispatch_frame(
|
||||
'{"post_type":"message","message_type":"private","user_id":"user1","message":"hi"}'
|
||||
)
|
||||
|
||||
assert len(channel._background_tasks) == 1
|
||||
await asyncio.wait_for(seen.wait(), timeout=1)
|
||||
await asyncio.sleep(0)
|
||||
assert channel._background_tasks == set()
|
||||
Loading…
x
Reference in New Issue
Block a user