From f382133bb410b8b14502e40d7902c1e5f3d0c744 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Tue, 2 Jun 2026 16:13:49 +0800 Subject: [PATCH] refactor(webui): move media replay helpers out of websocket channel --- nanobot/channels/websocket.py | 106 +++----------------- nanobot/webui/media_api.py | 60 +++++++++++ nanobot/webui/websocket_logging.py | 45 +++++++++ tests/channels/test_websocket_channel.py | 31 ------ tests/utils/test_webui_websocket_logging.py | 38 +++++++ 5 files changed, 155 insertions(+), 125 deletions(-) create mode 100644 nanobot/webui/websocket_logging.py create mode 100644 tests/utils/test_webui_websocket_logging.py diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index 58f0c265d..038aa7a22 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -7,7 +7,6 @@ import email.utils import hmac import http import json -import logging import mimetypes import re import secrets @@ -16,6 +15,7 @@ import time import uuid from collections.abc import Callable from contextlib import suppress +from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, Self from urllib.parse import parse_qs, unquote, urlparse @@ -48,9 +48,11 @@ from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_c from nanobot.webui.cli_apps_api import normalize_cli_app_mentions from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions from nanobot.webui.media_api import ( + attach_signed_media_urls, serve_signed_media, sign_media_path, sign_or_stage_media_path, + signed_media_attachments, ) from nanobot.webui.settings_api import runtime_capabilities from nanobot.webui.settings_routes import WebUISettingsRouter @@ -64,6 +66,7 @@ from nanobot.webui.transcript import ( build_webui_thread_response, rewrite_local_markdown_images, ) +from nanobot.webui.websocket_logging import websockets_server_logger from nanobot.webui.workspaces import ( WebUIWorkspaceController, ) @@ -117,45 +120,6 @@ def _host_for_url(host: str, port: int) -> str: return f"{host}:{port}" -_OPENING_HANDSHAKE_FAILED_MESSAGE = "opening handshake failed" - - -def _exception_chain_has_disconnect(exc: BaseException | None) -> bool: - seen: set[int] = set() - while exc is not None: - ident = id(exc) - if ident in seen: - return False - seen.add(ident) - if isinstance(exc, ( - BrokenPipeError, - ConnectionAbortedError, - ConnectionResetError, - ConnectionClosed, - )): - return True - exc = exc.__cause__ or exc.__context__ - return False - - -class _WebSocketHandshakeNoiseFilter(logging.Filter): - """Suppress noisy restart-time handshakes where the client already disconnected.""" - - def filter(self, record: logging.LogRecord) -> bool: - if record.getMessage() != _OPENING_HANDSHAKE_FAILED_MESSAGE: - return True - exc_info = record.exc_info - exc = exc_info[1] if isinstance(exc_info, tuple) and len(exc_info) >= 2 else None - return not _exception_chain_has_disconnect(exc) - - -def _websockets_server_logger() -> logging.Logger: - ws_logger = logging.getLogger("websockets.server") - if not any(isinstance(f, _WebSocketHandshakeNoiseFilter) for f in ws_logger.filters): - ws_logger.addFilter(_WebSocketHandshakeNoiseFilter()) - return ws_logger - - class WebSocketConfig(Base): """WebSocket server channel configuration. @@ -1016,7 +980,7 @@ class WebSocketChannel(BaseChannel): # client can render previews. The raw on-disk ``media`` paths are # stripped on the way out — they leak server filesystem layout and # the client never needs them once it has the signed fetch URL. - self._augment_media_urls(data) + attach_signed_media_urls(data, sign_path=self._sign_media_path) return _http_json_response(data) def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response: @@ -1028,10 +992,14 @@ class WebSocketChannel(BaseChannel): if not self._is_websocket_channel_session_key(decoded_key): return _http_error(404, "session not found") scope = self._webui_workspaces.scope_for_session_key(decoded_key) + augment_media = partial( + signed_media_attachments, + sign_path=self._sign_or_stage_media_path, + ) data = build_webui_thread_response( decoded_key, - augment_user_media=self._augment_transcript_media, - augment_assistant_media=self._augment_transcript_media, + augment_user_media=augment_media, + augment_assistant_media=augment_media, augment_assistant_text=lambda text: rewrite_local_markdown_images( text, workspace_path=scope.project_path, @@ -1051,25 +1019,6 @@ class WebSocketChannel(BaseChannel): except (ValueError, TypeError) as e: self.logger.warning("webui transcript append failed: {}", e) - def _augment_transcript_media(self, paths: list[str]) -> list[dict[str, Any]]: - out: list[dict[str, Any]] = [] - for pstr in paths: - path = Path(pstr) - att = self._sign_or_stage_media_path(path) - if att is None: - continue - mime, _ = mimetypes.guess_type(path.name) - if mime and mime.startswith("video/"): - kind = "video" - elif mime and mime.startswith("image/"): - kind = "image" - else: - kind = "file" - out.append( - {"kind": kind, "url": att["url"], "name": att.get("name", path.name)}, - ) - return out - async def _handle_message( self, sender_id: str, @@ -1106,37 +1055,6 @@ class WebSocketChannel(BaseChannel): is_dm, ) - def _augment_media_urls(self, payload: dict[str, Any]) -> None: - """Mutate *payload* in place: each message's ``media`` path list is - replaced by a parallel ``media_urls`` list of signed fetch URLs. - - Messages without media or with non-string path entries are left - untouched. Paths that no longer live inside ``media_dir`` (e.g. the - file was deleted, or the dir was relocated) are silently skipped; - the client falls back to the historical-replay placeholder tile. - """ - messages = payload.get("messages") - if not isinstance(messages, list): - return - for msg in messages: - if not isinstance(msg, dict): - continue - media = msg.get("media") - if not isinstance(media, list) or not media: - continue - urls: list[dict[str, str]] = [] - for entry in media: - if not isinstance(entry, str) or not entry: - continue - signed = self._sign_media_path(Path(entry)) - if signed is None: - continue - urls.append({"url": signed, "name": Path(entry).name}) - if urls: - msg["media_urls"] = urls - # Always drop the raw paths from the wire payload. - msg.pop("media", None) - def _sign_media_path(self, abs_path: Path) -> str | None: """Return a ``/api/media//`` URL for *abs_path*, or ``None`` when the path does not resolve inside the media root. @@ -1279,7 +1197,7 @@ class WebSocketChannel(BaseChannel): from nanobot.utils.logging_bridge import redirect_lib_logging redirect_lib_logging("websockets", level="WARNING") - ws_logger = _websockets_server_logger() + ws_logger = websockets_server_logger() self._running = True self._stop_event = asyncio.Event() diff --git a/nanobot/webui/media_api.py b/nanobot/webui/media_api.py index a4f6a7770..845ecd903 100644 --- a/nanobot/webui/media_api.py +++ b/nanobot/webui/media_api.py @@ -24,6 +24,8 @@ from nanobot.config.paths import get_media_dir from nanobot.utils.helpers import safe_filename MediaDirProvider = Callable[[str | None], Path] +SignedMediaPath = Callable[[Path], dict[str, str] | None] +SignedMediaUrl = Callable[[Path], str | None] def b64url_encode(data: bytes) -> str: @@ -172,6 +174,64 @@ def sign_or_stage_media_path( return {"url": signed, "name": path.name} +def media_attachment_kind(name: str) -> str: + """Infer the WebUI media attachment kind from a filename.""" + mime, _ = mimetypes.guess_type(name) + if mime and mime.startswith("video/"): + return "video" + if mime and mime.startswith("image/"): + return "image" + return "file" + + +def signed_media_attachments( + paths: list[str], + *, + sign_path: SignedMediaPath, +) -> list[dict[str, Any]]: + """Map persisted media paths to WebUI attachment dicts with fresh signed URLs.""" + out: list[dict[str, Any]] = [] + for pstr in paths: + path = Path(pstr) + att = sign_path(path) + if att is None: + continue + url = att.get("url") + if not url: + continue + name = att.get("name") or path.name + out.append({"kind": media_attachment_kind(name), "url": url, "name": name}) + return out + + +def attach_signed_media_urls( + payload: dict[str, Any], + *, + sign_path: SignedMediaUrl, +) -> None: + """Replace raw media path lists in a WebUI session payload with signed URLs.""" + messages = payload.get("messages") + if not isinstance(messages, list): + return + for msg in messages: + if not isinstance(msg, dict): + continue + media = msg.get("media") + if not isinstance(media, list) or not media: + continue + urls: list[dict[str, str]] = [] + for entry in media: + if not isinstance(entry, str) or not entry: + continue + signed = sign_path(Path(entry)) + if signed is None: + continue + urls.append({"url": signed, "name": Path(entry).name}) + if urls: + msg["media_urls"] = urls + msg.pop("media", None) + + def serve_signed_media( sig: str, payload: str, diff --git a/nanobot/webui/websocket_logging.py b/nanobot/webui/websocket_logging.py new file mode 100644 index 000000000..046b7a06b --- /dev/null +++ b/nanobot/webui/websocket_logging.py @@ -0,0 +1,45 @@ +"""Logging helpers for the WebUI WebSocket server surface.""" + +from __future__ import annotations + +import logging + +from websockets.exceptions import ConnectionClosed + +OPENING_HANDSHAKE_FAILED_MESSAGE = "opening handshake failed" + + +def _exception_chain_has_disconnect(exc: BaseException | None) -> bool: + seen: set[int] = set() + while exc is not None: + ident = id(exc) + if ident in seen: + return False + seen.add(ident) + if isinstance(exc, ( + BrokenPipeError, + ConnectionAbortedError, + ConnectionResetError, + ConnectionClosed, + )): + return True + exc = exc.__cause__ or exc.__context__ + return False + + +class WebSocketHandshakeNoiseFilter(logging.Filter): + """Suppress restart-time handshakes where the browser already disconnected.""" + + def filter(self, record: logging.LogRecord) -> bool: + if record.getMessage() != OPENING_HANDSHAKE_FAILED_MESSAGE: + return True + exc_info = record.exc_info + exc = exc_info[1] if isinstance(exc_info, tuple) and len(exc_info) >= 2 else None + return not _exception_chain_has_disconnect(exc) + + +def websockets_server_logger() -> logging.Logger: + ws_logger = logging.getLogger("websockets.server") + if not any(isinstance(f, WebSocketHandshakeNoiseFilter) for f in ws_logger.filters): + ws_logger.addFilter(WebSocketHandshakeNoiseFilter()) + return ws_logger diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index 78d22b969..03cee58f7 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -3,7 +3,6 @@ import asyncio import functools import json -import logging import time from typing import Any from unittest.mock import AsyncMock, MagicMock @@ -17,7 +16,6 @@ from websockets.frames import Close from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.websocket import ( - _OPENING_HANDSHAKE_FAILED_MESSAGE, WebSocketChannel, WebSocketConfig, _is_valid_chat_id, @@ -28,7 +26,6 @@ from nanobot.channels.websocket import ( _parse_inbound_payload, _parse_query, _parse_request_path, - _WebSocketHandshakeNoiseFilter, publish_runtime_model_update, ) from nanobot.config.loader import load_config, save_config @@ -42,18 +39,6 @@ from nanobot.webui.settings_api import settings_payload, update_provider_setting _PORT = 29876 -def _log_record(message: str, exc: BaseException) -> logging.LogRecord: - return logging.LogRecord( - name="websockets.server", - level=logging.ERROR, - pathname=__file__, - lineno=1, - msg=message, - args=(), - exc_info=(type(exc), exc, exc.__traceback__), - ) - - def _ch(bus: Any, **kw: Any) -> WebSocketChannel: cfg: dict[str, Any] = { "enabled": True, @@ -128,22 +113,6 @@ def test_websocket_config_rejects_relative_unix_socket() -> None: WebSocketConfig(unix_socket_path="engine.sock") -def test_websocket_handshake_noise_filter_suppresses_disconnects() -> None: - filter_ = _WebSocketHandshakeNoiseFilter() - wrapped = RuntimeError("wrapped") - wrapped.__cause__ = BrokenPipeError(32, "Broken pipe") - - assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, BrokenPipeError())) - assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, wrapped)) - - -def test_websocket_handshake_noise_filter_keeps_real_errors() -> None: - filter_ = _WebSocketHandshakeNoiseFilter() - - assert filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, RuntimeError("boom"))) - assert filter_.filter(_log_record("connection handler failed", BrokenPipeError())) - - def test_parse_query_extracts_token_and_client_id() -> None: query = _parse_query("/?token=secret&client_id=u1") assert query.get("token") == ["secret"] diff --git a/tests/utils/test_webui_websocket_logging.py b/tests/utils/test_webui_websocket_logging.py new file mode 100644 index 000000000..9adee3369 --- /dev/null +++ b/tests/utils/test_webui_websocket_logging.py @@ -0,0 +1,38 @@ +"""Tests for WebUI websocket logging helpers.""" + +from __future__ import annotations + +import logging + +from nanobot.webui.websocket_logging import ( + OPENING_HANDSHAKE_FAILED_MESSAGE, + WebSocketHandshakeNoiseFilter, +) + + +def _log_record(message: str, exc: BaseException) -> logging.LogRecord: + return logging.LogRecord( + name="websockets.server", + level=logging.ERROR, + pathname=__file__, + lineno=1, + msg=message, + args=(), + exc_info=(type(exc), exc, exc.__traceback__), + ) + + +def test_websocket_handshake_noise_filter_suppresses_disconnects() -> None: + filter_ = WebSocketHandshakeNoiseFilter() + wrapped = RuntimeError("wrapped") + wrapped.__cause__ = BrokenPipeError(32, "Broken pipe") + + assert not filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, BrokenPipeError())) + assert not filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, wrapped)) + + +def test_websocket_handshake_noise_filter_keeps_real_errors() -> None: + filter_ = WebSocketHandshakeNoiseFilter() + + assert filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, RuntimeError("boom"))) + assert filter_.filter(_log_record("connection handler failed", BrokenPipeError()))