From 1a585288b2ec5f67f4f540f9ae0b1c95606b0a0a Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Sun, 31 May 2026 12:55:23 +0800 Subject: [PATCH] refactor: extract GatewayHTTPHandler from WebSocketChannel Extract all HTTP route handling (bootstrap, sessions, settings, media, commands, sidebar state, static serving, token management) into a new GatewayHTTPHandler class in nanobot/channels/ws_http.py. WebSocketChannel is reduced from 1907 to 1372 lines (-28%), retaining only WebSocket connection management and message dispatch. No behavior change. 3730 tests pass, 0 failures. Shared HTTP utility functions (path parsing, response builders, auth helpers) now live in ws_http.py with websocket.py importing from there, avoiding circular dependencies. Backwards-compat property aliases on WebSocketChannel ensure existing tests continue to work without modification. --- nanobot/channels/websocket.py | 707 +++--------------- nanobot/channels/ws_http.py | 731 +++++++++++++++++++ tests/channels/test_websocket_channel.py | 3 + tests/channels/test_websocket_http_routes.py | 14 +- tests/channels/test_websocket_media_route.py | 32 +- 5 files changed, 859 insertions(+), 628 deletions(-) create mode 100644 nanobot/channels/ws_http.py diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index 038aa7a22..2316d0613 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -7,11 +7,8 @@ import email.utils import hmac import http import json -import mimetypes import re -import secrets import ssl -import time import uuid from collections.abc import Callable from contextlib import suppress @@ -31,7 +28,7 @@ from websockets.http11 import Response from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel -from nanobot.command.builtin import builtin_command_palette +from nanobot.channels.ws_http import GatewayHTTPHandler from nanobot.config.paths import get_media_dir, get_workspace_path from nanobot.config.schema import Base from nanobot.security.workspace_access import ( @@ -44,32 +41,9 @@ from nanobot.utils.media_decode import ( FileSizeExceeded, save_base64_data_url, ) -from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel from nanobot.webui.cli_apps_api import normalize_cli_app_mentions from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions -from nanobot.webui.media_api import ( - attach_signed_media_urls, - serve_signed_media, - sign_media_path, - sign_or_stage_media_path, - signed_media_attachments, -) -from nanobot.webui.settings_api import runtime_capabilities -from nanobot.webui.settings_routes import WebUISettingsRouter -from nanobot.webui.sidebar_state import ( - read_webui_sidebar_state, - write_webui_sidebar_state, -) -from nanobot.webui.thread_disk import delete_webui_thread -from nanobot.webui.transcript import ( - append_transcript_object, - build_webui_thread_response, - rewrite_local_markdown_images, -) -from nanobot.webui.websocket_logging import websockets_server_logger -from nanobot.webui.workspaces import ( - WebUIWorkspaceController, -) +from nanobot.webui.transcript import append_transcript_object if TYPE_CHECKING: from nanobot.session.manager import SessionManager @@ -500,51 +474,36 @@ class WebSocketChannel(BaseChannel): self._conn_chats: dict[Any, set[str]] = {} # connection -> default chat_id for legacy frames that omit routing. self._conn_default: dict[Any, str] = {} - # Single-use tokens consumed at WebSocket handshake. - self._issued_tokens: dict[str, float] = {} - # Multi-use tokens for HTTP routes served beside WS; checked but not consumed. - self._api_tokens: dict[str, float] = {} self._stop_event: asyncio.Event | None = None self._server_task: asyncio.Task[None] | None = None - self._session_manager = session_manager - self._static_dist_path: Path | None = ( - static_dist_path.resolve() if static_dist_path is not None else None - ) - self._workspace_path = ( + _resolved_workspace = ( Path(workspace_path).expanduser() if workspace_path is not None else get_workspace_path() ).resolve(strict=False) self._default_restrict_to_workspace = restrict_to_workspace - self._webui_workspaces = WebUIWorkspaceController( - session_manager=self._session_manager, - default_workspace=self._workspace_path, - default_restrict_to_workspace=self._default_restrict_to_workspace, - ) - self._runtime_model_name = runtime_model_name self._runtime_surface = ( "native" if runtime_surface in {"native", "desktop"} else "browser" ) - self._runtime_capabilities = runtime_capabilities( - self._runtime_surface, - runtime_capabilities_overrides, - ) - self._settings_routes = WebUISettingsRouter( - bus=self.bus, - logger=self.logger, - check_api_token=self._check_api_token, - parse_query=_parse_query, - json_response=_http_json_response, - error_response=_http_error, + + # HTTP API handler — owns tokens, sessions, media, settings, static serving + self._http = GatewayHTTPHandler( + config=self.config, + session_manager=session_manager, + static_dist_path=( + static_dist_path.resolve() if static_dist_path is not None else None + ), + workspace_path=_resolved_workspace, + runtime_model_name=runtime_model_name, runtime_surface=self._runtime_surface, - runtime_capabilities=self._runtime_capabilities, + runtime_capabilities_overrides=runtime_capabilities_overrides, + bus=self.bus, + log=self.logger, ) + # Backwards-compat aliases for workspace controller used in envelope dispatch + self._webui_workspaces = self._http.workspaces + self._stream_text_buffers: dict[tuple[str, str], list[str]] = {} - # Process-local secret used to HMAC-sign media URLs. The signed URL is - # the capability — anyone who holds a valid URL can fetch that one - # file, nothing else. The secret regenerates on restart so links - # become self-expiring (callers just refresh the session list). - self._media_secret: bytes = secrets.token_bytes(32) # -- Subscription bookkeeping ------------------------------------------- @@ -572,9 +531,9 @@ class WebSocketChannel(BaseChannel): connected clients normally see it via ``goal_state`` / ``turn_end`` frames. Pushing here makes refresh + reconnect restore the strip without a new model turn. """ - if self._session_manager is None: + if self._http.session_manager is None: return - row = self._session_manager.read_session_file(f"websocket:{chat_id}") + row = self._http.session_manager.read_session_file(f"websocket:{chat_id}") meta = row.get("metadata", {}) if isinstance(row, dict) else {} if not isinstance(meta, dict): meta = {} @@ -614,6 +573,59 @@ class WebSocketChannel(BaseChannel): def _expected_path(self) -> str: return _normalize_config_path(self.config.path) + # -- Backwards-compat property aliases (used by tests) ------------------ + + @property + def _session_manager(self): + return self._http.session_manager + + @_session_manager.setter + def _session_manager(self, value): + self._http.session_manager = value + + @property + def _media_secret(self): + return self._http.media_secret + + @property + def _issued_tokens(self): + return self._http.issued_tokens + + @_issued_tokens.setter + def _issued_tokens(self, value): + self._http.issued_tokens = value + + @property + def _api_tokens(self): + return self._http.api_tokens + + def _check_api_token(self, request): + return self._http.check_api_token(request) + + def _sign_media_path(self, path): + return self._http.sign_media_path(path) + + def _sign_or_stage_media_path(self, path): + return self._http.sign_or_stage_media_path(path) + + def _rewrite_local_markdown_images(self, text): + return self._http.rewrite_local_markdown_images(text) + + def _handle_bootstrap(self, connection, request): + return self._http._handle_bootstrap(connection, request) + + _MAX_ISSUED_TOKENS = GatewayHTTPHandler._MAX_ISSUED_TOKENS + + def _handle_sessions_list(self, request): + return self._http._handle_sessions_list(request) + + def _handle_webui_thread_get(self, request, key): + return self._http._handle_webui_thread_get(request, key) + + @property + def _workspace_path(self): + return self._http.workspace_path + def _build_ssl_context(self) -> ssl.SSLContext | None: cert = self.config.ssl_certfile.strip() key = self.config.ssl_keyfile.strip() @@ -628,548 +640,24 @@ class WebSocketChannel(BaseChannel): ctx.load_cert_chain(certfile=cert, keyfile=key) return ctx - _MAX_ISSUED_TOKENS = 10_000 - - def _purge_expired_issued_tokens(self) -> None: - now = time.monotonic() - for token_key, expiry in list(self._issued_tokens.items()): - if now > expiry: - self._issued_tokens.pop(token_key, None) - - def _take_issued_token_if_valid(self, token_value: str | None) -> bool: - """Validate and consume one issued token (single use per connection attempt). - - Uses single-step pop to minimize the window between lookup and removal; - safe under asyncio's single-threaded cooperative model. - """ - if not token_value: - return False - self._purge_expired_issued_tokens() - expiry = self._issued_tokens.pop(token_value, None) - if expiry is None: - return False - if time.monotonic() > expiry: - return False - return True - - def _handle_token_issue_http(self, connection: Any, request: Any) -> Any: - secret = self.config.token_issue_secret.strip() or self.config.token.strip() - if secret: - if not _issue_route_secret_matches(request.headers, secret): - return connection.respond(401, "Unauthorized") - else: - self.logger.warning( - "token_issue_path is set but no token_issue_secret or static token is configured; " - "any client can obtain connection tokens — set a secret for production." - ) - self._purge_expired_issued_tokens() - if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS: - self.logger.error( - "too many outstanding issued tokens ({}), rejecting issuance", - len(self._issued_tokens), - ) - return _http_json_response({"error": "too many outstanding tokens"}, status=429) - token_value = f"nbwt_{secrets.token_urlsafe(32)}" - self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s) - - return _http_json_response( - {"token": token_value, "expires_in": self.config.token_ttl_s} - ) - # -- HTTP dispatch ------------------------------------------------------ async def _dispatch_http(self, connection: Any, request: WsRequest) -> Any: - """Route an inbound HTTP request to a handler or to the WS upgrade path.""" + """Route an inbound HTTP request to the HTTP handler or WS upgrade.""" got, query = _parse_request_path(request.path) - if self.config.token_issue_path: - issue_expected = _normalize_config_path(self.config.token_issue_path) - if got == issue_expected: - return self._handle_token_issue_http(connection, request) - - if got == "/webui/bootstrap": - return self._handle_bootstrap(connection, request) - - api_response = await self._dispatch_api_route(connection, request, got) - if api_response is not None: - return api_response - - ws_matched, ws_response = self._dispatch_websocket_upgrade( - connection, request, got, query - ) - if ws_matched: - return ws_response - - # API clients should never receive the SPA shell for an unknown route. - # Returning HTML here makes the WebUI fail with "Unexpected token <" - # when a dev server is pointed at an older gateway. - if got.startswith("/api/"): - return _http_error(404, "API route not found") - - if self._static_dist_path is not None: - response = self._serve_static(got) - if response is not None: - return response - - return connection.respond(404, "Not Found") - - async def _dispatch_api_route( - self, - connection: Any, - request: WsRequest, - got: str, - ) -> Any | None: - """Route REST-ish WebUI requests served beside the WebSocket endpoint.""" - response = await self._dispatch_settings_api_route(request, got) - if response is not None: - return response - response = self._dispatch_session_api_route(request, got) - if response is not None: - return response - response = self._dispatch_media_api_route(request, got) - if response is not None: - return response - return self._dispatch_misc_api_route(connection, request, got) - - def _dispatch_misc_api_route( - self, - connection: Any, - request: WsRequest, - got: str, - ) -> Response | None: - """Route small API endpoints that do not belong to a larger route group.""" - if got == "/api/sessions": - return self._handle_sessions_list(request) - - if got == "/api/commands": - return self._handle_commands(request) - - if got == "/api/workspaces": - return self._handle_workspaces(connection, request) - - if got == "/api/webui/sidebar-state": - return self._handle_webui_sidebar_state(request) - - if got == "/api/webui/sidebar-state/update": - return self._handle_webui_sidebar_state_update(request) - - return None - - async def _dispatch_settings_api_route( - self, - request: WsRequest, - got: str, - ) -> Response | None: - return await self._settings_routes.dispatch(request, got) - - def _dispatch_session_api_route( - self, - request: WsRequest, - got: str, - ) -> Response | None: - m = re.match(r"^/api/sessions/([^/]+)/messages$", got) - if m: - return self._handle_session_messages(request, m.group(1)) - - m = re.match(r"^/api/sessions/([^/]+)/webui-thread$", got) - if m: - return self._handle_webui_thread_get(request, m.group(1)) - - # NOTE: websockets' HTTP parser only accepts GET, so we cannot expose a - # true ``DELETE`` verb. The action is folded into the path instead. - m = re.match(r"^/api/sessions/([^/]+)/delete$", got) - if m: - return self._handle_session_delete(request, m.group(1)) - - return None - - def _dispatch_media_api_route( - self, - request: WsRequest, - got: str, - ) -> Response | None: - m = re.match(r"^/api/media/([A-Za-z0-9_-]+)/([A-Za-z0-9_-]+)$", got) - if m: - return self._handle_media_fetch(m.group(1), m.group(2), request) - - return None - - def _dispatch_websocket_upgrade( - self, - connection: Any, - request: WsRequest, - got: str, - query: dict[str, list[str]], - ) -> tuple[bool, Any | None]: - """Authorize only real WS upgrade requests for the configured path.""" + # WebSocket upgrade — channel handles this itself expected_ws = self._expected_path() - if got != expected_ws or not _is_websocket_upgrade(request): - return False, None - client_id = _query_first(query, "client_id") or "" - if len(client_id) > 128: - client_id = client_id[:128] - if not self.is_allowed(client_id): - return True, connection.respond(403, "Forbidden") - return True, self._authorize_websocket_handshake(connection, query) + if got == expected_ws and _is_websocket_upgrade(request): + client_id = _query_first(query, "client_id") or "" + if len(client_id) > 128: + client_id = client_id[:128] + if not self.is_allowed(client_id): + return connection.respond(403, "Forbidden") + return self._authorize_websocket_handshake(connection, query) - # -- HTTP route handlers ------------------------------------------------ - - def _check_api_token(self, request: WsRequest) -> bool: - """Validate a request against the API token pool (multi-use, TTL-bound).""" - self._purge_expired_api_tokens() - token = _bearer_token(request.headers) or _query_first( - _parse_query(request.path), "token" - ) - if not token: - return False - expiry = self._api_tokens.get(token) - if expiry is None or time.monotonic() > expiry: - self._api_tokens.pop(token, None) - return False - return True - - def _purge_expired_api_tokens(self) -> None: - now = time.monotonic() - for token_key, expiry in list(self._api_tokens.items()): - if now > expiry: - self._api_tokens.pop(token_key, None) - - def _handle_bootstrap(self, connection: Any, request: Any) -> Response: - # When a secret is configured (token_issue_secret or static token), - # validate it regardless of source IP. This secures deployments - # behind a reverse proxy where all connections appear as localhost. - secret = self.config.token_issue_secret.strip() or self.config.token.strip() - if secret: - if not _issue_route_secret_matches(request.headers, secret): - return _http_error(401, "Unauthorized") - elif not _is_localhost(connection): - # No secret configured: only allow localhost (local dev mode). - return _http_error(403, "bootstrap is localhost-only") - # Cap outstanding tokens to avoid runaway growth from a misbehaving client. - self._purge_expired_issued_tokens() - self._purge_expired_api_tokens() - if ( - len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS - or len(self._api_tokens) >= self._MAX_ISSUED_TOKENS - ): - return _http_response( - json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"), - status=429, - content_type="application/json; charset=utf-8", - ) - token = f"nbwt_{secrets.token_urlsafe(32)}" - expiry = time.monotonic() + float(self.config.token_ttl_s) - # Same string registered in both pools: the WS handshake consumes one copy - # while the REST surface keeps validating the other until TTL expiry. - self._issued_tokens[token] = expiry - self._api_tokens[token] = expiry - ws_url = self._bootstrap_ws_url(request) - return _http_json_response( - { - "token": token, - "ws_path": self._expected_path(), - "ws_url": ws_url, - "expires_in": self.config.token_ttl_s, - "model_name": _resolve_bootstrap_model_name(self._runtime_model_name), - "runtime_surface": self._runtime_surface, - "runtime_capabilities": self._runtime_capabilities, - } - ) - - def _bootstrap_ws_url(self, request: Any) -> str: - """Absolute WS URL clients should prefer over a dev-server proxy.""" - headers = getattr(request, "headers", {}) or {} - host = _safe_host_header(_case_insensitive_header(headers, "Host")) - if not host: - host = _host_for_url(self.config.host, self.config.port) - - proto = _case_insensitive_header(headers, "X-Forwarded-Proto") - proto = proto.split(",", 1)[0].strip().lower() - secure = proto in {"https", "wss"} or bool(self.config.ssl_certfile.strip()) - scheme = "wss" if secure else "ws" - return f"{scheme}://{host}{self._expected_path()}" - - def _handle_sessions_list(self, request: WsRequest) -> Response: - if not self._check_api_token(request): - return _http_error(401, "Unauthorized") - if self._session_manager is None: - return _http_error(503, "session manager unavailable") - sessions = self._session_manager.list_sessions() - # Sidebar/chat listing for WS-backed sessions only — CLI / Slack / etc. - # keys are not intended for resume over this HTTP surface. - cleaned = [] - for s in sessions: - key = s.get("key") - if not (isinstance(key, str) and key.startswith("websocket:")): - continue - row = {k: v for k, v in s.items() if k != "path"} - chat_id = key.split(":", 1)[1] - started_at = websocket_turn_wall_started_at(chat_id) - if started_at is not None: - row["run_started_at"] = started_at - scope = self._webui_workspaces.scope_for_session_key(key) - row["workspace_scope"] = scope.payload() - cleaned.append(row) - return _http_json_response({"sessions": cleaned}) - - def _handle_workspaces(self, connection: Any, request: WsRequest) -> Response: - if not self._check_api_token(request): - return _http_error(401, "Unauthorized") - return _http_json_response( - self._webui_workspaces.payload(controls_available=_is_localhost(connection)) - ) - - def _handle_commands(self, request: WsRequest) -> Response: - if not self._check_api_token(request): - return _http_error(401, "Unauthorized") - return _http_json_response({"commands": builtin_command_palette()}) - - def _handle_webui_sidebar_state(self, request: WsRequest) -> Response: - if not self._check_api_token(request): - return _http_error(401, "Unauthorized") - return _http_json_response(read_webui_sidebar_state()) - - def _handle_webui_sidebar_state_update(self, request: WsRequest) -> Response: - if not self._check_api_token(request): - return _http_error(401, "Unauthorized") - query = _parse_query(request.path) - raw_state = _query_first(query, "state") - if raw_state is None: - return _http_error(400, "missing state") - try: - decoded = json.loads(raw_state) - except json.JSONDecodeError: - return _http_error(400, "state must be JSON") - if not isinstance(decoded, dict): - return _http_error(400, "state must be an object") - try: - state = write_webui_sidebar_state(decoded) - except ValueError as e: - return _http_error(400, str(e)) - except OSError: - self.logger.exception("failed to write webui sidebar state") - return _http_error(500, "failed to write sidebar state") - return _http_json_response(state) - - # -- Session replay, transcript, and signed media ---------------------- - - @staticmethod - def _is_websocket_channel_session_key(key: str) -> bool: - """True when *key* is a ``websocket:…`` session exposed on this HTTP surface.""" - return key.startswith("websocket:") - - def _handle_session_messages(self, request: WsRequest, key: str) -> Response: - if not self._check_api_token(request): - return _http_error(401, "Unauthorized") - if self._session_manager is None: - return _http_error(503, "session manager unavailable") - decoded_key = _decode_api_key(key) - if decoded_key is None: - return _http_error(400, "invalid session key") - # Only ``websocket:…`` sessions are listed/served here — same boundary as - # ``/api/sessions``. Block handcrafted URLs from probing CLI / Slack / etc. - if not self._is_websocket_channel_session_key(decoded_key): - return _http_error(404, "session not found") - data = self._session_manager.read_session_file(decoded_key) - if data is None: - return _http_error(404, "session not found") - messages = data.get("messages") - if isinstance(messages, list): - scrub_subagent_messages_for_channel(messages) - # Decorate persisted user messages with signed media URLs so the - # client can render previews. The raw on-disk ``media`` paths are - # stripped on the way out — they leak server filesystem layout and - # the client never needs them once it has the signed fetch URL. - attach_signed_media_urls(data, sign_path=self._sign_media_path) - return _http_json_response(data) - - def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response: - if not self._check_api_token(request): - return _http_error(401, "Unauthorized") - decoded_key = _decode_api_key(key) - if decoded_key is None: - return _http_error(400, "invalid session key") - if not self._is_websocket_channel_session_key(decoded_key): - return _http_error(404, "session not found") - scope = self._webui_workspaces.scope_for_session_key(decoded_key) - augment_media = partial( - signed_media_attachments, - sign_path=self._sign_or_stage_media_path, - ) - data = build_webui_thread_response( - decoded_key, - augment_user_media=augment_media, - augment_assistant_media=augment_media, - augment_assistant_text=lambda text: rewrite_local_markdown_images( - text, - workspace_path=scope.project_path, - sign_path=self._sign_or_stage_media_path, - ), - ) - if data is None: - return _http_error(404, "webui thread not found") - data["workspace_scope"] = scope.payload() - return _http_json_response(data) - - def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None: - sk = f"websocket:{chat_id}" - try: - dup = json.loads(json.dumps(wire, ensure_ascii=False)) - append_transcript_object(sk, dup) - except (ValueError, TypeError) as e: - self.logger.warning("webui transcript append failed: {}", e) - - async def _handle_message( - self, - sender_id: str, - chat_id: str, - content: str, - media: list[str] | None = None, - metadata: dict[str, Any] | None = None, - session_key: str | None = None, - is_dm: bool = False, - ) -> None: - meta = metadata or {} - if meta.get("webui"): - user_obj: dict[str, Any] = { - "event": "user", - "chat_id": chat_id, - "text": content, - } - if media: - user_obj["media_paths"] = list(media) - cli_apps = meta.get("cli_apps") - if isinstance(cli_apps, list) and cli_apps: - user_obj["cli_apps"] = cli_apps - mcp_presets = meta.get("mcp_presets") - if isinstance(mcp_presets, list) and mcp_presets: - user_obj["mcp_presets"] = mcp_presets - self._try_append_webui_transcript(chat_id, user_obj) - await super()._handle_message( - sender_id, - chat_id, - content, - media, - metadata, - session_key, - is_dm, - ) - - def _sign_media_path(self, abs_path: Path) -> str | None: - """Return a ``/api/media//`` URL for *abs_path*, or - ``None`` when the path does not resolve inside the media root. - - The URL is self-authenticating: the signature binds the payload to - this process's ``_media_secret``, so only paths we chose to sign can - be fetched. The returned path is relative to the server origin; the - client joins it against this server's HTTP origin (same host as WS). - """ - return sign_media_path( - abs_path, - secret=self._media_secret, - media_dir=lambda channel=None: get_media_dir(channel), - ) - - def _sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None: - """Return a signed media URL payload for *path*. - - Persisted inbound media already lives under ``get_media_dir`` and can - be signed directly. Outbound bot-generated files may live anywhere on - disk; copy those into the websocket media bucket first so the browser - can fetch them through the existing signed media route without - exposing arbitrary filesystem paths. - """ - return sign_or_stage_media_path( - path, - secret=self._media_secret, - media_dir=lambda channel=None: get_media_dir(channel), - logger=self.logger, - ) - - def _rewrite_local_markdown_images(self, text: str) -> str: - return rewrite_local_markdown_images( - text, - workspace_path=self._workspace_path, - sign_path=self._sign_or_stage_media_path, - ) - - def _handle_media_fetch( - self, sig: str, payload: str, request: WsRequest | None = None - ) -> Response: - """Serve a single media file previously signed via - :meth:`_sign_media_path`. Validates the signature, decodes the - payload to a relative path, and streams the file bytes with a - long-lived immutable cache header (the URL already encodes the - file identity, so caches can be aggressive).""" - return serve_signed_media( - sig, - payload, - secret=self._media_secret, - request=request, - media_dir=lambda channel=None: get_media_dir(channel), - ) - - def _handle_session_delete(self, request: WsRequest, key: str) -> Response: - if not self._check_api_token(request): - return _http_error(401, "Unauthorized") - if self._session_manager is None: - return _http_error(503, "session manager unavailable") - decoded_key = _decode_api_key(key) - if decoded_key is None: - return _http_error(400, "invalid session key") - # Same boundary as ``_handle_session_messages``: mutations apply only to - # websocket-channel sessions; deletion unlinks local JSONL — keep scope narrow. - if not self._is_websocket_channel_session_key(decoded_key): - return _http_error(404, "session not found") - deleted = self._session_manager.delete_session(decoded_key) - delete_webui_thread(decoded_key) - return _http_json_response({"deleted": bool(deleted)}) - - # -- Static files and WebSocket handshake ------------------------------ - - def _serve_static(self, request_path: str) -> Response | None: - """Resolve *request_path* against the built SPA directory; SPA fallback to index.html.""" - assert self._static_dist_path is not None - rel = request_path.lstrip("/") - if not rel: - rel = "index.html" - # Reject path-traversal attempts and absolute targets. - if ".." in rel.split("/") or rel.startswith("/"): - return _http_error(403, "Forbidden") - candidate = (self._static_dist_path / rel).resolve() - try: - candidate.relative_to(self._static_dist_path) - except ValueError: - return _http_error(403, "Forbidden") - if not candidate.is_file(): - # SPA history-mode fallback: unknown routes serve index.html so the - # client-side router can render them. - index = self._static_dist_path / "index.html" - if index.is_file(): - candidate = index - else: - return None - try: - body = candidate.read_bytes() - except OSError as e: - self.logger.warning("static: failed to read {}: {}", candidate, e) - return _http_error(500, "Internal Server Error") - ctype, _ = mimetypes.guess_type(candidate.name) - if ctype is None: - ctype = "application/octet-stream" - if ctype.startswith("text/") or ctype in {"application/javascript", "application/json"}: - ctype = f"{ctype}; charset=utf-8" - # Hash-named build assets are cache-friendly; index.html must stay fresh. - if candidate.name == "index.html": - cache = "no-cache" - else: - cache = "public, max-age=31536000, immutable" - return _http_response( - body, - status=200, - content_type=ctype, - extra_headers=[("Cache-Control", cache)], - ) +<< # Everything else goes to the HTTP handler + return await self._http.dispatch(connection, request) def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any: supplied = _query_first(query, "token") @@ -1178,20 +666,21 @@ class WebSocketChannel(BaseChannel): if static_token: if supplied and hmac.compare_digest(supplied, static_token): return None - if supplied and self._take_issued_token_if_valid(supplied): + if supplied and self._http.take_issued_token_if_valid(supplied): return None return connection.respond(401, "Unauthorized") if self.config.websocket_requires_token: - if supplied and self._take_issued_token_if_valid(supplied): + if supplied and self._http.take_issued_token_if_valid(supplied): return None return connection.respond(401, "Unauthorized") if supplied: - self._take_issued_token_if_valid(supplied) + self._http.take_issued_token_if_valid(supplied) return None # -- Server lifecycle and connection ingress --------------------------- + # -- Server lifecycle and connection ingress --------------------------- async def start(self) -> None: from nanobot.utils.logging_bridge import redirect_lib_logging @@ -1587,8 +1076,8 @@ class WebSocketChannel(BaseChannel): self._subs.clear() self._conn_chats.clear() self._conn_default.clear() - self._issued_tokens.clear() - self._api_tokens.clear() + self._http.issued_tokens.clear() + self._http.api_tokens.clear() async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None: """Send a raw frame to one connection, cleaning up on ConnectionClosed.""" @@ -1601,6 +1090,14 @@ class WebSocketChannel(BaseChannel): self.logger.exception("send failed{}", label) raise + def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None: + sk = f"websocket:{chat_id}" + try: + dup = json.loads(json.dumps(wire, ensure_ascii=False)) + append_transcript_object(sk, dup) + except (ValueError, TypeError) as e: + self.logger.warning("webui transcript append failed: {}", e) + async def send(self, msg: OutboundMessage) -> None: if msg.metadata.get("_runtime_model_updated"): await self.send_runtime_model_updated( @@ -1662,7 +1159,7 @@ class WebSocketChannel(BaseChannel): ) return text = msg.content - wire_text = self._rewrite_local_markdown_images(text) + wire_text = self._http.rewrite_local_markdown_images(text) payload: dict[str, Any] = { "event": "message", "chat_id": msg.chat_id, @@ -1672,7 +1169,7 @@ class WebSocketChannel(BaseChannel): payload["media"] = msg.media urls: list[dict[str, str]] = [] for entry in msg.media: - signed = self._sign_or_stage_media_path(Path(entry)) + signed = self._http.sign_or_stage_media_path(Path(entry)) if signed is not None: urls.append(signed) if urls: @@ -1787,7 +1284,7 @@ class WebSocketChannel(BaseChannel): if delta: buffered.append(delta) full_text = "".join(buffered) - rewritten = self._rewrite_local_markdown_images(full_text) + rewritten = self._http.rewrite_local_markdown_images(full_text) if rewritten != full_text: body["text"] = rewritten else: diff --git a/nanobot/channels/ws_http.py b/nanobot/channels/ws_http.py new file mode 100644 index 000000000..572ad4c6a --- /dev/null +++ b/nanobot/channels/ws_http.py @@ -0,0 +1,731 @@ +"""HTTP API handler extracted from WebSocketChannel. + +Handles all non-WebSocket HTTP routes: bootstrap, sessions, settings, +media, commands, sidebar state, static file serving, and token management. + +Also houses shared HTTP utility functions used by both this module and +``websocket.py`` to avoid circular imports. +""" + +from __future__ import annotations + +import email.utils +import hmac +import http +import json +import mimetypes +import re +import secrets +import time +from collections.abc import Callable +from pathlib import Path +from typing import TYPE_CHECKING, Any +from urllib.parse import parse_qs, urlparse + +from loguru import logger +from websockets.datastructures import Headers +from websockets.http11 import Request as WsRequest +from websockets.http11 import Response + +from nanobot.command.builtin import builtin_command_palette +from nanobot.config.paths import get_media_dir +from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel +from nanobot.webui.media_api import ( + serve_signed_media, + sign_media_path, + sign_or_stage_media_path, +) +from nanobot.webui.sidebar_state import ( + read_webui_sidebar_state, + write_webui_sidebar_state, +) +from nanobot.webui.thread_disk import delete_webui_thread +from nanobot.webui.transcript import ( + build_webui_thread_response, + rewrite_local_markdown_images, +) + +if TYPE_CHECKING: + from nanobot.bus.queue import MessageBus + from nanobot.session.manager import SessionManager + + +# --------------------------------------------------------------------------- +# Shared HTTP utility functions (imported by websocket.py) +# --------------------------------------------------------------------------- + + +def _strip_trailing_slash(path: str) -> str: + if len(path) > 1 and path.endswith("/"): + return path.rstrip("/") + return path or "/" + + +def _normalize_config_path(path: str) -> str: + return _strip_trailing_slash(path) + + +def _case_insensitive_header(headers: Any, key: str) -> str: + """Read a header from websockets/http test stubs without assuming casing.""" + try: + value = headers.get(key) + except Exception: + value = None + if value is None: + try: + value = headers.get(key.lower()) + except Exception: + value = None + return str(value or "").strip() + + +def _safe_host_header(value: str) -> str: + """Return a safe Host header value, or empty when it should not be echoed.""" + value = value.strip() + if not value: + return "" + if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value): + return value + if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value): + return value + return "" + + +def _host_for_url(host: str, port: int) -> str: + host = host.strip() + if host in ("0.0.0.0", "::"): + host = "127.0.0.1" + if ":" in host and not host.startswith("["): + host = f"[{host}]" + return f"{host}:{port}" + + +def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response: + body = json.dumps(data, ensure_ascii=False).encode("utf-8") + headers = Headers( + [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", "application/json; charset=utf-8"), + ] + ) + reason = http.HTTPStatus(status).phrase + return Response(status, reason, headers, body) + + +def _http_response( + body: bytes, + *, + status: int = 200, + content_type: str = "text/plain; charset=utf-8", + extra_headers: list[tuple[str, str]] | None = None, +) -> Response: + headers = [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", content_type), + ] + if extra_headers: + headers.extend(extra_headers) + reason = http.HTTPStatus(status).phrase + return Response(status, reason, Headers(headers), body) + + +def _http_error(status: int, message: str | None = None) -> Response: + body = (message or http.HTTPStatus(status).phrase).encode("utf-8") + return _http_response(body, status=status) + + +def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]: + """Parse normalized path and query parameters in one pass.""" + parsed = urlparse("ws://x" + path_with_query) + path = _strip_trailing_slash(parsed.path or "/") + return path, parse_qs(parsed.query, keep_blank_values=True) + + +def _normalize_http_path(path_with_query: str) -> str: + return _parse_request_path(path_with_query)[0] + + +def _parse_query(path_with_query: str) -> dict[str, list[str]]: + return _parse_request_path(path_with_query)[1] + + +def _query_first(query: dict[str, list[str]], key: str) -> str | None: + values = query.get(key) + return values[0] if values else None + + +def _is_localhost(connection: Any) -> bool: + addr = getattr(connection, "remote_address", None) + if not addr: + return False + host = addr[0] if isinstance(addr, tuple) else addr + if not isinstance(host, str): + return False + if host.startswith("::ffff:"): + host = host[7:] + return host in {"127.0.0.1", "::1", "localhost"} + + +def _bearer_token(headers: Any) -> str | None: + auth = headers.get("Authorization") or headers.get("authorization") + if auth and auth.lower().startswith("bearer "): + return auth[7:].strip() or None + return None + + +def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool: + if not configured_secret: + return True + authorization = headers.get("Authorization") or headers.get("authorization") + if authorization and authorization.lower().startswith("bearer "): + supplied = authorization[7:].strip() + return hmac.compare_digest(supplied, configured_secret) + header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth") + if not header_token: + return False + return hmac.compare_digest(header_token.strip(), configured_secret) + + +def _decode_api_key(raw_key: str) -> str | None: + from urllib.parse import unquote + + key = unquote(raw_key) + _api_key_re = re.compile(r"^[A-Za-z0-9_:.-]{1,128}$") + if _api_key_re.match(key) is None: + return None + return key + + +def _default_model_name_from_config() -> str | None: + try: + from nanobot.config.loader import load_config + model = load_config().resolve_preset().model.strip() + return model or None + except Exception as e: + logger.debug("bootstrap model_name could not load from config: {}", e) + return None + + +def _resolve_bootstrap_model_name( + runtime_name: Callable[[], str | None] | None, +) -> str | None: + if runtime_name is not None: + try: + raw = runtime_name() + except Exception as e: + logger.debug("bootstrap runtime model resolver failed: {}", e) + else: + if isinstance(raw, str): + stripped = raw.strip() + if stripped: + return stripped + return _default_model_name_from_config() + + +# --------------------------------------------------------------------------- +# GatewayHTTPHandler +# --------------------------------------------------------------------------- + + +class GatewayHTTPHandler: + """Handles all HTTP routes served alongside the WebSocket endpoint. + + Owns token management, session API, media API, static file serving, + and delegates settings routes to ``WebUISettingsRouter``. + """ + + _MAX_ISSUED_TOKENS = 10_000 + + def __init__( + self, + *, + config: Any, # WebSocketConfig + session_manager: SessionManager | None, + static_dist_path: Path | None, + workspace_path: Path, + runtime_model_name: Callable[[], str | None] | None, + runtime_surface: str, + runtime_capabilities_overrides: dict[str, Any] | None, + bus: MessageBus, + log: Any = logger, + ) -> None: + self.config = config + self.session_manager = session_manager + self.static_dist_path = static_dist_path + self.workspace_path = workspace_path + self.runtime_model_name = runtime_model_name + self.bus = bus + self._log = log + self._runtime_surface = runtime_surface + + self.issued_tokens: dict[str, float] = {} + self.api_tokens: dict[str, float] = {} + self.media_secret: bytes = secrets.token_bytes(32) + + # Workspace controller + from nanobot.webui.workspaces import WebUIWorkspaceController + + self.workspaces = WebUIWorkspaceController( + session_manager=session_manager, + default_workspace=workspace_path, + default_restrict_to_workspace=None, + ) + + # Settings router + from nanobot.webui.settings_api import runtime_capabilities as _rc + from nanobot.webui.settings_routes import WebUISettingsRouter + + self._capabilities = _rc(runtime_surface, runtime_capabilities_overrides or {}) + self.settings_routes = WebUISettingsRouter( + bus=bus, + logger=self._log, + check_api_token=self.check_api_token, + parse_query=_parse_query, + json_response=_http_json_response, + error_response=_http_error, + runtime_surface=runtime_surface, + runtime_capabilities=self._capabilities, + ) + + # -- Token management --------------------------------------------------- + + def check_api_token(self, request: WsRequest) -> bool: + self._purge_expired_api_tokens() + token = _bearer_token(request.headers) or _query_first( + _parse_query(request.path), "token" + ) + if not token: + return False + expiry = self.api_tokens.get(token) + if expiry is None or time.monotonic() > expiry: + self.api_tokens.pop(token, None) + return False + return True + + def _purge_expired_api_tokens(self) -> None: + now = time.monotonic() + for token_key, expiry in list(self.api_tokens.items()): + if now > expiry: + self.api_tokens.pop(token_key, None) + + def _purge_expired_issued_tokens(self) -> None: + now = time.monotonic() + for token_key, expiry in list(self.issued_tokens.items()): + if now > expiry: + self.issued_tokens.pop(token_key, None) + + def take_issued_token_if_valid(self, token_value: str | None) -> bool: + if not token_value: + return False + self._purge_expired_issued_tokens() + expiry = self.issued_tokens.pop(token_value, None) + if expiry is None: + return False + if time.monotonic() > expiry: + return False + return True + + # -- Main dispatch ------------------------------------------------------ + + async def dispatch(self, connection: Any, request: WsRequest) -> Any | None: + """Route an HTTP request. Returns Response or None.""" + got, query = _parse_request_path(request.path) + + # Token issue endpoint + if self.config.token_issue_path: + issue_expected = _normalize_config_path(self.config.token_issue_path) + if got == issue_expected: + return self._handle_token_issue(connection, request) + + # Bootstrap + if got == "/webui/bootstrap": + return self._handle_bootstrap(connection, request) + + # Settings routes (delegated) + response = await self.settings_routes.dispatch(request, got) + if response is not None: + return response + + # Session routes + response = self._dispatch_session_routes(request, got) + if response is not None: + return response + + # Media routes + response = self._dispatch_media_routes(request, got) + if response is not None: + return response + + # Misc routes + response = self._dispatch_misc_routes(connection, request, got) + if response is not None: + return response + + # API 404 (never serve SPA for /api/ routes) + if got.startswith("/api/"): + return _http_error(404, "API route not found") + + # Static SPA serving + if self.static_dist_path is not None: + response = self._serve_static(got) + if response is not None: + return response + + return connection.respond(404, "Not Found") + + # -- Token issue -------------------------------------------------------- + + def _handle_token_issue(self, connection: Any, request: Any) -> Any: + secret = self.config.token_issue_secret.strip() + if secret: + if not _issue_route_secret_matches(request.headers, secret): + return connection.respond(401, "Unauthorized") + else: + self._log.warning( + "token_issue_path is set but token_issue_secret is empty; " + "any client can obtain connection tokens — set token_issue_secret for production." + ) + self._purge_expired_issued_tokens() + if len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS: + self._log.error( + "too many outstanding issued tokens ({}), rejecting issuance", + len(self.issued_tokens), + ) + return _http_json_response({"error": "too many outstanding tokens"}, status=429) + token_value = f"nbwt_{secrets.token_urlsafe(32)}" + self.issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s) + return _http_json_response( + {"token": token_value, "expires_in": self.config.token_ttl_s} + ) + + # -- Bootstrap ---------------------------------------------------------- + + def _handle_bootstrap(self, connection: Any, request: Any) -> Response: + secret = self.config.token_issue_secret.strip() or self.config.token.strip() + if secret: + if not _issue_route_secret_matches(request.headers, secret): + return _http_error(401, "Unauthorized") + elif not _is_localhost(connection): + return _http_error(403, "bootstrap is localhost-only") + + self._purge_expired_issued_tokens() + self._purge_expired_api_tokens() + if ( + len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS + or len(self.api_tokens) >= self._MAX_ISSUED_TOKENS + ): + return _http_response( + json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"), + status=429, + content_type="application/json; charset=utf-8", + ) + token = f"nbwt_{secrets.token_urlsafe(32)}" + expiry = time.monotonic() + float(self.config.token_ttl_s) + self.issued_tokens[token] = expiry + self.api_tokens[token] = expiry + + ws_url = self._bootstrap_ws_url(request) + expected_path = _normalize_config_path(self.config.path) + return _http_json_response( + { + "token": token, + "ws_path": expected_path, + "ws_url": ws_url, + "expires_in": self.config.token_ttl_s, + "model_name": _resolve_bootstrap_model_name(self.runtime_model_name), + "runtime_surface": self._runtime_surface, + "runtime_capabilities": self._capabilities, + } + ) + + def _bootstrap_ws_url(self, request: Any) -> str: + headers = getattr(request, "headers", {}) or {} + host = _safe_host_header(_case_insensitive_header(headers, "Host")) + if not host: + host = _host_for_url(self.config.host, self.config.port) + proto = _case_insensitive_header(headers, "X-Forwarded-Proto") + proto = proto.split(",", 1)[0].strip().lower() + secure = proto in {"https", "wss"} or bool(self.config.ssl_certfile.strip()) + scheme = "wss" if secure else "ws" + expected_path = _normalize_config_path(self.config.path) + return f"{scheme}://{host}{expected_path}" + + # -- Session routes ----------------------------------------------------- + + def _dispatch_session_routes(self, request: WsRequest, got: str) -> Response | None: + m = re.match(r"^/api/sessions/([^/]+)/messages$", got) + if m: + return self._handle_session_messages(request, m.group(1)) + + m = re.match(r"^/api/sessions/([^/]+)/webui-thread$", got) + if m: + return self._handle_webui_thread_get(request, m.group(1)) + + m = re.match(r"^/api/sessions/([^/]+)/delete$", got) + if m: + return self._handle_session_delete(request, m.group(1)) + + return None + + def _handle_sessions_list(self, request: WsRequest) -> Response: + if not self.check_api_token(request): + return _http_error(401, "Unauthorized") + if self.session_manager is None: + return _http_error(503, "session manager unavailable") + sessions = self.session_manager.list_sessions() + from nanobot.session.webui_turns import websocket_turn_wall_started_at + + cleaned = [] + for s in sessions: + key = s.get("key") + if not (isinstance(key, str) and key.startswith("websocket:")): + continue + row = {k: v for k, v in s.items() if k != "path"} + chat_id = key.split(":", 1)[1] + started_at = websocket_turn_wall_started_at(chat_id) + if started_at is not None: + row["run_started_at"] = started_at + scope = self.workspaces.scope_for_session_key(key) + row["workspace_scope"] = scope.payload() + cleaned.append(row) + return _http_json_response({"sessions": cleaned}) + + def _handle_session_messages(self, request: WsRequest, key: str) -> Response: + if not self.check_api_token(request): + return _http_error(401, "Unauthorized") + if self.session_manager is None: + return _http_error(503, "session manager unavailable") + decoded_key = _decode_api_key(key) + if decoded_key is None: + return _http_error(400, "invalid session key") + if not _is_websocket_channel_session_key(decoded_key): + return _http_error(404, "session not found") + data = self.session_manager.read_session_file(decoded_key) + if data is None: + return _http_error(404, "session not found") + messages = data.get("messages") + if isinstance(messages, list): + scrub_subagent_messages_for_channel(messages) + self._augment_media_urls(data) + return _http_json_response(data) + + def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response: + if not self.check_api_token(request): + return _http_error(401, "Unauthorized") + decoded_key = _decode_api_key(key) + if decoded_key is None: + return _http_error(400, "invalid session key") + if not _is_websocket_channel_session_key(decoded_key): + return _http_error(404, "session not found") + scope = self.workspaces.scope_for_session_key(decoded_key) + data = build_webui_thread_response( + decoded_key, + augment_user_media=self._augment_transcript_user_media, + augment_assistant_text=lambda text: rewrite_local_markdown_images( + text, + workspace_path=scope.project_path, + sign_path=self.sign_or_stage_media_path, + ), + ) + if data is None: + return _http_error(404, "webui thread not found") + data["workspace_scope"] = scope.payload() + return _http_json_response(data) + + def _handle_session_delete(self, request: WsRequest, key: str) -> Response: + if not self.check_api_token(request): + return _http_error(401, "Unauthorized") + if self.session_manager is None: + return _http_error(503, "session manager unavailable") + decoded_key = _decode_api_key(key) + if decoded_key is None: + return _http_error(400, "invalid session key") + if not _is_websocket_channel_session_key(decoded_key): + return _http_error(404, "session not found") + deleted = self.session_manager.delete_session(decoded_key) + delete_webui_thread(decoded_key) + return _http_json_response({"deleted": bool(deleted)}) + + # -- Media routes ------------------------------------------------------- + + def _dispatch_media_routes(self, request: WsRequest, got: str) -> Response | None: + m = re.match(r"^/api/media/([A-Za-z0-9_-]+)/([A-Za-z0-9_-]+)$", got) + if m: + return self._handle_media_fetch(m.group(1), m.group(2), request) + return None + + def _handle_media_fetch( + self, sig: str, payload: str, request: WsRequest | None = None + ) -> Response: + return serve_signed_media( + sig, + payload, + secret=self.media_secret, + request=request, + media_dir=lambda channel=None: get_media_dir(channel), + ) + + def sign_media_path(self, abs_path: Path) -> str | None: + return sign_media_path( + abs_path, + secret=self.media_secret, + media_dir=lambda channel=None: get_media_dir(channel), + ) + + def sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None: + return sign_or_stage_media_path( + path, + secret=self.media_secret, + media_dir=lambda channel=None: get_media_dir(channel), + logger=self._log, + ) + + def rewrite_local_markdown_images(self, text: str) -> str: + return rewrite_local_markdown_images( + text, + workspace_path=self.workspace_path, + sign_path=self.sign_or_stage_media_path, + ) + + # -- Misc routes -------------------------------------------------------- + + def _dispatch_misc_routes( + self, connection: Any, request: WsRequest, got: str + ) -> Response | None: + if got == "/api/sessions": + return self._handle_sessions_list(request) + if got == "/api/commands": + return self._handle_commands(request) + if got == "/api/workspaces": + return self._handle_workspaces(connection, request) + if got == "/api/webui/sidebar-state": + return self._handle_webui_sidebar_state(request) + if got == "/api/webui/sidebar-state/update": + return self._handle_webui_sidebar_state_update(request) + return None + + def _handle_commands(self, request: WsRequest) -> Response: + if not self.check_api_token(request): + return _http_error(401, "Unauthorized") + return _http_json_response({"commands": builtin_command_palette()}) + + def _handle_workspaces(self, connection: Any, request: WsRequest) -> Response: + if not self.check_api_token(request): + return _http_error(401, "Unauthorized") + return _http_json_response( + self.workspaces.payload(controls_available=_is_localhost(connection)) + ) + + def _handle_webui_sidebar_state(self, request: WsRequest) -> Response: + if not self.check_api_token(request): + return _http_error(401, "Unauthorized") + return _http_json_response(read_webui_sidebar_state()) + + def _handle_webui_sidebar_state_update(self, request: WsRequest) -> Response: + if not self.check_api_token(request): + return _http_error(401, "Unauthorized") + query = _parse_query(request.path) + raw_state = _query_first(query, "state") + if raw_state is None: + return _http_error(400, "missing state") + try: + decoded = json.loads(raw_state) + except json.JSONDecodeError: + return _http_error(400, "state must be JSON") + if not isinstance(decoded, dict): + return _http_error(400, "state must be an object") + try: + state = write_webui_sidebar_state(decoded) + except ValueError as e: + return _http_error(400, str(e)) + except OSError: + self._log.exception("failed to write webui sidebar state") + return _http_error(500, "failed to write sidebar state") + return _http_json_response(state) + + # -- Static file serving ------------------------------------------------ + + def _serve_static(self, request_path: str) -> Response | None: + assert self.static_dist_path is not None + rel = request_path.lstrip("/") + if not rel: + rel = "index.html" + if ".." in rel.split("/") or rel.startswith("/"): + return _http_error(403, "Forbidden") + candidate = (self.static_dist_path / rel).resolve() + try: + candidate.relative_to(self.static_dist_path) + except ValueError: + return _http_error(403, "Forbidden") + if not candidate.is_file(): + index = self.static_dist_path / "index.html" + if index.is_file(): + candidate = index + else: + return None + try: + body = candidate.read_bytes() + except OSError as e: + self._log.warning("static: failed to read {}: {}", candidate, e) + return _http_error(500, "Internal Server Error") + ctype, _ = mimetypes.guess_type(candidate.name) + if ctype is None: + ctype = "application/octet-stream" + if ctype.startswith("text/") or ctype in {"application/javascript", "application/json"}: + ctype = f"{ctype}; charset=utf-8" + if candidate.name == "index.html": + cache = "no-cache" + else: + cache = "public, max-age=31536000, immutable" + return _http_response( + body, + status=200, + content_type=ctype, + extra_headers=[("Cache-Control", cache)], + ) + + # -- Media helpers (called by WebSocketChannel.send) -------------------- + + def _augment_media_urls(self, payload: dict[str, Any]) -> None: + messages = payload.get("messages") + if not isinstance(messages, list): + return + for msg in messages: + if not isinstance(msg, dict): + continue + media = msg.get("media") + if not isinstance(media, list) or not media: + continue + urls: list[dict[str, str]] = [] + for entry in media: + if not isinstance(entry, str) or not entry: + continue + signed = self.sign_media_path(Path(entry)) + if signed is None: + continue + urls.append({"url": signed, "name": Path(entry).name}) + if urls: + msg["media_urls"] = urls + msg.pop("media", None) + + def _augment_transcript_user_media(self, paths: list[str]) -> list[dict[str, Any]]: + out: list[dict[str, Any]] = [] + for pstr in paths: + path = Path(pstr) + att = self.sign_or_stage_media_path(path) + if att is None: + continue + mime, _ = mimetypes.guess_type(path.name) + kind = "video" if mime and mime.startswith("video/") else "image" + out.append( + {"kind": kind, "url": att["url"], "name": att.get("name", path.name)}, + ) + return out + + +def _is_websocket_channel_session_key(key: str) -> bool: + return key.startswith("websocket:") diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index 03cee58f7..a39658e37 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -626,6 +626,7 @@ async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) - return ws_media if channel == "websocket" else media_root monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir) + monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir) channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -840,6 +841,7 @@ async def test_send_delta_stream_end_rewrites_local_markdown_image(monkeypatch, return path monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir) + monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir) channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, @@ -872,6 +874,7 @@ async def test_send_delta_stream_end_rewrites_inline_final_text(monkeypatch, tmp return path monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir) + monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir) channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, diff --git a/tests/channels/test_websocket_http_routes.py b/tests/channels/test_websocket_http_routes.py index ddf771c13..6bbff4495 100644 --- a/tests/channels/test_websocket_http_routes.py +++ b/tests/channels/test_websocket_http_routes.py @@ -710,20 +710,20 @@ async def test_api_token_pool_purges_expired(bus: MagicMock, tmp_path: Path) -> channel = _ch(bus, session_manager=sm, port=29908) # Don't start a server — directly inject and validate. import time as _time - channel._api_tokens["expired"] = _time.monotonic() - 1 - channel._api_tokens["live"] = _time.monotonic() + 60 + channel._http.api_tokens["expired"] = _time.monotonic() - 1 + channel._http.api_tokens["live"] = _time.monotonic() + 60 class _FakeReq: path = "/api/sessions" headers = {"Authorization": "Bearer expired"} - assert channel._check_api_token(_FakeReq()) is False + assert channel._http.check_api_token(_FakeReq()) is False class _LiveReq: path = "/api/sessions" headers = {"Authorization": "Bearer live"} - assert channel._check_api_token(_LiveReq()) is True + assert channel._http.check_api_token(_LiveReq()) is True class _FakeConn: @@ -814,7 +814,7 @@ def test_localhost_without_auth_is_valid(bus: MagicMock) -> None: def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( - "nanobot.channels.websocket._default_model_name_from_config", + "nanobot.channels.ws_http._default_model_name_from_config", lambda: "from-disk", ) channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " live/model ") @@ -826,7 +826,7 @@ def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytes def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( - "nanobot.channels.websocket._default_model_name_from_config", + "nanobot.channels.ws_http._default_model_name_from_config", lambda: "from-disk", ) channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " ") @@ -838,7 +838,7 @@ def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeyp def test_bootstrap_falls_back_when_runtime_raises(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( - "nanobot.channels.websocket._default_model_name_from_config", + "nanobot.channels.ws_http._default_model_name_from_config", lambda: "from-disk", ) diff --git a/tests/channels/test_websocket_media_route.py b/tests/channels/test_websocket_media_route.py index f70826f8e..2b6737aa0 100644 --- a/tests/channels/test_websocket_media_route.py +++ b/tests/channels/test_websocket_media_route.py @@ -106,7 +106,7 @@ def test_sign_media_path_rejects_paths_outside_media_root( media = tmp_path / "media" media.mkdir() channel = _ch(bus, port=0) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): assert channel._sign_media_path(outside) is None # Traversal via the media root is also rejected — the resolve() step # normalises ``..`` out before the relative_to check. @@ -121,7 +121,7 @@ def test_sign_media_path_round_trips_via_hmac( media.mkdir() (media / "a.png").write_bytes(_PNG_BYTES) channel = _ch(bus, port=0) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): url = channel._sign_media_path(media / "a.png") assert url is not None assert url.startswith("/api/media/") @@ -144,7 +144,7 @@ def test_local_markdown_image_is_staged_and_rewritten( media = tmp_path / "media" channel = _ch(bus, workspace_path=workspace, port=0) - with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)): + with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)): rewritten = channel._rewrite_local_markdown_images( "The result:\n![Cloud Architecture Diagram](demo_arch.png)" ) @@ -166,7 +166,7 @@ def test_local_markdown_video_is_staged_and_rewritten( media = tmp_path / "media" channel = _ch(bus, workspace_path=workspace, port=0) - with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)): + with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)): rewritten = channel._rewrite_local_markdown_images( "The result:\n![nanobot-intro.mp4](nanobot-intro.mp4)" ) @@ -189,7 +189,7 @@ def test_local_markdown_image_rejects_workspace_escape( channel = _ch(bus, workspace_path=workspace, port=0) text = "![nope](../outside.png)" - with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)): + with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)): assert channel._rewrite_local_markdown_images(text) == text assert not (media / "websocket").exists() @@ -211,7 +211,7 @@ async def test_media_route_serves_signed_file( target.write_bytes(_PNG_BYTES) channel = _ch(bus, port=29920) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): url_path = channel._sign_media_path(target) assert url_path is not None server_task = asyncio.create_task(channel.start()) @@ -244,7 +244,7 @@ async def test_media_route_serves_video_byte_ranges( target.write_bytes(b"0123456789") channel = _ch(bus, port=29927) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): url_path = channel._sign_media_path(target) assert url_path is not None server_task = asyncio.create_task(channel.start()) @@ -276,7 +276,7 @@ async def test_media_route_serves_suffix_video_byte_ranges( target.write_bytes(b"0123456789") channel = _ch(bus, port=29928) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): url_path = channel._sign_media_path(target) assert url_path is not None server_task = asyncio.create_task(channel.start()) @@ -305,7 +305,7 @@ async def test_media_route_rejects_unsatisfiable_byte_range( target.write_bytes(b"0123456789") channel = _ch(bus, port=29929) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): url_path = channel._sign_media_path(target) assert url_path is not None server_task = asyncio.create_task(channel.start()) @@ -338,7 +338,7 @@ async def test_media_route_rejects_bad_signature( (media / "f.png").write_bytes(_PNG_BYTES) channel = _ch(bus, port=29921) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): good = channel._sign_media_path(media / "f.png") assert good is not None _, payload = good[len("/api/media/"):].split("/", 1) @@ -381,7 +381,7 @@ async def test_media_route_rejects_path_traversal_payload( ).digest()[:16] url = f"/api/media/{b64url_encode(mac)}/{payload}" - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: @@ -405,7 +405,7 @@ async def test_media_route_404s_missing_file( target.write_bytes(_PNG_BYTES) channel = _ch(bus, port=29923) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): url_path = channel._sign_media_path(target) assert url_path is not None target.unlink() # the file vanishes between signing and fetching @@ -433,7 +433,7 @@ async def test_media_route_degrades_non_image_to_octet_stream( (media / "scary.html").write_bytes(b"") channel = _ch(bus, port=29924) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): payload = b64url_encode(b"scary.html") mac = hmac.new( channel._media_secret, payload.encode("ascii"), hashlib.sha256 @@ -464,7 +464,7 @@ async def test_media_route_serves_svg_with_strict_csp( target.write_text("") channel = _ch(bus, port=29928) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): url_path = channel._sign_media_path(target) assert url_path is not None server_task = asyncio.create_task(channel.start()) @@ -505,7 +505,7 @@ async def test_session_messages_exposes_signed_media_urls( sm.save(sess) channel = _ch(bus, session_manager=sm, port=29925) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: @@ -550,7 +550,7 @@ async def test_session_messages_skips_vanished_media( sm.save(sess) channel = _ch(bus, session_manager=sm, port=29926) - with patch("nanobot.channels.websocket.get_media_dir", return_value=media): + with patch("nanobot.channels.ws_http.get_media_dir", return_value=media): server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: