From 8f7ce9fef79232c70233516b01de82b40e598999 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Thu, 9 Apr 2026 15:18:02 +0800 Subject: [PATCH] fix(websocket): harden security and robustness - Use hmac.compare_digest for timing-safe static token comparison - Add issued token capacity limit (_MAX_ISSUED_TOKENS=10000) with 429 response - Use atomic pop in _take_issued_token_if_valid to eliminate TOCTOU window - Enforce TLSv1.2 minimum version for SSL connections - Extract _safe_send helper for consistent ConnectionClosed handling - Move connection registration after ready send to prevent out-of-order delivery - Add HTTP-level allow_from check and client_id truncation in process_request - Make stop() idempotent with graceful shutdown error handling - Normalize path via validator instead of leaving raw value - Default websocket_requires_token to True for secure-by-default behavior - Add integration tests and ws_test_client helper - Refactor tests to use shared _ch factory and bus fixture --- nanobot/channels/websocket.py | 124 +++-- tests/channels/test_websocket_channel.py | 373 +++++++++++++-- tests/channels/test_websocket_integration.py | 477 +++++++++++++++++++ tests/channels/ws_test_client.py | 227 +++++++++ 4 files changed, 1102 insertions(+), 99 deletions(-) create mode 100644 tests/channels/test_websocket_integration.py create mode 100644 tests/channels/ws_test_client.py diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index 1660cbe7e..2af61d6f7 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -65,7 +65,7 @@ class WebSocketConfig(Base): token_issue_path: str = "" token_issue_secret: str = "" token_ttl_s: int = Field(default=300, ge=30, le=86_400) - websocket_requires_token: bool = False + websocket_requires_token: bool = True allow_from: list[str] = Field(default_factory=lambda: ["*"]) streaming: bool = True max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216) @@ -79,7 +79,7 @@ class WebSocketConfig(Base): def path_must_start_with_slash(cls, value: str) -> str: if not value.startswith("/"): raise ValueError('path must start with "/"') - return value + return _normalize_config_path(value) @field_validator("token_issue_path") @classmethod @@ -130,6 +130,12 @@ def _parse_query(path_with_query: str) -> dict[str, list[str]]: return _parse_request_path(path_with_query)[1] +def _query_first(query: dict[str, list[str]], key: str) -> str | None: + """Return the first value for *key*, or None.""" + values = query.get(key) + return values[0] if values else None + + def _parse_inbound_payload(raw: str) -> str | None: """Parse a client frame into text; return None for empty or unrecognized content.""" text = raw.strip() @@ -197,9 +203,12 @@ class WebSocketChannel(BaseChannel): "websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty" ) ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 ctx.load_cert_chain(certfile=cert, keyfile=key) return ctx + _MAX_ISSUED_TOKENS = 10_000 + def _purge_expired_issued_tokens(self) -> None: now = time.monotonic() for token_key, expiry in list(self._issued_tokens.items()): @@ -207,17 +216,19 @@ class WebSocketChannel(BaseChannel): self._issued_tokens.pop(token_key, None) def _take_issued_token_if_valid(self, token_value: str | None) -> bool: - """Validate and consume one issued token (single use per connection attempt).""" + """Validate and consume one issued token (single use per connection attempt). + + Uses single-step pop to minimize the window between lookup and removal; + safe under asyncio's single-threaded cooperative model. + """ if not token_value: return False self._purge_expired_issued_tokens() - expiry = self._issued_tokens.get(token_value) + expiry = self._issued_tokens.pop(token_value, None) if expiry is None: return False if time.monotonic() > expiry: - self._issued_tokens.pop(token_value, None) return False - self._issued_tokens.pop(token_value, None) return True def _handle_token_issue_http(self, connection: Any, request: Any) -> Any: @@ -231,6 +242,12 @@ class WebSocketChannel(BaseChannel): "any client can obtain connection tokens — set token_issue_secret for production." ) self._purge_expired_issued_tokens() + if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS: + logger.error( + "websocket: too many outstanding issued tokens ({}), rejecting issuance", + len(self._issued_tokens), + ) + return _http_json_response({"error": "too many outstanding tokens"}, status=429) token_value = f"nbwt_{secrets.token_urlsafe(32)}" self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s) @@ -238,13 +255,12 @@ class WebSocketChannel(BaseChannel): {"token": token_value, "expires_in": self.config.token_ttl_s} ) - def _authorize_websocket_handshake(self, connection: Any, request_path: str) -> Any: - query = _parse_query(request_path) - supplied = (query.get("token") or [None])[0] + def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any: + supplied = _query_first(query, "token") static_token = self.config.token.strip() if static_token: - if supplied == static_token: + if supplied and hmac.compare_digest(supplied, static_token): return None if supplied and self._take_issued_token_if_valid(supplied): return None @@ -279,7 +295,15 @@ class WebSocketChannel(BaseChannel): expected_ws = self._expected_path() if got != expected_ws: return connection.respond(404, "Not Found") - return self._authorize_websocket_handshake(connection, request.path) + # Early reject before WebSocket upgrade to avoid unnecessary overhead; + # _handle_message() performs a second check as defense-in-depth. + query = _parse_query(request.path) + client_id = _query_first(query, "client_id") or "" + if len(client_id) > 128: + client_id = client_id[:128] + if not self.is_allowed(client_id): + return connection.respond(403, "Forbidden") + return self._authorize_websocket_handshake(connection, query) async def handler(connection: ServerConnection) -> None: await self._connection_loop(connection) @@ -321,26 +345,30 @@ class WebSocketChannel(BaseChannel): request = connection.request path_part = request.path if request else "/" _, query = _parse_request_path(path_part) - client_id_raw = (query.get("client_id") or [None])[0] + client_id_raw = _query_first(query, "client_id") client_id = client_id_raw.strip() if client_id_raw else "" if not client_id: client_id = f"anon-{uuid.uuid4().hex[:12]}" + elif len(client_id) > 128: + logger.warning("websocket: client_id too long ({} chars), truncating", len(client_id)) + client_id = client_id[:128] chat_id = str(uuid.uuid4()) - self._connections[chat_id] = connection - - await connection.send( - json.dumps( - { - "event": "ready", - "chat_id": chat_id, - "client_id": client_id, - }, - ensure_ascii=False, - ) - ) try: + await connection.send( + json.dumps( + { + "event": "ready", + "chat_id": chat_id, + "client_id": client_id, + }, + ensure_ascii=False, + ) + ) + # Register only after ready is successfully sent to avoid out-of-order sends + self._connections[chat_id] = connection + async for raw in connection: if isinstance(raw, bytes): try: @@ -363,15 +391,34 @@ class WebSocketChannel(BaseChannel): self._connections.pop(chat_id, None) async def stop(self) -> None: + if not self._running: + return self._running = False if self._stop_event: self._stop_event.set() if self._server_task: - await self._server_task + try: + await self._server_task + except Exception as e: + logger.warning("websocket: server task error during shutdown: {}", e) self._server_task = None self._connections.clear() self._issued_tokens.clear() + async def _safe_send(self, chat_id: str, raw: str, *, label: str = "") -> None: + """Send a raw frame, cleaning up dead connections on ConnectionClosed.""" + connection = self._connections.get(chat_id) + if connection is None: + return + try: + await connection.send(raw) + except ConnectionClosed: + self._connections.pop(chat_id, None) + logger.warning("websocket{}connection gone for chat_id={}", label, chat_id) + except Exception as e: + logger.error("websocket{}send failed: {}", label, e) + raise + async def send(self, msg: OutboundMessage) -> None: connection = self._connections.get(msg.chat_id) if connection is None: @@ -386,14 +433,7 @@ class WebSocketChannel(BaseChannel): if msg.reply_to: payload["reply_to"] = msg.reply_to raw = json.dumps(payload, ensure_ascii=False) - try: - await connection.send(raw) - except ConnectionClosed: - self._connections.pop(msg.chat_id, None) - logger.warning("websocket: connection gone for chat_id={}", msg.chat_id) - except Exception as e: - logger.error("websocket send failed: {}", e) - raise + await self._safe_send(msg.chat_id, raw, label=" ") async def send_delta( self, @@ -401,27 +441,17 @@ class WebSocketChannel(BaseChannel): delta: str, metadata: dict[str, Any] | None = None, ) -> None: - connection = self._connections.get(chat_id) - if connection is None: + if self._connections.get(chat_id) is None: return meta = metadata or {} if meta.get("_stream_end"): body: dict[str, Any] = {"event": "stream_end"} - if meta.get("_stream_id") is not None: - body["stream_id"] = meta["_stream_id"] else: body = { "event": "delta", "text": delta, } - if meta.get("_stream_id") is not None: - body["stream_id"] = meta["_stream_id"] + if meta.get("_stream_id") is not None: + body["stream_id"] = meta["_stream_id"] raw = json.dumps(body, ensure_ascii=False) - try: - await connection.send(raw) - except ConnectionClosed: - self._connections.pop(chat_id, None) - logger.warning("websocket: stream connection gone for chat_id={}", chat_id) - except Exception as e: - logger.error("websocket stream send failed: {}", e) - raise + await self._safe_send(chat_id, raw, label=" stream ") diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index e4c5ad635..89a330a18 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -3,11 +3,15 @@ import asyncio import functools import json +import time +from typing import Any from unittest.mock import AsyncMock, MagicMock import httpx import pytest import websockets +from websockets.exceptions import ConnectionClosed +from websockets.frames import Close from nanobot.bus.events import OutboundMessage from nanobot.channels.websocket import ( @@ -21,6 +25,30 @@ from nanobot.channels.websocket import ( _parse_request_path, ) +# -- Shared helpers (aligned with test_websocket_integration.py) --------------- + +_PORT = 29876 + + +def _ch(bus: Any, **kw: Any) -> WebSocketChannel: + cfg: dict[str, Any] = { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": _PORT, + "path": "/ws", + "websocketRequiresToken": False, + } + cfg.update(kw) + return WebSocketChannel(cfg, bus) + + +@pytest.fixture() +def bus() -> MagicMock: + b = MagicMock() + b.publish_inbound = AsyncMock() + return b + async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response: """Run GET in a thread to avoid blocking the asyncio loop shared with websockets.""" @@ -71,6 +99,21 @@ def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None: assert _parse_inbound_payload("{not json") == "{not json" +@pytest.mark.parametrize( + ("raw", "expected"), + [ + ('{"content": ""}', None), # empty string content + ('{"content": 123}', None), # non-string content + ('{"content": " "}', None), # whitespace-only content + ('["hello"]', '["hello"]'), # JSON array: not a dict, treated as plain text + ('{"unknown_key": "val"}', None), # unrecognized key + ('{"content": null}', None), # null content + ], +) +def test_parse_inbound_payload_edge_cases(raw: str, expected: str | None) -> None: + assert _parse_inbound_payload(raw) == expected + + def test_web_socket_config_path_must_start_with_slash() -> None: with pytest.raises(ValueError, match='path must start with "/"'): WebSocketConfig(path="bad") @@ -112,6 +155,14 @@ def test_issue_route_secret_matches_bearer_and_header() -> None: assert _issue_route_secret_matches(wrong, secret) is False +def test_issue_route_secret_matches_empty_secret() -> None: + from websockets.datastructures import Headers + + # Empty secret always returns True regardless of headers + assert _issue_route_secret_matches(Headers([]), "") is True + assert _issue_route_secret_matches(Headers([("Authorization", "Bearer anything")]), "") is True + + @pytest.mark.asyncio async def test_send_delivers_json_message_with_media_and_reply() -> None: bus = MagicMock() @@ -144,6 +195,33 @@ async def test_send_missing_connection_is_noop_without_error() -> None: await channel.send(msg) +@pytest.mark.asyncio +async def test_send_removes_connection_on_connection_closed() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True) + channel._connections["chat-1"] = mock_ws + + msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello") + await channel.send(msg) + + assert "chat-1" not in channel._connections + + +@pytest.mark.asyncio +async def test_send_delta_removes_connection_on_connection_closed() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus) + mock_ws = AsyncMock() + mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True) + channel._connections["chat-1"] = mock_ws + + await channel.send_delta("chat-1", "chunk", {"_stream_delta": True, "_stream_id": "s1"}) + + assert "chat-1" not in channel._connections + + @pytest.mark.asyncio async def test_send_delta_emits_delta_and_stream_end() -> None: bus = MagicMock() @@ -165,20 +243,39 @@ async def test_send_delta_emits_delta_and_stream_end() -> None: @pytest.mark.asyncio -async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None: +async def test_send_non_connection_closed_exception_is_raised() -> None: bus = MagicMock() - bus.publish_inbound = AsyncMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + mock_ws.send.side_effect = RuntimeError("unexpected") + channel._connections["chat-1"] = mock_ws + + msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello") + with pytest.raises(RuntimeError, match="unexpected"): + await channel.send(msg) + + +@pytest.mark.asyncio +async def test_send_delta_missing_connection_is_noop() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus) + # No exception, no error — just a no-op + await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"}) + + +@pytest.mark.asyncio +async def test_stop_is_idempotent() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + # stop() before start() should not raise + await channel.stop() + await channel.stop() + + +@pytest.mark.asyncio +async def test_end_to_end_client_receives_ready_and_agent_sees_inbound(bus: MagicMock) -> None: port = 29876 - channel = WebSocketChannel( - { - "enabled": True, - "allowFrom": ["*"], - "host": "127.0.0.1", - "port": port, - "path": "/ws", - }, - bus, - ) + channel = _ch(bus, port=port) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -212,20 +309,9 @@ async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None @pytest.mark.asyncio -async def test_token_rejects_handshake_when_mismatch() -> None: - bus = MagicMock() +async def test_token_rejects_handshake_when_mismatch(bus: MagicMock) -> None: port = 29877 - channel = WebSocketChannel( - { - "enabled": True, - "allowFrom": ["*"], - "host": "127.0.0.1", - "port": port, - "path": "/", - "token": "secret", - }, - bus, - ) + channel = _ch(bus, port=port, path="/", token="secret") server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -241,19 +327,9 @@ async def test_token_rejects_handshake_when_mismatch() -> None: @pytest.mark.asyncio -async def test_wrong_path_returns_404() -> None: - bus = MagicMock() +async def test_wrong_path_returns_404(bus: MagicMock) -> None: port = 29878 - channel = WebSocketChannel( - { - "enabled": True, - "allowFrom": ["*"], - "host": "127.0.0.1", - "port": port, - "path": "/ws", - }, - bus, - ) + channel = _ch(bus, port=port) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -276,22 +352,13 @@ def test_registry_discovers_websocket_channel() -> None: @pytest.mark.asyncio -async def test_http_route_issues_token_then_websocket_requires_it() -> None: - bus = MagicMock() - bus.publish_inbound = AsyncMock() +async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock) -> None: port = 29879 - channel = WebSocketChannel( - { - "enabled": True, - "allowFrom": ["*"], - "host": "127.0.0.1", - "port": port, - "path": "/ws", - "tokenIssuePath": "/auth/token", - "tokenIssueSecret": "route-secret", - "websocketRequiresToken": True, - }, - bus, + channel = _ch( + bus, port=port, + tokenIssuePath="/auth/token", + tokenIssueSecret="route-secret", + websocketRequiresToken=True, ) server_task = asyncio.create_task(channel.start()) @@ -327,3 +394,205 @@ async def test_http_route_issues_token_then_websocket_requires_it() -> None: finally: await channel.stop() await server_task + + +@pytest.mark.asyncio +async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None: + port = 29880 + channel = _ch(bus, port=port, streaming=True) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=stream-tester") as client: + ready_raw = await client.recv() + ready = json.loads(ready_raw) + chat_id = ready["chat_id"] + + # Server pushes deltas directly + await channel.send_delta( + chat_id, "Hello ", {"_stream_delta": True, "_stream_id": "s1"} + ) + await channel.send_delta( + chat_id, "world", {"_stream_delta": True, "_stream_id": "s1"} + ) + await channel.send_delta( + chat_id, "", {"_stream_end": True, "_stream_id": "s1"} + ) + + delta1 = json.loads(await client.recv()) + assert delta1["event"] == "delta" + assert delta1["text"] == "Hello " + assert delta1["stream_id"] == "s1" + + delta2 = json.loads(await client.recv()) + assert delta2["event"] == "delta" + assert delta2["text"] == "world" + assert delta2["stream_id"] == "s1" + + end = json.loads(await client.recv()) + assert end["event"] == "stream_end" + assert end["stream_id"] == "s1" + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None: + port = 29881 + channel = _ch(bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="s") + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + # Fill issued tokens to capacity + channel._issued_tokens = { + f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._MAX_ISSUED_TOKENS) + } + + resp = await _http_get( + f"http://127.0.0.1:{port}/auth/token", + headers={"Authorization": "Bearer s"}, + ) + assert resp.status_code == 429 + data = resp.json() + assert "error" in data + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_allow_from_rejects_unauthorized_client_id(bus: MagicMock) -> None: + port = 29882 + channel = _ch(bus, port=port, allowFrom=["alice", "bob"]) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=eve"): + pass + assert exc_info.value.response.status_code == 403 + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_client_id_truncation(bus: MagicMock) -> None: + port = 29883 + channel = _ch(bus, port=port) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + long_id = "x" * 200 + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id={long_id}") as client: + ready = json.loads(await client.recv()) + assert ready["client_id"] == "x" * 128 + assert len(ready["client_id"]) == 128 + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_non_utf8_binary_frame_ignored(bus: MagicMock) -> None: + port = 29884 + channel = _ch(bus, port=port) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=bin-test") as client: + await client.recv() # consume ready + # Send non-UTF-8 bytes + await client.send(b"\xff\xfe\xfd") + await asyncio.sleep(0.05) + # publish_inbound should NOT have been called + bus.publish_inbound.assert_not_awaited() + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_static_token_accepts_issued_token_as_fallback(bus: MagicMock) -> None: + port = 29885 + channel = _ch( + bus, port=port, + token="static-secret", + tokenIssuePath="/auth/token", + tokenIssueSecret="route-secret", + ) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + # Get an issued token + resp = await _http_get( + f"http://127.0.0.1:{port}/auth/token", + headers={"Authorization": "Bearer route-secret"}, + ) + assert resp.status_code == 200 + issued_token = resp.json()["token"] + + # Connect using issued token (not the static one) + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?token={issued_token}&client_id=caller") as client: + ready = json.loads(await client.recv()) + assert ready["event"] == "ready" + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_allow_from_empty_list_denies_all(bus: MagicMock) -> None: + port = 29886 + channel = _ch(bus, port=port, allowFrom=[]) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=anyone"): + pass + assert exc_info.value.response.status_code == 403 + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_websocket_requires_token_without_issue_path(bus: MagicMock) -> None: + """When websocket_requires_token is True but no token or issue path configured, all connections are rejected.""" + port = 29887 + channel = _ch(bus, port=port, websocketRequiresToken=True) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + # No token at all → 401 + with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u"): + pass + assert exc_info.value.response.status_code == 401 + + # Wrong token → 401 + with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u&token=wrong"): + pass + assert exc_info.value.response.status_code == 401 + finally: + await channel.stop() + await server_task diff --git a/tests/channels/test_websocket_integration.py b/tests/channels/test_websocket_integration.py new file mode 100644 index 000000000..2cf0331ab --- /dev/null +++ b/tests/channels/test_websocket_integration.py @@ -0,0 +1,477 @@ +"""Integration tests for the WebSocket channel using WsTestClient. + +Complements the unit/lightweight tests in test_websocket_channel.py by covering +multi-client scenarios, edge cases, and realistic usage patterns. +""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest +import websockets + +from nanobot.channels.websocket import WebSocketChannel +from nanobot.bus.events import OutboundMessage +from ws_test_client import WsTestClient, issue_token, issue_token_ok + + +def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel: + cfg: dict[str, Any] = { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": port, + "path": "/", + "websocketRequiresToken": False, + } + cfg.update(kw) + return WebSocketChannel(cfg, bus) + + +@pytest.fixture() +def bus() -> MagicMock: + b = MagicMock() + b.publish_inbound = AsyncMock() + return b + + +# -- Connection basics ---------------------------------------------------- + + +@pytest.mark.asyncio +async def test_ready_event_fields(bus: MagicMock) -> None: + ch = _ch(bus, 29901) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29901/", client_id="c1") as c: + r = await c.recv_ready() + assert r.event == "ready" + assert len(r.chat_id) == 36 + assert r.client_id == "c1" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None: + ch = _ch(bus, 29902) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29902/", client_id="") as c: + r = await c.recv_ready() + assert r.client_id.startswith("anon-") + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_each_connection_unique_chat_id(bus: MagicMock) -> None: + ch = _ch(bus, 29903) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29903/", client_id="a") as c1: + async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2: + assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id + finally: + await ch.stop(); await t + + +# -- Inbound messages (client -> server) ---------------------------------- + + +@pytest.mark.asyncio +async def test_plain_text(bus: MagicMock) -> None: + ch = _ch(bus, 29904) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29904/", client_id="p") as c: + await c.recv_ready() + await c.send_text("hello world") + await asyncio.sleep(0.1) + inbound = bus.publish_inbound.call_args[0][0] + assert inbound.content == "hello world" + assert inbound.sender_id == "p" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_json_content_field(bus: MagicMock) -> None: + ch = _ch(bus, 29905) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29905/", client_id="j") as c: + await c.recv_ready() + await c.send_json({"content": "structured"}) + await asyncio.sleep(0.1) + assert bus.publish_inbound.call_args[0][0].content == "structured" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_json_text_and_message_fields(bus: MagicMock) -> None: + ch = _ch(bus, 29906) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29906/", client_id="x") as c: + await c.recv_ready() + await c.send_json({"text": "via text"}) + await asyncio.sleep(0.1) + assert bus.publish_inbound.call_args[0][0].content == "via text" + await c.send_json({"message": "via message"}) + await asyncio.sleep(0.1) + assert bus.publish_inbound.call_args[0][0].content == "via message" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_empty_payload_ignored(bus: MagicMock) -> None: + ch = _ch(bus, 29907) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29907/", client_id="e") as c: + await c.recv_ready() + await c.send_text(" ") + await c.send_json({}) + await asyncio.sleep(0.1) + bus.publish_inbound.assert_not_awaited() + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_messages_preserve_order(bus: MagicMock) -> None: + ch = _ch(bus, 29908) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29908/", client_id="o") as c: + await c.recv_ready() + for i in range(5): + await c.send_text(f"msg-{i}") + await asyncio.sleep(0.2) + contents = [call[0][0].content for call in bus.publish_inbound.call_args_list] + assert contents == [f"msg-{i}" for i in range(5)] + finally: + await ch.stop(); await t + + +# -- Outbound messages (server -> client) --------------------------------- + + +@pytest.mark.asyncio +async def test_server_send_message(bus: MagicMock) -> None: + ch = _ch(bus, 29909) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29909/", client_id="r") as c: + ready = await c.recv_ready() + await ch.send(OutboundMessage( + channel="websocket", chat_id=ready.chat_id, content="reply", + )) + msg = await c.recv_message() + assert msg.text == "reply" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_server_send_with_media_and_reply(bus: MagicMock) -> None: + ch = _ch(bus, 29910) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29910/", client_id="m") as c: + ready = await c.recv_ready() + await ch.send(OutboundMessage( + channel="websocket", chat_id=ready.chat_id, content="img", + media=["/tmp/a.png"], reply_to="m1", + )) + msg = await c.recv_message() + assert msg.text == "img" + assert msg.media == ["/tmp/a.png"] + assert msg.reply_to == "m1" + finally: + await ch.stop(); await t + + +# -- Streaming ------------------------------------------------------------ + + +@pytest.mark.asyncio +async def test_streaming_deltas_and_end(bus: MagicMock) -> None: + ch = _ch(bus, 29911, streaming=True) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29911/", client_id="s") as c: + cid = (await c.recv_ready()).chat_id + for part in ("Hello", " ", "world", "!"): + await ch.send_delta(cid, part, {"_stream_delta": True, "_stream_id": "s1"}) + await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "s1"}) + + msgs = await c.collect_stream() + deltas = [m for m in msgs if m.event == "delta"] + assert "".join(d.text for d in deltas) == "Hello world!" + ends = [m for m in msgs if m.event == "stream_end"] + assert len(ends) == 1 + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_interleaved_streams(bus: MagicMock) -> None: + ch = _ch(bus, 29912, streaming=True) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29912/", client_id="i") as c: + cid = (await c.recv_ready()).chat_id + await ch.send_delta(cid, "A1", {"_stream_delta": True, "_stream_id": "sa"}) + await ch.send_delta(cid, "B1", {"_stream_delta": True, "_stream_id": "sb"}) + await ch.send_delta(cid, "A2", {"_stream_delta": True, "_stream_id": "sa"}) + await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sa"}) + await ch.send_delta(cid, "B2", {"_stream_delta": True, "_stream_id": "sb"}) + await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sb"}) + + msgs = await c.recv_n(6) + sa = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sa") + sb = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sb") + assert sa == "A1A2" + assert sb == "B1B2" + finally: + await ch.stop(); await t + + +# -- Multi-client --------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_independent_sessions(bus: MagicMock) -> None: + ch = _ch(bus, 29913) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29913/", client_id="u1") as c1: + async with WsTestClient("ws://127.0.0.1:29913/", client_id="u2") as c2: + r1, r2 = await c1.recv_ready(), await c2.recv_ready() + await ch.send(OutboundMessage( + channel="websocket", chat_id=r1.chat_id, content="for-u1", + )) + assert (await c1.recv_message()).text == "for-u1" + await ch.send(OutboundMessage( + channel="websocket", chat_id=r2.chat_id, content="for-u2", + )) + assert (await c2.recv_message()).text == "for-u2" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_disconnected_client_cleanup(bus: MagicMock) -> None: + ch = _ch(bus, 29914) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29914/", client_id="tmp") as c: + chat_id = (await c.recv_ready()).chat_id + # disconnected + await ch.send(OutboundMessage( + channel="websocket", chat_id=chat_id, content="orphan", + )) + assert chat_id not in ch._connections + finally: + await ch.stop(); await t + + +# -- Authentication ------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_static_token_accepted(bus: MagicMock) -> None: + ch = _ch(bus, 29915, token="secret") + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c: + assert (await c.recv_ready()).client_id == "a" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_static_token_rejected(bus: MagicMock) -> None: + ch = _ch(bus, 29916, token="correct") + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as exc: + async with WsTestClient("ws://127.0.0.1:29916/", client_id="b", token="wrong"): + pass + assert exc.value.response.status_code == 401 + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_token_issue_full_flow(bus: MagicMock) -> None: + ch = _ch(bus, 29917, path="/ws", + tokenIssuePath="/auth/token", tokenIssueSecret="s", + websocketRequiresToken=True) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + # no secret -> 401 + _, status = await issue_token(port=29917, issue_path="/auth/token") + assert status == 401 + + # with secret -> token + token = await issue_token_ok(port=29917, issue_path="/auth/token", secret="s") + + # no token -> 401 + with pytest.raises(websockets.exceptions.InvalidStatus) as exc: + async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="x"): + pass + assert exc.value.response.status_code == 401 + + # valid token -> ok + async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="ok", token=token) as c: + assert (await c.recv_ready()).client_id == "ok" + + # reuse -> 401 + with pytest.raises(websockets.exceptions.InvalidStatus) as exc: + async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="r", token=token): + pass + assert exc.value.response.status_code == 401 + finally: + await ch.stop(); await t + + +# -- Path routing --------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_custom_path(bus: MagicMock) -> None: + ch = _ch(bus, 29918, path="/my-chat") + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c: + assert (await c.recv_ready()).event == "ready" + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_wrong_path_404(bus: MagicMock) -> None: + ch = _ch(bus, 29919, path="/ws") + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as exc: + async with WsTestClient("ws://127.0.0.1:29919/wrong", client_id="x"): + pass + assert exc.value.response.status_code == 404 + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_trailing_slash_normalized(bus: MagicMock) -> None: + ch = _ch(bus, 29920, path="/ws") + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c: + assert (await c.recv_ready()).event == "ready" + finally: + await ch.stop(); await t + + +# -- Edge cases ----------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_large_message(bus: MagicMock) -> None: + ch = _ch(bus, 29921) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29921/", client_id="big") as c: + await c.recv_ready() + big = "x" * 100_000 + await c.send_text(big) + await asyncio.sleep(0.2) + assert bus.publish_inbound.call_args[0][0].content == big + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_unicode_roundtrip(bus: MagicMock) -> None: + ch = _ch(bus, 29922) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29922/", client_id="u") as c: + ready = await c.recv_ready() + text = "你好世界 🌍 日本語テスト" + await c.send_text(text) + await asyncio.sleep(0.1) + assert bus.publish_inbound.call_args[0][0].content == text + await ch.send(OutboundMessage( + channel="websocket", chat_id=ready.chat_id, content=text, + )) + assert (await c.recv_message()).text == text + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_rapid_fire(bus: MagicMock) -> None: + ch = _ch(bus, 29923) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29923/", client_id="r") as c: + ready = await c.recv_ready() + for i in range(50): + await c.send_text(f"in-{i}") + await asyncio.sleep(0.5) + assert bus.publish_inbound.await_count == 50 + for i in range(50): + await ch.send(OutboundMessage( + channel="websocket", chat_id=ready.chat_id, content=f"out-{i}", + )) + received = [(await c.recv_message()).text for _ in range(50)] + assert received == [f"out-{i}" for i in range(50)] + finally: + await ch.stop(); await t + + +@pytest.mark.asyncio +async def test_invalid_json_as_plain_text(bus: MagicMock) -> None: + ch = _ch(bus, 29924) + t = asyncio.create_task(ch.start()) + await asyncio.sleep(0.3) + try: + async with WsTestClient("ws://127.0.0.1:29924/", client_id="j") as c: + await c.recv_ready() + await c.send_text("{broken json") + await asyncio.sleep(0.1) + assert bus.publish_inbound.call_args[0][0].content == "{broken json" + finally: + await ch.stop(); await t diff --git a/tests/channels/ws_test_client.py b/tests/channels/ws_test_client.py new file mode 100644 index 000000000..ec3ba1460 --- /dev/null +++ b/tests/channels/ws_test_client.py @@ -0,0 +1,227 @@ +"""Lightweight WebSocket test client for integration testing the nanobot WebSocket channel. + +Provides an async ``WsTestClient`` class and token-issuance helpers that +integration tests can import and use directly:: + + from ws_test_client import WsTestClient + + async with WsTestClient("ws://127.0.0.1:8765/", client_id="t") as c: + ready = await c.recv_ready() + await c.send_text("hello") + msg = await c.recv_message() +""" + +from __future__ import annotations + +import asyncio +import json +from dataclasses import dataclass, field +from typing import Any + +import httpx +import websockets +from websockets.asyncio.client import ClientConnection + + +@dataclass +class WsMessage: + """A parsed message received from the WebSocket server.""" + + event: str + raw: dict[str, Any] = field(repr=False) + + @property + def text(self) -> str | None: + return self.raw.get("text") + + @property + def chat_id(self) -> str | None: + return self.raw.get("chat_id") + + @property + def client_id(self) -> str | None: + return self.raw.get("client_id") + + @property + def media(self) -> list[str] | None: + return self.raw.get("media") + + @property + def reply_to(self) -> str | None: + return self.raw.get("reply_to") + + @property + def stream_id(self) -> str | None: + return self.raw.get("stream_id") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, WsMessage): + return NotImplemented + return self.event == other.event and self.raw == other.raw + + +class WsTestClient: + """Async WebSocket test client with helper methods for common operations. + + Usage:: + + async with WsTestClient("ws://127.0.0.1:8765/", client_id="tester") as client: + ready = await client.recv_ready() + await client.send_text("hello") + msg = await client.recv_message(timeout=5.0) + """ + + def __init__( + self, + uri: str, + *, + client_id: str = "test-client", + token: str = "", + extra_headers: dict[str, str] | None = None, + ) -> None: + params: list[str] = [] + if client_id: + params.append(f"client_id={client_id}") + if token: + params.append(f"token={token}") + sep = "&" if "?" in uri else "?" + self._uri = uri + sep + "&".join(params) if params else uri + self._extra_headers = extra_headers + self._ws: ClientConnection | None = None + + async def connect(self) -> None: + self._ws = await websockets.connect( + self._uri, + additional_headers=self._extra_headers, + ) + + async def close(self) -> None: + if self._ws: + await self._ws.close() + self._ws = None + + async def __aenter__(self) -> WsTestClient: + await self.connect() + return self + + async def __aexit__(self, *args: Any) -> None: + await self.close() + + @property + def ws(self) -> ClientConnection: + assert self._ws is not None, "Client is not connected" + return self._ws + + # -- Receiving -------------------------------------------------------- + + async def recv_raw(self, timeout: float = 10.0) -> dict[str, Any]: + """Receive and parse one raw JSON message with timeout.""" + raw = await asyncio.wait_for(self.ws.recv(), timeout=timeout) + return json.loads(raw) + + async def recv(self, timeout: float = 10.0) -> WsMessage: + """Receive one message, returning a WsMessage wrapper.""" + data = await self.recv_raw(timeout) + return WsMessage(event=data.get("event", ""), raw=data) + + async def recv_ready(self, timeout: float = 5.0) -> WsMessage: + """Receive and validate the 'ready' event.""" + msg = await self.recv(timeout) + assert msg.event == "ready", f"Expected 'ready' event, got '{msg.event}'" + return msg + + async def recv_message(self, timeout: float = 10.0) -> WsMessage: + """Receive and validate a 'message' event.""" + msg = await self.recv(timeout) + assert msg.event == "message", f"Expected 'message' event, got '{msg.event}'" + return msg + + async def recv_delta(self, timeout: float = 10.0) -> WsMessage: + """Receive and validate a 'delta' event.""" + msg = await self.recv(timeout) + assert msg.event == "delta", f"Expected 'delta' event, got '{msg.event}'" + return msg + + async def recv_stream_end(self, timeout: float = 10.0) -> WsMessage: + """Receive and validate a 'stream_end' event.""" + msg = await self.recv(timeout) + assert msg.event == "stream_end", f"Expected 'stream_end' event, got '{msg.event}'" + return msg + + async def collect_stream(self, timeout: float = 10.0) -> list[WsMessage]: + """Collect all deltas and the final stream_end into a list.""" + messages: list[WsMessage] = [] + while True: + msg = await self.recv(timeout) + messages.append(msg) + if msg.event == "stream_end": + break + return messages + + async def recv_n(self, n: int, timeout: float = 10.0) -> list[WsMessage]: + """Receive exactly *n* messages.""" + return [await self.recv(timeout) for _ in range(n)] + + # -- Sending ---------------------------------------------------------- + + async def send_text(self, text: str) -> None: + """Send a plain text frame.""" + await self.ws.send(text) + + async def send_json(self, data: dict[str, Any]) -> None: + """Send a JSON frame.""" + await self.ws.send(json.dumps(data, ensure_ascii=False)) + + async def send_content(self, content: str) -> None: + """Send content in the preferred JSON format ``{"content": ...}``.""" + await self.send_json({"content": content}) + + # -- Connection introspection ----------------------------------------- + + @property + def closed(self) -> bool: + return self._ws is None or self._ws.closed + + +# -- Token issuance helpers ----------------------------------------------- + + +async def issue_token( + host: str = "127.0.0.1", + port: int = 8765, + issue_path: str = "/auth/token", + secret: str = "", +) -> tuple[dict[str, Any] | None, int]: + """Request a short-lived token from the token-issue HTTP endpoint. + + Returns ``(parsed_json_or_None, status_code)``. + """ + url = f"http://{host}:{port}{issue_path}" + headers: dict[str, str] = {} + if secret: + headers["Authorization"] = f"Bearer {secret}" + + loop = asyncio.get_running_loop() + resp = await loop.run_in_executor( + None, lambda: httpx.get(url, headers=headers, timeout=5.0) + ) + try: + data = resp.json() + except Exception: + data = None + return data, resp.status_code + + +async def issue_token_ok( + host: str = "127.0.0.1", + port: int = 8765, + issue_path: str = "/auth/token", + secret: str = "", +) -> str: + """Request a token, asserting success, and return the token string.""" + (data, status) = await issue_token(host, port, issue_path, secret) + assert status == 200, f"Token issue failed with status {status}" + assert data is not None + token = data["token"] + assert token.startswith("nbwt_"), f"Unexpected token format: {token}" + return token