nanobot/tests/channels/test_websocket_integration.py
chengyongru 8f7ce9fef7 fix(websocket): harden security and robustness
- Use hmac.compare_digest for timing-safe static token comparison
- Add issued token capacity limit (_MAX_ISSUED_TOKENS=10000) with 429 response
- Use atomic pop in _take_issued_token_if_valid to eliminate TOCTOU window
- Enforce TLSv1.2 minimum version for SSL connections
- Extract _safe_send helper for consistent ConnectionClosed handling
- Move connection registration after ready send to prevent out-of-order delivery
- Add HTTP-level allow_from check and client_id truncation in process_request
- Make stop() idempotent with graceful shutdown error handling
- Normalize path via validator instead of leaving raw value
- Default websocket_requires_token to True for secure-by-default behavior
- Add integration tests and ws_test_client helper
- Refactor tests to use shared _ch factory and bus fixture
2026-04-09 15:56:34 +08:00

478 lines
16 KiB
Python

"""Integration tests for the WebSocket channel using WsTestClient.
Complements the unit/lightweight tests in test_websocket_channel.py by covering
multi-client scenarios, edge cases, and realistic usage patterns.
"""
from __future__ import annotations
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
import websockets
from nanobot.channels.websocket import WebSocketChannel
from nanobot.bus.events import OutboundMessage
from ws_test_client import WsTestClient, issue_token, issue_token_ok
def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
cfg: dict[str, Any] = {
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/",
"websocketRequiresToken": False,
}
cfg.update(kw)
return WebSocketChannel(cfg, bus)
@pytest.fixture()
def bus() -> MagicMock:
b = MagicMock()
b.publish_inbound = AsyncMock()
return b
# -- Connection basics ----------------------------------------------------
@pytest.mark.asyncio
async def test_ready_event_fields(bus: MagicMock) -> None:
ch = _ch(bus, 29901)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29901/", client_id="c1") as c:
r = await c.recv_ready()
assert r.event == "ready"
assert len(r.chat_id) == 36
assert r.client_id == "c1"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None:
ch = _ch(bus, 29902)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29902/", client_id="") as c:
r = await c.recv_ready()
assert r.client_id.startswith("anon-")
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_each_connection_unique_chat_id(bus: MagicMock) -> None:
ch = _ch(bus, 29903)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29903/", client_id="a") as c1:
async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2:
assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id
finally:
await ch.stop(); await t
# -- Inbound messages (client -> server) ----------------------------------
@pytest.mark.asyncio
async def test_plain_text(bus: MagicMock) -> None:
ch = _ch(bus, 29904)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29904/", client_id="p") as c:
await c.recv_ready()
await c.send_text("hello world")
await asyncio.sleep(0.1)
inbound = bus.publish_inbound.call_args[0][0]
assert inbound.content == "hello world"
assert inbound.sender_id == "p"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_json_content_field(bus: MagicMock) -> None:
ch = _ch(bus, 29905)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29905/", client_id="j") as c:
await c.recv_ready()
await c.send_json({"content": "structured"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "structured"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_json_text_and_message_fields(bus: MagicMock) -> None:
ch = _ch(bus, 29906)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29906/", client_id="x") as c:
await c.recv_ready()
await c.send_json({"text": "via text"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "via text"
await c.send_json({"message": "via message"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "via message"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_empty_payload_ignored(bus: MagicMock) -> None:
ch = _ch(bus, 29907)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29907/", client_id="e") as c:
await c.recv_ready()
await c.send_text(" ")
await c.send_json({})
await asyncio.sleep(0.1)
bus.publish_inbound.assert_not_awaited()
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_messages_preserve_order(bus: MagicMock) -> None:
ch = _ch(bus, 29908)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29908/", client_id="o") as c:
await c.recv_ready()
for i in range(5):
await c.send_text(f"msg-{i}")
await asyncio.sleep(0.2)
contents = [call[0][0].content for call in bus.publish_inbound.call_args_list]
assert contents == [f"msg-{i}" for i in range(5)]
finally:
await ch.stop(); await t
# -- Outbound messages (server -> client) ---------------------------------
@pytest.mark.asyncio
async def test_server_send_message(bus: MagicMock) -> None:
ch = _ch(bus, 29909)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29909/", client_id="r") as c:
ready = await c.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content="reply",
))
msg = await c.recv_message()
assert msg.text == "reply"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_server_send_with_media_and_reply(bus: MagicMock) -> None:
ch = _ch(bus, 29910)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29910/", client_id="m") as c:
ready = await c.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content="img",
media=["/tmp/a.png"], reply_to="m1",
))
msg = await c.recv_message()
assert msg.text == "img"
assert msg.media == ["/tmp/a.png"]
assert msg.reply_to == "m1"
finally:
await ch.stop(); await t
# -- Streaming ------------------------------------------------------------
@pytest.mark.asyncio
async def test_streaming_deltas_and_end(bus: MagicMock) -> None:
ch = _ch(bus, 29911, streaming=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29911/", client_id="s") as c:
cid = (await c.recv_ready()).chat_id
for part in ("Hello", " ", "world", "!"):
await ch.send_delta(cid, part, {"_stream_delta": True, "_stream_id": "s1"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "s1"})
msgs = await c.collect_stream()
deltas = [m for m in msgs if m.event == "delta"]
assert "".join(d.text for d in deltas) == "Hello world!"
ends = [m for m in msgs if m.event == "stream_end"]
assert len(ends) == 1
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_interleaved_streams(bus: MagicMock) -> None:
ch = _ch(bus, 29912, streaming=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29912/", client_id="i") as c:
cid = (await c.recv_ready()).chat_id
await ch.send_delta(cid, "A1", {"_stream_delta": True, "_stream_id": "sa"})
await ch.send_delta(cid, "B1", {"_stream_delta": True, "_stream_id": "sb"})
await ch.send_delta(cid, "A2", {"_stream_delta": True, "_stream_id": "sa"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sa"})
await ch.send_delta(cid, "B2", {"_stream_delta": True, "_stream_id": "sb"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sb"})
msgs = await c.recv_n(6)
sa = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sa")
sb = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sb")
assert sa == "A1A2"
assert sb == "B1B2"
finally:
await ch.stop(); await t
# -- Multi-client ---------------------------------------------------------
@pytest.mark.asyncio
async def test_independent_sessions(bus: MagicMock) -> None:
ch = _ch(bus, 29913)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u1") as c1:
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u2") as c2:
r1, r2 = await c1.recv_ready(), await c2.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=r1.chat_id, content="for-u1",
))
assert (await c1.recv_message()).text == "for-u1"
await ch.send(OutboundMessage(
channel="websocket", chat_id=r2.chat_id, content="for-u2",
))
assert (await c2.recv_message()).text == "for-u2"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_disconnected_client_cleanup(bus: MagicMock) -> None:
ch = _ch(bus, 29914)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29914/", client_id="tmp") as c:
chat_id = (await c.recv_ready()).chat_id
# disconnected
await ch.send(OutboundMessage(
channel="websocket", chat_id=chat_id, content="orphan",
))
assert chat_id not in ch._connections
finally:
await ch.stop(); await t
# -- Authentication -------------------------------------------------------
@pytest.mark.asyncio
async def test_static_token_accepted(bus: MagicMock) -> None:
ch = _ch(bus, 29915, token="secret")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c:
assert (await c.recv_ready()).client_id == "a"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_static_token_rejected(bus: MagicMock) -> None:
ch = _ch(bus, 29916, token="correct")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29916/", client_id="b", token="wrong"):
pass
assert exc.value.response.status_code == 401
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_token_issue_full_flow(bus: MagicMock) -> None:
ch = _ch(bus, 29917, path="/ws",
tokenIssuePath="/auth/token", tokenIssueSecret="s",
websocketRequiresToken=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
# no secret -> 401
_, status = await issue_token(port=29917, issue_path="/auth/token")
assert status == 401
# with secret -> token
token = await issue_token_ok(port=29917, issue_path="/auth/token", secret="s")
# no token -> 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="x"):
pass
assert exc.value.response.status_code == 401
# valid token -> ok
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="ok", token=token) as c:
assert (await c.recv_ready()).client_id == "ok"
# reuse -> 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="r", token=token):
pass
assert exc.value.response.status_code == 401
finally:
await ch.stop(); await t
# -- Path routing ---------------------------------------------------------
@pytest.mark.asyncio
async def test_custom_path(bus: MagicMock) -> None:
ch = _ch(bus, 29918, path="/my-chat")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c:
assert (await c.recv_ready()).event == "ready"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_wrong_path_404(bus: MagicMock) -> None:
ch = _ch(bus, 29919, path="/ws")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29919/wrong", client_id="x"):
pass
assert exc.value.response.status_code == 404
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_trailing_slash_normalized(bus: MagicMock) -> None:
ch = _ch(bus, 29920, path="/ws")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c:
assert (await c.recv_ready()).event == "ready"
finally:
await ch.stop(); await t
# -- Edge cases -----------------------------------------------------------
@pytest.mark.asyncio
async def test_large_message(bus: MagicMock) -> None:
ch = _ch(bus, 29921)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29921/", client_id="big") as c:
await c.recv_ready()
big = "x" * 100_000
await c.send_text(big)
await asyncio.sleep(0.2)
assert bus.publish_inbound.call_args[0][0].content == big
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_unicode_roundtrip(bus: MagicMock) -> None:
ch = _ch(bus, 29922)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29922/", client_id="u") as c:
ready = await c.recv_ready()
text = "你好世界 🌍 日本語テスト"
await c.send_text(text)
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == text
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content=text,
))
assert (await c.recv_message()).text == text
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_rapid_fire(bus: MagicMock) -> None:
ch = _ch(bus, 29923)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29923/", client_id="r") as c:
ready = await c.recv_ready()
for i in range(50):
await c.send_text(f"in-{i}")
await asyncio.sleep(0.5)
assert bus.publish_inbound.await_count == 50
for i in range(50):
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content=f"out-{i}",
))
received = [(await c.recv_message()).text for _ in range(50)]
assert received == [f"out-{i}" for i in range(50)]
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_invalid_json_as_plain_text(bus: MagicMock) -> None:
ch = _ch(bus, 29924)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29924/", client_id="j") as c:
await c.recv_ready()
await c.send_text("{broken json")
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "{broken json"
finally:
await ch.stop(); await t