mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 06:14:02 +00:00
fix(napcat): harden async handlers and action errors
maintainer edit: track background handler tasks, surface failed OneBot actions, reject image redirects, and add focused unit coverage for group routing and edge cases.
This commit is contained in:
parent
0c3063b78c
commit
d5692bf94c
@ -12,11 +12,12 @@ import uuid
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, Literal
|
from typing import Annotated, Any, Literal
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from websockets.asyncio.client import ClientConnection
|
||||||
from websockets.asyncio.client import ClientConnection, connect as ws_connect
|
from websockets.asyncio.client import connect as ws_connect
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@ -75,6 +76,7 @@ class NapcatChannel(BaseChannel):
|
|||||||
self._pending: dict[str, asyncio.Future[dict[str, Any]]] = {}
|
self._pending: dict[str, asyncio.Future[dict[str, Any]]] = {}
|
||||||
self._processed_ids: deque[int] = deque(maxlen=2000)
|
self._processed_ids: deque[int] = deque(maxlen=2000)
|
||||||
self._bot_outbound_ids: deque[int] = deque(maxlen=2000)
|
self._bot_outbound_ids: deque[int] = deque(maxlen=2000)
|
||||||
|
self._background_tasks: set[asyncio.Task[None]] = set()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Lifecycle
|
# Lifecycle
|
||||||
@ -161,6 +163,12 @@ class NapcatChannel(BaseChannel):
|
|||||||
pass
|
pass
|
||||||
self._http = None
|
self._http = None
|
||||||
self._fail_pending(RuntimeError("napcat: stopped"))
|
self._fail_pending(RuntimeError("napcat: stopped"))
|
||||||
|
tasks = list(self._background_tasks)
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
self._background_tasks.clear()
|
||||||
|
|
||||||
def _fail_pending(self, err: BaseException) -> None:
|
def _fail_pending(self, err: BaseException) -> None:
|
||||||
for fut in self._pending.values():
|
for fut in self._pending.values():
|
||||||
@ -198,9 +206,24 @@ class NapcatChannel(BaseChannel):
|
|||||||
|
|
||||||
post_type = payload.get("post_type")
|
post_type = payload.get("post_type")
|
||||||
if post_type == "message":
|
if post_type == "message":
|
||||||
asyncio.create_task(self._on_message(payload))
|
self._create_background_task(self._on_message(payload), "message")
|
||||||
elif post_type == "notice":
|
elif post_type == "notice":
|
||||||
asyncio.create_task(self._on_notice(payload)) # avoid deadlock
|
self._create_background_task(self._on_notice(payload), "notice")
|
||||||
|
|
||||||
|
def _create_background_task(self, coro: Any, kind: str) -> None:
|
||||||
|
task = asyncio.create_task(coro)
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
|
||||||
|
def _done(done: asyncio.Task[None]) -> None:
|
||||||
|
self._background_tasks.discard(done)
|
||||||
|
try:
|
||||||
|
done.result()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("napcat: {} handler failed: {}", kind, e)
|
||||||
|
|
||||||
|
task.add_done_callback(_done)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Inbound: messages
|
# Inbound: messages
|
||||||
@ -245,8 +268,6 @@ class NapcatChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
chat_id = f"group:{group_id}"
|
chat_id = f"group:{group_id}"
|
||||||
label = nickname or str(user_id)
|
|
||||||
content = f"{label}: {text}"
|
|
||||||
content = self._format_group_content(
|
content = self._format_group_content(
|
||||||
text=text,
|
text=text,
|
||||||
nickname=nickname,
|
nickname=nickname,
|
||||||
@ -367,7 +388,14 @@ class NapcatChannel(BaseChannel):
|
|||||||
if group_id is None or user_id is None:
|
if group_id is None or user_id is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
nickname = await self._lookup_member_name(int(group_id), int(user_id))
|
try:
|
||||||
|
group_id_int = int(group_id)
|
||||||
|
user_id_int = int(user_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logger.warning("napcat: invalid group_increase ids group_id={} user_id={}", group_id, user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
nickname = await self._lookup_member_name(group_id_int, user_id_int)
|
||||||
|
|
||||||
# Note: this routes through is_allowed(). For group bots set
|
# Note: this routes through is_allowed(). For group bots set
|
||||||
# `allow_from: ["*"]` (or include the joining user's id) for welcomes
|
# `allow_from: ["*"]` (or include the joining user's id) for welcomes
|
||||||
@ -467,7 +495,14 @@ class NapcatChannel(BaseChannel):
|
|||||||
await self._ws.send(
|
await self._ws.send(
|
||||||
json.dumps({"action": action, "params": params, "echo": echo}, ensure_ascii=False)
|
json.dumps({"action": action, "params": params, "echo": echo}, ensure_ascii=False)
|
||||||
)
|
)
|
||||||
return await asyncio.wait_for(fut, timeout=timeout)
|
resp = await asyncio.wait_for(fut, timeout=timeout)
|
||||||
|
status = resp.get("status")
|
||||||
|
retcode = resp.get("retcode")
|
||||||
|
if (status and status != "ok") or (retcode not in (None, 0)):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"napcat: action {action} failed status={status!r} retcode={retcode!r}"
|
||||||
|
)
|
||||||
|
return resp
|
||||||
finally:
|
finally:
|
||||||
self._pending.pop(echo, None)
|
self._pending.pop(echo, None)
|
||||||
|
|
||||||
@ -503,7 +538,10 @@ class NapcatChannel(BaseChannel):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._http.get(url, allow_redirects=True) as resp:
|
async with self._http.get(url, allow_redirects=False) as resp:
|
||||||
|
if 300 <= resp.status < 400:
|
||||||
|
logger.warning("napcat: image download redirect rejected url={}", url)
|
||||||
|
return None
|
||||||
if resp.status >= 400:
|
if resp.status >= 400:
|
||||||
logger.warning("napcat: image download status={} url={}", resp.status, url)
|
logger.warning("napcat: image download status={} url={}", resp.status, url)
|
||||||
return None
|
return None
|
||||||
|
|||||||
172
tests/channels/test_napcat_channel.py
Normal file
172
tests/channels/test_napcat_channel.py
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.napcat import NapcatChannel, NapcatConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeWs:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.sent: list[str] = []
|
||||||
|
|
||||||
|
async def send(self, payload: str) -> None:
|
||||||
|
self.sent.append(payload)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeContent:
|
||||||
|
def __init__(self, chunks: list[bytes]) -> None:
|
||||||
|
self._chunks = chunks
|
||||||
|
|
||||||
|
async def iter_chunked(self, _size: int):
|
||||||
|
for chunk in self._chunks:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResponse:
|
||||||
|
def __init__(self, status: int, chunks: list[bytes] | None = None) -> None:
|
||||||
|
self.status = status
|
||||||
|
self.content = _FakeContent(chunks or [])
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeHttp:
|
||||||
|
def __init__(self, response: _FakeResponse) -> None:
|
||||||
|
self.response = response
|
||||||
|
self.calls: list[dict] = []
|
||||||
|
|
||||||
|
def get(self, url: str, **kwargs):
|
||||||
|
self.calls.append({"url": url, "kwargs": kwargs})
|
||||||
|
return self.response
|
||||||
|
|
||||||
|
|
||||||
|
def _channel(config: NapcatConfig | None = None) -> NapcatChannel:
|
||||||
|
return NapcatChannel(config or NapcatConfig(allow_from=["*"]), MessageBus())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_message_requires_mention_by_default() -> None:
|
||||||
|
channel = _channel(NapcatConfig(allow_from=["user1"], group_policy="mention"))
|
||||||
|
channel._self_id = 42
|
||||||
|
|
||||||
|
await channel._on_message(
|
||||||
|
{
|
||||||
|
"message_id": 1,
|
||||||
|
"message_type": "group",
|
||||||
|
"group_id": 100,
|
||||||
|
"user_id": "user1",
|
||||||
|
"sender": {"nickname": "Alice"},
|
||||||
|
"message": [{"type": "text", "data": {"text": "hello"}}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel.bus.inbound_size == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_mention_routes_with_sender_label() -> None:
|
||||||
|
channel = _channel(NapcatConfig(allow_from=["user1"], group_policy="mention"))
|
||||||
|
channel._self_id = 42
|
||||||
|
|
||||||
|
await channel._on_message(
|
||||||
|
{
|
||||||
|
"message_id": 1,
|
||||||
|
"message_type": "group",
|
||||||
|
"group_id": 100,
|
||||||
|
"user_id": "user1",
|
||||||
|
"sender": {"card": "Alice"},
|
||||||
|
"message": [
|
||||||
|
{"type": "at", "data": {"qq": "42"}},
|
||||||
|
{"type": "text", "data": {"text": "hello"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert msg.sender_id == "user1"
|
||||||
|
assert msg.chat_id == "group:100"
|
||||||
|
assert msg.content == "Alice: hello"
|
||||||
|
assert msg.metadata["message_id"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_action_raises_on_onebot_failure_and_clears_pending() -> None:
|
||||||
|
channel = _channel()
|
||||||
|
channel._ws = _FakeWs()
|
||||||
|
|
||||||
|
task = asyncio.create_task(channel._call_action("send_msg", {"message": []}))
|
||||||
|
while not channel._pending:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
fut = next(iter(channel._pending.values()))
|
||||||
|
fut.set_result({"status": "failed", "retcode": 1400, "wording": "bad request"})
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="action send_msg failed"):
|
||||||
|
await task
|
||||||
|
assert channel._pending == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_notice_with_invalid_ids_is_ignored(monkeypatch) -> None:
|
||||||
|
channel = _channel()
|
||||||
|
|
||||||
|
async def fail_lookup(*_args, **_kwargs):
|
||||||
|
raise AssertionError("lookup should not be called for invalid ids")
|
||||||
|
|
||||||
|
monkeypatch.setattr(channel, "_lookup_member_name", fail_lookup)
|
||||||
|
|
||||||
|
await channel._on_notice(
|
||||||
|
{
|
||||||
|
"notice_type": "group_increase",
|
||||||
|
"group_id": "not-an-int",
|
||||||
|
"user_id": "user1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel.bus.inbound_size == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_image_rejects_redirects(tmp_path, monkeypatch) -> None:
|
||||||
|
channel = _channel()
|
||||||
|
channel._media_root = tmp_path
|
||||||
|
channel._http = _FakeHttp(_FakeResponse(status=302))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.napcat.validate_url_target",
|
||||||
|
lambda _url: (True, ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await channel._download_image({"url": "https://example.com/a.png", "file": "a.png"})
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
assert channel._http.calls == [
|
||||||
|
{"url": "https://example.com/a.png", "kwargs": {"allow_redirects": False}}
|
||||||
|
]
|
||||||
|
assert list(tmp_path.iterdir()) == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_tracks_and_discards_background_tasks() -> None:
|
||||||
|
channel = _channel()
|
||||||
|
seen = asyncio.Event()
|
||||||
|
|
||||||
|
async def fake_on_message(_payload):
|
||||||
|
seen.set()
|
||||||
|
|
||||||
|
channel._on_message = fake_on_message
|
||||||
|
|
||||||
|
await channel._dispatch_frame(
|
||||||
|
'{"post_type":"message","message_type":"private","user_id":"user1","message":"hi"}'
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(channel._background_tasks) == 1
|
||||||
|
await asyncio.wait_for(seen.wait(), timeout=1)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert channel._background_tasks == set()
|
||||||
Loading…
x
Reference in New Issue
Block a user