mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 13:13:39 +00:00
- Use hmac.compare_digest for timing-safe static token comparison - Add issued token capacity limit (_MAX_ISSUED_TOKENS=10000) with 429 response - Use atomic pop in _take_issued_token_if_valid to eliminate TOCTOU window - Enforce TLSv1.2 minimum version for SSL connections - Extract _safe_send helper for consistent ConnectionClosed handling - Move connection registration after ready send to prevent out-of-order delivery - Add HTTP-level allow_from check and client_id truncation in process_request - Make stop() idempotent with graceful shutdown error handling - Normalize path via validator instead of leaving raw value - Default websocket_requires_token to True for secure-by-default behavior - Add integration tests and ws_test_client helper - Refactor tests to use shared _ch factory and bus fixture
228 lines
7.1 KiB
Python
228 lines
7.1 KiB
Python
"""Lightweight WebSocket test client for integration testing the nanobot WebSocket channel.
|
|
|
|
Provides an async ``WsTestClient`` class and token-issuance helpers that
|
|
integration tests can import and use directly::
|
|
|
|
from ws_test_client import WsTestClient
|
|
|
|
async with WsTestClient("ws://127.0.0.1:8765/", client_id="t") as c:
|
|
ready = await c.recv_ready()
|
|
await c.send_text("hello")
|
|
msg = await c.recv_message()
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
import httpx
|
|
import websockets
|
|
from websockets.asyncio.client import ClientConnection
|
|
|
|
|
|
@dataclass
|
|
class WsMessage:
|
|
"""A parsed message received from the WebSocket server."""
|
|
|
|
event: str
|
|
raw: dict[str, Any] = field(repr=False)
|
|
|
|
@property
|
|
def text(self) -> str | None:
|
|
return self.raw.get("text")
|
|
|
|
@property
|
|
def chat_id(self) -> str | None:
|
|
return self.raw.get("chat_id")
|
|
|
|
@property
|
|
def client_id(self) -> str | None:
|
|
return self.raw.get("client_id")
|
|
|
|
@property
|
|
def media(self) -> list[str] | None:
|
|
return self.raw.get("media")
|
|
|
|
@property
|
|
def reply_to(self) -> str | None:
|
|
return self.raw.get("reply_to")
|
|
|
|
@property
|
|
def stream_id(self) -> str | None:
|
|
return self.raw.get("stream_id")
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, WsMessage):
|
|
return NotImplemented
|
|
return self.event == other.event and self.raw == other.raw
|
|
|
|
|
|
class WsTestClient:
|
|
"""Async WebSocket test client with helper methods for common operations.
|
|
|
|
Usage::
|
|
|
|
async with WsTestClient("ws://127.0.0.1:8765/", client_id="tester") as client:
|
|
ready = await client.recv_ready()
|
|
await client.send_text("hello")
|
|
msg = await client.recv_message(timeout=5.0)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
uri: str,
|
|
*,
|
|
client_id: str = "test-client",
|
|
token: str = "",
|
|
extra_headers: dict[str, str] | None = None,
|
|
) -> None:
|
|
params: list[str] = []
|
|
if client_id:
|
|
params.append(f"client_id={client_id}")
|
|
if token:
|
|
params.append(f"token={token}")
|
|
sep = "&" if "?" in uri else "?"
|
|
self._uri = uri + sep + "&".join(params) if params else uri
|
|
self._extra_headers = extra_headers
|
|
self._ws: ClientConnection | None = None
|
|
|
|
async def connect(self) -> None:
|
|
self._ws = await websockets.connect(
|
|
self._uri,
|
|
additional_headers=self._extra_headers,
|
|
)
|
|
|
|
async def close(self) -> None:
|
|
if self._ws:
|
|
await self._ws.close()
|
|
self._ws = None
|
|
|
|
async def __aenter__(self) -> WsTestClient:
|
|
await self.connect()
|
|
return self
|
|
|
|
async def __aexit__(self, *args: Any) -> None:
|
|
await self.close()
|
|
|
|
@property
|
|
def ws(self) -> ClientConnection:
|
|
assert self._ws is not None, "Client is not connected"
|
|
return self._ws
|
|
|
|
# -- Receiving --------------------------------------------------------
|
|
|
|
async def recv_raw(self, timeout: float = 10.0) -> dict[str, Any]:
|
|
"""Receive and parse one raw JSON message with timeout."""
|
|
raw = await asyncio.wait_for(self.ws.recv(), timeout=timeout)
|
|
return json.loads(raw)
|
|
|
|
async def recv(self, timeout: float = 10.0) -> WsMessage:
|
|
"""Receive one message, returning a WsMessage wrapper."""
|
|
data = await self.recv_raw(timeout)
|
|
return WsMessage(event=data.get("event", ""), raw=data)
|
|
|
|
async def recv_ready(self, timeout: float = 5.0) -> WsMessage:
|
|
"""Receive and validate the 'ready' event."""
|
|
msg = await self.recv(timeout)
|
|
assert msg.event == "ready", f"Expected 'ready' event, got '{msg.event}'"
|
|
return msg
|
|
|
|
async def recv_message(self, timeout: float = 10.0) -> WsMessage:
|
|
"""Receive and validate a 'message' event."""
|
|
msg = await self.recv(timeout)
|
|
assert msg.event == "message", f"Expected 'message' event, got '{msg.event}'"
|
|
return msg
|
|
|
|
async def recv_delta(self, timeout: float = 10.0) -> WsMessage:
|
|
"""Receive and validate a 'delta' event."""
|
|
msg = await self.recv(timeout)
|
|
assert msg.event == "delta", f"Expected 'delta' event, got '{msg.event}'"
|
|
return msg
|
|
|
|
async def recv_stream_end(self, timeout: float = 10.0) -> WsMessage:
|
|
"""Receive and validate a 'stream_end' event."""
|
|
msg = await self.recv(timeout)
|
|
assert msg.event == "stream_end", f"Expected 'stream_end' event, got '{msg.event}'"
|
|
return msg
|
|
|
|
async def collect_stream(self, timeout: float = 10.0) -> list[WsMessage]:
|
|
"""Collect all deltas and the final stream_end into a list."""
|
|
messages: list[WsMessage] = []
|
|
while True:
|
|
msg = await self.recv(timeout)
|
|
messages.append(msg)
|
|
if msg.event == "stream_end":
|
|
break
|
|
return messages
|
|
|
|
async def recv_n(self, n: int, timeout: float = 10.0) -> list[WsMessage]:
|
|
"""Receive exactly *n* messages."""
|
|
return [await self.recv(timeout) for _ in range(n)]
|
|
|
|
# -- Sending ----------------------------------------------------------
|
|
|
|
async def send_text(self, text: str) -> None:
|
|
"""Send a plain text frame."""
|
|
await self.ws.send(text)
|
|
|
|
async def send_json(self, data: dict[str, Any]) -> None:
|
|
"""Send a JSON frame."""
|
|
await self.ws.send(json.dumps(data, ensure_ascii=False))
|
|
|
|
async def send_content(self, content: str) -> None:
|
|
"""Send content in the preferred JSON format ``{"content": ...}``."""
|
|
await self.send_json({"content": content})
|
|
|
|
# -- Connection introspection -----------------------------------------
|
|
|
|
@property
|
|
def closed(self) -> bool:
|
|
return self._ws is None or self._ws.closed
|
|
|
|
|
|
# -- Token issuance helpers -----------------------------------------------
|
|
|
|
|
|
async def issue_token(
|
|
host: str = "127.0.0.1",
|
|
port: int = 8765,
|
|
issue_path: str = "/auth/token",
|
|
secret: str = "",
|
|
) -> tuple[dict[str, Any] | None, int]:
|
|
"""Request a short-lived token from the token-issue HTTP endpoint.
|
|
|
|
Returns ``(parsed_json_or_None, status_code)``.
|
|
"""
|
|
url = f"http://{host}:{port}{issue_path}"
|
|
headers: dict[str, str] = {}
|
|
if secret:
|
|
headers["Authorization"] = f"Bearer {secret}"
|
|
|
|
loop = asyncio.get_running_loop()
|
|
resp = await loop.run_in_executor(
|
|
None, lambda: httpx.get(url, headers=headers, timeout=5.0)
|
|
)
|
|
try:
|
|
data = resp.json()
|
|
except Exception:
|
|
data = None
|
|
return data, resp.status_code
|
|
|
|
|
|
async def issue_token_ok(
|
|
host: str = "127.0.0.1",
|
|
port: int = 8765,
|
|
issue_path: str = "/auth/token",
|
|
secret: str = "",
|
|
) -> str:
|
|
"""Request a token, asserting success, and return the token string."""
|
|
(data, status) = await issue_token(host, port, issue_path, secret)
|
|
assert status == 200, f"Token issue failed with status {status}"
|
|
assert data is not None
|
|
token = data["token"]
|
|
assert token.startswith("nbwt_"), f"Unexpected token format: {token}"
|
|
return token
|