mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-02 07:45:54 +00:00
feat(channels): add WebSocket server channel and tests
Port Python implementation from a1ec7b192ad97ffd58250a720891ff09bbb73888 (websocket channel module and channel tests; excludes webui debug app).
This commit is contained in:
parent
51200a954c
commit
e00dca2f84
418
nanobot/channels/websocket.py
Normal file
418
nanobot/channels/websocket.py
Normal file
@ -0,0 +1,418 @@
|
|||||||
|
"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import email.utils
|
||||||
|
import hmac
|
||||||
|
import http
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import ssl
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Self
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import Field, field_validator, model_validator
|
||||||
|
from websockets.asyncio.server import ServerConnection, serve
|
||||||
|
from websockets.datastructures import Headers
|
||||||
|
from websockets.http11 import Request as WsRequest, Response
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.base import BaseChannel
|
||||||
|
from nanobot.config.schema import Base
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_trailing_slash(path: str) -> str:
|
||||||
|
if len(path) > 1 and path.endswith("/"):
|
||||||
|
return path.rstrip("/")
|
||||||
|
return path or "/"
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_config_path(path: str) -> str:
|
||||||
|
return _strip_trailing_slash(path)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketConfig(Base):
|
||||||
|
"""WebSocket server channel configuration.
|
||||||
|
|
||||||
|
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
|
||||||
|
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
|
||||||
|
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
|
||||||
|
from ``token_issue_path`` are also accepted.
|
||||||
|
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
|
||||||
|
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
|
||||||
|
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
|
||||||
|
nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call
|
||||||
|
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
|
||||||
|
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
|
||||||
|
``X-Nanobot-Auth: <secret>``.
|
||||||
|
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
|
||||||
|
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
host: str = "127.0.0.1"
|
||||||
|
port: int = 8765
|
||||||
|
path: str = "/"
|
||||||
|
token: str = ""
|
||||||
|
token_issue_path: str = ""
|
||||||
|
token_issue_secret: str = ""
|
||||||
|
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
|
||||||
|
websocket_requires_token: bool = False
|
||||||
|
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
||||||
|
streaming: bool = True
|
||||||
|
max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216)
|
||||||
|
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||||||
|
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||||||
|
ssl_certfile: str = ""
|
||||||
|
ssl_keyfile: str = ""
|
||||||
|
|
||||||
|
@field_validator("path")
|
||||||
|
@classmethod
|
||||||
|
def path_must_start_with_slash(cls, value: str) -> str:
|
||||||
|
if not value.startswith("/"):
|
||||||
|
raise ValueError('path must start with "/"')
|
||||||
|
return value
|
||||||
|
|
||||||
|
@field_validator("token_issue_path")
|
||||||
|
@classmethod
|
||||||
|
def token_issue_path_format(cls, value: str) -> str:
|
||||||
|
value = value.strip()
|
||||||
|
if not value:
|
||||||
|
return ""
|
||||||
|
if not value.startswith("/"):
|
||||||
|
raise ValueError('token_issue_path must start with "/"')
|
||||||
|
return _normalize_config_path(value)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def token_issue_path_differs_from_ws_path(self) -> Self:
|
||||||
|
if not self.token_issue_path:
|
||||||
|
return self
|
||||||
|
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
|
||||||
|
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||||
|
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||||
|
headers = Headers(
|
||||||
|
[
|
||||||
|
("Date", email.utils.formatdate(usegmt=True)),
|
||||||
|
("Connection", "close"),
|
||||||
|
("Content-Length", str(len(body))),
|
||||||
|
("Content-Type", "application/json; charset=utf-8"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
reason = http.HTTPStatus(status).phrase
|
||||||
|
return Response(status, reason, headers, body)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||||||
|
"""Parse normalized path and query parameters in one pass."""
|
||||||
|
parsed = urlparse("ws://x" + path_with_query)
|
||||||
|
path = _strip_trailing_slash(parsed.path or "/")
|
||||||
|
return path, parse_qs(parsed.query)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_http_path(path_with_query: str) -> str:
|
||||||
|
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
|
||||||
|
return _parse_request_path(path_with_query)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||||||
|
return _parse_request_path(path_with_query)[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_inbound_payload(raw: str) -> str | None:
|
||||||
|
"""Parse a client frame into text; return None for empty or unrecognized content."""
|
||||||
|
text = raw.strip()
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
if text.startswith("{"):
|
||||||
|
try:
|
||||||
|
data = json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return text
|
||||||
|
if isinstance(data, dict):
|
||||||
|
for key in ("content", "text", "message"):
|
||||||
|
value = data.get(key)
|
||||||
|
if isinstance(value, str) and value.strip():
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||||||
|
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
|
||||||
|
if not configured_secret:
|
||||||
|
return True
|
||||||
|
authorization = headers.get("Authorization") or headers.get("authorization")
|
||||||
|
if authorization and authorization.lower().startswith("bearer "):
|
||||||
|
supplied = authorization[7:].strip()
|
||||||
|
return hmac.compare_digest(supplied, configured_secret)
|
||||||
|
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||||||
|
if not header_token:
|
||||||
|
return False
|
||||||
|
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketChannel(BaseChannel):
|
||||||
|
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
|
||||||
|
|
||||||
|
name = "websocket"
|
||||||
|
display_name = "WebSocket"
|
||||||
|
|
||||||
|
def __init__(self, config: Any, bus: MessageBus):
|
||||||
|
if isinstance(config, dict):
|
||||||
|
config = WebSocketConfig.model_validate(config)
|
||||||
|
super().__init__(config, bus)
|
||||||
|
self.config: WebSocketConfig = config
|
||||||
|
self._connections: dict[str, Any] = {}
|
||||||
|
self._issued_tokens: dict[str, float] = {}
|
||||||
|
self._stop_event: asyncio.Event | None = None
|
||||||
|
self._server_task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_config(cls) -> dict[str, Any]:
|
||||||
|
return WebSocketConfig().model_dump(by_alias=True)
|
||||||
|
|
||||||
|
def _expected_path(self) -> str:
|
||||||
|
return _normalize_config_path(self.config.path)
|
||||||
|
|
||||||
|
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
||||||
|
cert = self.config.ssl_certfile.strip()
|
||||||
|
key = self.config.ssl_keyfile.strip()
|
||||||
|
if not cert and not key:
|
||||||
|
return None
|
||||||
|
if not cert or not key:
|
||||||
|
raise ValueError(
|
||||||
|
"websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
||||||
|
)
|
||||||
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
def _purge_expired_issued_tokens(self) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
for token_key, expiry in list(self._issued_tokens.items()):
|
||||||
|
if now > expiry:
|
||||||
|
self._issued_tokens.pop(token_key, None)
|
||||||
|
|
||||||
|
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||||||
|
"""Validate and consume one issued token (single use per connection attempt)."""
|
||||||
|
if not token_value:
|
||||||
|
return False
|
||||||
|
self._purge_expired_issued_tokens()
|
||||||
|
expiry = self._issued_tokens.get(token_value)
|
||||||
|
if expiry is None:
|
||||||
|
return False
|
||||||
|
if time.monotonic() > expiry:
|
||||||
|
self._issued_tokens.pop(token_value, None)
|
||||||
|
return False
|
||||||
|
self._issued_tokens.pop(token_value, None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
|
||||||
|
secret = self.config.token_issue_secret.strip()
|
||||||
|
if secret:
|
||||||
|
if not _issue_route_secret_matches(request.headers, secret):
|
||||||
|
return connection.respond(401, "Unauthorized")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"websocket: token_issue_path is set but token_issue_secret is empty; "
|
||||||
|
"any client can obtain connection tokens — set token_issue_secret for production."
|
||||||
|
)
|
||||||
|
self._purge_expired_issued_tokens()
|
||||||
|
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||||
|
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||||||
|
|
||||||
|
return _http_json_response(
|
||||||
|
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _authorize_websocket_handshake(self, connection: Any, request_path: str) -> Any:
|
||||||
|
query = _parse_query(request_path)
|
||||||
|
supplied = (query.get("token") or [None])[0]
|
||||||
|
static_token = self.config.token.strip()
|
||||||
|
|
||||||
|
if static_token:
|
||||||
|
if supplied == static_token:
|
||||||
|
return None
|
||||||
|
if supplied and self._take_issued_token_if_valid(supplied):
|
||||||
|
return None
|
||||||
|
return connection.respond(401, "Unauthorized")
|
||||||
|
|
||||||
|
if self.config.websocket_requires_token:
|
||||||
|
if supplied and self._take_issued_token_if_valid(supplied):
|
||||||
|
return None
|
||||||
|
return connection.respond(401, "Unauthorized")
|
||||||
|
|
||||||
|
if supplied:
|
||||||
|
self._take_issued_token_if_valid(supplied)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self._running = True
|
||||||
|
self._stop_event = asyncio.Event()
|
||||||
|
|
||||||
|
ssl_context = self._build_ssl_context()
|
||||||
|
scheme = "wss" if ssl_context else "ws"
|
||||||
|
|
||||||
|
async def process_request(
|
||||||
|
connection: ServerConnection,
|
||||||
|
request: WsRequest,
|
||||||
|
) -> Any:
|
||||||
|
got, _ = _parse_request_path(request.path)
|
||||||
|
if self.config.token_issue_path:
|
||||||
|
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
||||||
|
if got == issue_expected:
|
||||||
|
return self._handle_token_issue_http(connection, request)
|
||||||
|
|
||||||
|
expected_ws = self._expected_path()
|
||||||
|
if got != expected_ws:
|
||||||
|
return connection.respond(404, "Not Found")
|
||||||
|
return self._authorize_websocket_handshake(connection, request.path)
|
||||||
|
|
||||||
|
async def handler(connection: ServerConnection) -> None:
|
||||||
|
await self._connection_loop(connection)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"WebSocket server listening on {}://{}:{}{}",
|
||||||
|
scheme,
|
||||||
|
self.config.host,
|
||||||
|
self.config.port,
|
||||||
|
self.config.path,
|
||||||
|
)
|
||||||
|
if self.config.token_issue_path:
|
||||||
|
logger.info(
|
||||||
|
"WebSocket token issue route: {}://{}:{}{}",
|
||||||
|
scheme,
|
||||||
|
self.config.host,
|
||||||
|
self.config.port,
|
||||||
|
_normalize_config_path(self.config.token_issue_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def runner() -> None:
|
||||||
|
async with serve(
|
||||||
|
handler,
|
||||||
|
self.config.host,
|
||||||
|
self.config.port,
|
||||||
|
process_request=process_request,
|
||||||
|
max_size=self.config.max_message_bytes,
|
||||||
|
ping_interval=self.config.ping_interval_s,
|
||||||
|
ping_timeout=self.config.ping_timeout_s,
|
||||||
|
ssl=ssl_context,
|
||||||
|
):
|
||||||
|
assert self._stop_event is not None
|
||||||
|
await self._stop_event.wait()
|
||||||
|
|
||||||
|
self._server_task = asyncio.create_task(runner())
|
||||||
|
await self._server_task
|
||||||
|
|
||||||
|
async def _connection_loop(self, connection: Any) -> None:
|
||||||
|
request = connection.request
|
||||||
|
path_part = request.path if request else "/"
|
||||||
|
_, query = _parse_request_path(path_part)
|
||||||
|
client_id_raw = (query.get("client_id") or [None])[0]
|
||||||
|
client_id = client_id_raw.strip() if client_id_raw else ""
|
||||||
|
if not client_id:
|
||||||
|
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
chat_id = str(uuid.uuid4())
|
||||||
|
self._connections[chat_id] = connection
|
||||||
|
|
||||||
|
await connection.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"event": "ready",
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"client_id": client_id,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for raw in connection:
|
||||||
|
if isinstance(raw, bytes):
|
||||||
|
try:
|
||||||
|
raw = raw.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logger.warning("websocket: ignoring non-utf8 binary frame")
|
||||||
|
continue
|
||||||
|
content = _parse_inbound_payload(raw)
|
||||||
|
if content is None:
|
||||||
|
continue
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=client_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
content=content,
|
||||||
|
metadata={"remote": getattr(connection, "remote_address", None)},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("websocket connection ended: {}", e)
|
||||||
|
finally:
|
||||||
|
self._connections.pop(chat_id, None)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
self._running = False
|
||||||
|
if self._stop_event:
|
||||||
|
self._stop_event.set()
|
||||||
|
if self._server_task:
|
||||||
|
await self._server_task
|
||||||
|
self._server_task = None
|
||||||
|
self._connections.clear()
|
||||||
|
self._issued_tokens.clear()
|
||||||
|
|
||||||
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
|
connection = self._connections.get(msg.chat_id)
|
||||||
|
if connection is None:
|
||||||
|
logger.warning("websocket: no active connection for chat_id={}", msg.chat_id)
|
||||||
|
return
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"event": "message",
|
||||||
|
"text": msg.content,
|
||||||
|
}
|
||||||
|
if msg.media:
|
||||||
|
payload["media"] = msg.media
|
||||||
|
if msg.reply_to:
|
||||||
|
payload["reply_to"] = msg.reply_to
|
||||||
|
raw = json.dumps(payload, ensure_ascii=False)
|
||||||
|
try:
|
||||||
|
await connection.send(raw)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("websocket send failed: {}", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def send_delta(
|
||||||
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
delta: str,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
connection = self._connections.get(chat_id)
|
||||||
|
if connection is None:
|
||||||
|
return
|
||||||
|
meta = metadata or {}
|
||||||
|
if meta.get("_stream_end"):
|
||||||
|
body: dict[str, Any] = {"event": "stream_end"}
|
||||||
|
if meta.get("_stream_id") is not None:
|
||||||
|
body["stream_id"] = meta["_stream_id"]
|
||||||
|
else:
|
||||||
|
body = {
|
||||||
|
"event": "delta",
|
||||||
|
"text": delta,
|
||||||
|
}
|
||||||
|
if meta.get("_stream_id") is not None:
|
||||||
|
body["stream_id"] = meta["_stream_id"]
|
||||||
|
raw = json.dumps(body, ensure_ascii=False)
|
||||||
|
try:
|
||||||
|
await connection.send(raw)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("websocket stream send failed: {}", e)
|
||||||
|
raise
|
||||||
329
tests/channels/test_websocket_channel.py
Normal file
329
tests/channels/test_websocket_channel.py
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
"""Unit and lightweight integration tests for the WebSocket channel."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.channels.websocket import (
|
||||||
|
WebSocketChannel,
|
||||||
|
WebSocketConfig,
|
||||||
|
_issue_route_secret_matches,
|
||||||
|
_normalize_config_path,
|
||||||
|
_normalize_http_path,
|
||||||
|
_parse_inbound_payload,
|
||||||
|
_parse_query,
|
||||||
|
_parse_request_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
|
||||||
|
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_http_path_strips_trailing_slash_except_root() -> None:
|
||||||
|
assert _normalize_http_path("/chat/") == "/chat"
|
||||||
|
assert _normalize_http_path("/chat?x=1") == "/chat"
|
||||||
|
assert _normalize_http_path("/") == "/"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_request_path_matches_normalize_and_query() -> None:
|
||||||
|
path, query = _parse_request_path("/ws/?token=secret&client_id=u1")
|
||||||
|
assert path == _normalize_http_path("/ws/?token=secret&client_id=u1")
|
||||||
|
assert query == _parse_query("/ws/?token=secret&client_id=u1")
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_config_path_matches_request() -> None:
|
||||||
|
assert _normalize_config_path("/ws/") == "/ws"
|
||||||
|
assert _normalize_config_path("/") == "/"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_query_extracts_token_and_client_id() -> None:
|
||||||
|
query = _parse_query("/?token=secret&client_id=u1")
|
||||||
|
assert query.get("token") == ["secret"]
|
||||||
|
assert query.get("client_id") == ["u1"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("raw", "expected"),
|
||||||
|
[
|
||||||
|
("plain", "plain"),
|
||||||
|
('{"content": "hi"}', "hi"),
|
||||||
|
('{"text": "there"}', "there"),
|
||||||
|
('{"message": "x"}', "x"),
|
||||||
|
(" ", None),
|
||||||
|
("{}", None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_inbound_payload(raw: str, expected: str | None) -> None:
|
||||||
|
assert _parse_inbound_payload(raw) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
|
||||||
|
assert _parse_inbound_payload("{not json") == "{not json"
|
||||||
|
|
||||||
|
|
||||||
|
def test_web_socket_config_path_must_start_with_slash() -> None:
|
||||||
|
with pytest.raises(ValueError, match='path must start with "/"'):
|
||||||
|
WebSocketConfig(path="bad")
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssl_context_requires_both_cert_and_key_files() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel(
|
||||||
|
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
|
||||||
|
channel._build_ssl_context()
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_config_includes_safe_bind_and_streaming() -> None:
|
||||||
|
defaults = WebSocketChannel.default_config()
|
||||||
|
assert defaults["enabled"] is False
|
||||||
|
assert defaults["host"] == "127.0.0.1"
|
||||||
|
assert defaults["streaming"] is True
|
||||||
|
assert defaults["allowFrom"] == ["*"]
|
||||||
|
assert defaults.get("tokenIssuePath", "") == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_issue_path_must_differ_from_websocket_path() -> None:
|
||||||
|
with pytest.raises(ValueError, match="token_issue_path must differ"):
|
||||||
|
WebSocketConfig(path="/ws", token_issue_path="/ws")
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue_route_secret_matches_bearer_and_header() -> None:
|
||||||
|
from websockets.datastructures import Headers
|
||||||
|
|
||||||
|
secret = "my-secret"
|
||||||
|
bearer_headers = Headers([("Authorization", "Bearer my-secret")])
|
||||||
|
assert _issue_route_secret_matches(bearer_headers, secret) is True
|
||||||
|
x_headers = Headers([("X-Nanobot-Auth", "my-secret")])
|
||||||
|
assert _issue_route_secret_matches(x_headers, secret) is True
|
||||||
|
wrong = Headers([("Authorization", "Bearer other")])
|
||||||
|
assert _issue_route_secret_matches(wrong, secret) is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
channel._connections["chat-1"] = mock_ws
|
||||||
|
|
||||||
|
msg = OutboundMessage(
|
||||||
|
channel="websocket",
|
||||||
|
chat_id="chat-1",
|
||||||
|
content="hello",
|
||||||
|
reply_to="m1",
|
||||||
|
media=["/tmp/a.png"],
|
||||||
|
)
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
mock_ws.send.assert_awaited_once()
|
||||||
|
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||||
|
assert payload["event"] == "message"
|
||||||
|
assert payload["text"] == "hello"
|
||||||
|
assert payload["reply_to"] == "m1"
|
||||||
|
assert payload["media"] == ["/tmp/a.png"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_missing_connection_is_noop_without_error() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||||
|
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_delta_emits_delta_and_stream_end() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
channel._connections["chat-1"] = mock_ws
|
||||||
|
|
||||||
|
await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"})
|
||||||
|
await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"})
|
||||||
|
|
||||||
|
assert mock_ws.send.await_count == 2
|
||||||
|
first = json.loads(mock_ws.send.call_args_list[0][0][0])
|
||||||
|
second = json.loads(mock_ws.send.call_args_list[1][0][0])
|
||||||
|
assert first["event"] == "delta"
|
||||||
|
assert first["text"] == "part"
|
||||||
|
assert first["stream_id"] == "sid"
|
||||||
|
assert second["event"] == "stream_end"
|
||||||
|
assert second["stream_id"] == "sid"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
bus.publish_inbound = AsyncMock()
|
||||||
|
port = 29876
|
||||||
|
channel = WebSocketChannel(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"allowFrom": ["*"],
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": port,
|
||||||
|
"path": "/ws",
|
||||||
|
},
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client:
|
||||||
|
ready_raw = await client.recv()
|
||||||
|
ready = json.loads(ready_raw)
|
||||||
|
assert ready["event"] == "ready"
|
||||||
|
assert ready["client_id"] == "tester"
|
||||||
|
chat_id = ready["chat_id"]
|
||||||
|
|
||||||
|
await client.send(json.dumps({"content": "ping from client"}))
|
||||||
|
await asyncio.sleep(0.08)
|
||||||
|
|
||||||
|
bus.publish_inbound.assert_awaited()
|
||||||
|
inbound = bus.publish_inbound.call_args[0][0]
|
||||||
|
assert inbound.channel == "websocket"
|
||||||
|
assert inbound.sender_id == "tester"
|
||||||
|
assert inbound.chat_id == chat_id
|
||||||
|
assert inbound.content == "ping from client"
|
||||||
|
|
||||||
|
await client.send("plain text frame")
|
||||||
|
await asyncio.sleep(0.08)
|
||||||
|
assert bus.publish_inbound.await_count >= 2
|
||||||
|
second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1]
|
||||||
|
assert second.content == "plain text frame"
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_rejects_handshake_when_mismatch() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
port = 29877
|
||||||
|
channel = WebSocketChannel(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"allowFrom": ["*"],
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": port,
|
||||||
|
"path": "/",
|
||||||
|
"token": "secret",
|
||||||
|
},
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"):
|
||||||
|
pass
|
||||||
|
assert excinfo.value.response.status_code == 401
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wrong_path_returns_404() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
port = 29878
|
||||||
|
channel = WebSocketChannel(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"allowFrom": ["*"],
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": port,
|
||||||
|
"path": "/ws",
|
||||||
|
},
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/other"):
|
||||||
|
pass
|
||||||
|
assert excinfo.value.response.status_code == 404
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_discovers_websocket_channel() -> None:
|
||||||
|
from nanobot.channels.registry import load_channel_class
|
||||||
|
|
||||||
|
cls = load_channel_class("websocket")
|
||||||
|
assert cls.name == "websocket"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_route_issues_token_then_websocket_requires_it() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
bus.publish_inbound = AsyncMock()
|
||||||
|
port = 29879
|
||||||
|
channel = WebSocketChannel(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"allowFrom": ["*"],
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": port,
|
||||||
|
"path": "/ws",
|
||||||
|
"tokenIssuePath": "/auth/token",
|
||||||
|
"tokenIssueSecret": "route-secret",
|
||||||
|
"websocketRequiresToken": True,
|
||||||
|
},
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
deny = await _http_get(f"http://127.0.0.1:{port}/auth/token")
|
||||||
|
assert deny.status_code == 401
|
||||||
|
|
||||||
|
issue = await _http_get(
|
||||||
|
f"http://127.0.0.1:{port}/auth/token",
|
||||||
|
headers={"Authorization": "Bearer route-secret"},
|
||||||
|
)
|
||||||
|
assert issue.status_code == 200
|
||||||
|
token = issue.json()["token"]
|
||||||
|
assert token.startswith("nbwt_")
|
||||||
|
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"):
|
||||||
|
pass
|
||||||
|
assert missing_token.value.response.status_code == 401
|
||||||
|
|
||||||
|
uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller"
|
||||||
|
async with websockets.connect(uri) as client:
|
||||||
|
ready = json.loads(await client.recv())
|
||||||
|
assert ready["event"] == "ready"
|
||||||
|
assert ready["client_id"] == "caller"
|
||||||
|
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as reuse:
|
||||||
|
async with websockets.connect(uri):
|
||||||
|
pass
|
||||||
|
assert reuse.value.response.status_code == 401
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
Loading…
x
Reference in New Issue
Block a user