refactor(webui): move media replay helpers out of websocket channel

This commit is contained in:
Xubin Ren 2026-06-02 16:13:49 +08:00
parent 7aa5e620be
commit f382133bb4
5 changed files with 155 additions and 125 deletions

View File

@ -7,7 +7,6 @@ import email.utils
import hmac import hmac
import http import http
import json import json
import logging
import mimetypes import mimetypes
import re import re
import secrets import secrets
@ -16,6 +15,7 @@ import time
import uuid import uuid
from collections.abc import Callable from collections.abc import Callable
from contextlib import suppress from contextlib import suppress
from functools import partial
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Self from typing import TYPE_CHECKING, Any, Self
from urllib.parse import parse_qs, unquote, urlparse from urllib.parse import parse_qs, unquote, urlparse
@ -48,9 +48,11 @@ from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_c
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
from nanobot.webui.media_api import ( from nanobot.webui.media_api import (
attach_signed_media_urls,
serve_signed_media, serve_signed_media,
sign_media_path, sign_media_path,
sign_or_stage_media_path, sign_or_stage_media_path,
signed_media_attachments,
) )
from nanobot.webui.settings_api import runtime_capabilities from nanobot.webui.settings_api import runtime_capabilities
from nanobot.webui.settings_routes import WebUISettingsRouter from nanobot.webui.settings_routes import WebUISettingsRouter
@ -64,6 +66,7 @@ from nanobot.webui.transcript import (
build_webui_thread_response, build_webui_thread_response,
rewrite_local_markdown_images, rewrite_local_markdown_images,
) )
from nanobot.webui.websocket_logging import websockets_server_logger
from nanobot.webui.workspaces import ( from nanobot.webui.workspaces import (
WebUIWorkspaceController, WebUIWorkspaceController,
) )
@ -117,45 +120,6 @@ def _host_for_url(host: str, port: int) -> str:
return f"{host}:{port}" return f"{host}:{port}"
_OPENING_HANDSHAKE_FAILED_MESSAGE = "opening handshake failed"
def _exception_chain_has_disconnect(exc: BaseException | None) -> bool:
seen: set[int] = set()
while exc is not None:
ident = id(exc)
if ident in seen:
return False
seen.add(ident)
if isinstance(exc, (
BrokenPipeError,
ConnectionAbortedError,
ConnectionResetError,
ConnectionClosed,
)):
return True
exc = exc.__cause__ or exc.__context__
return False
class _WebSocketHandshakeNoiseFilter(logging.Filter):
"""Suppress noisy restart-time handshakes where the client already disconnected."""
def filter(self, record: logging.LogRecord) -> bool:
if record.getMessage() != _OPENING_HANDSHAKE_FAILED_MESSAGE:
return True
exc_info = record.exc_info
exc = exc_info[1] if isinstance(exc_info, tuple) and len(exc_info) >= 2 else None
return not _exception_chain_has_disconnect(exc)
def _websockets_server_logger() -> logging.Logger:
ws_logger = logging.getLogger("websockets.server")
if not any(isinstance(f, _WebSocketHandshakeNoiseFilter) for f in ws_logger.filters):
ws_logger.addFilter(_WebSocketHandshakeNoiseFilter())
return ws_logger
class WebSocketConfig(Base): class WebSocketConfig(Base):
"""WebSocket server channel configuration. """WebSocket server channel configuration.
@ -1016,7 +980,7 @@ class WebSocketChannel(BaseChannel):
# client can render previews. The raw on-disk ``media`` paths are # client can render previews. The raw on-disk ``media`` paths are
# stripped on the way out — they leak server filesystem layout and # stripped on the way out — they leak server filesystem layout and
# the client never needs them once it has the signed fetch URL. # the client never needs them once it has the signed fetch URL.
self._augment_media_urls(data) attach_signed_media_urls(data, sign_path=self._sign_media_path)
return _http_json_response(data) return _http_json_response(data)
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response: def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
@ -1028,10 +992,14 @@ class WebSocketChannel(BaseChannel):
if not self._is_websocket_channel_session_key(decoded_key): if not self._is_websocket_channel_session_key(decoded_key):
return _http_error(404, "session not found") return _http_error(404, "session not found")
scope = self._webui_workspaces.scope_for_session_key(decoded_key) scope = self._webui_workspaces.scope_for_session_key(decoded_key)
augment_media = partial(
signed_media_attachments,
sign_path=self._sign_or_stage_media_path,
)
data = build_webui_thread_response( data = build_webui_thread_response(
decoded_key, decoded_key,
augment_user_media=self._augment_transcript_media, augment_user_media=augment_media,
augment_assistant_media=self._augment_transcript_media, augment_assistant_media=augment_media,
augment_assistant_text=lambda text: rewrite_local_markdown_images( augment_assistant_text=lambda text: rewrite_local_markdown_images(
text, text,
workspace_path=scope.project_path, workspace_path=scope.project_path,
@ -1051,25 +1019,6 @@ class WebSocketChannel(BaseChannel):
except (ValueError, TypeError) as e: except (ValueError, TypeError) as e:
self.logger.warning("webui transcript append failed: {}", e) self.logger.warning("webui transcript append failed: {}", e)
def _augment_transcript_media(self, paths: list[str]) -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for pstr in paths:
path = Path(pstr)
att = self._sign_or_stage_media_path(path)
if att is None:
continue
mime, _ = mimetypes.guess_type(path.name)
if mime and mime.startswith("video/"):
kind = "video"
elif mime and mime.startswith("image/"):
kind = "image"
else:
kind = "file"
out.append(
{"kind": kind, "url": att["url"], "name": att.get("name", path.name)},
)
return out
async def _handle_message( async def _handle_message(
self, self,
sender_id: str, sender_id: str,
@ -1106,37 +1055,6 @@ class WebSocketChannel(BaseChannel):
is_dm, is_dm,
) )
def _augment_media_urls(self, payload: dict[str, Any]) -> None:
"""Mutate *payload* in place: each message's ``media`` path list is
replaced by a parallel ``media_urls`` list of signed fetch URLs.
Messages without media or with non-string path entries are left
untouched. Paths that no longer live inside ``media_dir`` (e.g. the
file was deleted, or the dir was relocated) are silently skipped;
the client falls back to the historical-replay placeholder tile.
"""
messages = payload.get("messages")
if not isinstance(messages, list):
return
for msg in messages:
if not isinstance(msg, dict):
continue
media = msg.get("media")
if not isinstance(media, list) or not media:
continue
urls: list[dict[str, str]] = []
for entry in media:
if not isinstance(entry, str) or not entry:
continue
signed = self._sign_media_path(Path(entry))
if signed is None:
continue
urls.append({"url": signed, "name": Path(entry).name})
if urls:
msg["media_urls"] = urls
# Always drop the raw paths from the wire payload.
msg.pop("media", None)
def _sign_media_path(self, abs_path: Path) -> str | None: def _sign_media_path(self, abs_path: Path) -> str | None:
"""Return a ``/api/media/<sig>/<payload>`` URL for *abs_path*, or """Return a ``/api/media/<sig>/<payload>`` URL for *abs_path*, or
``None`` when the path does not resolve inside the media root. ``None`` when the path does not resolve inside the media root.
@ -1279,7 +1197,7 @@ class WebSocketChannel(BaseChannel):
from nanobot.utils.logging_bridge import redirect_lib_logging from nanobot.utils.logging_bridge import redirect_lib_logging
redirect_lib_logging("websockets", level="WARNING") redirect_lib_logging("websockets", level="WARNING")
ws_logger = _websockets_server_logger() ws_logger = websockets_server_logger()
self._running = True self._running = True
self._stop_event = asyncio.Event() self._stop_event = asyncio.Event()

View File

@ -24,6 +24,8 @@ from nanobot.config.paths import get_media_dir
from nanobot.utils.helpers import safe_filename from nanobot.utils.helpers import safe_filename
MediaDirProvider = Callable[[str | None], Path] MediaDirProvider = Callable[[str | None], Path]
SignedMediaPath = Callable[[Path], dict[str, str] | None]
SignedMediaUrl = Callable[[Path], str | None]
def b64url_encode(data: bytes) -> str: def b64url_encode(data: bytes) -> str:
@ -172,6 +174,64 @@ def sign_or_stage_media_path(
return {"url": signed, "name": path.name} return {"url": signed, "name": path.name}
def media_attachment_kind(name: str) -> str:
"""Infer the WebUI media attachment kind from a filename."""
mime, _ = mimetypes.guess_type(name)
if mime and mime.startswith("video/"):
return "video"
if mime and mime.startswith("image/"):
return "image"
return "file"
def signed_media_attachments(
paths: list[str],
*,
sign_path: SignedMediaPath,
) -> list[dict[str, Any]]:
"""Map persisted media paths to WebUI attachment dicts with fresh signed URLs."""
out: list[dict[str, Any]] = []
for pstr in paths:
path = Path(pstr)
att = sign_path(path)
if att is None:
continue
url = att.get("url")
if not url:
continue
name = att.get("name") or path.name
out.append({"kind": media_attachment_kind(name), "url": url, "name": name})
return out
def attach_signed_media_urls(
payload: dict[str, Any],
*,
sign_path: SignedMediaUrl,
) -> None:
"""Replace raw media path lists in a WebUI session payload with signed URLs."""
messages = payload.get("messages")
if not isinstance(messages, list):
return
for msg in messages:
if not isinstance(msg, dict):
continue
media = msg.get("media")
if not isinstance(media, list) or not media:
continue
urls: list[dict[str, str]] = []
for entry in media:
if not isinstance(entry, str) or not entry:
continue
signed = sign_path(Path(entry))
if signed is None:
continue
urls.append({"url": signed, "name": Path(entry).name})
if urls:
msg["media_urls"] = urls
msg.pop("media", None)
def serve_signed_media( def serve_signed_media(
sig: str, sig: str,
payload: str, payload: str,

View File

@ -0,0 +1,45 @@
"""Logging helpers for the WebUI WebSocket server surface."""
from __future__ import annotations
import logging
from websockets.exceptions import ConnectionClosed
OPENING_HANDSHAKE_FAILED_MESSAGE = "opening handshake failed"
def _exception_chain_has_disconnect(exc: BaseException | None) -> bool:
seen: set[int] = set()
while exc is not None:
ident = id(exc)
if ident in seen:
return False
seen.add(ident)
if isinstance(exc, (
BrokenPipeError,
ConnectionAbortedError,
ConnectionResetError,
ConnectionClosed,
)):
return True
exc = exc.__cause__ or exc.__context__
return False
class WebSocketHandshakeNoiseFilter(logging.Filter):
"""Suppress restart-time handshakes where the browser already disconnected."""
def filter(self, record: logging.LogRecord) -> bool:
if record.getMessage() != OPENING_HANDSHAKE_FAILED_MESSAGE:
return True
exc_info = record.exc_info
exc = exc_info[1] if isinstance(exc_info, tuple) and len(exc_info) >= 2 else None
return not _exception_chain_has_disconnect(exc)
def websockets_server_logger() -> logging.Logger:
ws_logger = logging.getLogger("websockets.server")
if not any(isinstance(f, WebSocketHandshakeNoiseFilter) for f in ws_logger.filters):
ws_logger.addFilter(WebSocketHandshakeNoiseFilter())
return ws_logger

View File

@ -3,7 +3,6 @@
import asyncio import asyncio
import functools import functools
import json import json
import logging
import time import time
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@ -17,7 +16,6 @@ from websockets.frames import Close
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.websocket import ( from nanobot.channels.websocket import (
_OPENING_HANDSHAKE_FAILED_MESSAGE,
WebSocketChannel, WebSocketChannel,
WebSocketConfig, WebSocketConfig,
_is_valid_chat_id, _is_valid_chat_id,
@ -28,7 +26,6 @@ from nanobot.channels.websocket import (
_parse_inbound_payload, _parse_inbound_payload,
_parse_query, _parse_query,
_parse_request_path, _parse_request_path,
_WebSocketHandshakeNoiseFilter,
publish_runtime_model_update, publish_runtime_model_update,
) )
from nanobot.config.loader import load_config, save_config from nanobot.config.loader import load_config, save_config
@ -42,18 +39,6 @@ from nanobot.webui.settings_api import settings_payload, update_provider_setting
_PORT = 29876 _PORT = 29876
def _log_record(message: str, exc: BaseException) -> logging.LogRecord:
return logging.LogRecord(
name="websockets.server",
level=logging.ERROR,
pathname=__file__,
lineno=1,
msg=message,
args=(),
exc_info=(type(exc), exc, exc.__traceback__),
)
def _ch(bus: Any, **kw: Any) -> WebSocketChannel: def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
cfg: dict[str, Any] = { cfg: dict[str, Any] = {
"enabled": True, "enabled": True,
@ -128,22 +113,6 @@ def test_websocket_config_rejects_relative_unix_socket() -> None:
WebSocketConfig(unix_socket_path="engine.sock") WebSocketConfig(unix_socket_path="engine.sock")
def test_websocket_handshake_noise_filter_suppresses_disconnects() -> None:
filter_ = _WebSocketHandshakeNoiseFilter()
wrapped = RuntimeError("wrapped")
wrapped.__cause__ = BrokenPipeError(32, "Broken pipe")
assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, BrokenPipeError()))
assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, wrapped))
def test_websocket_handshake_noise_filter_keeps_real_errors() -> None:
filter_ = _WebSocketHandshakeNoiseFilter()
assert filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, RuntimeError("boom")))
assert filter_.filter(_log_record("connection handler failed", BrokenPipeError()))
def test_parse_query_extracts_token_and_client_id() -> None: def test_parse_query_extracts_token_and_client_id() -> None:
query = _parse_query("/?token=secret&client_id=u1") query = _parse_query("/?token=secret&client_id=u1")
assert query.get("token") == ["secret"] assert query.get("token") == ["secret"]

View File

@ -0,0 +1,38 @@
"""Tests for WebUI websocket logging helpers."""
from __future__ import annotations
import logging
from nanobot.webui.websocket_logging import (
OPENING_HANDSHAKE_FAILED_MESSAGE,
WebSocketHandshakeNoiseFilter,
)
def _log_record(message: str, exc: BaseException) -> logging.LogRecord:
return logging.LogRecord(
name="websockets.server",
level=logging.ERROR,
pathname=__file__,
lineno=1,
msg=message,
args=(),
exc_info=(type(exc), exc, exc.__traceback__),
)
def test_websocket_handshake_noise_filter_suppresses_disconnects() -> None:
filter_ = WebSocketHandshakeNoiseFilter()
wrapped = RuntimeError("wrapped")
wrapped.__cause__ = BrokenPipeError(32, "Broken pipe")
assert not filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, BrokenPipeError()))
assert not filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, wrapped))
def test_websocket_handshake_noise_filter_keeps_real_errors() -> None:
filter_ = WebSocketHandshakeNoiseFilter()
assert filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, RuntimeError("boom")))
assert filter_.filter(_log_record("connection handler failed", BrokenPipeError()))