fix(websocket): harden security and robustness

- Use hmac.compare_digest for timing-safe static token comparison
- Add issued token capacity limit (_MAX_ISSUED_TOKENS=10000) with 429 response
- Use atomic pop in _take_issued_token_if_valid to eliminate TOCTOU window
- Enforce TLSv1.2 minimum version for SSL connections
- Extract _safe_send helper for consistent ConnectionClosed handling
- Move connection registration after ready send to prevent out-of-order delivery
- Add HTTP-level allow_from check and client_id truncation in process_request
- Make stop() idempotent with graceful shutdown error handling
- Normalize path via validator instead of leaving raw value
- Default websocket_requires_token to True for secure-by-default behavior
- Add integration tests and ws_test_client helper
- Refactor tests to use shared _ch factory and bus fixture
This commit is contained in:
chengyongru 2026-04-09 15:18:02 +08:00 committed by chengyongru
parent d327c19db0
commit 8f7ce9fef7
4 changed files with 1102 additions and 99 deletions

View File

@ -65,7 +65,7 @@ class WebSocketConfig(Base):
token_issue_path: str = ""
token_issue_secret: str = ""
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
websocket_requires_token: bool = False
websocket_requires_token: bool = True
allow_from: list[str] = Field(default_factory=lambda: ["*"])
streaming: bool = True
max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216)
@ -79,7 +79,7 @@ class WebSocketConfig(Base):
def path_must_start_with_slash(cls, value: str) -> str:
if not value.startswith("/"):
raise ValueError('path must start with "/"')
return value
return _normalize_config_path(value)
@field_validator("token_issue_path")
@classmethod
@ -130,6 +130,12 @@ def _parse_query(path_with_query: str) -> dict[str, list[str]]:
return _parse_request_path(path_with_query)[1]
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
"""Return the first value for *key*, or None."""
values = query.get(key)
return values[0] if values else None
def _parse_inbound_payload(raw: str) -> str | None:
"""Parse a client frame into text; return None for empty or unrecognized content."""
text = raw.strip()
@ -197,9 +203,12 @@ class WebSocketChannel(BaseChannel):
"websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
)
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ctx.load_cert_chain(certfile=cert, keyfile=key)
return ctx
_MAX_ISSUED_TOKENS = 10_000
def _purge_expired_issued_tokens(self) -> None:
now = time.monotonic()
for token_key, expiry in list(self._issued_tokens.items()):
@ -207,17 +216,19 @@ class WebSocketChannel(BaseChannel):
self._issued_tokens.pop(token_key, None)
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
"""Validate and consume one issued token (single use per connection attempt)."""
"""Validate and consume one issued token (single use per connection attempt).
Uses single-step pop to minimize the window between lookup and removal;
safe under asyncio's single-threaded cooperative model.
"""
if not token_value:
return False
self._purge_expired_issued_tokens()
expiry = self._issued_tokens.get(token_value)
expiry = self._issued_tokens.pop(token_value, None)
if expiry is None:
return False
if time.monotonic() > expiry:
self._issued_tokens.pop(token_value, None)
return False
self._issued_tokens.pop(token_value, None)
return True
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
@ -231,6 +242,12 @@ class WebSocketChannel(BaseChannel):
"any client can obtain connection tokens — set token_issue_secret for production."
)
self._purge_expired_issued_tokens()
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
logger.error(
"websocket: too many outstanding issued tokens ({}), rejecting issuance",
len(self._issued_tokens),
)
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
@ -238,13 +255,12 @@ class WebSocketChannel(BaseChannel):
{"token": token_value, "expires_in": self.config.token_ttl_s}
)
def _authorize_websocket_handshake(self, connection: Any, request_path: str) -> Any:
query = _parse_query(request_path)
supplied = (query.get("token") or [None])[0]
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
supplied = _query_first(query, "token")
static_token = self.config.token.strip()
if static_token:
if supplied == static_token:
if supplied and hmac.compare_digest(supplied, static_token):
return None
if supplied and self._take_issued_token_if_valid(supplied):
return None
@ -279,7 +295,15 @@ class WebSocketChannel(BaseChannel):
expected_ws = self._expected_path()
if got != expected_ws:
return connection.respond(404, "Not Found")
return self._authorize_websocket_handshake(connection, request.path)
# Early reject before WebSocket upgrade to avoid unnecessary overhead;
# _handle_message() performs a second check as defense-in-depth.
query = _parse_query(request.path)
client_id = _query_first(query, "client_id") or ""
if len(client_id) > 128:
client_id = client_id[:128]
if not self.is_allowed(client_id):
return connection.respond(403, "Forbidden")
return self._authorize_websocket_handshake(connection, query)
async def handler(connection: ServerConnection) -> None:
await self._connection_loop(connection)
@ -321,26 +345,30 @@ class WebSocketChannel(BaseChannel):
request = connection.request
path_part = request.path if request else "/"
_, query = _parse_request_path(path_part)
client_id_raw = (query.get("client_id") or [None])[0]
client_id_raw = _query_first(query, "client_id")
client_id = client_id_raw.strip() if client_id_raw else ""
if not client_id:
client_id = f"anon-{uuid.uuid4().hex[:12]}"
elif len(client_id) > 128:
logger.warning("websocket: client_id too long ({} chars), truncating", len(client_id))
client_id = client_id[:128]
chat_id = str(uuid.uuid4())
self._connections[chat_id] = connection
await connection.send(
json.dumps(
{
"event": "ready",
"chat_id": chat_id,
"client_id": client_id,
},
ensure_ascii=False,
)
)
try:
await connection.send(
json.dumps(
{
"event": "ready",
"chat_id": chat_id,
"client_id": client_id,
},
ensure_ascii=False,
)
)
# Register only after ready is successfully sent to avoid out-of-order sends
self._connections[chat_id] = connection
async for raw in connection:
if isinstance(raw, bytes):
try:
@ -363,15 +391,34 @@ class WebSocketChannel(BaseChannel):
self._connections.pop(chat_id, None)
async def stop(self) -> None:
if not self._running:
return
self._running = False
if self._stop_event:
self._stop_event.set()
if self._server_task:
await self._server_task
try:
await self._server_task
except Exception as e:
logger.warning("websocket: server task error during shutdown: {}", e)
self._server_task = None
self._connections.clear()
self._issued_tokens.clear()
async def _safe_send(self, chat_id: str, raw: str, *, label: str = "") -> None:
"""Send a raw frame, cleaning up dead connections on ConnectionClosed."""
connection = self._connections.get(chat_id)
if connection is None:
return
try:
await connection.send(raw)
except ConnectionClosed:
self._connections.pop(chat_id, None)
logger.warning("websocket{}connection gone for chat_id={}", label, chat_id)
except Exception as e:
logger.error("websocket{}send failed: {}", label, e)
raise
async def send(self, msg: OutboundMessage) -> None:
connection = self._connections.get(msg.chat_id)
if connection is None:
@ -386,14 +433,7 @@ class WebSocketChannel(BaseChannel):
if msg.reply_to:
payload["reply_to"] = msg.reply_to
raw = json.dumps(payload, ensure_ascii=False)
try:
await connection.send(raw)
except ConnectionClosed:
self._connections.pop(msg.chat_id, None)
logger.warning("websocket: connection gone for chat_id={}", msg.chat_id)
except Exception as e:
logger.error("websocket send failed: {}", e)
raise
await self._safe_send(msg.chat_id, raw, label=" ")
async def send_delta(
self,
@ -401,27 +441,17 @@ class WebSocketChannel(BaseChannel):
delta: str,
metadata: dict[str, Any] | None = None,
) -> None:
connection = self._connections.get(chat_id)
if connection is None:
if self._connections.get(chat_id) is None:
return
meta = metadata or {}
if meta.get("_stream_end"):
body: dict[str, Any] = {"event": "stream_end"}
if meta.get("_stream_id") is not None:
body["stream_id"] = meta["_stream_id"]
else:
body = {
"event": "delta",
"text": delta,
}
if meta.get("_stream_id") is not None:
body["stream_id"] = meta["_stream_id"]
if meta.get("_stream_id") is not None:
body["stream_id"] = meta["_stream_id"]
raw = json.dumps(body, ensure_ascii=False)
try:
await connection.send(raw)
except ConnectionClosed:
self._connections.pop(chat_id, None)
logger.warning("websocket: stream connection gone for chat_id={}", chat_id)
except Exception as e:
logger.error("websocket stream send failed: {}", e)
raise
await self._safe_send(chat_id, raw, label=" stream ")

View File

@ -3,11 +3,15 @@
import asyncio
import functools
import json
import time
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import httpx
import pytest
import websockets
from websockets.exceptions import ConnectionClosed
from websockets.frames import Close
from nanobot.bus.events import OutboundMessage
from nanobot.channels.websocket import (
@ -21,6 +25,30 @@ from nanobot.channels.websocket import (
_parse_request_path,
)
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
_PORT = 29876
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
cfg: dict[str, Any] = {
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": _PORT,
"path": "/ws",
"websocketRequiresToken": False,
}
cfg.update(kw)
return WebSocketChannel(cfg, bus)
@pytest.fixture()
def bus() -> MagicMock:
b = MagicMock()
b.publish_inbound = AsyncMock()
return b
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
@ -71,6 +99,21 @@ def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
assert _parse_inbound_payload("{not json") == "{not json"
@pytest.mark.parametrize(
("raw", "expected"),
[
('{"content": ""}', None), # empty string content
('{"content": 123}', None), # non-string content
('{"content": " "}', None), # whitespace-only content
('["hello"]', '["hello"]'), # JSON array: not a dict, treated as plain text
('{"unknown_key": "val"}', None), # unrecognized key
('{"content": null}', None), # null content
],
)
def test_parse_inbound_payload_edge_cases(raw: str, expected: str | None) -> None:
assert _parse_inbound_payload(raw) == expected
def test_web_socket_config_path_must_start_with_slash() -> None:
with pytest.raises(ValueError, match='path must start with "/"'):
WebSocketConfig(path="bad")
@ -112,6 +155,14 @@ def test_issue_route_secret_matches_bearer_and_header() -> None:
assert _issue_route_secret_matches(wrong, secret) is False
def test_issue_route_secret_matches_empty_secret() -> None:
from websockets.datastructures import Headers
# Empty secret always returns True regardless of headers
assert _issue_route_secret_matches(Headers([]), "") is True
assert _issue_route_secret_matches(Headers([("Authorization", "Bearer anything")]), "") is True
@pytest.mark.asyncio
async def test_send_delivers_json_message_with_media_and_reply() -> None:
bus = MagicMock()
@ -144,6 +195,33 @@ async def test_send_missing_connection_is_noop_without_error() -> None:
await channel.send(msg)
@pytest.mark.asyncio
async def test_send_removes_connection_on_connection_closed() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock()
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
channel._connections["chat-1"] = mock_ws
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
await channel.send(msg)
assert "chat-1" not in channel._connections
@pytest.mark.asyncio
async def test_send_delta_removes_connection_on_connection_closed() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
mock_ws = AsyncMock()
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
channel._connections["chat-1"] = mock_ws
await channel.send_delta("chat-1", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
assert "chat-1" not in channel._connections
@pytest.mark.asyncio
async def test_send_delta_emits_delta_and_stream_end() -> None:
bus = MagicMock()
@ -165,20 +243,39 @@ async def test_send_delta_emits_delta_and_stream_end() -> None:
@pytest.mark.asyncio
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None:
async def test_send_non_connection_closed_exception_is_raised() -> None:
bus = MagicMock()
bus.publish_inbound = AsyncMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock()
mock_ws.send.side_effect = RuntimeError("unexpected")
channel._connections["chat-1"] = mock_ws
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
with pytest.raises(RuntimeError, match="unexpected"):
await channel.send(msg)
@pytest.mark.asyncio
async def test_send_delta_missing_connection_is_noop() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
# No exception, no error — just a no-op
await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
@pytest.mark.asyncio
async def test_stop_is_idempotent() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
# stop() before start() should not raise
await channel.stop()
await channel.stop()
@pytest.mark.asyncio
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound(bus: MagicMock) -> None:
port = 29876
channel = WebSocketChannel(
{
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/ws",
},
bus,
)
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
@ -212,20 +309,9 @@ async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None
@pytest.mark.asyncio
async def test_token_rejects_handshake_when_mismatch() -> None:
bus = MagicMock()
async def test_token_rejects_handshake_when_mismatch(bus: MagicMock) -> None:
port = 29877
channel = WebSocketChannel(
{
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/",
"token": "secret",
},
bus,
)
channel = _ch(bus, port=port, path="/", token="secret")
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
@ -241,19 +327,9 @@ async def test_token_rejects_handshake_when_mismatch() -> None:
@pytest.mark.asyncio
async def test_wrong_path_returns_404() -> None:
bus = MagicMock()
async def test_wrong_path_returns_404(bus: MagicMock) -> None:
port = 29878
channel = WebSocketChannel(
{
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/ws",
},
bus,
)
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
@ -276,22 +352,13 @@ def test_registry_discovers_websocket_channel() -> None:
@pytest.mark.asyncio
async def test_http_route_issues_token_then_websocket_requires_it() -> None:
bus = MagicMock()
bus.publish_inbound = AsyncMock()
async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock) -> None:
port = 29879
channel = WebSocketChannel(
{
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/ws",
"tokenIssuePath": "/auth/token",
"tokenIssueSecret": "route-secret",
"websocketRequiresToken": True,
},
bus,
channel = _ch(
bus, port=port,
tokenIssuePath="/auth/token",
tokenIssueSecret="route-secret",
websocketRequiresToken=True,
)
server_task = asyncio.create_task(channel.start())
@ -327,3 +394,205 @@ async def test_http_route_issues_token_then_websocket_requires_it() -> None:
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
port = 29880
channel = _ch(bus, port=port, streaming=True)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=stream-tester") as client:
ready_raw = await client.recv()
ready = json.loads(ready_raw)
chat_id = ready["chat_id"]
# Server pushes deltas directly
await channel.send_delta(
chat_id, "Hello ", {"_stream_delta": True, "_stream_id": "s1"}
)
await channel.send_delta(
chat_id, "world", {"_stream_delta": True, "_stream_id": "s1"}
)
await channel.send_delta(
chat_id, "", {"_stream_end": True, "_stream_id": "s1"}
)
delta1 = json.loads(await client.recv())
assert delta1["event"] == "delta"
assert delta1["text"] == "Hello "
assert delta1["stream_id"] == "s1"
delta2 = json.loads(await client.recv())
assert delta2["event"] == "delta"
assert delta2["text"] == "world"
assert delta2["stream_id"] == "s1"
end = json.loads(await client.recv())
assert end["event"] == "stream_end"
assert end["stream_id"] == "s1"
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None:
port = 29881
channel = _ch(bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="s")
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
# Fill issued tokens to capacity
channel._issued_tokens = {
f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._MAX_ISSUED_TOKENS)
}
resp = await _http_get(
f"http://127.0.0.1:{port}/auth/token",
headers={"Authorization": "Bearer s"},
)
assert resp.status_code == 429
data = resp.json()
assert "error" in data
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_allow_from_rejects_unauthorized_client_id(bus: MagicMock) -> None:
port = 29882
channel = _ch(bus, port=port, allowFrom=["alice", "bob"])
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=eve"):
pass
assert exc_info.value.response.status_code == 403
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_client_id_truncation(bus: MagicMock) -> None:
port = 29883
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
long_id = "x" * 200
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id={long_id}") as client:
ready = json.loads(await client.recv())
assert ready["client_id"] == "x" * 128
assert len(ready["client_id"]) == 128
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_non_utf8_binary_frame_ignored(bus: MagicMock) -> None:
port = 29884
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=bin-test") as client:
await client.recv() # consume ready
# Send non-UTF-8 bytes
await client.send(b"\xff\xfe\xfd")
await asyncio.sleep(0.05)
# publish_inbound should NOT have been called
bus.publish_inbound.assert_not_awaited()
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_static_token_accepts_issued_token_as_fallback(bus: MagicMock) -> None:
port = 29885
channel = _ch(
bus, port=port,
token="static-secret",
tokenIssuePath="/auth/token",
tokenIssueSecret="route-secret",
)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
# Get an issued token
resp = await _http_get(
f"http://127.0.0.1:{port}/auth/token",
headers={"Authorization": "Bearer route-secret"},
)
assert resp.status_code == 200
issued_token = resp.json()["token"]
# Connect using issued token (not the static one)
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?token={issued_token}&client_id=caller") as client:
ready = json.loads(await client.recv())
assert ready["event"] == "ready"
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_allow_from_empty_list_denies_all(bus: MagicMock) -> None:
port = 29886
channel = _ch(bus, port=port, allowFrom=[])
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=anyone"):
pass
assert exc_info.value.response.status_code == 403
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_websocket_requires_token_without_issue_path(bus: MagicMock) -> None:
"""When websocket_requires_token is True but no token or issue path configured, all connections are rejected."""
port = 29887
channel = _ch(bus, port=port, websocketRequiresToken=True)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
# No token at all → 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u"):
pass
assert exc_info.value.response.status_code == 401
# Wrong token → 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u&token=wrong"):
pass
assert exc_info.value.response.status_code == 401
finally:
await channel.stop()
await server_task

View File

@ -0,0 +1,477 @@
"""Integration tests for the WebSocket channel using WsTestClient.
Complements the unit/lightweight tests in test_websocket_channel.py by covering
multi-client scenarios, edge cases, and realistic usage patterns.
"""
from __future__ import annotations
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
import websockets
from nanobot.channels.websocket import WebSocketChannel
from nanobot.bus.events import OutboundMessage
from ws_test_client import WsTestClient, issue_token, issue_token_ok
def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
cfg: dict[str, Any] = {
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/",
"websocketRequiresToken": False,
}
cfg.update(kw)
return WebSocketChannel(cfg, bus)
@pytest.fixture()
def bus() -> MagicMock:
b = MagicMock()
b.publish_inbound = AsyncMock()
return b
# -- Connection basics ----------------------------------------------------
@pytest.mark.asyncio
async def test_ready_event_fields(bus: MagicMock) -> None:
ch = _ch(bus, 29901)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29901/", client_id="c1") as c:
r = await c.recv_ready()
assert r.event == "ready"
assert len(r.chat_id) == 36
assert r.client_id == "c1"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None:
ch = _ch(bus, 29902)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29902/", client_id="") as c:
r = await c.recv_ready()
assert r.client_id.startswith("anon-")
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_each_connection_unique_chat_id(bus: MagicMock) -> None:
ch = _ch(bus, 29903)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29903/", client_id="a") as c1:
async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2:
assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id
finally:
await ch.stop(); await t
# -- Inbound messages (client -> server) ----------------------------------
@pytest.mark.asyncio
async def test_plain_text(bus: MagicMock) -> None:
ch = _ch(bus, 29904)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29904/", client_id="p") as c:
await c.recv_ready()
await c.send_text("hello world")
await asyncio.sleep(0.1)
inbound = bus.publish_inbound.call_args[0][0]
assert inbound.content == "hello world"
assert inbound.sender_id == "p"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_json_content_field(bus: MagicMock) -> None:
ch = _ch(bus, 29905)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29905/", client_id="j") as c:
await c.recv_ready()
await c.send_json({"content": "structured"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "structured"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_json_text_and_message_fields(bus: MagicMock) -> None:
ch = _ch(bus, 29906)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29906/", client_id="x") as c:
await c.recv_ready()
await c.send_json({"text": "via text"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "via text"
await c.send_json({"message": "via message"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "via message"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_empty_payload_ignored(bus: MagicMock) -> None:
ch = _ch(bus, 29907)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29907/", client_id="e") as c:
await c.recv_ready()
await c.send_text(" ")
await c.send_json({})
await asyncio.sleep(0.1)
bus.publish_inbound.assert_not_awaited()
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_messages_preserve_order(bus: MagicMock) -> None:
ch = _ch(bus, 29908)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29908/", client_id="o") as c:
await c.recv_ready()
for i in range(5):
await c.send_text(f"msg-{i}")
await asyncio.sleep(0.2)
contents = [call[0][0].content for call in bus.publish_inbound.call_args_list]
assert contents == [f"msg-{i}" for i in range(5)]
finally:
await ch.stop(); await t
# -- Outbound messages (server -> client) ---------------------------------
@pytest.mark.asyncio
async def test_server_send_message(bus: MagicMock) -> None:
ch = _ch(bus, 29909)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29909/", client_id="r") as c:
ready = await c.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content="reply",
))
msg = await c.recv_message()
assert msg.text == "reply"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_server_send_with_media_and_reply(bus: MagicMock) -> None:
ch = _ch(bus, 29910)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29910/", client_id="m") as c:
ready = await c.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content="img",
media=["/tmp/a.png"], reply_to="m1",
))
msg = await c.recv_message()
assert msg.text == "img"
assert msg.media == ["/tmp/a.png"]
assert msg.reply_to == "m1"
finally:
await ch.stop(); await t
# -- Streaming ------------------------------------------------------------
@pytest.mark.asyncio
async def test_streaming_deltas_and_end(bus: MagicMock) -> None:
ch = _ch(bus, 29911, streaming=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29911/", client_id="s") as c:
cid = (await c.recv_ready()).chat_id
for part in ("Hello", " ", "world", "!"):
await ch.send_delta(cid, part, {"_stream_delta": True, "_stream_id": "s1"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "s1"})
msgs = await c.collect_stream()
deltas = [m for m in msgs if m.event == "delta"]
assert "".join(d.text for d in deltas) == "Hello world!"
ends = [m for m in msgs if m.event == "stream_end"]
assert len(ends) == 1
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_interleaved_streams(bus: MagicMock) -> None:
ch = _ch(bus, 29912, streaming=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29912/", client_id="i") as c:
cid = (await c.recv_ready()).chat_id
await ch.send_delta(cid, "A1", {"_stream_delta": True, "_stream_id": "sa"})
await ch.send_delta(cid, "B1", {"_stream_delta": True, "_stream_id": "sb"})
await ch.send_delta(cid, "A2", {"_stream_delta": True, "_stream_id": "sa"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sa"})
await ch.send_delta(cid, "B2", {"_stream_delta": True, "_stream_id": "sb"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sb"})
msgs = await c.recv_n(6)
sa = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sa")
sb = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sb")
assert sa == "A1A2"
assert sb == "B1B2"
finally:
await ch.stop(); await t
# -- Multi-client ---------------------------------------------------------
@pytest.mark.asyncio
async def test_independent_sessions(bus: MagicMock) -> None:
ch = _ch(bus, 29913)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u1") as c1:
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u2") as c2:
r1, r2 = await c1.recv_ready(), await c2.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=r1.chat_id, content="for-u1",
))
assert (await c1.recv_message()).text == "for-u1"
await ch.send(OutboundMessage(
channel="websocket", chat_id=r2.chat_id, content="for-u2",
))
assert (await c2.recv_message()).text == "for-u2"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_disconnected_client_cleanup(bus: MagicMock) -> None:
ch = _ch(bus, 29914)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29914/", client_id="tmp") as c:
chat_id = (await c.recv_ready()).chat_id
# disconnected
await ch.send(OutboundMessage(
channel="websocket", chat_id=chat_id, content="orphan",
))
assert chat_id not in ch._connections
finally:
await ch.stop(); await t
# -- Authentication -------------------------------------------------------
@pytest.mark.asyncio
async def test_static_token_accepted(bus: MagicMock) -> None:
ch = _ch(bus, 29915, token="secret")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c:
assert (await c.recv_ready()).client_id == "a"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_static_token_rejected(bus: MagicMock) -> None:
ch = _ch(bus, 29916, token="correct")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29916/", client_id="b", token="wrong"):
pass
assert exc.value.response.status_code == 401
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_token_issue_full_flow(bus: MagicMock) -> None:
ch = _ch(bus, 29917, path="/ws",
tokenIssuePath="/auth/token", tokenIssueSecret="s",
websocketRequiresToken=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
# no secret -> 401
_, status = await issue_token(port=29917, issue_path="/auth/token")
assert status == 401
# with secret -> token
token = await issue_token_ok(port=29917, issue_path="/auth/token", secret="s")
# no token -> 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="x"):
pass
assert exc.value.response.status_code == 401
# valid token -> ok
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="ok", token=token) as c:
assert (await c.recv_ready()).client_id == "ok"
# reuse -> 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="r", token=token):
pass
assert exc.value.response.status_code == 401
finally:
await ch.stop(); await t
# -- Path routing ---------------------------------------------------------
@pytest.mark.asyncio
async def test_custom_path(bus: MagicMock) -> None:
ch = _ch(bus, 29918, path="/my-chat")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c:
assert (await c.recv_ready()).event == "ready"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_wrong_path_404(bus: MagicMock) -> None:
ch = _ch(bus, 29919, path="/ws")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29919/wrong", client_id="x"):
pass
assert exc.value.response.status_code == 404
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_trailing_slash_normalized(bus: MagicMock) -> None:
ch = _ch(bus, 29920, path="/ws")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c:
assert (await c.recv_ready()).event == "ready"
finally:
await ch.stop(); await t
# -- Edge cases -----------------------------------------------------------
@pytest.mark.asyncio
async def test_large_message(bus: MagicMock) -> None:
ch = _ch(bus, 29921)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29921/", client_id="big") as c:
await c.recv_ready()
big = "x" * 100_000
await c.send_text(big)
await asyncio.sleep(0.2)
assert bus.publish_inbound.call_args[0][0].content == big
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_unicode_roundtrip(bus: MagicMock) -> None:
ch = _ch(bus, 29922)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29922/", client_id="u") as c:
ready = await c.recv_ready()
text = "你好世界 🌍 日本語テスト"
await c.send_text(text)
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == text
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content=text,
))
assert (await c.recv_message()).text == text
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_rapid_fire(bus: MagicMock) -> None:
ch = _ch(bus, 29923)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29923/", client_id="r") as c:
ready = await c.recv_ready()
for i in range(50):
await c.send_text(f"in-{i}")
await asyncio.sleep(0.5)
assert bus.publish_inbound.await_count == 50
for i in range(50):
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content=f"out-{i}",
))
received = [(await c.recv_message()).text for _ in range(50)]
assert received == [f"out-{i}" for i in range(50)]
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_invalid_json_as_plain_text(bus: MagicMock) -> None:
ch = _ch(bus, 29924)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29924/", client_id="j") as c:
await c.recv_ready()
await c.send_text("{broken json")
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "{broken json"
finally:
await ch.stop(); await t

View File

@ -0,0 +1,227 @@
"""Lightweight WebSocket test client for integration testing the nanobot WebSocket channel.
Provides an async ``WsTestClient`` class and token-issuance helpers that
integration tests can import and use directly::
from ws_test_client import WsTestClient
async with WsTestClient("ws://127.0.0.1:8765/", client_id="t") as c:
ready = await c.recv_ready()
await c.send_text("hello")
msg = await c.recv_message()
"""
from __future__ import annotations
import asyncio
import json
from dataclasses import dataclass, field
from typing import Any
import httpx
import websockets
from websockets.asyncio.client import ClientConnection
@dataclass
class WsMessage:
"""A parsed message received from the WebSocket server."""
event: str
raw: dict[str, Any] = field(repr=False)
@property
def text(self) -> str | None:
return self.raw.get("text")
@property
def chat_id(self) -> str | None:
return self.raw.get("chat_id")
@property
def client_id(self) -> str | None:
return self.raw.get("client_id")
@property
def media(self) -> list[str] | None:
return self.raw.get("media")
@property
def reply_to(self) -> str | None:
return self.raw.get("reply_to")
@property
def stream_id(self) -> str | None:
return self.raw.get("stream_id")
def __eq__(self, other: object) -> bool:
if not isinstance(other, WsMessage):
return NotImplemented
return self.event == other.event and self.raw == other.raw
class WsTestClient:
"""Async WebSocket test client with helper methods for common operations.
Usage::
async with WsTestClient("ws://127.0.0.1:8765/", client_id="tester") as client:
ready = await client.recv_ready()
await client.send_text("hello")
msg = await client.recv_message(timeout=5.0)
"""
def __init__(
self,
uri: str,
*,
client_id: str = "test-client",
token: str = "",
extra_headers: dict[str, str] | None = None,
) -> None:
params: list[str] = []
if client_id:
params.append(f"client_id={client_id}")
if token:
params.append(f"token={token}")
sep = "&" if "?" in uri else "?"
self._uri = uri + sep + "&".join(params) if params else uri
self._extra_headers = extra_headers
self._ws: ClientConnection | None = None
async def connect(self) -> None:
self._ws = await websockets.connect(
self._uri,
additional_headers=self._extra_headers,
)
async def close(self) -> None:
if self._ws:
await self._ws.close()
self._ws = None
async def __aenter__(self) -> WsTestClient:
await self.connect()
return self
async def __aexit__(self, *args: Any) -> None:
await self.close()
@property
def ws(self) -> ClientConnection:
assert self._ws is not None, "Client is not connected"
return self._ws
# -- Receiving --------------------------------------------------------
async def recv_raw(self, timeout: float = 10.0) -> dict[str, Any]:
"""Receive and parse one raw JSON message with timeout."""
raw = await asyncio.wait_for(self.ws.recv(), timeout=timeout)
return json.loads(raw)
async def recv(self, timeout: float = 10.0) -> WsMessage:
"""Receive one message, returning a WsMessage wrapper."""
data = await self.recv_raw(timeout)
return WsMessage(event=data.get("event", ""), raw=data)
async def recv_ready(self, timeout: float = 5.0) -> WsMessage:
"""Receive and validate the 'ready' event."""
msg = await self.recv(timeout)
assert msg.event == "ready", f"Expected 'ready' event, got '{msg.event}'"
return msg
async def recv_message(self, timeout: float = 10.0) -> WsMessage:
"""Receive and validate a 'message' event."""
msg = await self.recv(timeout)
assert msg.event == "message", f"Expected 'message' event, got '{msg.event}'"
return msg
async def recv_delta(self, timeout: float = 10.0) -> WsMessage:
"""Receive and validate a 'delta' event."""
msg = await self.recv(timeout)
assert msg.event == "delta", f"Expected 'delta' event, got '{msg.event}'"
return msg
async def recv_stream_end(self, timeout: float = 10.0) -> WsMessage:
"""Receive and validate a 'stream_end' event."""
msg = await self.recv(timeout)
assert msg.event == "stream_end", f"Expected 'stream_end' event, got '{msg.event}'"
return msg
async def collect_stream(self, timeout: float = 10.0) -> list[WsMessage]:
"""Collect all deltas and the final stream_end into a list."""
messages: list[WsMessage] = []
while True:
msg = await self.recv(timeout)
messages.append(msg)
if msg.event == "stream_end":
break
return messages
async def recv_n(self, n: int, timeout: float = 10.0) -> list[WsMessage]:
"""Receive exactly *n* messages."""
return [await self.recv(timeout) for _ in range(n)]
# -- Sending ----------------------------------------------------------
async def send_text(self, text: str) -> None:
"""Send a plain text frame."""
await self.ws.send(text)
async def send_json(self, data: dict[str, Any]) -> None:
"""Send a JSON frame."""
await self.ws.send(json.dumps(data, ensure_ascii=False))
async def send_content(self, content: str) -> None:
"""Send content in the preferred JSON format ``{"content": ...}``."""
await self.send_json({"content": content})
# -- Connection introspection -----------------------------------------
@property
def closed(self) -> bool:
return self._ws is None or self._ws.closed
# -- Token issuance helpers -----------------------------------------------
async def issue_token(
host: str = "127.0.0.1",
port: int = 8765,
issue_path: str = "/auth/token",
secret: str = "",
) -> tuple[dict[str, Any] | None, int]:
"""Request a short-lived token from the token-issue HTTP endpoint.
Returns ``(parsed_json_or_None, status_code)``.
"""
url = f"http://{host}:{port}{issue_path}"
headers: dict[str, str] = {}
if secret:
headers["Authorization"] = f"Bearer {secret}"
loop = asyncio.get_running_loop()
resp = await loop.run_in_executor(
None, lambda: httpx.get(url, headers=headers, timeout=5.0)
)
try:
data = resp.json()
except Exception:
data = None
return data, resp.status_code
async def issue_token_ok(
host: str = "127.0.0.1",
port: int = 8765,
issue_path: str = "/auth/token",
secret: str = "",
) -> str:
"""Request a token, asserting success, and return the token string."""
(data, status) = await issue_token(host, port, issue_path, secret)
assert status == 200, f"Token issue failed with status {status}"
assert data is not None
token = data["token"]
assert token.startswith("nbwt_"), f"Unexpected token format: {token}"
return token