nanobot/tests/channels/test_websocket_channel.py
chengyongru 8f7ce9fef7 fix(websocket): harden security and robustness
- Use hmac.compare_digest for timing-safe static token comparison
- Add issued token capacity limit (_MAX_ISSUED_TOKENS=10000) with 429 response
- Use atomic pop in _take_issued_token_if_valid to eliminate TOCTOU window
- Enforce TLSv1.2 minimum version for SSL connections
- Extract _safe_send helper for consistent ConnectionClosed handling
- Move connection registration after ready send to prevent out-of-order delivery
- Add HTTP-level allow_from check and client_id truncation in process_request
- Make stop() idempotent with graceful shutdown error handling
- Normalize path via validator instead of leaving raw value
- Default websocket_requires_token to True for secure-by-default behavior
- Add integration tests and ws_test_client helper
- Refactor tests to use shared _ch factory and bus fixture
2026-04-09 15:56:34 +08:00

599 lines
20 KiB
Python

"""Unit and lightweight integration tests for the WebSocket channel."""
import asyncio
import functools
import json
import time
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import httpx
import pytest
import websockets
from websockets.exceptions import ConnectionClosed
from websockets.frames import Close
from nanobot.bus.events import OutboundMessage
from nanobot.channels.websocket import (
WebSocketChannel,
WebSocketConfig,
_issue_route_secret_matches,
_normalize_config_path,
_normalize_http_path,
_parse_inbound_payload,
_parse_query,
_parse_request_path,
)
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
_PORT = 29876
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
cfg: dict[str, Any] = {
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": _PORT,
"path": "/ws",
"websocketRequiresToken": False,
}
cfg.update(kw)
return WebSocketChannel(cfg, bus)
@pytest.fixture()
def bus() -> MagicMock:
b = MagicMock()
b.publish_inbound = AsyncMock()
return b
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
return await asyncio.to_thread(
functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0)
)
def test_normalize_http_path_strips_trailing_slash_except_root() -> None:
assert _normalize_http_path("/chat/") == "/chat"
assert _normalize_http_path("/chat?x=1") == "/chat"
assert _normalize_http_path("/") == "/"
def test_parse_request_path_matches_normalize_and_query() -> None:
path, query = _parse_request_path("/ws/?token=secret&client_id=u1")
assert path == _normalize_http_path("/ws/?token=secret&client_id=u1")
assert query == _parse_query("/ws/?token=secret&client_id=u1")
def test_normalize_config_path_matches_request() -> None:
assert _normalize_config_path("/ws/") == "/ws"
assert _normalize_config_path("/") == "/"
def test_parse_query_extracts_token_and_client_id() -> None:
query = _parse_query("/?token=secret&client_id=u1")
assert query.get("token") == ["secret"]
assert query.get("client_id") == ["u1"]
@pytest.mark.parametrize(
("raw", "expected"),
[
("plain", "plain"),
('{"content": "hi"}', "hi"),
('{"text": "there"}', "there"),
('{"message": "x"}', "x"),
(" ", None),
("{}", None),
],
)
def test_parse_inbound_payload(raw: str, expected: str | None) -> None:
assert _parse_inbound_payload(raw) == expected
def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
assert _parse_inbound_payload("{not json") == "{not json"
@pytest.mark.parametrize(
("raw", "expected"),
[
('{"content": ""}', None), # empty string content
('{"content": 123}', None), # non-string content
('{"content": " "}', None), # whitespace-only content
('["hello"]', '["hello"]'), # JSON array: not a dict, treated as plain text
('{"unknown_key": "val"}', None), # unrecognized key
('{"content": null}', None), # null content
],
)
def test_parse_inbound_payload_edge_cases(raw: str, expected: str | None) -> None:
assert _parse_inbound_payload(raw) == expected
def test_web_socket_config_path_must_start_with_slash() -> None:
with pytest.raises(ValueError, match='path must start with "/"'):
WebSocketConfig(path="bad")
def test_ssl_context_requires_both_cert_and_key_files() -> None:
bus = MagicMock()
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
bus,
)
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
channel._build_ssl_context()
def test_default_config_includes_safe_bind_and_streaming() -> None:
defaults = WebSocketChannel.default_config()
assert defaults["enabled"] is False
assert defaults["host"] == "127.0.0.1"
assert defaults["streaming"] is True
assert defaults["allowFrom"] == ["*"]
assert defaults.get("tokenIssuePath", "") == ""
def test_token_issue_path_must_differ_from_websocket_path() -> None:
with pytest.raises(ValueError, match="token_issue_path must differ"):
WebSocketConfig(path="/ws", token_issue_path="/ws")
def test_issue_route_secret_matches_bearer_and_header() -> None:
from websockets.datastructures import Headers
secret = "my-secret"
bearer_headers = Headers([("Authorization", "Bearer my-secret")])
assert _issue_route_secret_matches(bearer_headers, secret) is True
x_headers = Headers([("X-Nanobot-Auth", "my-secret")])
assert _issue_route_secret_matches(x_headers, secret) is True
wrong = Headers([("Authorization", "Bearer other")])
assert _issue_route_secret_matches(wrong, secret) is False
def test_issue_route_secret_matches_empty_secret() -> None:
from websockets.datastructures import Headers
# Empty secret always returns True regardless of headers
assert _issue_route_secret_matches(Headers([]), "") is True
assert _issue_route_secret_matches(Headers([("Authorization", "Bearer anything")]), "") is True
@pytest.mark.asyncio
async def test_send_delivers_json_message_with_media_and_reply() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock()
channel._connections["chat-1"] = mock_ws
msg = OutboundMessage(
channel="websocket",
chat_id="chat-1",
content="hello",
reply_to="m1",
media=["/tmp/a.png"],
)
await channel.send(msg)
mock_ws.send.assert_awaited_once()
payload = json.loads(mock_ws.send.call_args[0][0])
assert payload["event"] == "message"
assert payload["text"] == "hello"
assert payload["reply_to"] == "m1"
assert payload["media"] == ["/tmp/a.png"]
@pytest.mark.asyncio
async def test_send_missing_connection_is_noop_without_error() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
await channel.send(msg)
@pytest.mark.asyncio
async def test_send_removes_connection_on_connection_closed() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock()
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
channel._connections["chat-1"] = mock_ws
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
await channel.send(msg)
assert "chat-1" not in channel._connections
@pytest.mark.asyncio
async def test_send_delta_removes_connection_on_connection_closed() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
mock_ws = AsyncMock()
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
channel._connections["chat-1"] = mock_ws
await channel.send_delta("chat-1", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
assert "chat-1" not in channel._connections
@pytest.mark.asyncio
async def test_send_delta_emits_delta_and_stream_end() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
mock_ws = AsyncMock()
channel._connections["chat-1"] = mock_ws
await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"})
await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"})
assert mock_ws.send.await_count == 2
first = json.loads(mock_ws.send.call_args_list[0][0][0])
second = json.loads(mock_ws.send.call_args_list[1][0][0])
assert first["event"] == "delta"
assert first["text"] == "part"
assert first["stream_id"] == "sid"
assert second["event"] == "stream_end"
assert second["stream_id"] == "sid"
@pytest.mark.asyncio
async def test_send_non_connection_closed_exception_is_raised() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock()
mock_ws.send.side_effect = RuntimeError("unexpected")
channel._connections["chat-1"] = mock_ws
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
with pytest.raises(RuntimeError, match="unexpected"):
await channel.send(msg)
@pytest.mark.asyncio
async def test_send_delta_missing_connection_is_noop() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
# No exception, no error — just a no-op
await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
@pytest.mark.asyncio
async def test_stop_is_idempotent() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
# stop() before start() should not raise
await channel.stop()
await channel.stop()
@pytest.mark.asyncio
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound(bus: MagicMock) -> None:
port = 29876
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client:
ready_raw = await client.recv()
ready = json.loads(ready_raw)
assert ready["event"] == "ready"
assert ready["client_id"] == "tester"
chat_id = ready["chat_id"]
await client.send(json.dumps({"content": "ping from client"}))
await asyncio.sleep(0.08)
bus.publish_inbound.assert_awaited()
inbound = bus.publish_inbound.call_args[0][0]
assert inbound.channel == "websocket"
assert inbound.sender_id == "tester"
assert inbound.chat_id == chat_id
assert inbound.content == "ping from client"
await client.send("plain text frame")
await asyncio.sleep(0.08)
assert bus.publish_inbound.await_count >= 2
second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1]
assert second.content == "plain text frame"
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_token_rejects_handshake_when_mismatch(bus: MagicMock) -> None:
port = 29877
channel = _ch(bus, port=port, path="/", token="secret")
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"):
pass
assert excinfo.value.response.status_code == 401
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_wrong_path_returns_404(bus: MagicMock) -> None:
port = 29878
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
async with websockets.connect(f"ws://127.0.0.1:{port}/other"):
pass
assert excinfo.value.response.status_code == 404
finally:
await channel.stop()
await server_task
def test_registry_discovers_websocket_channel() -> None:
from nanobot.channels.registry import load_channel_class
cls = load_channel_class("websocket")
assert cls.name == "websocket"
@pytest.mark.asyncio
async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock) -> None:
port = 29879
channel = _ch(
bus, port=port,
tokenIssuePath="/auth/token",
tokenIssueSecret="route-secret",
websocketRequiresToken=True,
)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
deny = await _http_get(f"http://127.0.0.1:{port}/auth/token")
assert deny.status_code == 401
issue = await _http_get(
f"http://127.0.0.1:{port}/auth/token",
headers={"Authorization": "Bearer route-secret"},
)
assert issue.status_code == 200
token = issue.json()["token"]
assert token.startswith("nbwt_")
with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"):
pass
assert missing_token.value.response.status_code == 401
uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller"
async with websockets.connect(uri) as client:
ready = json.loads(await client.recv())
assert ready["event"] == "ready"
assert ready["client_id"] == "caller"
with pytest.raises(websockets.exceptions.InvalidStatus) as reuse:
async with websockets.connect(uri):
pass
assert reuse.value.response.status_code == 401
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
port = 29880
channel = _ch(bus, port=port, streaming=True)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=stream-tester") as client:
ready_raw = await client.recv()
ready = json.loads(ready_raw)
chat_id = ready["chat_id"]
# Server pushes deltas directly
await channel.send_delta(
chat_id, "Hello ", {"_stream_delta": True, "_stream_id": "s1"}
)
await channel.send_delta(
chat_id, "world", {"_stream_delta": True, "_stream_id": "s1"}
)
await channel.send_delta(
chat_id, "", {"_stream_end": True, "_stream_id": "s1"}
)
delta1 = json.loads(await client.recv())
assert delta1["event"] == "delta"
assert delta1["text"] == "Hello "
assert delta1["stream_id"] == "s1"
delta2 = json.loads(await client.recv())
assert delta2["event"] == "delta"
assert delta2["text"] == "world"
assert delta2["stream_id"] == "s1"
end = json.loads(await client.recv())
assert end["event"] == "stream_end"
assert end["stream_id"] == "s1"
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None:
port = 29881
channel = _ch(bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="s")
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
# Fill issued tokens to capacity
channel._issued_tokens = {
f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._MAX_ISSUED_TOKENS)
}
resp = await _http_get(
f"http://127.0.0.1:{port}/auth/token",
headers={"Authorization": "Bearer s"},
)
assert resp.status_code == 429
data = resp.json()
assert "error" in data
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_allow_from_rejects_unauthorized_client_id(bus: MagicMock) -> None:
port = 29882
channel = _ch(bus, port=port, allowFrom=["alice", "bob"])
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=eve"):
pass
assert exc_info.value.response.status_code == 403
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_client_id_truncation(bus: MagicMock) -> None:
port = 29883
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
long_id = "x" * 200
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id={long_id}") as client:
ready = json.loads(await client.recv())
assert ready["client_id"] == "x" * 128
assert len(ready["client_id"]) == 128
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_non_utf8_binary_frame_ignored(bus: MagicMock) -> None:
port = 29884
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=bin-test") as client:
await client.recv() # consume ready
# Send non-UTF-8 bytes
await client.send(b"\xff\xfe\xfd")
await asyncio.sleep(0.05)
# publish_inbound should NOT have been called
bus.publish_inbound.assert_not_awaited()
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_static_token_accepts_issued_token_as_fallback(bus: MagicMock) -> None:
port = 29885
channel = _ch(
bus, port=port,
token="static-secret",
tokenIssuePath="/auth/token",
tokenIssueSecret="route-secret",
)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
# Get an issued token
resp = await _http_get(
f"http://127.0.0.1:{port}/auth/token",
headers={"Authorization": "Bearer route-secret"},
)
assert resp.status_code == 200
issued_token = resp.json()["token"]
# Connect using issued token (not the static one)
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?token={issued_token}&client_id=caller") as client:
ready = json.loads(await client.recv())
assert ready["event"] == "ready"
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_allow_from_empty_list_denies_all(bus: MagicMock) -> None:
port = 29886
channel = _ch(bus, port=port, allowFrom=[])
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=anyone"):
pass
assert exc_info.value.response.status_code == 403
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_websocket_requires_token_without_issue_path(bus: MagicMock) -> None:
"""When websocket_requires_token is True but no token or issue path configured, all connections are rejected."""
port = 29887
channel = _ch(bus, port=port, websocketRequiresToken=True)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
# No token at all → 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u"):
pass
assert exc_info.value.response.status_code == 401
# Wrong token → 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u&token=wrong"):
pass
assert exc_info.value.response.status_code == 401
finally:
await channel.stop()
await server_task