mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-14 06:43:53 +00:00
1367 lines
52 KiB
Python
1367 lines
52 KiB
Python
"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import email.utils
|
||
import hmac
|
||
import http
|
||
import json
|
||
import re
|
||
import ssl
|
||
import uuid
|
||
from collections.abc import Callable
|
||
from contextlib import suppress
|
||
from functools import partial
|
||
from pathlib import Path
|
||
from typing import TYPE_CHECKING, Any, Self
|
||
from urllib.parse import parse_qs, unquote, urlparse
|
||
|
||
from loguru import logger
|
||
from pydantic import Field, field_validator, model_validator
|
||
from websockets.asyncio.server import ServerConnection, serve, unix_serve
|
||
from websockets.datastructures import Headers
|
||
from websockets.exceptions import ConnectionClosed
|
||
from websockets.http11 import Request as WsRequest
|
||
from websockets.http11 import Response
|
||
|
||
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
|
||
from nanobot.bus.queue import MessageBus
|
||
from nanobot.channels.base import BaseChannel
|
||
from nanobot.config.paths import get_media_dir, get_workspace_path
|
||
from nanobot.config.schema import Base
|
||
from nanobot.security.workspace_access import (
|
||
WORKSPACE_SCOPE_METADATA_KEY,
|
||
WorkspaceScopeError,
|
||
)
|
||
from nanobot.session.goal_state import goal_state_ws_blob
|
||
from nanobot.session.webui_turns import websocket_turn_wall_started_at
|
||
from nanobot.utils.media_decode import (
|
||
FileSizeExceeded,
|
||
save_base64_data_url,
|
||
)
|
||
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
|
||
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
||
from nanobot.webui.transcript import append_transcript_object
|
||
|
||
if TYPE_CHECKING:
|
||
from nanobot.session.manager import SessionManager
|
||
|
||
|
||
def _strip_trailing_slash(path: str) -> str:
|
||
if len(path) > 1 and path.endswith("/"):
|
||
return path.rstrip("/")
|
||
return path or "/"
|
||
|
||
|
||
def _normalize_config_path(path: str) -> str:
|
||
return _strip_trailing_slash(path)
|
||
|
||
|
||
def _case_insensitive_header(headers: Any, key: str) -> str:
|
||
"""Read a header from websockets/http test stubs without assuming casing."""
|
||
try:
|
||
value = headers.get(key)
|
||
except Exception:
|
||
value = None
|
||
if value is None:
|
||
try:
|
||
value = headers.get(key.lower())
|
||
except Exception:
|
||
value = None
|
||
return str(value or "").strip()
|
||
|
||
|
||
def _safe_host_header(value: str) -> str:
|
||
"""Return a safe Host header value, or empty when it should not be echoed."""
|
||
value = value.strip()
|
||
if not value:
|
||
return ""
|
||
if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value):
|
||
return value
|
||
if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value):
|
||
return value
|
||
return ""
|
||
|
||
|
||
def _host_for_url(host: str, port: int) -> str:
|
||
host = host.strip()
|
||
if host in ("0.0.0.0", "::"):
|
||
host = "127.0.0.1"
|
||
if ":" in host and not host.startswith("["):
|
||
host = f"[{host}]"
|
||
return f"{host}:{port}"
|
||
|
||
|
||
class WebSocketConfig(Base):
|
||
"""WebSocket server channel configuration.
|
||
|
||
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
|
||
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
|
||
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
|
||
from ``token_issue_path`` are also accepted.
|
||
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
|
||
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
|
||
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
|
||
nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call
|
||
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
|
||
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
|
||
``X-Nanobot-Auth: <secret>``.
|
||
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
|
||
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
|
||
- ``media`` field in outbound messages contains local filesystem paths; remote clients need a
|
||
shared filesystem or an HTTP file server to access these files.
|
||
"""
|
||
|
||
enabled: bool = False
|
||
host: str = "127.0.0.1"
|
||
port: int = 8765
|
||
unix_socket_path: str = ""
|
||
path: str = "/"
|
||
token: str = ""
|
||
token_issue_path: str = ""
|
||
token_issue_secret: str = ""
|
||
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
|
||
websocket_requires_token: bool = True
|
||
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
||
streaming: bool = True
|
||
# Default 36 MB, upper 40 MB: supports up to 4 images at ~6 MB each after
|
||
# client-side Worker normalization (see webui Composer). 4 × 6 MB × 1.37
|
||
# (base64 overhead) + envelope framing stays under 36 MB; the 40 MB ceiling
|
||
# leaves a small margin for sender slop without opening a DoS avenue.
|
||
max_message_bytes: int = Field(default=37_748_736, ge=1024, le=41_943_040)
|
||
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||
ssl_certfile: str = ""
|
||
ssl_keyfile: str = ""
|
||
|
||
@field_validator("unix_socket_path")
|
||
@classmethod
|
||
def unix_socket_path_format(cls, value: str) -> str:
|
||
value = value.strip()
|
||
if not value:
|
||
return ""
|
||
if "\x00" in value:
|
||
raise ValueError("unix_socket_path must not contain NUL bytes")
|
||
path = Path(value).expanduser()
|
||
if not path.is_absolute():
|
||
raise ValueError("unix_socket_path must be an absolute path")
|
||
return str(path)
|
||
|
||
@field_validator("path")
|
||
@classmethod
|
||
def path_must_start_with_slash(cls, value: str) -> str:
|
||
if not value.startswith("/"):
|
||
raise ValueError('path must start with "/"')
|
||
return _normalize_config_path(value)
|
||
|
||
@field_validator("token_issue_path")
|
||
@classmethod
|
||
def token_issue_path_format(cls, value: str) -> str:
|
||
value = value.strip()
|
||
if not value:
|
||
return ""
|
||
if not value.startswith("/"):
|
||
raise ValueError('token_issue_path must start with "/"')
|
||
return _normalize_config_path(value)
|
||
|
||
@model_validator(mode="after")
|
||
def token_issue_path_differs_from_ws_path(self) -> Self:
|
||
if not self.token_issue_path:
|
||
return self
|
||
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
|
||
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
|
||
return self
|
||
|
||
@model_validator(mode="after")
|
||
def wildcard_host_requires_auth(self) -> Self:
|
||
if self.host not in ("0.0.0.0", "::"):
|
||
return self
|
||
if self.token.strip() or self.token_issue_secret.strip():
|
||
return self
|
||
raise ValueError(
|
||
"host is 0.0.0.0 (all interfaces) but neither token nor "
|
||
"token_issue_secret is set — set one to prevent unauthenticated access"
|
||
)
|
||
|
||
|
||
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||
headers = Headers(
|
||
[
|
||
("Date", email.utils.formatdate(usegmt=True)),
|
||
("Connection", "close"),
|
||
("Content-Length", str(len(body))),
|
||
("Content-Type", "application/json; charset=utf-8"),
|
||
]
|
||
)
|
||
reason = http.HTTPStatus(status).phrase
|
||
return Response(status, reason, headers, body)
|
||
|
||
|
||
def publish_runtime_model_update(
|
||
bus: MessageBus,
|
||
model: str,
|
||
model_preset: str | None,
|
||
) -> None:
|
||
"""Enqueue a runtime model snapshot for websocket subscribers (fan-out in-channel)."""
|
||
bus.outbound.put_nowait(OutboundMessage(
|
||
channel="websocket",
|
||
chat_id="*",
|
||
content="",
|
||
metadata={
|
||
"_runtime_model_updated": True,
|
||
"model": model,
|
||
"model_preset": model_preset,
|
||
},
|
||
))
|
||
|
||
|
||
def _default_model_name_from_config() -> str | None:
|
||
"""Resolved model string from on-disk config (bootstrap fallback)."""
|
||
try:
|
||
from nanobot.config.loader import load_config
|
||
|
||
model = load_config().resolve_preset().model.strip()
|
||
return model or None
|
||
except Exception as e:
|
||
logger.debug("bootstrap model_name could not load from config: {}", e)
|
||
return None
|
||
|
||
|
||
def _resolve_bootstrap_model_name(
|
||
runtime_name: Callable[[], str | None] | None,
|
||
) -> str | None:
|
||
"""Prefer an in-process resolver (e.g. AgentLoop); else config-derived default."""
|
||
if runtime_name is not None:
|
||
try:
|
||
raw = runtime_name()
|
||
except Exception as e:
|
||
logger.debug("bootstrap runtime model resolver failed: {}", e)
|
||
else:
|
||
if isinstance(raw, str):
|
||
stripped = raw.strip()
|
||
if stripped:
|
||
return stripped
|
||
return _default_model_name_from_config()
|
||
|
||
|
||
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||
"""Parse normalized path and query parameters in one pass."""
|
||
parsed = urlparse("ws://x" + path_with_query)
|
||
path = _strip_trailing_slash(parsed.path or "/")
|
||
return path, parse_qs(parsed.query, keep_blank_values=True)
|
||
|
||
|
||
def _normalize_http_path(path_with_query: str) -> str:
|
||
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
|
||
return _parse_request_path(path_with_query)[0]
|
||
|
||
|
||
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||
return _parse_request_path(path_with_query)[1]
|
||
|
||
|
||
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
|
||
"""Return the first value for *key*, or None."""
|
||
values = query.get(key)
|
||
return values[0] if values else None
|
||
|
||
|
||
def _parse_inbound_payload(raw: str) -> str | None:
|
||
"""Parse a client frame into text; return None for empty or unrecognized content."""
|
||
text = raw.strip()
|
||
if not text:
|
||
return None
|
||
if text.startswith("{"):
|
||
try:
|
||
data = json.loads(text)
|
||
except json.JSONDecodeError:
|
||
return text
|
||
if isinstance(data, dict):
|
||
for key in ("content", "text", "message"):
|
||
value = data.get(key)
|
||
if isinstance(value, str) and value.strip():
|
||
return value
|
||
return None
|
||
return None
|
||
return text
|
||
|
||
|
||
# Accept UUIDs and short scoped keys like "unified:default". Keeps the capability
|
||
# namespace small enough to rule out path traversal / quote injection tricks.
|
||
_CHAT_ID_RE = re.compile(r"^[A-Za-z0-9_:-]{1,64}$")
|
||
|
||
|
||
def _is_valid_chat_id(value: Any) -> bool:
|
||
return isinstance(value, str) and _CHAT_ID_RE.match(value) is not None
|
||
|
||
|
||
def _parse_envelope(raw: str) -> dict[str, Any] | None:
|
||
"""Return a typed envelope dict if the frame is a new-style JSON envelope, else None.
|
||
|
||
A frame qualifies when it parses as a JSON object with a string ``type`` field.
|
||
Legacy frames (plain text, or ``{"content": ...}`` without ``type``) return None;
|
||
callers should fall back to :func:`_parse_inbound_payload` for those.
|
||
"""
|
||
text = raw.strip()
|
||
if not text.startswith("{"):
|
||
return None
|
||
try:
|
||
data = json.loads(text)
|
||
except json.JSONDecodeError:
|
||
return None
|
||
if not isinstance(data, dict):
|
||
return None
|
||
t = data.get("type")
|
||
if not isinstance(t, str):
|
||
return None
|
||
return data
|
||
|
||
|
||
# Per-message media limits. The server-side guard is a touch looser than the
|
||
# client's ``Worker`` normalization target (6 MB) — tolerate client slop, but
|
||
# still cap total ingress at ``_MAX_IMAGES_PER_MESSAGE * _MAX_IMAGE_BYTES``
|
||
# which fits comfortably inside ``max_message_bytes``.
|
||
_MAX_IMAGES_PER_MESSAGE = 4
|
||
_MAX_IMAGE_BYTES = 8 * 1024 * 1024
|
||
_MAX_VIDEOS_PER_MESSAGE = 1
|
||
_MAX_VIDEO_BYTES = 20 * 1024 * 1024
|
||
|
||
# Image MIME whitelist — matches the Composer's ``accept`` list. SVG is
|
||
# explicitly excluded to avoid the XSS surface inside embedded scripts.
|
||
_IMAGE_MIME_ALLOWED: frozenset[str] = frozenset({
|
||
"image/png",
|
||
"image/jpeg",
|
||
"image/webp",
|
||
"image/gif",
|
||
})
|
||
|
||
_VIDEO_MIME_ALLOWED: frozenset[str] = frozenset({
|
||
"video/mp4",
|
||
"video/webm",
|
||
"video/quicktime",
|
||
})
|
||
|
||
_UPLOAD_MIME_ALLOWED: frozenset[str] = _IMAGE_MIME_ALLOWED | _VIDEO_MIME_ALLOWED
|
||
|
||
_DATA_URL_MIME_RE = re.compile(r"^data:([^;]+);base64,", re.DOTALL)
|
||
|
||
|
||
def _extract_data_url_mime(url: str) -> str | None:
|
||
"""Return the MIME type of a ``data:<mime>;base64,...`` URL, else ``None``."""
|
||
if not isinstance(url, str):
|
||
return None
|
||
m = _DATA_URL_MIME_RE.match(url)
|
||
if not m:
|
||
return None
|
||
return m.group(1).strip().lower() or None
|
||
|
||
|
||
_LOCALHOSTS = frozenset({"127.0.0.1", "::1", "localhost"})
|
||
|
||
# Matches the legacy chat-id pattern but allows file-system-safe stems too,
|
||
# so the API can address sessions whose keys came from non-WebSocket channels.
|
||
_API_KEY_RE = re.compile(r"^[A-Za-z0-9_:.-]{1,128}$")
|
||
|
||
|
||
def _decode_api_key(raw_key: str) -> str | None:
|
||
"""Decode a percent-encoded API path segment, then validate the result."""
|
||
key = unquote(raw_key)
|
||
if _API_KEY_RE.match(key) is None:
|
||
return None
|
||
return key
|
||
|
||
|
||
def _is_localhost(connection: Any) -> bool:
|
||
"""Return True if *connection* originated from the loopback interface."""
|
||
addr = getattr(connection, "remote_address", None)
|
||
if not addr:
|
||
return False
|
||
host = addr[0] if isinstance(addr, tuple) else addr
|
||
if not isinstance(host, str):
|
||
return False
|
||
# ``::ffff:127.0.0.1`` is loopback in IPv6-mapped form.
|
||
if host.startswith("::ffff:"):
|
||
host = host[7:]
|
||
return host in _LOCALHOSTS
|
||
|
||
|
||
def _http_response(
|
||
body: bytes,
|
||
*,
|
||
status: int = 200,
|
||
content_type: str = "text/plain; charset=utf-8",
|
||
extra_headers: list[tuple[str, str]] | None = None,
|
||
) -> Response:
|
||
headers = [
|
||
("Date", email.utils.formatdate(usegmt=True)),
|
||
("Connection", "close"),
|
||
("Content-Length", str(len(body))),
|
||
("Content-Type", content_type),
|
||
]
|
||
if extra_headers:
|
||
headers.extend(extra_headers)
|
||
reason = http.HTTPStatus(status).phrase
|
||
return Response(status, reason, Headers(headers), body)
|
||
|
||
|
||
def _http_error(status: int, message: str | None = None) -> Response:
|
||
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
|
||
return _http_response(body, status=status)
|
||
|
||
|
||
def _bearer_token(headers: Any) -> str | None:
|
||
"""Pull a Bearer token out of standard or query-style headers."""
|
||
auth = headers.get("Authorization") or headers.get("authorization")
|
||
if auth and auth.lower().startswith("bearer "):
|
||
return auth[7:].strip() or None
|
||
return None
|
||
|
||
|
||
def _is_websocket_upgrade(request: WsRequest) -> bool:
|
||
"""Detect an actual WS upgrade; plain HTTP GETs to the same path should fall through."""
|
||
upgrade = request.headers.get("Upgrade") or request.headers.get("upgrade")
|
||
connection = request.headers.get("Connection") or request.headers.get("connection")
|
||
if not upgrade or "websocket" not in upgrade.lower():
|
||
return False
|
||
if not connection or "upgrade" not in connection.lower():
|
||
return False
|
||
return True
|
||
|
||
|
||
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
|
||
if not configured_secret:
|
||
return True
|
||
authorization = headers.get("Authorization") or headers.get("authorization")
|
||
if authorization and authorization.lower().startswith("bearer "):
|
||
supplied = authorization[7:].strip()
|
||
return hmac.compare_digest(supplied, configured_secret)
|
||
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||
if not header_token:
|
||
return False
|
||
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||
|
||
|
||
class WebSocketChannel(BaseChannel):
|
||
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
|
||
|
||
name = "websocket"
|
||
display_name = "WebSocket"
|
||
|
||
def __init__(
|
||
self,
|
||
config: Any,
|
||
bus: MessageBus,
|
||
*,
|
||
session_manager: "SessionManager | None" = None,
|
||
http_handler: Any | None = None,
|
||
workspace_path: Path | None = None,
|
||
restrict_to_workspace: bool = False,
|
||
runtime_surface: str = "browser",
|
||
):
|
||
if isinstance(config, dict):
|
||
config = WebSocketConfig.model_validate(config)
|
||
super().__init__(config, bus)
|
||
self.config: WebSocketConfig = config
|
||
# chat_id -> connections subscribed to it (fan-out target).
|
||
self._subs: dict[str, set[Any]] = {}
|
||
# connection -> chat_ids it is subscribed to (O(1) cleanup on disconnect).
|
||
self._conn_chats: dict[Any, set[str]] = {}
|
||
# connection -> default chat_id for legacy frames that omit routing.
|
||
self._conn_default: dict[Any, str] = {}
|
||
self._stop_event: asyncio.Event | None = None
|
||
self._server_task: asyncio.Task[None] | None = None
|
||
self._default_restrict_to_workspace = restrict_to_workspace
|
||
self._runtime_surface = (
|
||
"native" if runtime_surface in {"native", "desktop"} else "browser"
|
||
)
|
||
|
||
# HTTP handler injected from outside (ChannelManager / gateway startup).
|
||
# Owns tokens, sessions, media, settings, static serving.
|
||
self._http = http_handler
|
||
# Backwards-compat: workspace controller used in envelope dispatch
|
||
self._webui_workspaces = http_handler.workspaces if http_handler else None
|
||
|
||
self._stream_text_buffers: dict[tuple[str, str], list[str]] = {}
|
||
|
||
# -- Subscription bookkeeping -------------------------------------------
|
||
|
||
def _attach(self, connection: Any, chat_id: str) -> None:
|
||
"""Idempotently subscribe *connection* to *chat_id*."""
|
||
self._subs.setdefault(chat_id, set()).add(connection)
|
||
self._conn_chats.setdefault(connection, set()).add(chat_id)
|
||
|
||
def _cleanup_connection(self, connection: Any) -> None:
|
||
"""Remove *connection* from every subscription set; safe to call multiple times."""
|
||
chat_ids = self._conn_chats.pop(connection, set())
|
||
for cid in chat_ids:
|
||
subs = self._subs.get(cid)
|
||
if subs is None:
|
||
continue
|
||
subs.discard(connection)
|
||
if not subs:
|
||
self._subs.pop(cid, None)
|
||
self._conn_default.pop(connection, None)
|
||
|
||
async def _maybe_push_active_goal_state(self, chat_id: str) -> None:
|
||
"""Replay an active sustained goal from session metadata after *chat_id* is subscribed.
|
||
|
||
Goal metadata lives on the session JSONL and survives gateway restarts, but
|
||
connected clients normally see it via ``goal_state`` / ``turn_end`` frames.
|
||
Pushing here makes refresh + reconnect restore the strip without a new model turn.
|
||
"""
|
||
if self._http.session_manager is None:
|
||
return
|
||
row = self._http.session_manager.read_session_file(f"websocket:{chat_id}")
|
||
meta = row.get("metadata", {}) if isinstance(row, dict) else {}
|
||
if not isinstance(meta, dict):
|
||
meta = {}
|
||
blob = goal_state_ws_blob(meta)
|
||
if not blob.get("active"):
|
||
return
|
||
await self.send_goal_state(chat_id, blob)
|
||
|
||
async def _maybe_push_turn_run_wall_clock(self, chat_id: str) -> None:
|
||
"""Replay ``goal_status: running`` when a turn is still active (same-process refresh)."""
|
||
t0 = websocket_turn_wall_started_at(chat_id)
|
||
if t0 is None:
|
||
return
|
||
await self.send_goal_status(chat_id, "running", started_at=t0)
|
||
|
||
async def _hydrate_after_subscribe(self, chat_id: str) -> None:
|
||
"""Replay goal/run strip state after subscribe (same-process refresh)."""
|
||
await self._maybe_push_active_goal_state(chat_id)
|
||
await self._maybe_push_turn_run_wall_clock(chat_id)
|
||
|
||
async def _send_event(self, connection: Any, event: str, **fields: Any) -> None:
|
||
"""Send a control event (attached, error, ...) to a single connection."""
|
||
payload: dict[str, Any] = {"event": event}
|
||
payload.update(fields)
|
||
raw = json.dumps(payload, ensure_ascii=False)
|
||
try:
|
||
await connection.send(raw)
|
||
except ConnectionClosed:
|
||
self._cleanup_connection(connection)
|
||
except Exception as e:
|
||
self.logger.warning("failed to send {} event: {}", event, e)
|
||
|
||
@classmethod
|
||
def default_config(cls) -> dict[str, Any]:
|
||
return WebSocketConfig().model_dump(by_alias=True)
|
||
|
||
def _expected_path(self) -> str:
|
||
return _normalize_config_path(self.config.path)
|
||
|
||
# -- Backwards-compat property aliases (used by tests) ------------------
|
||
|
||
@property
|
||
def _session_manager(self):
|
||
return self._http.session_manager
|
||
|
||
@_session_manager.setter
|
||
def _session_manager(self, value):
|
||
self._http.session_manager = value
|
||
|
||
@property
|
||
def _media_secret(self):
|
||
return self._http.media_secret
|
||
|
||
@property
|
||
def _issued_tokens(self):
|
||
return self._http.issued_tokens
|
||
|
||
@_issued_tokens.setter
|
||
def _issued_tokens(self, value):
|
||
self._http.issued_tokens = value
|
||
|
||
@property
|
||
def _api_tokens(self):
|
||
return self._http.api_tokens
|
||
|
||
def _check_api_token(self, request):
|
||
return self._http.check_api_token(request)
|
||
|
||
def _sign_media_path(self, path):
|
||
return self._http.sign_media_path(path)
|
||
|
||
def _sign_or_stage_media_path(self, path):
|
||
return self._http.sign_or_stage_media_path(path)
|
||
|
||
def _rewrite_local_markdown_images(self, text):
|
||
return self._http.rewrite_local_markdown_images(text)
|
||
|
||
def _handle_bootstrap(self, connection, request):
|
||
return self._http._handle_bootstrap(connection, request)
|
||
|
||
def _handle_sessions_list(self, request):
|
||
return self._http._handle_sessions_list(request)
|
||
|
||
def _handle_webui_thread_get(self, request, key):
|
||
return self._http._handle_webui_thread_get(request, key)
|
||
|
||
@property
|
||
def _workspace_path(self):
|
||
return self._http.workspace_path
|
||
|
||
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
||
cert = self.config.ssl_certfile.strip()
|
||
key = self.config.ssl_keyfile.strip()
|
||
if not cert and not key:
|
||
return None
|
||
if not cert or not key:
|
||
raise ValueError(
|
||
"ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
||
)
|
||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||
return ctx
|
||
|
||
# -- HTTP dispatch ------------------------------------------------------
|
||
|
||
async def _dispatch_http(self, connection: Any, request: WsRequest) -> Any:
|
||
"""Route an inbound HTTP request to the HTTP handler or WS upgrade."""
|
||
got, query = _parse_request_path(request.path)
|
||
|
||
# WebSocket upgrade — channel handles this itself
|
||
expected_ws = self._expected_path()
|
||
if got == expected_ws and _is_websocket_upgrade(request):
|
||
client_id = _query_first(query, "client_id") or ""
|
||
if len(client_id) > 128:
|
||
client_id = client_id[:128]
|
||
if not self.is_allowed(client_id):
|
||
return connection.respond(403, "Forbidden")
|
||
return self._authorize_websocket_handshake(connection, query)
|
||
|
||
<< # Everything else goes to the HTTP handler
|
||
return await self._http.dispatch(connection, request)
|
||
|
||
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
|
||
supplied = _query_first(query, "token")
|
||
static_token = self.config.token.strip()
|
||
|
||
if static_token:
|
||
if supplied and hmac.compare_digest(supplied, static_token):
|
||
return None
|
||
if supplied and self._http.take_issued_token_if_valid(supplied):
|
||
return None
|
||
return connection.respond(401, "Unauthorized")
|
||
|
||
if self.config.websocket_requires_token:
|
||
if supplied and self._http.take_issued_token_if_valid(supplied):
|
||
return None
|
||
return connection.respond(401, "Unauthorized")
|
||
|
||
if supplied:
|
||
self._http.take_issued_token_if_valid(supplied)
|
||
return None
|
||
|
||
# -- Server lifecycle and connection ingress ---------------------------
|
||
# -- Server lifecycle and connection ingress ---------------------------
|
||
|
||
async def start(self) -> None:
|
||
from nanobot.utils.logging_bridge import redirect_lib_logging
|
||
|
||
redirect_lib_logging("websockets", level="WARNING")
|
||
ws_logger = websockets_server_logger()
|
||
|
||
self._running = True
|
||
self._stop_event = asyncio.Event()
|
||
|
||
ssl_context = self._build_ssl_context()
|
||
scheme = "wss" if ssl_context else "ws"
|
||
|
||
async def process_request(
|
||
connection: ServerConnection,
|
||
request: WsRequest,
|
||
) -> Any:
|
||
return await self._dispatch_http(connection, request)
|
||
|
||
async def handler(connection: ServerConnection) -> None:
|
||
await self._connection_loop(connection)
|
||
|
||
self.logger.info(
|
||
"WebSocket server listening on {}",
|
||
(
|
||
f"unix:{self.config.unix_socket_path}{self.config.path}"
|
||
if self.config.unix_socket_path
|
||
else f"{scheme}://{self.config.host}:{self.config.port}{self.config.path}"
|
||
),
|
||
)
|
||
if self.config.token_issue_path:
|
||
self.logger.info(
|
||
"WebSocket token issue route: {}",
|
||
(
|
||
f"unix:{self.config.unix_socket_path}{_normalize_config_path(self.config.token_issue_path)}"
|
||
if self.config.unix_socket_path
|
||
else (
|
||
f"{scheme}://{self.config.host}:{self.config.port}"
|
||
f"{_normalize_config_path(self.config.token_issue_path)}"
|
||
)
|
||
),
|
||
)
|
||
|
||
async def runner() -> None:
|
||
socket_path = self.config.unix_socket_path
|
||
if socket_path:
|
||
path_obj = Path(socket_path)
|
||
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||
with suppress(FileNotFoundError):
|
||
path_obj.unlink()
|
||
server = await unix_serve(
|
||
handler,
|
||
socket_path,
|
||
process_request=process_request,
|
||
max_size=self.config.max_message_bytes,
|
||
ping_interval=self.config.ping_interval_s,
|
||
ping_timeout=self.config.ping_timeout_s,
|
||
logger=ws_logger,
|
||
)
|
||
with suppress(OSError):
|
||
path_obj.chmod(0o600)
|
||
else:
|
||
server = await serve(
|
||
handler,
|
||
self.config.host,
|
||
self.config.port,
|
||
process_request=process_request,
|
||
max_size=self.config.max_message_bytes,
|
||
ping_interval=self.config.ping_interval_s,
|
||
ping_timeout=self.config.ping_timeout_s,
|
||
ssl=ssl_context,
|
||
logger=ws_logger,
|
||
)
|
||
try:
|
||
assert self._stop_event is not None
|
||
await self._stop_event.wait()
|
||
finally:
|
||
server.close()
|
||
await server.wait_closed()
|
||
if socket_path:
|
||
with suppress(FileNotFoundError):
|
||
Path(socket_path).unlink()
|
||
|
||
self._server_task = asyncio.create_task(runner())
|
||
await self._server_task
|
||
|
||
async def _connection_loop(self, connection: Any) -> None:
|
||
request = connection.request
|
||
path_part = request.path if request else "/"
|
||
_, query = _parse_request_path(path_part)
|
||
client_id_raw = _query_first(query, "client_id")
|
||
client_id = client_id_raw.strip() if client_id_raw else ""
|
||
if not client_id:
|
||
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
||
elif len(client_id) > 128:
|
||
self.logger.warning("client_id too long ({} chars), truncating", len(client_id))
|
||
client_id = client_id[:128]
|
||
|
||
default_chat_id = str(uuid.uuid4())
|
||
|
||
try:
|
||
await connection.send(
|
||
json.dumps(
|
||
{
|
||
"event": "ready",
|
||
"chat_id": default_chat_id,
|
||
"client_id": client_id,
|
||
},
|
||
ensure_ascii=False,
|
||
)
|
||
)
|
||
# Register only after ready is successfully sent to avoid out-of-order sends
|
||
self._conn_default[connection] = default_chat_id
|
||
self._attach(connection, default_chat_id)
|
||
await self._hydrate_after_subscribe(default_chat_id)
|
||
|
||
async for raw in connection:
|
||
if isinstance(raw, bytes):
|
||
try:
|
||
raw = raw.decode("utf-8")
|
||
except UnicodeDecodeError:
|
||
self.logger.warning("ignoring non-utf8 binary frame")
|
||
continue
|
||
|
||
envelope = _parse_envelope(raw)
|
||
if envelope is not None:
|
||
await self._dispatch_envelope(connection, client_id, envelope)
|
||
continue
|
||
|
||
content = _parse_inbound_payload(raw)
|
||
if content is None:
|
||
continue
|
||
# WebSocket already authenticates at handshake time (token),
|
||
# so pairing is not applicable. Treat as non-DM to avoid
|
||
# sending pairing codes to an already-authenticated client.
|
||
await self._handle_message(
|
||
sender_id=client_id,
|
||
chat_id=default_chat_id,
|
||
content=content,
|
||
metadata={"remote": getattr(connection, "remote_address", None)},
|
||
is_dm=False,
|
||
)
|
||
except Exception as e:
|
||
self.logger.debug("connection ended: {}", e)
|
||
finally:
|
||
self._cleanup_connection(connection)
|
||
|
||
# -- Inbound WebSocket envelopes ---------------------------------------
|
||
|
||
def _save_envelope_media(
|
||
self,
|
||
media: list[Any],
|
||
) -> tuple[list[str], str | None]:
|
||
"""Decode and persist ``media`` items from a ``message`` envelope.
|
||
|
||
Returns ``(paths, None)`` on success or ``([], reason)`` on the first
|
||
failure — the caller is expected to surface ``reason`` to the client
|
||
and skip publishing so no half-formed message ever reaches the agent.
|
||
On failure, any files already written to disk earlier in the same
|
||
call are unlinked so partial ingress doesn't leak orphan files.
|
||
``reason`` is a short, stable token suitable for UI localization.
|
||
|
||
Shape: ``list[{"data_url": str, "name"?: str | None}]``.
|
||
"""
|
||
image_count = 0
|
||
video_count = 0
|
||
for item in media:
|
||
mime = _extract_data_url_mime(item.get("data_url", "")) if isinstance(item, dict) else None
|
||
if mime in _VIDEO_MIME_ALLOWED:
|
||
video_count += 1
|
||
elif mime in _IMAGE_MIME_ALLOWED:
|
||
image_count += 1
|
||
if image_count > _MAX_IMAGES_PER_MESSAGE:
|
||
return [], "too_many_images"
|
||
if video_count > _MAX_VIDEOS_PER_MESSAGE:
|
||
return [], "too_many_videos"
|
||
|
||
media_dir = get_media_dir("websocket")
|
||
paths: list[str] = []
|
||
|
||
def _abort(reason: str) -> tuple[list[str], str]:
|
||
for p in paths:
|
||
try:
|
||
Path(p).unlink(missing_ok=True)
|
||
except OSError as exc:
|
||
self.logger.warning(
|
||
"failed to unlink partial media {}: {}", p, exc
|
||
)
|
||
return [], reason
|
||
|
||
for item in media:
|
||
if not isinstance(item, dict):
|
||
return _abort("malformed")
|
||
data_url = item.get("data_url")
|
||
if not isinstance(data_url, str) or not data_url:
|
||
return _abort("malformed")
|
||
mime = _extract_data_url_mime(data_url)
|
||
if mime is None:
|
||
return _abort("decode")
|
||
if mime not in _UPLOAD_MIME_ALLOWED:
|
||
return _abort("mime")
|
||
is_video = mime in _VIDEO_MIME_ALLOWED
|
||
max_bytes = _MAX_VIDEO_BYTES if is_video else _MAX_IMAGE_BYTES
|
||
try:
|
||
saved = save_base64_data_url(
|
||
data_url, media_dir, max_bytes=max_bytes,
|
||
)
|
||
except FileSizeExceeded:
|
||
return _abort("size")
|
||
except Exception as exc:
|
||
self.logger.warning("media decode failed: {}", exc)
|
||
return _abort("decode")
|
||
if saved is None:
|
||
return _abort("decode")
|
||
paths.append(saved)
|
||
return paths, None
|
||
|
||
async def _dispatch_envelope(
|
||
self,
|
||
connection: Any,
|
||
client_id: str,
|
||
envelope: dict[str, Any],
|
||
) -> None:
|
||
"""Route one typed inbound envelope (``new_chat`` / ``attach`` / ``message``)."""
|
||
t = envelope.get("type")
|
||
if t == "new_chat":
|
||
new_id = str(uuid.uuid4())
|
||
scope = await self._workspace_scope_or_error(
|
||
connection,
|
||
lambda: self._webui_workspaces.scope_for_new_chat(
|
||
envelope,
|
||
controls_available=_is_localhost(connection),
|
||
),
|
||
)
|
||
if scope is None:
|
||
return
|
||
self._webui_workspaces.persist_scope(new_id, scope)
|
||
self._attach(connection, new_id)
|
||
await self._send_event(connection, "attached", chat_id=new_id)
|
||
await self._send_event(
|
||
connection,
|
||
"session_updated",
|
||
chat_id=new_id,
|
||
scope="metadata",
|
||
workspace_scope=scope.payload(),
|
||
)
|
||
await self._hydrate_after_subscribe(new_id)
|
||
return
|
||
if t == "attach":
|
||
cid = envelope.get("chat_id")
|
||
if not _is_valid_chat_id(cid):
|
||
await self._send_event(connection, "error", detail="invalid chat_id")
|
||
return
|
||
self._attach(connection, cid)
|
||
await self._send_event(connection, "attached", chat_id=cid)
|
||
await self._hydrate_after_subscribe(cid)
|
||
return
|
||
if t == "set_workspace_scope":
|
||
cid = envelope.get("chat_id")
|
||
if not _is_valid_chat_id(cid):
|
||
await self._send_event(connection, "error", detail="invalid chat_id")
|
||
return
|
||
scope = await self._workspace_scope_or_error(
|
||
connection,
|
||
lambda: self._webui_workspaces.scope_for_set_request(
|
||
envelope,
|
||
chat_id=cid,
|
||
chat_running=websocket_turn_wall_started_at(cid) is not None,
|
||
controls_available=_is_localhost(connection),
|
||
),
|
||
chat_id=cid,
|
||
)
|
||
if scope is None:
|
||
return
|
||
self._webui_workspaces.persist_scope(cid, scope)
|
||
await self._send_event(
|
||
connection,
|
||
"session_updated",
|
||
chat_id=cid,
|
||
scope="metadata",
|
||
workspace_scope=scope.payload(),
|
||
)
|
||
return
|
||
if t == "message":
|
||
cid = envelope.get("chat_id")
|
||
content = envelope.get("content")
|
||
if not _is_valid_chat_id(cid):
|
||
await self._send_event(connection, "error", detail="invalid chat_id")
|
||
return
|
||
if not isinstance(content, str):
|
||
await self._send_event(connection, "error", detail="missing content")
|
||
return
|
||
|
||
raw_media = envelope.get("media")
|
||
media_paths: list[str] = []
|
||
if raw_media is not None:
|
||
if not isinstance(raw_media, list):
|
||
await self._send_event(
|
||
connection, "error",
|
||
detail="image_rejected", reason="malformed",
|
||
)
|
||
return
|
||
media_paths, reason = self._save_envelope_media(raw_media)
|
||
if reason is not None:
|
||
await self._send_event(
|
||
connection, "error",
|
||
detail="image_rejected", reason=reason,
|
||
)
|
||
return
|
||
|
||
# Allow image-only turns (content may be empty when media is attached).
|
||
if not content.strip() and not media_paths:
|
||
await self._send_event(connection, "error", detail="missing content")
|
||
return
|
||
scope = await self._workspace_scope_or_error(
|
||
connection,
|
||
lambda: self._webui_workspaces.scope_for_message(
|
||
envelope,
|
||
chat_id=cid,
|
||
chat_running=websocket_turn_wall_started_at(cid) is not None,
|
||
controls_available=_is_localhost(connection),
|
||
),
|
||
chat_id=cid,
|
||
)
|
||
if scope is None:
|
||
return
|
||
|
||
# Auto-attach on first use so clients can one-shot without a separate attach.
|
||
self._attach(connection, cid)
|
||
await self._hydrate_after_subscribe(cid)
|
||
metadata: dict[str, Any] = {"remote": getattr(connection, "remote_address", None)}
|
||
if envelope.get("webui") is True:
|
||
metadata["webui"] = True
|
||
cli_apps = normalize_cli_app_mentions(envelope.get("cli_apps"))
|
||
if cli_apps:
|
||
metadata["cli_apps"] = cli_apps
|
||
mcp_presets = normalize_mcp_preset_mentions(envelope.get("mcp_presets"))
|
||
if mcp_presets:
|
||
metadata["mcp_presets"] = mcp_presets
|
||
metadata[WORKSPACE_SCOPE_METADATA_KEY] = scope.metadata()
|
||
self._webui_workspaces.persist_scope(cid, scope)
|
||
image_generation = envelope.get("image_generation")
|
||
if isinstance(image_generation, dict) and image_generation.get("enabled") is True:
|
||
aspect_ratio = image_generation.get("aspect_ratio")
|
||
metadata["image_generation"] = {
|
||
"enabled": True,
|
||
"aspect_ratio": aspect_ratio if isinstance(aspect_ratio, str) else None,
|
||
}
|
||
await self._handle_message(
|
||
sender_id=client_id,
|
||
chat_id=cid,
|
||
content=content,
|
||
media=media_paths or None,
|
||
metadata=metadata,
|
||
is_dm=False,
|
||
)
|
||
return
|
||
await self._send_event(connection, "error", detail=f"unknown type: {t!r}")
|
||
|
||
async def _workspace_scope_or_error(
|
||
self,
|
||
connection: Any,
|
||
resolver: Callable[[], Any],
|
||
*,
|
||
chat_id: str | None = None,
|
||
) -> Any | None:
|
||
try:
|
||
return resolver()
|
||
except WorkspaceScopeError as exc:
|
||
await self._send_event(
|
||
connection,
|
||
"error",
|
||
detail="workspace_scope_rejected",
|
||
reason=exc.message,
|
||
**({"chat_id": chat_id} if chat_id else {}),
|
||
)
|
||
return None
|
||
|
||
# -- Outbound WebSocket events -----------------------------------------
|
||
|
||
async def stop(self) -> None:
|
||
if not self._running:
|
||
return
|
||
self._running = False
|
||
if self._stop_event:
|
||
self._stop_event.set()
|
||
if self._server_task:
|
||
try:
|
||
await self._server_task
|
||
except Exception as e:
|
||
self.logger.warning("server task error during shutdown: {}", e)
|
||
self._server_task = None
|
||
self._subs.clear()
|
||
self._conn_chats.clear()
|
||
self._conn_default.clear()
|
||
self._http.issued_tokens.clear()
|
||
self._http.api_tokens.clear()
|
||
|
||
async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None:
|
||
"""Send a raw frame to one connection, cleaning up on ConnectionClosed."""
|
||
try:
|
||
await connection.send(raw)
|
||
except ConnectionClosed:
|
||
self._cleanup_connection(connection)
|
||
self.logger.warning("connection gone{}", label)
|
||
except Exception:
|
||
self.logger.exception("send failed{}", label)
|
||
raise
|
||
|
||
def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None:
|
||
sk = f"websocket:{chat_id}"
|
||
try:
|
||
dup = json.loads(json.dumps(wire, ensure_ascii=False))
|
||
append_transcript_object(sk, dup)
|
||
except (ValueError, TypeError) as e:
|
||
self.logger.warning("webui transcript append failed: {}", e)
|
||
|
||
async def send(self, msg: OutboundMessage) -> None:
|
||
if msg.metadata.get("_runtime_model_updated"):
|
||
await self.send_runtime_model_updated(
|
||
model_name=msg.metadata.get("model"),
|
||
model_preset=msg.metadata.get("model_preset"),
|
||
)
|
||
return
|
||
|
||
# Snapshot the subscriber set so ConnectionClosed cleanups mid-iteration are safe.
|
||
conns = list(self._subs.get(msg.chat_id, ()))
|
||
if not conns:
|
||
if (
|
||
msg.metadata.get("_progress")
|
||
or msg.metadata.get("_file_edit_events")
|
||
or msg.metadata.get("_turn_end")
|
||
or msg.metadata.get("_session_updated")
|
||
or msg.metadata.get("_goal_status")
|
||
or msg.metadata.get("_goal_state_sync")
|
||
):
|
||
self.logger.debug("no active subscribers for chat_id={}", msg.chat_id)
|
||
else:
|
||
self.logger.warning("no active subscribers for chat_id={}", msg.chat_id)
|
||
return
|
||
if msg.metadata.get("_goal_state_sync"):
|
||
blob = msg.metadata.get("goal_state")
|
||
await self.send_goal_state(msg.chat_id, blob if isinstance(blob, dict) else {"active": False})
|
||
return
|
||
if msg.metadata.get("_goal_status"):
|
||
status = msg.metadata.get("goal_status")
|
||
if status in ("running", "idle"):
|
||
started_raw = msg.metadata.get("started_at", msg.metadata.get("goal_started_at"))
|
||
await self.send_goal_status(
|
||
msg.chat_id,
|
||
status,
|
||
started_at=float(started_raw) if isinstance(started_raw, int | float) else None,
|
||
)
|
||
return
|
||
# Signal that the agent has fully finished processing the current turn.
|
||
if msg.metadata.get("_turn_end"):
|
||
lat = msg.metadata.get("latency_ms")
|
||
lat_i = int(lat) if isinstance(lat, (int, float)) else None
|
||
gs = msg.metadata.get("goal_state")
|
||
gs_blob = gs if isinstance(gs, dict) else None
|
||
await self.send_turn_end(msg.chat_id, latency_ms=lat_i, goal_state=gs_blob)
|
||
return
|
||
if msg.metadata.get("_session_updated"):
|
||
scope = msg.metadata.get("_session_update_scope")
|
||
await self.send_session_updated(
|
||
msg.chat_id,
|
||
scope=scope if isinstance(scope, str) else None,
|
||
)
|
||
return
|
||
if msg.metadata.get("_file_edit_events"):
|
||
edits = msg.metadata.get("_file_edit_events")
|
||
await self.send_file_edit_events(
|
||
msg.chat_id,
|
||
edits if isinstance(edits, list) else [],
|
||
msg.metadata,
|
||
)
|
||
return
|
||
text = msg.content
|
||
wire_text = self._http.rewrite_local_markdown_images(text)
|
||
payload: dict[str, Any] = {
|
||
"event": "message",
|
||
"chat_id": msg.chat_id,
|
||
"text": wire_text,
|
||
}
|
||
if msg.media:
|
||
payload["media"] = msg.media
|
||
urls: list[dict[str, str]] = []
|
||
for entry in msg.media:
|
||
signed = self._http.sign_or_stage_media_path(Path(entry))
|
||
if signed is not None:
|
||
urls.append(signed)
|
||
if urls:
|
||
payload["media_urls"] = urls
|
||
if msg.reply_to:
|
||
payload["reply_to"] = msg.reply_to
|
||
lat = msg.metadata.get("latency_ms")
|
||
if isinstance(lat, (int, float)):
|
||
payload["latency_ms"] = int(lat)
|
||
if msg.metadata.get("_tool_events"):
|
||
payload["tool_events"] = msg.metadata["_tool_events"]
|
||
agent_ui = msg.metadata.get(OUTBOUND_META_AGENT_UI)
|
||
if agent_ui is not None:
|
||
payload["agent_ui"] = agent_ui
|
||
# Mark intermediate agent breadcrumbs (tool-call hints, generic
|
||
# progress strings) so WS clients can render them as subordinate
|
||
# trace rows rather than conversational replies.
|
||
if msg.metadata.get("_tool_hint"):
|
||
payload["kind"] = "tool_hint"
|
||
elif msg.metadata.get("_progress"):
|
||
payload["kind"] = "progress"
|
||
transcript_payload = dict(payload)
|
||
transcript_payload["text"] = text
|
||
self._try_append_webui_transcript(msg.chat_id, transcript_payload)
|
||
raw = json.dumps(payload, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" ")
|
||
|
||
async def send_reasoning_delta(
|
||
self,
|
||
chat_id: str,
|
||
delta: str,
|
||
metadata: dict[str, Any] | None = None,
|
||
) -> None:
|
||
"""Push one chunk of model reasoning. Mirrors ``send_delta`` shape so
|
||
clients receive a stream that opens, updates in place, and closes —
|
||
rendered above the active assistant bubble with a shimmer header
|
||
until the matching ``reasoning_end`` arrives.
|
||
"""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns or not delta:
|
||
return
|
||
meta = metadata or {}
|
||
body: dict[str, Any] = {
|
||
"event": "reasoning_delta",
|
||
"chat_id": chat_id,
|
||
"text": delta,
|
||
}
|
||
stream_id = meta.get("_stream_id")
|
||
if stream_id is not None:
|
||
body["stream_id"] = stream_id
|
||
self._try_append_webui_transcript(chat_id, body)
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" reasoning ")
|
||
|
||
async def send_reasoning_end(
|
||
self,
|
||
chat_id: str,
|
||
metadata: dict[str, Any] | None = None,
|
||
) -> None:
|
||
"""Close the current reasoning stream segment for in-place renderers."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
meta = metadata or {}
|
||
body: dict[str, Any] = {
|
||
"event": "reasoning_end",
|
||
"chat_id": chat_id,
|
||
}
|
||
stream_id = meta.get("_stream_id")
|
||
if stream_id is not None:
|
||
body["stream_id"] = stream_id
|
||
self._try_append_webui_transcript(chat_id, body)
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" reasoning_end ")
|
||
|
||
async def send_file_edit_events(
|
||
self,
|
||
chat_id: str,
|
||
edits: list[dict[str, Any]],
|
||
metadata: dict[str, Any] | None = None,
|
||
) -> None:
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
payload: dict[str, Any] = {
|
||
"event": "file_edit",
|
||
"chat_id": chat_id,
|
||
"edits": edits,
|
||
}
|
||
self._try_append_webui_transcript(chat_id, payload)
|
||
raw = json.dumps(payload, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" file_edit ")
|
||
|
||
async def send_delta(
|
||
self,
|
||
chat_id: str,
|
||
delta: str,
|
||
metadata: dict[str, Any] | None = None,
|
||
) -> None:
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
meta = metadata or {}
|
||
stream_key = (chat_id, str(meta.get("_stream_id") or ""))
|
||
if meta.get("_stream_end"):
|
||
body: dict[str, Any] = {"event": "stream_end", "chat_id": chat_id}
|
||
buffered = self._stream_text_buffers.pop(stream_key, [])
|
||
if delta:
|
||
buffered.append(delta)
|
||
full_text = "".join(buffered)
|
||
rewritten = self._http.rewrite_local_markdown_images(full_text)
|
||
if rewritten != full_text:
|
||
body["text"] = rewritten
|
||
else:
|
||
body = {
|
||
"event": "delta",
|
||
"chat_id": chat_id,
|
||
"text": delta,
|
||
}
|
||
self._stream_text_buffers.setdefault(stream_key, []).append(delta)
|
||
if meta.get("_stream_id") is not None:
|
||
body["stream_id"] = meta["_stream_id"]
|
||
self._try_append_webui_transcript(chat_id, body)
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" stream ")
|
||
|
||
async def send_turn_end(
|
||
self,
|
||
chat_id: str,
|
||
latency_ms: int | None = None,
|
||
*,
|
||
goal_state: dict[str, Any] | None = None,
|
||
) -> None:
|
||
"""Signal that the agent has fully finished processing the current turn."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
body: dict[str, Any] = {"event": "turn_end", "chat_id": chat_id}
|
||
if latency_ms is not None:
|
||
body["latency_ms"] = int(latency_ms)
|
||
if goal_state is not None:
|
||
body["goal_state"] = goal_state
|
||
self._try_append_webui_transcript(chat_id, body)
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" turn_end ")
|
||
|
||
async def send_goal_state(self, chat_id: str, blob: dict[str, Any]) -> None:
|
||
"""Push persisted goal-state snapshot for *chat_id* (multi-chat isolation)."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
body = {"event": "goal_state", "chat_id": chat_id, "goal_state": blob}
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" goal_state ")
|
||
|
||
async def send_goal_status(
|
||
self,
|
||
chat_id: str,
|
||
status: str,
|
||
*,
|
||
started_at: float | None = None,
|
||
) -> None:
|
||
"""Notify subscribed clients that a turn started or finished (wall-clock hint)."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
body: dict[str, Any] = {
|
||
"event": "goal_status",
|
||
"chat_id": chat_id,
|
||
"status": status,
|
||
}
|
||
if status == "running" and started_at is not None:
|
||
body["started_at"] = started_at
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" goal_status ")
|
||
|
||
async def send_session_updated(self, chat_id: str, *, scope: str | None = None) -> None:
|
||
"""Notify clients that session metadata changed outside the main turn."""
|
||
conns = list(self._subs.get(chat_id, ()))
|
||
if not conns:
|
||
return
|
||
body: dict[str, Any] = {"event": "session_updated", "chat_id": chat_id}
|
||
if scope:
|
||
body["scope"] = scope
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" session_updated ")
|
||
|
||
async def send_runtime_model_updated(
|
||
self,
|
||
*,
|
||
model_name: Any,
|
||
model_preset: Any = None,
|
||
) -> None:
|
||
"""Broadcast runtime model changes to every open websocket connection."""
|
||
conns = list(self._conn_chats)
|
||
if not conns or not isinstance(model_name, str) or not model_name.strip():
|
||
return
|
||
body: dict[str, Any] = {
|
||
"event": "runtime_model_updated",
|
||
"model_name": model_name.strip(),
|
||
}
|
||
if isinstance(model_preset, str) and model_preset.strip():
|
||
body["model_preset"] = model_preset.strip()
|
||
raw = json.dumps(body, ensure_ascii=False)
|
||
for connection in conns:
|
||
await self._safe_send_to(connection, raw, label=" runtime_model_updated ")
|