nanobot/nanobot/channels/websocket.py

1367 lines
52 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
from __future__ import annotations
import asyncio
import email.utils
import hmac
import http
import json
import re
import ssl
import uuid
from collections.abc import Callable
from contextlib import suppress
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Self
from urllib.parse import parse_qs, unquote, urlparse
from loguru import logger
from pydantic import Field, field_validator, model_validator
from websockets.asyncio.server import ServerConnection, serve, unix_serve
from websockets.datastructures import Headers
from websockets.exceptions import ConnectionClosed
from websockets.http11 import Request as WsRequest
from websockets.http11 import Response
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir, get_workspace_path
from nanobot.config.schema import Base
from nanobot.security.workspace_access import (
WORKSPACE_SCOPE_METADATA_KEY,
WorkspaceScopeError,
)
from nanobot.session.goal_state import goal_state_ws_blob
from nanobot.session.webui_turns import websocket_turn_wall_started_at
from nanobot.utils.media_decode import (
FileSizeExceeded,
save_base64_data_url,
)
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
from nanobot.webui.transcript import append_transcript_object
if TYPE_CHECKING:
from nanobot.session.manager import SessionManager
def _strip_trailing_slash(path: str) -> str:
if len(path) > 1 and path.endswith("/"):
return path.rstrip("/")
return path or "/"
def _normalize_config_path(path: str) -> str:
return _strip_trailing_slash(path)
def _case_insensitive_header(headers: Any, key: str) -> str:
"""Read a header from websockets/http test stubs without assuming casing."""
try:
value = headers.get(key)
except Exception:
value = None
if value is None:
try:
value = headers.get(key.lower())
except Exception:
value = None
return str(value or "").strip()
def _safe_host_header(value: str) -> str:
"""Return a safe Host header value, or empty when it should not be echoed."""
value = value.strip()
if not value:
return ""
if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value):
return value
if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value):
return value
return ""
def _host_for_url(host: str, port: int) -> str:
host = host.strip()
if host in ("0.0.0.0", "::"):
host = "127.0.0.1"
if ":" in host and not host.startswith("["):
host = f"[{host}]"
return f"{host}:{port}"
class WebSocketConfig(Base):
"""WebSocket server channel configuration.
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
from ``token_issue_path`` are also accepted.
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
``X-Nanobot-Auth: <secret>``.
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
- ``media`` field in outbound messages contains local filesystem paths; remote clients need a
shared filesystem or an HTTP file server to access these files.
"""
enabled: bool = False
host: str = "127.0.0.1"
port: int = 8765
unix_socket_path: str = ""
path: str = "/"
token: str = ""
token_issue_path: str = ""
token_issue_secret: str = ""
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
websocket_requires_token: bool = True
allow_from: list[str] = Field(default_factory=lambda: ["*"])
streaming: bool = True
# Default 36 MB, upper 40 MB: supports up to 4 images at ~6 MB each after
# client-side Worker normalization (see webui Composer). 4 × 6 MB × 1.37
# (base64 overhead) + envelope framing stays under 36 MB; the 40 MB ceiling
# leaves a small margin for sender slop without opening a DoS avenue.
max_message_bytes: int = Field(default=37_748_736, ge=1024, le=41_943_040)
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
ssl_certfile: str = ""
ssl_keyfile: str = ""
@field_validator("unix_socket_path")
@classmethod
def unix_socket_path_format(cls, value: str) -> str:
value = value.strip()
if not value:
return ""
if "\x00" in value:
raise ValueError("unix_socket_path must not contain NUL bytes")
path = Path(value).expanduser()
if not path.is_absolute():
raise ValueError("unix_socket_path must be an absolute path")
return str(path)
@field_validator("path")
@classmethod
def path_must_start_with_slash(cls, value: str) -> str:
if not value.startswith("/"):
raise ValueError('path must start with "/"')
return _normalize_config_path(value)
@field_validator("token_issue_path")
@classmethod
def token_issue_path_format(cls, value: str) -> str:
value = value.strip()
if not value:
return ""
if not value.startswith("/"):
raise ValueError('token_issue_path must start with "/"')
return _normalize_config_path(value)
@model_validator(mode="after")
def token_issue_path_differs_from_ws_path(self) -> Self:
if not self.token_issue_path:
return self
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
return self
@model_validator(mode="after")
def wildcard_host_requires_auth(self) -> Self:
if self.host not in ("0.0.0.0", "::"):
return self
if self.token.strip() or self.token_issue_secret.strip():
return self
raise ValueError(
"host is 0.0.0.0 (all interfaces) but neither token nor "
"token_issue_secret is set — set one to prevent unauthenticated access"
)
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
headers = Headers(
[
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", "application/json; charset=utf-8"),
]
)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, headers, body)
def publish_runtime_model_update(
bus: MessageBus,
model: str,
model_preset: str | None,
) -> None:
"""Enqueue a runtime model snapshot for websocket subscribers (fan-out in-channel)."""
bus.outbound.put_nowait(OutboundMessage(
channel="websocket",
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": model,
"model_preset": model_preset,
},
))
def _default_model_name_from_config() -> str | None:
"""Resolved model string from on-disk config (bootstrap fallback)."""
try:
from nanobot.config.loader import load_config
model = load_config().resolve_preset().model.strip()
return model or None
except Exception as e:
logger.debug("bootstrap model_name could not load from config: {}", e)
return None
def _resolve_bootstrap_model_name(
runtime_name: Callable[[], str | None] | None,
) -> str | None:
"""Prefer an in-process resolver (e.g. AgentLoop); else config-derived default."""
if runtime_name is not None:
try:
raw = runtime_name()
except Exception as e:
logger.debug("bootstrap runtime model resolver failed: {}", e)
else:
if isinstance(raw, str):
stripped = raw.strip()
if stripped:
return stripped
return _default_model_name_from_config()
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
"""Parse normalized path and query parameters in one pass."""
parsed = urlparse("ws://x" + path_with_query)
path = _strip_trailing_slash(parsed.path or "/")
return path, parse_qs(parsed.query, keep_blank_values=True)
def _normalize_http_path(path_with_query: str) -> str:
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
return _parse_request_path(path_with_query)[0]
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
return _parse_request_path(path_with_query)[1]
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
"""Return the first value for *key*, or None."""
values = query.get(key)
return values[0] if values else None
def _parse_inbound_payload(raw: str) -> str | None:
"""Parse a client frame into text; return None for empty or unrecognized content."""
text = raw.strip()
if not text:
return None
if text.startswith("{"):
try:
data = json.loads(text)
except json.JSONDecodeError:
return text
if isinstance(data, dict):
for key in ("content", "text", "message"):
value = data.get(key)
if isinstance(value, str) and value.strip():
return value
return None
return None
return text
# Accept UUIDs and short scoped keys like "unified:default". Keeps the capability
# namespace small enough to rule out path traversal / quote injection tricks.
_CHAT_ID_RE = re.compile(r"^[A-Za-z0-9_:-]{1,64}$")
def _is_valid_chat_id(value: Any) -> bool:
return isinstance(value, str) and _CHAT_ID_RE.match(value) is not None
def _parse_envelope(raw: str) -> dict[str, Any] | None:
"""Return a typed envelope dict if the frame is a new-style JSON envelope, else None.
A frame qualifies when it parses as a JSON object with a string ``type`` field.
Legacy frames (plain text, or ``{"content": ...}`` without ``type``) return None;
callers should fall back to :func:`_parse_inbound_payload` for those.
"""
text = raw.strip()
if not text.startswith("{"):
return None
try:
data = json.loads(text)
except json.JSONDecodeError:
return None
if not isinstance(data, dict):
return None
t = data.get("type")
if not isinstance(t, str):
return None
return data
# Per-message media limits. The server-side guard is a touch looser than the
# client's ``Worker`` normalization target (6 MB) — tolerate client slop, but
# still cap total ingress at ``_MAX_IMAGES_PER_MESSAGE * _MAX_IMAGE_BYTES``
# which fits comfortably inside ``max_message_bytes``.
_MAX_IMAGES_PER_MESSAGE = 4
_MAX_IMAGE_BYTES = 8 * 1024 * 1024
_MAX_VIDEOS_PER_MESSAGE = 1
_MAX_VIDEO_BYTES = 20 * 1024 * 1024
# Image MIME whitelist — matches the Composer's ``accept`` list. SVG is
# explicitly excluded to avoid the XSS surface inside embedded scripts.
_IMAGE_MIME_ALLOWED: frozenset[str] = frozenset({
"image/png",
"image/jpeg",
"image/webp",
"image/gif",
})
_VIDEO_MIME_ALLOWED: frozenset[str] = frozenset({
"video/mp4",
"video/webm",
"video/quicktime",
})
_UPLOAD_MIME_ALLOWED: frozenset[str] = _IMAGE_MIME_ALLOWED | _VIDEO_MIME_ALLOWED
_DATA_URL_MIME_RE = re.compile(r"^data:([^;]+);base64,", re.DOTALL)
def _extract_data_url_mime(url: str) -> str | None:
"""Return the MIME type of a ``data:<mime>;base64,...`` URL, else ``None``."""
if not isinstance(url, str):
return None
m = _DATA_URL_MIME_RE.match(url)
if not m:
return None
return m.group(1).strip().lower() or None
_LOCALHOSTS = frozenset({"127.0.0.1", "::1", "localhost"})
# Matches the legacy chat-id pattern but allows file-system-safe stems too,
# so the API can address sessions whose keys came from non-WebSocket channels.
_API_KEY_RE = re.compile(r"^[A-Za-z0-9_:.-]{1,128}$")
def _decode_api_key(raw_key: str) -> str | None:
"""Decode a percent-encoded API path segment, then validate the result."""
key = unquote(raw_key)
if _API_KEY_RE.match(key) is None:
return None
return key
def _is_localhost(connection: Any) -> bool:
"""Return True if *connection* originated from the loopback interface."""
addr = getattr(connection, "remote_address", None)
if not addr:
return False
host = addr[0] if isinstance(addr, tuple) else addr
if not isinstance(host, str):
return False
# ``::ffff:127.0.0.1`` is loopback in IPv6-mapped form.
if host.startswith("::ffff:"):
host = host[7:]
return host in _LOCALHOSTS
def _http_response(
body: bytes,
*,
status: int = 200,
content_type: str = "text/plain; charset=utf-8",
extra_headers: list[tuple[str, str]] | None = None,
) -> Response:
headers = [
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", content_type),
]
if extra_headers:
headers.extend(extra_headers)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, Headers(headers), body)
def _http_error(status: int, message: str | None = None) -> Response:
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
return _http_response(body, status=status)
def _bearer_token(headers: Any) -> str | None:
"""Pull a Bearer token out of standard or query-style headers."""
auth = headers.get("Authorization") or headers.get("authorization")
if auth and auth.lower().startswith("bearer "):
return auth[7:].strip() or None
return None
def _is_websocket_upgrade(request: WsRequest) -> bool:
"""Detect an actual WS upgrade; plain HTTP GETs to the same path should fall through."""
upgrade = request.headers.get("Upgrade") or request.headers.get("upgrade")
connection = request.headers.get("Connection") or request.headers.get("connection")
if not upgrade or "websocket" not in upgrade.lower():
return False
if not connection or "upgrade" not in connection.lower():
return False
return True
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
if not configured_secret:
return True
authorization = headers.get("Authorization") or headers.get("authorization")
if authorization and authorization.lower().startswith("bearer "):
supplied = authorization[7:].strip()
return hmac.compare_digest(supplied, configured_secret)
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
if not header_token:
return False
return hmac.compare_digest(header_token.strip(), configured_secret)
class WebSocketChannel(BaseChannel):
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
name = "websocket"
display_name = "WebSocket"
def __init__(
self,
config: Any,
bus: MessageBus,
*,
session_manager: "SessionManager | None" = None,
http_handler: Any | None = None,
workspace_path: Path | None = None,
restrict_to_workspace: bool = False,
runtime_surface: str = "browser",
):
if isinstance(config, dict):
config = WebSocketConfig.model_validate(config)
super().__init__(config, bus)
self.config: WebSocketConfig = config
# chat_id -> connections subscribed to it (fan-out target).
self._subs: dict[str, set[Any]] = {}
# connection -> chat_ids it is subscribed to (O(1) cleanup on disconnect).
self._conn_chats: dict[Any, set[str]] = {}
# connection -> default chat_id for legacy frames that omit routing.
self._conn_default: dict[Any, str] = {}
self._stop_event: asyncio.Event | None = None
self._server_task: asyncio.Task[None] | None = None
self._default_restrict_to_workspace = restrict_to_workspace
self._runtime_surface = (
"native" if runtime_surface in {"native", "desktop"} else "browser"
)
# HTTP handler injected from outside (ChannelManager / gateway startup).
# Owns tokens, sessions, media, settings, static serving.
self._http = http_handler
# Backwards-compat: workspace controller used in envelope dispatch
self._webui_workspaces = http_handler.workspaces if http_handler else None
self._stream_text_buffers: dict[tuple[str, str], list[str]] = {}
# -- Subscription bookkeeping -------------------------------------------
def _attach(self, connection: Any, chat_id: str) -> None:
"""Idempotently subscribe *connection* to *chat_id*."""
self._subs.setdefault(chat_id, set()).add(connection)
self._conn_chats.setdefault(connection, set()).add(chat_id)
def _cleanup_connection(self, connection: Any) -> None:
"""Remove *connection* from every subscription set; safe to call multiple times."""
chat_ids = self._conn_chats.pop(connection, set())
for cid in chat_ids:
subs = self._subs.get(cid)
if subs is None:
continue
subs.discard(connection)
if not subs:
self._subs.pop(cid, None)
self._conn_default.pop(connection, None)
async def _maybe_push_active_goal_state(self, chat_id: str) -> None:
"""Replay an active sustained goal from session metadata after *chat_id* is subscribed.
Goal metadata lives on the session JSONL and survives gateway restarts, but
connected clients normally see it via ``goal_state`` / ``turn_end`` frames.
Pushing here makes refresh + reconnect restore the strip without a new model turn.
"""
if self._http.session_manager is None:
return
row = self._http.session_manager.read_session_file(f"websocket:{chat_id}")
meta = row.get("metadata", {}) if isinstance(row, dict) else {}
if not isinstance(meta, dict):
meta = {}
blob = goal_state_ws_blob(meta)
if not blob.get("active"):
return
await self.send_goal_state(chat_id, blob)
async def _maybe_push_turn_run_wall_clock(self, chat_id: str) -> None:
"""Replay ``goal_status: running`` when a turn is still active (same-process refresh)."""
t0 = websocket_turn_wall_started_at(chat_id)
if t0 is None:
return
await self.send_goal_status(chat_id, "running", started_at=t0)
async def _hydrate_after_subscribe(self, chat_id: str) -> None:
"""Replay goal/run strip state after subscribe (same-process refresh)."""
await self._maybe_push_active_goal_state(chat_id)
await self._maybe_push_turn_run_wall_clock(chat_id)
async def _send_event(self, connection: Any, event: str, **fields: Any) -> None:
"""Send a control event (attached, error, ...) to a single connection."""
payload: dict[str, Any] = {"event": event}
payload.update(fields)
raw = json.dumps(payload, ensure_ascii=False)
try:
await connection.send(raw)
except ConnectionClosed:
self._cleanup_connection(connection)
except Exception as e:
self.logger.warning("failed to send {} event: {}", event, e)
@classmethod
def default_config(cls) -> dict[str, Any]:
return WebSocketConfig().model_dump(by_alias=True)
def _expected_path(self) -> str:
return _normalize_config_path(self.config.path)
# -- Backwards-compat property aliases (used by tests) ------------------
@property
def _session_manager(self):
return self._http.session_manager
@_session_manager.setter
def _session_manager(self, value):
self._http.session_manager = value
@property
def _media_secret(self):
return self._http.media_secret
@property
def _issued_tokens(self):
return self._http.issued_tokens
@_issued_tokens.setter
def _issued_tokens(self, value):
self._http.issued_tokens = value
@property
def _api_tokens(self):
return self._http.api_tokens
def _check_api_token(self, request):
return self._http.check_api_token(request)
def _sign_media_path(self, path):
return self._http.sign_media_path(path)
def _sign_or_stage_media_path(self, path):
return self._http.sign_or_stage_media_path(path)
def _rewrite_local_markdown_images(self, text):
return self._http.rewrite_local_markdown_images(text)
def _handle_bootstrap(self, connection, request):
return self._http._handle_bootstrap(connection, request)
def _handle_sessions_list(self, request):
return self._http._handle_sessions_list(request)
def _handle_webui_thread_get(self, request, key):
return self._http._handle_webui_thread_get(request, key)
@property
def _workspace_path(self):
return self._http.workspace_path
def _build_ssl_context(self) -> ssl.SSLContext | None:
cert = self.config.ssl_certfile.strip()
key = self.config.ssl_keyfile.strip()
if not cert and not key:
return None
if not cert or not key:
raise ValueError(
"ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
)
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ctx.load_cert_chain(certfile=cert, keyfile=key)
return ctx
# -- HTTP dispatch ------------------------------------------------------
async def _dispatch_http(self, connection: Any, request: WsRequest) -> Any:
"""Route an inbound HTTP request to the HTTP handler or WS upgrade."""
got, query = _parse_request_path(request.path)
# WebSocket upgrade — channel handles this itself
expected_ws = self._expected_path()
if got == expected_ws and _is_websocket_upgrade(request):
client_id = _query_first(query, "client_id") or ""
if len(client_id) > 128:
client_id = client_id[:128]
if not self.is_allowed(client_id):
return connection.respond(403, "Forbidden")
return self._authorize_websocket_handshake(connection, query)
<< # Everything else goes to the HTTP handler
return await self._http.dispatch(connection, request)
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
supplied = _query_first(query, "token")
static_token = self.config.token.strip()
if static_token:
if supplied and hmac.compare_digest(supplied, static_token):
return None
if supplied and self._http.take_issued_token_if_valid(supplied):
return None
return connection.respond(401, "Unauthorized")
if self.config.websocket_requires_token:
if supplied and self._http.take_issued_token_if_valid(supplied):
return None
return connection.respond(401, "Unauthorized")
if supplied:
self._http.take_issued_token_if_valid(supplied)
return None
# -- Server lifecycle and connection ingress ---------------------------
# -- Server lifecycle and connection ingress ---------------------------
async def start(self) -> None:
from nanobot.utils.logging_bridge import redirect_lib_logging
redirect_lib_logging("websockets", level="WARNING")
ws_logger = websockets_server_logger()
self._running = True
self._stop_event = asyncio.Event()
ssl_context = self._build_ssl_context()
scheme = "wss" if ssl_context else "ws"
async def process_request(
connection: ServerConnection,
request: WsRequest,
) -> Any:
return await self._dispatch_http(connection, request)
async def handler(connection: ServerConnection) -> None:
await self._connection_loop(connection)
self.logger.info(
"WebSocket server listening on {}",
(
f"unix:{self.config.unix_socket_path}{self.config.path}"
if self.config.unix_socket_path
else f"{scheme}://{self.config.host}:{self.config.port}{self.config.path}"
),
)
if self.config.token_issue_path:
self.logger.info(
"WebSocket token issue route: {}",
(
f"unix:{self.config.unix_socket_path}{_normalize_config_path(self.config.token_issue_path)}"
if self.config.unix_socket_path
else (
f"{scheme}://{self.config.host}:{self.config.port}"
f"{_normalize_config_path(self.config.token_issue_path)}"
)
),
)
async def runner() -> None:
socket_path = self.config.unix_socket_path
if socket_path:
path_obj = Path(socket_path)
path_obj.parent.mkdir(parents=True, exist_ok=True)
with suppress(FileNotFoundError):
path_obj.unlink()
server = await unix_serve(
handler,
socket_path,
process_request=process_request,
max_size=self.config.max_message_bytes,
ping_interval=self.config.ping_interval_s,
ping_timeout=self.config.ping_timeout_s,
logger=ws_logger,
)
with suppress(OSError):
path_obj.chmod(0o600)
else:
server = await serve(
handler,
self.config.host,
self.config.port,
process_request=process_request,
max_size=self.config.max_message_bytes,
ping_interval=self.config.ping_interval_s,
ping_timeout=self.config.ping_timeout_s,
ssl=ssl_context,
logger=ws_logger,
)
try:
assert self._stop_event is not None
await self._stop_event.wait()
finally:
server.close()
await server.wait_closed()
if socket_path:
with suppress(FileNotFoundError):
Path(socket_path).unlink()
self._server_task = asyncio.create_task(runner())
await self._server_task
async def _connection_loop(self, connection: Any) -> None:
request = connection.request
path_part = request.path if request else "/"
_, query = _parse_request_path(path_part)
client_id_raw = _query_first(query, "client_id")
client_id = client_id_raw.strip() if client_id_raw else ""
if not client_id:
client_id = f"anon-{uuid.uuid4().hex[:12]}"
elif len(client_id) > 128:
self.logger.warning("client_id too long ({} chars), truncating", len(client_id))
client_id = client_id[:128]
default_chat_id = str(uuid.uuid4())
try:
await connection.send(
json.dumps(
{
"event": "ready",
"chat_id": default_chat_id,
"client_id": client_id,
},
ensure_ascii=False,
)
)
# Register only after ready is successfully sent to avoid out-of-order sends
self._conn_default[connection] = default_chat_id
self._attach(connection, default_chat_id)
await self._hydrate_after_subscribe(default_chat_id)
async for raw in connection:
if isinstance(raw, bytes):
try:
raw = raw.decode("utf-8")
except UnicodeDecodeError:
self.logger.warning("ignoring non-utf8 binary frame")
continue
envelope = _parse_envelope(raw)
if envelope is not None:
await self._dispatch_envelope(connection, client_id, envelope)
continue
content = _parse_inbound_payload(raw)
if content is None:
continue
# WebSocket already authenticates at handshake time (token),
# so pairing is not applicable. Treat as non-DM to avoid
# sending pairing codes to an already-authenticated client.
await self._handle_message(
sender_id=client_id,
chat_id=default_chat_id,
content=content,
metadata={"remote": getattr(connection, "remote_address", None)},
is_dm=False,
)
except Exception as e:
self.logger.debug("connection ended: {}", e)
finally:
self._cleanup_connection(connection)
# -- Inbound WebSocket envelopes ---------------------------------------
def _save_envelope_media(
self,
media: list[Any],
) -> tuple[list[str], str | None]:
"""Decode and persist ``media`` items from a ``message`` envelope.
Returns ``(paths, None)`` on success or ``([], reason)`` on the first
failure — the caller is expected to surface ``reason`` to the client
and skip publishing so no half-formed message ever reaches the agent.
On failure, any files already written to disk earlier in the same
call are unlinked so partial ingress doesn't leak orphan files.
``reason`` is a short, stable token suitable for UI localization.
Shape: ``list[{"data_url": str, "name"?: str | None}]``.
"""
image_count = 0
video_count = 0
for item in media:
mime = _extract_data_url_mime(item.get("data_url", "")) if isinstance(item, dict) else None
if mime in _VIDEO_MIME_ALLOWED:
video_count += 1
elif mime in _IMAGE_MIME_ALLOWED:
image_count += 1
if image_count > _MAX_IMAGES_PER_MESSAGE:
return [], "too_many_images"
if video_count > _MAX_VIDEOS_PER_MESSAGE:
return [], "too_many_videos"
media_dir = get_media_dir("websocket")
paths: list[str] = []
def _abort(reason: str) -> tuple[list[str], str]:
for p in paths:
try:
Path(p).unlink(missing_ok=True)
except OSError as exc:
self.logger.warning(
"failed to unlink partial media {}: {}", p, exc
)
return [], reason
for item in media:
if not isinstance(item, dict):
return _abort("malformed")
data_url = item.get("data_url")
if not isinstance(data_url, str) or not data_url:
return _abort("malformed")
mime = _extract_data_url_mime(data_url)
if mime is None:
return _abort("decode")
if mime not in _UPLOAD_MIME_ALLOWED:
return _abort("mime")
is_video = mime in _VIDEO_MIME_ALLOWED
max_bytes = _MAX_VIDEO_BYTES if is_video else _MAX_IMAGE_BYTES
try:
saved = save_base64_data_url(
data_url, media_dir, max_bytes=max_bytes,
)
except FileSizeExceeded:
return _abort("size")
except Exception as exc:
self.logger.warning("media decode failed: {}", exc)
return _abort("decode")
if saved is None:
return _abort("decode")
paths.append(saved)
return paths, None
async def _dispatch_envelope(
self,
connection: Any,
client_id: str,
envelope: dict[str, Any],
) -> None:
"""Route one typed inbound envelope (``new_chat`` / ``attach`` / ``message``)."""
t = envelope.get("type")
if t == "new_chat":
new_id = str(uuid.uuid4())
scope = await self._workspace_scope_or_error(
connection,
lambda: self._webui_workspaces.scope_for_new_chat(
envelope,
controls_available=_is_localhost(connection),
),
)
if scope is None:
return
self._webui_workspaces.persist_scope(new_id, scope)
self._attach(connection, new_id)
await self._send_event(connection, "attached", chat_id=new_id)
await self._send_event(
connection,
"session_updated",
chat_id=new_id,
scope="metadata",
workspace_scope=scope.payload(),
)
await self._hydrate_after_subscribe(new_id)
return
if t == "attach":
cid = envelope.get("chat_id")
if not _is_valid_chat_id(cid):
await self._send_event(connection, "error", detail="invalid chat_id")
return
self._attach(connection, cid)
await self._send_event(connection, "attached", chat_id=cid)
await self._hydrate_after_subscribe(cid)
return
if t == "set_workspace_scope":
cid = envelope.get("chat_id")
if not _is_valid_chat_id(cid):
await self._send_event(connection, "error", detail="invalid chat_id")
return
scope = await self._workspace_scope_or_error(
connection,
lambda: self._webui_workspaces.scope_for_set_request(
envelope,
chat_id=cid,
chat_running=websocket_turn_wall_started_at(cid) is not None,
controls_available=_is_localhost(connection),
),
chat_id=cid,
)
if scope is None:
return
self._webui_workspaces.persist_scope(cid, scope)
await self._send_event(
connection,
"session_updated",
chat_id=cid,
scope="metadata",
workspace_scope=scope.payload(),
)
return
if t == "message":
cid = envelope.get("chat_id")
content = envelope.get("content")
if not _is_valid_chat_id(cid):
await self._send_event(connection, "error", detail="invalid chat_id")
return
if not isinstance(content, str):
await self._send_event(connection, "error", detail="missing content")
return
raw_media = envelope.get("media")
media_paths: list[str] = []
if raw_media is not None:
if not isinstance(raw_media, list):
await self._send_event(
connection, "error",
detail="image_rejected", reason="malformed",
)
return
media_paths, reason = self._save_envelope_media(raw_media)
if reason is not None:
await self._send_event(
connection, "error",
detail="image_rejected", reason=reason,
)
return
# Allow image-only turns (content may be empty when media is attached).
if not content.strip() and not media_paths:
await self._send_event(connection, "error", detail="missing content")
return
scope = await self._workspace_scope_or_error(
connection,
lambda: self._webui_workspaces.scope_for_message(
envelope,
chat_id=cid,
chat_running=websocket_turn_wall_started_at(cid) is not None,
controls_available=_is_localhost(connection),
),
chat_id=cid,
)
if scope is None:
return
# Auto-attach on first use so clients can one-shot without a separate attach.
self._attach(connection, cid)
await self._hydrate_after_subscribe(cid)
metadata: dict[str, Any] = {"remote": getattr(connection, "remote_address", None)}
if envelope.get("webui") is True:
metadata["webui"] = True
cli_apps = normalize_cli_app_mentions(envelope.get("cli_apps"))
if cli_apps:
metadata["cli_apps"] = cli_apps
mcp_presets = normalize_mcp_preset_mentions(envelope.get("mcp_presets"))
if mcp_presets:
metadata["mcp_presets"] = mcp_presets
metadata[WORKSPACE_SCOPE_METADATA_KEY] = scope.metadata()
self._webui_workspaces.persist_scope(cid, scope)
image_generation = envelope.get("image_generation")
if isinstance(image_generation, dict) and image_generation.get("enabled") is True:
aspect_ratio = image_generation.get("aspect_ratio")
metadata["image_generation"] = {
"enabled": True,
"aspect_ratio": aspect_ratio if isinstance(aspect_ratio, str) else None,
}
await self._handle_message(
sender_id=client_id,
chat_id=cid,
content=content,
media=media_paths or None,
metadata=metadata,
is_dm=False,
)
return
await self._send_event(connection, "error", detail=f"unknown type: {t!r}")
async def _workspace_scope_or_error(
self,
connection: Any,
resolver: Callable[[], Any],
*,
chat_id: str | None = None,
) -> Any | None:
try:
return resolver()
except WorkspaceScopeError as exc:
await self._send_event(
connection,
"error",
detail="workspace_scope_rejected",
reason=exc.message,
**({"chat_id": chat_id} if chat_id else {}),
)
return None
# -- Outbound WebSocket events -----------------------------------------
async def stop(self) -> None:
if not self._running:
return
self._running = False
if self._stop_event:
self._stop_event.set()
if self._server_task:
try:
await self._server_task
except Exception as e:
self.logger.warning("server task error during shutdown: {}", e)
self._server_task = None
self._subs.clear()
self._conn_chats.clear()
self._conn_default.clear()
self._http.issued_tokens.clear()
self._http.api_tokens.clear()
async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None:
"""Send a raw frame to one connection, cleaning up on ConnectionClosed."""
try:
await connection.send(raw)
except ConnectionClosed:
self._cleanup_connection(connection)
self.logger.warning("connection gone{}", label)
except Exception:
self.logger.exception("send failed{}", label)
raise
def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None:
sk = f"websocket:{chat_id}"
try:
dup = json.loads(json.dumps(wire, ensure_ascii=False))
append_transcript_object(sk, dup)
except (ValueError, TypeError) as e:
self.logger.warning("webui transcript append failed: {}", e)
async def send(self, msg: OutboundMessage) -> None:
if msg.metadata.get("_runtime_model_updated"):
await self.send_runtime_model_updated(
model_name=msg.metadata.get("model"),
model_preset=msg.metadata.get("model_preset"),
)
return
# Snapshot the subscriber set so ConnectionClosed cleanups mid-iteration are safe.
conns = list(self._subs.get(msg.chat_id, ()))
if not conns:
if (
msg.metadata.get("_progress")
or msg.metadata.get("_file_edit_events")
or msg.metadata.get("_turn_end")
or msg.metadata.get("_session_updated")
or msg.metadata.get("_goal_status")
or msg.metadata.get("_goal_state_sync")
):
self.logger.debug("no active subscribers for chat_id={}", msg.chat_id)
else:
self.logger.warning("no active subscribers for chat_id={}", msg.chat_id)
return
if msg.metadata.get("_goal_state_sync"):
blob = msg.metadata.get("goal_state")
await self.send_goal_state(msg.chat_id, blob if isinstance(blob, dict) else {"active": False})
return
if msg.metadata.get("_goal_status"):
status = msg.metadata.get("goal_status")
if status in ("running", "idle"):
started_raw = msg.metadata.get("started_at", msg.metadata.get("goal_started_at"))
await self.send_goal_status(
msg.chat_id,
status,
started_at=float(started_raw) if isinstance(started_raw, int | float) else None,
)
return
# Signal that the agent has fully finished processing the current turn.
if msg.metadata.get("_turn_end"):
lat = msg.metadata.get("latency_ms")
lat_i = int(lat) if isinstance(lat, (int, float)) else None
gs = msg.metadata.get("goal_state")
gs_blob = gs if isinstance(gs, dict) else None
await self.send_turn_end(msg.chat_id, latency_ms=lat_i, goal_state=gs_blob)
return
if msg.metadata.get("_session_updated"):
scope = msg.metadata.get("_session_update_scope")
await self.send_session_updated(
msg.chat_id,
scope=scope if isinstance(scope, str) else None,
)
return
if msg.metadata.get("_file_edit_events"):
edits = msg.metadata.get("_file_edit_events")
await self.send_file_edit_events(
msg.chat_id,
edits if isinstance(edits, list) else [],
msg.metadata,
)
return
text = msg.content
wire_text = self._http.rewrite_local_markdown_images(text)
payload: dict[str, Any] = {
"event": "message",
"chat_id": msg.chat_id,
"text": wire_text,
}
if msg.media:
payload["media"] = msg.media
urls: list[dict[str, str]] = []
for entry in msg.media:
signed = self._http.sign_or_stage_media_path(Path(entry))
if signed is not None:
urls.append(signed)
if urls:
payload["media_urls"] = urls
if msg.reply_to:
payload["reply_to"] = msg.reply_to
lat = msg.metadata.get("latency_ms")
if isinstance(lat, (int, float)):
payload["latency_ms"] = int(lat)
if msg.metadata.get("_tool_events"):
payload["tool_events"] = msg.metadata["_tool_events"]
agent_ui = msg.metadata.get(OUTBOUND_META_AGENT_UI)
if agent_ui is not None:
payload["agent_ui"] = agent_ui
# Mark intermediate agent breadcrumbs (tool-call hints, generic
# progress strings) so WS clients can render them as subordinate
# trace rows rather than conversational replies.
if msg.metadata.get("_tool_hint"):
payload["kind"] = "tool_hint"
elif msg.metadata.get("_progress"):
payload["kind"] = "progress"
transcript_payload = dict(payload)
transcript_payload["text"] = text
self._try_append_webui_transcript(msg.chat_id, transcript_payload)
raw = json.dumps(payload, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" ")
async def send_reasoning_delta(
self,
chat_id: str,
delta: str,
metadata: dict[str, Any] | None = None,
) -> None:
"""Push one chunk of model reasoning. Mirrors ``send_delta`` shape so
clients receive a stream that opens, updates in place, and closes —
rendered above the active assistant bubble with a shimmer header
until the matching ``reasoning_end`` arrives.
"""
conns = list(self._subs.get(chat_id, ()))
if not conns or not delta:
return
meta = metadata or {}
body: dict[str, Any] = {
"event": "reasoning_delta",
"chat_id": chat_id,
"text": delta,
}
stream_id = meta.get("_stream_id")
if stream_id is not None:
body["stream_id"] = stream_id
self._try_append_webui_transcript(chat_id, body)
raw = json.dumps(body, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" reasoning ")
async def send_reasoning_end(
self,
chat_id: str,
metadata: dict[str, Any] | None = None,
) -> None:
"""Close the current reasoning stream segment for in-place renderers."""
conns = list(self._subs.get(chat_id, ()))
if not conns:
return
meta = metadata or {}
body: dict[str, Any] = {
"event": "reasoning_end",
"chat_id": chat_id,
}
stream_id = meta.get("_stream_id")
if stream_id is not None:
body["stream_id"] = stream_id
self._try_append_webui_transcript(chat_id, body)
raw = json.dumps(body, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" reasoning_end ")
async def send_file_edit_events(
self,
chat_id: str,
edits: list[dict[str, Any]],
metadata: dict[str, Any] | None = None,
) -> None:
conns = list(self._subs.get(chat_id, ()))
if not conns:
return
payload: dict[str, Any] = {
"event": "file_edit",
"chat_id": chat_id,
"edits": edits,
}
self._try_append_webui_transcript(chat_id, payload)
raw = json.dumps(payload, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" file_edit ")
async def send_delta(
self,
chat_id: str,
delta: str,
metadata: dict[str, Any] | None = None,
) -> None:
conns = list(self._subs.get(chat_id, ()))
if not conns:
return
meta = metadata or {}
stream_key = (chat_id, str(meta.get("_stream_id") or ""))
if meta.get("_stream_end"):
body: dict[str, Any] = {"event": "stream_end", "chat_id": chat_id}
buffered = self._stream_text_buffers.pop(stream_key, [])
if delta:
buffered.append(delta)
full_text = "".join(buffered)
rewritten = self._http.rewrite_local_markdown_images(full_text)
if rewritten != full_text:
body["text"] = rewritten
else:
body = {
"event": "delta",
"chat_id": chat_id,
"text": delta,
}
self._stream_text_buffers.setdefault(stream_key, []).append(delta)
if meta.get("_stream_id") is not None:
body["stream_id"] = meta["_stream_id"]
self._try_append_webui_transcript(chat_id, body)
raw = json.dumps(body, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" stream ")
async def send_turn_end(
self,
chat_id: str,
latency_ms: int | None = None,
*,
goal_state: dict[str, Any] | None = None,
) -> None:
"""Signal that the agent has fully finished processing the current turn."""
conns = list(self._subs.get(chat_id, ()))
if not conns:
return
body: dict[str, Any] = {"event": "turn_end", "chat_id": chat_id}
if latency_ms is not None:
body["latency_ms"] = int(latency_ms)
if goal_state is not None:
body["goal_state"] = goal_state
self._try_append_webui_transcript(chat_id, body)
raw = json.dumps(body, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" turn_end ")
async def send_goal_state(self, chat_id: str, blob: dict[str, Any]) -> None:
"""Push persisted goal-state snapshot for *chat_id* (multi-chat isolation)."""
conns = list(self._subs.get(chat_id, ()))
if not conns:
return
body = {"event": "goal_state", "chat_id": chat_id, "goal_state": blob}
raw = json.dumps(body, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" goal_state ")
async def send_goal_status(
self,
chat_id: str,
status: str,
*,
started_at: float | None = None,
) -> None:
"""Notify subscribed clients that a turn started or finished (wall-clock hint)."""
conns = list(self._subs.get(chat_id, ()))
if not conns:
return
body: dict[str, Any] = {
"event": "goal_status",
"chat_id": chat_id,
"status": status,
}
if status == "running" and started_at is not None:
body["started_at"] = started_at
raw = json.dumps(body, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" goal_status ")
async def send_session_updated(self, chat_id: str, *, scope: str | None = None) -> None:
"""Notify clients that session metadata changed outside the main turn."""
conns = list(self._subs.get(chat_id, ()))
if not conns:
return
body: dict[str, Any] = {"event": "session_updated", "chat_id": chat_id}
if scope:
body["scope"] = scope
raw = json.dumps(body, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" session_updated ")
async def send_runtime_model_updated(
self,
*,
model_name: Any,
model_preset: Any = None,
) -> None:
"""Broadcast runtime model changes to every open websocket connection."""
conns = list(self._conn_chats)
if not conns or not isinstance(model_name, str) or not model_name.strip():
return
body: dict[str, Any] = {
"event": "runtime_model_updated",
"model_name": model_name.strip(),
}
if isinstance(model_preset, str) and model_preset.strip():
body["model_preset"] = model_preset.strip()
raw = json.dumps(body, ensure_ascii=False)
for connection in conns:
await self._safe_send_to(connection, raw, label=" runtime_model_updated ")