mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 00:22:31 +00:00
* feat(settings): expand settings api payload * feat(webui): build app-style settings center * feat(webui): add centered chat search dialog * fix(webui): shorten chat search label * fix(webui): center dialog entrance animation * fix(webui): simplify chat search results * fix(webui): tighten mobile settings navigation * feat(webui): persist sidebar state * feat(webui): add sidebar organization controls * refactor(webui): organize backend helpers * refactor(webui): remove utils compatibility shims * refactor(session): move shared webui helpers out of webui package * feat(webui): add image generation settings * style(webui): refine settings overview layout * fix(webui): localize settings zh-CN copy * style(webui): add settings status indicators * feat(webui): show sidebar run indicators * fix(webui): persist sidebar run indicators * fix(webui): highlight settings pending status * fix(webui): align settings test with provider update * fix(utils): preserve legacy webui helper imports
1730 lines
70 KiB
Python
1730 lines
70 KiB
Python
"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import binascii
|
||
import email.utils
|
||
import hashlib
|
||
import hmac
|
||
import http
|
||
import json
|
||
import mimetypes
|
||
import re
|
||
import secrets
|
||
import shutil
|
||
import ssl
|
||
import time
|
||
import uuid
|
||
from collections.abc import Callable
|
||
from pathlib import Path
|
||
from typing import TYPE_CHECKING, Any, Self
|
||
from urllib.parse import parse_qs, unquote, urlparse
|
||
|
||
from loguru import logger
|
||
from pydantic import Field, field_validator, model_validator
|
||
from websockets.asyncio.server import ServerConnection, serve
|
||
from websockets.datastructures import Headers
|
||
from websockets.exceptions import ConnectionClosed
|
||
from websockets.http11 import Request as WsRequest
|
||
from websockets.http11 import Response
|
||
|
||
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
|
||
from nanobot.bus.queue import MessageBus
|
||
from nanobot.channels.base import BaseChannel
|
||
from nanobot.command.builtin import builtin_command_palette
|
||
from nanobot.config.paths import get_media_dir
|
||
from nanobot.config.schema import Base
|
||
from nanobot.session.goal_state import goal_state_ws_blob
|
||
from nanobot.session.webui_turns import websocket_turn_wall_started_at
|
||
from nanobot.utils.helpers import safe_filename
|
||
from nanobot.utils.media_decode import (
|
||
FileSizeExceeded,
|
||
save_base64_data_url,
|
||
)
|
||
from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel
|
||
from nanobot.webui.settings_api import (
|
||
WebUISettingsError,
|
||
settings_payload,
|
||
update_agent_settings,
|
||
update_image_generation_settings,
|
||
update_provider_settings,
|
||
update_web_search_settings,
|
||
)
|
||
from nanobot.webui.sidebar_state import (
|
||
read_webui_sidebar_state,
|
||
write_webui_sidebar_state,
|
||
)
|
||
from nanobot.webui.thread_disk import delete_webui_thread
|
||
from nanobot.webui.transcript import append_transcript_object, build_webui_thread_response
|
||
|
||
if TYPE_CHECKING:
|
||
from nanobot.session.manager import SessionManager
|
||
|
||
|
||
def _strip_trailing_slash(path: str) -> str:
|
||
if len(path) > 1 and path.endswith("/"):
|
||
return path.rstrip("/")
|
||
return path or "/"
|
||
|
||
|
||
def _normalize_config_path(path: str) -> str:
|
||
return _strip_trailing_slash(path)
|
||
|
||
|
||
class WebSocketConfig(Base):
|
||
"""WebSocket server channel configuration.
|
||
|
||
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
|
||
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
|
||
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
|
||
from ``token_issue_path`` are also accepted.
|
||
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
|
||
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
|
||
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
|
||
nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call
|
||
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
|
||
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
|
||
``X-Nanobot-Auth: <secret>``.
|
||
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
|
||
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
|
||
- ``media`` field in outbound messages contains local filesystem paths; remote clients need a
|
||
shared filesystem or an HTTP file server to access these files.
|
||
"""
|
||
|
||
enabled: bool = False
|
||
host: str = "127.0.0.1"
|
||
port: int = 8765
|
||
path: str = "/"
|
||
token: str = ""
|
||
token_issue_path: str = ""
|
||
token_issue_secret: str = ""
|
||
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
|
||
websocket_requires_token: bool = True
|
||
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
||
streaming: bool = True
|
||
# Default 36 MB, upper 40 MB: supports up to 4 images at ~6 MB each after
|
||
# client-side Worker normalization (see webui Composer). 4 × 6 MB × 1.37
|
||
# (base64 overhead) + envelope framing stays under 36 MB; the 40 MB ceiling
|
||
# leaves a small margin for sender slop without opening a DoS avenue.
|
||
max_message_bytes: int = Field(default=37_748_736, ge=1024, le=41_943_040)
|
||
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||
ssl_certfile: str = ""
|
||
ssl_keyfile: str = ""
|
||
|
||
@field_validator("path")
|
||
@classmethod
|
||
def path_must_start_with_slash(cls, value: str) -> str:
|
||
if not value.startswith("/"):
|
||
raise ValueError('path must start with "/"')
|
||
return _normalize_config_path(value)
|
||
|
||
@field_validator("token_issue_path")
|
||
@classmethod
|
||
def token_issue_path_format(cls, value: str) -> str:
|
||
value = value.strip()
|
||
if not value:
|
||
return ""
|
||
if not value.startswith("/"):
|
||
raise ValueError('token_issue_path must start with "/"')
|
||
return _normalize_config_path(value)
|
||
|
||
@model_validator(mode="after")
|
||
def token_issue_path_differs_from_ws_path(self) -> Self:
|
||
if not self.token_issue_path:
|
||
return self
|
||
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
|
||
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
|
||
return self
|
||
|
||
@model_validator(mode="after")
|
||
def wildcard_host_requires_auth(self) -> Self:
|
||
if self.host not in ("0.0.0.0", "::"):
|
||
return self
|
||
if self.token.strip() or self.token_issue_secret.strip():
|
||
return self
|
||
raise ValueError(
|
||
"host is 0.0.0.0 (all interfaces) but neither token nor "
|
||
"token_issue_secret is set — set one to prevent unauthenticated access"
|
||
)
|
||
|
||
|
||
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||
headers = Headers(
|
||
[
|
||
("Date", email.utils.formatdate(usegmt=True)),
|
||
("Connection", "close"),
|
||
("Content-Length", str(len(body))),
|
||
("Content-Type", "application/json; charset=utf-8"),
|
||
]
|
||
)
|
||
reason = http.HTTPStatus(status).phrase
|
||
return Response(status, reason, headers, body)
|
||
|
||
|
||
def publish_runtime_model_update(
|
||
bus: MessageBus,
|
||
model: str,
|
||
model_preset: str | None,
|
||
) -> None:
|
||
"""Enqueue a runtime model snapshot for websocket subscribers (fan-out in-channel)."""
|
||
bus.outbound.put_nowait(OutboundMessage(
|
||
channel="websocket",
|
||
chat_id="*",
|
||
content="",
|
||
metadata={
|
||
"_runtime_model_updated": True,
|
||
"model": model,
|
||
"model_preset": model_preset,
|
||
},
|
||
))
|
||
|
||
|
||
def _default_model_name_from_config() -> str | None:
|
||
"""Resolved model string from on-disk config (bootstrap fallback)."""
|
||
try:
|
||
from nanobot.config.loader import load_config
|
||
|
||
model = load_config().resolve_preset().model.strip()
|
||
return model or None
|
||
except Exception as e:
|
||
logger.debug("bootstrap model_name could not load from config: {}", e)
|
||
return None
|
||
|
||
|
||
def _resolve_bootstrap_model_name(
|
||
runtime_name: Callable[[], str | None] | None,
|
||
) -> str | None:
|
||
"""Prefer an in-process resolver (e.g. AgentLoop); else config-derived default."""
|
||
if runtime_name is not None:
|
||
try:
|
||
raw = runtime_name()
|
||
except Exception as e:
|
||
logger.debug("bootstrap runtime model resolver failed: {}", e)
|
||
else:
|
||
if isinstance(raw, str):
|
||
stripped = raw.strip()
|
||
if stripped:
|
||
return stripped
|
||
return _default_model_name_from_config()
|
||
|
||
|
||
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||
"""Parse normalized path and query parameters in one pass."""
|
||
parsed = urlparse("ws://x" + path_with_query)
|
||
path = _strip_trailing_slash(parsed.path or "/")
|
||
return path, parse_qs(parsed.query, keep_blank_values=True)
|
||
|
||
|
||
def _normalize_http_path(path_with_query: str) -> str:
|
||
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
|
||
return _parse_request_path(path_with_query)[0]
|
||
|
||
|
||
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||
return _parse_request_path(path_with_query)[1]
|
||
|
||
|
||
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
|
||
"""Return the first value for *key*, or None."""
|
||
values = query.get(key)
|
||
return values[0] if values else None
|
||
|
||
|
||
def _parse_inbound_payload(raw: str) -> str | None:
|
||
"""Parse a client frame into text; return None for empty or unrecognized content."""
|
||
text = raw.strip()
|
||
if not text:
|
||
return None
|
||
if text.startswith("{"):
|
||
try:
|
||
data = json.loads(text)
|
||
except json.JSONDecodeError:
|
||
return text
|
||
if isinstance(data, dict):
|
||
for key in ("content", "text", "message"):
|
||
value = data.get(key)
|
||
if isinstance(value, str) and value.strip():
|
||
return value
|
||
return None
|
||
return None
|
||
return text
|
||
|
||
|
||
# Accept UUIDs and short scoped keys like "unified:default". Keeps the capability
|
||
# namespace small enough to rule out path traversal / quote injection tricks.
|
||
_CHAT_ID_RE = re.compile(r"^[A-Za-z0-9_:-]{1,64}$")
|
||
|
||
|
||
def _is_valid_chat_id(value: Any) -> bool:
|
||
return isinstance(value, str) and _CHAT_ID_RE.match(value) is not None
|
||
|
||
|
||
def _parse_envelope(raw: str) -> dict[str, Any] | None:
|
||
"""Return a typed envelope dict if the frame is a new-style JSON envelope, else None.
|
||
|
||
A frame qualifies when it parses as a JSON object with a string ``type`` field.
|
||
Legacy frames (plain text, or ``{"content": ...}`` without ``type``) return None;
|
||
callers should fall back to :func:`_parse_inbound_payload` for those.
|
||
"""
|
||
text = raw.strip()
|
||
if not text.startswith("{"):
|
||
return None
|
||
try:
|
||
data = json.loads(text)
|
||
except json.JSONDecodeError:
|
||
return None
|
||
if not isinstance(data, dict):
|
||
return None
|
||
t = data.get("type")
|
||
if not isinstance(t, str):
|
||
return None
|
||
return data
|
||
|
||
|
||
# Per-message media limits. The server-side guard is a touch looser than the
|
||
# client's ``Worker`` normalization target (6 MB) — tolerate client slop, but
|
||
# still cap total ingress at ``_MAX_IMAGES_PER_MESSAGE * _MAX_IMAGE_BYTES``
|
||
# which fits comfortably inside ``max_message_bytes``.
|
||
_MAX_IMAGES_PER_MESSAGE = 4
|
||
_MAX_IMAGE_BYTES = 8 * 1024 * 1024
|
||
_MAX_VIDEOS_PER_MESSAGE = 1
|
||
_MAX_VIDEO_BYTES = 20 * 1024 * 1024
|
||
|
||
# Image MIME whitelist — matches the Composer's ``accept`` list. SVG is
|
||
# explicitly excluded to avoid the XSS surface inside embedded scripts.
|
||
_IMAGE_MIME_ALLOWED: frozenset[str] = frozenset({
|
||
"image/png",
|
||
"image/jpeg",
|
||
"image/webp",
|
||
"image/gif",
|
||
})
|
||
|
||
_VIDEO_MIME_ALLOWED: frozenset[str] = frozenset({
|
||
"video/mp4",
|
||
"video/webm",
|
||
"video/quicktime",
|
||
})
|
||
|
||
_UPLOAD_MIME_ALLOWED: frozenset[str] = _IMAGE_MIME_ALLOWED | _VIDEO_MIME_ALLOWED
|
||
|
||
_DATA_URL_MIME_RE = re.compile(r"^data:([^;]+);base64,", re.DOTALL)
|
||
|
||
|
||
def _extract_data_url_mime(url: str) -> str | None:
|
||
"""Return the MIME type of a ``data:<mime>;base64,...`` URL, else ``None``."""
|
||
if not isinstance(url, str):
|
||
return None
|
||
m = _DATA_URL_MIME_RE.match(url)
|
||
if not m:
|
||
return None
|
||
return m.group(1).strip().lower() or None
|
||
|
||
|
||
_LOCALHOSTS = frozenset({"127.0.0.1", "::1", "localhost"})
|
||
|
||
# Matches the legacy chat-id pattern but allows file-system-safe stems too,
|
||
# so the API can address sessions whose keys came from non-WebSocket channels.
|
||
_API_KEY_RE = re.compile(r"^[A-Za-z0-9_:.-]{1,128}$")
|
||
|
||
|
||
def _decode_api_key(raw_key: str) -> str | None:
|
||
"""Decode a percent-encoded API path segment, then validate the result."""
|
||
key = unquote(raw_key)
|
||
if _API_KEY_RE.match(key) is None:
|
||
return None
|
||
return key
|
||
|
||
|
||
def _is_localhost(connection: Any) -> bool:
|
||
"""Return True if *connection* originated from the loopback interface."""
|
||
addr = getattr(connection, "remote_address", None)
|
||
if not addr:
|
||
return False
|
||
host = addr[0] if isinstance(addr, tuple) else addr
|
||
if not isinstance(host, str):
|
||
return False
|
||
# ``::ffff:127.0.0.1`` is loopback in IPv6-mapped form.
|
||
if host.startswith("::ffff:"):
|
||
host = host[7:]
|
||
return host in _LOCALHOSTS
|
||
|
||
|
||
def _http_response(
|
||
body: bytes,
|
||
*,
|
||
status: int = 200,
|
||
content_type: str = "text/plain; charset=utf-8",
|
||
extra_headers: list[tuple[str, str]] | None = None,
|
||
) -> Response:
|
||
headers = [
|
||
("Date", email.utils.formatdate(usegmt=True)),
|
||
("Connection", "close"),
|
||
("Content-Length", str(len(body))),
|
||
("Content-Type", content_type),
|
||
]
|
||
if extra_headers:
|
||
headers.extend(extra_headers)
|
||
reason = http.HTTPStatus(status).phrase
|
||
return Response(status, reason, Headers(headers), body)
|
||
|
||
|
||
def _http_error(status: int, message: str | None = None) -> Response:
|
||
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
|
||
return _http_response(body, status=status)
|
||
|
||
|
||
def _bearer_token(headers: Any) -> str | None:
|
||
"""Pull a Bearer token out of standard or query-style headers."""
|
||
auth = headers.get("Authorization") or headers.get("authorization")
|
||
if auth and auth.lower().startswith("bearer "):
|
||
return auth[7:].strip() or None
|
||
return None
|
||
|
||
|
||
def _is_websocket_upgrade(request: WsRequest) -> bool:
|
||
"""Detect an actual WS upgrade; plain HTTP GETs to the same path should fall through."""
|
||
upgrade = request.headers.get("Upgrade") or request.headers.get("upgrade")
|
||
connection = request.headers.get("Connection") or request.headers.get("connection")
|
||
if not upgrade or "websocket" not in upgrade.lower():
|
||
return False
|
||
if not connection or "upgrade" not in connection.lower():
|
||
return False
|
||
return True
|
||
|
||
|
||
def _b64url_encode(data: bytes) -> str:
|
||
"""URL-safe base64 without padding — compact + friendly in URL paths."""
|
||
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
|
||
|
||
|
||
def _b64url_decode(s: str) -> bytes:
|
||
"""Reverse of :func:`_b64url_encode`; caller handles ``ValueError``."""
|
||
pad = "=" * (-len(s) % 4)
|
||
return base64.urlsafe_b64decode(s + pad)
|
||
|
||
|
||
# Allowed MIME types we actually serve from the media endpoint. Anything
|
||
# outside this set is degraded to ``application/octet-stream`` so an
|
||
# attacker who somehow gets a signed URL for an unexpected file type can't
|
||
# trick the browser into sniffing executable content.
|
||
_MEDIA_ALLOWED_MIMES: frozenset[str] = frozenset({
|
||
"image/png",
|
||
"image/jpeg",
|
||
"image/webp",
|
||
"image/gif",
|
||
"video/mp4",
|
||
"video/webm",
|
||
"video/quicktime",
|
||
})
|
||
|
||
|
||
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
|
||
if not configured_secret:
|
||
return True
|
||
authorization = headers.get("Authorization") or headers.get("authorization")
|
||
if authorization and authorization.lower().startswith("bearer "):
|
||
supplied = authorization[7:].strip()
|
||
return hmac.compare_digest(supplied, configured_secret)
|
||
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||
if not header_token:
|
||
return False
|
||
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||
|
||
|
||
class WebSocketChannel(BaseChannel):
|
||
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
|
||
|
||
name = "websocket"
|
||
display_name = "WebSocket"
|
||
|
||
def __init__(
|
||
self,
|
||
config: Any,
|
||
bus: MessageBus,
|
||
*,
|
||
session_manager: "SessionManager | None" = None,
|
||
static_dist_path: Path | None = None,
|
||
runtime_model_name: Callable[[], str | None] | None = None,
|
||
):
|
||
if isinstance(config, dict):
|
||
config = WebSocketConfig.model_validate(config)
|
||
super().__init__(config, bus)
|
||
self.config: WebSocketConfig = config
|
||
# chat_id -> connections subscribed to it (fan-out target).
|
||
self._subs: dict[str, set[Any]] = {}
|
||
# connection -> chat_ids it is subscribed to (O(1) cleanup on disconnect).
|
||
self._conn_chats: dict[Any, set[str]] = {}
|
||
# connection -> default chat_id for legacy frames that omit routing.
|
||
self._conn_default: dict[Any, str] = {}
|
||
# Single-use tokens consumed at WebSocket handshake.
|
||
self._issued_tokens: dict[str, float] = {}
|
||
# Multi-use tokens for HTTP routes served beside WS; checked but not consumed.
|
||
self._api_tokens: dict[str, float] = {}
|
||
self._stop_event: asyncio.Event | None = None
|
||
self._server_task: asyncio.Task[None] | None = None
|
||
self._session_manager = session_manager
|
||
self._static_dist_path: Path | None = (
|
||
static_dist_path.resolve() if static_dist_path is not None else None
|
||
)
|
||
self._runtime_model_name = runtime_model_name
|
||
self._settings_restart_sections: set[str] = set()
|
||
# Process-local secret used to HMAC-sign media URLs. The signed URL is
|
||
# the capability — anyone who holds a valid URL can fetch that one
|
||
# file, nothing else. The secret regenerates on restart so links
|
||
# become self-expiring (callers just refresh the session list).
|
||
self._media_secret: bytes = secrets.token_bytes(32)
|
||
|
||
# -- Subscription bookkeeping -------------------------------------------
|
||
|
||
def _attach(self, connection: Any, chat_id: str) -> None:
|
||
"""Idempotently subscribe *connection* to *chat_id*."""
|
||
self._subs.setdefault(chat_id, set()).add(connection)
|
||
self._conn_chats.setdefault(connection, set()).add(chat_id)
|
||
|
||
def _cleanup_connection(self, connection: Any) -> None:
|
||
"""Remove *connection* from every subscription set; safe to call multiple times."""
|
||
chat_ids = self._conn_chats.pop(connection, set())
|
||
for cid in chat_ids:
|
||
subs = self._subs.get(cid)
|
||
if subs is None:
|
||
continue
|
||
subs.discard(connection)
|
||
if not subs:
|
||
self._subs.pop(cid, None)
|
||
self._conn_default.pop(connection, None)
|
||
|
||
async def _maybe_push_active_goal_state(self, chat_id: str) -> None:
|
||
"""Replay an active sustained goal from session metadata after *chat_id* is subscribed.
|
||
|
||
Goal metadata lives on the session JSONL and survives gateway restarts, but
|
||
connected clients normally see it via ``goal_state`` / ``turn_end`` frames.
|
||
Pushing here makes refresh + reconnect restore the strip without a new model turn.
|
||
"""
|
||
if self._session_manager is None:
|
||
return
|
||
row = self._session_manager.read_session_file(f"websocket:{chat_id}")
|
||
meta = row.get("metadata", {}) if isinstance(row, dict) else {}
|
||
if not isinstance(meta, dict):
|
||
meta = {}
|
||
blob = goal_state_ws_blob(meta)
|
||
if not blob.get("active"):
|
||
return
|
||
await self.send_goal_state(chat_id, blob)
|
||
|
||
async def _maybe_push_turn_run_wall_clock(self, chat_id: str) -> None:
|
||
"""Replay ``goal_status: running`` when a turn is still active (same-process refresh)."""
|
||
t0 = websocket_turn_wall_started_at(chat_id)
|
||
if t0 is None:
|
||
return
|
||
await self.send_goal_status(chat_id, "running", started_at=t0)
|
||
|
||
async def _hydrate_after_subscribe(self, chat_id: str) -> None:
|
||
"""Replay goal/run strip state after subscribe (same-process refresh)."""
|
||
await self._maybe_push_active_goal_state(chat_id)
|
||
await self._maybe_push_turn_run_wall_clock(chat_id)
|
||
|
||
async def _send_event(self, connection: Any, event: str, **fields: Any) -> None:
|
||
"""Send a control event (attached, error, ...) to a single connection."""
|
||
payload: dict[str, Any] = {"event": event}
|
||
payload.update(fields)
|
||
raw = json.dumps(payload, ensure_ascii=False)
|
||
try:
|
||
await connection.send(raw)
|
||
except ConnectionClosed:
|
||
self._cleanup_connection(connection)
|
||
except Exception as e:
|
||
self.logger.warning("failed to send {} event: {}", event, e)
|
||
|
||
@classmethod
|
||
def default_config(cls) -> dict[str, Any]:
|
||
return WebSocketConfig().model_dump(by_alias=True)
|
||
|
||
def _expected_path(self) -> str:
|
||
return _normalize_config_path(self.config.path)
|
||
|
||
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
||
cert = self.config.ssl_certfile.strip()
|
||
key = self.config.ssl_keyfile.strip()
|
||
if not cert and not key:
|
||
return None
|
||
if not cert or not key:
|
||
raise ValueError(
|
||
"ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
||
)
|
||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||
return ctx
|
||
|
||
_MAX_ISSUED_TOKENS = 10_000
|
||
|
||
def _purge_expired_issued_tokens(self) -> None:
|
||
now = time.monotonic()
|
||
for token_key, expiry in list(self._issued_tokens.items()):
|
||
if now > expiry:
|
||
self._issued_tokens.pop(token_key, None)
|
||
|
||
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||
"""Validate and consume one issued token (single use per connection attempt).
|
||
|
||
Uses single-step pop to minimize the window between lookup and removal;
|
||
safe under asyncio's single-threaded cooperative model.
|
||
"""
|
||
if not token_value:
|
||
return False
|
||
self._purge_expired_issued_tokens()
|
||
expiry = self._issued_tokens.pop(token_value, None)
|
||
if expiry is None:
|
||
return False
|
||
if time.monotonic() > expiry:
|
||
return False
|
||
return True
|
||
|
||
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
|
||
secret = self.config.token_issue_secret.strip()
|
||
if secret:
|
||
if not _issue_route_secret_matches(request.headers, secret):
|
||
return connection.respond(401, "Unauthorized")
|
||
else:
|
||
self.logger.warning(
|
||
"token_issue_path is set but token_issue_secret is empty; "
|
||
"any client can obtain connection tokens — set token_issue_secret for production."
|
||
)
|
||
self._purge_expired_issued_tokens()
|
||
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
||
self.logger.error(
|
||
"too many outstanding issued tokens ({}), rejecting issuance",
|
||
len(self._issued_tokens),
|
||
)
|
||
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
||
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||
|
||
return _http_json_response(
|
||
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||
)
|
||
|
||
# -- HTTP dispatch ------------------------------------------------------
|
||
|
||
async def _dispatch_http(self, connection: Any, request: WsRequest) -> Any:
|
||
"""Route an inbound HTTP request to a handler or to the WS upgrade path."""
|
||
got, query = _parse_request_path(request.path)
|
||
|
||
# 1. Token issue endpoint (legacy, optional, gated by configured secret).
|
||
if self.config.token_issue_path:
|
||
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
||
if got == issue_expected:
|
||
return self._handle_token_issue_http(connection, request)
|
||
|
||
# 2. Bootstrap (`/webui/bootstrap`): mint WS/API tokens + shared session metadata.
|
||
if got == "/webui/bootstrap":
|
||
return self._handle_bootstrap(connection, request)
|
||
|
||
# 3. REST handlers co-located with this channel (sessions, settings, …).
|
||
if got == "/api/sessions":
|
||
return self._handle_sessions_list(request)
|
||
|
||
if got == "/api/settings":
|
||
return self._handle_settings(request)
|
||
|
||
if got == "/api/commands":
|
||
return self._handle_commands(request)
|
||
|
||
if got == "/api/webui/sidebar-state":
|
||
return self._handle_webui_sidebar_state(request)
|
||
|
||
if got == "/api/webui/sidebar-state/update":
|
||
return self._handle_webui_sidebar_state_update(request)
|
||
|
||
if got == "/api/settings/update":
|
||
return self._handle_settings_update(request)
|
||
|
||
if got == "/api/settings/provider/update":
|
||
return self._handle_settings_provider_update(request)
|
||
|
||
if got == "/api/settings/web-search/update":
|
||
return self._handle_settings_web_search_update(request)
|
||
|
||
if got == "/api/settings/image-generation/update":
|
||
return self._handle_settings_image_generation_update(request)
|
||
|
||
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
|
||
if m:
|
||
return self._handle_session_messages(request, m.group(1))
|
||
|
||
m = re.match(r"^/api/sessions/([^/]+)/webui-thread$", got)
|
||
if m:
|
||
return self._handle_webui_thread_get(request, m.group(1))
|
||
|
||
# NOTE: websockets' HTTP parser only accepts GET, so we cannot expose a
|
||
# true ``DELETE`` verb. The action is folded into the path instead.
|
||
m = re.match(r"^/api/sessions/([^/]+)/delete$", got)
|
||
if m:
|
||
return self._handle_session_delete(request, m.group(1))
|
||
|
||
# Signed media fetch: ``<sig>`` is an HMAC over ``<payload>``; the
|
||
# payload decodes to a path inside :func:`get_media_dir`. See
|
||
# :meth:`_sign_media_path` for the inverse direction used to build
|
||
# these URLs when replaying a session.
|
||
m = re.match(r"^/api/media/([A-Za-z0-9_-]+)/([A-Za-z0-9_-]+)$", got)
|
||
if m:
|
||
return self._handle_media_fetch(m.group(1), m.group(2))
|
||
|
||
# 4. WebSocket upgrade (the channel's primary purpose). Only run the
|
||
# handshake gate on requests that actually ask to upgrade; otherwise
|
||
# a bare ``GET /`` from the browser would be rejected as an
|
||
# unauthorized WS handshake instead of serving the SPA's index.html.
|
||
expected_ws = self._expected_path()
|
||
if got == expected_ws and _is_websocket_upgrade(request):
|
||
client_id = _query_first(query, "client_id") or ""
|
||
if len(client_id) > 128:
|
||
client_id = client_id[:128]
|
||
if not self.is_allowed(client_id):
|
||
return connection.respond(403, "Forbidden")
|
||
return self._authorize_websocket_handshake(connection, query)
|
||
|
||
# 5. Static SPA serving (only if a build directory was wired in).
|
||
if self._static_dist_path is not None:
|
||
response = self._serve_static(got)
|
||
if response is not None:
|
||
return response
|
||
|
||
return connection.respond(404, "Not Found")
|
||
|
||
# -- HTTP route handlers ------------------------------------------------
|
||
|
||
def _check_api_token(self, request: WsRequest) -> bool:
|
||
"""Validate a request against the API token pool (multi-use, TTL-bound)."""
|
||
self._purge_expired_api_tokens()
|
||
token = _bearer_token(request.headers) or _query_first(
|
||
_parse_query(request.path), "token"
|
||
)
|
||
if not token:
|
||
return False
|
||
expiry = self._api_tokens.get(token)
|
||
if expiry is None or time.monotonic() > expiry:
|
||
self._api_tokens.pop(token, None)
|
||
return False
|
||
return True
|
||
|
||
def _purge_expired_api_tokens(self) -> None:
|
||
now = time.monotonic()
|
||
for token_key, expiry in list(self._api_tokens.items()):
|
||
if now > expiry:
|
||
self._api_tokens.pop(token_key, None)
|
||
|
||
def _handle_bootstrap(self, connection: Any, request: Any) -> Response:
|
||
# When a secret is configured (token_issue_secret or static token),
|
||
# validate it regardless of source IP. This secures deployments
|
||
# behind a reverse proxy where all connections appear as localhost.
|
||
secret = self.config.token_issue_secret.strip() or self.config.token.strip()
|
||
if secret:
|
||
if not _issue_route_secret_matches(request.headers, secret):
|
||
return _http_error(401, "Unauthorized")
|
||
elif not _is_localhost(connection):
|
||
# No secret configured: only allow localhost (local dev mode).
|
||
return _http_error(403, "bootstrap is localhost-only")
|
||
# Cap outstanding tokens to avoid runaway growth from a misbehaving client.
|
||
self._purge_expired_issued_tokens()
|
||
self._purge_expired_api_tokens()
|
||
if (
|
||
len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS
|
||
or len(self._api_tokens) >= self._MAX_ISSUED_TOKENS
|
||
):
|
||
return _http_response(
|
||
json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"),
|
||
status=429,
|
||
content_type="application/json; charset=utf-8",
|
||
)
|
||
token = f"nbwt_{secrets.token_urlsafe(32)}"
|
||
expiry = time.monotonic() + float(self.config.token_ttl_s)
|
||
# Same string registered in both pools: the WS handshake consumes one copy
|
||
# while the REST surface keeps validating the other until TTL expiry.
|
||
self._issued_tokens[token] = expiry
|
||
self._api_tokens[token] = expiry
|
||
return _http_json_response(
|
||
{
|
||
"token": token,
|
||
"ws_path": self._expected_path(),
|
||
"expires_in": self.config.token_ttl_s,
|
||
"model_name": _resolve_bootstrap_model_name(self._runtime_model_name),
|
||
}
|
||
)
|
||
|
||
def _handle_sessions_list(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
if self._session_manager is None:
|
||
return _http_error(503, "session manager unavailable")
|
||
sessions = self._session_manager.list_sessions()
|
||
# Sidebar/chat listing for WS-backed sessions only — CLI / Slack / etc.
|
||
# keys are not intended for resume over this HTTP surface.
|
||
cleaned = []
|
||
for s in sessions:
|
||
key = s.get("key")
|
||
if not (isinstance(key, str) and key.startswith("websocket:")):
|
||
continue
|
||
row = {k: v for k, v in s.items() if k != "path"}
|
||
chat_id = key.split(":", 1)[1]
|
||
started_at = websocket_turn_wall_started_at(chat_id)
|
||
if started_at is not None:
|
||
row["run_started_at"] = started_at
|
||
cleaned.append(row)
|
||
return _http_json_response({"sessions": cleaned})
|
||
|
||
def _handle_settings(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
return _http_json_response(self._with_settings_restart_state(settings_payload()))
|
||
|
||
def _with_settings_restart_state(
|
||
self,
|
||
payload: dict[str, Any],
|
||
*,
|
||
section: str | None = None,
|
||
) -> dict[str, Any]:
|
||
"""Keep restart-required state alive for this gateway process."""
|
||
if section and payload.get("requires_restart"):
|
||
self._settings_restart_sections.add(section)
|
||
if self._settings_restart_sections:
|
||
payload = dict(payload)
|
||
payload["requires_restart"] = True
|
||
payload["restart_required_sections"] = sorted(self._settings_restart_sections)
|
||
else:
|
||
payload = dict(payload)
|
||
payload["restart_required_sections"] = []
|
||
return payload
|
||
|
||
def _handle_commands(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
return _http_json_response({"commands": builtin_command_palette()})
|
||
|
||
def _handle_webui_sidebar_state(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
return _http_json_response(read_webui_sidebar_state())
|
||
|
||
def _handle_webui_sidebar_state_update(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
query = _parse_query(request.path)
|
||
raw_state = _query_first(query, "state")
|
||
if raw_state is None:
|
||
return _http_error(400, "missing state")
|
||
try:
|
||
decoded = json.loads(raw_state)
|
||
except json.JSONDecodeError:
|
||
return _http_error(400, "state must be JSON")
|
||
if not isinstance(decoded, dict):
|
||
return _http_error(400, "state must be an object")
|
||
try:
|
||
state = write_webui_sidebar_state(decoded)
|
||
except ValueError as e:
|
||
return _http_error(400, str(e))
|
||
except OSError:
|
||
self.logger.exception("failed to write webui sidebar state")
|
||
return _http_error(500, "failed to write sidebar state")
|
||
return _http_json_response(state)
|
||
|
||
def _handle_settings_update(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
query = _parse_query(request.path)
|
||
try:
|
||
payload = update_agent_settings(query)
|
||
except WebUISettingsError as e:
|
||
return _http_error(e.status, e.message)
|
||
return _http_json_response(
|
||
self._with_settings_restart_state(payload, section="runtime")
|
||
)
|
||
|
||
def _handle_settings_provider_update(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
query = _parse_query(request.path)
|
||
try:
|
||
payload = update_provider_settings(query)
|
||
except WebUISettingsError as e:
|
||
return _http_error(e.status, e.message)
|
||
return _http_json_response(self._with_settings_restart_state(payload, section="image"))
|
||
|
||
def _handle_settings_web_search_update(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
query = _parse_query(request.path)
|
||
try:
|
||
payload = update_web_search_settings(query)
|
||
except WebUISettingsError as e:
|
||
return _http_error(e.status, e.message)
|
||
return _http_json_response(self._with_settings_restart_state(payload, section="web"))
|
||
|
||
def _handle_settings_image_generation_update(self, request: WsRequest) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
query = _parse_query(request.path)
|
||
try:
|
||
payload = update_image_generation_settings(query)
|
||
except WebUISettingsError as e:
|
||
return _http_error(e.status, e.message)
|
||
return _http_json_response(self._with_settings_restart_state(payload, section="image"))
|
||
|
||
@staticmethod
|
||
def _is_websocket_channel_session_key(key: str) -> bool:
|
||
"""True when *key* is a ``websocket:…`` session exposed on this HTTP surface."""
|
||
return key.startswith("websocket:")
|
||
|
||
def _handle_session_messages(self, request: WsRequest, key: str) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
if self._session_manager is None:
|
||
return _http_error(503, "session manager unavailable")
|
||
decoded_key = _decode_api_key(key)
|
||
if decoded_key is None:
|
||
return _http_error(400, "invalid session key")
|
||
# Only ``websocket:…`` sessions are listed/served here — same boundary as
|
||
# ``/api/sessions``. Block handcrafted URLs from probing CLI / Slack / etc.
|
||
if not self._is_websocket_channel_session_key(decoded_key):
|
||
return _http_error(404, "session not found")
|
||
data = self._session_manager.read_session_file(decoded_key)
|
||
if data is None:
|
||
return _http_error(404, "session not found")
|
||
messages = data.get("messages")
|
||
if isinstance(messages, list):
|
||
scrub_subagent_messages_for_channel(messages)
|
||
# Decorate persisted user messages with signed media URLs so the
|
||
# client can render previews. The raw on-disk ``media`` paths are
|
||
# stripped on the way out — they leak server filesystem layout and
|
||
# the client never needs them once it has the signed fetch URL.
|
||
self._augment_media_urls(data)
|
||
return _http_json_response(data)
|
||
|
||
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
decoded_key = _decode_api_key(key)
|
||
if decoded_key is None:
|
||
return _http_error(400, "invalid session key")
|
||
if not self._is_websocket_channel_session_key(decoded_key):
|
||
return _http_error(404, "session not found")
|
||
data = build_webui_thread_response(
|
||
decoded_key,
|
||
augment_user_media=self._augment_transcript_user_media,
|
||
)
|
||
if data is None:
|
||
return _http_error(404, "webui thread not found")
|
||
return _http_json_response(data)
|
||
|
||
def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None:
|
||
sk = f"websocket:{chat_id}"
|
||
try:
|
||
dup = json.loads(json.dumps(wire, ensure_ascii=False))
|
||
append_transcript_object(sk, dup)
|
||
except (ValueError, TypeError) as e:
|
||
self.logger.warning("webui transcript append failed: {}", e)
|
||
|
||
def _augment_transcript_user_media(self, paths: list[str]) -> list[dict[str, Any]]:
|
||
out: list[dict[str, Any]] = []
|
||
for pstr in paths:
|
||
path = Path(pstr)
|
||
att = self._sign_or_stage_media_path(path)
|
||
if att is None:
|
||
continue
|
||
mime, _ = mimetypes.guess_type(path.name)
|
||
kind = "video" if mime and mime.startswith("video/") else "image"
|
||
out.append(
|
||
{"kind": kind, "url": att["url"], "name": att.get("name", path.name)},
|
||
)
|
||
return out
|
||
|
||
async def _handle_message(
|
||
self,
|
||
sender_id: str,
|
||
chat_id: str,
|
||
content: str,
|
||
media: list[str] | None = None,
|
||
metadata: dict[str, Any] | None = None,
|
||
session_key: str | None = None,
|
||
is_dm: bool = False,
|
||
) -> None:
|
||
meta = metadata or {}
|
||
if meta.get("webui"):
|
||
user_obj: dict[str, Any] = {
|
||
"event": "user",
|
||
"chat_id": chat_id,
|
||
"text": content,
|
||
}
|
||
if media:
|
||
user_obj["media_paths"] = list(media)
|
||
self._try_append_webui_transcript(chat_id, user_obj)
|
||
await super()._handle_message(
|
||
sender_id,
|
||
chat_id,
|
||
content,
|
||
media,
|
||
metadata,
|
||
session_key,
|
||
is_dm,
|
||
)
|
||
|
||
def _augment_media_urls(self, payload: dict[str, Any]) -> None:
|
||
"""Mutate *payload* in place: each message's ``media`` path list is
|
||
replaced by a parallel ``media_urls`` list of signed fetch URLs.
|
||
|
||
Messages without media or with non-string path entries are left
|
||
untouched. Paths that no longer live inside ``media_dir`` (e.g. the
|
||
file was deleted, or the dir was relocated) are silently skipped;
|
||
the client falls back to the historical-replay placeholder tile.
|
||
"""
|
||
messages = payload.get("messages")
|
||
if not isinstance(messages, list):
|
||
return
|
||
for msg in messages:
|
||
if not isinstance(msg, dict):
|
||
continue
|
||
media = msg.get("media")
|
||
if not isinstance(media, list) or not media:
|
||
continue
|
||
urls: list[dict[str, str]] = []
|
||
for entry in media:
|
||
if not isinstance(entry, str) or not entry:
|
||
continue
|
||
signed = self._sign_media_path(Path(entry))
|
||
if signed is None:
|
||
continue
|
||
urls.append({"url": signed, "name": Path(entry).name})
|
||
if urls:
|
||
msg["media_urls"] = urls
|
||
# Always drop the raw paths from the wire payload.
|
||
msg.pop("media", None)
|
||
|
||
def _sign_media_path(self, abs_path: Path) -> str | None:
|
||
"""Return a ``/api/media/<sig>/<payload>`` URL for *abs_path*, or
|
||
``None`` when the path does not resolve inside the media root.
|
||
|
||
The URL is self-authenticating: the signature binds the payload to
|
||
this process's ``_media_secret``, so only paths we chose to sign can
|
||
be fetched. The returned path is relative to the server origin; the
|
||
client joins it against this server's HTTP origin (same host as WS).
|
||
"""
|
||
try:
|
||
media_root = get_media_dir().resolve()
|
||
rel = abs_path.resolve().relative_to(media_root)
|
||
except (OSError, ValueError):
|
||
return None
|
||
payload = _b64url_encode(rel.as_posix().encode("utf-8"))
|
||
mac = hmac.new(
|
||
self._media_secret, payload.encode("ascii"), hashlib.sha256
|
||
).digest()[:16]
|
||
return f"/api/media/{_b64url_encode(mac)}/{payload}"
|
||
|
||
def _sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
|
||
"""Return a signed media URL payload for *path*.
|
||
|
||
Persisted inbound media already lives under ``get_media_dir`` and can
|
||
be signed directly. Outbound bot-generated files may live anywhere on
|
||
disk; copy those into the websocket media bucket first so the browser
|
||
can fetch them through the existing signed media route without
|
||
exposing arbitrary filesystem paths.
|
||
"""
|
||
signed = self._sign_media_path(path)
|
||
if signed is not None:
|
||
return {"url": signed, "name": path.name}
|
||
try:
|
||
if not path.is_file():
|
||
return None
|
||
media_dir = get_media_dir("websocket")
|
||
safe_name = safe_filename(path.name) or "attachment"
|
||
staged = media_dir / f"{uuid.uuid4().hex[:12]}-{safe_name}"
|
||
shutil.copyfile(path, staged)
|
||
except OSError as exc:
|
||
self.logger.warning("failed to stage outbound media {}: {}", path, exc)
|
||
return None
|
||
signed = self._sign_media_path(staged)
|
||
if signed is None:
|
||
return None
|
||
return {"url": signed, "name": path.name}
|
||
|
||
def _handle_media_fetch(self, sig: str, payload: str) -> Response:
|
||
"""Serve a single media file previously signed via
|
||
:meth:`_sign_media_path`. Validates the signature, decodes the
|
||
payload to a relative path, and streams the file bytes with a
|
||
long-lived immutable cache header (the URL already encodes the
|
||
file identity, so caches can be aggressive)."""
|
||
try:
|
||
provided_mac = _b64url_decode(sig)
|
||
except (ValueError, binascii.Error):
|
||
return _http_error(401, "invalid signature")
|
||
expected_mac = hmac.new(
|
||
self._media_secret, payload.encode("ascii"), hashlib.sha256
|
||
).digest()[:16]
|
||
if not hmac.compare_digest(expected_mac, provided_mac):
|
||
return _http_error(401, "invalid signature")
|
||
try:
|
||
rel_bytes = _b64url_decode(payload)
|
||
rel_str = rel_bytes.decode("utf-8")
|
||
except (ValueError, binascii.Error, UnicodeDecodeError):
|
||
return _http_error(400, "invalid payload")
|
||
# An attacker who somehow bypassed the HMAC check would still need
|
||
# the resolved path to escape the media root; guard defensively.
|
||
try:
|
||
media_root = get_media_dir().resolve()
|
||
candidate = (media_root / rel_str).resolve()
|
||
candidate.relative_to(media_root)
|
||
except (OSError, ValueError):
|
||
return _http_error(404, "not found")
|
||
if not candidate.is_file():
|
||
return _http_error(404, "not found")
|
||
try:
|
||
body = candidate.read_bytes()
|
||
except OSError:
|
||
return _http_error(500, "read error")
|
||
mime, _ = mimetypes.guess_type(candidate.name)
|
||
if mime not in _MEDIA_ALLOWED_MIMES:
|
||
mime = "application/octet-stream"
|
||
return _http_response(
|
||
body,
|
||
content_type=mime,
|
||
extra_headers=[
|
||
("Cache-Control", "private, max-age=31536000, immutable"),
|
||
# Paired with the MIME whitelist above: prevents browsers from
|
||
# MIME-sniffing an octet-stream fallback into executable HTML.
|
||
("X-Content-Type-Options", "nosniff"),
|
||
],
|
||
)
|
||
|
||
def _handle_session_delete(self, request: WsRequest, key: str) -> Response:
|
||
if not self._check_api_token(request):
|
||
return _http_error(401, "Unauthorized")
|
||
if self._session_manager is None:
|
||
return _http_error(503, "session manager unavailable")
|
||
decoded_key = _decode_api_key(key)
|
||
if decoded_key is None:
|
||
return _http_error(400, "invalid session key")
|
||
# Same boundary as ``_handle_session_messages``: mutations apply only to
|
||
# websocket-channel sessions; deletion unlinks local JSONL — keep scope narrow.
|
||
if not self._is_websocket_channel_session_key(decoded_key):
|
||
return _http_error(404, "session not found")
|
||
deleted = self._session_manager.delete_session(decoded_key)
|
||
delete_webui_thread(decoded_key)
|
||
return _http_json_response({"deleted": bool(deleted)})
|
||
|
||
def _serve_static(self, request_path: str) -> Response | None:
|
||
"""Resolve *request_path* against the built SPA directory; SPA fallback to index.html."""
|
||
assert self._static_dist_path is not None
|
||
rel = request_path.lstrip("/")
|
||
if not rel:
|
||
rel = "index.html"
|
||
# Reject path-traversal attempts and absolute targets.
|
||
if ".." in rel.split("/") or rel.startswith("/"):
|
||
return _http_error(403, "Forbidden")
|
||
candidate = (self._static_dist_path / rel).resolve()
|
||
try:
|
||
candidate.relative_to(self._static_dist_path)
|
||
except ValueError:
|
||
return _http_error(403, "Forbidden")
|
||
if not candidate.is_file():
|
||
# SPA history-mode fallback: unknown routes serve index.html so the
|
||
# client-side router can render them.
|
||
index = self._static_dist_path / "index.html"
|
||
if index.is_file():
|
||
candidate = index
|
||
else:
|
||
return None
|
||
try:
|
||
body = candidate.read_bytes()
|
||
except OSError as e:
|
||
self.logger.warning("static: failed to read {}: {}", candidate, e)
|
||
return _http_error(500, "Internal Server Error")
|
||
ctype, _ = mimetypes.guess_type(candidate.name)
|
||
if ctype is None:
|
||
ctype = "application/octet-stream"
|
||
if ctype.startswith("text/") or ctype in {"application/javascript", "application/json"}:
|
||
ctype = f"{ctype}; charset=utf-8"
|
||
# Hash-named build assets are cache-friendly; index.html must stay fresh.
|
||
if candidate.name == "index.html":
|
||
cache = "no-cache"
|
||
else:
|
||
cache = "public, max-age=31536000, immutable"
|
||
return _http_response(
|
||
body,
|
||
status=200,
|
||
content_type=ctype,
|
||
extra_headers=[("Cache-Control", cache)],
|
||
)
|
||
|
||
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
|
||
supplied = _query_first(query, "token")
|
||
static_token = self.config.token.strip()
|
||
|
||
if static_token:
|
||
if supplied and hmac.compare_digest(supplied, static_token):
|
||
return None
|
||
if supplied and self._take_issued_token_if_valid(supplied):
|
||
return None
|
||
return connection.respond(401, "Unauthorized")
|
||
|
||
if self.config.websocket_requires_token:
|
||
if supplied and self._take_issued_token_if_valid(supplied):
|
||
return None
|
||
return connection.respond(401, "Unauthorized")
|
||
|
||
if supplied:
|
||
self._take_issued_token_if_valid(supplied)
|
||
return None
|
||
|
||
async def start(self) -> None:
|
||
from nanobot.utils.logging_bridge import redirect_lib_logging
|
||
|
||
redirect_lib_logging("websockets", level="WARNING")
|
||
|
||
self._running = True
|
||
self._stop_event = asyncio.Event()
|
||
|
||
ssl_context = self._build_ssl_context()
|
||
scheme = "wss" if ssl_context else "ws"
|
||
|
||
async def process_request(
|
||
connection: ServerConnection,
|
||
request: WsRequest,
|
||
) -> Any:
|
||
return await self._dispatch_http(connection, request)
|
||
|
||
async def handler(connection: ServerConnection) -> None:
|
||
await self._connection_loop(connection)
|
||
|
||
self.logger.info(
|
||
"WebSocket server listening on {}://{}:{}{}",
|
||
scheme,
|
||
self.config.host,
|
||
self.config.port,
|
||
self.config.path,
|
||
)
|
||
if self.config.token_issue_path:
|
||
self.logger.info(
|
||
"WebSocket token issue route: {}://{}:{}{}",
|
||
scheme,
|
||
self.config.host,
|
||
self.config.port,
|
||
_normalize_config_path(self.config.token_issue_path),
|
||
)
|
||
|
||
async def runner() -> None:
|
||
async with serve(
|
||
handler,
|
||
self.config.host,
|
||
self.config.port,
|
||
process_request=process_request,
|
||
max_size=self.config.max_message_bytes,
|
||
ping_interval=self.config.ping_interval_s,
|
||
ping_timeout=self.config.ping_timeout_s,
|
||
ssl=ssl_context,
|
||
):
|
||
assert self._stop_event is not None
|
||
await self._stop_event.wait()
|
||
|
||
self._server_task = asyncio.create_task(runner())
|
||
await self._server_task
|
||
|
||
async def _connection_loop(self, connection: Any) -> None:
|
||
request = connection.request
|
||
path_part = request.path if request else "/"
|
||
_, query = _parse_request_path(path_part)
|
||
client_id_raw = _query_first(query, "client_id")
|
||
client_id = client_id_raw.strip() if client_id_raw else ""
|
||
if not client_id:
|
||
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
||
elif len(client_id) > 128:
|
||
self.logger.warning("client_id too long ({} chars), truncating", len(client_id))
|
||
client_id = client_id[:128]
|
||
|
||
default_chat_id = str(uuid.uuid4())
|
||
|
||
try:
|
||
await connection.send(
|
||
json.dumps(
|
||
{
|
||
"event": "ready",
|
||
"chat_id": default_chat_id,
|
||
"client_id": client_id,
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
)
|
||
# Register only after ready is successfully sent to avoid out-of-order sends
|
||
self._conn_default[connection] = default_chat_id
|
||
self._attach(connection, default_chat_id)
|
||
await self._hydrate_after_subscribe(default_chat_id)
|
||
|
||
async for raw in connection:
|
||
if isinstance(raw, bytes):
|
||
try:
|
||
raw = raw.decode("utf-8")
|
||
except UnicodeDecodeError:
|
||
self.logger.warning("ignoring non-utf8 binary frame")
|
||
continue
|
||
|
||
envelope = _parse_envelope(raw)
|
||
if envelope is not None:
|
||
await self._dispatch_envelope(connection, client_id, envelope)
|
||
continue
|
||
|
||
content = _parse_inbound_payload(raw)
|
||
if content is None:
|
||
continue
|
||
# WebSocket already authenticates at handshake time (token),
|
||
# so pairing is not applicable. Treat as non-DM to avoid
|
||
# sending pairing codes to an already-authenticated client.
|
||
await self._handle_message(
|
||
sender_id=client_id,
|
||
chat_id=default_chat_id,
|
||
content=content,
|
||
metadata={"remote": getattr(connection, "remote_address", None)},
|
||
is_dm=False,
|
||
)
|
||
except Exception as e:
|
||
self.logger.debug("connection ended: {}", e)
|
||
finally:
|
||
self._cleanup_connection(connection)
|
||
|
||
def _save_envelope_media(
|
||
self,
|
||
media: list[Any],
|
||
) -> tuple[list[str], str | None]:
|
||
"""Decode and persist ``media`` items from a ``message`` envelope.
|
||
|
||
Returns ``(paths, None)`` on success or ``([], reason)`` on the first
|
||
failure — the caller is expected to surface ``reason`` to the client
|
||
and skip publishing so no half-formed message ever reaches the agent.
|
||
On failure, any files already written to disk earlier in the same
|
||
call are unlinked so partial ingress doesn't leak orphan files.
|
||
``reason`` is a short, stable token suitable for UI localization.
|
||
|
||
Shape: ``list[{"data_url": str, "name"?: str | None}]``.
|
||
"""
|
||
image_count = 0
|
||
video_count = 0
|
||
for item in media:
|
||
mime = _extract_data_url_mime(item.get("data_url", "")) if isinstance(item, dict) else None
|
||
if mime in _VIDEO_MIME_ALLOWED:
|
||
video_count += 1
|
||
elif mime in _IMAGE_MIME_ALLOWED:
|
||
image_count += 1
|
||
if image_count > _MAX_IMAGES_PER_MESSAGE:
|
||
return [], "too_many_images"
|
||
if video_count > _MAX_VIDEOS_PER_MESSAGE:
|
||
return [], "too_many_videos"
|
||
|
||
media_dir = get_media_dir("websocket")
|
||
paths: list[str] = []
|
||
|
||
def _abort(reason: str) -> tuple[list[str], str]:
|
||
for p in paths:
|
||
try:
|
||
Path(p).unlink(missing_ok=True)
|
||
except OSError as exc:
|
||
self.logger.warning(
|
||
"failed to unlink partial media {}: {}", p, exc
|
||
)
|
||
return [], reason
|
||
|
||
for item in media:
|
||
if not isinstance(item, dict):
|
||
return _abort("malformed")
|
||
data_url = item.get("data_url")
|
||
if not isinstance(data_url, str) or not data_url:
|
||
return _abort("malformed")
|
||
mime = _extract_data_url_mime(data_url)
|
||
if mime is None:
|
||
return _abort("decode")
|
||
if mime not in _UPLOAD_MIME_ALLOWED:
|
||
return _abort("mime")
|
||
is_video = mime in _VIDEO_MIME_ALLOWED
|
||
max_bytes = _MAX_VIDEO_BYTES if is_video else _MAX_IMAGE_BYTES
|
||
try:
|
||
saved = save_base64_data_url(
|
||
data_url, media_dir, max_bytes=max_bytes,
|
||
)
|
||
except FileSizeExceeded:
|
||
return _abort("size")
|
||
except Exception as exc:
|
||
self.logger.warning("media decode failed: {}", exc)
|
||
return _abort("decode")
|
||
if saved is None:
|
||
return _abort("decode")
|
||
paths.append(saved)
|
||
return paths, None
|
||
|
||
async def _dispatch_envelope(
|
||
self,
|
||
connection: Any,
|
||
client_id: str,
|
||
envelope: dict[str, Any],
|
||
) -> None:
|
||
"""Route one typed inbound envelope (``new_chat`` / ``attach`` / ``message``)."""
|
||
t = envelope.get("type")
|
||
if t == "new_chat":
|
||
new_id = str(uuid.uuid4())
|
||
self._attach(connection, new_id)
|
||
await self._send_event(connection, "attached", chat_id=new_id)
|
||
await self._hydrate_after_subscribe(new_id)
|
||
return
|
||
if t == "attach":
|
||
cid = envelope.get("chat_id")
|
||
if not _is_valid_chat_id(cid):
|
||
await self._send_event(connection, "error", detail="invalid chat_id")
|
||
return
|
||
self._attach(connection, cid)
|
||
await self._send_event(connection, "attached", chat_id=cid)
|
||
await self._hydrate_after_subscribe(cid)
|
||
return
|
||
if t == "message":
|
||
cid = envelope.get("chat_id")
|
||
content = envelope.get("content")
|
||
if not _is_valid_chat_id(cid):
|
||
await self._send_event(connection, "error", detail="invalid chat_id")
|
||
return
|
||
if not isinstance(content, str):
|
||
await self._send_event(connection, "error", detail="missing content")
|
||
return
|
||
|
||
raw_media = envelope.get("media")
|
||
media_paths: list[str] = []
|
||
if raw_media is not None:
|
||
if not isinstance(raw_media, list):
|
||
await self._send_event(
|
||
connection, "error",
|
||
detail="image_rejected", reason="malformed",
|
||
)
|
||
return
|
||
media_paths, reason = self._save_envelope_media(raw_media)
|
||
if reason is not None:
|
||
await self._send_event(
|
||
connection, "error",
|
||
detail="image_rejected", reason=reason,
|
||
)
|
||
return
|
||
|
||
# Allow image-only turns (content may be empty when media is attached).
|
||
if not content.strip() and not media_paths:
|
||
await self._send_event(connection, "error", detail="missing content")
|
||
return
|
||
|
||
# Auto-attach on first use so clients can one-shot without a separate attach.
|
||
self._attach(connection, cid)
|
||
await self._hydrate_after_subscribe(cid)
|
||
metadata: dict[str, Any] = {"remote": getattr(connection, "remote_address", None)}
|
||
if envelope.get("webui") is True:
|
||
metadata["webui"] = True
|
||
image_generation = envelope.get("image_generation")
|
||
if isinstance(image_generation, dict) and image_generation.get("enabled") is True:
|
||
aspect_ratio = image_generation.get("aspect_ratio")
|
||
metadata["image_generation"] = {
|
||
"enabled": True,
|
||
"aspect_ratio": aspect_ratio if isinstance(aspect_ratio, str) else None,
|
||
}
|
||
await self._handle_message(
|
||
sender_id=client_id,
|
||
chat_id=cid,
|
||
content=content,
|
||
media=media_paths or None,
|
||
metadata=metadata,
|
||
is_dm=False,
|
||
)
|
||
return
|
||
await self._send_event(connection, "error", detail=f"unknown type: {t!r}")
|
||
|
||
async def stop(self) -> None:
|
||
if not self._running:
|
||
return
|
||
self._running = False
|
||
if self._stop_event:
|
||
self._stop_event.set()
|
||
if self._server_task:
|
||
try:
|
||
await self._server_task
|
||
except Exception as e:
|
||
self.logger.warning("server task error during shutdown: {}", e)
|
||
self._server_task = None
|
||
self._subs.clear()
|
||
self._conn_chats.clear()
|
||
self._conn_default.clear()
|
||
self._issued_tokens.clear()
|
||
self._api_tokens.clear()
|
||
|
||
async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None:
|
||
"""Send a raw frame to one connection, cleaning up on ConnectionClosed."""
|
||
try:
|
||
await connection.send(raw)
|
||
except ConnectionClosed:
|
||
self._cleanup_connection(connection)
|
||
self.logger.warning("connection gone{}", label)
|
||
except Exception:
|
||
self.logger.exception("send failed{}", label)
|
||
raise
|
||
|
||
async def send(self, msg: OutboundMessage) -> None:
|
||
if msg.metadata.get("_runtime_model_updated"):
|
||
await self.send_runtime_model_updated(
|
||
model_name=msg.metadata.get("model"),
|
||
model_preset=msg.metadata.get("model_preset"),
|
||
)
|
||
return
|
||
|
||
# Snapshot the subscriber set so ConnectionClosed cleanups mid-iteration are safe.
|
||
conns = list(self._subs.get(msg.chat_id, ()))
|
||
if not conns:
|
||
if (
|
||
msg.metadata.get("_progress")
|
||
or msg.metadata.get("_file_edit_events")
|
||
or msg.metadata.get("_turn_end")
|
||
or msg.metadata.get("_session_updated")
|
||
or msg.metadata.get("_goal_status")
|
||
or msg.metadata.get("_goal_state_sync")
|
||
):
|
||
self.logger.debug("no active subscribers for chat_id={}", msg.chat_id)
|
||
else:
|
||
self.logger.warning("no active subscribers for chat_id={}", msg.chat_id)
|
||
return
|
||
if msg.metadata.get("_goal_state_sync"):
|
||
blob = msg.metadata.get("goal_state")
|
||
await self.send_goal_state(msg.chat_id, blob if isinstance(blob, dict) else {"active": False})
|
||
return
|
||
if msg.metadata.get("_goal_status"):
|
||
status = msg.metadata.get("goal_status")
|
||
if status in ("running", "idle"):
|
||
started_raw = msg.metadata.get("started_at", msg.metadata.get("goal_started_at"))
|
||
await self.send_goal_status(
|
||
msg.chat_id,
|
||
status,
|
||
started_at=float(started_raw) if isinstance(started_raw, int | float) else None,
|
||
)
|
||
return
|
||
# Signal that the agent has fully finished processing the current turn.
|
||
if msg.metadata.get("_turn_end"):
|
||
lat = msg.metadata.get("latency_ms")
|
||
lat_i = int(lat) if isinstance(lat, (int, float)) else None
|
||
gs = msg.metadata.get("goal_state")
|
||
gs_blob = gs if isinstance(gs, dict) else None
|
||
await self.send_turn_end(msg.chat_id, latency_ms=lat_i, goal_state=gs_blob)
|
||
return
|
||
if msg.metadata.get("_session_updated"):
|
||
scope = msg.metadata.get("_session_update_scope")
|
||
await self.send_session_updated(
|
||
msg.chat_id,
|
||
scope=scope if isinstance(scope, str) else None,
|
||
)
|
||
return
|
||
if msg.metadata.get("_file_edit_events"):
|
||
payload: dict[str, Any] = {
|
||
"event": "file_edit",
|
||
"chat_id": msg.chat_id,
|
||
"edits": msg.metadata["_file_edit_events"],
|
||
}
|
||
self._try_append_webui_transcript(msg.chat_id, payload)
|
||
raw = json.dumps(payload, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" ")
|
||
return
|
||
text = msg.content
|
||
payload: dict[str, Any] = {
|
||
"event": "message",
|
||
"chat_id": msg.chat_id,
|
||
"text": text,
|
||
}
|
||
if msg.media:
|
||
payload["media"] = msg.media
|
||
urls: list[dict[str, str]] = []
|
||
for entry in msg.media:
|
||
signed = self._sign_or_stage_media_path(Path(entry))
|
||
if signed is not None:
|
||
urls.append(signed)
|
||
if urls:
|
||
payload["media_urls"] = urls
|
||
if msg.reply_to:
|
||
payload["reply_to"] = msg.reply_to
|
||
lat = msg.metadata.get("latency_ms")
|
||
if isinstance(lat, (int, float)):
|
||
payload["latency_ms"] = int(lat)
|
||
if msg.metadata.get("_tool_events"):
|
||
payload["tool_events"] = msg.metadata["_tool_events"]
|
||
agent_ui = msg.metadata.get(OUTBOUND_META_AGENT_UI)
|
||
if agent_ui is not None:
|
||
payload["agent_ui"] = agent_ui
|
||
# Mark intermediate agent breadcrumbs (tool-call hints, generic
|
||
# progress strings) so WS clients can render them as subordinate
|
||
# trace rows rather than conversational replies.
|
||
if msg.metadata.get("_tool_hint"):
|
||
payload["kind"] = "tool_hint"
|
||
elif msg.metadata.get("_progress"):
|
||
payload["kind"] = "progress"
|
||
self._try_append_webui_transcript(msg.chat_id, payload)
|
||
raw = json.dumps(payload, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" ")
|
||
|
||
async def send_reasoning_delta(
|
||
self,
|
||
chat_id: str,
|
||
delta: str,
|
||
metadata: dict[str, Any] | None = None,
|
||
) -> None:
|
||
"""Push one chunk of model reasoning. Mirrors ``send_delta`` shape so
|
||
clients receive a stream that opens, updates in place, and closes —
|
||
rendered above the active assistant bubble with a shimmer header
|
||
until the matching ``reasoning_end`` arrives.
|
||
"""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns or not delta:
|
||
return
|
||
meta = metadata or {}
|
||
body: dict[str, Any] = {
|
||
"event": "reasoning_delta",
|
||
"chat_id": chat_id,
|
||
"text": delta,
|
||
}
|
||
stream_id = meta.get("_stream_id")
|
||
if stream_id is not None:
|
||
body["stream_id"] = stream_id
|
||
self._try_append_webui_transcript(chat_id, body)
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" reasoning ")
|
||
|
||
async def send_reasoning_end(
|
||
self,
|
||
chat_id: str,
|
||
metadata: dict[str, Any] | None = None,
|
||
) -> None:
|
||
"""Close the current reasoning stream segment for in-place renderers."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
meta = metadata or {}
|
||
body: dict[str, Any] = {
|
||
"event": "reasoning_end",
|
||
"chat_id": chat_id,
|
||
}
|
||
stream_id = meta.get("_stream_id")
|
||
if stream_id is not None:
|
||
body["stream_id"] = stream_id
|
||
self._try_append_webui_transcript(chat_id, body)
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" reasoning_end ")
|
||
|
||
async def send_delta(
|
||
self,
|
||
chat_id: str,
|
||
delta: str,
|
||
metadata: dict[str, Any] | None = None,
|
||
) -> None:
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
meta = metadata or {}
|
||
if meta.get("_stream_end"):
|
||
body: dict[str, Any] = {"event": "stream_end", "chat_id": chat_id}
|
||
else:
|
||
body = {
|
||
"event": "delta",
|
||
"chat_id": chat_id,
|
||
"text": delta,
|
||
}
|
||
if meta.get("_stream_id") is not None:
|
||
body["stream_id"] = meta["_stream_id"]
|
||
self._try_append_webui_transcript(chat_id, body)
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" stream ")
|
||
|
||
async def send_turn_end(
|
||
self,
|
||
chat_id: str,
|
||
latency_ms: int | None = None,
|
||
*,
|
||
goal_state: dict[str, Any] | None = None,
|
||
) -> None:
|
||
"""Signal that the agent has fully finished processing the current turn."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
body: dict[str, Any] = {"event": "turn_end", "chat_id": chat_id}
|
||
if latency_ms is not None:
|
||
body["latency_ms"] = int(latency_ms)
|
||
if goal_state is not None:
|
||
body["goal_state"] = goal_state
|
||
self._try_append_webui_transcript(chat_id, body)
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" turn_end ")
|
||
|
||
async def send_goal_state(self, chat_id: str, blob: dict[str, Any]) -> None:
|
||
"""Push persisted goal-state snapshot for *chat_id* (multi-chat isolation)."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
body = {"event": "goal_state", "chat_id": chat_id, "goal_state": blob}
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" goal_state ")
|
||
|
||
async def send_goal_status(
|
||
self,
|
||
chat_id: str,
|
||
status: str,
|
||
*,
|
||
started_at: float | None = None,
|
||
) -> None:
|
||
"""Notify subscribed clients that a turn started or finished (wall-clock hint)."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
body: dict[str, Any] = {
|
||
"event": "goal_status",
|
||
"chat_id": chat_id,
|
||
"status": status,
|
||
}
|
||
if status == "running" and started_at is not None:
|
||
body["started_at"] = started_at
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" goal_status ")
|
||
|
||
async def send_session_updated(self, chat_id: str, *, scope: str | None = None) -> None:
|
||
"""Notify clients that session metadata changed outside the main turn."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
body: dict[str, Any] = {"event": "session_updated", "chat_id": chat_id}
|
||
if scope:
|
||
body["scope"] = scope
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" session_updated ")
|
||
|
||
async def send_runtime_model_updated(
|
||
self,
|
||
*,
|
||
model_name: Any,
|
||
model_preset: Any = None,
|
||
) -> None:
|
||
"""Broadcast runtime model changes to every open websocket connection."""
|
||
conns = list(self._conn_chats)
|
||
if not conns or not isinstance(model_name, str) or not model_name.strip():
|
||
return
|
||
body: dict[str, Any] = {
|
||
"event": "runtime_model_updated",
|
||
"model_name": model_name.strip(),
|
||
}
|
||
if isinstance(model_preset, str) and model_preset.strip():
|
||
body["model_preset"] = model_preset.strip()
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" runtime_model_updated ")
|