mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 13:13:39 +00:00
feat(channels): add WebSocket server channel and tests
Port Python implementation from a1ec7b192ad97ffd58250a720891ff09bbb73888 (websocket channel module and channel tests; excludes webui debug app).
This commit is contained in:
parent
51200a954c
commit
e00dca2f84
418
nanobot/channels/websocket.py
Normal file
418
nanobot/channels/websocket.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import email.utils
|
||||
import hmac
|
||||
import http
|
||||
import json
|
||||
import secrets
|
||||
import ssl
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Self
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from websockets.asyncio.server import ServerConnection, serve
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.http11 import Request as WsRequest, Response
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
|
||||
def _strip_trailing_slash(path: str) -> str:
|
||||
if len(path) > 1 and path.endswith("/"):
|
||||
return path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _normalize_config_path(path: str) -> str:
|
||||
return _strip_trailing_slash(path)
|
||||
|
||||
|
||||
class WebSocketConfig(Base):
|
||||
"""WebSocket server channel configuration.
|
||||
|
||||
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
|
||||
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
|
||||
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
|
||||
from ``token_issue_path`` are also accepted.
|
||||
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
|
||||
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
|
||||
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
|
||||
nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call
|
||||
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
|
||||
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
|
||||
``X-Nanobot-Auth: <secret>``.
|
||||
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
|
||||
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 8765
|
||||
path: str = "/"
|
||||
token: str = ""
|
||||
token_issue_path: str = ""
|
||||
token_issue_secret: str = ""
|
||||
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
|
||||
websocket_requires_token: bool = False
|
||||
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
||||
streaming: bool = True
|
||||
max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216)
|
||||
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||||
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||||
ssl_certfile: str = ""
|
||||
ssl_keyfile: str = ""
|
||||
|
||||
@field_validator("path")
|
||||
@classmethod
|
||||
def path_must_start_with_slash(cls, value: str) -> str:
|
||||
if not value.startswith("/"):
|
||||
raise ValueError('path must start with "/"')
|
||||
return value
|
||||
|
||||
@field_validator("token_issue_path")
|
||||
@classmethod
|
||||
def token_issue_path_format(cls, value: str) -> str:
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return ""
|
||||
if not value.startswith("/"):
|
||||
raise ValueError('token_issue_path must start with "/"')
|
||||
return _normalize_config_path(value)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def token_issue_path_differs_from_ws_path(self) -> Self:
|
||||
if not self.token_issue_path:
|
||||
return self
|
||||
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
|
||||
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
|
||||
return self
|
||||
|
||||
|
||||
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||
headers = Headers(
|
||||
[
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", "application/json; charset=utf-8"),
|
||||
]
|
||||
)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, headers, body)
|
||||
|
||||
|
||||
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||||
"""Parse normalized path and query parameters in one pass."""
|
||||
parsed = urlparse("ws://x" + path_with_query)
|
||||
path = _strip_trailing_slash(parsed.path or "/")
|
||||
return path, parse_qs(parsed.query)
|
||||
|
||||
|
||||
def _normalize_http_path(path_with_query: str) -> str:
|
||||
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
|
||||
return _parse_request_path(path_with_query)[0]
|
||||
|
||||
|
||||
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||||
return _parse_request_path(path_with_query)[1]
|
||||
|
||||
|
||||
def _parse_inbound_payload(raw: str) -> str | None:
|
||||
"""Parse a client frame into text; return None for empty or unrecognized content."""
|
||||
text = raw.strip()
|
||||
if not text:
|
||||
return None
|
||||
if text.startswith("{"):
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return text
|
||||
if isinstance(data, dict):
|
||||
for key in ("content", "text", "message"):
|
||||
value = data.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value
|
||||
return None
|
||||
return None
|
||||
return text
|
||||
|
||||
|
||||
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||||
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
|
||||
if not configured_secret:
|
||||
return True
|
||||
authorization = headers.get("Authorization") or headers.get("authorization")
|
||||
if authorization and authorization.lower().startswith("bearer "):
|
||||
supplied = authorization[7:].strip()
|
||||
return hmac.compare_digest(supplied, configured_secret)
|
||||
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||||
if not header_token:
|
||||
return False
|
||||
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||||
|
||||
|
||||
class WebSocketChannel(BaseChannel):
|
||||
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
|
||||
|
||||
name = "websocket"
|
||||
display_name = "WebSocket"
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
if isinstance(config, dict):
|
||||
config = WebSocketConfig.model_validate(config)
|
||||
super().__init__(config, bus)
|
||||
self.config: WebSocketConfig = config
|
||||
self._connections: dict[str, Any] = {}
|
||||
self._issued_tokens: dict[str, float] = {}
|
||||
self._stop_event: asyncio.Event | None = None
|
||||
self._server_task: asyncio.Task[None] | None = None
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
return WebSocketConfig().model_dump(by_alias=True)
|
||||
|
||||
def _expected_path(self) -> str:
|
||||
return _normalize_config_path(self.config.path)
|
||||
|
||||
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
||||
cert = self.config.ssl_certfile.strip()
|
||||
key = self.config.ssl_keyfile.strip()
|
||||
if not cert and not key:
|
||||
return None
|
||||
if not cert or not key:
|
||||
raise ValueError(
|
||||
"websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
||||
)
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||||
return ctx
|
||||
|
||||
def _purge_expired_issued_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self._issued_tokens.items()):
|
||||
if now > expiry:
|
||||
self._issued_tokens.pop(token_key, None)
|
||||
|
||||
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||||
"""Validate and consume one issued token (single use per connection attempt)."""
|
||||
if not token_value:
|
||||
return False
|
||||
self._purge_expired_issued_tokens()
|
||||
expiry = self._issued_tokens.get(token_value)
|
||||
if expiry is None:
|
||||
return False
|
||||
if time.monotonic() > expiry:
|
||||
self._issued_tokens.pop(token_value, None)
|
||||
return False
|
||||
self._issued_tokens.pop(token_value, None)
|
||||
return True
|
||||
|
||||
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
|
||||
secret = self.config.token_issue_secret.strip()
|
||||
if secret:
|
||||
if not _issue_route_secret_matches(request.headers, secret):
|
||||
return connection.respond(401, "Unauthorized")
|
||||
else:
|
||||
logger.warning(
|
||||
"websocket: token_issue_path is set but token_issue_secret is empty; "
|
||||
"any client can obtain connection tokens — set token_issue_secret for production."
|
||||
)
|
||||
self._purge_expired_issued_tokens()
|
||||
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||||
|
||||
return _http_json_response(
|
||||
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||||
)
|
||||
|
||||
def _authorize_websocket_handshake(self, connection: Any, request_path: str) -> Any:
|
||||
query = _parse_query(request_path)
|
||||
supplied = (query.get("token") or [None])[0]
|
||||
static_token = self.config.token.strip()
|
||||
|
||||
if static_token:
|
||||
if supplied == static_token:
|
||||
return None
|
||||
if supplied and self._take_issued_token_if_valid(supplied):
|
||||
return None
|
||||
return connection.respond(401, "Unauthorized")
|
||||
|
||||
if self.config.websocket_requires_token:
|
||||
if supplied and self._take_issued_token_if_valid(supplied):
|
||||
return None
|
||||
return connection.respond(401, "Unauthorized")
|
||||
|
||||
if supplied:
|
||||
self._take_issued_token_if_valid(supplied)
|
||||
return None
|
||||
|
||||
async def start(self) -> None:
|
||||
self._running = True
|
||||
self._stop_event = asyncio.Event()
|
||||
|
||||
ssl_context = self._build_ssl_context()
|
||||
scheme = "wss" if ssl_context else "ws"
|
||||
|
||||
async def process_request(
|
||||
connection: ServerConnection,
|
||||
request: WsRequest,
|
||||
) -> Any:
|
||||
got, _ = _parse_request_path(request.path)
|
||||
if self.config.token_issue_path:
|
||||
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
||||
if got == issue_expected:
|
||||
return self._handle_token_issue_http(connection, request)
|
||||
|
||||
expected_ws = self._expected_path()
|
||||
if got != expected_ws:
|
||||
return connection.respond(404, "Not Found")
|
||||
return self._authorize_websocket_handshake(connection, request.path)
|
||||
|
||||
async def handler(connection: ServerConnection) -> None:
|
||||
await self._connection_loop(connection)
|
||||
|
||||
logger.info(
|
||||
"WebSocket server listening on {}://{}:{}{}",
|
||||
scheme,
|
||||
self.config.host,
|
||||
self.config.port,
|
||||
self.config.path,
|
||||
)
|
||||
if self.config.token_issue_path:
|
||||
logger.info(
|
||||
"WebSocket token issue route: {}://{}:{}{}",
|
||||
scheme,
|
||||
self.config.host,
|
||||
self.config.port,
|
||||
_normalize_config_path(self.config.token_issue_path),
|
||||
)
|
||||
|
||||
async def runner() -> None:
|
||||
async with serve(
|
||||
handler,
|
||||
self.config.host,
|
||||
self.config.port,
|
||||
process_request=process_request,
|
||||
max_size=self.config.max_message_bytes,
|
||||
ping_interval=self.config.ping_interval_s,
|
||||
ping_timeout=self.config.ping_timeout_s,
|
||||
ssl=ssl_context,
|
||||
):
|
||||
assert self._stop_event is not None
|
||||
await self._stop_event.wait()
|
||||
|
||||
self._server_task = asyncio.create_task(runner())
|
||||
await self._server_task
|
||||
|
||||
async def _connection_loop(self, connection: Any) -> None:
|
||||
request = connection.request
|
||||
path_part = request.path if request else "/"
|
||||
_, query = _parse_request_path(path_part)
|
||||
client_id_raw = (query.get("client_id") or [None])[0]
|
||||
client_id = client_id_raw.strip() if client_id_raw else ""
|
||||
if not client_id:
|
||||
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
chat_id = str(uuid.uuid4())
|
||||
self._connections[chat_id] = connection
|
||||
|
||||
await connection.send(
|
||||
json.dumps(
|
||||
{
|
||||
"event": "ready",
|
||||
"chat_id": chat_id,
|
||||
"client_id": client_id,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async for raw in connection:
|
||||
if isinstance(raw, bytes):
|
||||
try:
|
||||
raw = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("websocket: ignoring non-utf8 binary frame")
|
||||
continue
|
||||
content = _parse_inbound_payload(raw)
|
||||
if content is None:
|
||||
continue
|
||||
await self._handle_message(
|
||||
sender_id=client_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
metadata={"remote": getattr(connection, "remote_address", None)},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("websocket connection ended: {}", e)
|
||||
finally:
|
||||
self._connections.pop(chat_id, None)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
if self._stop_event:
|
||||
self._stop_event.set()
|
||||
if self._server_task:
|
||||
await self._server_task
|
||||
self._server_task = None
|
||||
self._connections.clear()
|
||||
self._issued_tokens.clear()
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
connection = self._connections.get(msg.chat_id)
|
||||
if connection is None:
|
||||
logger.warning("websocket: no active connection for chat_id={}", msg.chat_id)
|
||||
return
|
||||
payload: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"text": msg.content,
|
||||
}
|
||||
if msg.media:
|
||||
payload["media"] = msg.media
|
||||
if msg.reply_to:
|
||||
payload["reply_to"] = msg.reply_to
|
||||
raw = json.dumps(payload, ensure_ascii=False)
|
||||
try:
|
||||
await connection.send(raw)
|
||||
except Exception as e:
|
||||
logger.error("websocket send failed: {}", e)
|
||||
raise
|
||||
|
||||
async def send_delta(
|
||||
self,
|
||||
chat_id: str,
|
||||
delta: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
connection = self._connections.get(chat_id)
|
||||
if connection is None:
|
||||
return
|
||||
meta = metadata or {}
|
||||
if meta.get("_stream_end"):
|
||||
body: dict[str, Any] = {"event": "stream_end"}
|
||||
if meta.get("_stream_id") is not None:
|
||||
body["stream_id"] = meta["_stream_id"]
|
||||
else:
|
||||
body = {
|
||||
"event": "delta",
|
||||
"text": delta,
|
||||
}
|
||||
if meta.get("_stream_id") is not None:
|
||||
body["stream_id"] = meta["_stream_id"]
|
||||
raw = json.dumps(body, ensure_ascii=False)
|
||||
try:
|
||||
await connection.send(raw)
|
||||
except Exception as e:
|
||||
logger.error("websocket stream send failed: {}", e)
|
||||
raise
|
||||
329
tests/channels/test_websocket_channel.py
Normal file
329
tests/channels/test_websocket_channel.py
Normal file
@ -0,0 +1,329 @@
|
||||
"""Unit and lightweight integration tests for the WebSocket channel."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import websockets
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.channels.websocket import (
|
||||
WebSocketChannel,
|
||||
WebSocketConfig,
|
||||
_issue_route_secret_matches,
|
||||
_normalize_config_path,
|
||||
_normalize_http_path,
|
||||
_parse_inbound_payload,
|
||||
_parse_query,
|
||||
_parse_request_path,
|
||||
)
|
||||
|
||||
|
||||
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
|
||||
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
|
||||
return await asyncio.to_thread(
|
||||
functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0)
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_http_path_strips_trailing_slash_except_root() -> None:
|
||||
assert _normalize_http_path("/chat/") == "/chat"
|
||||
assert _normalize_http_path("/chat?x=1") == "/chat"
|
||||
assert _normalize_http_path("/") == "/"
|
||||
|
||||
|
||||
def test_parse_request_path_matches_normalize_and_query() -> None:
|
||||
path, query = _parse_request_path("/ws/?token=secret&client_id=u1")
|
||||
assert path == _normalize_http_path("/ws/?token=secret&client_id=u1")
|
||||
assert query == _parse_query("/ws/?token=secret&client_id=u1")
|
||||
|
||||
|
||||
def test_normalize_config_path_matches_request() -> None:
|
||||
assert _normalize_config_path("/ws/") == "/ws"
|
||||
assert _normalize_config_path("/") == "/"
|
||||
|
||||
|
||||
def test_parse_query_extracts_token_and_client_id() -> None:
|
||||
query = _parse_query("/?token=secret&client_id=u1")
|
||||
assert query.get("token") == ["secret"]
|
||||
assert query.get("client_id") == ["u1"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "expected"),
|
||||
[
|
||||
("plain", "plain"),
|
||||
('{"content": "hi"}', "hi"),
|
||||
('{"text": "there"}', "there"),
|
||||
('{"message": "x"}', "x"),
|
||||
(" ", None),
|
||||
("{}", None),
|
||||
],
|
||||
)
|
||||
def test_parse_inbound_payload(raw: str, expected: str | None) -> None:
|
||||
assert _parse_inbound_payload(raw) == expected
|
||||
|
||||
|
||||
def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
|
||||
assert _parse_inbound_payload("{not json") == "{not json"
|
||||
|
||||
|
||||
def test_web_socket_config_path_must_start_with_slash() -> None:
|
||||
with pytest.raises(ValueError, match='path must start with "/"'):
|
||||
WebSocketConfig(path="bad")
|
||||
|
||||
|
||||
def test_ssl_context_requires_both_cert_and_key_files() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
|
||||
bus,
|
||||
)
|
||||
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
|
||||
channel._build_ssl_context()
|
||||
|
||||
|
||||
def test_default_config_includes_safe_bind_and_streaming() -> None:
|
||||
defaults = WebSocketChannel.default_config()
|
||||
assert defaults["enabled"] is False
|
||||
assert defaults["host"] == "127.0.0.1"
|
||||
assert defaults["streaming"] is True
|
||||
assert defaults["allowFrom"] == ["*"]
|
||||
assert defaults.get("tokenIssuePath", "") == ""
|
||||
|
||||
|
||||
def test_token_issue_path_must_differ_from_websocket_path() -> None:
|
||||
with pytest.raises(ValueError, match="token_issue_path must differ"):
|
||||
WebSocketConfig(path="/ws", token_issue_path="/ws")
|
||||
|
||||
|
||||
def test_issue_route_secret_matches_bearer_and_header() -> None:
|
||||
from websockets.datastructures import Headers
|
||||
|
||||
secret = "my-secret"
|
||||
bearer_headers = Headers([("Authorization", "Bearer my-secret")])
|
||||
assert _issue_route_secret_matches(bearer_headers, secret) is True
|
||||
x_headers = Headers([("X-Nanobot-Auth", "my-secret")])
|
||||
assert _issue_route_secret_matches(x_headers, secret) is True
|
||||
wrong = Headers([("Authorization", "Bearer other")])
|
||||
assert _issue_route_secret_matches(wrong, secret) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="chat-1",
|
||||
content="hello",
|
||||
reply_to="m1",
|
||||
media=["/tmp/a.png"],
|
||||
)
|
||||
await channel.send(msg)
|
||||
|
||||
mock_ws.send.assert_awaited_once()
|
||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||
assert payload["event"] == "message"
|
||||
assert payload["text"] == "hello"
|
||||
assert payload["reply_to"] == "m1"
|
||||
assert payload["media"] == ["/tmp/a.png"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_missing_connection_is_noop_without_error() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
|
||||
await channel.send(msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_emits_delta_and_stream_end() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"})
|
||||
await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"})
|
||||
|
||||
assert mock_ws.send.await_count == 2
|
||||
first = json.loads(mock_ws.send.call_args_list[0][0][0])
|
||||
second = json.loads(mock_ws.send.call_args_list[1][0][0])
|
||||
assert first["event"] == "delta"
|
||||
assert first["text"] == "part"
|
||||
assert first["stream_id"] == "sid"
|
||||
assert second["event"] == "stream_end"
|
||||
assert second["stream_id"] == "sid"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None:
|
||||
bus = MagicMock()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
port = 29876
|
||||
channel = WebSocketChannel(
|
||||
{
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"path": "/ws",
|
||||
},
|
||||
bus,
|
||||
)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client:
|
||||
ready_raw = await client.recv()
|
||||
ready = json.loads(ready_raw)
|
||||
assert ready["event"] == "ready"
|
||||
assert ready["client_id"] == "tester"
|
||||
chat_id = ready["chat_id"]
|
||||
|
||||
await client.send(json.dumps({"content": "ping from client"}))
|
||||
await asyncio.sleep(0.08)
|
||||
|
||||
bus.publish_inbound.assert_awaited()
|
||||
inbound = bus.publish_inbound.call_args[0][0]
|
||||
assert inbound.channel == "websocket"
|
||||
assert inbound.sender_id == "tester"
|
||||
assert inbound.chat_id == chat_id
|
||||
assert inbound.content == "ping from client"
|
||||
|
||||
await client.send("plain text frame")
|
||||
await asyncio.sleep(0.08)
|
||||
assert bus.publish_inbound.await_count >= 2
|
||||
second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1]
|
||||
assert second.content == "plain text frame"
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_rejects_handshake_when_mismatch() -> None:
|
||||
bus = MagicMock()
|
||||
port = 29877
|
||||
channel = WebSocketChannel(
|
||||
{
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"path": "/",
|
||||
"token": "secret",
|
||||
},
|
||||
bus,
|
||||
)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"):
|
||||
pass
|
||||
assert excinfo.value.response.status_code == 401
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_path_returns_404() -> None:
|
||||
bus = MagicMock()
|
||||
port = 29878
|
||||
channel = WebSocketChannel(
|
||||
{
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"path": "/ws",
|
||||
},
|
||||
bus,
|
||||
)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/other"):
|
||||
pass
|
||||
assert excinfo.value.response.status_code == 404
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
def test_registry_discovers_websocket_channel() -> None:
|
||||
from nanobot.channels.registry import load_channel_class
|
||||
|
||||
cls = load_channel_class("websocket")
|
||||
assert cls.name == "websocket"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_route_issues_token_then_websocket_requires_it() -> None:
|
||||
bus = MagicMock()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
port = 29879
|
||||
channel = WebSocketChannel(
|
||||
{
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"path": "/ws",
|
||||
"tokenIssuePath": "/auth/token",
|
||||
"tokenIssueSecret": "route-secret",
|
||||
"websocketRequiresToken": True,
|
||||
},
|
||||
bus,
|
||||
)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
deny = await _http_get(f"http://127.0.0.1:{port}/auth/token")
|
||||
assert deny.status_code == 401
|
||||
|
||||
issue = await _http_get(
|
||||
f"http://127.0.0.1:{port}/auth/token",
|
||||
headers={"Authorization": "Bearer route-secret"},
|
||||
)
|
||||
assert issue.status_code == 200
|
||||
token = issue.json()["token"]
|
||||
assert token.startswith("nbwt_")
|
||||
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"):
|
||||
pass
|
||||
assert missing_token.value.response.status_code == 401
|
||||
|
||||
uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller"
|
||||
async with websockets.connect(uri) as client:
|
||||
ready = json.loads(await client.recv())
|
||||
assert ready["event"] == "ready"
|
||||
assert ready["client_id"] == "caller"
|
||||
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as reuse:
|
||||
async with websockets.connect(uri):
|
||||
pass
|
||||
assert reuse.value.response.status_code == 401
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
Loading…
x
Reference in New Issue
Block a user