refactor(webui): move media replay helpers out of websocket channel

This commit is contained in:
Xubin Ren 2026-06-02 16:13:49 +08:00
parent 7aa5e620be
commit f382133bb4
5 changed files with 155 additions and 125 deletions

View File

@ -7,7 +7,6 @@ import email.utils
import hmac
import http
import json
import logging
import mimetypes
import re
import secrets
@ -16,6 +15,7 @@ import time
import uuid
from collections.abc import Callable
from contextlib import suppress
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Self
from urllib.parse import parse_qs, unquote, urlparse
@ -48,9 +48,11 @@ from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_c
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
from nanobot.webui.media_api import (
attach_signed_media_urls,
serve_signed_media,
sign_media_path,
sign_or_stage_media_path,
signed_media_attachments,
)
from nanobot.webui.settings_api import runtime_capabilities
from nanobot.webui.settings_routes import WebUISettingsRouter
@ -64,6 +66,7 @@ from nanobot.webui.transcript import (
build_webui_thread_response,
rewrite_local_markdown_images,
)
from nanobot.webui.websocket_logging import websockets_server_logger
from nanobot.webui.workspaces import (
WebUIWorkspaceController,
)
@ -117,45 +120,6 @@ def _host_for_url(host: str, port: int) -> str:
return f"{host}:{port}"
_OPENING_HANDSHAKE_FAILED_MESSAGE = "opening handshake failed"
def _exception_chain_has_disconnect(exc: BaseException | None) -> bool:
seen: set[int] = set()
while exc is not None:
ident = id(exc)
if ident in seen:
return False
seen.add(ident)
if isinstance(exc, (
BrokenPipeError,
ConnectionAbortedError,
ConnectionResetError,
ConnectionClosed,
)):
return True
exc = exc.__cause__ or exc.__context__
return False
class _WebSocketHandshakeNoiseFilter(logging.Filter):
"""Suppress noisy restart-time handshakes where the client already disconnected."""
def filter(self, record: logging.LogRecord) -> bool:
if record.getMessage() != _OPENING_HANDSHAKE_FAILED_MESSAGE:
return True
exc_info = record.exc_info
exc = exc_info[1] if isinstance(exc_info, tuple) and len(exc_info) >= 2 else None
return not _exception_chain_has_disconnect(exc)
def _websockets_server_logger() -> logging.Logger:
ws_logger = logging.getLogger("websockets.server")
if not any(isinstance(f, _WebSocketHandshakeNoiseFilter) for f in ws_logger.filters):
ws_logger.addFilter(_WebSocketHandshakeNoiseFilter())
return ws_logger
class WebSocketConfig(Base):
"""WebSocket server channel configuration.
@ -1016,7 +980,7 @@ class WebSocketChannel(BaseChannel):
# client can render previews. The raw on-disk ``media`` paths are
# stripped on the way out — they leak server filesystem layout and
# the client never needs them once it has the signed fetch URL.
self._augment_media_urls(data)
attach_signed_media_urls(data, sign_path=self._sign_media_path)
return _http_json_response(data)
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
@ -1028,10 +992,14 @@ class WebSocketChannel(BaseChannel):
if not self._is_websocket_channel_session_key(decoded_key):
return _http_error(404, "session not found")
scope = self._webui_workspaces.scope_for_session_key(decoded_key)
augment_media = partial(
signed_media_attachments,
sign_path=self._sign_or_stage_media_path,
)
data = build_webui_thread_response(
decoded_key,
augment_user_media=self._augment_transcript_media,
augment_assistant_media=self._augment_transcript_media,
augment_user_media=augment_media,
augment_assistant_media=augment_media,
augment_assistant_text=lambda text: rewrite_local_markdown_images(
text,
workspace_path=scope.project_path,
@ -1051,25 +1019,6 @@ class WebSocketChannel(BaseChannel):
except (ValueError, TypeError) as e:
self.logger.warning("webui transcript append failed: {}", e)
def _augment_transcript_media(self, paths: list[str]) -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for pstr in paths:
path = Path(pstr)
att = self._sign_or_stage_media_path(path)
if att is None:
continue
mime, _ = mimetypes.guess_type(path.name)
if mime and mime.startswith("video/"):
kind = "video"
elif mime and mime.startswith("image/"):
kind = "image"
else:
kind = "file"
out.append(
{"kind": kind, "url": att["url"], "name": att.get("name", path.name)},
)
return out
async def _handle_message(
self,
sender_id: str,
@ -1106,37 +1055,6 @@ class WebSocketChannel(BaseChannel):
is_dm,
)
def _augment_media_urls(self, payload: dict[str, Any]) -> None:
"""Mutate *payload* in place: each message's ``media`` path list is
replaced by a parallel ``media_urls`` list of signed fetch URLs.
Messages without media or with non-string path entries are left
untouched. Paths that no longer live inside ``media_dir`` (e.g. the
file was deleted, or the dir was relocated) are silently skipped;
the client falls back to the historical-replay placeholder tile.
"""
messages = payload.get("messages")
if not isinstance(messages, list):
return
for msg in messages:
if not isinstance(msg, dict):
continue
media = msg.get("media")
if not isinstance(media, list) or not media:
continue
urls: list[dict[str, str]] = []
for entry in media:
if not isinstance(entry, str) or not entry:
continue
signed = self._sign_media_path(Path(entry))
if signed is None:
continue
urls.append({"url": signed, "name": Path(entry).name})
if urls:
msg["media_urls"] = urls
# Always drop the raw paths from the wire payload.
msg.pop("media", None)
def _sign_media_path(self, abs_path: Path) -> str | None:
"""Return a ``/api/media/<sig>/<payload>`` URL for *abs_path*, or
``None`` when the path does not resolve inside the media root.
@ -1279,7 +1197,7 @@ class WebSocketChannel(BaseChannel):
from nanobot.utils.logging_bridge import redirect_lib_logging
redirect_lib_logging("websockets", level="WARNING")
ws_logger = _websockets_server_logger()
ws_logger = websockets_server_logger()
self._running = True
self._stop_event = asyncio.Event()

View File

@ -24,6 +24,8 @@ from nanobot.config.paths import get_media_dir
from nanobot.utils.helpers import safe_filename
MediaDirProvider = Callable[[str | None], Path]
SignedMediaPath = Callable[[Path], dict[str, str] | None]
SignedMediaUrl = Callable[[Path], str | None]
def b64url_encode(data: bytes) -> str:
@ -172,6 +174,64 @@ def sign_or_stage_media_path(
return {"url": signed, "name": path.name}
def media_attachment_kind(name: str) -> str:
"""Infer the WebUI media attachment kind from a filename."""
mime, _ = mimetypes.guess_type(name)
if mime and mime.startswith("video/"):
return "video"
if mime and mime.startswith("image/"):
return "image"
return "file"
def signed_media_attachments(
paths: list[str],
*,
sign_path: SignedMediaPath,
) -> list[dict[str, Any]]:
"""Map persisted media paths to WebUI attachment dicts with fresh signed URLs."""
out: list[dict[str, Any]] = []
for pstr in paths:
path = Path(pstr)
att = sign_path(path)
if att is None:
continue
url = att.get("url")
if not url:
continue
name = att.get("name") or path.name
out.append({"kind": media_attachment_kind(name), "url": url, "name": name})
return out
def attach_signed_media_urls(
payload: dict[str, Any],
*,
sign_path: SignedMediaUrl,
) -> None:
"""Replace raw media path lists in a WebUI session payload with signed URLs."""
messages = payload.get("messages")
if not isinstance(messages, list):
return
for msg in messages:
if not isinstance(msg, dict):
continue
media = msg.get("media")
if not isinstance(media, list) or not media:
continue
urls: list[dict[str, str]] = []
for entry in media:
if not isinstance(entry, str) or not entry:
continue
signed = sign_path(Path(entry))
if signed is None:
continue
urls.append({"url": signed, "name": Path(entry).name})
if urls:
msg["media_urls"] = urls
msg.pop("media", None)
def serve_signed_media(
sig: str,
payload: str,

View File

@ -0,0 +1,45 @@
"""Logging helpers for the WebUI WebSocket server surface."""
from __future__ import annotations
import logging
from websockets.exceptions import ConnectionClosed
OPENING_HANDSHAKE_FAILED_MESSAGE = "opening handshake failed"
def _exception_chain_has_disconnect(exc: BaseException | None) -> bool:
seen: set[int] = set()
while exc is not None:
ident = id(exc)
if ident in seen:
return False
seen.add(ident)
if isinstance(exc, (
BrokenPipeError,
ConnectionAbortedError,
ConnectionResetError,
ConnectionClosed,
)):
return True
exc = exc.__cause__ or exc.__context__
return False
class WebSocketHandshakeNoiseFilter(logging.Filter):
"""Suppress restart-time handshakes where the browser already disconnected."""
def filter(self, record: logging.LogRecord) -> bool:
if record.getMessage() != OPENING_HANDSHAKE_FAILED_MESSAGE:
return True
exc_info = record.exc_info
exc = exc_info[1] if isinstance(exc_info, tuple) and len(exc_info) >= 2 else None
return not _exception_chain_has_disconnect(exc)
def websockets_server_logger() -> logging.Logger:
ws_logger = logging.getLogger("websockets.server")
if not any(isinstance(f, WebSocketHandshakeNoiseFilter) for f in ws_logger.filters):
ws_logger.addFilter(WebSocketHandshakeNoiseFilter())
return ws_logger

View File

@ -3,7 +3,6 @@
import asyncio
import functools
import json
import logging
import time
from typing import Any
from unittest.mock import AsyncMock, MagicMock
@ -17,7 +16,6 @@ from websockets.frames import Close
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.websocket import (
_OPENING_HANDSHAKE_FAILED_MESSAGE,
WebSocketChannel,
WebSocketConfig,
_is_valid_chat_id,
@ -28,7 +26,6 @@ from nanobot.channels.websocket import (
_parse_inbound_payload,
_parse_query,
_parse_request_path,
_WebSocketHandshakeNoiseFilter,
publish_runtime_model_update,
)
from nanobot.config.loader import load_config, save_config
@ -42,18 +39,6 @@ from nanobot.webui.settings_api import settings_payload, update_provider_setting
_PORT = 29876
def _log_record(message: str, exc: BaseException) -> logging.LogRecord:
return logging.LogRecord(
name="websockets.server",
level=logging.ERROR,
pathname=__file__,
lineno=1,
msg=message,
args=(),
exc_info=(type(exc), exc, exc.__traceback__),
)
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
cfg: dict[str, Any] = {
"enabled": True,
@ -128,22 +113,6 @@ def test_websocket_config_rejects_relative_unix_socket() -> None:
WebSocketConfig(unix_socket_path="engine.sock")
def test_websocket_handshake_noise_filter_suppresses_disconnects() -> None:
filter_ = _WebSocketHandshakeNoiseFilter()
wrapped = RuntimeError("wrapped")
wrapped.__cause__ = BrokenPipeError(32, "Broken pipe")
assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, BrokenPipeError()))
assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, wrapped))
def test_websocket_handshake_noise_filter_keeps_real_errors() -> None:
filter_ = _WebSocketHandshakeNoiseFilter()
assert filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, RuntimeError("boom")))
assert filter_.filter(_log_record("connection handler failed", BrokenPipeError()))
def test_parse_query_extracts_token_and_client_id() -> None:
query = _parse_query("/?token=secret&client_id=u1")
assert query.get("token") == ["secret"]

View File

@ -0,0 +1,38 @@
"""Tests for WebUI websocket logging helpers."""
from __future__ import annotations
import logging
from nanobot.webui.websocket_logging import (
OPENING_HANDSHAKE_FAILED_MESSAGE,
WebSocketHandshakeNoiseFilter,
)
def _log_record(message: str, exc: BaseException) -> logging.LogRecord:
return logging.LogRecord(
name="websockets.server",
level=logging.ERROR,
pathname=__file__,
lineno=1,
msg=message,
args=(),
exc_info=(type(exc), exc, exc.__traceback__),
)
def test_websocket_handshake_noise_filter_suppresses_disconnects() -> None:
filter_ = WebSocketHandshakeNoiseFilter()
wrapped = RuntimeError("wrapped")
wrapped.__cause__ = BrokenPipeError(32, "Broken pipe")
assert not filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, BrokenPipeError()))
assert not filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, wrapped))
def test_websocket_handshake_noise_filter_keeps_real_errors() -> None:
filter_ = WebSocketHandshakeNoiseFilter()
assert filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, RuntimeError("boom")))
assert filter_.filter(_log_record("connection handler failed", BrokenPipeError()))