mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 05:03:41 +00:00
- Use hmac.compare_digest for timing-safe static token comparison - Add issued token capacity limit (_MAX_ISSUED_TOKENS=10000) with 429 response - Use atomic pop in _take_issued_token_if_valid to eliminate TOCTOU window - Enforce TLSv1.2 minimum version for SSL connections - Extract _safe_send helper for consistent ConnectionClosed handling - Move connection registration after ready send to prevent out-of-order delivery - Add HTTP-level allow_from check and client_id truncation in process_request - Make stop() idempotent with graceful shutdown error handling - Normalize path via validator instead of leaving raw value - Default websocket_requires_token to True for secure-by-default behavior - Add integration tests and ws_test_client helper - Refactor tests to use shared _ch factory and bus fixture
599 lines
20 KiB
Python
599 lines
20 KiB
Python
"""Unit and lightweight integration tests for the WebSocket channel."""
|
|
|
|
import asyncio
|
|
import functools
|
|
import json
|
|
import time
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import httpx
|
|
import pytest
|
|
import websockets
|
|
from websockets.exceptions import ConnectionClosed
|
|
from websockets.frames import Close
|
|
|
|
from nanobot.bus.events import OutboundMessage
|
|
from nanobot.channels.websocket import (
|
|
WebSocketChannel,
|
|
WebSocketConfig,
|
|
_issue_route_secret_matches,
|
|
_normalize_config_path,
|
|
_normalize_http_path,
|
|
_parse_inbound_payload,
|
|
_parse_query,
|
|
_parse_request_path,
|
|
)
|
|
|
|
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
|
|
|
|
_PORT = 29876
|
|
|
|
|
|
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
|
|
cfg: dict[str, Any] = {
|
|
"enabled": True,
|
|
"allowFrom": ["*"],
|
|
"host": "127.0.0.1",
|
|
"port": _PORT,
|
|
"path": "/ws",
|
|
"websocketRequiresToken": False,
|
|
}
|
|
cfg.update(kw)
|
|
return WebSocketChannel(cfg, bus)
|
|
|
|
|
|
@pytest.fixture()
|
|
def bus() -> MagicMock:
|
|
b = MagicMock()
|
|
b.publish_inbound = AsyncMock()
|
|
return b
|
|
|
|
|
|
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
|
|
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
|
|
return await asyncio.to_thread(
|
|
functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0)
|
|
)
|
|
|
|
|
|
def test_normalize_http_path_strips_trailing_slash_except_root() -> None:
|
|
assert _normalize_http_path("/chat/") == "/chat"
|
|
assert _normalize_http_path("/chat?x=1") == "/chat"
|
|
assert _normalize_http_path("/") == "/"
|
|
|
|
|
|
def test_parse_request_path_matches_normalize_and_query() -> None:
|
|
path, query = _parse_request_path("/ws/?token=secret&client_id=u1")
|
|
assert path == _normalize_http_path("/ws/?token=secret&client_id=u1")
|
|
assert query == _parse_query("/ws/?token=secret&client_id=u1")
|
|
|
|
|
|
def test_normalize_config_path_matches_request() -> None:
|
|
assert _normalize_config_path("/ws/") == "/ws"
|
|
assert _normalize_config_path("/") == "/"
|
|
|
|
|
|
def test_parse_query_extracts_token_and_client_id() -> None:
|
|
query = _parse_query("/?token=secret&client_id=u1")
|
|
assert query.get("token") == ["secret"]
|
|
assert query.get("client_id") == ["u1"]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("raw", "expected"),
|
|
[
|
|
("plain", "plain"),
|
|
('{"content": "hi"}', "hi"),
|
|
('{"text": "there"}', "there"),
|
|
('{"message": "x"}', "x"),
|
|
(" ", None),
|
|
("{}", None),
|
|
],
|
|
)
|
|
def test_parse_inbound_payload(raw: str, expected: str | None) -> None:
|
|
assert _parse_inbound_payload(raw) == expected
|
|
|
|
|
|
def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
|
|
assert _parse_inbound_payload("{not json") == "{not json"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("raw", "expected"),
|
|
[
|
|
('{"content": ""}', None), # empty string content
|
|
('{"content": 123}', None), # non-string content
|
|
('{"content": " "}', None), # whitespace-only content
|
|
('["hello"]', '["hello"]'), # JSON array: not a dict, treated as plain text
|
|
('{"unknown_key": "val"}', None), # unrecognized key
|
|
('{"content": null}', None), # null content
|
|
],
|
|
)
|
|
def test_parse_inbound_payload_edge_cases(raw: str, expected: str | None) -> None:
|
|
assert _parse_inbound_payload(raw) == expected
|
|
|
|
|
|
def test_web_socket_config_path_must_start_with_slash() -> None:
|
|
with pytest.raises(ValueError, match='path must start with "/"'):
|
|
WebSocketConfig(path="bad")
|
|
|
|
|
|
def test_ssl_context_requires_both_cert_and_key_files() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel(
|
|
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
|
|
bus,
|
|
)
|
|
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
|
|
channel._build_ssl_context()
|
|
|
|
|
|
def test_default_config_includes_safe_bind_and_streaming() -> None:
|
|
defaults = WebSocketChannel.default_config()
|
|
assert defaults["enabled"] is False
|
|
assert defaults["host"] == "127.0.0.1"
|
|
assert defaults["streaming"] is True
|
|
assert defaults["allowFrom"] == ["*"]
|
|
assert defaults.get("tokenIssuePath", "") == ""
|
|
|
|
|
|
def test_token_issue_path_must_differ_from_websocket_path() -> None:
|
|
with pytest.raises(ValueError, match="token_issue_path must differ"):
|
|
WebSocketConfig(path="/ws", token_issue_path="/ws")
|
|
|
|
|
|
def test_issue_route_secret_matches_bearer_and_header() -> None:
|
|
from websockets.datastructures import Headers
|
|
|
|
secret = "my-secret"
|
|
bearer_headers = Headers([("Authorization", "Bearer my-secret")])
|
|
assert _issue_route_secret_matches(bearer_headers, secret) is True
|
|
x_headers = Headers([("X-Nanobot-Auth", "my-secret")])
|
|
assert _issue_route_secret_matches(x_headers, secret) is True
|
|
wrong = Headers([("Authorization", "Bearer other")])
|
|
assert _issue_route_secret_matches(wrong, secret) is False
|
|
|
|
|
|
def test_issue_route_secret_matches_empty_secret() -> None:
|
|
from websockets.datastructures import Headers
|
|
|
|
# Empty secret always returns True regardless of headers
|
|
assert _issue_route_secret_matches(Headers([]), "") is True
|
|
assert _issue_route_secret_matches(Headers([("Authorization", "Bearer anything")]), "") is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
|
mock_ws = AsyncMock()
|
|
channel._connections["chat-1"] = mock_ws
|
|
|
|
msg = OutboundMessage(
|
|
channel="websocket",
|
|
chat_id="chat-1",
|
|
content="hello",
|
|
reply_to="m1",
|
|
media=["/tmp/a.png"],
|
|
)
|
|
await channel.send(msg)
|
|
|
|
mock_ws.send.assert_awaited_once()
|
|
payload = json.loads(mock_ws.send.call_args[0][0])
|
|
assert payload["event"] == "message"
|
|
assert payload["text"] == "hello"
|
|
assert payload["reply_to"] == "m1"
|
|
assert payload["media"] == ["/tmp/a.png"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_missing_connection_is_noop_without_error() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
|
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
|
|
await channel.send(msg)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_removes_connection_on_connection_closed() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
|
mock_ws = AsyncMock()
|
|
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
|
channel._connections["chat-1"] = mock_ws
|
|
|
|
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
|
|
await channel.send(msg)
|
|
|
|
assert "chat-1" not in channel._connections
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_delta_removes_connection_on_connection_closed() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
|
mock_ws = AsyncMock()
|
|
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
|
channel._connections["chat-1"] = mock_ws
|
|
|
|
await channel.send_delta("chat-1", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
|
|
|
|
assert "chat-1" not in channel._connections
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_delta_emits_delta_and_stream_end() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
|
mock_ws = AsyncMock()
|
|
channel._connections["chat-1"] = mock_ws
|
|
|
|
await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"})
|
|
await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"})
|
|
|
|
assert mock_ws.send.await_count == 2
|
|
first = json.loads(mock_ws.send.call_args_list[0][0][0])
|
|
second = json.loads(mock_ws.send.call_args_list[1][0][0])
|
|
assert first["event"] == "delta"
|
|
assert first["text"] == "part"
|
|
assert first["stream_id"] == "sid"
|
|
assert second["event"] == "stream_end"
|
|
assert second["stream_id"] == "sid"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_non_connection_closed_exception_is_raised() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
|
mock_ws = AsyncMock()
|
|
mock_ws.send.side_effect = RuntimeError("unexpected")
|
|
channel._connections["chat-1"] = mock_ws
|
|
|
|
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
|
|
with pytest.raises(RuntimeError, match="unexpected"):
|
|
await channel.send(msg)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_delta_missing_connection_is_noop() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
|
# No exception, no error — just a no-op
|
|
await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stop_is_idempotent() -> None:
|
|
bus = MagicMock()
|
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
|
# stop() before start() should not raise
|
|
await channel.stop()
|
|
await channel.stop()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound(bus: MagicMock) -> None:
|
|
port = 29876
|
|
channel = _ch(bus, port=port)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client:
|
|
ready_raw = await client.recv()
|
|
ready = json.loads(ready_raw)
|
|
assert ready["event"] == "ready"
|
|
assert ready["client_id"] == "tester"
|
|
chat_id = ready["chat_id"]
|
|
|
|
await client.send(json.dumps({"content": "ping from client"}))
|
|
await asyncio.sleep(0.08)
|
|
|
|
bus.publish_inbound.assert_awaited()
|
|
inbound = bus.publish_inbound.call_args[0][0]
|
|
assert inbound.channel == "websocket"
|
|
assert inbound.sender_id == "tester"
|
|
assert inbound.chat_id == chat_id
|
|
assert inbound.content == "ping from client"
|
|
|
|
await client.send("plain text frame")
|
|
await asyncio.sleep(0.08)
|
|
assert bus.publish_inbound.await_count >= 2
|
|
second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1]
|
|
assert second.content == "plain text frame"
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_rejects_handshake_when_mismatch(bus: MagicMock) -> None:
|
|
port = 29877
|
|
channel = _ch(bus, port=port, path="/", token="secret")
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"):
|
|
pass
|
|
assert excinfo.value.response.status_code == 401
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_wrong_path_returns_404(bus: MagicMock) -> None:
|
|
port = 29878
|
|
channel = _ch(bus, port=port)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/other"):
|
|
pass
|
|
assert excinfo.value.response.status_code == 404
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
def test_registry_discovers_websocket_channel() -> None:
|
|
from nanobot.channels.registry import load_channel_class
|
|
|
|
cls = load_channel_class("websocket")
|
|
assert cls.name == "websocket"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock) -> None:
|
|
port = 29879
|
|
channel = _ch(
|
|
bus, port=port,
|
|
tokenIssuePath="/auth/token",
|
|
tokenIssueSecret="route-secret",
|
|
websocketRequiresToken=True,
|
|
)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
deny = await _http_get(f"http://127.0.0.1:{port}/auth/token")
|
|
assert deny.status_code == 401
|
|
|
|
issue = await _http_get(
|
|
f"http://127.0.0.1:{port}/auth/token",
|
|
headers={"Authorization": "Bearer route-secret"},
|
|
)
|
|
assert issue.status_code == 200
|
|
token = issue.json()["token"]
|
|
assert token.startswith("nbwt_")
|
|
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"):
|
|
pass
|
|
assert missing_token.value.response.status_code == 401
|
|
|
|
uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller"
|
|
async with websockets.connect(uri) as client:
|
|
ready = json.loads(await client.recv())
|
|
assert ready["event"] == "ready"
|
|
assert ready["client_id"] == "caller"
|
|
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as reuse:
|
|
async with websockets.connect(uri):
|
|
pass
|
|
assert reuse.value.response.status_code == 401
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
|
|
port = 29880
|
|
channel = _ch(bus, port=port, streaming=True)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=stream-tester") as client:
|
|
ready_raw = await client.recv()
|
|
ready = json.loads(ready_raw)
|
|
chat_id = ready["chat_id"]
|
|
|
|
# Server pushes deltas directly
|
|
await channel.send_delta(
|
|
chat_id, "Hello ", {"_stream_delta": True, "_stream_id": "s1"}
|
|
)
|
|
await channel.send_delta(
|
|
chat_id, "world", {"_stream_delta": True, "_stream_id": "s1"}
|
|
)
|
|
await channel.send_delta(
|
|
chat_id, "", {"_stream_end": True, "_stream_id": "s1"}
|
|
)
|
|
|
|
delta1 = json.loads(await client.recv())
|
|
assert delta1["event"] == "delta"
|
|
assert delta1["text"] == "Hello "
|
|
assert delta1["stream_id"] == "s1"
|
|
|
|
delta2 = json.loads(await client.recv())
|
|
assert delta2["event"] == "delta"
|
|
assert delta2["text"] == "world"
|
|
assert delta2["stream_id"] == "s1"
|
|
|
|
end = json.loads(await client.recv())
|
|
assert end["event"] == "stream_end"
|
|
assert end["stream_id"] == "s1"
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None:
|
|
port = 29881
|
|
channel = _ch(bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="s")
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
# Fill issued tokens to capacity
|
|
channel._issued_tokens = {
|
|
f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._MAX_ISSUED_TOKENS)
|
|
}
|
|
|
|
resp = await _http_get(
|
|
f"http://127.0.0.1:{port}/auth/token",
|
|
headers={"Authorization": "Bearer s"},
|
|
)
|
|
assert resp.status_code == 429
|
|
data = resp.json()
|
|
assert "error" in data
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_allow_from_rejects_unauthorized_client_id(bus: MagicMock) -> None:
|
|
port = 29882
|
|
channel = _ch(bus, port=port, allowFrom=["alice", "bob"])
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=eve"):
|
|
pass
|
|
assert exc_info.value.response.status_code == 403
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_client_id_truncation(bus: MagicMock) -> None:
|
|
port = 29883
|
|
channel = _ch(bus, port=port)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
long_id = "x" * 200
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id={long_id}") as client:
|
|
ready = json.loads(await client.recv())
|
|
assert ready["client_id"] == "x" * 128
|
|
assert len(ready["client_id"]) == 128
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_non_utf8_binary_frame_ignored(bus: MagicMock) -> None:
|
|
port = 29884
|
|
channel = _ch(bus, port=port)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=bin-test") as client:
|
|
await client.recv() # consume ready
|
|
# Send non-UTF-8 bytes
|
|
await client.send(b"\xff\xfe\xfd")
|
|
await asyncio.sleep(0.05)
|
|
# publish_inbound should NOT have been called
|
|
bus.publish_inbound.assert_not_awaited()
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_static_token_accepts_issued_token_as_fallback(bus: MagicMock) -> None:
|
|
port = 29885
|
|
channel = _ch(
|
|
bus, port=port,
|
|
token="static-secret",
|
|
tokenIssuePath="/auth/token",
|
|
tokenIssueSecret="route-secret",
|
|
)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
# Get an issued token
|
|
resp = await _http_get(
|
|
f"http://127.0.0.1:{port}/auth/token",
|
|
headers={"Authorization": "Bearer route-secret"},
|
|
)
|
|
assert resp.status_code == 200
|
|
issued_token = resp.json()["token"]
|
|
|
|
# Connect using issued token (not the static one)
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?token={issued_token}&client_id=caller") as client:
|
|
ready = json.loads(await client.recv())
|
|
assert ready["event"] == "ready"
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_allow_from_empty_list_denies_all(bus: MagicMock) -> None:
|
|
port = 29886
|
|
channel = _ch(bus, port=port, allowFrom=[])
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=anyone"):
|
|
pass
|
|
assert exc_info.value.response.status_code == 403
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_websocket_requires_token_without_issue_path(bus: MagicMock) -> None:
|
|
"""When websocket_requires_token is True but no token or issue path configured, all connections are rejected."""
|
|
port = 29887
|
|
channel = _ch(bus, port=port, websocketRequiresToken=True)
|
|
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
|
|
try:
|
|
# No token at all → 401
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u"):
|
|
pass
|
|
assert exc_info.value.response.status_code == 401
|
|
|
|
# Wrong token → 401
|
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u&token=wrong"):
|
|
pass
|
|
assert exc_info.value.response.status_code == 401
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|