mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 13:13:39 +00:00
Port Python implementation from a1ec7b192ad97ffd58250a720891ff09bbb73888 (websocket channel module and channel tests; excludes webui debug app).
330 lines
10 KiB
Python
330 lines
10 KiB
Python
"""Unit and lightweight integration tests for the WebSocket channel."""
|
|
|
|
import asyncio
|
|
import functools
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import httpx
|
|
import pytest
|
|
import websockets
|
|
|
|
from nanobot.bus.events import OutboundMessage
|
|
from nanobot.channels.websocket import (
|
|
WebSocketChannel,
|
|
WebSocketConfig,
|
|
_issue_route_secret_matches,
|
|
_normalize_config_path,
|
|
_normalize_http_path,
|
|
_parse_inbound_payload,
|
|
_parse_query,
|
|
_parse_request_path,
|
|
)
|
|
|
|
|
|
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
|
|
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
|
|
return await asyncio.to_thread(
|
|
functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0)
|
|
)
|
|
|
|
|
|
def test_normalize_http_path_strips_trailing_slash_except_root() -> None:
|
|
assert _normalize_http_path("/chat/") == "/chat"
|
|
assert _normalize_http_path("/chat?x=1") == "/chat"
|
|
assert _normalize_http_path("/") == "/"
|
|
|
|
|
|
def test_parse_request_path_matches_normalize_and_query() -> None:
|
|
path, query = _parse_request_path("/ws/?token=secret&client_id=u1")
|
|
assert path == _normalize_http_path("/ws/?token=secret&client_id=u1")
|
|
assert query == _parse_query("/ws/?token=secret&client_id=u1")
|
|
|
|
|
|
def test_normalize_config_path_matches_request() -> None:
|
|
assert _normalize_config_path("/ws/") == "/ws"
|
|
assert _normalize_config_path("/") == "/"
|
|
|
|
|
|
def test_parse_query_extracts_token_and_client_id() -> None:
|
|
query = _parse_query("/?token=secret&client_id=u1")
|
|
assert query.get("token") == ["secret"]
|
|
assert query.get("client_id") == ["u1"]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("raw", "expected"),
|
|
[
|
|
("plain", "plain"),
|
|
('{"content": "hi"}', "hi"),
|
|
('{"text": "there"}', "there"),
|
|
('{"message": "x"}', "x"),
|
|
(" ", None),
|
|
("{}", None),
|
|
],
|
|
)
|
|
def test_parse_inbound_payload(raw: str, expected: str | None) -> None:
|
|
assert _parse_inbound_payload(raw) == expected
|
|
|
|
|
|
def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
|
|
assert _parse_inbound_payload("{not json") == "{not json"
|
|
|
|
|
|
def test_web_socket_config_path_must_start_with_slash() -> None:
|
|
with pytest.raises(ValueError, match='path must start with "/"'):
|
|
WebSocketConfig(path="bad")
|
|
|
|
|
|
def test_ssl_context_requires_both_cert_and_key_files() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel(
|
|
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
|
|
bus,
|
|
)
|
|
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
|
|
channel._build_ssl_context()
|
|
|
|
|
|
def test_default_config_includes_safe_bind_and_streaming() -> None:
|
|
defaults = WebSocketChannel.default_config()
|
|
assert defaults["enabled"] is False
|
|
assert defaults["host"] == "127.0.0.1"
|
|
assert defaults["streaming"] is True
|
|
assert defaults["allowFrom"] == ["*"]
|
|
assert defaults.get("tokenIssuePath", "") == ""
|
|
|
|
|
|
def test_token_issue_path_must_differ_from_websocket_path() -> None:
|
|
with pytest.raises(ValueError, match="token_issue_path must differ"):
|
|
WebSocketConfig(path="/ws", token_issue_path="/ws")
|
|
|
|
|
|
def test_issue_route_secret_matches_bearer_and_header() -> None:
|
|
from websockets.datastructures import Headers
|
|
|
|
secret = "my-secret"
|
|
bearer_headers = Headers([("Authorization", "Bearer my-secret")])
|
|
assert _issue_route_secret_matches(bearer_headers, secret) is True
|
|
x_headers = Headers([("X-Nanobot-Auth", "my-secret")])
|
|
assert _issue_route_secret_matches(x_headers, secret) is True
|
|
wrong = Headers([("Authorization", "Bearer other")])
|
|
assert _issue_route_secret_matches(wrong, secret) is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
|
mock_ws = AsyncMock()
|
|
channel._connections["chat-1"] = mock_ws
|
|
|
|
msg = OutboundMessage(
|
|
channel="websocket",
|
|
chat_id="chat-1",
|
|
content="hello",
|
|
reply_to="m1",
|
|
media=["/tmp/a.png"],
|
|
)
|
|
await channel.send(msg)
|
|
|
|
mock_ws.send.assert_awaited_once()
|
|
payload = json.loads(mock_ws.send.call_args[0][0])
|
|
assert payload["event"] == "message"
|
|
assert payload["text"] == "hello"
|
|
assert payload["reply_to"] == "m1"
|
|
assert payload["media"] == ["/tmp/a.png"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_missing_connection_is_noop_without_error() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
|
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
|
|
await channel.send(msg)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_delta_emits_delta_and_stream_end() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
|
mock_ws = AsyncMock()
|
|
channel._connections["chat-1"] = mock_ws
|
|
|
|
await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"})
|
|
await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"})
|
|
|
|
assert mock_ws.send.await_count == 2
|
|
first = json.loads(mock_ws.send.call_args_list[0][0][0])
|
|
second = json.loads(mock_ws.send.call_args_list[1][0][0])
|
|
assert first["event"] == "delta"
|
|
assert first["text"] == "part"
|
|
assert first["stream_id"] == "sid"
|
|
assert second["event"] == "stream_end"
|
|
assert second["stream_id"] == "sid"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None:
|
|
bus = MagicMock()
|
|
bus.publish_inbound = AsyncMock()
|
|
port = 29876
|
|
channel = WebSocketChannel(
|
|
{
|
|
"enabled": True,
|
|
"allowFrom": ["*"],
|
|
"host": "127.0.0.1",
|
|
"port": port,
|
|
"path": "/ws",
|
|
},
|
|
bus,
|
|
)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client:
|
|
ready_raw = await client.recv()
|
|
ready = json.loads(ready_raw)
|
|
assert ready["event"] == "ready"
|
|
assert ready["client_id"] == "tester"
|
|
chat_id = ready["chat_id"]
|
|
|
|
await client.send(json.dumps({"content": "ping from client"}))
|
|
await asyncio.sleep(0.08)
|
|
|
|
bus.publish_inbound.assert_awaited()
|
|
inbound = bus.publish_inbound.call_args[0][0]
|
|
assert inbound.channel == "websocket"
|
|
assert inbound.sender_id == "tester"
|
|
assert inbound.chat_id == chat_id
|
|
assert inbound.content == "ping from client"
|
|
|
|
await client.send("plain text frame")
|
|
await asyncio.sleep(0.08)
|
|
assert bus.publish_inbound.await_count >= 2
|
|
second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1]
|
|
assert second.content == "plain text frame"
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_rejects_handshake_when_mismatch() -> None:
|
|
bus = MagicMock()
|
|
port = 29877
|
|
channel = WebSocketChannel(
|
|
{
|
|
"enabled": True,
|
|
"allowFrom": ["*"],
|
|
"host": "127.0.0.1",
|
|
"port": port,
|
|
"path": "/",
|
|
"token": "secret",
|
|
},
|
|
bus,
|
|
)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"):
|
|
pass
|
|
assert excinfo.value.response.status_code == 401
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wrong_path_returns_404() -> None:
|
|
bus = MagicMock()
|
|
port = 29878
|
|
channel = WebSocketChannel(
|
|
{
|
|
"enabled": True,
|
|
"allowFrom": ["*"],
|
|
"host": "127.0.0.1",
|
|
"port": port,
|
|
"path": "/ws",
|
|
},
|
|
bus,
|
|
)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/other"):
|
|
pass
|
|
assert excinfo.value.response.status_code == 404
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
def test_registry_discovers_websocket_channel() -> None:
|
|
from nanobot.channels.registry import load_channel_class
|
|
|
|
cls = load_channel_class("websocket")
|
|
assert cls.name == "websocket"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_route_issues_token_then_websocket_requires_it() -> None:
|
|
bus = MagicMock()
|
|
bus.publish_inbound = AsyncMock()
|
|
port = 29879
|
|
channel = WebSocketChannel(
|
|
{
|
|
"enabled": True,
|
|
"allowFrom": ["*"],
|
|
"host": "127.0.0.1",
|
|
"port": port,
|
|
"path": "/ws",
|
|
"tokenIssuePath": "/auth/token",
|
|
"tokenIssueSecret": "route-secret",
|
|
"websocketRequiresToken": True,
|
|
},
|
|
bus,
|
|
)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
deny = await _http_get(f"http://127.0.0.1:{port}/auth/token")
|
|
assert deny.status_code == 401
|
|
|
|
issue = await _http_get(
|
|
f"http://127.0.0.1:{port}/auth/token",
|
|
headers={"Authorization": "Bearer route-secret"},
|
|
)
|
|
assert issue.status_code == 200
|
|
token = issue.json()["token"]
|
|
assert token.startswith("nbwt_")
|
|
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"):
|
|
pass
|
|
assert missing_token.value.response.status_code == 401
|
|
|
|
uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller"
|
|
async with websockets.connect(uri) as client:
|
|
ready = json.loads(await client.recv())
|
|
assert ready["event"] == "ready"
|
|
assert ready["client_id"] == "caller"
|
|
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as reuse:
|
|
async with websockets.connect(uri):
|
|
pass
|
|
assert reuse.value.response.status_code == 401
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|