mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-13 06:29:48 +00:00
428 lines
16 KiB
Python
428 lines
16 KiB
Python
"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import email.utils
|
|
import hmac
|
|
import http
|
|
import json
|
|
import secrets
|
|
import ssl
|
|
import time
|
|
import uuid
|
|
from typing import Any, Self
|
|
from urllib.parse import parse_qs, urlparse
|
|
|
|
from loguru import logger
|
|
from pydantic import Field, field_validator, model_validator
|
|
from websockets.asyncio.server import ServerConnection, serve
|
|
from websockets.datastructures import Headers
|
|
from websockets.exceptions import ConnectionClosed
|
|
from websockets.http11 import Request as WsRequest, Response
|
|
|
|
from nanobot.bus.events import OutboundMessage
|
|
from nanobot.bus.queue import MessageBus
|
|
from nanobot.channels.base import BaseChannel
|
|
from nanobot.config.schema import Base
|
|
|
|
|
|
def _strip_trailing_slash(path: str) -> str:
|
|
if len(path) > 1 and path.endswith("/"):
|
|
return path.rstrip("/")
|
|
return path or "/"
|
|
|
|
|
|
def _normalize_config_path(path: str) -> str:
|
|
return _strip_trailing_slash(path)
|
|
|
|
|
|
class WebSocketConfig(Base):
|
|
"""WebSocket server channel configuration.
|
|
|
|
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
|
|
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
|
|
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
|
|
from ``token_issue_path`` are also accepted.
|
|
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
|
|
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
|
|
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
|
|
nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call
|
|
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
|
|
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
|
|
``X-Nanobot-Auth: <secret>``.
|
|
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
|
|
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
|
|
- ``media`` field in outbound messages contains local filesystem paths; remote clients need a
|
|
shared filesystem or an HTTP file server to access these files.
|
|
"""
|
|
|
|
enabled: bool = False
|
|
host: str = "127.0.0.1"
|
|
port: int = 8765
|
|
path: str = "/"
|
|
token: str = ""
|
|
token_issue_path: str = ""
|
|
token_issue_secret: str = ""
|
|
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
|
|
websocket_requires_token: bool = False
|
|
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
|
streaming: bool = True
|
|
max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216)
|
|
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
|
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
|
ssl_certfile: str = ""
|
|
ssl_keyfile: str = ""
|
|
|
|
@field_validator("path")
|
|
@classmethod
|
|
def path_must_start_with_slash(cls, value: str) -> str:
|
|
if not value.startswith("/"):
|
|
raise ValueError('path must start with "/"')
|
|
return value
|
|
|
|
@field_validator("token_issue_path")
|
|
@classmethod
|
|
def token_issue_path_format(cls, value: str) -> str:
|
|
value = value.strip()
|
|
if not value:
|
|
return ""
|
|
if not value.startswith("/"):
|
|
raise ValueError('token_issue_path must start with "/"')
|
|
return _normalize_config_path(value)
|
|
|
|
@model_validator(mode="after")
|
|
def token_issue_path_differs_from_ws_path(self) -> Self:
|
|
if not self.token_issue_path:
|
|
return self
|
|
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
|
|
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
|
|
return self
|
|
|
|
|
|
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
|
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
|
headers = Headers(
|
|
[
|
|
("Date", email.utils.formatdate(usegmt=True)),
|
|
("Connection", "close"),
|
|
("Content-Length", str(len(body))),
|
|
("Content-Type", "application/json; charset=utf-8"),
|
|
]
|
|
)
|
|
reason = http.HTTPStatus(status).phrase
|
|
return Response(status, reason, headers, body)
|
|
|
|
|
|
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
|
"""Parse normalized path and query parameters in one pass."""
|
|
parsed = urlparse("ws://x" + path_with_query)
|
|
path = _strip_trailing_slash(parsed.path or "/")
|
|
return path, parse_qs(parsed.query)
|
|
|
|
|
|
def _normalize_http_path(path_with_query: str) -> str:
|
|
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
|
|
return _parse_request_path(path_with_query)[0]
|
|
|
|
|
|
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
|
return _parse_request_path(path_with_query)[1]
|
|
|
|
|
|
def _parse_inbound_payload(raw: str) -> str | None:
|
|
"""Parse a client frame into text; return None for empty or unrecognized content."""
|
|
text = raw.strip()
|
|
if not text:
|
|
return None
|
|
if text.startswith("{"):
|
|
try:
|
|
data = json.loads(text)
|
|
except json.JSONDecodeError:
|
|
return text
|
|
if isinstance(data, dict):
|
|
for key in ("content", "text", "message"):
|
|
value = data.get(key)
|
|
if isinstance(value, str) and value.strip():
|
|
return value
|
|
return None
|
|
return None
|
|
return text
|
|
|
|
|
|
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
|
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
|
|
if not configured_secret:
|
|
return True
|
|
authorization = headers.get("Authorization") or headers.get("authorization")
|
|
if authorization and authorization.lower().startswith("bearer "):
|
|
supplied = authorization[7:].strip()
|
|
return hmac.compare_digest(supplied, configured_secret)
|
|
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
|
if not header_token:
|
|
return False
|
|
return hmac.compare_digest(header_token.strip(), configured_secret)
|
|
|
|
|
|
class WebSocketChannel(BaseChannel):
|
|
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
|
|
|
|
name = "websocket"
|
|
display_name = "WebSocket"
|
|
|
|
def __init__(self, config: Any, bus: MessageBus):
|
|
if isinstance(config, dict):
|
|
config = WebSocketConfig.model_validate(config)
|
|
super().__init__(config, bus)
|
|
self.config: WebSocketConfig = config
|
|
self._connections: dict[str, Any] = {}
|
|
self._issued_tokens: dict[str, float] = {}
|
|
self._stop_event: asyncio.Event | None = None
|
|
self._server_task: asyncio.Task[None] | None = None
|
|
|
|
@classmethod
|
|
def default_config(cls) -> dict[str, Any]:
|
|
return WebSocketConfig().model_dump(by_alias=True)
|
|
|
|
def _expected_path(self) -> str:
|
|
return _normalize_config_path(self.config.path)
|
|
|
|
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
|
cert = self.config.ssl_certfile.strip()
|
|
key = self.config.ssl_keyfile.strip()
|
|
if not cert and not key:
|
|
return None
|
|
if not cert or not key:
|
|
raise ValueError(
|
|
"websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
|
)
|
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
|
return ctx
|
|
|
|
def _purge_expired_issued_tokens(self) -> None:
|
|
now = time.monotonic()
|
|
for token_key, expiry in list(self._issued_tokens.items()):
|
|
if now > expiry:
|
|
self._issued_tokens.pop(token_key, None)
|
|
|
|
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
|
"""Validate and consume one issued token (single use per connection attempt)."""
|
|
if not token_value:
|
|
return False
|
|
self._purge_expired_issued_tokens()
|
|
expiry = self._issued_tokens.get(token_value)
|
|
if expiry is None:
|
|
return False
|
|
if time.monotonic() > expiry:
|
|
self._issued_tokens.pop(token_value, None)
|
|
return False
|
|
self._issued_tokens.pop(token_value, None)
|
|
return True
|
|
|
|
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
|
|
secret = self.config.token_issue_secret.strip()
|
|
if secret:
|
|
if not _issue_route_secret_matches(request.headers, secret):
|
|
return connection.respond(401, "Unauthorized")
|
|
else:
|
|
logger.warning(
|
|
"websocket: token_issue_path is set but token_issue_secret is empty; "
|
|
"any client can obtain connection tokens — set token_issue_secret for production."
|
|
)
|
|
self._purge_expired_issued_tokens()
|
|
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
|
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
|
|
|
return _http_json_response(
|
|
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
|
)
|
|
|
|
def _authorize_websocket_handshake(self, connection: Any, request_path: str) -> Any:
|
|
query = _parse_query(request_path)
|
|
supplied = (query.get("token") or [None])[0]
|
|
static_token = self.config.token.strip()
|
|
|
|
if static_token:
|
|
if supplied == static_token:
|
|
return None
|
|
if supplied and self._take_issued_token_if_valid(supplied):
|
|
return None
|
|
return connection.respond(401, "Unauthorized")
|
|
|
|
if self.config.websocket_requires_token:
|
|
if supplied and self._take_issued_token_if_valid(supplied):
|
|
return None
|
|
return connection.respond(401, "Unauthorized")
|
|
|
|
if supplied:
|
|
self._take_issued_token_if_valid(supplied)
|
|
return None
|
|
|
|
async def start(self) -> None:
|
|
self._running = True
|
|
self._stop_event = asyncio.Event()
|
|
|
|
ssl_context = self._build_ssl_context()
|
|
scheme = "wss" if ssl_context else "ws"
|
|
|
|
async def process_request(
|
|
connection: ServerConnection,
|
|
request: WsRequest,
|
|
) -> Any:
|
|
got, _ = _parse_request_path(request.path)
|
|
if self.config.token_issue_path:
|
|
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
|
if got == issue_expected:
|
|
return self._handle_token_issue_http(connection, request)
|
|
|
|
expected_ws = self._expected_path()
|
|
if got != expected_ws:
|
|
return connection.respond(404, "Not Found")
|
|
return self._authorize_websocket_handshake(connection, request.path)
|
|
|
|
async def handler(connection: ServerConnection) -> None:
|
|
await self._connection_loop(connection)
|
|
|
|
logger.info(
|
|
"WebSocket server listening on {}://{}:{}{}",
|
|
scheme,
|
|
self.config.host,
|
|
self.config.port,
|
|
self.config.path,
|
|
)
|
|
if self.config.token_issue_path:
|
|
logger.info(
|
|
"WebSocket token issue route: {}://{}:{}{}",
|
|
scheme,
|
|
self.config.host,
|
|
self.config.port,
|
|
_normalize_config_path(self.config.token_issue_path),
|
|
)
|
|
|
|
async def runner() -> None:
|
|
async with serve(
|
|
handler,
|
|
self.config.host,
|
|
self.config.port,
|
|
process_request=process_request,
|
|
max_size=self.config.max_message_bytes,
|
|
ping_interval=self.config.ping_interval_s,
|
|
ping_timeout=self.config.ping_timeout_s,
|
|
ssl=ssl_context,
|
|
):
|
|
assert self._stop_event is not None
|
|
await self._stop_event.wait()
|
|
|
|
self._server_task = asyncio.create_task(runner())
|
|
await self._server_task
|
|
|
|
async def _connection_loop(self, connection: Any) -> None:
|
|
request = connection.request
|
|
path_part = request.path if request else "/"
|
|
_, query = _parse_request_path(path_part)
|
|
client_id_raw = (query.get("client_id") or [None])[0]
|
|
client_id = client_id_raw.strip() if client_id_raw else ""
|
|
if not client_id:
|
|
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
|
|
|
chat_id = str(uuid.uuid4())
|
|
self._connections[chat_id] = connection
|
|
|
|
await connection.send(
|
|
json.dumps(
|
|
{
|
|
"event": "ready",
|
|
"chat_id": chat_id,
|
|
"client_id": client_id,
|
|
},
|
|
ensure_ascii=False,
|
|
)
|
|
)
|
|
|
|
try:
|
|
async for raw in connection:
|
|
if isinstance(raw, bytes):
|
|
try:
|
|
raw = raw.decode("utf-8")
|
|
except UnicodeDecodeError:
|
|
logger.warning("websocket: ignoring non-utf8 binary frame")
|
|
continue
|
|
content = _parse_inbound_payload(raw)
|
|
if content is None:
|
|
continue
|
|
await self._handle_message(
|
|
sender_id=client_id,
|
|
chat_id=chat_id,
|
|
content=content,
|
|
metadata={"remote": getattr(connection, "remote_address", None)},
|
|
)
|
|
except Exception as e:
|
|
logger.debug("websocket connection ended: {}", e)
|
|
finally:
|
|
self._connections.pop(chat_id, None)
|
|
|
|
async def stop(self) -> None:
|
|
self._running = False
|
|
if self._stop_event:
|
|
self._stop_event.set()
|
|
if self._server_task:
|
|
await self._server_task
|
|
self._server_task = None
|
|
self._connections.clear()
|
|
self._issued_tokens.clear()
|
|
|
|
async def send(self, msg: OutboundMessage) -> None:
|
|
connection = self._connections.get(msg.chat_id)
|
|
if connection is None:
|
|
logger.warning("websocket: no active connection for chat_id={}", msg.chat_id)
|
|
return
|
|
payload: dict[str, Any] = {
|
|
"event": "message",
|
|
"text": msg.content,
|
|
}
|
|
if msg.media:
|
|
payload["media"] = msg.media
|
|
if msg.reply_to:
|
|
payload["reply_to"] = msg.reply_to
|
|
raw = json.dumps(payload, ensure_ascii=False)
|
|
try:
|
|
await connection.send(raw)
|
|
except ConnectionClosed:
|
|
self._connections.pop(msg.chat_id, None)
|
|
logger.warning("websocket: connection gone for chat_id={}", msg.chat_id)
|
|
except Exception as e:
|
|
logger.error("websocket send failed: {}", e)
|
|
raise
|
|
|
|
async def send_delta(
|
|
self,
|
|
chat_id: str,
|
|
delta: str,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> None:
|
|
connection = self._connections.get(chat_id)
|
|
if connection is None:
|
|
return
|
|
meta = metadata or {}
|
|
if meta.get("_stream_end"):
|
|
body: dict[str, Any] = {"event": "stream_end"}
|
|
if meta.get("_stream_id") is not None:
|
|
body["stream_id"] = meta["_stream_id"]
|
|
else:
|
|
body = {
|
|
"event": "delta",
|
|
"text": delta,
|
|
}
|
|
if meta.get("_stream_id") is not None:
|
|
body["stream_id"] = meta["_stream_id"]
|
|
raw = json.dumps(body, ensure_ascii=False)
|
|
try:
|
|
await connection.send(raw)
|
|
except ConnectionClosed:
|
|
self._connections.pop(chat_id, None)
|
|
logger.warning("websocket: stream connection gone for chat_id={}", chat_id)
|
|
except Exception as e:
|
|
logger.error("websocket stream send failed: {}", e)
|
|
raise
|