nanobot/tests/channels/test_websocket_integration.py
chengyongru 2a98360105 refactor: split WebUI gateway dependencies
Maintainer edit for PR 4115: rebase onto origin/main and split gateway HTTP routing from token, media, and workspace services so WebSocketChannel depends on explicit gateway services instead of GatewayHTTPHandler internals.

Preserve file edit channel capabilities and restore tools.restrict_to_workspace wiring through ChannelManager.
2026-06-02 17:14:38 +08:00

557 lines
19 KiB
Python

"""Integration tests for the WebSocket channel using WsTestClient.
Complements the unit/lightweight tests in test_websocket_channel.py by covering
multi-client scenarios, edge cases, and realistic usage patterns.
"""
from __future__ import annotations
import asyncio
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
import websockets
from ws_test_client import WsTestClient, issue_token, issue_token_ok
from nanobot.bus.events import OutboundMessage
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
from nanobot.webui.gateway_services import build_gateway_services
def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
cfg: dict[str, Any] = {
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/",
"websocketRequiresToken": False,
}
cfg.update(kw)
parsed = WebSocketConfig.model_validate(cfg)
gateway = build_gateway_services(
config=parsed,
bus=bus,
session_manager=None,
static_dist_path=None,
workspace_path=Path.cwd(),
default_restrict_to_workspace=False,
runtime_model_name=None,
runtime_surface="browser",
runtime_capabilities_overrides=None,
)
return WebSocketChannel(cfg, bus, gateway=gateway)
@pytest.fixture()
def bus() -> MagicMock:
b = MagicMock()
b.publish_inbound = AsyncMock()
return b
# -- Connection basics ----------------------------------------------------
@pytest.mark.asyncio
async def test_ready_event_fields(bus: MagicMock) -> None:
ch = _ch(bus, 29901)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29901/", client_id="c1") as c:
r = await c.recv_ready()
assert r.event == "ready"
assert len(r.chat_id) == 36
assert r.client_id == "c1"
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None:
ch = _ch(bus, 29902)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29902/", client_id="") as c:
r = await c.recv_ready()
assert r.client_id.startswith("anon-")
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_each_connection_unique_chat_id(bus: MagicMock) -> None:
ch = _ch(bus, 29903)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29903/", client_id="a") as c1:
async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2:
assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id
finally:
await ch.stop()
await t
# -- Inbound messages (client -> server) ----------------------------------
@pytest.mark.asyncio
async def test_plain_text(bus: MagicMock) -> None:
ch = _ch(bus, 29904)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29904/", client_id="p") as c:
await c.recv_ready()
await c.send_text("hello world")
await asyncio.sleep(0.1)
inbound = bus.publish_inbound.call_args[0][0]
assert inbound.content == "hello world"
assert inbound.sender_id == "p"
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_json_content_field(bus: MagicMock) -> None:
ch = _ch(bus, 29905)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29905/", client_id="j") as c:
await c.recv_ready()
await c.send_json({"content": "structured"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "structured"
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_json_text_and_message_fields(bus: MagicMock) -> None:
ch = _ch(bus, 29906)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29906/", client_id="x") as c:
await c.recv_ready()
await c.send_json({"text": "via text"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "via text"
await c.send_json({"message": "via message"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "via message"
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_empty_payload_ignored(bus: MagicMock) -> None:
ch = _ch(bus, 29907)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29907/", client_id="e") as c:
await c.recv_ready()
await c.send_text(" ")
await c.send_json({})
await asyncio.sleep(0.1)
bus.publish_inbound.assert_not_awaited()
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_messages_preserve_order(bus: MagicMock) -> None:
ch = _ch(bus, 29908)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29908/", client_id="o") as c:
await c.recv_ready()
for i in range(5):
await c.send_text(f"msg-{i}")
await asyncio.sleep(0.2)
contents = [call[0][0].content for call in bus.publish_inbound.call_args_list]
assert contents == [f"msg-{i}" for i in range(5)]
finally:
await ch.stop()
await t
# -- Outbound messages (server -> client) ---------------------------------
@pytest.mark.asyncio
async def test_server_send_message(bus: MagicMock) -> None:
ch = _ch(bus, 29909)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29909/", client_id="r") as c:
ready = await c.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content="reply",
))
msg = await c.recv_message()
assert msg.text == "reply"
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_server_send_tags_tool_hint_with_kind(bus: MagicMock) -> None:
"""``_tool_hint`` metadata must surface as ``kind: "tool_hint"`` so WS
clients render breadcrumbs separately from conversational replies."""
ch = _ch(bus, 29919)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29919/", client_id="h") as c:
ready = await c.recv_ready()
# Plain reply: no "kind" field.
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content="hi",
))
plain = await c.recv_message()
assert plain.raw.get("kind") is None
# Tool-hint breadcrumb: kind == "tool_hint".
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id,
content='weather("get")',
metadata={"_progress": True, "_tool_hint": True},
))
hint = await c.recv_message()
assert hint.raw.get("kind") == "tool_hint"
assert hint.text == 'weather("get")'
# Generic progress (non-tool-hint) gets the softer "progress" label.
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id,
content="thinking…",
metadata={"_progress": True},
))
prog = await c.recv_message()
assert prog.raw.get("kind") == "progress"
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_server_send_with_media_and_reply(bus: MagicMock) -> None:
ch = _ch(bus, 29910)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29910/", client_id="m") as c:
ready = await c.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content="img",
media=["/tmp/a.png"], reply_to="m1",
))
msg = await c.recv_message()
assert msg.text == "img"
assert msg.media == ["/tmp/a.png"]
assert msg.reply_to == "m1"
finally:
await ch.stop()
await t
# -- Streaming ------------------------------------------------------------
@pytest.mark.asyncio
async def test_streaming_deltas_and_end(bus: MagicMock) -> None:
ch = _ch(bus, 29911, streaming=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29911/", client_id="s") as c:
cid = (await c.recv_ready()).chat_id
for part in ("Hello", " ", "world", "!"):
await ch.send_delta(cid, part, {"_stream_delta": True, "_stream_id": "s1"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "s1"})
msgs = await c.collect_stream()
deltas = [m for m in msgs if m.event == "delta"]
assert "".join(d.text for d in deltas) == "Hello world!"
ends = [m for m in msgs if m.event == "stream_end"]
assert len(ends) == 1
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_interleaved_streams(bus: MagicMock) -> None:
ch = _ch(bus, 29912, streaming=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29912/", client_id="i") as c:
cid = (await c.recv_ready()).chat_id
await ch.send_delta(cid, "A1", {"_stream_delta": True, "_stream_id": "sa"})
await ch.send_delta(cid, "B1", {"_stream_delta": True, "_stream_id": "sb"})
await ch.send_delta(cid, "A2", {"_stream_delta": True, "_stream_id": "sa"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sa"})
await ch.send_delta(cid, "B2", {"_stream_delta": True, "_stream_id": "sb"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sb"})
msgs = await c.recv_n(6)
sa = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sa")
sb = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sb")
assert sa == "A1A2"
assert sb == "B1B2"
finally:
await ch.stop()
await t
# -- Multi-client ---------------------------------------------------------
@pytest.mark.asyncio
async def test_independent_sessions(bus: MagicMock) -> None:
ch = _ch(bus, 29913)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u1") as c1:
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u2") as c2:
r1, r2 = await c1.recv_ready(), await c2.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=r1.chat_id, content="for-u1",
))
assert (await c1.recv_message()).text == "for-u1"
await ch.send(OutboundMessage(
channel="websocket", chat_id=r2.chat_id, content="for-u2",
))
assert (await c2.recv_message()).text == "for-u2"
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_disconnected_client_cleanup(bus: MagicMock) -> None:
ch = _ch(bus, 29914)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29914/", client_id="tmp") as c:
chat_id = (await c.recv_ready()).chat_id
# disconnected
await asyncio.sleep(0.1)
await ch.send(OutboundMessage(
channel="websocket", chat_id=chat_id, content="orphan",
))
assert chat_id not in ch._subs
finally:
await ch.stop()
await t
# -- Authentication -------------------------------------------------------
@pytest.mark.asyncio
async def test_static_token_accepted(bus: MagicMock) -> None:
ch = _ch(bus, 29915, token="secret")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c:
assert (await c.recv_ready()).client_id == "a"
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_static_token_rejected(bus: MagicMock) -> None:
ch = _ch(bus, 29916, token="correct")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29916/", client_id="b", token="wrong"):
pass
assert exc.value.response.status_code == 401
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_token_issue_full_flow(bus: MagicMock) -> None:
ch = _ch(bus, 29917, path="/ws",
tokenIssuePath="/auth/token", tokenIssueSecret="s",
websocketRequiresToken=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
# no secret -> 401
_, status = await issue_token(port=29917, issue_path="/auth/token")
assert status == 401
# with secret -> token
token = await issue_token_ok(port=29917, issue_path="/auth/token", secret="s")
# no token -> 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="x"):
pass
assert exc.value.response.status_code == 401
# valid token -> ok
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="ok", token=token) as c:
assert (await c.recv_ready()).client_id == "ok"
# reuse -> 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="r", token=token):
pass
assert exc.value.response.status_code == 401
finally:
await ch.stop()
await t
# -- Path routing ---------------------------------------------------------
@pytest.mark.asyncio
async def test_custom_path(bus: MagicMock) -> None:
ch = _ch(bus, 29918, path="/my-chat")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c:
assert (await c.recv_ready()).event == "ready"
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_wrong_path_404(bus: MagicMock) -> None:
ch = _ch(bus, 29919, path="/ws")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29919/wrong", client_id="x"):
pass
assert exc.value.response.status_code == 404
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_trailing_slash_normalized(bus: MagicMock) -> None:
ch = _ch(bus, 29920, path="/ws")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c:
assert (await c.recv_ready()).event == "ready"
finally:
await ch.stop()
await t
# -- Edge cases -----------------------------------------------------------
@pytest.mark.asyncio
async def test_large_message(bus: MagicMock) -> None:
ch = _ch(bus, 29921)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29921/", client_id="big") as c:
await c.recv_ready()
big = "x" * 100_000
await c.send_text(big)
await asyncio.sleep(0.2)
assert bus.publish_inbound.call_args[0][0].content == big
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_unicode_roundtrip(bus: MagicMock) -> None:
ch = _ch(bus, 29922)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29922/", client_id="u") as c:
ready = await c.recv_ready()
text = "你好世界 🌍 日本語テスト"
await c.send_text(text)
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == text
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content=text,
))
assert (await c.recv_message()).text == text
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_rapid_fire(bus: MagicMock) -> None:
ch = _ch(bus, 29923)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29923/", client_id="r") as c:
ready = await c.recv_ready()
for i in range(50):
await c.send_text(f"in-{i}")
await asyncio.sleep(0.5)
assert bus.publish_inbound.await_count == 50
for i in range(50):
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content=f"out-{i}",
))
received = [(await c.recv_message()).text for _ in range(50)]
assert received == [f"out-{i}" for i in range(50)]
finally:
await ch.stop()
await t
@pytest.mark.asyncio
async def test_invalid_json_as_plain_text(bus: MagicMock) -> None:
ch = _ch(bus, 29924)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29924/", client_id="j") as c:
await c.recv_ready()
await c.send_text("{broken json")
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "{broken json"
finally:
await ch.stop()
await t