"""Unit and lightweight integration tests for the WebSocket channel.""" import asyncio import functools import json import time from typing import Any from unittest.mock import AsyncMock, MagicMock import httpx import pytest import websockets from websockets.exceptions import ConnectionClosed from websockets.frames import Close from nanobot.bus.events import OutboundMessage from nanobot.channels.websocket import ( WebSocketChannel, WebSocketConfig, _issue_route_secret_matches, _normalize_config_path, _normalize_http_path, _parse_inbound_payload, _parse_query, _parse_request_path, ) # -- Shared helpers (aligned with test_websocket_integration.py) --------------- _PORT = 29876 def _ch(bus: Any, **kw: Any) -> WebSocketChannel: cfg: dict[str, Any] = { "enabled": True, "allowFrom": ["*"], "host": "127.0.0.1", "port": _PORT, "path": "/ws", "websocketRequiresToken": False, } cfg.update(kw) return WebSocketChannel(cfg, bus) @pytest.fixture() def bus() -> MagicMock: b = MagicMock() b.publish_inbound = AsyncMock() return b async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response: """Run GET in a thread to avoid blocking the asyncio loop shared with websockets.""" return await asyncio.to_thread( functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0) ) def test_normalize_http_path_strips_trailing_slash_except_root() -> None: assert _normalize_http_path("/chat/") == "/chat" assert _normalize_http_path("/chat?x=1") == "/chat" assert _normalize_http_path("/") == "/" def test_parse_request_path_matches_normalize_and_query() -> None: path, query = _parse_request_path("/ws/?token=secret&client_id=u1") assert path == _normalize_http_path("/ws/?token=secret&client_id=u1") assert query == _parse_query("/ws/?token=secret&client_id=u1") def test_normalize_config_path_matches_request() -> None: assert _normalize_config_path("/ws/") == "/ws" assert _normalize_config_path("/") == "/" def test_parse_query_extracts_token_and_client_id() -> None: query = _parse_query("/?token=secret&client_id=u1") assert query.get("token") == ["secret"] assert query.get("client_id") == ["u1"] @pytest.mark.parametrize( ("raw", "expected"), [ ("plain", "plain"), ('{"content": "hi"}', "hi"), ('{"text": "there"}', "there"), ('{"message": "x"}', "x"), (" ", None), ("{}", None), ], ) def test_parse_inbound_payload(raw: str, expected: str | None) -> None: assert _parse_inbound_payload(raw) == expected def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None: assert _parse_inbound_payload("{not json") == "{not json" @pytest.mark.parametrize( ("raw", "expected"), [ ('{"content": ""}', None), # empty string content ('{"content": 123}', None), # non-string content ('{"content": " "}', None), # whitespace-only content ('["hello"]', '["hello"]'), # JSON array: not a dict, treated as plain text ('{"unknown_key": "val"}', None), # unrecognized key ('{"content": null}', None), # null content ], ) def test_parse_inbound_payload_edge_cases(raw: str, expected: str | None) -> None: assert _parse_inbound_payload(raw) == expected def test_web_socket_config_path_must_start_with_slash() -> None: with pytest.raises(ValueError, match='path must start with "/"'): WebSocketConfig(path="bad") def test_ssl_context_requires_both_cert_and_key_files() -> None: bus = MagicMock() channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""}, bus, ) with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"): channel._build_ssl_context() def test_default_config_includes_safe_bind_and_streaming() -> None: defaults = WebSocketChannel.default_config() assert defaults["enabled"] is False assert defaults["host"] == "127.0.0.1" assert defaults["streaming"] is True assert defaults["allowFrom"] == ["*"] assert defaults.get("tokenIssuePath", "") == "" def test_token_issue_path_must_differ_from_websocket_path() -> None: with pytest.raises(ValueError, match="token_issue_path must differ"): WebSocketConfig(path="/ws", token_issue_path="/ws") def test_issue_route_secret_matches_bearer_and_header() -> None: from websockets.datastructures import Headers secret = "my-secret" bearer_headers = Headers([("Authorization", "Bearer my-secret")]) assert _issue_route_secret_matches(bearer_headers, secret) is True x_headers = Headers([("X-Nanobot-Auth", "my-secret")]) assert _issue_route_secret_matches(x_headers, secret) is True wrong = Headers([("Authorization", "Bearer other")]) assert _issue_route_secret_matches(wrong, secret) is False def test_issue_route_secret_matches_empty_secret() -> None: from websockets.datastructures import Headers # Empty secret always returns True regardless of headers assert _issue_route_secret_matches(Headers([]), "") is True assert _issue_route_secret_matches(Headers([("Authorization", "Bearer anything")]), "") is True @pytest.mark.asyncio async def test_send_delivers_json_message_with_media_and_reply() -> None: bus = MagicMock() channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) mock_ws = AsyncMock() channel._connections["chat-1"] = mock_ws msg = OutboundMessage( channel="websocket", chat_id="chat-1", content="hello", reply_to="m1", media=["/tmp/a.png"], ) await channel.send(msg) mock_ws.send.assert_awaited_once() payload = json.loads(mock_ws.send.call_args[0][0]) assert payload["event"] == "message" assert payload["text"] == "hello" assert payload["reply_to"] == "m1" assert payload["media"] == ["/tmp/a.png"] @pytest.mark.asyncio async def test_send_missing_connection_is_noop_without_error() -> None: bus = MagicMock() channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) msg = OutboundMessage(channel="websocket", chat_id="missing", content="x") await channel.send(msg) @pytest.mark.asyncio async def test_send_removes_connection_on_connection_closed() -> None: bus = MagicMock() channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) mock_ws = AsyncMock() mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True) channel._connections["chat-1"] = mock_ws msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello") await channel.send(msg) assert "chat-1" not in channel._connections @pytest.mark.asyncio async def test_send_delta_removes_connection_on_connection_closed() -> None: bus = MagicMock() channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus) mock_ws = AsyncMock() mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True) channel._connections["chat-1"] = mock_ws await channel.send_delta("chat-1", "chunk", {"_stream_delta": True, "_stream_id": "s1"}) assert "chat-1" not in channel._connections @pytest.mark.asyncio async def test_send_delta_emits_delta_and_stream_end() -> None: bus = MagicMock() channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus) mock_ws = AsyncMock() channel._connections["chat-1"] = mock_ws await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"}) await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"}) assert mock_ws.send.await_count == 2 first = json.loads(mock_ws.send.call_args_list[0][0][0]) second = json.loads(mock_ws.send.call_args_list[1][0][0]) assert first["event"] == "delta" assert first["text"] == "part" assert first["stream_id"] == "sid" assert second["event"] == "stream_end" assert second["stream_id"] == "sid" @pytest.mark.asyncio async def test_send_non_connection_closed_exception_is_raised() -> None: bus = MagicMock() channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) mock_ws = AsyncMock() mock_ws.send.side_effect = RuntimeError("unexpected") channel._connections["chat-1"] = mock_ws msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello") with pytest.raises(RuntimeError, match="unexpected"): await channel.send(msg) @pytest.mark.asyncio async def test_send_delta_missing_connection_is_noop() -> None: bus = MagicMock() channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus) # No exception, no error — just a no-op await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"}) @pytest.mark.asyncio async def test_stop_is_idempotent() -> None: bus = MagicMock() channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) # stop() before start() should not raise await channel.stop() await channel.stop() @pytest.mark.asyncio async def test_end_to_end_client_receives_ready_and_agent_sees_inbound(bus: MagicMock) -> None: port = 29876 channel = _ch(bus, port=port) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client: ready_raw = await client.recv() ready = json.loads(ready_raw) assert ready["event"] == "ready" assert ready["client_id"] == "tester" chat_id = ready["chat_id"] await client.send(json.dumps({"content": "ping from client"})) await asyncio.sleep(0.08) bus.publish_inbound.assert_awaited() inbound = bus.publish_inbound.call_args[0][0] assert inbound.channel == "websocket" assert inbound.sender_id == "tester" assert inbound.chat_id == chat_id assert inbound.content == "ping from client" await client.send("plain text frame") await asyncio.sleep(0.08) assert bus.publish_inbound.await_count >= 2 second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1] assert second.content == "plain text frame" finally: await channel.stop() await server_task @pytest.mark.asyncio async def test_token_rejects_handshake_when_mismatch(bus: MagicMock) -> None: port = 29877 channel = _ch(bus, port=port, path="/", token="secret") server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo: async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"): pass assert excinfo.value.response.status_code == 401 finally: await channel.stop() await server_task @pytest.mark.asyncio async def test_wrong_path_returns_404(bus: MagicMock) -> None: port = 29878 channel = _ch(bus, port=port) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo: async with websockets.connect(f"ws://127.0.0.1:{port}/other"): pass assert excinfo.value.response.status_code == 404 finally: await channel.stop() await server_task def test_registry_discovers_websocket_channel() -> None: from nanobot.channels.registry import load_channel_class cls = load_channel_class("websocket") assert cls.name == "websocket" @pytest.mark.asyncio async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock) -> None: port = 29879 channel = _ch( bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="route-secret", websocketRequiresToken=True, ) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: deny = await _http_get(f"http://127.0.0.1:{port}/auth/token") assert deny.status_code == 401 issue = await _http_get( f"http://127.0.0.1:{port}/auth/token", headers={"Authorization": "Bearer route-secret"}, ) assert issue.status_code == 200 token = issue.json()["token"] assert token.startswith("nbwt_") with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token: async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"): pass assert missing_token.value.response.status_code == 401 uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller" async with websockets.connect(uri) as client: ready = json.loads(await client.recv()) assert ready["event"] == "ready" assert ready["client_id"] == "caller" with pytest.raises(websockets.exceptions.InvalidStatus) as reuse: async with websockets.connect(uri): pass assert reuse.value.response.status_code == 401 finally: await channel.stop() await server_task @pytest.mark.asyncio async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None: port = 29880 channel = _ch(bus, port=port, streaming=True) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=stream-tester") as client: ready_raw = await client.recv() ready = json.loads(ready_raw) chat_id = ready["chat_id"] # Server pushes deltas directly await channel.send_delta( chat_id, "Hello ", {"_stream_delta": True, "_stream_id": "s1"} ) await channel.send_delta( chat_id, "world", {"_stream_delta": True, "_stream_id": "s1"} ) await channel.send_delta( chat_id, "", {"_stream_end": True, "_stream_id": "s1"} ) delta1 = json.loads(await client.recv()) assert delta1["event"] == "delta" assert delta1["text"] == "Hello " assert delta1["stream_id"] == "s1" delta2 = json.loads(await client.recv()) assert delta2["event"] == "delta" assert delta2["text"] == "world" assert delta2["stream_id"] == "s1" end = json.loads(await client.recv()) assert end["event"] == "stream_end" assert end["stream_id"] == "s1" finally: await channel.stop() await server_task @pytest.mark.asyncio async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None: port = 29881 channel = _ch(bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="s") server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: # Fill issued tokens to capacity channel._issued_tokens = { f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._MAX_ISSUED_TOKENS) } resp = await _http_get( f"http://127.0.0.1:{port}/auth/token", headers={"Authorization": "Bearer s"}, ) assert resp.status_code == 429 data = resp.json() assert "error" in data finally: await channel.stop() await server_task @pytest.mark.asyncio async def test_allow_from_rejects_unauthorized_client_id(bus: MagicMock) -> None: port = 29882 channel = _ch(bus, port=port, allowFrom=["alice", "bob"]) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=eve"): pass assert exc_info.value.response.status_code == 403 finally: await channel.stop() await server_task @pytest.mark.asyncio async def test_client_id_truncation(bus: MagicMock) -> None: port = 29883 channel = _ch(bus, port=port) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: long_id = "x" * 200 async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id={long_id}") as client: ready = json.loads(await client.recv()) assert ready["client_id"] == "x" * 128 assert len(ready["client_id"]) == 128 finally: await channel.stop() await server_task @pytest.mark.asyncio async def test_non_utf8_binary_frame_ignored(bus: MagicMock) -> None: port = 29884 channel = _ch(bus, port=port) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=bin-test") as client: await client.recv() # consume ready # Send non-UTF-8 bytes await client.send(b"\xff\xfe\xfd") await asyncio.sleep(0.05) # publish_inbound should NOT have been called bus.publish_inbound.assert_not_awaited() finally: await channel.stop() await server_task @pytest.mark.asyncio async def test_static_token_accepts_issued_token_as_fallback(bus: MagicMock) -> None: port = 29885 channel = _ch( bus, port=port, token="static-secret", tokenIssuePath="/auth/token", tokenIssueSecret="route-secret", ) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: # Get an issued token resp = await _http_get( f"http://127.0.0.1:{port}/auth/token", headers={"Authorization": "Bearer route-secret"}, ) assert resp.status_code == 200 issued_token = resp.json()["token"] # Connect using issued token (not the static one) async with websockets.connect(f"ws://127.0.0.1:{port}/ws?token={issued_token}&client_id=caller") as client: ready = json.loads(await client.recv()) assert ready["event"] == "ready" finally: await channel.stop() await server_task @pytest.mark.asyncio async def test_allow_from_empty_list_denies_all(bus: MagicMock) -> None: port = 29886 channel = _ch(bus, port=port, allowFrom=[]) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=anyone"): pass assert exc_info.value.response.status_code == 403 finally: await channel.stop() await server_task @pytest.mark.asyncio async def test_websocket_requires_token_without_issue_path(bus: MagicMock) -> None: """When websocket_requires_token is True but no token or issue path configured, all connections are rejected.""" port = 29887 channel = _ch(bus, port=port, websocketRequiresToken=True) server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: # No token at all → 401 with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u"): pass assert exc_info.value.response.status_code == 401 # Wrong token → 401 with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info: async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u&token=wrong"): pass assert exc_info.value.response.status_code == 401 finally: await channel.stop() await server_task