mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 00:22:31 +00:00
1346 lines
54 KiB
Python
1346 lines
54 KiB
Python
"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import binascii
|
||
import email.utils
|
||
import hashlib
|
||
import hmac
|
||
import http
|
||
import json
|
||
import mimetypes
|
||
import re
|
||
import secrets
|
||
import shutil
|
||
import ssl
|
||
import time
|
||
import uuid
|
||
from pathlib import Path
|
||
from typing import TYPE_CHECKING, Any, Self
|
||
from urllib.parse import parse_qs, unquote, urlparse
|
||
|
||
from loguru import logger
|
||
from pydantic import Field, field_validator, model_validator
|
||
from websockets.asyncio.server import ServerConnection, serve
|
||
from websockets.datastructures import Headers
|
||
from websockets.exceptions import ConnectionClosed
|
||
from websockets.http11 import Request as WsRequest
|
||
from websockets.http11 import Response
|
||
|
||
from nanobot.bus.events import OutboundMessage
|
||
from nanobot.bus.queue import MessageBus
|
||
from nanobot.channels.base import BaseChannel
|
||
from nanobot.command.builtin import builtin_command_palette
|
||
from nanobot.config.paths import get_media_dir
|
||
from nanobot.config.schema import Base
|
||
from nanobot.utils.helpers import safe_filename
|
||
from nanobot.utils.media_decode import (
|
||
FileSizeExceeded,
|
||
save_base64_data_url,
|
||
)
|
||
|
||
if TYPE_CHECKING:
|
||
from nanobot.session.manager import SessionManager
|
||
|
||
|
||
def _strip_trailing_slash(path: str) -> str:
|
||
if len(path) > 1 and path.endswith("/"):
|
||
return path.rstrip("/")
|
||
return path or "/"
|
||
|
||
|
||
def _normalize_config_path(path: str) -> str:
|
||
return _strip_trailing_slash(path)
|
||
|
||
|
||
def _append_buttons_as_text(text: str, buttons: list[list[str]]) -> str:
|
||
labels = [label for row in buttons for label in row if label]
|
||
if not labels:
|
||
return text
|
||
fallback = "\n".join(f"{index}. {label}" for index, label in enumerate(labels, 1))
|
||
return f"{text}\n\n{fallback}" if text else fallback
|
||
|
||
|
||
class WebSocketConfig(Base):
|
||
"""WebSocket server channel configuration.
|
||
|
||
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
|
||
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
|
||
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
|
||
from ``token_issue_path`` are also accepted.
|
||
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
|
||
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
|
||
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
|
||
nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call
|
||
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
|
||
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
|
||
``X-Nanobot-Auth: <secret>``.
|
||
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
|
||
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
|
||
- ``media`` field in outbound messages contains local filesystem paths; remote clients need a
|
||
shared filesystem or an HTTP file server to access these files.
|
||
"""
|
||
|
||
enabled: bool = False
|
||
host: str = "127.0.0.1"
|
||
port: int = 8765
|
||
path: str = "/"
|
||
token: str = ""
|
||
token_issue_path: str = ""
|
||
token_issue_secret: str = ""
|
||
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
|
||
websocket_requires_token: bool = True
|
||
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
||
streaming: bool = True
|
||
# Default 36 MB, upper 40 MB: supports up to 4 images at ~6 MB each after
|
||
# client-side Worker normalization (see webui Composer). 4 × 6 MB × 1.37
|
||
# (base64 overhead) + envelope framing stays under 36 MB; the 40 MB ceiling
|
||
# leaves a small margin for sender slop without opening a DoS avenue.
|
||
max_message_bytes: int = Field(default=37_748_736, ge=1024, le=41_943_040)
|
||
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||
ssl_certfile: str = ""
|
||
ssl_keyfile: str = ""
|
||
|
||
@field_validator("path")
|
||
@classmethod
|
||
def path_must_start_with_slash(cls, value: str) -> str:
|
||
if not value.startswith("/"):
|
||
raise ValueError('path must start with "/"')
|
||
return _normalize_config_path(value)
|
||
|
||
@field_validator("token_issue_path")
|
||
@classmethod
|
||
def token_issue_path_format(cls, value: str) -> str:
|
||
value = value.strip()
|
||
if not value:
|
||
return ""
|
||
if not value.startswith("/"):
|
||
raise ValueError('token_issue_path must start with "/"')
|
||
return _normalize_config_path(value)
|
||
|
||
@model_validator(mode="after")
|
||
def token_issue_path_differs_from_ws_path(self) -> Self:
|
||
if not self.token_issue_path:
|
||
return self
|
||
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
|
||
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
|
||
return self
|
||
|
||
@model_validator(mode="after")
|
||
def wildcard_host_requires_auth(self) -> Self:
|
||
if self.host not in ("0.0.0.0", "::"):
|
||
return self
|
||
if self.token.strip() or self.token_issue_secret.strip():
|
||
return self
|
||
raise ValueError(
|
||
"host is 0.0.0.0 (all interfaces) but neither token nor "
|
||
"token_issue_secret is set — set one to prevent unauthenticated access"
|
||
)
|
||
|
||
|
||
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||
headers = Headers(
|
||
[
|
||
("Date", email.utils.formatdate(usegmt=True)),
|
||
("Connection", "close"),
|
||
("Content-Length", str(len(body))),
|
||
("Content-Type", "application/json; charset=utf-8"),
|
||
]
|
||
)
|
||
reason = http.HTTPStatus(status).phrase
|
||
return Response(status, reason, headers, body)
|
||
|
||
|
||
def _read_webui_model_name() -> str | None:
|
||
"""Return the configured default model for readonly webui display."""
|
||
try:
|
||
from nanobot.config.loader import load_config
|
||
|
||
model = load_config().agents.defaults.model.strip()
|
||
return model or None
|
||
except Exception as e:
|
||
logger.debug("webui bootstrap could not load model name: {}", e)
|
||
return None
|
||
|
||
|
||
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||
"""Parse normalized path and query parameters in one pass."""
|
||
parsed = urlparse("ws://x" + path_with_query)
|
||
path = _strip_trailing_slash(parsed.path or "/")
|
||
return path, parse_qs(parsed.query)
|
||
|
||
|
||
def _normalize_http_path(path_with_query: str) -> str:
|
||
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
|
||
return _parse_request_path(path_with_query)[0]
|
||
|
||
|
||
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||
return _parse_request_path(path_with_query)[1]
|
||
|
||
|
||
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
|
||
"""Return the first value for *key*, or None."""
|
||
values = query.get(key)
|
||
return values[0] if values else None
|
||
|
||
|
||
def _parse_inbound_payload(raw: str) -> str | None:
|
||
"""Parse a client frame into text; return None for empty or unrecognized content."""
|
||
text = raw.strip()
|
||
if not text:
|
||
return None
|
||
if text.startswith("{"):
|
||
try:
|
||
data = json.loads(text)
|
||
except json.JSONDecodeError:
|
||
return text
|
||
if isinstance(data, dict):
|
||
for key in ("content", "text", "message"):
|
||
value = data.get(key)
|
||
if isinstance(value, str) and value.strip():
|
||
return value
|
||
return None
|
||
return None
|
||
return text
|
||
|
||
|
||
# Accept UUIDs and short scoped keys like "unified:default". Keeps the capability
|
||
# namespace small enough to rule out path traversal / quote injection tricks.
|
||
_CHAT_ID_RE = re.compile(r"^[A-Za-z0-9_:-]{1,64}$")
|
||
|
||
|
||
def _is_valid_chat_id(value: Any) -> bool:
|
||
return isinstance(value, str) and _CHAT_ID_RE.match(value) is not None
|
||
|
||
|
||
def _parse_envelope(raw: str) -> dict[str, Any] | None:
|
||
"""Return a typed envelope dict if the frame is a new-style JSON envelope, else None.
|
||
|
||
A frame qualifies when it parses as a JSON object with a string ``type`` field.
|
||
Legacy frames (plain text, or ``{"content": ...}`` without ``type``) return None;
|
||
callers should fall back to :func:`_parse_inbound_payload` for those.
|
||
"""
|
||
text = raw.strip()
|
||
if not text.startswith("{"):
|
||
return None
|
||
try:
|
||
data = json.loads(text)
|
||
except json.JSONDecodeError:
|
||
return None
|
||
if not isinstance(data, dict):
|
||
return None
|
||
t = data.get("type")
|
||
if not isinstance(t, str):
|
||
return None
|
||
return data
|
||
|
||
|
||
# Per-message media limits. The server-side guard is a touch looser than the
|
||
# client's ``Worker`` normalization target (6 MB) — tolerate client slop, but
|
||
# still cap total ingress at ``_MAX_IMAGES_PER_MESSAGE * _MAX_IMAGE_BYTES``
|
||
# which fits comfortably inside ``max_message_bytes``.
|
||
_MAX_IMAGES_PER_MESSAGE = 4
|
||
_MAX_IMAGE_BYTES = 8 * 1024 * 1024
|
||
_MAX_VIDEOS_PER_MESSAGE = 1
|
||
_MAX_VIDEO_BYTES = 20 * 1024 * 1024
|
||
|
||
# Image MIME whitelist — matches the Composer's ``accept`` list. SVG is
|
||
# explicitly excluded to avoid the XSS surface inside embedded scripts.
|
||
_IMAGE_MIME_ALLOWED: frozenset[str] = frozenset({
|
||
"image/png",
|
||
"image/jpeg",
|
||
"image/webp",
|
||
"image/gif",
|
||
})
|
||
|
||
_VIDEO_MIME_ALLOWED: frozenset[str] = frozenset({
|
||
"video/mp4",
|
||
"video/webm",
|
||
"video/quicktime",
|
||
})
|
||
|
||
_UPLOAD_MIME_ALLOWED: frozenset[str] = _IMAGE_MIME_ALLOWED | _VIDEO_MIME_ALLOWED
|
||
|
||
_DATA_URL_MIME_RE = re.compile(r"^data:([^;]+);base64,", re.DOTALL)
|
||
|
||
|
||
def _extract_data_url_mime(url: str) -> str | None:
|
||
"""Return the MIME type of a ``data:<mime>;base64,...`` URL, else ``None``."""
|
||
if not isinstance(url, str):
|
||
return None
|
||
m = _DATA_URL_MIME_RE.match(url)
|
||
if not m:
|
||
return None
|
||
return m.group(1).strip().lower() or None
|
||
|
||
|
||
_LOCALHOSTS = frozenset({"127.0.0.1", "::1", "localhost"})
|
||
|
||
# Matches the legacy chat-id pattern but allows file-system-safe stems too,
|
||
# so the API can address sessions whose keys came from non-WebSocket channels.
|
||
_API_KEY_RE = re.compile(r"^[A-Za-z0-9_:.-]{1,128}$")
|
||
|
||
|
||
def _decode_api_key(raw_key: str) -> str | None:
|
||
"""Decode a percent-encoded API path segment, then validate the result."""
|
||
key = unquote(raw_key)
|
||
if _API_KEY_RE.match(key) is None:
|
||
return None
|
||
return key
|
||
|
||
|
||
def _is_localhost(connection: Any) -> bool:
|
||
"""Return True if *connection* originated from the loopback interface."""
|
||
addr = getattr(connection, "remote_address", None)
|
||
if not addr:
|
||
return False
|
||
host = addr[0] if isinstance(addr, tuple) else addr
|
||
if not isinstance(host, str):
|
||
return False
|
||
# ``::ffff:127.0.0.1`` is loopback in IPv6-mapped form.
|
||
if host.startswith("::ffff:"):
|
||
host = host[7:]
|
||
return host in _LOCALHOSTS
|
||
|
||
|
||
def _http_response(
|
||
body: bytes,
|
||
*,
|
||
status: int = 200,
|
||
content_type: str = "text/plain; charset=utf-8",
|
||
extra_headers: list[tuple[str, str]] | None = None,
|
||
) -> Response:
|
||
headers = [
|
||
("Date", email.utils.formatdate(usegmt=True)),
|
||
("Connection", "close"),
|
||
("Content-Length", str(len(body))),
|
||
("Content-Type", content_type),
|
||
]
|
||
if extra_headers:
|
||
headers.extend(extra_headers)
|
||
reason = http.HTTPStatus(status).phrase
|
||
return Response(status, reason, Headers(headers), body)
|
||
|
||
|
||
def _http_error(status: int, message: str | None = None) -> Response:
|
||
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
|
||
return _http_response(body, status=status)
|
||
|
||
|
||
def _bearer_token(headers: Any) -> str | None:
|
||
"""Pull a Bearer token out of standard or query-style headers."""
|
||
auth = headers.get("Authorization") or headers.get("authorization")
|
||
if auth and auth.lower().startswith("bearer "):
|
||
return auth[7:].strip() or None
|
||
return None
|
||
|
||
|
||
def _is_websocket_upgrade(request: WsRequest) -> bool:
|
||
"""Detect an actual WS upgrade; plain HTTP GETs to the same path should fall through."""
|
||
upgrade = request.headers.get("Upgrade") or request.headers.get("upgrade")
|
||
connection = request.headers.get("Connection") or request.headers.get("connection")
|
||
if not upgrade or "websocket" not in upgrade.lower():
|
||
return False
|
||
if not connection or "upgrade" not in connection.lower():
|
||
return False
|
||
return True
|
||
|
||
|
||
def _b64url_encode(data: bytes) -> str:
|
||
"""URL-safe base64 without padding — compact + friendly in URL paths."""
|
||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||
|
||
|
||
def _b64url_decode(s: str) -> bytes:
|
||
"""Reverse of :func:`_b64url_encode`; caller handles ``ValueError``."""
|
||
pad = "=" * (-len(s) % 4)
|
||
return base64.urlsafe_b64decode(s + pad)
|
||
|
||
|
||
# Allowed MIME types we actually serve from the media endpoint. Anything
|
||
# outside this set is degraded to ``application/octet-stream`` so an
|
||
# attacker who somehow gets a signed URL for an unexpected file type can't
|
||
# trick the browser into sniffing executable content.
|
||
_MEDIA_ALLOWED_MIMES: frozenset[str] = frozenset({
|
||
"image/png",
|
||
"image/jpeg",
|
||
"image/webp",
|
||
"image/gif",
|
||
"video/mp4",
|
||
"video/webm",
|
||
"video/quicktime",
|
||
})
|
||
|
||
|
||
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
|
||
if not configured_secret:
|
||
return True
|
||
authorization = headers.get("Authorization") or headers.get("authorization")
|
||
if authorization and authorization.lower().startswith("bearer "):
|
||
supplied = authorization[7:].strip()
|
||
return hmac.compare_digest(supplied, configured_secret)
|
||
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||
if not header_token:
|
||
return False
|
||
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||
|
||
|
||
class WebSocketChannel(BaseChannel):
|
||
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
|
||
|
||
name = "websocket"
|
||
display_name = "WebSocket"
|
||
|
||
def __init__(
|
||
self,
|
||
config: Any,
|
||
bus: MessageBus,
|
||
*,
|
||
session_manager: "SessionManager | None" = None,
|
||
static_dist_path: Path | None = None,
|
||
):
|
||
if isinstance(config, dict):
|
||
config = WebSocketConfig.model_validate(config)
|
||
super().__init__(config, bus)
|
||
self.config: WebSocketConfig = config
|
||
# chat_id -> connections subscribed to it (fan-out target).
|
||
self._subs: dict[str, set[Any]] = {}
|
||
# connection -> chat_ids it is subscribed to (O(1) cleanup on disconnect).
|
||
self._conn_chats: dict[Any, set[str]] = {}
|
||
# connection -> default chat_id for legacy frames that omit routing.
|
||
self._conn_default: dict[Any, str] = {}
|
||
# Single-use tokens consumed at WebSocket handshake.
|
||
self._issued_tokens: dict[str, float] = {}
|
||
# Multi-use tokens for the embedded webui's REST surface; checked but not consumed.
|
||
self._api_tokens: dict[str, float] = {}
|
||
self._stop_event: asyncio.Event | None = None
|
||
self._server_task: asyncio.Task[None] | None = None
|
||
self._session_manager = session_manager
|
||
self._static_dist_path: Path | None = (
|
||
static_dist_path.resolve() if static_dist_path is not None else None
|
||
)
|
||
# Process-local secret used to HMAC-sign media URLs. The signed URL is
|
||
# the capability — anyone who holds a valid URL can fetch that one
|
||
# file, nothing else. The secret regenerates on restart so links
|
||
# become self-expiring (callers just refresh the session list).
|
||
self._media_secret: bytes = secrets.token_bytes(32)
|
||
|
||
# -- Subscription bookkeeping -------------------------------------------
|
||
|
||
def _attach(self, connection: Any, chat_id: str) -> None:
|
||
"""Idempotently subscribe *connection* to *chat_id*."""
|
||
self._subs.setdefault(chat_id, set()).add(connection)
|
||
self._conn_chats.setdefault(connection, set()).add(chat_id)
|
||
|
||
def _cleanup_connection(self, connection: Any) -> None:
|
||
"""Remove *connection* from every subscription set; safe to call multiple times."""
|
||
chat_ids = self._conn_chats.pop(connection, set())
|
||
for cid in chat_ids:
|
||
subs = self._subs.get(cid)
|
||
if subs is None:
|
||
continue
|
||
subs.discard(connection)
|
||
if not subs:
|
||
self._subs.pop(cid, None)
|
||
self._conn_default.pop(connection, None)
|
||
|
||
async def _send_event(self, connection: Any, event: str, **fields: Any) -> None:
|
||
"""Send a control event (attached, error, ...) to a single connection."""
|
||
payload: dict[str, Any] = {"event": event}
|
||
payload.update(fields)
|
||
raw = json.dumps(payload, ensure_ascii=False)
|
||
try:
|
||
await connection.send(raw)
|
||
except ConnectionClosed:
|
||
self._cleanup_connection(connection)
|
||
except Exception as e:
|
||
self.logger.warning("failed to send {} event: {}", event, e)
|
||
|
||
@classmethod
|
||
def default_config(cls) -> dict[str, Any]:
|
||
return WebSocketConfig().model_dump(by_alias=True)
|
||
|
||
def _expected_path(self) -> str:
|
||
return _normalize_config_path(self.config.path)
|
||
|
||
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
||
cert = self.config.ssl_certfile.strip()
|
||
key = self.config.ssl_keyfile.strip()
|
||
if not cert and not key:
|
||
return None
|
||
if not cert or not key:
|
||
raise ValueError(
|
||
"ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
||
)
|
||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||
return ctx
|
||
|
||
_MAX_ISSUED_TOKENS = 10_000
|
||
|
||
def _purge_expired_issued_tokens(self) -> None:
|
||
now = time.monotonic()
|
||
for token_key, expiry in list(self._issued_tokens.items()):
|
||
if now > expiry:
|
||
self._issued_tokens.pop(token_key, None)
|
||
|
||
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||
"""Validate and consume one issued token (single use per connection attempt).
|
||
|
||
Uses single-step pop to minimize the window between lookup and removal;
|
||
safe under asyncio's single-threaded cooperative model.
|
||
"""
|
||
if not token_value:
|
||
return False
|
||
self._purge_expired_issued_tokens()
|
||
expiry = self._issued_tokens.pop(token_value, None)
|
||
if expiry is None:
|
||
return False
|
||
if time.monotonic() > expiry:
|
||
return False
|
||
return True
|
||
|
||
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
|
||
secret = self.config.token_issue_secret.strip()
|
||
if secret:
|
||
if not _issue_route_secret_matches(request.headers, secret):
|
||
return connection.respond(401, "Unauthorized")
|
||
else:
|
||
self.logger.warning(
|
||
"token_issue_path is set but token_issue_secret is empty; "
|
||
"any client can obtain connection tokens — set token_issue_secret for production."
|
||
)
|
||
self._purge_expired_issued_tokens()
|
||
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
||
self.logger.error(
|
||
"too many outstanding issued tokens ({}), rejecting issuance",
|
||
len(self._issued_tokens),
|
||
)
|
||
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
||
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||
|
||
return _http_json_response(
|
||
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||
)
|
||
|
||
# -- HTTP dispatch ------------------------------------------------------
|
||
|
||
async def _dispatch_http(self, connection: Any, request: WsRequest) -> Any:
|
||
"""Route an inbound HTTP request to a handler or to the WS upgrade path."""
|
||
got, query = _parse_request_path(request.path)
|
||
|
||
# 1. Token issue endpoint (legacy, optional, gated by configured secret).
|
||
if self.config.token_issue_path:
|
||
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
||
if got == issue_expected:
|
||
return self._handle_token_issue_http(connection, request)
|
||
|
||
# 2. WebUI bootstrap: mints tokens for the embedded UI.
|
||
if got == "/webui/bootstrap":
|
||
return self._handle_webui_bootstrap(connection, request)
|
||
|
||
# 3. REST surface for the embedded UI.
|
||
if got == "/api/sessions":
|
||
return self._handle_sessions_list(request)
|
||
|
||
if got == "/api/settings":
|
||
return self._handle_settings(request)
|
||
|
||
if got == "/api/commands":
|
||
return self._handle_commands(request)
|
||
|
||
if got == "/api/settings/update":
|
||
return self._handle_settings_update(request)
|
||
|
||
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
|
||
if m:
|
||
return self._handle_session_messages(request, m.group(1))
|
||
|
||
# NOTE: websockets' HTTP parser only accepts GET, so we cannot expose a
|
||
# true ``DELETE`` verb. The action is folded into the path instead.
|
||
m = re.match(r"^/api/sessions/([^/]+)/delete$", got)
|
||
if m:
|
||
return self._handle_session_delete(request, m.group(1))
|
||
|
||
# Signed media fetch: ``<sig>`` is an HMAC over ``<payload>``; the
|
||
# payload decodes to a path inside :func:`get_media_dir`. See
|
||
# :meth:`_sign_media_path` for the inverse direction used to build
|
||
# these URLs when replaying a session.
|
||
m = re.match(r"^/api/media/([A-Za-z0-9_-]+)/([A-Za-z0-9_-]+)$", got)
|
||
if m:
|
||
return self._handle_media_fetch(m.group(1), m.group(2))
|
||
|
||
# 4. WebSocket upgrade (the channel's primary purpose). Only run the
|
||
# handshake gate on requests that actually ask to upgrade; otherwise
|
||
# a bare ``GET /`` from the browser would be rejected as an
|
||
# unauthorized WS handshake instead of serving the SPA's index.html.
|
||
expected_ws = self._expected_path()
|
||
if got == expected_ws and _is_websocket_upgrade(request):
|
||
client_id = _query_first(query, "client_id") or ""
|
||
if len(client_id) > 128:
|
||
client_id = client_id[:128]
|
||
if not self.is_allowed(client_id):
|
||
return connection.respond(403, "Forbidden")
|
||
return self._authorize_websocket_handshake(connection, query)
|
||
|
||
# 5. Static SPA serving (only if a build directory was wired in).
|
||
if self._static_dist_path is not None:
|
||
response = self._serve_static(got)
|
||
if response is not None:
|
||
return response
|
||
|
||
return connection.respond(404, "Not Found")
|
||
|
||
# -- HTTP route handlers ------------------------------------------------
|
||
|
||
def _check_api_token(self, request: WsRequest) -> bool:
|
||
"""Validate a request against the API token pool (multi-use, TTL-bound)."""
|
||
self._purge_expired_api_tokens()
|
||
token = _bearer_token(request.headers) or _query_first(
|
||
_parse_query(request.path), "token"
|
||
)
|
||
if not token:
|
||
return False
|
||
expiry = self._api_tokens.get(token)
|
||
if expiry is None or time.monotonic() > expiry:
|
||
self._api_tokens.pop(token, None)
|
||
return False
|
||
return True
|
||
|
||
def _purge_expired_api_tokens(self) -> None:
|
||
now = time.monotonic()
|
||
for token_key, expiry in list(self._api_tokens.items()):
|
||
if now > expiry:
|
||
self._api_tokens.pop(token_key, None)
|
||
|
||
def _handle_webui_bootstrap(self, connection: Any, request: Any) -> Response:
|
||
# When a secret is configured (token_issue_secret or static token),
|
||
# validate it regardless of source IP. This secures deployments
|
||
# behind a reverse proxy where all connections appear as localhost.
|
||
secret = self.config.token_issue_secret.strip() or self.config.token.strip()
|
||
if secret:
|
||
if not _issue_route_secret_matches(request.headers, secret):
|
||
return _http_error(401, "Unauthorized")
|
||
elif not _is_localhost(connection):
|
||
# No secret configured: only allow localhost (local dev mode).
|
||
return _http_error(403, "webui bootstrap is localhost-only")
|
||
# Cap outstanding tokens to avoid runaway growth from a misbehaving client.
|
||
self._purge_expired_issued_tokens()
|
||
self._purge_expired_api_tokens()
|
||
if (
|
||
len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS
|
||
or len(self._api_tokens) >= self._MAX_ISSUED_TOKENS
|
||
):
|
||
return _http_response(
|
||
json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"),
|
||
status=429,
|
||
content_type="application/json; charset=utf-8",
|
||
)
|
||
token = f"nbwt_{secrets.token_urlsafe(32)}"
|
||
expiry = time.monotonic() + float(self.config.token_ttl_s)
|
||
# Same string registered in both pools: the WS handshake consumes one copy
|
||
# while the REST surface keeps validating the other until TTL expiry.
|
||
self._issued_tokens[token] = expiry
|
||
self._api_tokens[token] = expiry
|
||
return _http_json_response(
|
||
{
|
||
"token": token,
|
||
"ws_path": self._expected_path(),
|
||
"expires_in": self.config.token_ttl_s,
|
||
"model_name": _read_webui_model_name(),
|
||
}
|
||
)
|
||
|
||
def _handle_sessions_list(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
if self._session_manager is None:
|
||
return _http_error(503, "session manager unavailable")
|
||
sessions = self._session_manager.list_sessions()
|
||
# The webui is only meaningful for websocket-channel chats — CLI /
|
||
# Slack / Lark / Discord sessions can't be resumed from the browser,
|
||
# so leaking them into the sidebar is just noise. Filter to the
|
||
# ``websocket:`` prefix and strip absolute paths on the way out.
|
||
cleaned = [
|
||
{k: v for k, v in s.items() if k != "path"}
|
||
for s in sessions
|
||
if isinstance(s.get("key"), str) and s["key"].startswith("websocket:")
|
||
]
|
||
return _http_json_response({"sessions": cleaned})
|
||
|
||
def _settings_payload(self, *, requires_restart: bool = False) -> dict[str, Any]:
|
||
from nanobot.config.loader import get_config_path, load_config
|
||
from nanobot.providers.registry import PROVIDERS, find_by_name
|
||
|
||
config = load_config()
|
||
defaults = config.agents.defaults
|
||
provider_name = config.get_provider_name(defaults.model) or defaults.provider
|
||
provider = config.get_provider(defaults.model)
|
||
selected_provider = provider_name
|
||
if defaults.provider != "auto":
|
||
spec = find_by_name(defaults.provider)
|
||
selected_provider = spec.name if spec else provider_name
|
||
return {
|
||
"agent": {
|
||
"model": defaults.model,
|
||
"provider": selected_provider,
|
||
"resolved_provider": provider_name,
|
||
"has_api_key": bool(provider and provider.api_key),
|
||
},
|
||
"providers": [
|
||
{"name": "auto", "label": "Auto"}
|
||
] + [
|
||
{"name": spec.name, "label": spec.label}
|
||
for spec in PROVIDERS
|
||
],
|
||
"runtime": {
|
||
"config_path": str(get_config_path().expanduser()),
|
||
},
|
||
"requires_restart": requires_restart,
|
||
}
|
||
|
||
def _handle_settings(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
return _http_json_response(self._settings_payload())
|
||
|
||
def _handle_commands(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
return _http_json_response({"commands": builtin_command_palette()})
|
||
|
||
def _handle_settings_update(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
from nanobot.config.loader import load_config, save_config
|
||
from nanobot.providers.registry import find_by_name
|
||
|
||
query = _parse_query(request.path)
|
||
config = load_config()
|
||
defaults = config.agents.defaults
|
||
changed = False
|
||
|
||
model = _query_first(query, "model")
|
||
if model is not None:
|
||
model = model.strip()
|
||
if not model:
|
||
return _http_error(400, "model is required")
|
||
if defaults.model != model:
|
||
defaults.model = model
|
||
changed = True
|
||
|
||
provider = _query_first(query, "provider")
|
||
if provider is not None:
|
||
provider = provider.strip() or "auto"
|
||
if provider != "auto" and find_by_name(provider) is None:
|
||
return _http_error(400, "unknown provider")
|
||
if defaults.provider != provider:
|
||
defaults.provider = provider
|
||
changed = True
|
||
|
||
if changed:
|
||
save_config(config)
|
||
return _http_json_response(self._settings_payload(requires_restart=changed))
|
||
|
||
@staticmethod
|
||
def _is_webui_session_key(key: str) -> bool:
|
||
"""Return True when *key* belongs to the webui's websocket-only surface."""
|
||
return key.startswith("websocket:")
|
||
|
||
def _handle_session_messages(self, request: WsRequest, key: str) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
if self._session_manager is None:
|
||
return _http_error(503, "session manager unavailable")
|
||
decoded_key = _decode_api_key(key)
|
||
if decoded_key is None:
|
||
return _http_error(400, "invalid session key")
|
||
# The embedded webui only understands websocket-channel sessions. Keep
|
||
# its read surface aligned with ``/api/sessions`` instead of letting a
|
||
# caller probe arbitrary CLI / Slack / Lark history by handcrafted URL.
|
||
if not self._is_webui_session_key(decoded_key):
|
||
return _http_error(404, "session not found")
|
||
data = self._session_manager.read_session_file(decoded_key)
|
||
if data is None:
|
||
return _http_error(404, "session not found")
|
||
# Decorate persisted user messages with signed media URLs so the
|
||
# client can render previews. The raw on-disk ``media`` paths are
|
||
# stripped on the way out — they leak server filesystem layout and
|
||
# the client never needs them once it has the signed fetch URL.
|
||
self._augment_media_urls(data)
|
||
return _http_json_response(data)
|
||
|
||
def _augment_media_urls(self, payload: dict[str, Any]) -> None:
|
||
"""Mutate *payload* in place: each message's ``media`` path list is
|
||
replaced by a parallel ``media_urls`` list of signed fetch URLs.
|
||
|
||
Messages without media or with non-string path entries are left
|
||
untouched. Paths that no longer live inside ``media_dir`` (e.g. the
|
||
file was deleted, or the dir was relocated) are silently skipped;
|
||
the client falls back to the historical-replay placeholder tile.
|
||
"""
|
||
messages = payload.get("messages")
|
||
if not isinstance(messages, list):
|
||
return
|
||
for msg in messages:
|
||
if not isinstance(msg, dict):
|
||
continue
|
||
media = msg.get("media")
|
||
if not isinstance(media, list) or not media:
|
||
continue
|
||
urls: list[dict[str, str]] = []
|
||
for entry in media:
|
||
if not isinstance(entry, str) or not entry:
|
||
continue
|
||
signed = self._sign_media_path(Path(entry))
|
||
if signed is None:
|
||
continue
|
||
urls.append({"url": signed, "name": Path(entry).name})
|
||
if urls:
|
||
msg["media_urls"] = urls
|
||
# Always drop the raw paths from the wire payload.
|
||
msg.pop("media", None)
|
||
|
||
def _sign_media_path(self, abs_path: Path) -> str | None:
|
||
"""Return a ``/api/media/<sig>/<payload>`` URL for *abs_path*, or
|
||
``None`` when the path does not resolve inside the media root.
|
||
|
||
The URL is self-authenticating: the signature binds the payload to
|
||
this process's ``_media_secret``, so only paths we chose to sign can
|
||
be fetched. The returned path is relative to the server origin; the
|
||
client joins it against the existing webui base.
|
||
"""
|
||
try:
|
||
media_root = get_media_dir().resolve()
|
||
rel = abs_path.resolve().relative_to(media_root)
|
||
except (OSError, ValueError):
|
||
return None
|
||
payload = _b64url_encode(rel.as_posix().encode("utf-8"))
|
||
mac = hmac.new(
|
||
self._media_secret, payload.encode("ascii"), hashlib.sha256
|
||
).digest()[:16]
|
||
return f"/api/media/{_b64url_encode(mac)}/{payload}"
|
||
|
||
def _sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
|
||
"""Return a signed media URL payload for *path*.
|
||
|
||
Persisted inbound media already lives under ``get_media_dir`` and can
|
||
be signed directly. Outbound bot-generated files may live anywhere on
|
||
disk; copy those into the websocket media bucket first so the browser
|
||
can fetch them through the existing signed media route without
|
||
exposing arbitrary filesystem paths.
|
||
"""
|
||
signed = self._sign_media_path(path)
|
||
if signed is not None:
|
||
return {"url": signed, "name": path.name}
|
||
try:
|
||
if not path.is_file():
|
||
return None
|
||
media_dir = get_media_dir("websocket")
|
||
safe_name = safe_filename(path.name) or "attachment"
|
||
staged = media_dir / f"{uuid.uuid4().hex[:12]}-{safe_name}"
|
||
shutil.copyfile(path, staged)
|
||
except OSError as exc:
|
||
self.logger.warning("failed to stage outbound media {}: {}", path, exc)
|
||
return None
|
||
signed = self._sign_media_path(staged)
|
||
if signed is None:
|
||
return None
|
||
return {"url": signed, "name": path.name}
|
||
|
||
def _handle_media_fetch(self, sig: str, payload: str) -> Response:
|
||
"""Serve a single media file previously signed via
|
||
:meth:`_sign_media_path`. Validates the signature, decodes the
|
||
payload to a relative path, and streams the file bytes with a
|
||
long-lived immutable cache header (the URL already encodes the
|
||
file identity, so caches can be aggressive)."""
|
||
try:
|
||
provided_mac = _b64url_decode(sig)
|
||
except (ValueError, binascii.Error):
|
||
return _http_error(401, "invalid signature")
|
||
expected_mac = hmac.new(
|
||
self._media_secret, payload.encode("ascii"), hashlib.sha256
|
||
).digest()[:16]
|
||
if not hmac.compare_digest(expected_mac, provided_mac):
|
||
return _http_error(401, "invalid signature")
|
||
try:
|
||
rel_bytes = _b64url_decode(payload)
|
||
rel_str = rel_bytes.decode("utf-8")
|
||
except (ValueError, binascii.Error, UnicodeDecodeError):
|
||
return _http_error(400, "invalid payload")
|
||
# An attacker who somehow bypassed the HMAC check would still need
|
||
# the resolved path to escape the media root; guard defensively.
|
||
try:
|
||
media_root = get_media_dir().resolve()
|
||
candidate = (media_root / rel_str).resolve()
|
||
candidate.relative_to(media_root)
|
||
except (OSError, ValueError):
|
||
return _http_error(404, "not found")
|
||
if not candidate.is_file():
|
||
return _http_error(404, "not found")
|
||
try:
|
||
body = candidate.read_bytes()
|
||
except OSError:
|
||
return _http_error(500, "read error")
|
||
mime, _ = mimetypes.guess_type(candidate.name)
|
||
if mime not in _MEDIA_ALLOWED_MIMES:
|
||
mime = "application/octet-stream"
|
||
return _http_response(
|
||
body,
|
||
content_type=mime,
|
||
extra_headers=[
|
||
("Cache-Control", "private, max-age=31536000, immutable"),
|
||
# Paired with the MIME whitelist above: prevents browsers from
|
||
# MIME-sniffing an octet-stream fallback into executable HTML.
|
||
("X-Content-Type-Options", "nosniff"),
|
||
],
|
||
)
|
||
|
||
def _handle_session_delete(self, request: WsRequest, key: str) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
if self._session_manager is None:
|
||
return _http_error(503, "session manager unavailable")
|
||
decoded_key = _decode_api_key(key)
|
||
if decoded_key is None:
|
||
return _http_error(400, "invalid session key")
|
||
# Same boundary as ``_handle_session_messages``: the webui may only
|
||
# mutate websocket sessions, and deletion really does unlink the local
|
||
# JSONL, so keep the blast radius narrow and explicit.
|
||
if not self._is_webui_session_key(decoded_key):
|
||
return _http_error(404, "session not found")
|
||
deleted = self._session_manager.delete_session(decoded_key)
|
||
return _http_json_response({"deleted": bool(deleted)})
|
||
|
||
def _serve_static(self, request_path: str) -> Response | None:
|
||
"""Resolve *request_path* against the built SPA directory; SPA fallback to index.html."""
|
||
assert self._static_dist_path is not None
|
||
rel = request_path.lstrip("/")
|
||
if not rel:
|
||
rel = "index.html"
|
||
# Reject path-traversal attempts and absolute targets.
|
||
if ".." in rel.split("/") or rel.startswith("/"):
|
||
return _http_error(403, "Forbidden")
|
||
candidate = (self._static_dist_path / rel).resolve()
|
||
try:
|
||
candidate.relative_to(self._static_dist_path)
|
||
except ValueError:
|
||
return _http_error(403, "Forbidden")
|
||
if not candidate.is_file():
|
||
# SPA history-mode fallback: unknown routes serve index.html so the
|
||
# client-side router can render them.
|
||
index = self._static_dist_path / "index.html"
|
||
if index.is_file():
|
||
candidate = index
|
||
else:
|
||
return None
|
||
try:
|
||
body = candidate.read_bytes()
|
||
except OSError as e:
|
||
self.logger.warning("static: failed to read {}: {}", candidate, e)
|
||
return _http_error(500, "Internal Server Error")
|
||
ctype, _ = mimetypes.guess_type(candidate.name)
|
||
if ctype is None:
|
||
ctype = "application/octet-stream"
|
||
if ctype.startswith("text/") or ctype in {"application/javascript", "application/json"}:
|
||
ctype = f"{ctype}; charset=utf-8"
|
||
# Hash-named build assets are cache-friendly; index.html must stay fresh.
|
||
if candidate.name == "index.html":
|
||
cache = "no-cache"
|
||
else:
|
||
cache = "public, max-age=31536000, immutable"
|
||
return _http_response(
|
||
body,
|
||
status=200,
|
||
content_type=ctype,
|
||
extra_headers=[("Cache-Control", cache)],
|
||
)
|
||
|
||
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
|
||
supplied = _query_first(query, "token")
|
||
static_token = self.config.token.strip()
|
||
|
||
if static_token:
|
||
if supplied and hmac.compare_digest(supplied, static_token):
|
||
return None
|
||
if supplied and self._take_issued_token_if_valid(supplied):
|
||
return None
|
||
return connection.respond(401, "Unauthorized")
|
||
|
||
if self.config.websocket_requires_token:
|
||
if supplied and self._take_issued_token_if_valid(supplied):
|
||
return None
|
||
return connection.respond(401, "Unauthorized")
|
||
|
||
if supplied:
|
||
self._take_issued_token_if_valid(supplied)
|
||
return None
|
||
|
||
async def start(self) -> None:
|
||
self._running = True
|
||
self._stop_event = asyncio.Event()
|
||
|
||
ssl_context = self._build_ssl_context()
|
||
scheme = "wss" if ssl_context else "ws"
|
||
|
||
async def process_request(
|
||
connection: ServerConnection,
|
||
request: WsRequest,
|
||
) -> Any:
|
||
return await self._dispatch_http(connection, request)
|
||
|
||
async def handler(connection: ServerConnection) -> None:
|
||
await self._connection_loop(connection)
|
||
|
||
self.logger.info(
|
||
"WebSocket server listening on {}://{}:{}{}",
|
||
scheme,
|
||
self.config.host,
|
||
self.config.port,
|
||
self.config.path,
|
||
)
|
||
if self.config.token_issue_path:
|
||
self.logger.info(
|
||
"WebSocket token issue route: {}://{}:{}{}",
|
||
scheme,
|
||
self.config.host,
|
||
self.config.port,
|
||
_normalize_config_path(self.config.token_issue_path),
|
||
)
|
||
|
||
async def runner() -> None:
|
||
async with serve(
|
||
handler,
|
||
self.config.host,
|
||
self.config.port,
|
||
process_request=process_request,
|
||
max_size=self.config.max_message_bytes,
|
||
ping_interval=self.config.ping_interval_s,
|
||
ping_timeout=self.config.ping_timeout_s,
|
||
ssl=ssl_context,
|
||
):
|
||
assert self._stop_event is not None
|
||
await self._stop_event.wait()
|
||
|
||
self._server_task = asyncio.create_task(runner())
|
||
await self._server_task
|
||
|
||
async def _connection_loop(self, connection: Any) -> None:
|
||
request = connection.request
|
||
path_part = request.path if request else "/"
|
||
_, query = _parse_request_path(path_part)
|
||
client_id_raw = _query_first(query, "client_id")
|
||
client_id = client_id_raw.strip() if client_id_raw else ""
|
||
if not client_id:
|
||
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
||
elif len(client_id) > 128:
|
||
self.logger.warning("client_id too long ({} chars), truncating", len(client_id))
|
||
client_id = client_id[:128]
|
||
|
||
default_chat_id = str(uuid.uuid4())
|
||
|
||
try:
|
||
await connection.send(
|
||
json.dumps(
|
||
{
|
||
"event": "ready",
|
||
"chat_id": default_chat_id,
|
||
"client_id": client_id,
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
)
|
||
# Register only after ready is successfully sent to avoid out-of-order sends
|
||
self._conn_default[connection] = default_chat_id
|
||
self._attach(connection, default_chat_id)
|
||
|
||
async for raw in connection:
|
||
if isinstance(raw, bytes):
|
||
try:
|
||
raw = raw.decode("utf-8")
|
||
except UnicodeDecodeError:
|
||
self.logger.warning("ignoring non-utf8 binary frame")
|
||
continue
|
||
|
||
envelope = _parse_envelope(raw)
|
||
if envelope is not None:
|
||
await self._dispatch_envelope(connection, client_id, envelope)
|
||
continue
|
||
|
||
content = _parse_inbound_payload(raw)
|
||
if content is None:
|
||
continue
|
||
await self._handle_message(
|
||
sender_id=client_id,
|
||
chat_id=default_chat_id,
|
||
content=content,
|
||
metadata={"remote": getattr(connection, "remote_address", None)},
|
||
)
|
||
except Exception as e:
|
||
self.logger.debug("connection ended: {}", e)
|
||
finally:
|
||
self._cleanup_connection(connection)
|
||
|
||
@staticmethod
|
||
def _save_envelope_media(
|
||
media: list[Any],
|
||
) -> tuple[list[str], str | None]:
|
||
"""Decode and persist ``media`` items from a ``message`` envelope.
|
||
|
||
Returns ``(paths, None)`` on success or ``([], reason)`` on the first
|
||
failure — the caller is expected to surface ``reason`` to the client
|
||
and skip publishing so no half-formed message ever reaches the agent.
|
||
On failure, any files already written to disk earlier in the same
|
||
call are unlinked so partial ingress doesn't leak orphan files.
|
||
``reason`` is a short, stable token suitable for UI localization.
|
||
|
||
Shape: ``list[{"data_url": str, "name"?: str | None}]``.
|
||
"""
|
||
image_count = 0
|
||
video_count = 0
|
||
for item in media:
|
||
mime = _extract_data_url_mime(item.get("data_url", "")) if isinstance(item, dict) else None
|
||
if mime in _VIDEO_MIME_ALLOWED:
|
||
video_count += 1
|
||
elif mime in _IMAGE_MIME_ALLOWED:
|
||
image_count += 1
|
||
if image_count > _MAX_IMAGES_PER_MESSAGE:
|
||
return [], "too_many_images"
|
||
if video_count > _MAX_VIDEOS_PER_MESSAGE:
|
||
return [], "too_many_videos"
|
||
|
||
media_dir = get_media_dir("websocket")
|
||
paths: list[str] = []
|
||
|
||
def _abort(reason: str) -> tuple[list[str], str]:
|
||
for p in paths:
|
||
try:
|
||
Path(p).unlink(missing_ok=True)
|
||
except OSError as exc:
|
||
logger.warning(
|
||
"failed to unlink partial media {}: {}", p, exc
|
||
)
|
||
return [], reason
|
||
|
||
for item in media:
|
||
if not isinstance(item, dict):
|
||
return _abort("malformed")
|
||
data_url = item.get("data_url")
|
||
if not isinstance(data_url, str) or not data_url:
|
||
return _abort("malformed")
|
||
mime = _extract_data_url_mime(data_url)
|
||
if mime is None:
|
||
return _abort("decode")
|
||
if mime not in _UPLOAD_MIME_ALLOWED:
|
||
return _abort("mime")
|
||
is_video = mime in _VIDEO_MIME_ALLOWED
|
||
max_bytes = _MAX_VIDEO_BYTES if is_video else _MAX_IMAGE_BYTES
|
||
try:
|
||
saved = save_base64_data_url(
|
||
data_url, media_dir, max_bytes=max_bytes,
|
||
)
|
||
except FileSizeExceeded:
|
||
return _abort("size")
|
||
except Exception as exc:
|
||
logger.warning("media decode failed: {}", exc)
|
||
return _abort("decode")
|
||
if saved is None:
|
||
return _abort("decode")
|
||
paths.append(saved)
|
||
return paths, None
|
||
|
||
async def _dispatch_envelope(
|
||
self,
|
||
connection: Any,
|
||
client_id: str,
|
||
envelope: dict[str, Any],
|
||
) -> None:
|
||
"""Route one typed inbound envelope (``new_chat`` / ``attach`` / ``message``)."""
|
||
t = envelope.get("type")
|
||
if t == "new_chat":
|
||
new_id = str(uuid.uuid4())
|
||
self._attach(connection, new_id)
|
||
await self._send_event(connection, "attached", chat_id=new_id)
|
||
return
|
||
if t == "attach":
|
||
cid = envelope.get("chat_id")
|
||
if not _is_valid_chat_id(cid):
|
||
await self._send_event(connection, "error", detail="invalid chat_id")
|
||
return
|
||
self._attach(connection, cid)
|
||
await self._send_event(connection, "attached", chat_id=cid)
|
||
return
|
||
if t == "message":
|
||
cid = envelope.get("chat_id")
|
||
content = envelope.get("content")
|
||
if not _is_valid_chat_id(cid):
|
||
await self._send_event(connection, "error", detail="invalid chat_id")
|
||
return
|
||
if not isinstance(content, str):
|
||
await self._send_event(connection, "error", detail="missing content")
|
||
return
|
||
|
||
raw_media = envelope.get("media")
|
||
media_paths: list[str] = []
|
||
if raw_media is not None:
|
||
if not isinstance(raw_media, list):
|
||
await self._send_event(
|
||
connection, "error",
|
||
detail="image_rejected", reason="malformed",
|
||
)
|
||
return
|
||
media_paths, reason = self._save_envelope_media(raw_media)
|
||
if reason is not None:
|
||
await self._send_event(
|
||
connection, "error",
|
||
detail="image_rejected", reason=reason,
|
||
)
|
||
return
|
||
|
||
# Allow image-only turns (content may be empty when media is attached).
|
||
if not content.strip() and not media_paths:
|
||
await self._send_event(connection, "error", detail="missing content")
|
||
return
|
||
|
||
# Auto-attach on first use so clients can one-shot without a separate attach.
|
||
self._attach(connection, cid)
|
||
metadata: dict[str, Any] = {"remote": getattr(connection, "remote_address", None)}
|
||
if envelope.get("webui") is True:
|
||
metadata["webui"] = True
|
||
await self._handle_message(
|
||
sender_id=client_id,
|
||
chat_id=cid,
|
||
content=content,
|
||
media=media_paths or None,
|
||
metadata=metadata,
|
||
)
|
||
return
|
||
await self._send_event(connection, "error", detail=f"unknown type: {t!r}")
|
||
|
||
async def stop(self) -> None:
|
||
if not self._running:
|
||
return
|
||
self._running = False
|
||
if self._stop_event:
|
||
self._stop_event.set()
|
||
if self._server_task:
|
||
try:
|
||
await self._server_task
|
||
except Exception as e:
|
||
self.logger.warning("server task error during shutdown: {}", e)
|
||
self._server_task = None
|
||
self._subs.clear()
|
||
self._conn_chats.clear()
|
||
self._conn_default.clear()
|
||
self._issued_tokens.clear()
|
||
self._api_tokens.clear()
|
||
|
||
async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None:
|
||
"""Send a raw frame to one connection, cleaning up on ConnectionClosed."""
|
||
try:
|
||
await connection.send(raw)
|
||
except ConnectionClosed:
|
||
self._cleanup_connection(connection)
|
||
self.logger.warning("connection gone{}", label)
|
||
except Exception:
|
||
self.logger.exception("send failed{}", label)
|
||
raise
|
||
|
||
async def send(self, msg: OutboundMessage) -> None:
|
||
# Snapshot the subscriber set so ConnectionClosed cleanups mid-iteration are safe.
|
||
conns = list(self._subs.get(msg.chat_id, ()))
|
||
if not conns:
|
||
self.logger.warning("no active subscribers for chat_id={}", msg.chat_id)
|
||
return
|
||
# Signal that the agent has fully finished processing the current turn.
|
||
if msg.metadata.get("_turn_end"):
|
||
await self.send_turn_end(msg.chat_id)
|
||
return
|
||
if msg.metadata.get("_session_updated"):
|
||
await self.send_session_updated(msg.chat_id)
|
||
return
|
||
text = msg.content
|
||
if msg.buttons:
|
||
text = _append_buttons_as_text(text, msg.buttons)
|
||
payload: dict[str, Any] = {
|
||
"event": "message",
|
||
"chat_id": msg.chat_id,
|
||
"text": text,
|
||
}
|
||
if msg.buttons:
|
||
payload["buttons"] = msg.buttons
|
||
payload["button_prompt"] = msg.content
|
||
if msg.media:
|
||
payload["media"] = msg.media
|
||
urls: list[dict[str, str]] = []
|
||
for entry in msg.media:
|
||
signed = self._sign_or_stage_media_path(Path(entry))
|
||
if signed is not None:
|
||
urls.append(signed)
|
||
if urls:
|
||
payload["media_urls"] = urls
|
||
if msg.reply_to:
|
||
payload["reply_to"] = msg.reply_to
|
||
# Mark intermediate agent breadcrumbs (tool-call hints, generic
|
||
# progress strings) so WS clients can render them as subordinate
|
||
# trace rows rather than conversational replies.
|
||
if msg.metadata.get("_tool_hint"):
|
||
payload["kind"] = "tool_hint"
|
||
elif msg.metadata.get("_progress"):
|
||
payload["kind"] = "progress"
|
||
raw = json.dumps(payload, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" ")
|
||
|
||
async def send_delta(
|
||
self,
|
||
chat_id: str,
|
||
delta: str,
|
||
metadata: dict[str, Any] | None = None,
|
||
) -> None:
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
meta = metadata or {}
|
||
if meta.get("_stream_end"):
|
||
body: dict[str, Any] = {"event": "stream_end", "chat_id": chat_id}
|
||
else:
|
||
body = {
|
||
"event": "delta",
|
||
"chat_id": chat_id,
|
||
"text": delta,
|
||
}
|
||
if meta.get("_stream_id") is not None:
|
||
body["stream_id"] = meta["_stream_id"]
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" stream ")
|
||
|
||
async def send_turn_end(self, chat_id: str) -> None:
|
||
"""Signal that the agent has fully finished processing the current turn."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
body: dict[str, Any] = {"event": "turn_end", "chat_id": chat_id}
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" turn_end ")
|
||
|
||
async def send_session_updated(self, chat_id: str) -> None:
|
||
"""Notify clients that session metadata changed outside the main turn."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
body: dict[str, Any] = {"event": "session_updated", "chat_id": chat_id}
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" session_updated ")
|