feat(channels): add WebSocket server channel and tests

Port Python implementation from a1ec7b192ad97ffd58250a720891ff09bbb73888
(websocket channel module and channel tests; excludes webui debug app).
This commit is contained in:
Jack Lu 2026-04-08 21:39:35 +08:00 committed by chengyongru
parent 51200a954c
commit e00dca2f84
2 changed files with 747 additions and 0 deletions

View File

@ -0,0 +1,418 @@
"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
from __future__ import annotations
import asyncio
import email.utils
import hmac
import http
import json
import secrets
import ssl
import time
import uuid
from typing import Any, Self
from urllib.parse import parse_qs, urlparse
from loguru import logger
from pydantic import Field, field_validator, model_validator
from websockets.asyncio.server import ServerConnection, serve
from websockets.datastructures import Headers
from websockets.http11 import Request as WsRequest, Response
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.schema import Base
def _strip_trailing_slash(path: str) -> str:
if len(path) > 1 and path.endswith("/"):
return path.rstrip("/")
return path or "/"
def _normalize_config_path(path: str) -> str:
return _strip_trailing_slash(path)
class WebSocketConfig(Base):
"""WebSocket server channel configuration.
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
from ``token_issue_path`` are also accepted.
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
nanobot and shares the asyncio loop, use a thread or async HTTP client for GETdo not call
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
``X-Nanobot-Auth: <secret>``.
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
"""
enabled: bool = False
host: str = "127.0.0.1"
port: int = 8765
path: str = "/"
token: str = ""
token_issue_path: str = ""
token_issue_secret: str = ""
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
websocket_requires_token: bool = False
allow_from: list[str] = Field(default_factory=lambda: ["*"])
streaming: bool = True
max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216)
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
ssl_certfile: str = ""
ssl_keyfile: str = ""
@field_validator("path")
@classmethod
def path_must_start_with_slash(cls, value: str) -> str:
if not value.startswith("/"):
raise ValueError('path must start with "/"')
return value
@field_validator("token_issue_path")
@classmethod
def token_issue_path_format(cls, value: str) -> str:
value = value.strip()
if not value:
return ""
if not value.startswith("/"):
raise ValueError('token_issue_path must start with "/"')
return _normalize_config_path(value)
@model_validator(mode="after")
def token_issue_path_differs_from_ws_path(self) -> Self:
if not self.token_issue_path:
return self
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
return self
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
headers = Headers(
[
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", "application/json; charset=utf-8"),
]
)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, headers, body)
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
"""Parse normalized path and query parameters in one pass."""
parsed = urlparse("ws://x" + path_with_query)
path = _strip_trailing_slash(parsed.path or "/")
return path, parse_qs(parsed.query)
def _normalize_http_path(path_with_query: str) -> str:
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
return _parse_request_path(path_with_query)[0]
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
return _parse_request_path(path_with_query)[1]
def _parse_inbound_payload(raw: str) -> str | None:
"""Parse a client frame into text; return None for empty or unrecognized content."""
text = raw.strip()
if not text:
return None
if text.startswith("{"):
try:
data = json.loads(text)
except json.JSONDecodeError:
return text
if isinstance(data, dict):
for key in ("content", "text", "message"):
value = data.get(key)
if isinstance(value, str) and value.strip():
return value
return None
return None
return text
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
if not configured_secret:
return True
authorization = headers.get("Authorization") or headers.get("authorization")
if authorization and authorization.lower().startswith("bearer "):
supplied = authorization[7:].strip()
return hmac.compare_digest(supplied, configured_secret)
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
if not header_token:
return False
return hmac.compare_digest(header_token.strip(), configured_secret)
class WebSocketChannel(BaseChannel):
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
name = "websocket"
display_name = "WebSocket"
def __init__(self, config: Any, bus: MessageBus):
if isinstance(config, dict):
config = WebSocketConfig.model_validate(config)
super().__init__(config, bus)
self.config: WebSocketConfig = config
self._connections: dict[str, Any] = {}
self._issued_tokens: dict[str, float] = {}
self._stop_event: asyncio.Event | None = None
self._server_task: asyncio.Task[None] | None = None
@classmethod
def default_config(cls) -> dict[str, Any]:
return WebSocketConfig().model_dump(by_alias=True)
def _expected_path(self) -> str:
return _normalize_config_path(self.config.path)
def _build_ssl_context(self) -> ssl.SSLContext | None:
cert = self.config.ssl_certfile.strip()
key = self.config.ssl_keyfile.strip()
if not cert and not key:
return None
if not cert or not key:
raise ValueError(
"websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
)
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.load_cert_chain(certfile=cert, keyfile=key)
return ctx
def _purge_expired_issued_tokens(self) -> None:
now = time.monotonic()
for token_key, expiry in list(self._issued_tokens.items()):
if now > expiry:
self._issued_tokens.pop(token_key, None)
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
"""Validate and consume one issued token (single use per connection attempt)."""
if not token_value:
return False
self._purge_expired_issued_tokens()
expiry = self._issued_tokens.get(token_value)
if expiry is None:
return False
if time.monotonic() > expiry:
self._issued_tokens.pop(token_value, None)
return False
self._issued_tokens.pop(token_value, None)
return True
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
secret = self.config.token_issue_secret.strip()
if secret:
if not _issue_route_secret_matches(request.headers, secret):
return connection.respond(401, "Unauthorized")
else:
logger.warning(
"websocket: token_issue_path is set but token_issue_secret is empty; "
"any client can obtain connection tokens — set token_issue_secret for production."
)
self._purge_expired_issued_tokens()
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
return _http_json_response(
{"token": token_value, "expires_in": self.config.token_ttl_s}
)
def _authorize_websocket_handshake(self, connection: Any, request_path: str) -> Any:
query = _parse_query(request_path)
supplied = (query.get("token") or [None])[0]
static_token = self.config.token.strip()
if static_token:
if supplied == static_token:
return None
if supplied and self._take_issued_token_if_valid(supplied):
return None
return connection.respond(401, "Unauthorized")
if self.config.websocket_requires_token:
if supplied and self._take_issued_token_if_valid(supplied):
return None
return connection.respond(401, "Unauthorized")
if supplied:
self._take_issued_token_if_valid(supplied)
return None
async def start(self) -> None:
self._running = True
self._stop_event = asyncio.Event()
ssl_context = self._build_ssl_context()
scheme = "wss" if ssl_context else "ws"
async def process_request(
connection: ServerConnection,
request: WsRequest,
) -> Any:
got, _ = _parse_request_path(request.path)
if self.config.token_issue_path:
issue_expected = _normalize_config_path(self.config.token_issue_path)
if got == issue_expected:
return self._handle_token_issue_http(connection, request)
expected_ws = self._expected_path()
if got != expected_ws:
return connection.respond(404, "Not Found")
return self._authorize_websocket_handshake(connection, request.path)
async def handler(connection: ServerConnection) -> None:
await self._connection_loop(connection)
logger.info(
"WebSocket server listening on {}://{}:{}{}",
scheme,
self.config.host,
self.config.port,
self.config.path,
)
if self.config.token_issue_path:
logger.info(
"WebSocket token issue route: {}://{}:{}{}",
scheme,
self.config.host,
self.config.port,
_normalize_config_path(self.config.token_issue_path),
)
async def runner() -> None:
async with serve(
handler,
self.config.host,
self.config.port,
process_request=process_request,
max_size=self.config.max_message_bytes,
ping_interval=self.config.ping_interval_s,
ping_timeout=self.config.ping_timeout_s,
ssl=ssl_context,
):
assert self._stop_event is not None
await self._stop_event.wait()
self._server_task = asyncio.create_task(runner())
await self._server_task
async def _connection_loop(self, connection: Any) -> None:
request = connection.request
path_part = request.path if request else "/"
_, query = _parse_request_path(path_part)
client_id_raw = (query.get("client_id") or [None])[0]
client_id = client_id_raw.strip() if client_id_raw else ""
if not client_id:
client_id = f"anon-{uuid.uuid4().hex[:12]}"
chat_id = str(uuid.uuid4())
self._connections[chat_id] = connection
await connection.send(
json.dumps(
{
"event": "ready",
"chat_id": chat_id,
"client_id": client_id,
},
ensure_ascii=False,
)
)
try:
async for raw in connection:
if isinstance(raw, bytes):
try:
raw = raw.decode("utf-8")
except UnicodeDecodeError:
logger.warning("websocket: ignoring non-utf8 binary frame")
continue
content = _parse_inbound_payload(raw)
if content is None:
continue
await self._handle_message(
sender_id=client_id,
chat_id=chat_id,
content=content,
metadata={"remote": getattr(connection, "remote_address", None)},
)
except Exception as e:
logger.debug("websocket connection ended: {}", e)
finally:
self._connections.pop(chat_id, None)
async def stop(self) -> None:
self._running = False
if self._stop_event:
self._stop_event.set()
if self._server_task:
await self._server_task
self._server_task = None
self._connections.clear()
self._issued_tokens.clear()
async def send(self, msg: OutboundMessage) -> None:
connection = self._connections.get(msg.chat_id)
if connection is None:
logger.warning("websocket: no active connection for chat_id={}", msg.chat_id)
return
payload: dict[str, Any] = {
"event": "message",
"text": msg.content,
}
if msg.media:
payload["media"] = msg.media
if msg.reply_to:
payload["reply_to"] = msg.reply_to
raw = json.dumps(payload, ensure_ascii=False)
try:
await connection.send(raw)
except Exception as e:
logger.error("websocket send failed: {}", e)
raise
async def send_delta(
self,
chat_id: str,
delta: str,
metadata: dict[str, Any] | None = None,
) -> None:
connection = self._connections.get(chat_id)
if connection is None:
return
meta = metadata or {}
if meta.get("_stream_end"):
body: dict[str, Any] = {"event": "stream_end"}
if meta.get("_stream_id") is not None:
body["stream_id"] = meta["_stream_id"]
else:
body = {
"event": "delta",
"text": delta,
}
if meta.get("_stream_id") is not None:
body["stream_id"] = meta["_stream_id"]
raw = json.dumps(body, ensure_ascii=False)
try:
await connection.send(raw)
except Exception as e:
logger.error("websocket stream send failed: {}", e)
raise

View File

@ -0,0 +1,329 @@
"""Unit and lightweight integration tests for the WebSocket channel."""
import asyncio
import functools
import json
from unittest.mock import AsyncMock, MagicMock
import httpx
import pytest
import websockets
from nanobot.bus.events import OutboundMessage
from nanobot.channels.websocket import (
WebSocketChannel,
WebSocketConfig,
_issue_route_secret_matches,
_normalize_config_path,
_normalize_http_path,
_parse_inbound_payload,
_parse_query,
_parse_request_path,
)
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
return await asyncio.to_thread(
functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0)
)
def test_normalize_http_path_strips_trailing_slash_except_root() -> None:
assert _normalize_http_path("/chat/") == "/chat"
assert _normalize_http_path("/chat?x=1") == "/chat"
assert _normalize_http_path("/") == "/"
def test_parse_request_path_matches_normalize_and_query() -> None:
path, query = _parse_request_path("/ws/?token=secret&client_id=u1")
assert path == _normalize_http_path("/ws/?token=secret&client_id=u1")
assert query == _parse_query("/ws/?token=secret&client_id=u1")
def test_normalize_config_path_matches_request() -> None:
assert _normalize_config_path("/ws/") == "/ws"
assert _normalize_config_path("/") == "/"
def test_parse_query_extracts_token_and_client_id() -> None:
query = _parse_query("/?token=secret&client_id=u1")
assert query.get("token") == ["secret"]
assert query.get("client_id") == ["u1"]
@pytest.mark.parametrize(
("raw", "expected"),
[
("plain", "plain"),
('{"content": "hi"}', "hi"),
('{"text": "there"}', "there"),
('{"message": "x"}', "x"),
(" ", None),
("{}", None),
],
)
def test_parse_inbound_payload(raw: str, expected: str | None) -> None:
assert _parse_inbound_payload(raw) == expected
def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
assert _parse_inbound_payload("{not json") == "{not json"
def test_web_socket_config_path_must_start_with_slash() -> None:
with pytest.raises(ValueError, match='path must start with "/"'):
WebSocketConfig(path="bad")
def test_ssl_context_requires_both_cert_and_key_files() -> None:
bus = MagicMock()
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
bus,
)
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
channel._build_ssl_context()
def test_default_config_includes_safe_bind_and_streaming() -> None:
defaults = WebSocketChannel.default_config()
assert defaults["enabled"] is False
assert defaults["host"] == "127.0.0.1"
assert defaults["streaming"] is True
assert defaults["allowFrom"] == ["*"]
assert defaults.get("tokenIssuePath", "") == ""
def test_token_issue_path_must_differ_from_websocket_path() -> None:
with pytest.raises(ValueError, match="token_issue_path must differ"):
WebSocketConfig(path="/ws", token_issue_path="/ws")
def test_issue_route_secret_matches_bearer_and_header() -> None:
from websockets.datastructures import Headers
secret = "my-secret"
bearer_headers = Headers([("Authorization", "Bearer my-secret")])
assert _issue_route_secret_matches(bearer_headers, secret) is True
x_headers = Headers([("X-Nanobot-Auth", "my-secret")])
assert _issue_route_secret_matches(x_headers, secret) is True
wrong = Headers([("Authorization", "Bearer other")])
assert _issue_route_secret_matches(wrong, secret) is False
@pytest.mark.asyncio
async def test_send_delivers_json_message_with_media_and_reply() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock()
channel._connections["chat-1"] = mock_ws
msg = OutboundMessage(
channel="websocket",
chat_id="chat-1",
content="hello",
reply_to="m1",
media=["/tmp/a.png"],
)
await channel.send(msg)
mock_ws.send.assert_awaited_once()
payload = json.loads(mock_ws.send.call_args[0][0])
assert payload["event"] == "message"
assert payload["text"] == "hello"
assert payload["reply_to"] == "m1"
assert payload["media"] == ["/tmp/a.png"]
@pytest.mark.asyncio
async def test_send_missing_connection_is_noop_without_error() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
await channel.send(msg)
@pytest.mark.asyncio
async def test_send_delta_emits_delta_and_stream_end() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
mock_ws = AsyncMock()
channel._connections["chat-1"] = mock_ws
await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"})
await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"})
assert mock_ws.send.await_count == 2
first = json.loads(mock_ws.send.call_args_list[0][0][0])
second = json.loads(mock_ws.send.call_args_list[1][0][0])
assert first["event"] == "delta"
assert first["text"] == "part"
assert first["stream_id"] == "sid"
assert second["event"] == "stream_end"
assert second["stream_id"] == "sid"
@pytest.mark.asyncio
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None:
bus = MagicMock()
bus.publish_inbound = AsyncMock()
port = 29876
channel = WebSocketChannel(
{
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/ws",
},
bus,
)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client:
ready_raw = await client.recv()
ready = json.loads(ready_raw)
assert ready["event"] == "ready"
assert ready["client_id"] == "tester"
chat_id = ready["chat_id"]
await client.send(json.dumps({"content": "ping from client"}))
await asyncio.sleep(0.08)
bus.publish_inbound.assert_awaited()
inbound = bus.publish_inbound.call_args[0][0]
assert inbound.channel == "websocket"
assert inbound.sender_id == "tester"
assert inbound.chat_id == chat_id
assert inbound.content == "ping from client"
await client.send("plain text frame")
await asyncio.sleep(0.08)
assert bus.publish_inbound.await_count >= 2
second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1]
assert second.content == "plain text frame"
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_token_rejects_handshake_when_mismatch() -> None:
bus = MagicMock()
port = 29877
channel = WebSocketChannel(
{
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/",
"token": "secret",
},
bus,
)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"):
pass
assert excinfo.value.response.status_code == 401
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_wrong_path_returns_404() -> None:
bus = MagicMock()
port = 29878
channel = WebSocketChannel(
{
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/ws",
},
bus,
)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
async with websockets.connect(f"ws://127.0.0.1:{port}/other"):
pass
assert excinfo.value.response.status_code == 404
finally:
await channel.stop()
await server_task
def test_registry_discovers_websocket_channel() -> None:
from nanobot.channels.registry import load_channel_class
cls = load_channel_class("websocket")
assert cls.name == "websocket"
@pytest.mark.asyncio
async def test_http_route_issues_token_then_websocket_requires_it() -> None:
bus = MagicMock()
bus.publish_inbound = AsyncMock()
port = 29879
channel = WebSocketChannel(
{
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/ws",
"tokenIssuePath": "/auth/token",
"tokenIssueSecret": "route-secret",
"websocketRequiresToken": True,
},
bus,
)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
deny = await _http_get(f"http://127.0.0.1:{port}/auth/token")
assert deny.status_code == 401
issue = await _http_get(
f"http://127.0.0.1:{port}/auth/token",
headers={"Authorization": "Bearer route-secret"},
)
assert issue.status_code == 200
token = issue.json()["token"]
assert token.startswith("nbwt_")
with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"):
pass
assert missing_token.value.response.status_code == 401
uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller"
async with websockets.connect(uri) as client:
ready = json.loads(await client.recv())
assert ready["event"] == "ready"
assert ready["client_id"] == "caller"
with pytest.raises(websockets.exceptions.InvalidStatus) as reuse:
async with websockets.connect(uri):
pass
assert reuse.value.response.status_code == 401
finally:
await channel.stop()
await server_task