mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-14 06:43:53 +00:00
refactor(webui): move media replay helpers out of websocket channel
This commit is contained in:
parent
7aa5e620be
commit
f382133bb4
@ -7,7 +7,6 @@ import email.utils
|
|||||||
import hmac
|
import hmac
|
||||||
import http
|
import http
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
@ -16,6 +15,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Self
|
from typing import TYPE_CHECKING, Any, Self
|
||||||
from urllib.parse import parse_qs, unquote, urlparse
|
from urllib.parse import parse_qs, unquote, urlparse
|
||||||
@ -48,9 +48,11 @@ from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_c
|
|||||||
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
|
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
|
||||||
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
||||||
from nanobot.webui.media_api import (
|
from nanobot.webui.media_api import (
|
||||||
|
attach_signed_media_urls,
|
||||||
serve_signed_media,
|
serve_signed_media,
|
||||||
sign_media_path,
|
sign_media_path,
|
||||||
sign_or_stage_media_path,
|
sign_or_stage_media_path,
|
||||||
|
signed_media_attachments,
|
||||||
)
|
)
|
||||||
from nanobot.webui.settings_api import runtime_capabilities
|
from nanobot.webui.settings_api import runtime_capabilities
|
||||||
from nanobot.webui.settings_routes import WebUISettingsRouter
|
from nanobot.webui.settings_routes import WebUISettingsRouter
|
||||||
@ -64,6 +66,7 @@ from nanobot.webui.transcript import (
|
|||||||
build_webui_thread_response,
|
build_webui_thread_response,
|
||||||
rewrite_local_markdown_images,
|
rewrite_local_markdown_images,
|
||||||
)
|
)
|
||||||
|
from nanobot.webui.websocket_logging import websockets_server_logger
|
||||||
from nanobot.webui.workspaces import (
|
from nanobot.webui.workspaces import (
|
||||||
WebUIWorkspaceController,
|
WebUIWorkspaceController,
|
||||||
)
|
)
|
||||||
@ -117,45 +120,6 @@ def _host_for_url(host: str, port: int) -> str:
|
|||||||
return f"{host}:{port}"
|
return f"{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
_OPENING_HANDSHAKE_FAILED_MESSAGE = "opening handshake failed"
|
|
||||||
|
|
||||||
|
|
||||||
def _exception_chain_has_disconnect(exc: BaseException | None) -> bool:
|
|
||||||
seen: set[int] = set()
|
|
||||||
while exc is not None:
|
|
||||||
ident = id(exc)
|
|
||||||
if ident in seen:
|
|
||||||
return False
|
|
||||||
seen.add(ident)
|
|
||||||
if isinstance(exc, (
|
|
||||||
BrokenPipeError,
|
|
||||||
ConnectionAbortedError,
|
|
||||||
ConnectionResetError,
|
|
||||||
ConnectionClosed,
|
|
||||||
)):
|
|
||||||
return True
|
|
||||||
exc = exc.__cause__ or exc.__context__
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class _WebSocketHandshakeNoiseFilter(logging.Filter):
|
|
||||||
"""Suppress noisy restart-time handshakes where the client already disconnected."""
|
|
||||||
|
|
||||||
def filter(self, record: logging.LogRecord) -> bool:
|
|
||||||
if record.getMessage() != _OPENING_HANDSHAKE_FAILED_MESSAGE:
|
|
||||||
return True
|
|
||||||
exc_info = record.exc_info
|
|
||||||
exc = exc_info[1] if isinstance(exc_info, tuple) and len(exc_info) >= 2 else None
|
|
||||||
return not _exception_chain_has_disconnect(exc)
|
|
||||||
|
|
||||||
|
|
||||||
def _websockets_server_logger() -> logging.Logger:
|
|
||||||
ws_logger = logging.getLogger("websockets.server")
|
|
||||||
if not any(isinstance(f, _WebSocketHandshakeNoiseFilter) for f in ws_logger.filters):
|
|
||||||
ws_logger.addFilter(_WebSocketHandshakeNoiseFilter())
|
|
||||||
return ws_logger
|
|
||||||
|
|
||||||
|
|
||||||
class WebSocketConfig(Base):
|
class WebSocketConfig(Base):
|
||||||
"""WebSocket server channel configuration.
|
"""WebSocket server channel configuration.
|
||||||
|
|
||||||
@ -1016,7 +980,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
# client can render previews. The raw on-disk ``media`` paths are
|
# client can render previews. The raw on-disk ``media`` paths are
|
||||||
# stripped on the way out — they leak server filesystem layout and
|
# stripped on the way out — they leak server filesystem layout and
|
||||||
# the client never needs them once it has the signed fetch URL.
|
# the client never needs them once it has the signed fetch URL.
|
||||||
self._augment_media_urls(data)
|
attach_signed_media_urls(data, sign_path=self._sign_media_path)
|
||||||
return _http_json_response(data)
|
return _http_json_response(data)
|
||||||
|
|
||||||
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
|
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
|
||||||
@ -1028,10 +992,14 @@ class WebSocketChannel(BaseChannel):
|
|||||||
if not self._is_websocket_channel_session_key(decoded_key):
|
if not self._is_websocket_channel_session_key(decoded_key):
|
||||||
return _http_error(404, "session not found")
|
return _http_error(404, "session not found")
|
||||||
scope = self._webui_workspaces.scope_for_session_key(decoded_key)
|
scope = self._webui_workspaces.scope_for_session_key(decoded_key)
|
||||||
|
augment_media = partial(
|
||||||
|
signed_media_attachments,
|
||||||
|
sign_path=self._sign_or_stage_media_path,
|
||||||
|
)
|
||||||
data = build_webui_thread_response(
|
data = build_webui_thread_response(
|
||||||
decoded_key,
|
decoded_key,
|
||||||
augment_user_media=self._augment_transcript_media,
|
augment_user_media=augment_media,
|
||||||
augment_assistant_media=self._augment_transcript_media,
|
augment_assistant_media=augment_media,
|
||||||
augment_assistant_text=lambda text: rewrite_local_markdown_images(
|
augment_assistant_text=lambda text: rewrite_local_markdown_images(
|
||||||
text,
|
text,
|
||||||
workspace_path=scope.project_path,
|
workspace_path=scope.project_path,
|
||||||
@ -1051,25 +1019,6 @@ class WebSocketChannel(BaseChannel):
|
|||||||
except (ValueError, TypeError) as e:
|
except (ValueError, TypeError) as e:
|
||||||
self.logger.warning("webui transcript append failed: {}", e)
|
self.logger.warning("webui transcript append failed: {}", e)
|
||||||
|
|
||||||
def _augment_transcript_media(self, paths: list[str]) -> list[dict[str, Any]]:
|
|
||||||
out: list[dict[str, Any]] = []
|
|
||||||
for pstr in paths:
|
|
||||||
path = Path(pstr)
|
|
||||||
att = self._sign_or_stage_media_path(path)
|
|
||||||
if att is None:
|
|
||||||
continue
|
|
||||||
mime, _ = mimetypes.guess_type(path.name)
|
|
||||||
if mime and mime.startswith("video/"):
|
|
||||||
kind = "video"
|
|
||||||
elif mime and mime.startswith("image/"):
|
|
||||||
kind = "image"
|
|
||||||
else:
|
|
||||||
kind = "file"
|
|
||||||
out.append(
|
|
||||||
{"kind": kind, "url": att["url"], "name": att.get("name", path.name)},
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def _handle_message(
|
async def _handle_message(
|
||||||
self,
|
self,
|
||||||
sender_id: str,
|
sender_id: str,
|
||||||
@ -1106,37 +1055,6 @@ class WebSocketChannel(BaseChannel):
|
|||||||
is_dm,
|
is_dm,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _augment_media_urls(self, payload: dict[str, Any]) -> None:
|
|
||||||
"""Mutate *payload* in place: each message's ``media`` path list is
|
|
||||||
replaced by a parallel ``media_urls`` list of signed fetch URLs.
|
|
||||||
|
|
||||||
Messages without media or with non-string path entries are left
|
|
||||||
untouched. Paths that no longer live inside ``media_dir`` (e.g. the
|
|
||||||
file was deleted, or the dir was relocated) are silently skipped;
|
|
||||||
the client falls back to the historical-replay placeholder tile.
|
|
||||||
"""
|
|
||||||
messages = payload.get("messages")
|
|
||||||
if not isinstance(messages, list):
|
|
||||||
return
|
|
||||||
for msg in messages:
|
|
||||||
if not isinstance(msg, dict):
|
|
||||||
continue
|
|
||||||
media = msg.get("media")
|
|
||||||
if not isinstance(media, list) or not media:
|
|
||||||
continue
|
|
||||||
urls: list[dict[str, str]] = []
|
|
||||||
for entry in media:
|
|
||||||
if not isinstance(entry, str) or not entry:
|
|
||||||
continue
|
|
||||||
signed = self._sign_media_path(Path(entry))
|
|
||||||
if signed is None:
|
|
||||||
continue
|
|
||||||
urls.append({"url": signed, "name": Path(entry).name})
|
|
||||||
if urls:
|
|
||||||
msg["media_urls"] = urls
|
|
||||||
# Always drop the raw paths from the wire payload.
|
|
||||||
msg.pop("media", None)
|
|
||||||
|
|
||||||
def _sign_media_path(self, abs_path: Path) -> str | None:
|
def _sign_media_path(self, abs_path: Path) -> str | None:
|
||||||
"""Return a ``/api/media/<sig>/<payload>`` URL for *abs_path*, or
|
"""Return a ``/api/media/<sig>/<payload>`` URL for *abs_path*, or
|
||||||
``None`` when the path does not resolve inside the media root.
|
``None`` when the path does not resolve inside the media root.
|
||||||
@ -1279,7 +1197,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
from nanobot.utils.logging_bridge import redirect_lib_logging
|
from nanobot.utils.logging_bridge import redirect_lib_logging
|
||||||
|
|
||||||
redirect_lib_logging("websockets", level="WARNING")
|
redirect_lib_logging("websockets", level="WARNING")
|
||||||
ws_logger = _websockets_server_logger()
|
ws_logger = websockets_server_logger()
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
self._stop_event = asyncio.Event()
|
self._stop_event = asyncio.Event()
|
||||||
|
|||||||
@ -24,6 +24,8 @@ from nanobot.config.paths import get_media_dir
|
|||||||
from nanobot.utils.helpers import safe_filename
|
from nanobot.utils.helpers import safe_filename
|
||||||
|
|
||||||
MediaDirProvider = Callable[[str | None], Path]
|
MediaDirProvider = Callable[[str | None], Path]
|
||||||
|
SignedMediaPath = Callable[[Path], dict[str, str] | None]
|
||||||
|
SignedMediaUrl = Callable[[Path], str | None]
|
||||||
|
|
||||||
|
|
||||||
def b64url_encode(data: bytes) -> str:
|
def b64url_encode(data: bytes) -> str:
|
||||||
@ -172,6 +174,64 @@ def sign_or_stage_media_path(
|
|||||||
return {"url": signed, "name": path.name}
|
return {"url": signed, "name": path.name}
|
||||||
|
|
||||||
|
|
||||||
|
def media_attachment_kind(name: str) -> str:
|
||||||
|
"""Infer the WebUI media attachment kind from a filename."""
|
||||||
|
mime, _ = mimetypes.guess_type(name)
|
||||||
|
if mime and mime.startswith("video/"):
|
||||||
|
return "video"
|
||||||
|
if mime and mime.startswith("image/"):
|
||||||
|
return "image"
|
||||||
|
return "file"
|
||||||
|
|
||||||
|
|
||||||
|
def signed_media_attachments(
|
||||||
|
paths: list[str],
|
||||||
|
*,
|
||||||
|
sign_path: SignedMediaPath,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Map persisted media paths to WebUI attachment dicts with fresh signed URLs."""
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
for pstr in paths:
|
||||||
|
path = Path(pstr)
|
||||||
|
att = sign_path(path)
|
||||||
|
if att is None:
|
||||||
|
continue
|
||||||
|
url = att.get("url")
|
||||||
|
if not url:
|
||||||
|
continue
|
||||||
|
name = att.get("name") or path.name
|
||||||
|
out.append({"kind": media_attachment_kind(name), "url": url, "name": name})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def attach_signed_media_urls(
|
||||||
|
payload: dict[str, Any],
|
||||||
|
*,
|
||||||
|
sign_path: SignedMediaUrl,
|
||||||
|
) -> None:
|
||||||
|
"""Replace raw media path lists in a WebUI session payload with signed URLs."""
|
||||||
|
messages = payload.get("messages")
|
||||||
|
if not isinstance(messages, list):
|
||||||
|
return
|
||||||
|
for msg in messages:
|
||||||
|
if not isinstance(msg, dict):
|
||||||
|
continue
|
||||||
|
media = msg.get("media")
|
||||||
|
if not isinstance(media, list) or not media:
|
||||||
|
continue
|
||||||
|
urls: list[dict[str, str]] = []
|
||||||
|
for entry in media:
|
||||||
|
if not isinstance(entry, str) or not entry:
|
||||||
|
continue
|
||||||
|
signed = sign_path(Path(entry))
|
||||||
|
if signed is None:
|
||||||
|
continue
|
||||||
|
urls.append({"url": signed, "name": Path(entry).name})
|
||||||
|
if urls:
|
||||||
|
msg["media_urls"] = urls
|
||||||
|
msg.pop("media", None)
|
||||||
|
|
||||||
|
|
||||||
def serve_signed_media(
|
def serve_signed_media(
|
||||||
sig: str,
|
sig: str,
|
||||||
payload: str,
|
payload: str,
|
||||||
|
|||||||
45
nanobot/webui/websocket_logging.py
Normal file
45
nanobot/webui/websocket_logging.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
"""Logging helpers for the WebUI WebSocket server surface."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from websockets.exceptions import ConnectionClosed
|
||||||
|
|
||||||
|
OPENING_HANDSHAKE_FAILED_MESSAGE = "opening handshake failed"
|
||||||
|
|
||||||
|
|
||||||
|
def _exception_chain_has_disconnect(exc: BaseException | None) -> bool:
|
||||||
|
seen: set[int] = set()
|
||||||
|
while exc is not None:
|
||||||
|
ident = id(exc)
|
||||||
|
if ident in seen:
|
||||||
|
return False
|
||||||
|
seen.add(ident)
|
||||||
|
if isinstance(exc, (
|
||||||
|
BrokenPipeError,
|
||||||
|
ConnectionAbortedError,
|
||||||
|
ConnectionResetError,
|
||||||
|
ConnectionClosed,
|
||||||
|
)):
|
||||||
|
return True
|
||||||
|
exc = exc.__cause__ or exc.__context__
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketHandshakeNoiseFilter(logging.Filter):
|
||||||
|
"""Suppress restart-time handshakes where the browser already disconnected."""
|
||||||
|
|
||||||
|
def filter(self, record: logging.LogRecord) -> bool:
|
||||||
|
if record.getMessage() != OPENING_HANDSHAKE_FAILED_MESSAGE:
|
||||||
|
return True
|
||||||
|
exc_info = record.exc_info
|
||||||
|
exc = exc_info[1] if isinstance(exc_info, tuple) and len(exc_info) >= 2 else None
|
||||||
|
return not _exception_chain_has_disconnect(exc)
|
||||||
|
|
||||||
|
|
||||||
|
def websockets_server_logger() -> logging.Logger:
|
||||||
|
ws_logger = logging.getLogger("websockets.server")
|
||||||
|
if not any(isinstance(f, WebSocketHandshakeNoiseFilter) for f in ws_logger.filters):
|
||||||
|
ws_logger.addFilter(WebSocketHandshakeNoiseFilter())
|
||||||
|
return ws_logger
|
||||||
@ -3,7 +3,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
@ -17,7 +16,6 @@ from websockets.frames import Close
|
|||||||
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
|
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.websocket import (
|
from nanobot.channels.websocket import (
|
||||||
_OPENING_HANDSHAKE_FAILED_MESSAGE,
|
|
||||||
WebSocketChannel,
|
WebSocketChannel,
|
||||||
WebSocketConfig,
|
WebSocketConfig,
|
||||||
_is_valid_chat_id,
|
_is_valid_chat_id,
|
||||||
@ -28,7 +26,6 @@ from nanobot.channels.websocket import (
|
|||||||
_parse_inbound_payload,
|
_parse_inbound_payload,
|
||||||
_parse_query,
|
_parse_query,
|
||||||
_parse_request_path,
|
_parse_request_path,
|
||||||
_WebSocketHandshakeNoiseFilter,
|
|
||||||
publish_runtime_model_update,
|
publish_runtime_model_update,
|
||||||
)
|
)
|
||||||
from nanobot.config.loader import load_config, save_config
|
from nanobot.config.loader import load_config, save_config
|
||||||
@ -42,18 +39,6 @@ from nanobot.webui.settings_api import settings_payload, update_provider_setting
|
|||||||
_PORT = 29876
|
_PORT = 29876
|
||||||
|
|
||||||
|
|
||||||
def _log_record(message: str, exc: BaseException) -> logging.LogRecord:
|
|
||||||
return logging.LogRecord(
|
|
||||||
name="websockets.server",
|
|
||||||
level=logging.ERROR,
|
|
||||||
pathname=__file__,
|
|
||||||
lineno=1,
|
|
||||||
msg=message,
|
|
||||||
args=(),
|
|
||||||
exc_info=(type(exc), exc, exc.__traceback__),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
|
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
|
||||||
cfg: dict[str, Any] = {
|
cfg: dict[str, Any] = {
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
@ -128,22 +113,6 @@ def test_websocket_config_rejects_relative_unix_socket() -> None:
|
|||||||
WebSocketConfig(unix_socket_path="engine.sock")
|
WebSocketConfig(unix_socket_path="engine.sock")
|
||||||
|
|
||||||
|
|
||||||
def test_websocket_handshake_noise_filter_suppresses_disconnects() -> None:
|
|
||||||
filter_ = _WebSocketHandshakeNoiseFilter()
|
|
||||||
wrapped = RuntimeError("wrapped")
|
|
||||||
wrapped.__cause__ = BrokenPipeError(32, "Broken pipe")
|
|
||||||
|
|
||||||
assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, BrokenPipeError()))
|
|
||||||
assert not filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, wrapped))
|
|
||||||
|
|
||||||
|
|
||||||
def test_websocket_handshake_noise_filter_keeps_real_errors() -> None:
|
|
||||||
filter_ = _WebSocketHandshakeNoiseFilter()
|
|
||||||
|
|
||||||
assert filter_.filter(_log_record(_OPENING_HANDSHAKE_FAILED_MESSAGE, RuntimeError("boom")))
|
|
||||||
assert filter_.filter(_log_record("connection handler failed", BrokenPipeError()))
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_query_extracts_token_and_client_id() -> None:
|
def test_parse_query_extracts_token_and_client_id() -> None:
|
||||||
query = _parse_query("/?token=secret&client_id=u1")
|
query = _parse_query("/?token=secret&client_id=u1")
|
||||||
assert query.get("token") == ["secret"]
|
assert query.get("token") == ["secret"]
|
||||||
|
|||||||
38
tests/utils/test_webui_websocket_logging.py
Normal file
38
tests/utils/test_webui_websocket_logging.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
"""Tests for WebUI websocket logging helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from nanobot.webui.websocket_logging import (
|
||||||
|
OPENING_HANDSHAKE_FAILED_MESSAGE,
|
||||||
|
WebSocketHandshakeNoiseFilter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_record(message: str, exc: BaseException) -> logging.LogRecord:
|
||||||
|
return logging.LogRecord(
|
||||||
|
name="websockets.server",
|
||||||
|
level=logging.ERROR,
|
||||||
|
pathname=__file__,
|
||||||
|
lineno=1,
|
||||||
|
msg=message,
|
||||||
|
args=(),
|
||||||
|
exc_info=(type(exc), exc, exc.__traceback__),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_handshake_noise_filter_suppresses_disconnects() -> None:
|
||||||
|
filter_ = WebSocketHandshakeNoiseFilter()
|
||||||
|
wrapped = RuntimeError("wrapped")
|
||||||
|
wrapped.__cause__ = BrokenPipeError(32, "Broken pipe")
|
||||||
|
|
||||||
|
assert not filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, BrokenPipeError()))
|
||||||
|
assert not filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, wrapped))
|
||||||
|
|
||||||
|
|
||||||
|
def test_websocket_handshake_noise_filter_keeps_real_errors() -> None:
|
||||||
|
filter_ = WebSocketHandshakeNoiseFilter()
|
||||||
|
|
||||||
|
assert filter_.filter(_log_record(OPENING_HANDSHAKE_FAILED_MESSAGE, RuntimeError("boom")))
|
||||||
|
assert filter_.filter(_log_record("connection handler failed", BrokenPipeError()))
|
||||||
Loading…
x
Reference in New Issue
Block a user