mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 05:03:41 +00:00
fix(websocket): harden security and robustness
- Use hmac.compare_digest for timing-safe static token comparison - Add issued token capacity limit (_MAX_ISSUED_TOKENS=10000) with 429 response - Use atomic pop in _take_issued_token_if_valid to eliminate TOCTOU window - Enforce TLSv1.2 minimum version for SSL connections - Extract _safe_send helper for consistent ConnectionClosed handling - Move connection registration after ready send to prevent out-of-order delivery - Add HTTP-level allow_from check and client_id truncation in process_request - Make stop() idempotent with graceful shutdown error handling - Normalize path via validator instead of leaving raw value - Default websocket_requires_token to True for secure-by-default behavior - Add integration tests and ws_test_client helper - Refactor tests to use shared _ch factory and bus fixture
This commit is contained in:
parent
d327c19db0
commit
8f7ce9fef7
@ -65,7 +65,7 @@ class WebSocketConfig(Base):
|
||||
token_issue_path: str = ""
|
||||
token_issue_secret: str = ""
|
||||
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
|
||||
websocket_requires_token: bool = False
|
||||
websocket_requires_token: bool = True
|
||||
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
||||
streaming: bool = True
|
||||
max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216)
|
||||
@ -79,7 +79,7 @@ class WebSocketConfig(Base):
|
||||
def path_must_start_with_slash(cls, value: str) -> str:
|
||||
if not value.startswith("/"):
|
||||
raise ValueError('path must start with "/"')
|
||||
return value
|
||||
return _normalize_config_path(value)
|
||||
|
||||
@field_validator("token_issue_path")
|
||||
@classmethod
|
||||
@ -130,6 +130,12 @@ def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||||
return _parse_request_path(path_with_query)[1]
|
||||
|
||||
|
||||
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
|
||||
"""Return the first value for *key*, or None."""
|
||||
values = query.get(key)
|
||||
return values[0] if values else None
|
||||
|
||||
|
||||
def _parse_inbound_payload(raw: str) -> str | None:
|
||||
"""Parse a client frame into text; return None for empty or unrecognized content."""
|
||||
text = raw.strip()
|
||||
@ -197,9 +203,12 @@ class WebSocketChannel(BaseChannel):
|
||||
"websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
||||
)
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||||
return ctx
|
||||
|
||||
_MAX_ISSUED_TOKENS = 10_000
|
||||
|
||||
def _purge_expired_issued_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self._issued_tokens.items()):
|
||||
@ -207,17 +216,19 @@ class WebSocketChannel(BaseChannel):
|
||||
self._issued_tokens.pop(token_key, None)
|
||||
|
||||
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||||
"""Validate and consume one issued token (single use per connection attempt)."""
|
||||
"""Validate and consume one issued token (single use per connection attempt).
|
||||
|
||||
Uses single-step pop to minimize the window between lookup and removal;
|
||||
safe under asyncio's single-threaded cooperative model.
|
||||
"""
|
||||
if not token_value:
|
||||
return False
|
||||
self._purge_expired_issued_tokens()
|
||||
expiry = self._issued_tokens.get(token_value)
|
||||
expiry = self._issued_tokens.pop(token_value, None)
|
||||
if expiry is None:
|
||||
return False
|
||||
if time.monotonic() > expiry:
|
||||
self._issued_tokens.pop(token_value, None)
|
||||
return False
|
||||
self._issued_tokens.pop(token_value, None)
|
||||
return True
|
||||
|
||||
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
|
||||
@ -231,6 +242,12 @@ class WebSocketChannel(BaseChannel):
|
||||
"any client can obtain connection tokens — set token_issue_secret for production."
|
||||
)
|
||||
self._purge_expired_issued_tokens()
|
||||
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
||||
logger.error(
|
||||
"websocket: too many outstanding issued tokens ({}), rejecting issuance",
|
||||
len(self._issued_tokens),
|
||||
)
|
||||
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
||||
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||||
|
||||
@ -238,13 +255,12 @@ class WebSocketChannel(BaseChannel):
|
||||
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||||
)
|
||||
|
||||
def _authorize_websocket_handshake(self, connection: Any, request_path: str) -> Any:
|
||||
query = _parse_query(request_path)
|
||||
supplied = (query.get("token") or [None])[0]
|
||||
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
|
||||
supplied = _query_first(query, "token")
|
||||
static_token = self.config.token.strip()
|
||||
|
||||
if static_token:
|
||||
if supplied == static_token:
|
||||
if supplied and hmac.compare_digest(supplied, static_token):
|
||||
return None
|
||||
if supplied and self._take_issued_token_if_valid(supplied):
|
||||
return None
|
||||
@ -279,7 +295,15 @@ class WebSocketChannel(BaseChannel):
|
||||
expected_ws = self._expected_path()
|
||||
if got != expected_ws:
|
||||
return connection.respond(404, "Not Found")
|
||||
return self._authorize_websocket_handshake(connection, request.path)
|
||||
# Early reject before WebSocket upgrade to avoid unnecessary overhead;
|
||||
# _handle_message() performs a second check as defense-in-depth.
|
||||
query = _parse_query(request.path)
|
||||
client_id = _query_first(query, "client_id") or ""
|
||||
if len(client_id) > 128:
|
||||
client_id = client_id[:128]
|
||||
if not self.is_allowed(client_id):
|
||||
return connection.respond(403, "Forbidden")
|
||||
return self._authorize_websocket_handshake(connection, query)
|
||||
|
||||
async def handler(connection: ServerConnection) -> None:
|
||||
await self._connection_loop(connection)
|
||||
@ -321,26 +345,30 @@ class WebSocketChannel(BaseChannel):
|
||||
request = connection.request
|
||||
path_part = request.path if request else "/"
|
||||
_, query = _parse_request_path(path_part)
|
||||
client_id_raw = (query.get("client_id") or [None])[0]
|
||||
client_id_raw = _query_first(query, "client_id")
|
||||
client_id = client_id_raw.strip() if client_id_raw else ""
|
||||
if not client_id:
|
||||
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
||||
elif len(client_id) > 128:
|
||||
logger.warning("websocket: client_id too long ({} chars), truncating", len(client_id))
|
||||
client_id = client_id[:128]
|
||||
|
||||
chat_id = str(uuid.uuid4())
|
||||
self._connections[chat_id] = connection
|
||||
|
||||
await connection.send(
|
||||
json.dumps(
|
||||
{
|
||||
"event": "ready",
|
||||
"chat_id": chat_id,
|
||||
"client_id": client_id,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
await connection.send(
|
||||
json.dumps(
|
||||
{
|
||||
"event": "ready",
|
||||
"chat_id": chat_id,
|
||||
"client_id": client_id,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
# Register only after ready is successfully sent to avoid out-of-order sends
|
||||
self._connections[chat_id] = connection
|
||||
|
||||
async for raw in connection:
|
||||
if isinstance(raw, bytes):
|
||||
try:
|
||||
@ -363,15 +391,34 @@ class WebSocketChannel(BaseChannel):
|
||||
self._connections.pop(chat_id, None)
|
||||
|
||||
async def stop(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._running = False
|
||||
if self._stop_event:
|
||||
self._stop_event.set()
|
||||
if self._server_task:
|
||||
await self._server_task
|
||||
try:
|
||||
await self._server_task
|
||||
except Exception as e:
|
||||
logger.warning("websocket: server task error during shutdown: {}", e)
|
||||
self._server_task = None
|
||||
self._connections.clear()
|
||||
self._issued_tokens.clear()
|
||||
|
||||
async def _safe_send(self, chat_id: str, raw: str, *, label: str = "") -> None:
|
||||
"""Send a raw frame, cleaning up dead connections on ConnectionClosed."""
|
||||
connection = self._connections.get(chat_id)
|
||||
if connection is None:
|
||||
return
|
||||
try:
|
||||
await connection.send(raw)
|
||||
except ConnectionClosed:
|
||||
self._connections.pop(chat_id, None)
|
||||
logger.warning("websocket{}connection gone for chat_id={}", label, chat_id)
|
||||
except Exception as e:
|
||||
logger.error("websocket{}send failed: {}", label, e)
|
||||
raise
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
connection = self._connections.get(msg.chat_id)
|
||||
if connection is None:
|
||||
@ -386,14 +433,7 @@ class WebSocketChannel(BaseChannel):
|
||||
if msg.reply_to:
|
||||
payload["reply_to"] = msg.reply_to
|
||||
raw = json.dumps(payload, ensure_ascii=False)
|
||||
try:
|
||||
await connection.send(raw)
|
||||
except ConnectionClosed:
|
||||
self._connections.pop(msg.chat_id, None)
|
||||
logger.warning("websocket: connection gone for chat_id={}", msg.chat_id)
|
||||
except Exception as e:
|
||||
logger.error("websocket send failed: {}", e)
|
||||
raise
|
||||
await self._safe_send(msg.chat_id, raw, label=" ")
|
||||
|
||||
async def send_delta(
|
||||
self,
|
||||
@ -401,27 +441,17 @@ class WebSocketChannel(BaseChannel):
|
||||
delta: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
connection = self._connections.get(chat_id)
|
||||
if connection is None:
|
||||
if self._connections.get(chat_id) is None:
|
||||
return
|
||||
meta = metadata or {}
|
||||
if meta.get("_stream_end"):
|
||||
body: dict[str, Any] = {"event": "stream_end"}
|
||||
if meta.get("_stream_id") is not None:
|
||||
body["stream_id"] = meta["_stream_id"]
|
||||
else:
|
||||
body = {
|
||||
"event": "delta",
|
||||
"text": delta,
|
||||
}
|
||||
if meta.get("_stream_id") is not None:
|
||||
body["stream_id"] = meta["_stream_id"]
|
||||
if meta.get("_stream_id") is not None:
|
||||
body["stream_id"] = meta["_stream_id"]
|
||||
raw = json.dumps(body, ensure_ascii=False)
|
||||
try:
|
||||
await connection.send(raw)
|
||||
except ConnectionClosed:
|
||||
self._connections.pop(chat_id, None)
|
||||
logger.warning("websocket: stream connection gone for chat_id={}", chat_id)
|
||||
except Exception as e:
|
||||
logger.error("websocket stream send failed: {}", e)
|
||||
raise
|
||||
await self._safe_send(chat_id, raw, label=" stream ")
|
||||
|
||||
@ -3,11 +3,15 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.frames import Close
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.channels.websocket import (
|
||||
@ -21,6 +25,30 @@ from nanobot.channels.websocket import (
|
||||
_parse_request_path,
|
||||
)
|
||||
|
||||
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
|
||||
|
||||
_PORT = 29876
|
||||
|
||||
|
||||
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
|
||||
cfg: dict[str, Any] = {
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": _PORT,
|
||||
"path": "/ws",
|
||||
"websocketRequiresToken": False,
|
||||
}
|
||||
cfg.update(kw)
|
||||
return WebSocketChannel(cfg, bus)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def bus() -> MagicMock:
|
||||
b = MagicMock()
|
||||
b.publish_inbound = AsyncMock()
|
||||
return b
|
||||
|
||||
|
||||
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
|
||||
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
|
||||
@ -71,6 +99,21 @@ def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
|
||||
assert _parse_inbound_payload("{not json") == "{not json"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "expected"),
|
||||
[
|
||||
('{"content": ""}', None), # empty string content
|
||||
('{"content": 123}', None), # non-string content
|
||||
('{"content": " "}', None), # whitespace-only content
|
||||
('["hello"]', '["hello"]'), # JSON array: not a dict, treated as plain text
|
||||
('{"unknown_key": "val"}', None), # unrecognized key
|
||||
('{"content": null}', None), # null content
|
||||
],
|
||||
)
|
||||
def test_parse_inbound_payload_edge_cases(raw: str, expected: str | None) -> None:
|
||||
assert _parse_inbound_payload(raw) == expected
|
||||
|
||||
|
||||
def test_web_socket_config_path_must_start_with_slash() -> None:
|
||||
with pytest.raises(ValueError, match='path must start with "/"'):
|
||||
WebSocketConfig(path="bad")
|
||||
@ -112,6 +155,14 @@ def test_issue_route_secret_matches_bearer_and_header() -> None:
|
||||
assert _issue_route_secret_matches(wrong, secret) is False
|
||||
|
||||
|
||||
def test_issue_route_secret_matches_empty_secret() -> None:
|
||||
from websockets.datastructures import Headers
|
||||
|
||||
# Empty secret always returns True regardless of headers
|
||||
assert _issue_route_secret_matches(Headers([]), "") is True
|
||||
assert _issue_route_secret_matches(Headers([("Authorization", "Bearer anything")]), "") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
bus = MagicMock()
|
||||
@ -144,6 +195,33 @@ async def test_send_missing_connection_is_noop_without_error() -> None:
|
||||
await channel.send(msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_removes_connection_on_connection_closed() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
|
||||
await channel.send(msg)
|
||||
|
||||
assert "chat-1" not in channel._connections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_removes_connection_on_connection_closed() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
await channel.send_delta("chat-1", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
|
||||
|
||||
assert "chat-1" not in channel._connections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_emits_delta_and_stream_end() -> None:
|
||||
bus = MagicMock()
|
||||
@ -165,20 +243,39 @@ async def test_send_delta_emits_delta_and_stream_end() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None:
|
||||
async def test_send_non_connection_closed_exception_is_raised() -> None:
|
||||
bus = MagicMock()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = RuntimeError("unexpected")
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
|
||||
with pytest.raises(RuntimeError, match="unexpected"):
|
||||
await channel.send(msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_missing_connection_is_noop() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||
# No exception, no error — just a no-op
|
||||
await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_is_idempotent() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
# stop() before start() should not raise
|
||||
await channel.stop()
|
||||
await channel.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound(bus: MagicMock) -> None:
|
||||
port = 29876
|
||||
channel = WebSocketChannel(
|
||||
{
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"path": "/ws",
|
||||
},
|
||||
bus,
|
||||
)
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
@ -212,20 +309,9 @@ async def test_end_to_end_client_receives_ready_and_agent_sees_inbound() -> None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_rejects_handshake_when_mismatch() -> None:
|
||||
bus = MagicMock()
|
||||
async def test_token_rejects_handshake_when_mismatch(bus: MagicMock) -> None:
|
||||
port = 29877
|
||||
channel = WebSocketChannel(
|
||||
{
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"path": "/",
|
||||
"token": "secret",
|
||||
},
|
||||
bus,
|
||||
)
|
||||
channel = _ch(bus, port=port, path="/", token="secret")
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
@ -241,19 +327,9 @@ async def test_token_rejects_handshake_when_mismatch() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_path_returns_404() -> None:
|
||||
bus = MagicMock()
|
||||
async def test_wrong_path_returns_404(bus: MagicMock) -> None:
|
||||
port = 29878
|
||||
channel = WebSocketChannel(
|
||||
{
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"path": "/ws",
|
||||
},
|
||||
bus,
|
||||
)
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
@ -276,22 +352,13 @@ def test_registry_discovers_websocket_channel() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_route_issues_token_then_websocket_requires_it() -> None:
|
||||
bus = MagicMock()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock) -> None:
|
||||
port = 29879
|
||||
channel = WebSocketChannel(
|
||||
{
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"path": "/ws",
|
||||
"tokenIssuePath": "/auth/token",
|
||||
"tokenIssueSecret": "route-secret",
|
||||
"websocketRequiresToken": True,
|
||||
},
|
||||
bus,
|
||||
channel = _ch(
|
||||
bus, port=port,
|
||||
tokenIssuePath="/auth/token",
|
||||
tokenIssueSecret="route-secret",
|
||||
websocketRequiresToken=True,
|
||||
)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
@ -327,3 +394,205 @@ async def test_http_route_issues_token_then_websocket_requires_it() -> None:
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
|
||||
port = 29880
|
||||
channel = _ch(bus, port=port, streaming=True)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=stream-tester") as client:
|
||||
ready_raw = await client.recv()
|
||||
ready = json.loads(ready_raw)
|
||||
chat_id = ready["chat_id"]
|
||||
|
||||
# Server pushes deltas directly
|
||||
await channel.send_delta(
|
||||
chat_id, "Hello ", {"_stream_delta": True, "_stream_id": "s1"}
|
||||
)
|
||||
await channel.send_delta(
|
||||
chat_id, "world", {"_stream_delta": True, "_stream_id": "s1"}
|
||||
)
|
||||
await channel.send_delta(
|
||||
chat_id, "", {"_stream_end": True, "_stream_id": "s1"}
|
||||
)
|
||||
|
||||
delta1 = json.loads(await client.recv())
|
||||
assert delta1["event"] == "delta"
|
||||
assert delta1["text"] == "Hello "
|
||||
assert delta1["stream_id"] == "s1"
|
||||
|
||||
delta2 = json.loads(await client.recv())
|
||||
assert delta2["event"] == "delta"
|
||||
assert delta2["text"] == "world"
|
||||
assert delta2["stream_id"] == "s1"
|
||||
|
||||
end = json.loads(await client.recv())
|
||||
assert end["event"] == "stream_end"
|
||||
assert end["stream_id"] == "s1"
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None:
|
||||
port = 29881
|
||||
channel = _ch(bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="s")
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
# Fill issued tokens to capacity
|
||||
channel._issued_tokens = {
|
||||
f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._MAX_ISSUED_TOKENS)
|
||||
}
|
||||
|
||||
resp = await _http_get(
|
||||
f"http://127.0.0.1:{port}/auth/token",
|
||||
headers={"Authorization": "Bearer s"},
|
||||
)
|
||||
assert resp.status_code == 429
|
||||
data = resp.json()
|
||||
assert "error" in data
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_from_rejects_unauthorized_client_id(bus: MagicMock) -> None:
|
||||
port = 29882
|
||||
channel = _ch(bus, port=port, allowFrom=["alice", "bob"])
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=eve"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 403
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_id_truncation(bus: MagicMock) -> None:
|
||||
port = 29883
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
long_id = "x" * 200
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id={long_id}") as client:
|
||||
ready = json.loads(await client.recv())
|
||||
assert ready["client_id"] == "x" * 128
|
||||
assert len(ready["client_id"]) == 128
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_utf8_binary_frame_ignored(bus: MagicMock) -> None:
|
||||
port = 29884
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=bin-test") as client:
|
||||
await client.recv() # consume ready
|
||||
# Send non-UTF-8 bytes
|
||||
await client.send(b"\xff\xfe\xfd")
|
||||
await asyncio.sleep(0.05)
|
||||
# publish_inbound should NOT have been called
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_token_accepts_issued_token_as_fallback(bus: MagicMock) -> None:
|
||||
port = 29885
|
||||
channel = _ch(
|
||||
bus, port=port,
|
||||
token="static-secret",
|
||||
tokenIssuePath="/auth/token",
|
||||
tokenIssueSecret="route-secret",
|
||||
)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
# Get an issued token
|
||||
resp = await _http_get(
|
||||
f"http://127.0.0.1:{port}/auth/token",
|
||||
headers={"Authorization": "Bearer route-secret"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
issued_token = resp.json()["token"]
|
||||
|
||||
# Connect using issued token (not the static one)
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?token={issued_token}&client_id=caller") as client:
|
||||
ready = json.loads(await client.recv())
|
||||
assert ready["event"] == "ready"
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_from_empty_list_denies_all(bus: MagicMock) -> None:
|
||||
port = 29886
|
||||
channel = _ch(bus, port=port, allowFrom=[])
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=anyone"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 403
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_requires_token_without_issue_path(bus: MagicMock) -> None:
|
||||
"""When websocket_requires_token is True but no token or issue path configured, all connections are rejected."""
|
||||
port = 29887
|
||||
channel = _ch(bus, port=port, websocketRequiresToken=True)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
# No token at all → 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 401
|
||||
|
||||
# Wrong token → 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u&token=wrong"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 401
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
477
tests/channels/test_websocket_integration.py
Normal file
477
tests/channels/test_websocket_integration.py
Normal file
@ -0,0 +1,477 @@
|
||||
"""Integration tests for the WebSocket channel using WsTestClient.
|
||||
|
||||
Complements the unit/lightweight tests in test_websocket_channel.py by covering
|
||||
multi-client scenarios, edge cases, and realistic usage patterns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
|
||||
from nanobot.channels.websocket import WebSocketChannel
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from ws_test_client import WsTestClient, issue_token, issue_token_ok
|
||||
|
||||
|
||||
def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
|
||||
cfg: dict[str, Any] = {
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"path": "/",
|
||||
"websocketRequiresToken": False,
|
||||
}
|
||||
cfg.update(kw)
|
||||
return WebSocketChannel(cfg, bus)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def bus() -> MagicMock:
|
||||
b = MagicMock()
|
||||
b.publish_inbound = AsyncMock()
|
||||
return b
|
||||
|
||||
|
||||
# -- Connection basics ----------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ready_event_fields(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29901)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29901/", client_id="c1") as c:
|
||||
r = await c.recv_ready()
|
||||
assert r.event == "ready"
|
||||
assert len(r.chat_id) == 36
|
||||
assert r.client_id == "c1"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29902)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29902/", client_id="") as c:
|
||||
r = await c.recv_ready()
|
||||
assert r.client_id.startswith("anon-")
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_each_connection_unique_chat_id(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29903)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29903/", client_id="a") as c1:
|
||||
async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2:
|
||||
assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Inbound messages (client -> server) ----------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_text(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29904)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29904/", client_id="p") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_text("hello world")
|
||||
await asyncio.sleep(0.1)
|
||||
inbound = bus.publish_inbound.call_args[0][0]
|
||||
assert inbound.content == "hello world"
|
||||
assert inbound.sender_id == "p"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_content_field(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29905)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29905/", client_id="j") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_json({"content": "structured"})
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "structured"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_text_and_message_fields(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29906)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29906/", client_id="x") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_json({"text": "via text"})
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "via text"
|
||||
await c.send_json({"message": "via message"})
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "via message"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_payload_ignored(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29907)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29907/", client_id="e") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_text(" ")
|
||||
await c.send_json({})
|
||||
await asyncio.sleep(0.1)
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_preserve_order(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29908)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29908/", client_id="o") as c:
|
||||
await c.recv_ready()
|
||||
for i in range(5):
|
||||
await c.send_text(f"msg-{i}")
|
||||
await asyncio.sleep(0.2)
|
||||
contents = [call[0][0].content for call in bus.publish_inbound.call_args_list]
|
||||
assert contents == [f"msg-{i}" for i in range(5)]
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Outbound messages (server -> client) ---------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_send_message(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29909)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29909/", client_id="r") as c:
|
||||
ready = await c.recv_ready()
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content="reply",
|
||||
))
|
||||
msg = await c.recv_message()
|
||||
assert msg.text == "reply"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_send_with_media_and_reply(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29910)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29910/", client_id="m") as c:
|
||||
ready = await c.recv_ready()
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content="img",
|
||||
media=["/tmp/a.png"], reply_to="m1",
|
||||
))
|
||||
msg = await c.recv_message()
|
||||
assert msg.text == "img"
|
||||
assert msg.media == ["/tmp/a.png"]
|
||||
assert msg.reply_to == "m1"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Streaming ------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_deltas_and_end(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29911, streaming=True)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29911/", client_id="s") as c:
|
||||
cid = (await c.recv_ready()).chat_id
|
||||
for part in ("Hello", " ", "world", "!"):
|
||||
await ch.send_delta(cid, part, {"_stream_delta": True, "_stream_id": "s1"})
|
||||
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "s1"})
|
||||
|
||||
msgs = await c.collect_stream()
|
||||
deltas = [m for m in msgs if m.event == "delta"]
|
||||
assert "".join(d.text for d in deltas) == "Hello world!"
|
||||
ends = [m for m in msgs if m.event == "stream_end"]
|
||||
assert len(ends) == 1
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interleaved_streams(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29912, streaming=True)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29912/", client_id="i") as c:
|
||||
cid = (await c.recv_ready()).chat_id
|
||||
await ch.send_delta(cid, "A1", {"_stream_delta": True, "_stream_id": "sa"})
|
||||
await ch.send_delta(cid, "B1", {"_stream_delta": True, "_stream_id": "sb"})
|
||||
await ch.send_delta(cid, "A2", {"_stream_delta": True, "_stream_id": "sa"})
|
||||
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sa"})
|
||||
await ch.send_delta(cid, "B2", {"_stream_delta": True, "_stream_id": "sb"})
|
||||
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sb"})
|
||||
|
||||
msgs = await c.recv_n(6)
|
||||
sa = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sa")
|
||||
sb = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sb")
|
||||
assert sa == "A1A2"
|
||||
assert sb == "B1B2"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Multi-client ---------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_independent_sessions(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29913)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u1") as c1:
|
||||
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u2") as c2:
|
||||
r1, r2 = await c1.recv_ready(), await c2.recv_ready()
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=r1.chat_id, content="for-u1",
|
||||
))
|
||||
assert (await c1.recv_message()).text == "for-u1"
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=r2.chat_id, content="for-u2",
|
||||
))
|
||||
assert (await c2.recv_message()).text == "for-u2"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnected_client_cleanup(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29914)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29914/", client_id="tmp") as c:
|
||||
chat_id = (await c.recv_ready()).chat_id
|
||||
# disconnected
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=chat_id, content="orphan",
|
||||
))
|
||||
assert chat_id not in ch._connections
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Authentication -------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_token_accepted(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29915, token="secret")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c:
|
||||
assert (await c.recv_ready()).client_id == "a"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_token_rejected(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29916, token="correct")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29916/", client_id="b", token="wrong"):
|
||||
pass
|
||||
assert exc.value.response.status_code == 401
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_issue_full_flow(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29917, path="/ws",
|
||||
tokenIssuePath="/auth/token", tokenIssueSecret="s",
|
||||
websocketRequiresToken=True)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
# no secret -> 401
|
||||
_, status = await issue_token(port=29917, issue_path="/auth/token")
|
||||
assert status == 401
|
||||
|
||||
# with secret -> token
|
||||
token = await issue_token_ok(port=29917, issue_path="/auth/token", secret="s")
|
||||
|
||||
# no token -> 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="x"):
|
||||
pass
|
||||
assert exc.value.response.status_code == 401
|
||||
|
||||
# valid token -> ok
|
||||
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="ok", token=token) as c:
|
||||
assert (await c.recv_ready()).client_id == "ok"
|
||||
|
||||
# reuse -> 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="r", token=token):
|
||||
pass
|
||||
assert exc.value.response.status_code == 401
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Path routing ---------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_path(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29918, path="/my-chat")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c:
|
||||
assert (await c.recv_ready()).event == "ready"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_path_404(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29919, path="/ws")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29919/wrong", client_id="x"):
|
||||
pass
|
||||
assert exc.value.response.status_code == 404
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trailing_slash_normalized(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29920, path="/ws")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c:
|
||||
assert (await c.recv_ready()).event == "ready"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Edge cases -----------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_message(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29921)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29921/", client_id="big") as c:
|
||||
await c.recv_ready()
|
||||
big = "x" * 100_000
|
||||
await c.send_text(big)
|
||||
await asyncio.sleep(0.2)
|
||||
assert bus.publish_inbound.call_args[0][0].content == big
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unicode_roundtrip(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29922)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29922/", client_id="u") as c:
|
||||
ready = await c.recv_ready()
|
||||
text = "你好世界 🌍 日本語テスト"
|
||||
await c.send_text(text)
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == text
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content=text,
|
||||
))
|
||||
assert (await c.recv_message()).text == text
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_fire(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29923)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29923/", client_id="r") as c:
|
||||
ready = await c.recv_ready()
|
||||
for i in range(50):
|
||||
await c.send_text(f"in-{i}")
|
||||
await asyncio.sleep(0.5)
|
||||
assert bus.publish_inbound.await_count == 50
|
||||
for i in range(50):
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content=f"out-{i}",
|
||||
))
|
||||
received = [(await c.recv_message()).text for _ in range(50)]
|
||||
assert received == [f"out-{i}" for i in range(50)]
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_as_plain_text(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29924)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29924/", client_id="j") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_text("{broken json")
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "{broken json"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
227
tests/channels/ws_test_client.py
Normal file
227
tests/channels/ws_test_client.py
Normal file
@ -0,0 +1,227 @@
|
||||
"""Lightweight WebSocket test client for integration testing the nanobot WebSocket channel.
|
||||
|
||||
Provides an async ``WsTestClient`` class and token-issuance helpers that
|
||||
integration tests can import and use directly::
|
||||
|
||||
from ws_test_client import WsTestClient
|
||||
|
||||
async with WsTestClient("ws://127.0.0.1:8765/", client_id="t") as c:
|
||||
ready = await c.recv_ready()
|
||||
await c.send_text("hello")
|
||||
msg = await c.recv_message()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
|
||||
@dataclass
|
||||
class WsMessage:
|
||||
"""A parsed message received from the WebSocket server."""
|
||||
|
||||
event: str
|
||||
raw: dict[str, Any] = field(repr=False)
|
||||
|
||||
@property
|
||||
def text(self) -> str | None:
|
||||
return self.raw.get("text")
|
||||
|
||||
@property
|
||||
def chat_id(self) -> str | None:
|
||||
return self.raw.get("chat_id")
|
||||
|
||||
@property
|
||||
def client_id(self) -> str | None:
|
||||
return self.raw.get("client_id")
|
||||
|
||||
@property
|
||||
def media(self) -> list[str] | None:
|
||||
return self.raw.get("media")
|
||||
|
||||
@property
|
||||
def reply_to(self) -> str | None:
|
||||
return self.raw.get("reply_to")
|
||||
|
||||
@property
|
||||
def stream_id(self) -> str | None:
|
||||
return self.raw.get("stream_id")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, WsMessage):
|
||||
return NotImplemented
|
||||
return self.event == other.event and self.raw == other.raw
|
||||
|
||||
|
||||
class WsTestClient:
|
||||
"""Async WebSocket test client with helper methods for common operations.
|
||||
|
||||
Usage::
|
||||
|
||||
async with WsTestClient("ws://127.0.0.1:8765/", client_id="tester") as client:
|
||||
ready = await client.recv_ready()
|
||||
await client.send_text("hello")
|
||||
msg = await client.recv_message(timeout=5.0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
client_id: str = "test-client",
|
||||
token: str = "",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
params: list[str] = []
|
||||
if client_id:
|
||||
params.append(f"client_id={client_id}")
|
||||
if token:
|
||||
params.append(f"token={token}")
|
||||
sep = "&" if "?" in uri else "?"
|
||||
self._uri = uri + sep + "&".join(params) if params else uri
|
||||
self._extra_headers = extra_headers
|
||||
self._ws: ClientConnection | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._ws = await websockets.connect(
|
||||
self._uri,
|
||||
additional_headers=self._extra_headers,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def __aenter__(self) -> WsTestClient:
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
@property
|
||||
def ws(self) -> ClientConnection:
|
||||
assert self._ws is not None, "Client is not connected"
|
||||
return self._ws
|
||||
|
||||
# -- Receiving --------------------------------------------------------
|
||||
|
||||
async def recv_raw(self, timeout: float = 10.0) -> dict[str, Any]:
|
||||
"""Receive and parse one raw JSON message with timeout."""
|
||||
raw = await asyncio.wait_for(self.ws.recv(), timeout=timeout)
|
||||
return json.loads(raw)
|
||||
|
||||
async def recv(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive one message, returning a WsMessage wrapper."""
|
||||
data = await self.recv_raw(timeout)
|
||||
return WsMessage(event=data.get("event", ""), raw=data)
|
||||
|
||||
async def recv_ready(self, timeout: float = 5.0) -> WsMessage:
|
||||
"""Receive and validate the 'ready' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "ready", f"Expected 'ready' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def recv_message(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive and validate a 'message' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "message", f"Expected 'message' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def recv_delta(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive and validate a 'delta' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "delta", f"Expected 'delta' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def recv_stream_end(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive and validate a 'stream_end' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "stream_end", f"Expected 'stream_end' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def collect_stream(self, timeout: float = 10.0) -> list[WsMessage]:
|
||||
"""Collect all deltas and the final stream_end into a list."""
|
||||
messages: list[WsMessage] = []
|
||||
while True:
|
||||
msg = await self.recv(timeout)
|
||||
messages.append(msg)
|
||||
if msg.event == "stream_end":
|
||||
break
|
||||
return messages
|
||||
|
||||
async def recv_n(self, n: int, timeout: float = 10.0) -> list[WsMessage]:
|
||||
"""Receive exactly *n* messages."""
|
||||
return [await self.recv(timeout) for _ in range(n)]
|
||||
|
||||
# -- Sending ----------------------------------------------------------
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send a plain text frame."""
|
||||
await self.ws.send(text)
|
||||
|
||||
async def send_json(self, data: dict[str, Any]) -> None:
|
||||
"""Send a JSON frame."""
|
||||
await self.ws.send(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
async def send_content(self, content: str) -> None:
|
||||
"""Send content in the preferred JSON format ``{"content": ...}``."""
|
||||
await self.send_json({"content": content})
|
||||
|
||||
# -- Connection introspection -----------------------------------------
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._ws is None or self._ws.closed
|
||||
|
||||
|
||||
# -- Token issuance helpers -----------------------------------------------
|
||||
|
||||
|
||||
async def issue_token(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
issue_path: str = "/auth/token",
|
||||
secret: str = "",
|
||||
) -> tuple[dict[str, Any] | None, int]:
|
||||
"""Request a short-lived token from the token-issue HTTP endpoint.
|
||||
|
||||
Returns ``(parsed_json_or_None, status_code)``.
|
||||
"""
|
||||
url = f"http://{host}:{port}{issue_path}"
|
||||
headers: dict[str, str] = {}
|
||||
if secret:
|
||||
headers["Authorization"] = f"Bearer {secret}"
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
resp = await loop.run_in_executor(
|
||||
None, lambda: httpx.get(url, headers=headers, timeout=5.0)
|
||||
)
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
data = None
|
||||
return data, resp.status_code
|
||||
|
||||
|
||||
async def issue_token_ok(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
issue_path: str = "/auth/token",
|
||||
secret: str = "",
|
||||
) -> str:
|
||||
"""Request a token, asserting success, and return the token string."""
|
||||
(data, status) = await issue_token(host, port, issue_path, secret)
|
||||
assert status == 200, f"Token issue failed with status {status}"
|
||||
assert data is not None
|
||||
token = data["token"]
|
||||
assert token.startswith("nbwt_"), f"Unexpected token format: {token}"
|
||||
return token
|
||||
Loading…
x
Reference in New Issue
Block a user