mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-14 06:43:53 +00:00
Maintainer edit for PR 4115: rebase onto origin/main and split gateway HTTP routing from token, media, and workspace services so WebSocketChannel depends on explicit gateway services instead of GatewayHTTPHandler internals. Preserve file edit channel capabilities and restore tools.restrict_to_workspace wiring through ChannelManager.
557 lines
19 KiB
Python
557 lines
19 KiB
Python
"""Integration tests for the WebSocket channel using WsTestClient.
|
|
|
|
Complements the unit/lightweight tests in test_websocket_channel.py by covering
|
|
multi-client scenarios, edge cases, and realistic usage patterns.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
import websockets
|
|
from ws_test_client import WsTestClient, issue_token, issue_token_ok
|
|
|
|
from nanobot.bus.events import OutboundMessage
|
|
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
|
|
from nanobot.webui.gateway_services import build_gateway_services
|
|
|
|
|
|
def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
|
|
cfg: dict[str, Any] = {
|
|
"enabled": True,
|
|
"allowFrom": ["*"],
|
|
"host": "127.0.0.1",
|
|
"port": port,
|
|
"path": "/",
|
|
"websocketRequiresToken": False,
|
|
}
|
|
cfg.update(kw)
|
|
parsed = WebSocketConfig.model_validate(cfg)
|
|
gateway = build_gateway_services(
|
|
config=parsed,
|
|
bus=bus,
|
|
session_manager=None,
|
|
static_dist_path=None,
|
|
workspace_path=Path.cwd(),
|
|
default_restrict_to_workspace=False,
|
|
runtime_model_name=None,
|
|
runtime_surface="browser",
|
|
runtime_capabilities_overrides=None,
|
|
)
|
|
return WebSocketChannel(cfg, bus, gateway=gateway)
|
|
|
|
|
|
@pytest.fixture()
|
|
def bus() -> MagicMock:
|
|
b = MagicMock()
|
|
b.publish_inbound = AsyncMock()
|
|
return b
|
|
|
|
|
|
# -- Connection basics ----------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ready_event_fields(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29901)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29901/", client_id="c1") as c:
|
|
r = await c.recv_ready()
|
|
assert r.event == "ready"
|
|
assert len(r.chat_id) == 36
|
|
assert r.client_id == "c1"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29902)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29902/", client_id="") as c:
|
|
r = await c.recv_ready()
|
|
assert r.client_id.startswith("anon-")
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_each_connection_unique_chat_id(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29903)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29903/", client_id="a") as c1:
|
|
async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2:
|
|
assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
# -- Inbound messages (client -> server) ----------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_plain_text(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29904)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29904/", client_id="p") as c:
|
|
await c.recv_ready()
|
|
await c.send_text("hello world")
|
|
await asyncio.sleep(0.1)
|
|
inbound = bus.publish_inbound.call_args[0][0]
|
|
assert inbound.content == "hello world"
|
|
assert inbound.sender_id == "p"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_json_content_field(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29905)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29905/", client_id="j") as c:
|
|
await c.recv_ready()
|
|
await c.send_json({"content": "structured"})
|
|
await asyncio.sleep(0.1)
|
|
assert bus.publish_inbound.call_args[0][0].content == "structured"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_json_text_and_message_fields(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29906)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29906/", client_id="x") as c:
|
|
await c.recv_ready()
|
|
await c.send_json({"text": "via text"})
|
|
await asyncio.sleep(0.1)
|
|
assert bus.publish_inbound.call_args[0][0].content == "via text"
|
|
await c.send_json({"message": "via message"})
|
|
await asyncio.sleep(0.1)
|
|
assert bus.publish_inbound.call_args[0][0].content == "via message"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_payload_ignored(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29907)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29907/", client_id="e") as c:
|
|
await c.recv_ready()
|
|
await c.send_text(" ")
|
|
await c.send_json({})
|
|
await asyncio.sleep(0.1)
|
|
bus.publish_inbound.assert_not_awaited()
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_messages_preserve_order(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29908)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29908/", client_id="o") as c:
|
|
await c.recv_ready()
|
|
for i in range(5):
|
|
await c.send_text(f"msg-{i}")
|
|
await asyncio.sleep(0.2)
|
|
contents = [call[0][0].content for call in bus.publish_inbound.call_args_list]
|
|
assert contents == [f"msg-{i}" for i in range(5)]
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
# -- Outbound messages (server -> client) ---------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_server_send_message(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29909)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29909/", client_id="r") as c:
|
|
ready = await c.recv_ready()
|
|
await ch.send(OutboundMessage(
|
|
channel="websocket", chat_id=ready.chat_id, content="reply",
|
|
))
|
|
msg = await c.recv_message()
|
|
assert msg.text == "reply"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_server_send_tags_tool_hint_with_kind(bus: MagicMock) -> None:
|
|
"""``_tool_hint`` metadata must surface as ``kind: "tool_hint"`` so WS
|
|
clients render breadcrumbs separately from conversational replies."""
|
|
ch = _ch(bus, 29919)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29919/", client_id="h") as c:
|
|
ready = await c.recv_ready()
|
|
# Plain reply: no "kind" field.
|
|
await ch.send(OutboundMessage(
|
|
channel="websocket", chat_id=ready.chat_id, content="hi",
|
|
))
|
|
plain = await c.recv_message()
|
|
assert plain.raw.get("kind") is None
|
|
|
|
# Tool-hint breadcrumb: kind == "tool_hint".
|
|
await ch.send(OutboundMessage(
|
|
channel="websocket", chat_id=ready.chat_id,
|
|
content='weather("get")',
|
|
metadata={"_progress": True, "_tool_hint": True},
|
|
))
|
|
hint = await c.recv_message()
|
|
assert hint.raw.get("kind") == "tool_hint"
|
|
assert hint.text == 'weather("get")'
|
|
|
|
# Generic progress (non-tool-hint) gets the softer "progress" label.
|
|
await ch.send(OutboundMessage(
|
|
channel="websocket", chat_id=ready.chat_id,
|
|
content="thinking…",
|
|
metadata={"_progress": True},
|
|
))
|
|
prog = await c.recv_message()
|
|
assert prog.raw.get("kind") == "progress"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_server_send_with_media_and_reply(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29910)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29910/", client_id="m") as c:
|
|
ready = await c.recv_ready()
|
|
await ch.send(OutboundMessage(
|
|
channel="websocket", chat_id=ready.chat_id, content="img",
|
|
media=["/tmp/a.png"], reply_to="m1",
|
|
))
|
|
msg = await c.recv_message()
|
|
assert msg.text == "img"
|
|
assert msg.media == ["/tmp/a.png"]
|
|
assert msg.reply_to == "m1"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
# -- Streaming ------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_streaming_deltas_and_end(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29911, streaming=True)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29911/", client_id="s") as c:
|
|
cid = (await c.recv_ready()).chat_id
|
|
for part in ("Hello", " ", "world", "!"):
|
|
await ch.send_delta(cid, part, {"_stream_delta": True, "_stream_id": "s1"})
|
|
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "s1"})
|
|
|
|
msgs = await c.collect_stream()
|
|
deltas = [m for m in msgs if m.event == "delta"]
|
|
assert "".join(d.text for d in deltas) == "Hello world!"
|
|
ends = [m for m in msgs if m.event == "stream_end"]
|
|
assert len(ends) == 1
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_interleaved_streams(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29912, streaming=True)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29912/", client_id="i") as c:
|
|
cid = (await c.recv_ready()).chat_id
|
|
await ch.send_delta(cid, "A1", {"_stream_delta": True, "_stream_id": "sa"})
|
|
await ch.send_delta(cid, "B1", {"_stream_delta": True, "_stream_id": "sb"})
|
|
await ch.send_delta(cid, "A2", {"_stream_delta": True, "_stream_id": "sa"})
|
|
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sa"})
|
|
await ch.send_delta(cid, "B2", {"_stream_delta": True, "_stream_id": "sb"})
|
|
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sb"})
|
|
|
|
msgs = await c.recv_n(6)
|
|
sa = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sa")
|
|
sb = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sb")
|
|
assert sa == "A1A2"
|
|
assert sb == "B1B2"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
# -- Multi-client ---------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_independent_sessions(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29913)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u1") as c1:
|
|
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u2") as c2:
|
|
r1, r2 = await c1.recv_ready(), await c2.recv_ready()
|
|
await ch.send(OutboundMessage(
|
|
channel="websocket", chat_id=r1.chat_id, content="for-u1",
|
|
))
|
|
assert (await c1.recv_message()).text == "for-u1"
|
|
await ch.send(OutboundMessage(
|
|
channel="websocket", chat_id=r2.chat_id, content="for-u2",
|
|
))
|
|
assert (await c2.recv_message()).text == "for-u2"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnected_client_cleanup(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29914)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29914/", client_id="tmp") as c:
|
|
chat_id = (await c.recv_ready()).chat_id
|
|
# disconnected
|
|
await asyncio.sleep(0.1)
|
|
await ch.send(OutboundMessage(
|
|
channel="websocket", chat_id=chat_id, content="orphan",
|
|
))
|
|
assert chat_id not in ch._subs
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
# -- Authentication -------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_static_token_accepted(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29915, token="secret")
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c:
|
|
assert (await c.recv_ready()).client_id == "a"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_static_token_rejected(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29916, token="correct")
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
|
async with WsTestClient("ws://127.0.0.1:29916/", client_id="b", token="wrong"):
|
|
pass
|
|
assert exc.value.response.status_code == 401
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_issue_full_flow(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29917, path="/ws",
|
|
tokenIssuePath="/auth/token", tokenIssueSecret="s",
|
|
websocketRequiresToken=True)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
# no secret -> 401
|
|
_, status = await issue_token(port=29917, issue_path="/auth/token")
|
|
assert status == 401
|
|
|
|
# with secret -> token
|
|
token = await issue_token_ok(port=29917, issue_path="/auth/token", secret="s")
|
|
|
|
# no token -> 401
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
|
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="x"):
|
|
pass
|
|
assert exc.value.response.status_code == 401
|
|
|
|
# valid token -> ok
|
|
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="ok", token=token) as c:
|
|
assert (await c.recv_ready()).client_id == "ok"
|
|
|
|
# reuse -> 401
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
|
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="r", token=token):
|
|
pass
|
|
assert exc.value.response.status_code == 401
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
# -- Path routing ---------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_path(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29918, path="/my-chat")
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c:
|
|
assert (await c.recv_ready()).event == "ready"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wrong_path_404(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29919, path="/ws")
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
|
async with WsTestClient("ws://127.0.0.1:29919/wrong", client_id="x"):
|
|
pass
|
|
assert exc.value.response.status_code == 404
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_trailing_slash_normalized(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29920, path="/ws")
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c:
|
|
assert (await c.recv_ready()).event == "ready"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
# -- Edge cases -----------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_large_message(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29921)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29921/", client_id="big") as c:
|
|
await c.recv_ready()
|
|
big = "x" * 100_000
|
|
await c.send_text(big)
|
|
await asyncio.sleep(0.2)
|
|
assert bus.publish_inbound.call_args[0][0].content == big
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unicode_roundtrip(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29922)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29922/", client_id="u") as c:
|
|
ready = await c.recv_ready()
|
|
text = "你好世界 🌍 日本語テスト"
|
|
await c.send_text(text)
|
|
await asyncio.sleep(0.1)
|
|
assert bus.publish_inbound.call_args[0][0].content == text
|
|
await ch.send(OutboundMessage(
|
|
channel="websocket", chat_id=ready.chat_id, content=text,
|
|
))
|
|
assert (await c.recv_message()).text == text
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rapid_fire(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29923)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29923/", client_id="r") as c:
|
|
ready = await c.recv_ready()
|
|
for i in range(50):
|
|
await c.send_text(f"in-{i}")
|
|
await asyncio.sleep(0.5)
|
|
assert bus.publish_inbound.await_count == 50
|
|
for i in range(50):
|
|
await ch.send(OutboundMessage(
|
|
channel="websocket", chat_id=ready.chat_id, content=f"out-{i}",
|
|
))
|
|
received = [(await c.recv_message()).text for _ in range(50)]
|
|
assert received == [f"out-{i}" for i in range(50)]
|
|
finally:
|
|
await ch.stop()
|
|
await t
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_json_as_plain_text(bus: MagicMock) -> None:
|
|
ch = _ch(bus, 29924)
|
|
t = asyncio.create_task(ch.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
async with WsTestClient("ws://127.0.0.1:29924/", client_id="j") as c:
|
|
await c.recv_ready()
|
|
await c.send_text("{broken json")
|
|
await asyncio.sleep(0.1)
|
|
assert bus.publish_inbound.call_args[0][0].content == "{broken json"
|
|
finally:
|
|
await ch.stop()
|
|
await t
|