From e00dca2f84efe4a06d332a8449e454b1da730e9e Mon Sep 17 00:00:00 2001 From: Jack Lu <46274946+JackLuguibin@users.noreply.github.com> Date: Wed, 8 Apr 2026 21:39:35 +0800 Subject: [PATCH] feat(channels): add WebSocket server channel and tests Port Python implementation from a1ec7b192ad97ffd58250a720891ff09bbb73888 (websocket channel module and channel tests; excludes webui debug app). --- nanobot/channels/websocket.py | 418 +++++++++++++++++++++++ tests/channels/test_websocket_channel.py | 329 ++++++++++++++++++ 2 files changed, 747 insertions(+) create mode 100644 nanobot/channels/websocket.py create mode 100644 tests/channels/test_websocket_channel.py diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py new file mode 100644 index 000000000..e09e6303e --- /dev/null +++ b/nanobot/channels/websocket.py @@ -0,0 +1,418 @@ +"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients.""" + +from __future__ import annotations + +import asyncio +import email.utils +import hmac +import http +import json +import secrets +import ssl +import time +import uuid +from typing import Any, Self +from urllib.parse import parse_qs, urlparse + +from loguru import logger +from pydantic import Field, field_validator, model_validator +from websockets.asyncio.server import ServerConnection, serve +from websockets.datastructures import Headers +from websockets.http11 import Request as WsRequest, Response + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.base import BaseChannel +from nanobot.config.schema import Base + + +def _strip_trailing_slash(path: str) -> str: + if len(path) > 1 and path.endswith("/"): + return path.rstrip("/") + return path or "/" + + +def _normalize_config_path(path: str) -> str: + return _strip_trailing_slash(path) + + +class WebSocketConfig(Base): + """WebSocket server channel configuration. + + Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``. + - ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged. + - ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens + from ``token_issue_path`` are also accepted. + - ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON + ``{"token": "...", "expires_in": }``; use ``?token=...`` when opening the WebSocket. + Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as + nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call + blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine. + - ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer `` or + ``X-Nanobot-Auth: ``. + - ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired). + - Each connection has its own session: a unique ``chat_id`` maps to the agent session internally. + """ + + enabled: bool = False + host: str = "127.0.0.1" + port: int = 8765 + path: str = "/" + token: str = "" + token_issue_path: str = "" + token_issue_secret: str = "" + token_ttl_s: int = Field(default=300, ge=30, le=86_400) + websocket_requires_token: bool = False + allow_from: list[str] = Field(default_factory=lambda: ["*"]) + streaming: bool = True + max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216) + ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0) + ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0) + ssl_certfile: str = "" + ssl_keyfile: str = "" + + @field_validator("path") + @classmethod + def path_must_start_with_slash(cls, value: str) -> str: + if not value.startswith("/"): + raise ValueError('path must start with "/"') + return value + + @field_validator("token_issue_path") + @classmethod + def token_issue_path_format(cls, value: str) -> str: + value = value.strip() + if not value: + return "" + if not value.startswith("/"): + raise ValueError('token_issue_path must start with "/"') + return _normalize_config_path(value) + + @model_validator(mode="after") + def token_issue_path_differs_from_ws_path(self) -> Self: + if not self.token_issue_path: + return self + if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path): + raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)") + return self + + +def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response: + body = json.dumps(data, ensure_ascii=False).encode("utf-8") + headers = Headers( + [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", "application/json; charset=utf-8"), + ] + ) + reason = http.HTTPStatus(status).phrase + return Response(status, reason, headers, body) + + +def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]: + """Parse normalized path and query parameters in one pass.""" + parsed = urlparse("ws://x" + path_with_query) + path = _strip_trailing_slash(parsed.path or "/") + return path, parse_qs(parsed.query) + + +def _normalize_http_path(path_with_query: str) -> str: + """Return the path component (no query string), with trailing slash normalized (root stays ``/``).""" + return _parse_request_path(path_with_query)[0] + + +def _parse_query(path_with_query: str) -> dict[str, list[str]]: + return _parse_request_path(path_with_query)[1] + + +def _parse_inbound_payload(raw: str) -> str | None: + """Parse a client frame into text; return None for empty or unrecognized content.""" + text = raw.strip() + if not text: + return None + if text.startswith("{"): + try: + data = json.loads(text) + except json.JSONDecodeError: + return text + if isinstance(data, dict): + for key in ("content", "text", "message"): + value = data.get(key) + if isinstance(value, str) and value.strip(): + return value + return None + return None + return text + + +def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool: + """Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``.""" + if not configured_secret: + return True + authorization = headers.get("Authorization") or headers.get("authorization") + if authorization and authorization.lower().startswith("bearer "): + supplied = authorization[7:].strip() + return hmac.compare_digest(supplied, configured_secret) + header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth") + if not header_token: + return False + return hmac.compare_digest(header_token.strip(), configured_secret) + + +class WebSocketChannel(BaseChannel): + """Run a local WebSocket server; forward text/JSON messages to the message bus.""" + + name = "websocket" + display_name = "WebSocket" + + def __init__(self, config: Any, bus: MessageBus): + if isinstance(config, dict): + config = WebSocketConfig.model_validate(config) + super().__init__(config, bus) + self.config: WebSocketConfig = config + self._connections: dict[str, Any] = {} + self._issued_tokens: dict[str, float] = {} + self._stop_event: asyncio.Event | None = None + self._server_task: asyncio.Task[None] | None = None + + @classmethod + def default_config(cls) -> dict[str, Any]: + return WebSocketConfig().model_dump(by_alias=True) + + def _expected_path(self) -> str: + return _normalize_config_path(self.config.path) + + def _build_ssl_context(self) -> ssl.SSLContext | None: + cert = self.config.ssl_certfile.strip() + key = self.config.ssl_keyfile.strip() + if not cert and not key: + return None + if not cert or not key: + raise ValueError( + "websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty" + ) + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + ctx.load_cert_chain(certfile=cert, keyfile=key) + return ctx + + def _purge_expired_issued_tokens(self) -> None: + now = time.monotonic() + for token_key, expiry in list(self._issued_tokens.items()): + if now > expiry: + self._issued_tokens.pop(token_key, None) + + def _take_issued_token_if_valid(self, token_value: str | None) -> bool: + """Validate and consume one issued token (single use per connection attempt).""" + if not token_value: + return False + self._purge_expired_issued_tokens() + expiry = self._issued_tokens.get(token_value) + if expiry is None: + return False + if time.monotonic() > expiry: + self._issued_tokens.pop(token_value, None) + return False + self._issued_tokens.pop(token_value, None) + return True + + def _handle_token_issue_http(self, connection: Any, request: Any) -> Any: + secret = self.config.token_issue_secret.strip() + if secret: + if not _issue_route_secret_matches(request.headers, secret): + return connection.respond(401, "Unauthorized") + else: + logger.warning( + "websocket: token_issue_path is set but token_issue_secret is empty; " + "any client can obtain connection tokens — set token_issue_secret for production." + ) + self._purge_expired_issued_tokens() + token_value = f"nbwt_{secrets.token_urlsafe(32)}" + self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s) + + return _http_json_response( + {"token": token_value, "expires_in": self.config.token_ttl_s} + ) + + def _authorize_websocket_handshake(self, connection: Any, request_path: str) -> Any: + query = _parse_query(request_path) + supplied = (query.get("token") or [None])[0] + static_token = self.config.token.strip() + + if static_token: + if supplied == static_token: + return None + if supplied and self._take_issued_token_if_valid(supplied): + return None + return connection.respond(401, "Unauthorized") + + if self.config.websocket_requires_token: + if supplied and self._take_issued_token_if_valid(supplied): + return None + return connection.respond(401, "Unauthorized") + + if supplied: + self._take_issued_token_if_valid(supplied) + return None + + async def start(self) -> None: + self._running = True + self._stop_event = asyncio.Event() + + ssl_context = self._build_ssl_context() + scheme = "wss" if ssl_context else "ws" + + async def process_request( + connection: ServerConnection, + request: WsRequest, + ) -> Any: + got, _ = _parse_request_path(request.path) + if self.config.token_issue_path: + issue_expected = _normalize_config_path(self.config.token_issue_path) + if got == issue_expected: + return self._handle_token_issue_http(connection, request) + + expected_ws = self._expected_path() + if got != expected_ws: + return connection.respond(404, "Not Found") + return self._authorize_websocket_handshake(connection, request.path) + + async def handler(connection: ServerConnection) -> None: + await self._connection_loop(connection) + + logger.info( + "WebSocket server listening on {}://{}:{}{}", + scheme, + self.config.host, + self.config.port, + self.config.path, + ) + if self.config.token_issue_path: + logger.info( + "WebSocket token issue route: {}://{}:{}{}", + scheme, + self.config.host, + self.config.port, + _normalize_config_path(self.config.token_issue_path), + ) + + async def runner() -> None: + async with serve( + handler, + self.config.host, + self.config.port, + process_request=process_request, + max_size=self.config.max_message_bytes, + ping_interval=self.config.ping_interval_s, + ping_timeout=self.config.ping_timeout_s, + ssl=ssl_context, + ): + assert self._stop_event is not None + await self._stop_event.wait() + + self._server_task = asyncio.create_task(runner()) + await self._server_task + + async def _connection_loop(self, connection: Any) -> None: + request = connection.request + path_part = request.path if request else "/" + _, query = _parse_request_path(path_part) + client_id_raw = (query.get("client_id") or [None])[0] + client_id = client_id_raw.strip() if client_id_raw else "" + if not client_id: + client_id = f"anon-{uuid.uuid4().hex[:12]}" + + chat_id = str(uuid.uuid4()) + self._connections[chat_id] = connection + + await connection.send( + json.dumps( + { + "event": "ready", + "chat_id": chat_id, + "client_id": client_id, + }, + ensure_ascii=False, + ) + ) + + try: + async for raw in connection: + if isinstance(raw, bytes): + try: + raw = raw.decode("utf-8") + except UnicodeDecodeError: + logger.warning("websocket: ignoring non-utf8 binary frame") + continue + content = _parse_inbound_payload(raw) + if content is None: + continue + await self._handle_message( + sender_id=client_id, + chat_id=chat_id, + content=content, + metadata={"remote": getattr(connection, "remote_address", None)}, + ) + except Exception as e: + logger.debug("websocket connection ended: {}", e) + finally: + self._connections.pop(chat_id, None) + + async def stop(self) -> None: + self._running = False + if self._stop_event: + self._stop_event.set() + if self._server_task: + await self._server_task + self._server_task = None + self._connections.clear() + self._issued_tokens.clear() + + async def send(self, msg: OutboundMessage) -> None: + connection = self._connections.get(msg.chat_id) + if connection is None: + logger.warning("websocket: no active connection for chat_id={}", msg.chat_id) + return + payload: dict[str, Any] = { + "event": "message", + "text": msg.content, + } + if msg.media: + payload["media"] = msg.media + if msg.reply_to: + payload["reply_to"] = msg.reply_to + raw = json.dumps(payload, ensure_ascii=False) + try: + await connection.send(raw) + except Exception as e: + logger.error("websocket send failed: {}", e) + raise + + async def send_delta( + self, + chat_id: str, + delta: str, + metadata: dict[str, Any] | None = None, + ) -> None: + connection = self._connections.get(chat_id) + if connection is None: + return + meta = metadata or {} + if meta.get("_stream_end"): + body: dict[str, Any] = {"event": "stream_end"} + if meta.get("_stream_id") is not None: + body["stream_id"] = meta["_stream_id"] + else: + body = { + "event": "delta", + "text": delta, + } + if meta.get("_stream_id") is not None: + body["stream_id"] = meta["_stream_id"] + raw = json.dumps(body, ensure_ascii=False) + try: + await connection.send(raw) + except Exception as e: + logger.error("websocket stream send failed: {}", e) + raise diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py new file mode 100644 index 000000000..e4c5ad635 --- /dev/null +++ b/tests/channels/test_websocket_channel.py @@ -0,0 +1,329 @@ +"""Unit and lightweight integration tests for the WebSocket channel.""" + +import asyncio +import functools +import json +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest +import websockets + +from nanobot.bus.events import OutboundMessage +from nanobot.channels.websocket import ( + WebSocketChannel, + WebSocketConfig, + _issue_route_secret_matches, + _normalize_config_path, + _normalize_http_path, + _parse_inbound_payload, + _parse_query, + _parse_request_path, +) + + +async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response: + """Run GET in a thread to avoid blocking the asyncio loop shared with websockets.""" + return await asyncio.to_thread( + functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0) + ) + + +def test_normalize_http_path_strips_trailing_slash_except_root() -> None: + assert _normalize_http_path("/chat/") == "/chat" + assert _normalize_http_path("/chat?x=1") == "/chat" + assert _normalize_http_path("/") == "/" + + +def test_parse_request_path_matches_normalize_and_query() -> None: + path, query = _parse_request_path("/ws/?token=secret&client_id=u1") + assert path == _normalize_http_path("/ws/?token=secret&client_id=u1") + assert query == _parse_query("/ws/?token=secret&client_id=u1") + + +def test_normalize_config_path_matches_request() -> None: + assert _normalize_config_path("/ws/") == "/ws" + assert _normalize_config_path("/") == "/" + + +def test_parse_query_extracts_token_and_client_id() -> None: + query = _parse_query("/?token=secret&client_id=u1") + assert query.get("token") == ["secret"] + assert query.get("client_id") == ["u1"] + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + ("plain", "plain"), + ('{"content": "hi"}', "hi"), + ('{"text": "there"}', "there"), + ('{"message": "x"}', "x"), + (" ", None), + ("{}", None), + ], +) +def test_parse_inbound_payload(raw: str, expected: str | None) -> None: + assert _parse_inbound_payload(raw) == expected + + +def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None: + assert _parse_inbound_payload("{not json") == "{not json" + + +def test_web_socket_config_path_must_start_with_slash() -> None: + with pytest.raises(ValueError, match='path must start with "/"'): + WebSocketConfig(path="bad") + + +def test_ssl_context_requires_both_cert_and_key_files() -> None: + bus = MagicMock() + channel = WebSocketChannel( + {"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""}, + bus, + ) + with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"): + channel._build_ssl_context() + + +def test_default_config_includes_safe_bind_and_streaming() -> None: + defaults = WebSocketChannel.default_config() + assert defaults["enabled"] is False + assert defaults["host"] == "127.0.0.1" + assert defaults["streaming"] is True + assert defaults["allowFrom"] == ["*"] + assert defaults.get("tokenIssuePath", "") == "" + + +def test_token_issue_path_must_differ_from_websocket_path() -> None: + with pytest.raises(ValueError, match="token_issue_path must differ"): + WebSocketConfig(path="/ws", token_issue_path="/ws") + + +def test_issue_route_secret_matches_bearer_and_header() -> None: + from websockets.datastructures import Headers + + secret = "my-secret" + bearer_headers = Headers([("Authorization", "Bearer my-secret")]) + assert _issue_route_secret_matches(bearer_headers, secret) is True + x_headers = Headers([("X-Nanobot-Auth", "my-secret")]) + assert _issue_route_secret_matches(x_headers, secret) is True + wrong = Headers([("Authorization", "Bearer other")]) + assert _issue_route_secret_matches(wrong, secret) is False + + +@pytest.mark.asyncio +async def test_send_delivers_json_message_with_media_and_reply() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + mock_ws = AsyncMock() + channel._connections["chat-1"] = mock_ws + + msg = OutboundMessage( + channel="websocket", + chat_id="chat-1", + content="hello", + reply_to="m1", + media=["/tmp/a.png"], + ) + await channel.send(msg) + + mock_ws.send.assert_awaited_once() + payload = json.loads(mock_ws.send.call_args[0][0]) + assert payload["event"] == "message" + assert payload["text"] == "hello" + assert payload["reply_to"] == "m1" + assert payload["media"] == ["/tmp/a.png"] + + +@pytest.mark.asyncio +async def test_send_missing_connection_is_noop_without_error() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) + msg = OutboundMessage(channel="websocket", chat_id="missing", content="x") + await channel.send(msg) + + +@pytest.mark.asyncio +async def test_send_delta_emits_delta_and_stream_end() -> None: + bus = MagicMock() + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus) + mock_ws = AsyncMock() + channel._connections["chat-1"] = mock_ws + + await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"}) + await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"}) + + assert mock_ws.send.await_count == 2 + first = json.loads(mock_ws.send.call_args_list[0][0][0]) + second = json.loads(mock_ws.send.call_args_list[1][0][0]) + assert first["event"] == "delta" + assert first["text"] == "part" + assert first["stream_id"] == "sid" + assert second["event"] == "stream_end" + assert second["stream_id"] == "sid" + + +@pytest.mark.asyncio +async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None: + bus = MagicMock() + bus.publish_inbound = AsyncMock() + port = 29876 + channel = WebSocketChannel( + { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": port, + "path": "/ws", + }, + bus, + ) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client: + ready_raw = await client.recv() + ready = json.loads(ready_raw) + assert ready["event"] == "ready" + assert ready["client_id"] == "tester" + chat_id = ready["chat_id"] + + await client.send(json.dumps({"content": "ping from client"})) + await asyncio.sleep(0.08) + + bus.publish_inbound.assert_awaited() + inbound = bus.publish_inbound.call_args[0][0] + assert inbound.channel == "websocket" + assert inbound.sender_id == "tester" + assert inbound.chat_id == chat_id + assert inbound.content == "ping from client" + + await client.send("plain text frame") + await asyncio.sleep(0.08) + assert bus.publish_inbound.await_count >= 2 + second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1] + assert second.content == "plain text frame" + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_token_rejects_handshake_when_mismatch() -> None: + bus = MagicMock() + port = 29877 + channel = WebSocketChannel( + { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": port, + "path": "/", + "token": "secret", + }, + bus, + ) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo: + async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"): + pass + assert excinfo.value.response.status_code == 401 + finally: + await channel.stop() + await server_task + + +@pytest.mark.asyncio +async def test_wrong_path_returns_404() -> None: + bus = MagicMock() + port = 29878 + channel = WebSocketChannel( + { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": port, + "path": "/ws", + }, + bus, + ) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo: + async with websockets.connect(f"ws://127.0.0.1:{port}/other"): + pass + assert excinfo.value.response.status_code == 404 + finally: + await channel.stop() + await server_task + + +def test_registry_discovers_websocket_channel() -> None: + from nanobot.channels.registry import load_channel_class + + cls = load_channel_class("websocket") + assert cls.name == "websocket" + + +@pytest.mark.asyncio +async def test_http_route_issues_token_then_websocket_requires_it() -> None: + bus = MagicMock() + bus.publish_inbound = AsyncMock() + port = 29879 + channel = WebSocketChannel( + { + "enabled": True, + "allowFrom": ["*"], + "host": "127.0.0.1", + "port": port, + "path": "/ws", + "tokenIssuePath": "/auth/token", + "tokenIssueSecret": "route-secret", + "websocketRequiresToken": True, + }, + bus, + ) + + server_task = asyncio.create_task(channel.start()) + await asyncio.sleep(0.3) + + try: + deny = await _http_get(f"http://127.0.0.1:{port}/auth/token") + assert deny.status_code == 401 + + issue = await _http_get( + f"http://127.0.0.1:{port}/auth/token", + headers={"Authorization": "Bearer route-secret"}, + ) + assert issue.status_code == 200 + token = issue.json()["token"] + assert token.startswith("nbwt_") + + with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token: + async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"): + pass + assert missing_token.value.response.status_code == 401 + + uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller" + async with websockets.connect(uri) as client: + ready = json.loads(await client.recv()) + assert ready["event"] == "ready" + assert ready["client_id"] == "caller" + + with pytest.raises(websockets.exceptions.InvalidStatus) as reuse: + async with websockets.connect(uri): + pass + assert reuse.value.response.status_code == 401 + finally: + await channel.stop() + await server_task