mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
refactor(webui): move media replay helpers out of websocket channel
This commit is contained in:
parent
7aa5e620be
commit
f382133bb4
@ -7,7 +7,6 @@ import email.utils
|
||||
import hmac
|
||||
import http
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import re
|
||||
import secrets
|
||||
@ -16,6 +15,7 @@ import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Self
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
@ -48,9 +48,11 @@ from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_c
|
||||
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
|
||||
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
||||
from nanobot.webui.media_api import (
|
||||
attach_signed_media_urls,
|
||||
serve_signed_media,
|
||||
sign_media_path,
|
||||
sign_or_stage_media_path,
|
||||
signed_media_attachments,
|
||||
)
|
||||
from nanobot.webui.settings_api import runtime_capabilities
|
||||
from nanobot.webui.settings_routes import WebUISettingsRouter
|
||||
@ -64,6 +66,7 @@ from nanobot.webui.transcript import (
|
||||
build_webui_thread_response,
|
||||
rewrite_local_markdown_images,
|
||||
)
|
||||
from nanobot.webui.websocket_logging import websockets_server_logger
|
||||
from nanobot.webui.workspaces import (
|
||||
WebUIWorkspaceController,
|
||||
)
|
||||
@ -117,45 +120,6 @@ def _host_for_url(host: str, port: int) -> str:
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
_OPENING_HANDSHAKE_FAILED_MESSAGE = "opening handshake failed"
|
||||
|
||||
|
||||
def _exception_chain_has_disconnect(exc: BaseException | None) -> bool:
|
||||
seen: set[int] = set()
|
||||
while exc is not None:
|
||||
ident = id(exc)
|
||||
if ident in seen:
|
||||
return False
|
||||
seen.add(ident)
|
||||
if isinstance(exc, (
|
||||
BrokenPipeError,
|
||||
ConnectionAbortedError,
|
||||
ConnectionResetError,
|
||||
ConnectionClosed,
|
||||
)):
|
||||
return True
|
||||
exc = exc.__cause__ or exc.__context__
|
||||
return False
|
||||
|
||||
|
||||
class _WebSocketHandshakeNoiseFilter(logging.Filter):
|
||||
"""Suppress noisy restart-time handshakes where the client already disconnected."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.getMessage() != _OPENING_HANDSHAKE_FAILED_MESSAGE:
|
||||
return True
|
||||
exc_info = record.exc_info
|
||||
exc = exc_info[1] if isinstance(exc_info, tuple) and len(exc_info) >= 2 else None
|
||||
return not _exception_chain_has_disconnect(exc)
|
||||
|
||||
|
||||
def _websockets_server_logger() -> logging.Logger:
|
||||
ws_logger = logging.getLogger("websockets.server")
|
||||
if not any(isinstance(f, _WebSocketHandshakeNoiseFilter) for f in ws_logger.filters):
|
||||
ws_logger.addFilter(_WebSocketHandshakeNoiseFilter())
|
||||
return ws_logger
|
||||
|
||||
|
||||
class WebSocketConfig(Base):
|
||||
"""WebSocket server channel configuration.
|
||||
|
||||
@ -1016,7 +980,7 @@ class WebSocketChannel(BaseChannel):
|
||||
# client can render previews. The raw on-disk ``media`` paths are
|
||||
# stripped on the way out — they leak server filesystem layout and
|
||||
# the client never needs them once it has the signed fetch URL.
|
||||
self._augment_media_urls(data)
|
||||
attach_signed_media_urls(data, sign_path=self._sign_media_path)
|
||||
return _http_json_response(data)
|
||||
|
||||
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
|
||||
@ -1028,10 +992,14 @@ class WebSocketChannel(BaseChannel):
|
||||
if not self._is_websocket_channel_session_key(decoded_key):
|
||||
return _http_error(404, "session not found")
|
||||
scope = self._webui_workspaces.scope_for_session_key(decoded_key)
|
||||
augment_media = partial(
|
||||
signed_media_attachments,
|
||||
sign_path=self._sign_or_stage_media_path,
|
||||
)
|
||||
data = build_webui_thread_response(
|
||||
decoded_key,
|
||||
augment_user_media=self._augment_transcript_media,
|
||||
augment_assistant_media=self._augment_transcript_media,
|
||||
augment_user_media=augment_media,
|
||||
augment_assistant_media=augment_media,
|
||||
augment_assistant_text=lambda text: rewrite_local_markdown_images(
|
||||
text,
|
||||
workspace_path=scope.project_path,
|
||||
@ -1051,25 +1019,6 @@ class WebSocketChannel(BaseChannel):
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning("webui transcript append failed: {}", e)
|
||||
|
||||
def _augment_transcript_media(self, paths: list[str]) -> list[dict[str, Any]]:
|
||||
out: list[dict[str, Any]] = []
|
||||
for pstr in paths:
|
||||
path = Path(pstr)
|
||||
att = self._sign_or_stage_media_path(path)
|
||||
if att is None:
|
||||
continue
|
||||
mime, _ = mimetypes.guess_type(path.name)
|
||||
if mime and mime.startswith("video/"):
|
||||
kind = "video"
|
||||
elif mime and mime.startswith("image/"):
|
||||
kind = "image"
|
||||
else:
|
||||
kind = "file"
|
||||
out.append(
|
||||
{"kind": kind, "url": att["url"], "name": att.get("name", path.name)},
|
||||
)
|
||||
return out
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
sender_id: str,
|
||||
@ -1106,37 +1055,6 @@ class WebSocketChannel(BaseChannel):
|
||||
is_dm,
|
||||
)
|
||||
|
||||
def _augment_media_urls(self, payload: dict[str, Any]) -> None:
|
||||
"""Mutate *payload* in place: each message's ``media`` path list is
|
||||
replaced by a parallel ``media_urls`` list of signed fetch URLs.
|
||||
|
||||
Messages without media or with non-string path entries are left
|
||||
untouched. Paths that no longer live inside ``media_dir`` (e.g. the
|
||||
file was deleted, or the dir was relocated) are silently skipped;
|
||||
the client falls back to the historical-replay placeholder tile.
|
||||
"""
|
||||
messages = payload.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
media = msg.get("media")
|
||||
if not isinstance(media, list) or not media:
|
||||
continue
|
||||
urls: list[dict[str, str]] = []
|
||||
for entry in media:
|
||||
if not isinstance(entry, str) or not entry:
|
||||
continue
|
||||
signed = self._sign_media_path(Path(entry))
|
||||
if signed is None:
|
||||
continue
|
||||
urls.append({"url": signed, "name": Path(entry).name})
|
||||
if urls:
|
||||
msg["media_urls"] = urls
|
||||
# Always drop the raw paths from the wire payload.
|
||||
msg.pop("media", None)
|
||||
|
||||
def _sign_media_path(self, abs_path: Path) -> str | None:
|
||||
"""Return a ``/api/media/<sig>/<payload>`` URL for *abs_path*, or
|
||||
``None`` when the path does not resolve inside the media root.
|
||||
@ -1279,7 +1197,7 @@ class WebSocketChannel(BaseChannel):
|
||||
from nanobot.utils.logging_bridge import redirect_lib_logging
|
||||
|
||||
redirect_lib_logging("websockets", level="WARNING")
|
||||
ws_logger = _websockets_server_logger()
|
||||
ws_logger = websockets_server_logger()
|
||||
|
||||
self._running = True
|
||||
self._stop_event = asyncio.Event()
|
||||
|
||||
@ -24,6 +24,8 @@ from nanobot.config.paths import get_media_dir
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
|
||||
MediaDirProvider = Callable[[str | None], Path]
|
||||
SignedMediaPath = Callable[[Path], dict[str, str] | None]
|
||||
SignedMediaUrl = Callable[[Path], str | None]
|
||||
|
||||
|
||||
def b64url_encode(data: bytes) -> str:
|
||||
@ -172,6 +174,64 @@ def sign_or_stage_media_path(
|
||||
return {"url": signed, "name": path.name}
|
||||
|
||||
|
||||
def media_attachment_kind(name: str) -> str:
|
||||
"""Infer the WebUI media attachment kind from a filename."""
|
||||
mime, _ = mimetypes.guess_type(name)
|
||||
if mime and mime.startswith("video/"):
|
||||
return "video"
|
||||
if mime and mime.startswith("image/"):
|
||||
return "image"
|
||||
return "file"
|
||||
|
||||
|
||||
def signed_media_attachments(
|
||||
paths: list[str],
|
||||
*,
|
||||
sign_path: SignedMediaPath,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Map persisted media paths to WebUI attachment dicts with fresh signed URLs."""
|
||||
out: list[dict[str, Any]] = []
|
||||
for pstr in paths:
|
||||
path = Path(pstr)
|
||||
att = sign_path(path)
|
||||
if att is None:
|
||||
continue
|
||||
url = att.get("url")
|
||||
if not url:
|
||||
continue
|
||||
name = att.get("name") or path.name
|
||||
out.append({"kind": media_attachment_kind(name), "url": url, "name": name})
|
||||
return out
|
||||
|
||||
|
||||
def attach_signed_media_urls(
|
||||
payload: dict[str, Any],
|
||||
*,
|
||||
sign_path: SignedMediaUrl,
|
||||
) -> None:
|
||||
"""Replace raw media path lists in a WebUI session payload with signed URLs."""
|
||||
messages = payload.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
media = msg.get("media")
|
||||
if not isinstance(media, list) or not media:
|
||||
continue
|
||||
urls: list[dict[str, str]] = []
|
||||
for entry in media:
|
||||
if not isinstance(entry, str) or not entry:
|
||||
continue
|
||||
signed = sign_path(Path(entry))
|
||||
if signed is None:
|
||||
continue
|
||||
urls.append({"url": signed, "name": Path(entry).name})
|
||||
if urls:
|
||||
msg["media_urls"] = urls
|
||||
msg.pop("media", None)
|
||||
|
||||
|
||||
def serve_signed_media(
|
||||
sig: str,
|
||||
payload: str,
|
||||
|
||||
45
nanobot/webui/websocket_logging.py
Normal file
45
nanobot/webui/websocket_logging.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Logging helpers for the WebUI WebSocket server surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
OPENING_HANDSHAKE_FAILED_MESSAGE = "opening handshake failed"
|
||||
|
||||
|
||||
def _exception_chain_has_disconnect(exc: BaseException | None) -> bool:
|
||||
seen: set[int] = set()
|
||||
while exc is not None:
|
||||
ident = id(exc)
|
||||
if ident in seen:
|
||||
return False
|
||||
seen.add(ident)
|
||||
if isinstance(exc, (
|
||||
BrokenPipeError,
|
||||
ConnectionAbortedError,
|
||||
ConnectionResetError,
|
||||
ConnectionClosed,
|
||||
)):
|
||||
return True
|
||||
exc = exc.__cause__ or exc.__context__
|
||||
return False
|
||||
|
||||
|
||||
class WebSocketHandshakeNoiseFilter(logging.Filter):
|
||||
"""Suppress restart-time handshakes where the browser already disconnected."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.getMessage() != OPENING_HANDSHAKE_FAILED_MESSAGE:
|
||||
return True
|
||||
exc_info = record.exc_info
|
||||
exc = exc_info[1] if isinstance(exc_info, tuple) and len(exc_info) >= 2 else None
|
||||
return not _exception_chain_has_disconnect(exc)
|
||||
|
||||
|
||||
def websockets_server_logger() -> logging.Logger:
|
||||
ws_logger = logging.getLogger("websockets.server")
|
||||
if not any(isinstance(f, WebSocketHandshakeNoiseFilter) for f in ws_logger.filters):
|
||||
ws_logger.addFilter(WebSocketHandshakeNoiseFilter())
|
||||
return ws_logger
|
||||
@ -3,7 +3,6 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
@ -17,7 +16,6 @@ from websockets.frames import Close
|
||||
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.websocket import (
|
||||
_OPENING_HANDSHAKE_FAILED_MESSAGE,
|
||||
WebSocketChannel,
|
||||
WebSocketConfig,
|
||||
_is_valid_chat_id,
|
||||
@ -28,7 +26,6 @@ from nanobot.channels.websocket import (
|
||||
_parse_inbound_payload,
|
||||
_parse_query,
|
||||
_parse_request_path,
|
||||
_WebSocketHandshakeNoiseFilter,
|
||||
publish_runtime_model_update,
|
||||
)
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
@ -42,18 +39,6 @@ from nanobot.webui.settings_api import settings_payload, update_provider_setting
|
||||
_PORT = 29876
|
||||
|
||||
|
||||
def _log_record(message: str, exc: BaseException) -> logging.LogRecord:
|
||||
return logging.LogRecord(
|
||||
name="websockets.server",
|
||||
level=logging.ERROR,
|
||||
pathname=__file__,
|
||||
lineno=1,
|
||||
msg=message,
|
||||
args=(),
|
||||
exc_info=(type(exc), exc, exc.__traceback__),
|
||||
)
|
||||
|
||||
|
||||
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
|
||||
cfg: dict[str, Any] = {
|
||||
"enabled": True,
|
||||
@ -128,22 +113,6 @@ def test_websocket_config_rejects_relative_unix_socket() -> None:
|
||||
WebSocketConfig(unix_socket_path="engine.sock")
|
||||
|
||||
|
||||
def test_websocket_handshake_noise_filter_suppresses_disconnects() -> None:
|
||||
filter_ = _WebSocketHandshakeNoiseFilter()
|
||||
wrapped = RuntimeError("wrapped")
|
||||
wrapped.__cause__ = BrokenPipeError(32, "Broken pipe")
|
||||
|
||||
assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, BrokenPipeError()))
|
||||
assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, wrapped))
|
||||
|
||||
|
||||
def test_websocket_handshake_noise_filter_keeps_real_errors() -> None:
|
||||
filter_ = _WebSocketHandshakeNoiseFilter()
|
||||
|
||||
assert filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, RuntimeError("boom")))
|
||||
assert filter_.filter(_log_record("connection handler failed", BrokenPipeError()))
|
||||
|
||||
|
||||
def test_parse_query_extracts_token_and_client_id() -> None:
|
||||
query = _parse_query("/?token=secret&client_id=u1")
|
||||
assert query.get("token") == ["secret"]
|
||||
|
||||
38
tests/utils/test_webui_websocket_logging.py
Normal file
38
tests/utils/test_webui_websocket_logging.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""Tests for WebUI websocket logging helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from nanobot.webui.websocket_logging import (
|
||||
OPENING_HANDSHAKE_FAILED_MESSAGE,
|
||||
WebSocketHandshakeNoiseFilter,
|
||||
)
|
||||
|
||||
|
||||
def _log_record(message: str, exc: BaseException) -> logging.LogRecord:
|
||||
return logging.LogRecord(
|
||||
name="websockets.server",
|
||||
level=logging.ERROR,
|
||||
pathname=__file__,
|
||||
lineno=1,
|
||||
msg=message,
|
||||
args=(),
|
||||
exc_info=(type(exc), exc, exc.__traceback__),
|
||||
)
|
||||
|
||||
|
||||
def test_websocket_handshake_noise_filter_suppresses_disconnects() -> None:
|
||||
filter_ = WebSocketHandshakeNoiseFilter()
|
||||
wrapped = RuntimeError("wrapped")
|
||||
wrapped.__cause__ = BrokenPipeError(32, "Broken pipe")
|
||||
|
||||
assert not filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, BrokenPipeError()))
|
||||
assert not filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, wrapped))
|
||||
|
||||
|
||||
def test_websocket_handshake_noise_filter_keeps_real_errors() -> None:
|
||||
filter_ = WebSocketHandshakeNoiseFilter()
|
||||
|
||||
assert filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, RuntimeError("boom")))
|
||||
assert filter_.filter(_log_record("connection handler failed", BrokenPipeError()))
|
||||
Loading…
x
Reference in New Issue
Block a user