diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index baae42263..5bbc8879d 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -16,7 +16,6 @@ from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.schema import Config from nanobot.utils.restart import consume_restart_notice_from_env, format_restart_completed_message -from nanobot.webui.ws_http import GatewayHTTPHandler if TYPE_CHECKING: from nanobot.session.manager import SessionManager @@ -113,21 +112,24 @@ class ChannelManager: kwargs: dict[str, Any] = {} if cls.name == "websocket": from nanobot.channels.websocket import WebSocketConfig + from nanobot.webui.gateway_services import build_gateway_services parsed = WebSocketConfig.model_validate(section) static_path = _default_webui_dist() if self._webui_static_dist else None workspace = Path(self.config.workspace_path) - http_handler = GatewayHTTPHandler( + gateway = build_gateway_services( config=parsed, + bus=self.bus, session_manager=self._session_manager, static_dist_path=static_path, workspace_path=workspace, + default_restrict_to_workspace=self.config.tools.restrict_to_workspace, runtime_model_name=self._webui_runtime_model_name, runtime_surface=self._webui_runtime_surface, runtime_capabilities_overrides=self._webui_runtime_capabilities, - bus=self.bus, + logger=logger, ) - kwargs["http_handler"] = http_handler + kwargs["gateway"] = gateway channel = cls(section, self.bus, **kwargs) channel.transcription_provider = transcription_provider channel.transcription_api_key = transcription_key diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index 231f763ca..2a8fc2e7c 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -3,27 +3,20 @@ from __future__ import annotations import asyncio -import email.utils import hmac -import http import json import re import ssl import uuid from collections.abc import Callable from contextlib import suppress -from functools import partial from pathlib import Path from typing import Any, Self -from urllib.parse import parse_qs, unquote, urlparse -from loguru import logger from pydantic import Field, field_validator, model_validator from websockets.asyncio.server import ServerConnection, serve, unix_serve -from websockets.datastructures import Headers from websockets.exceptions import ConnectionClosed from websockets.http11 import Request as WsRequest -from websockets.http11 import Response from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage from nanobot.bus.queue import MessageBus @@ -41,53 +34,22 @@ from nanobot.utils.media_decode import ( save_base64_data_url, ) from nanobot.webui.cli_apps_api import normalize_cli_app_mentions +from nanobot.webui.gateway_services import GatewayServices +from nanobot.webui.http_utils import ( + is_localhost as _is_localhost, +) +from nanobot.webui.http_utils import ( + normalize_config_path as _normalize_config_path, +) +from nanobot.webui.http_utils import ( + parse_request_path as _parse_request_path, +) +from nanobot.webui.http_utils import ( + query_first as _query_first, +) from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions from nanobot.webui.transcript import append_transcript_object - - -def _strip_trailing_slash(path: str) -> str: - if len(path) > 1 and path.endswith("/"): - return path.rstrip("/") - return path or "/" - - -def _normalize_config_path(path: str) -> str: - return _strip_trailing_slash(path) - - -def _case_insensitive_header(headers: Any, key: str) -> str: - """Read a header from websockets/http test stubs without assuming casing.""" - try: - value = headers.get(key) - except Exception: - value = None - if value is None: - try: - value = headers.get(key.lower()) - except Exception: - value = None - return str(value or "").strip() - - -def _safe_host_header(value: str) -> str: - """Return a safe Host header value, or empty when it should not be echoed.""" - value = value.strip() - if not value: - return "" - if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value): - return value - if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value): - return value - return "" - - -def _host_for_url(host: str, port: int) -> str: - host = host.strip() - if host in ("0.0.0.0", "::"): - host = "127.0.0.1" - if ":" in host and not host.startswith("["): - host = f"[{host}]" - return f"{host}:{port}" +from nanobot.webui.websocket_logging import websockets_server_logger class WebSocketConfig(Base): @@ -182,20 +144,6 @@ class WebSocketConfig(Base): ) -def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response: - body = json.dumps(data, ensure_ascii=False).encode("utf-8") - headers = Headers( - [ - ("Date", email.utils.formatdate(usegmt=True)), - ("Connection", "close"), - ("Content-Length", str(len(body))), - ("Content-Type", "application/json; charset=utf-8"), - ] - ) - reason = http.HTTPStatus(status).phrase - return Response(status, reason, headers, body) - - def publish_runtime_model_update( bus: MessageBus, model: str, @@ -214,57 +162,6 @@ def publish_runtime_model_update( )) -def _default_model_name_from_config() -> str | None: - """Resolved model string from on-disk config (bootstrap fallback).""" - try: - from nanobot.config.loader import load_config - - model = load_config().resolve_preset().model.strip() - return model or None - except Exception as e: - logger.debug("bootstrap model_name could not load from config: {}", e) - return None - - -def _resolve_bootstrap_model_name( - runtime_name: Callable[[], str | None] | None, -) -> str | None: - """Prefer an in-process resolver (e.g. AgentLoop); else config-derived default.""" - if runtime_name is not None: - try: - raw = runtime_name() - except Exception as e: - logger.debug("bootstrap runtime model resolver failed: {}", e) - else: - if isinstance(raw, str): - stripped = raw.strip() - if stripped: - return stripped - return _default_model_name_from_config() - - -def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]: - """Parse normalized path and query parameters in one pass.""" - parsed = urlparse("ws://x" + path_with_query) - path = _strip_trailing_slash(parsed.path or "/") - return path, parse_qs(parsed.query, keep_blank_values=True) - - -def _normalize_http_path(path_with_query: str) -> str: - """Return the path component (no query string), with trailing slash normalized (root stays ``/``).""" - return _parse_request_path(path_with_query)[0] - - -def _parse_query(path_with_query: str) -> dict[str, list[str]]: - return _parse_request_path(path_with_query)[1] - - -def _query_first(query: dict[str, list[str]], key: str) -> str | None: - """Return the first value for *key*, or None.""" - values = query.get(key) - return values[0] if values else None - - def _parse_inbound_payload(raw: str) -> str | None: """Parse a client frame into text; return None for empty or unrecognized content.""" text = raw.strip() @@ -355,67 +252,6 @@ def _extract_data_url_mime(url: str) -> str | None: return m.group(1).strip().lower() or None -_LOCALHOSTS = frozenset({"127.0.0.1", "::1", "localhost"}) - -# Matches the legacy chat-id pattern but allows file-system-safe stems too, -# so the API can address sessions whose keys came from non-WebSocket channels. -_API_KEY_RE = re.compile(r"^[A-Za-z0-9_:.-]{1,128}$") - - -def _decode_api_key(raw_key: str) -> str | None: - """Decode a percent-encoded API path segment, then validate the result.""" - key = unquote(raw_key) - if _API_KEY_RE.match(key) is None: - return None - return key - - -def _is_localhost(connection: Any) -> bool: - """Return True if *connection* originated from the loopback interface.""" - addr = getattr(connection, "remote_address", None) - if not addr: - return False - host = addr[0] if isinstance(addr, tuple) else addr - if not isinstance(host, str): - return False - # ``::ffff:127.0.0.1`` is loopback in IPv6-mapped form. - if host.startswith("::ffff:"): - host = host[7:] - return host in _LOCALHOSTS - - -def _http_response( - body: bytes, - *, - status: int = 200, - content_type: str = "text/plain; charset=utf-8", - extra_headers: list[tuple[str, str]] | None = None, -) -> Response: - headers = [ - ("Date", email.utils.formatdate(usegmt=True)), - ("Connection", "close"), - ("Content-Length", str(len(body))), - ("Content-Type", content_type), - ] - if extra_headers: - headers.extend(extra_headers) - reason = http.HTTPStatus(status).phrase - return Response(status, reason, Headers(headers), body) - - -def _http_error(status: int, message: str | None = None) -> Response: - body = (message or http.HTTPStatus(status).phrase).encode("utf-8") - return _http_response(body, status=status) - - -def _bearer_token(headers: Any) -> str | None: - """Pull a Bearer token out of standard or query-style headers.""" - auth = headers.get("Authorization") or headers.get("authorization") - if auth and auth.lower().startswith("bearer "): - return auth[7:].strip() or None - return None - - def _is_websocket_upgrade(request: WsRequest) -> bool: """Detect an actual WS upgrade; plain HTTP GETs to the same path should fall through.""" upgrade = request.headers.get("Upgrade") or request.headers.get("upgrade") @@ -427,20 +263,6 @@ def _is_websocket_upgrade(request: WsRequest) -> bool: return True -def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool: - """Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``.""" - if not configured_secret: - return True - authorization = headers.get("Authorization") or headers.get("authorization") - if authorization and authorization.lower().startswith("bearer "): - supplied = authorization[7:].strip() - return hmac.compare_digest(supplied, configured_secret) - header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth") - if not header_token: - return False - return hmac.compare_digest(header_token.strip(), configured_secret) - - class WebSocketChannel(BaseChannel): """Run a local WebSocket server; forward text/JSON messages to the message bus.""" @@ -452,7 +274,7 @@ class WebSocketChannel(BaseChannel): config: Any, bus: MessageBus, *, - http_handler: Any | None = None, + gateway: GatewayServices, ): if isinstance(config, dict): config = WebSocketConfig.model_validate(config) @@ -467,11 +289,11 @@ class WebSocketChannel(BaseChannel): self._stop_event: asyncio.Event | None = None self._server_task: asyncio.Task[None] | None = None - # HTTP handler injected from outside (ChannelManager / gateway startup). - # Owns tokens, sessions, media, settings, static serving. - self._http = http_handler - # Backwards-compat: workspace controller used in envelope dispatch - self._webui_workspaces = http_handler.workspaces if http_handler else None + self.gateway = gateway + self._http_router = gateway.http + self._tokens = gateway.tokens + self._media = gateway.media + self._workspaces = gateway.workspaces self._stream_text_buffers: dict[tuple[str, str], list[str]] = {} @@ -501,9 +323,9 @@ class WebSocketChannel(BaseChannel): connected clients normally see it via ``goal_state`` / ``turn_end`` frames. Pushing here makes refresh + reconnect restore the strip without a new model turn. """ - if self._http.session_manager is None: + if self.gateway.session_manager is None: return - row = self._http.session_manager.read_session_file(f"websocket:{chat_id}") + row = self.gateway.session_manager.read_session_file(f"websocket:{chat_id}") meta = row.get("metadata", {}) if isinstance(row, dict) else {} if not isinstance(meta, dict): meta = {} @@ -543,57 +365,6 @@ class WebSocketChannel(BaseChannel): def _expected_path(self) -> str: return _normalize_config_path(self.config.path) - # -- Backwards-compat property aliases (used by tests) ------------------ - - @property - def _session_manager(self): - return self._http.session_manager - - @_session_manager.setter - def _session_manager(self, value): - self._http.session_manager = value - - @property - def _media_secret(self): - return self._http.media_secret - - @property - def _issued_tokens(self): - return self._http.issued_tokens - - @_issued_tokens.setter - def _issued_tokens(self, value): - self._http.issued_tokens = value - - @property - def _api_tokens(self): - return self._http.api_tokens - - def _check_api_token(self, request): - return self._http.check_api_token(request) - - def _sign_media_path(self, path): - return self._http.sign_media_path(path) - - def _sign_or_stage_media_path(self, path): - return self._http.sign_or_stage_media_path(path) - - def _rewrite_local_markdown_images(self, text): - return self._http.rewrite_local_markdown_images(text) - - def _handle_bootstrap(self, connection, request): - return self._http._handle_bootstrap(connection, request) - - def _handle_sessions_list(self, request): - return self._http._handle_sessions_list(request) - - def _handle_webui_thread_get(self, request, key): - return self._http._handle_webui_thread_get(request, key) - - @property - def _workspace_path(self): - return self._http.workspace_path - def _build_ssl_context(self) -> ssl.SSLContext | None: cert = self.config.ssl_certfile.strip() key = self.config.ssl_keyfile.strip() @@ -624,8 +395,8 @@ class WebSocketChannel(BaseChannel): return connection.respond(403, "Forbidden") return self._authorize_websocket_handshake(connection, query) -<< # Everything else goes to the HTTP handler - return await self._http.dispatch(connection, request) + # Everything else goes to the HTTP handler + return await self._http_router.dispatch(connection, request) def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any: supplied = _query_first(query, "token") @@ -634,17 +405,17 @@ class WebSocketChannel(BaseChannel): if static_token: if supplied and hmac.compare_digest(supplied, static_token): return None - if supplied and self._http.take_issued_token_if_valid(supplied): + if supplied and self._tokens.take_issued_token_if_valid(supplied): return None return connection.respond(401, "Unauthorized") if self.config.websocket_requires_token: - if supplied and self._http.take_issued_token_if_valid(supplied): + if supplied and self._tokens.take_issued_token_if_valid(supplied): return None return connection.respond(401, "Unauthorized") if supplied: - self._http.take_issued_token_if_valid(supplied) + self._tokens.take_issued_token_if_valid(supplied) return None # -- Server lifecycle and connection ingress --------------------------- @@ -878,14 +649,14 @@ class WebSocketChannel(BaseChannel): new_id = str(uuid.uuid4()) scope = await self._workspace_scope_or_error( connection, - lambda: self._webui_workspaces.scope_for_new_chat( + lambda: self._workspaces.scope_for_new_chat( envelope, controls_available=_is_localhost(connection), ), ) if scope is None: return - self._webui_workspaces.persist_scope(new_id, scope) + self._workspaces.persist_scope(new_id, scope) self._attach(connection, new_id) await self._send_event(connection, "attached", chat_id=new_id) await self._send_event( @@ -913,7 +684,7 @@ class WebSocketChannel(BaseChannel): return scope = await self._workspace_scope_or_error( connection, - lambda: self._webui_workspaces.scope_for_set_request( + lambda: self._workspaces.scope_for_set_request( envelope, chat_id=cid, chat_running=websocket_turn_wall_started_at(cid) is not None, @@ -923,7 +694,7 @@ class WebSocketChannel(BaseChannel): ) if scope is None: return - self._webui_workspaces.persist_scope(cid, scope) + self._workspaces.persist_scope(cid, scope) await self._send_event( connection, "session_updated", @@ -965,7 +736,7 @@ class WebSocketChannel(BaseChannel): return scope = await self._workspace_scope_or_error( connection, - lambda: self._webui_workspaces.scope_for_message( + lambda: self._workspaces.scope_for_message( envelope, chat_id=cid, chat_running=websocket_turn_wall_started_at(cid) is not None, @@ -989,7 +760,7 @@ class WebSocketChannel(BaseChannel): if mcp_presets: metadata["mcp_presets"] = mcp_presets metadata[WORKSPACE_SCOPE_METADATA_KEY] = scope.metadata() - self._webui_workspaces.persist_scope(cid, scope) + self._workspaces.persist_scope(cid, scope) image_generation = envelope.get("image_generation") if isinstance(image_generation, dict) and image_generation.get("enabled") is True: aspect_ratio = image_generation.get("aspect_ratio") @@ -1044,8 +815,7 @@ class WebSocketChannel(BaseChannel): self._subs.clear() self._conn_chats.clear() self._conn_default.clear() - self._http.issued_tokens.clear() - self._http.api_tokens.clear() + self._tokens.clear() async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None: """Send a raw frame to one connection, cleaning up on ConnectionClosed.""" @@ -1127,7 +897,7 @@ class WebSocketChannel(BaseChannel): ) return text = msg.content - wire_text = self._http.rewrite_local_markdown_images(text) + wire_text = self._media.rewrite_local_markdown_images(text) payload: dict[str, Any] = { "event": "message", "chat_id": msg.chat_id, @@ -1137,7 +907,7 @@ class WebSocketChannel(BaseChannel): payload["media"] = msg.media urls: list[dict[str, str]] = [] for entry in msg.media: - signed = self._http.sign_or_stage_media_path(Path(entry)) + signed = self._media.sign_or_stage_media_path(Path(entry)) if signed is not None: urls.append(signed) if urls: @@ -1252,7 +1022,7 @@ class WebSocketChannel(BaseChannel): if delta: buffered.append(delta) full_text = "".join(buffered) - rewritten = self._http.rewrite_local_markdown_images(full_text) + rewritten = self._media.rewrite_local_markdown_images(full_text) if rewritten != full_text: body["text"] = rewritten else: diff --git a/nanobot/webui/gateway_services.py b/nanobot/webui/gateway_services.py new file mode 100644 index 000000000..cf3eede19 --- /dev/null +++ b/nanobot/webui/gateway_services.py @@ -0,0 +1,70 @@ +"""Composition helpers for the embedded WebUI gateway.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from loguru import logger as default_logger + +from nanobot.webui.gateway_tokens import GatewayTokenStore +from nanobot.webui.media_gateway import WebUIMediaGateway +from nanobot.webui.workspaces import WebUIWorkspaceController +from nanobot.webui.ws_http import GatewayHTTPHandler + + +@dataclass(frozen=True) +class GatewayServices: + """Explicit dependencies shared by WebSocket transport and HTTP routes.""" + + http: GatewayHTTPHandler + tokens: GatewayTokenStore + media: WebUIMediaGateway + workspaces: WebUIWorkspaceController + session_manager: Any | None + + +def build_gateway_services( + *, + config: Any, + bus: Any, + session_manager: Any | None, + static_dist_path: Path | None, + workspace_path: Path, + default_restrict_to_workspace: bool, + runtime_model_name: Any | None, + runtime_surface: str, + runtime_capabilities_overrides: dict[str, Any] | None, + logger: Any = default_logger, +) -> GatewayServices: + tokens = GatewayTokenStore() + media = WebUIMediaGateway( + workspace_path=workspace_path, + logger=logger, + ) + workspaces = WebUIWorkspaceController( + session_manager=session_manager, + default_workspace=workspace_path, + default_restrict_to_workspace=default_restrict_to_workspace, + ) + http = GatewayHTTPHandler( + config=config, + session_manager=session_manager, + static_dist_path=static_dist_path, + runtime_model_name=runtime_model_name, + runtime_surface=runtime_surface, + runtime_capabilities_overrides=runtime_capabilities_overrides, + bus=bus, + tokens=tokens, + media=media, + workspaces=workspaces, + log=logger, + ) + return GatewayServices( + http=http, + tokens=tokens, + media=media, + workspaces=workspaces, + session_manager=session_manager, + ) diff --git a/nanobot/webui/gateway_tokens.py b/nanobot/webui/gateway_tokens.py new file mode 100644 index 000000000..a7a5b5903 --- /dev/null +++ b/nanobot/webui/gateway_tokens.py @@ -0,0 +1,82 @@ +"""Token state for the embedded WebUI gateway.""" + +from __future__ import annotations + +import secrets +import time +from dataclasses import dataclass, field +from typing import Any + +from websockets.http11 import Request as WsRequest + +from nanobot.webui.http_utils import bearer_token, parse_query, query_first + + +@dataclass +class GatewayTokenStore: + """Own short-lived WebSocket and WebUI API tokens for one gateway process.""" + + max_tokens: int = 10_000 + issued_tokens: dict[str, float] = field(default_factory=dict) + api_tokens: dict[str, float] = field(default_factory=dict) + + def check_api_token(self, request: WsRequest) -> bool: + self._purge_expired_api_tokens() + token = bearer_token(request.headers) or query_first( + parse_query(request.path), "token" + ) + if not token: + return False + expiry = self.api_tokens.get(token) + if expiry is None or time.monotonic() > expiry: + self.api_tokens.pop(token, None) + return False + return True + + def can_issue(self, *, include_api_token: bool = False) -> bool: + self._purge_expired_issued_tokens() + self._purge_expired_api_tokens() + if len(self.issued_tokens) >= self.max_tokens: + return False + if include_api_token and len(self.api_tokens) >= self.max_tokens: + return False + return True + + def issue_token(self, ttl_s: int | float, *, api_token: bool = False) -> str: + token_value = f"nbwt_{secrets.token_urlsafe(32)}" + expiry = time.monotonic() + float(ttl_s) + self.issued_tokens[token_value] = expiry + if api_token: + self.api_tokens[token_value] = expiry + return token_value + + def take_issued_token_if_valid(self, token_value: str | None) -> bool: + if not token_value: + return False + self._purge_expired_issued_tokens() + expiry = self.issued_tokens.pop(token_value, None) + if expiry is None: + return False + if time.monotonic() > expiry: + return False + return True + + def clear(self) -> None: + self.issued_tokens.clear() + self.api_tokens.clear() + + def _purge_expired_api_tokens(self) -> None: + now = time.monotonic() + for token_key, expiry in list(self.api_tokens.items()): + if now > expiry: + self.api_tokens.pop(token_key, None) + + def _purge_expired_issued_tokens(self) -> None: + now = time.monotonic() + for token_key, expiry in list(self.issued_tokens.items()): + if now > expiry: + self.issued_tokens.pop(token_key, None) + + +def token_response_payload(token: str, expires_in: Any) -> dict[str, Any]: + return {"token": token, "expires_in": expires_in} diff --git a/nanobot/webui/http_utils.py b/nanobot/webui/http_utils.py new file mode 100644 index 000000000..01f3f54bb --- /dev/null +++ b/nanobot/webui/http_utils.py @@ -0,0 +1,151 @@ +"""Shared HTTP helpers for the embedded WebUI gateway.""" + +from __future__ import annotations + +import email.utils +import hmac +import http +import json +import re +from typing import Any +from urllib.parse import parse_qs, urlparse + +from websockets.datastructures import Headers +from websockets.http11 import Response + +QueryParams = dict[str, list[str]] + + +def strip_trailing_slash(path: str) -> str: + if len(path) > 1 and path.endswith("/"): + return path.rstrip("/") + return path or "/" + + +def normalize_config_path(path: str) -> str: + return strip_trailing_slash(path) + + +def case_insensitive_header(headers: Any, key: str) -> str: + """Read a header from websockets/http test stubs without assuming casing.""" + try: + value = headers.get(key) + except Exception: + value = None + if value is None: + try: + value = headers.get(key.lower()) + except Exception: + value = None + return str(value or "").strip() + + +def safe_host_header(value: str) -> str: + """Return a safe Host header value, or empty when it should not be echoed.""" + value = value.strip() + if not value: + return "" + if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value): + return value + if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value): + return value + return "" + + +def host_for_url(host: str, port: int) -> str: + host = host.strip() + if host in ("0.0.0.0", "::"): + host = "127.0.0.1" + if ":" in host and not host.startswith("["): + host = f"[{host}]" + return f"{host}:{port}" + + +def http_json_response(data: dict[str, Any], *, status: int = 200) -> Response: + body = json.dumps(data, ensure_ascii=False).encode("utf-8") + headers = Headers( + [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", "application/json; charset=utf-8"), + ] + ) + reason = http.HTTPStatus(status).phrase + return Response(status, reason, headers, body) + + +def http_response( + body: bytes, + *, + status: int = 200, + content_type: str = "text/plain; charset=utf-8", + extra_headers: list[tuple[str, str]] | None = None, +) -> Response: + headers = [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", content_type), + ] + if extra_headers: + headers.extend(extra_headers) + reason = http.HTTPStatus(status).phrase + return Response(status, reason, Headers(headers), body) + + +def http_error(status: int, message: str | None = None) -> Response: + body = (message or http.HTTPStatus(status).phrase).encode("utf-8") + return http_response(body, status=status) + + +def parse_request_path(path_with_query: str) -> tuple[str, QueryParams]: + """Parse normalized path and query parameters in one pass.""" + parsed = urlparse("ws://x" + path_with_query) + path = strip_trailing_slash(parsed.path or "/") + return path, parse_qs(parsed.query, keep_blank_values=True) + + +def normalize_http_path(path_with_query: str) -> str: + return parse_request_path(path_with_query)[0] + + +def parse_query(path_with_query: str) -> QueryParams: + return parse_request_path(path_with_query)[1] + + +def query_first(query: QueryParams, key: str) -> str | None: + values = query.get(key) + return values[0] if values else None + + +def is_localhost(connection: Any) -> bool: + addr = getattr(connection, "remote_address", None) + if not addr: + return False + host = addr[0] if isinstance(addr, tuple) else addr + if not isinstance(host, str): + return False + if host.startswith("::ffff:"): + host = host[7:] + return host in {"127.0.0.1", "::1", "localhost"} + + +def bearer_token(headers: Any) -> str | None: + auth = headers.get("Authorization") or headers.get("authorization") + if auth and auth.lower().startswith("bearer "): + return auth[7:].strip() or None + return None + + +def issue_route_secret_matches(headers: Any, configured_secret: str) -> bool: + if not configured_secret: + return True + authorization = headers.get("Authorization") or headers.get("authorization") + if authorization and authorization.lower().startswith("bearer "): + supplied = authorization[7:].strip() + return hmac.compare_digest(supplied, configured_secret) + header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth") + if not header_token: + return False + return hmac.compare_digest(header_token.strip(), configured_secret) diff --git a/nanobot/webui/media_api.py b/nanobot/webui/media_api.py index 845ecd903..f8292d40d 100644 --- a/nanobot/webui/media_api.py +++ b/nanobot/webui/media_api.py @@ -4,10 +4,8 @@ from __future__ import annotations import base64 import binascii -import email.utils import hashlib import hmac -import http import mimetypes import re import shutil @@ -16,12 +14,20 @@ from collections.abc import Callable from pathlib import Path from typing import Any -from websockets.datastructures import Headers from websockets.http11 import Request as WsRequest from websockets.http11 import Response from nanobot.config.paths import get_media_dir from nanobot.utils.helpers import safe_filename +from nanobot.webui.http_utils import ( + case_insensitive_header as _case_insensitive_header, +) +from nanobot.webui.http_utils import ( + http_error as _http_error, +) +from nanobot.webui.http_utils import ( + http_response as _http_response, +) MediaDirProvider = Callable[[str | None], Path] SignedMediaPath = Callable[[Path], dict[str, str] | None] @@ -67,43 +73,6 @@ _SVG_MEDIA_HEADERS: tuple[tuple[str, str], ...] = ( _BYTE_RANGE_RE = re.compile(r"^bytes=(\d*)-(\d*)$") -def _http_response( - body: bytes, - *, - status: int = 200, - content_type: str = "text/plain; charset=utf-8", - extra_headers: list[tuple[str, str]] | None = None, -) -> Response: - headers = [ - ("Date", email.utils.formatdate(usegmt=True)), - ("Connection", "close"), - ("Content-Length", str(len(body))), - ("Content-Type", content_type), - ] - if extra_headers: - headers.extend(extra_headers) - reason = http.HTTPStatus(status).phrase - return Response(status, reason, Headers(headers), body) - - -def _http_error(status: int, message: str | None = None) -> Response: - body = (message or http.HTTPStatus(status).phrase).encode("utf-8") - return _http_response(body, status=status) - - -def _case_insensitive_header(headers: Any, key: str) -> str: - try: - value = headers.get(key) - except Exception: - value = None - if value is None: - try: - value = headers.get(key.lower()) - except Exception: - value = None - return str(value or "").strip() - - def _parse_single_byte_range(range_header: str, size: int) -> tuple[int, int]: """Parse a single HTTP byte range for signed media responses.""" if size <= 0 or "," in range_header: diff --git a/nanobot/webui/media_gateway.py b/nanobot/webui/media_gateway.py new file mode 100644 index 000000000..27109fa32 --- /dev/null +++ b/nanobot/webui/media_gateway.py @@ -0,0 +1,92 @@ +"""Media gateway services shared by WebUI HTTP routes and WebSocket frames.""" + +from __future__ import annotations + +import secrets +from collections.abc import Callable +from pathlib import Path +from typing import Any + +from websockets.http11 import Request as WsRequest +from websockets.http11 import Response + +from nanobot.config.paths import get_media_dir +from nanobot.webui.media_api import ( + attach_signed_media_urls, + serve_signed_media, + sign_media_path, + sign_or_stage_media_path, + signed_media_attachments, +) +from nanobot.webui.transcript import rewrite_local_markdown_images + + +class WebUIMediaGateway: + """Own media URL signing and WebUI markdown/media augmentation.""" + + def __init__( + self, + *, + workspace_path: Path, + logger: Any, + media_dir: Callable[[str | None], Path] | None = None, + secret: bytes | None = None, + ) -> None: + self.workspace_path = workspace_path + self.logger = logger + self._media_dir = media_dir or (lambda channel=None: get_media_dir(channel)) + self.secret = secret or secrets.token_bytes(32) + + def serve_signed_media( + self, + sig: str, + payload: str, + *, + request: WsRequest | None = None, + ) -> Response: + return serve_signed_media( + sig, + payload, + secret=self.secret, + request=request, + media_dir=self._media_dir, + ) + + def sign_media_path(self, abs_path: Path) -> str | None: + return sign_media_path( + abs_path, + secret=self.secret, + media_dir=self._media_dir, + ) + + def sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None: + return sign_or_stage_media_path( + path, + secret=self.secret, + media_dir=self._media_dir, + logger=self.logger, + ) + + def rewrite_local_markdown_images( + self, + text: str, + *, + workspace_path: Path | None = None, + ) -> str: + return rewrite_local_markdown_images( + text, + workspace_path=workspace_path or self.workspace_path, + sign_path=self.sign_or_stage_media_path, + ) + + def augment_media_urls(self, payload: dict[str, Any]) -> None: + attach_signed_media_urls(payload, sign_path=self.sign_media_path) + + def augment_transcript_media(self, paths: list[str]) -> list[dict[str, Any]]: + return signed_media_attachments( + paths, + sign_path=self.sign_or_stage_media_path, + ) + + def augment_transcript_user_media(self, paths: list[str]) -> list[dict[str, Any]]: + return self.augment_transcript_media(paths) diff --git a/nanobot/webui/ws_http.py b/nanobot/webui/ws_http.py index 1d6f35b5e..89f5e7b12 100644 --- a/nanobot/webui/ws_http.py +++ b/nanobot/webui/ws_http.py @@ -9,187 +9,70 @@ Also houses shared HTTP utility functions used by both this module and from __future__ import annotations -import email.utils -import hmac -import http import json import mimetypes import re -import secrets -import time from collections.abc import Callable from pathlib import Path from typing import TYPE_CHECKING, Any -from urllib.parse import parse_qs, urlparse from loguru import logger -from websockets.datastructures import Headers from websockets.http11 import Request as WsRequest from websockets.http11 import Response from nanobot.command.builtin import builtin_command_palette -from nanobot.config.paths import get_media_dir from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel -from nanobot.webui.media_api import ( - serve_signed_media, - sign_media_path, - sign_or_stage_media_path, +from nanobot.webui.gateway_tokens import GatewayTokenStore, token_response_payload +from nanobot.webui.http_utils import ( + case_insensitive_header as _case_insensitive_header, ) +from nanobot.webui.http_utils import ( + host_for_url as _host_for_url, +) +from nanobot.webui.http_utils import ( + http_error as _http_error, +) +from nanobot.webui.http_utils import ( + http_json_response as _http_json_response, +) +from nanobot.webui.http_utils import ( + http_response as _http_response, +) +from nanobot.webui.http_utils import ( + is_localhost as _is_localhost, +) +from nanobot.webui.http_utils import ( + issue_route_secret_matches as _issue_route_secret_matches, +) +from nanobot.webui.http_utils import ( + normalize_config_path as _normalize_config_path, +) +from nanobot.webui.http_utils import ( + parse_query as _parse_query, +) +from nanobot.webui.http_utils import ( + parse_request_path as _parse_request_path, +) +from nanobot.webui.http_utils import ( + query_first as _query_first, +) +from nanobot.webui.http_utils import ( + safe_host_header as _safe_host_header, +) +from nanobot.webui.media_gateway import WebUIMediaGateway from nanobot.webui.sidebar_state import ( read_webui_sidebar_state, write_webui_sidebar_state, ) from nanobot.webui.thread_disk import delete_webui_thread -from nanobot.webui.transcript import ( - build_webui_thread_response, - rewrite_local_markdown_images, -) +from nanobot.webui.transcript import build_webui_thread_response +from nanobot.webui.workspaces import WebUIWorkspaceController if TYPE_CHECKING: from nanobot.bus.queue import MessageBus from nanobot.session.manager import SessionManager -# --------------------------------------------------------------------------- -# Shared HTTP utility functions (imported by websocket.py) -# --------------------------------------------------------------------------- - - -def _strip_trailing_slash(path: str) -> str: - if len(path) > 1 and path.endswith("/"): - return path.rstrip("/") - return path or "/" - - -def _normalize_config_path(path: str) -> str: - return _strip_trailing_slash(path) - - -def _case_insensitive_header(headers: Any, key: str) -> str: - """Read a header from websockets/http test stubs without assuming casing.""" - try: - value = headers.get(key) - except Exception: - value = None - if value is None: - try: - value = headers.get(key.lower()) - except Exception: - value = None - return str(value or "").strip() - - -def _safe_host_header(value: str) -> str: - """Return a safe Host header value, or empty when it should not be echoed.""" - value = value.strip() - if not value: - return "" - if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value): - return value - if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value): - return value - return "" - - -def _host_for_url(host: str, port: int) -> str: - host = host.strip() - if host in ("0.0.0.0", "::"): - host = "127.0.0.1" - if ":" in host and not host.startswith("["): - host = f"[{host}]" - return f"{host}:{port}" - - -def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response: - body = json.dumps(data, ensure_ascii=False).encode("utf-8") - headers = Headers( - [ - ("Date", email.utils.formatdate(usegmt=True)), - ("Connection", "close"), - ("Content-Length", str(len(body))), - ("Content-Type", "application/json; charset=utf-8"), - ] - ) - reason = http.HTTPStatus(status).phrase - return Response(status, reason, headers, body) - - -def _http_response( - body: bytes, - *, - status: int = 200, - content_type: str = "text/plain; charset=utf-8", - extra_headers: list[tuple[str, str]] | None = None, -) -> Response: - headers = [ - ("Date", email.utils.formatdate(usegmt=True)), - ("Connection", "close"), - ("Content-Length", str(len(body))), - ("Content-Type", content_type), - ] - if extra_headers: - headers.extend(extra_headers) - reason = http.HTTPStatus(status).phrase - return Response(status, reason, Headers(headers), body) - - -def _http_error(status: int, message: str | None = None) -> Response: - body = (message or http.HTTPStatus(status).phrase).encode("utf-8") - return _http_response(body, status=status) - - -def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]: - """Parse normalized path and query parameters in one pass.""" - parsed = urlparse("ws://x" + path_with_query) - path = _strip_trailing_slash(parsed.path or "/") - return path, parse_qs(parsed.query, keep_blank_values=True) - - -def _normalize_http_path(path_with_query: str) -> str: - return _parse_request_path(path_with_query)[0] - - -def _parse_query(path_with_query: str) -> dict[str, list[str]]: - return _parse_request_path(path_with_query)[1] - - -def _query_first(query: dict[str, list[str]], key: str) -> str | None: - values = query.get(key) - return values[0] if values else None - - -def _is_localhost(connection: Any) -> bool: - addr = getattr(connection, "remote_address", None) - if not addr: - return False - host = addr[0] if isinstance(addr, tuple) else addr - if not isinstance(host, str): - return False - if host.startswith("::ffff:"): - host = host[7:] - return host in {"127.0.0.1", "::1", "localhost"} - - -def _bearer_token(headers: Any) -> str | None: - auth = headers.get("Authorization") or headers.get("authorization") - if auth and auth.lower().startswith("bearer "): - return auth[7:].strip() or None - return None - - -def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool: - if not configured_secret: - return True - authorization = headers.get("Authorization") or headers.get("authorization") - if authorization and authorization.lower().startswith("bearer "): - supplied = authorization[7:].strip() - return hmac.compare_digest(supplied, configured_secret) - header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth") - if not header_token: - return False - return hmac.compare_digest(header_token.strip(), configured_secret) - - def _decode_api_key(raw_key: str) -> str | None: from urllib.parse import unquote @@ -234,48 +117,36 @@ def _resolve_bootstrap_model_name( class GatewayHTTPHandler: """Handles all HTTP routes served alongside the WebSocket endpoint. - Owns token management, session API, media API, static file serving, - and delegates settings routes to ``WebUISettingsRouter``. + Routes HTTP requests and delegates stateful work to explicit gateway + services owned by the composition layer. """ - _MAX_ISSUED_TOKENS = 10_000 - def __init__( self, *, config: Any, # WebSocketConfig session_manager: SessionManager | None, static_dist_path: Path | None, - workspace_path: Path, runtime_model_name: Callable[[], str | None] | None, runtime_surface: str, runtime_capabilities_overrides: dict[str, Any] | None, bus: MessageBus, + tokens: GatewayTokenStore, + media: WebUIMediaGateway, + workspaces: WebUIWorkspaceController, log: Any = logger, ) -> None: self.config = config self.session_manager = session_manager self.static_dist_path = static_dist_path - self.workspace_path = workspace_path self.runtime_model_name = runtime_model_name self.bus = bus + self.tokens = tokens + self.media = media + self.workspaces = workspaces self._log = log self._runtime_surface = runtime_surface - self.issued_tokens: dict[str, float] = {} - self.api_tokens: dict[str, float] = {} - self.media_secret: bytes = secrets.token_bytes(32) - - # Workspace controller - from nanobot.webui.workspaces import WebUIWorkspaceController - - self.workspaces = WebUIWorkspaceController( - session_manager=session_manager, - default_workspace=workspace_path, - default_restrict_to_workspace=None, - ) - - # Settings router from nanobot.webui.settings_api import runtime_capabilities as _rc from nanobot.webui.settings_routes import WebUISettingsRouter @@ -294,46 +165,13 @@ class GatewayHTTPHandler: # -- Token management --------------------------------------------------- def check_api_token(self, request: WsRequest) -> bool: - self._purge_expired_api_tokens() - token = _bearer_token(request.headers) or _query_first( - _parse_query(request.path), "token" - ) - if not token: - return False - expiry = self.api_tokens.get(token) - if expiry is None or time.monotonic() > expiry: - self.api_tokens.pop(token, None) - return False - return True - - def _purge_expired_api_tokens(self) -> None: - now = time.monotonic() - for token_key, expiry in list(self.api_tokens.items()): - if now > expiry: - self.api_tokens.pop(token_key, None) - - def _purge_expired_issued_tokens(self) -> None: - now = time.monotonic() - for token_key, expiry in list(self.issued_tokens.items()): - if now > expiry: - self.issued_tokens.pop(token_key, None) - - def take_issued_token_if_valid(self, token_value: str | None) -> bool: - if not token_value: - return False - self._purge_expired_issued_tokens() - expiry = self.issued_tokens.pop(token_value, None) - if expiry is None: - return False - if time.monotonic() > expiry: - return False - return True + return self.tokens.check_api_token(request) # -- Main dispatch ------------------------------------------------------ async def dispatch(self, connection: Any, request: WsRequest) -> Any | None: """Route an HTTP request. Returns Response or None.""" - got, query = _parse_request_path(request.path) + got, _ = _parse_request_path(request.path) # Token issue endpoint if self.config.token_issue_path: @@ -389,18 +227,14 @@ class GatewayHTTPHandler: "token_issue_path is set but token_issue_secret is empty; " "any client can obtain connection tokens — set token_issue_secret for production." ) - self._purge_expired_issued_tokens() - if len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS: + if not self.tokens.can_issue(): self._log.error( "too many outstanding issued tokens ({}), rejecting issuance", - len(self.issued_tokens), + len(self.tokens.issued_tokens), ) return _http_json_response({"error": "too many outstanding tokens"}, status=429) - token_value = f"nbwt_{secrets.token_urlsafe(32)}" - self.issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s) - return _http_json_response( - {"token": token_value, "expires_in": self.config.token_ttl_s} - ) + token_value = self.tokens.issue_token(self.config.token_ttl_s) + return _http_json_response(token_response_payload(token_value, self.config.token_ttl_s)) # -- Bootstrap ---------------------------------------------------------- @@ -412,21 +246,13 @@ class GatewayHTTPHandler: elif not _is_localhost(connection): return _http_error(403, "bootstrap is localhost-only") - self._purge_expired_issued_tokens() - self._purge_expired_api_tokens() - if ( - len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS - or len(self.api_tokens) >= self._MAX_ISSUED_TOKENS - ): + if not self.tokens.can_issue(include_api_token=True): return _http_response( json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"), status=429, content_type="application/json; charset=utf-8", ) - token = f"nbwt_{secrets.token_urlsafe(32)}" - expiry = time.monotonic() + float(self.config.token_ttl_s) - self.issued_tokens[token] = expiry - self.api_tokens[token] = expiry + token = self.tokens.issue_token(self.config.token_ttl_s, api_token=True) ws_url = self._bootstrap_ws_url(request) expected_path = _normalize_config_path(self.config.path) @@ -510,7 +336,7 @@ class GatewayHTTPHandler: messages = data.get("messages") if isinstance(messages, list): scrub_subagent_messages_for_channel(messages) - self._augment_media_urls(data) + self.media.augment_media_urls(data) return _http_json_response(data) def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response: @@ -524,11 +350,11 @@ class GatewayHTTPHandler: scope = self.workspaces.scope_for_session_key(decoded_key) data = build_webui_thread_response( decoded_key, - augment_user_media=self._augment_transcript_user_media, - augment_assistant_text=lambda text: rewrite_local_markdown_images( + augment_user_media=self.media.augment_transcript_media, + augment_assistant_media=self.media.augment_transcript_media, + augment_assistant_text=lambda text: self.media.rewrite_local_markdown_images( text, workspace_path=scope.project_path, - sign_path=self.sign_or_stage_media_path, ), ) if data is None: @@ -561,34 +387,10 @@ class GatewayHTTPHandler: def _handle_media_fetch( self, sig: str, payload: str, request: WsRequest | None = None ) -> Response: - return serve_signed_media( + return self.media.serve_signed_media( sig, payload, - secret=self.media_secret, request=request, - media_dir=lambda channel=None: get_media_dir(channel), - ) - - def sign_media_path(self, abs_path: Path) -> str | None: - return sign_media_path( - abs_path, - secret=self.media_secret, - media_dir=lambda channel=None: get_media_dir(channel), - ) - - def sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None: - return sign_or_stage_media_path( - path, - secret=self.media_secret, - media_dir=lambda channel=None: get_media_dir(channel), - logger=self._log, - ) - - def rewrite_local_markdown_images(self, text: str) -> str: - return rewrite_local_markdown_images( - text, - workspace_path=self.workspace_path, - sign_path=self.sign_or_stage_media_path, ) # -- Misc routes -------------------------------------------------------- @@ -688,44 +490,5 @@ class GatewayHTTPHandler: extra_headers=[("Cache-Control", cache)], ) - # -- Media helpers (called by WebSocketChannel.send) -------------------- - - def _augment_media_urls(self, payload: dict[str, Any]) -> None: - messages = payload.get("messages") - if not isinstance(messages, list): - return - for msg in messages: - if not isinstance(msg, dict): - continue - media = msg.get("media") - if not isinstance(media, list) or not media: - continue - urls: list[dict[str, str]] = [] - for entry in media: - if not isinstance(entry, str) or not entry: - continue - signed = self.sign_media_path(Path(entry)) - if signed is None: - continue - urls.append({"url": signed, "name": Path(entry).name}) - if urls: - msg["media_urls"] = urls - msg.pop("media", None) - - def _augment_transcript_user_media(self, paths: list[str]) -> list[dict[str, Any]]: - out: list[dict[str, Any]] = [] - for pstr in paths: - path = Path(pstr) - att = self.sign_or_stage_media_path(path) - if att is None: - continue - mime, _ = mimetypes.guess_type(path.name) - kind = "video" if mime and mime.startswith("video/") else "image" - out.append( - {"kind": kind, "url": att["url"], "name": att.get("name", path.name)}, - ) - return out - - def _is_websocket_channel_session_key(key: str) -> bool: return key.startswith("websocket:") diff --git a/tests/channels/test_channel_manager_reasoning.py b/tests/channels/test_channel_manager_reasoning.py index b02262751..5df1b3fbf 100644 --- a/tests/channels/test_channel_manager_reasoning.py +++ b/tests/channels/test_channel_manager_reasoning.py @@ -65,6 +65,32 @@ def manager() -> ChannelManager: return mgr +def test_websocket_gateway_uses_configured_workspace_restriction(tmp_path, monkeypatch): + monkeypatch.setattr( + "nanobot.webui.workspaces.read_webui_default_access_mode", + lambda: "default", + ) + config = Config.model_validate( + { + "agents": {"defaults": {"workspace": str(tmp_path)}}, + "tools": {"restrictToWorkspace": True}, + "channels": { + "websocket": { + "enabled": True, + "websocketRequiresToken": False, + }, + }, + } + ) + + mgr = ChannelManager(config, MessageBus(), webui_static_dist=False) + channel = mgr.channels["websocket"] + + scope = channel.gateway.workspaces.default_scope() + assert scope.project_path == tmp_path + assert scope.restrict_to_workspace is True + + @pytest.mark.asyncio async def test_reasoning_delta_routes_to_send_reasoning_delta(manager): channel = manager.channels["mock"] diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index 51fb28e7d..d6f047de3 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -20,20 +20,30 @@ from nanobot.channels.websocket import ( WebSocketChannel, WebSocketConfig, _is_valid_chat_id, - _issue_route_secret_matches, - _normalize_config_path, - _normalize_http_path, _parse_envelope, _parse_inbound_payload, - _parse_query, - _parse_request_path, publish_runtime_model_update, ) -from nanobot.webui.ws_http import GatewayHTTPHandler from nanobot.config.loader import load_config, save_config from nanobot.config.schema import Config, ModelPresetConfig from nanobot.session import webui_turns as wth from nanobot.session.manager import SessionManager +from nanobot.webui.gateway_services import GatewayServices, build_gateway_services +from nanobot.webui.http_utils import ( + issue_route_secret_matches as _issue_route_secret_matches, +) +from nanobot.webui.http_utils import ( + normalize_config_path as _normalize_config_path, +) +from nanobot.webui.http_utils import ( + normalize_http_path as _normalize_http_path, +) +from nanobot.webui.http_utils import ( + parse_query as _parse_query, +) +from nanobot.webui.http_utils import ( + parse_request_path as _parse_request_path, +) from nanobot.webui.settings_api import settings_payload, update_provider_settings # -- Shared helpers (aligned with test_websocket_integration.py) --------------- @@ -52,34 +62,36 @@ def _ch(bus: Any, **kw: Any) -> WebSocketChannel: } cfg.update(kw) parsed = WebSocketConfig.model_validate(cfg) - http_handler = GatewayHTTPHandler( + gateway = build_gateway_services( config=parsed, + bus=bus, session_manager=None, static_dist_path=None, workspace_path=Path.cwd(), + default_restrict_to_workspace=False, runtime_model_name=None, runtime_surface="browser", runtime_capabilities_overrides=None, - bus=bus, ) - return WebSocketChannel(cfg, bus, http_handler=http_handler) + return WebSocketChannel(cfg, bus, gateway=gateway) -def _basic_handler(bus: Any, **kw: Any) -> GatewayHTTPHandler: +def _basic_handler(bus: Any, **kw: Any) -> GatewayServices: cfg = WebSocketConfig.model_validate({ "enabled": True, "allowFrom": ["*"], "host": "127.0.0.1", "port": _PORT, "path": "/ws", "websocketRequiresToken": False, }) - return GatewayHTTPHandler( + return build_gateway_services( config=cfg, + bus=bus, session_manager=kw.get("session_manager"), static_dist_path=None, workspace_path=kw.get("workspace_path", Path.cwd()), + default_restrict_to_workspace=kw.get("default_restrict_to_workspace", False), runtime_model_name=None, runtime_surface=kw.get("runtime_surface", "browser"), runtime_capabilities_overrides=kw.get("runtime_capabilities_overrides"), - bus=bus, ) @@ -194,7 +206,7 @@ def test_ssl_context_requires_both_cert_and_key_files() -> None: channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""}, bus, - http_handler=_basic_handler(bus), + gateway=_basic_handler(bus), ) with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"): channel._build_ssl_context() @@ -310,7 +322,7 @@ async def test_webui_message_scope_inherits_persisted_session_scope( channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, bus, - http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace), + gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace), ) conn = AsyncMock() conn.remote_address = ("127.0.0.1", 50123) @@ -356,7 +368,7 @@ async def test_webui_scope_expands_home_project_path( channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, bus, - http_handler=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace), + gateway=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace), ) conn = AsyncMock() conn.remote_address = ("127.0.0.1", 50123) @@ -393,7 +405,7 @@ async def test_webui_scope_rejects_missing_project_path(bus: MagicMock, tmp_path channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, bus, - http_handler=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace), + gateway=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace), ) conn = AsyncMock() conn.remote_address = ("127.0.0.1", 50123) @@ -430,7 +442,7 @@ async def test_webui_scope_rejects_running_scope_change(bus: MagicMock, tmp_path channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, bus, - http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace), + gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace), ) conn = AsyncMock() conn.remote_address = ("127.0.0.1", 50123) @@ -486,7 +498,7 @@ async def test_webui_set_workspace_scope_rejects_running_chat(bus: MagicMock, tm channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, bus, - http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace), + gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace), ) conn = AsyncMock() conn.remote_address = ("127.0.0.1", 50123) @@ -545,7 +557,7 @@ async def test_webui_scope_rejects_non_loopback_custom_scope(bus: MagicMock, tmp channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"}, bus, - http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace), + gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace), ) conn = AsyncMock() conn.remote_address = ("203.0.113.8", 50123) @@ -574,7 +586,7 @@ async def test_webui_scope_rejects_non_loopback_custom_scope(bus: MagicMock, tmp @pytest.mark.asyncio async def test_send_delivers_json_message_with_media_and_reply() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -600,7 +612,7 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None: @pytest.mark.asyncio async def test_send_broadcasts_runtime_model_updates() -> None: bus = MessageBus() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -647,8 +659,8 @@ async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) - return ws_media if channel == "websocket" else media_root monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir) - monkeypatch.setattr("nanobot.webui.ws_http.get_media_dir", fake_media_dir) - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + monkeypatch.setattr("nanobot.webui.media_gateway.get_media_dir", fake_media_dir) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -671,7 +683,7 @@ async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) - @pytest.mark.asyncio async def test_send_missing_connection_is_noop_without_error() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) msg = OutboundMessage(channel="websocket", chat_id="missing", content="x") await channel.send(msg) @@ -679,7 +691,7 @@ async def test_send_missing_connection_is_noop_without_error() -> None: @pytest.mark.asyncio async def test_send_removes_connection_on_connection_closed() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True) channel._attach(mock_ws, "chat-1") @@ -694,7 +706,7 @@ async def test_send_removes_connection_on_connection_closed() -> None: @pytest.mark.asyncio async def test_send_progress_includes_structured_tool_events() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -742,7 +754,7 @@ async def test_send_progress_includes_structured_tool_events() -> None: @pytest.mark.asyncio async def test_send_file_edit_progress_uses_file_edit_event() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -791,7 +803,7 @@ async def test_send_file_edit_progress_uses_file_edit_event() -> None: @pytest.mark.asyncio async def test_send_progress_includes_agent_ui_blob() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -815,7 +827,7 @@ async def test_send_progress_includes_agent_ui_blob() -> None: @pytest.mark.asyncio async def test_send_delta_removes_connection_on_connection_closed() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True) channel._attach(mock_ws, "chat-1") @@ -829,7 +841,7 @@ async def test_send_delta_removes_connection_on_connection_closed() -> None: @pytest.mark.asyncio async def test_send_delta_emits_delta_and_stream_end() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -862,11 +874,11 @@ async def test_send_delta_stream_end_rewrites_local_markdown_image(monkeypatch, return path monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir) - monkeypatch.setattr("nanobot.webui.ws_http.get_media_dir", fake_media_dir) + monkeypatch.setattr("nanobot.webui.media_gateway.get_media_dir", fake_media_dir) channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, - http_handler=_basic_handler(bus, workspace_path=workspace), + gateway=_basic_handler(bus, workspace_path=workspace), ) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -895,11 +907,11 @@ async def test_send_delta_stream_end_rewrites_inline_final_text(monkeypatch, tmp return path monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir) - monkeypatch.setattr("nanobot.webui.ws_http.get_media_dir", fake_media_dir) + monkeypatch.setattr("nanobot.webui.media_gateway.get_media_dir", fake_media_dir) channel = WebSocketChannel( {"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, - http_handler=_basic_handler(bus, workspace_path=workspace), + gateway=_basic_handler(bus, workspace_path=workspace), ) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -919,7 +931,7 @@ async def test_send_delta_stream_end_rewrites_inline_final_text(monkeypatch, tmp @pytest.mark.asyncio async def test_send_reasoning_delta_emits_streaming_frame() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -940,7 +952,7 @@ async def test_send_reasoning_delta_emits_streaming_frame() -> None: @pytest.mark.asyncio async def test_send_reasoning_end_emits_close_frame() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -956,7 +968,7 @@ async def test_send_reasoning_one_shot_expands_to_delta_plus_end() -> None: the base implementation must produce one delta and one end so the WebUI sees the same shape either way.""" bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -978,7 +990,7 @@ async def test_send_reasoning_one_shot_expands_to_delta_plus_end() -> None: @pytest.mark.asyncio async def test_send_reasoning_delta_drops_empty_chunks() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -990,7 +1002,7 @@ async def test_send_reasoning_delta_drops_empty_chunks() -> None: @pytest.mark.asyncio async def test_send_reasoning_without_subscribers_is_noop() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) await channel.send_reasoning_delta("unattached", "thinking", None) await channel.send_reasoning_end("unattached", None) @@ -1000,7 +1012,7 @@ async def test_send_reasoning_without_subscribers_is_noop() -> None: @pytest.mark.asyncio async def test_send_turn_end_emits_turn_end_event() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -1019,7 +1031,7 @@ async def test_send_turn_end_emits_turn_end_event() -> None: @pytest.mark.asyncio async def test_send_turn_end_includes_latency_ms_when_present() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -1038,7 +1050,7 @@ async def test_send_turn_end_includes_latency_ms_when_present() -> None: @pytest.mark.asyncio async def test_send_turn_end_includes_goal_state_when_present() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -1058,7 +1070,7 @@ async def test_send_turn_end_includes_goal_state_when_present() -> None: @pytest.mark.asyncio async def test_send_goal_status_running_emits_event_with_started_at() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -1086,7 +1098,7 @@ async def test_send_goal_status_running_emits_event_with_started_at() -> None: @pytest.mark.asyncio async def test_send_goal_status_idle_omits_started_at() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -1109,7 +1121,7 @@ async def test_send_goal_status_idle_omits_started_at() -> None: @pytest.mark.asyncio async def test_send_goal_state_emits_blob_per_chat() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_a = AsyncMock() mock_b = AsyncMock() channel._attach(mock_a, "chat-a") @@ -1138,10 +1150,9 @@ async def test_send_goal_state_emits_blob_per_chat() -> None: @pytest.mark.asyncio async def test_maybe_push_active_goal_state_noop_without_session_manager() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") - channel._session_manager = None await channel._maybe_push_active_goal_state("chat-1") mock_ws.send.assert_not_called() @@ -1149,10 +1160,13 @@ async def test_maybe_push_active_goal_state_noop_without_session_manager() -> No @pytest.mark.asyncio async def test_maybe_push_active_goal_state_skips_when_no_goal_on_disk() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) sm = MagicMock() sm.read_session_file.return_value = None - channel._session_manager = sm + channel = WebSocketChannel( + {"enabled": True, "allowFrom": ["*"]}, + bus, + gateway=_basic_handler(bus, session_manager=sm), + ) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") await channel._maybe_push_active_goal_state("chat-1") @@ -1162,7 +1176,6 @@ async def test_maybe_push_active_goal_state_skips_when_no_goal_on_disk() -> None @pytest.mark.asyncio async def test_maybe_push_active_goal_state_notifies_when_goal_active_on_disk() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) sm = MagicMock() sm.read_session_file.return_value = { "metadata": { @@ -1174,7 +1187,11 @@ async def test_maybe_push_active_goal_state_notifies_when_goal_active_on_disk() }, "messages": [], } - channel._session_manager = sm + channel = WebSocketChannel( + {"enabled": True, "allowFrom": ["*"]}, + bus, + gateway=_basic_handler(bus, session_manager=sm), + ) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") await channel._maybe_push_active_goal_state("chat-1") @@ -1190,7 +1207,7 @@ async def test_maybe_push_active_goal_state_notifies_when_goal_active_on_disk() @pytest.mark.asyncio async def test_maybe_push_turn_run_wall_clock_skips_when_no_active_turn() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") from nanobot.session import webui_turns as wth @@ -1203,7 +1220,7 @@ async def test_maybe_push_turn_run_wall_clock_skips_when_no_active_turn() -> Non @pytest.mark.asyncio async def test_maybe_push_turn_run_wall_clock_replays_running() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") from nanobot.session import webui_turns as wth @@ -1228,7 +1245,7 @@ async def test_maybe_push_turn_run_wall_clock_replays_running() -> None: @pytest.mark.asyncio async def test_send_session_updated_emits_session_updated_event() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -1247,7 +1264,7 @@ async def test_send_session_updated_emits_session_updated_event() -> None: @pytest.mark.asyncio async def test_send_session_updated_includes_scope_when_present() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() channel._attach(mock_ws, "chat-1") @@ -1266,7 +1283,7 @@ async def test_send_session_updated_includes_scope_when_present() -> None: @pytest.mark.asyncio async def test_send_non_connection_closed_exception_is_raised() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) mock_ws = AsyncMock() mock_ws.send.side_effect = RuntimeError("unexpected") channel._attach(mock_ws, "chat-1") @@ -1279,7 +1296,7 @@ async def test_send_non_connection_closed_exception_is_raised() -> None: @pytest.mark.asyncio async def test_send_delta_missing_connection_is_noop() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, gateway=_basic_handler(bus)) # No exception, no error — just a no-op await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"}) @@ -1287,7 +1304,7 @@ async def test_send_delta_missing_connection_is_noop() -> None: @pytest.mark.asyncio async def test_stop_is_idempotent() -> None: bus = MagicMock() - channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus)) + channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus)) # stop() before start() should not raise await channel.stop() await channel.stop() @@ -1448,7 +1465,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( ) channel = _ch(bus, port=port) - channel._api_tokens["tok"] = time.monotonic() + 300 + channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300 server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -1733,7 +1750,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( async def test_commands_api_returns_slash_command_metadata(bus: MagicMock) -> None: port = 29892 channel = _ch(bus, port=port) - channel._api_tokens["tok"] = time.monotonic() + 300 + channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300 server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -1771,7 +1788,7 @@ async def test_bootstrap_exposes_native_surface(bus: MagicMock) -> None: "websocketRequiresToken": True, }, bus, - http_handler=_basic_handler(bus, runtime_surface="native", runtime_capabilities_overrides={"can_pick_folder": True}), + gateway=_basic_handler(bus, runtime_surface="native", runtime_capabilities_overrides={"can_pick_folder": True}), ) server_task = asyncio.create_task(channel.start()) @@ -1941,8 +1958,9 @@ async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None: try: # Fill issued tokens to capacity - channel._http.issued_tokens = { - f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._http._MAX_ISSUED_TOKENS) + channel.gateway.tokens.issued_tokens = { + f"nbwt_fill_{i}": time.monotonic() + 300 + for i in range(channel.gateway.tokens.max_tokens) } resp = await _http_get( @@ -2299,10 +2317,8 @@ def test_sessions_list_includes_active_run_started_at() -> None: from nanobot.session import webui_turns as wth bus = MagicMock() - channel = _ch(bus) - channel._api_tokens["tok"] = time.monotonic() + 300.0 - channel._session_manager = MagicMock() - channel._session_manager.list_sessions.return_value = [ + session_manager = MagicMock() + session_manager.list_sessions.return_value = [ { "key": "websocket:chat-1", "created_at": "2026-05-19T10:00:00Z", @@ -2317,19 +2333,25 @@ def test_sessions_list_includes_active_run_started_at() -> None: "updated_at": "2026-05-19T10:01:00Z", }, ] + channel = WebSocketChannel( + {"enabled": True, "allowFrom": ["*"]}, + bus, + gateway=_basic_handler(bus, session_manager=session_manager), + ) + channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300.0 wth._WEBSOCKET_TURN_WALL_STARTED_AT.clear() try: wth._WEBSOCKET_TURN_WALL_STARTED_AT["chat-1"] = 1_700_000_000.0 req = Request("/api/sessions", Headers([("Authorization", "Bearer tok")])) - resp = channel._handle_sessions_list(req) + resp = channel.gateway.http._handle_sessions_list(req) finally: wth._WEBSOCKET_TURN_WALL_STARTED_AT.clear() assert resp.status_code == 200 body = json.loads(resp.body.decode()) workspace_scope = body["sessions"][0].pop("workspace_scope") - assert workspace_scope["project_path"] == str(channel._workspace_path) + assert workspace_scope["project_path"] == str(channel.gateway.media.workspace_path) assert workspace_scope["access_mode"] in {"restricted", "full"} assert body["sessions"] == [ { @@ -2376,10 +2398,10 @@ def test_handle_webui_thread_get_returns_json(tmp_path, monkeypatch) -> None: append_transcript_object(key, {"event": "user", "chat_id": "c1", "text": "hi"}) bus = MagicMock() channel = _ch(bus) - channel._api_tokens["tok"] = time.monotonic() + 300.0 + channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300.0 enc = quote(key, safe="") req = Request(f"/api/sessions/{enc}/webui-thread", Headers([("Authorization", "Bearer tok")])) - resp = channel._handle_webui_thread_get(req, enc) + resp = channel.gateway.http._handle_webui_thread_get(req, enc) assert resp.status_code == 200 body = json.loads(resp.body.decode()) assert body["sessionKey"] == key diff --git a/tests/channels/test_websocket_envelope_media.py b/tests/channels/test_websocket_envelope_media.py index 4d840f263..0b67320da 100644 --- a/tests/channels/test_websocket_envelope_media.py +++ b/tests/channels/test_websocket_envelope_media.py @@ -21,7 +21,7 @@ from nanobot.channels.websocket import ( WebSocketConfig, _extract_data_url_mime, ) -from nanobot.webui.ws_http import GatewayHTTPHandler +from nanobot.webui.gateway_services import build_gateway_services def _tiny_png_data_url() -> str: @@ -45,17 +45,18 @@ def _make_channel() -> WebSocketChannel: bus.publish_inbound = AsyncMock() cfg = {"enabled": True, "allowFrom": ["*"], "websocketRequiresToken": False} parsed = WebSocketConfig.model_validate(cfg) - handler = GatewayHTTPHandler( + gateway = build_gateway_services( config=parsed, + bus=bus, session_manager=None, static_dist_path=None, workspace_path=Path.cwd(), + default_restrict_to_workspace=False, runtime_model_name=None, runtime_surface="browser", runtime_capabilities_overrides=None, - bus=bus, ) - channel = WebSocketChannel(cfg, bus, http_handler=handler) + channel = WebSocketChannel(cfg, bus, gateway=gateway) channel._handle_message = AsyncMock() # type: ignore[method-assign] return channel diff --git a/tests/channels/test_websocket_http_routes.py b/tests/channels/test_websocket_http_routes.py index f58f72781..3eee4074c 100644 --- a/tests/channels/test_websocket_http_routes.py +++ b/tests/channels/test_websocket_http_routes.py @@ -13,7 +13,7 @@ import pytest from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig from nanobot.session.manager import Session, SessionManager -from nanobot.webui.ws_http import GatewayHTTPHandler +from nanobot.webui.gateway_services import GatewayServices, build_gateway_services _PORT = 29900 @@ -25,18 +25,19 @@ def _make_handler( session_manager: SessionManager | None = None, static_dist_path: Path | None = None, runtime_model_name: Any | None = None, -) -> GatewayHTTPHandler: +) -> GatewayServices: config = WebSocketConfig.model_validate(cfg) if isinstance(cfg, dict) else cfg workspace = Path.cwd() - return GatewayHTTPHandler( + return build_gateway_services( config=config, + bus=bus, session_manager=session_manager, static_dist_path=static_dist_path, workspace_path=workspace, + default_restrict_to_workspace=False, runtime_model_name=runtime_model_name, runtime_surface="browser", runtime_capabilities_overrides=None, - bus=bus, ) @@ -58,13 +59,13 @@ def _ch( "websocketRequiresToken": False, } cfg.update(extra) - http_handler = _make_handler( + gateway = _make_handler( cfg, bus, session_manager=session_manager, static_dist_path=static_dist_path, runtime_model_name=runtime_model_name, ) - return WebSocketChannel(cfg, bus, http_handler=http_handler) + return WebSocketChannel(cfg, bus, gateway=gateway) @pytest.fixture() @@ -729,20 +730,20 @@ async def test_api_token_pool_purges_expired(bus: MagicMock, tmp_path: Path) -> channel = _ch(bus, session_manager=sm, port=29908) # Don't start a server — directly inject and validate. import time as _time - channel._http.api_tokens["expired"] = _time.monotonic() - 1 - channel._http.api_tokens["live"] = _time.monotonic() + 60 + channel.gateway.tokens.api_tokens["expired"] = _time.monotonic() - 1 + channel.gateway.tokens.api_tokens["live"] = _time.monotonic() + 60 class _FakeReq: path = "/api/sessions" headers = {"Authorization": "Bearer expired"} - assert channel._http.check_api_token(_FakeReq()) is False + assert channel.gateway.tokens.check_api_token(_FakeReq()) is False class _LiveReq: path = "/api/sessions" headers = {"Authorization": "Bearer live"} - assert channel._http.check_api_token(_LiveReq()) is True + assert channel.gateway.tokens.check_api_token(_LiveReq()) is True class _FakeConn: @@ -797,7 +798,7 @@ def test_wildcard_ipv6_without_auth_raises(bus: MagicMock) -> None: def test_wildcard_ipv6_with_secret_is_valid(bus: MagicMock) -> None: channel = _ch(bus, host="::", tokenIssueSecret="s3cret") - resp = channel._handle_bootstrap( + resp = channel.gateway.http._handle_bootstrap( _REMOTE, _FakeReq({"X-Nanobot-Auth": "s3cret"}) ) assert resp.status_code == 200 @@ -806,7 +807,7 @@ def test_wildcard_ipv6_with_secret_is_valid(bus: MagicMock) -> None: def test_bootstrap_accepts_static_token_as_secret(bus: MagicMock) -> None: """When only token (not token_issue_secret) is set, bootstrap accepts it.""" channel = _ch(bus, host="0.0.0.0", token="static-tok") - resp = channel._handle_bootstrap( + resp = channel.gateway.http._handle_bootstrap( _REMOTE, _FakeReq({"Authorization": "Bearer static-tok"}) ) assert resp.status_code == 200 @@ -816,7 +817,7 @@ def test_bootstrap_accepts_static_token_as_secret(bus: MagicMock) -> None: def test_bootstrap_ws_url_uses_forwarded_https_host(bus: MagicMock) -> None: channel = _ch(bus, host="127.0.0.1", port=29931) - resp = channel._handle_bootstrap( + resp = channel.gateway.http._handle_bootstrap( _LOCAL, _FakeReq({"Host": "nanobot.example", "X-Forwarded-Proto": "https"}), ) @@ -827,7 +828,7 @@ def test_bootstrap_ws_url_uses_forwarded_https_host(bus: MagicMock) -> None: def test_localhost_without_auth_is_valid(bus: MagicMock) -> None: channel = _ch(bus, host="127.0.0.1") - resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS) + resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS) assert resp.status_code == 200 @@ -837,7 +838,7 @@ def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytes lambda: "from-disk", ) channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " live/model ") - resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS) + resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS) assert resp.status_code == 200 body = json.loads(resp.body) assert body["model_name"] == "live/model" @@ -849,7 +850,7 @@ def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeyp lambda: "from-disk", ) channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " ") - resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS) + resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS) assert resp.status_code == 200 body = json.loads(resp.body) assert body["model_name"] == "from-disk" @@ -865,7 +866,7 @@ def test_bootstrap_falls_back_when_runtime_raises(bus: MagicMock, monkeypatch: p raise RuntimeError("resolver failed") channel = _ch(bus, host="127.0.0.1", runtime_model_name=boom) - resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS) + resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS) assert resp.status_code == 200 body = json.loads(resp.body) assert body["model_name"] == "from-disk" @@ -873,7 +874,7 @@ def test_bootstrap_falls_back_when_runtime_raises(bus: MagicMock, monkeypatch: p def test_bootstrap_rejects_wrong_secret(bus: MagicMock) -> None: channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="correct") - resp = channel._handle_bootstrap( + resp = channel.gateway.http._handle_bootstrap( _REMOTE, _FakeReq({"Authorization": "Bearer wrong"}) ) assert resp.status_code == 401 @@ -881,7 +882,7 @@ def test_bootstrap_rejects_wrong_secret(bus: MagicMock) -> None: def test_bootstrap_accepts_remote_with_valid_secret(bus: MagicMock) -> None: channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="s3cret") - resp = channel._handle_bootstrap( + resp = channel.gateway.http._handle_bootstrap( _REMOTE, _FakeReq({"Authorization": "Bearer s3cret"}) ) assert resp.status_code == 200 @@ -891,7 +892,7 @@ def test_bootstrap_accepts_remote_with_valid_secret(bus: MagicMock) -> None: def test_bootstrap_accepts_x_nanobot_auth_header(bus: MagicMock) -> None: channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="s3cret") - resp = channel._handle_bootstrap( + resp = channel.gateway.http._handle_bootstrap( _REMOTE, _FakeReq({"X-Nanobot-Auth": "s3cret"}) ) assert resp.status_code == 200 @@ -900,5 +901,5 @@ def test_bootstrap_accepts_x_nanobot_auth_header(bus: MagicMock) -> None: def test_bootstrap_secret_also_enforced_on_localhost(bus: MagicMock) -> None: """When secret is set, even localhost must provide it (reverse-proxy safety).""" channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="s3cret") - resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS) + resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS) assert resp.status_code == 401 diff --git a/tests/channels/test_websocket_integration.py b/tests/channels/test_websocket_integration.py index cb0bf7606..24bf9f4c4 100644 --- a/tests/channels/test_websocket_integration.py +++ b/tests/channels/test_websocket_integration.py @@ -7,19 +7,18 @@ multi-client scenarios, edge cases, and realistic usage patterns. from __future__ import annotations import asyncio -import json from pathlib import Path from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest import websockets - -from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig -from nanobot.bus.events import OutboundMessage -from nanobot.webui.ws_http import GatewayHTTPHandler from ws_test_client import WsTestClient, issue_token, issue_token_ok +from nanobot.bus.events import OutboundMessage +from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig +from nanobot.webui.gateway_services import build_gateway_services + def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel: cfg: dict[str, Any] = { @@ -32,17 +31,18 @@ def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel: } cfg.update(kw) parsed = WebSocketConfig.model_validate(cfg) - handler = GatewayHTTPHandler( + gateway = build_gateway_services( config=parsed, + bus=bus, session_manager=None, static_dist_path=None, workspace_path=Path.cwd(), + default_restrict_to_workspace=False, runtime_model_name=None, runtime_surface="browser", runtime_capabilities_overrides=None, - bus=bus, ) - return WebSocketChannel(cfg, bus, http_handler=handler) + return WebSocketChannel(cfg, bus, gateway=gateway) @pytest.fixture() @@ -67,7 +67,8 @@ async def test_ready_event_fields(bus: MagicMock) -> None: assert len(r.chat_id) == 36 assert r.client_id == "c1" finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -80,7 +81,8 @@ async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None: r = await c.recv_ready() assert r.client_id.startswith("anon-") finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -93,7 +95,8 @@ async def test_each_connection_unique_chat_id(bus: MagicMock) -> None: async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2: assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id finally: - await ch.stop(); await t + await ch.stop() + await t # -- Inbound messages (client -> server) ---------------------------------- @@ -113,7 +116,8 @@ async def test_plain_text(bus: MagicMock) -> None: assert inbound.content == "hello world" assert inbound.sender_id == "p" finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -128,7 +132,8 @@ async def test_json_content_field(bus: MagicMock) -> None: await asyncio.sleep(0.1) assert bus.publish_inbound.call_args[0][0].content == "structured" finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -146,7 +151,8 @@ async def test_json_text_and_message_fields(bus: MagicMock) -> None: await asyncio.sleep(0.1) assert bus.publish_inbound.call_args[0][0].content == "via message" finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -162,7 +168,8 @@ async def test_empty_payload_ignored(bus: MagicMock) -> None: await asyncio.sleep(0.1) bus.publish_inbound.assert_not_awaited() finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -179,7 +186,8 @@ async def test_messages_preserve_order(bus: MagicMock) -> None: contents = [call[0][0].content for call in bus.publish_inbound.call_args_list] assert contents == [f"msg-{i}" for i in range(5)] finally: - await ch.stop(); await t + await ch.stop() + await t # -- Outbound messages (server -> client) --------------------------------- @@ -199,7 +207,8 @@ async def test_server_send_message(bus: MagicMock) -> None: msg = await c.recv_message() assert msg.text == "reply" finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -238,7 +247,8 @@ async def test_server_send_tags_tool_hint_with_kind(bus: MagicMock) -> None: prog = await c.recv_message() assert prog.raw.get("kind") == "progress" finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -258,7 +268,8 @@ async def test_server_send_with_media_and_reply(bus: MagicMock) -> None: assert msg.media == ["/tmp/a.png"] assert msg.reply_to == "m1" finally: - await ch.stop(); await t + await ch.stop() + await t # -- Streaming ------------------------------------------------------------ @@ -282,7 +293,8 @@ async def test_streaming_deltas_and_end(bus: MagicMock) -> None: ends = [m for m in msgs if m.event == "stream_end"] assert len(ends) == 1 finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -306,7 +318,8 @@ async def test_interleaved_streams(bus: MagicMock) -> None: assert sa == "A1A2" assert sb == "B1B2" finally: - await ch.stop(); await t + await ch.stop() + await t # -- Multi-client --------------------------------------------------------- @@ -330,7 +343,8 @@ async def test_independent_sessions(bus: MagicMock) -> None: )) assert (await c2.recv_message()).text == "for-u2" finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -348,7 +362,8 @@ async def test_disconnected_client_cleanup(bus: MagicMock) -> None: )) assert chat_id not in ch._subs finally: - await ch.stop(); await t + await ch.stop() + await t # -- Authentication ------------------------------------------------------- @@ -363,7 +378,8 @@ async def test_static_token_accepted(bus: MagicMock) -> None: async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c: assert (await c.recv_ready()).client_id == "a" finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -377,7 +393,8 @@ async def test_static_token_rejected(bus: MagicMock) -> None: pass assert exc.value.response.status_code == 401 finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -411,7 +428,8 @@ async def test_token_issue_full_flow(bus: MagicMock) -> None: pass assert exc.value.response.status_code == 401 finally: - await ch.stop(); await t + await ch.stop() + await t # -- Path routing --------------------------------------------------------- @@ -426,7 +444,8 @@ async def test_custom_path(bus: MagicMock) -> None: async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c: assert (await c.recv_ready()).event == "ready" finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -440,7 +459,8 @@ async def test_wrong_path_404(bus: MagicMock) -> None: pass assert exc.value.response.status_code == 404 finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -452,7 +472,8 @@ async def test_trailing_slash_normalized(bus: MagicMock) -> None: async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c: assert (await c.recv_ready()).event == "ready" finally: - await ch.stop(); await t + await ch.stop() + await t # -- Edge cases ----------------------------------------------------------- @@ -471,7 +492,8 @@ async def test_large_message(bus: MagicMock) -> None: await asyncio.sleep(0.2) assert bus.publish_inbound.call_args[0][0].content == big finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -491,7 +513,8 @@ async def test_unicode_roundtrip(bus: MagicMock) -> None: )) assert (await c.recv_message()).text == text finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -513,7 +536,8 @@ async def test_rapid_fire(bus: MagicMock) -> None: received = [(await c.recv_message()).text for _ in range(50)] assert received == [f"out-{i}" for i in range(50)] finally: - await ch.stop(); await t + await ch.stop() + await t @pytest.mark.asyncio @@ -528,4 +552,5 @@ async def test_invalid_json_as_plain_text(bus: MagicMock) -> None: await asyncio.sleep(0.1) assert bus.publish_inbound.call_args[0][0].content == "{broken json" finally: - await ch.stop(); await t + await ch.stop() + await t diff --git a/tests/channels/test_websocket_media_route.py b/tests/channels/test_websocket_media_route.py index 71bca63b7..d539dd914 100644 --- a/tests/channels/test_websocket_media_route.py +++ b/tests/channels/test_websocket_media_route.py @@ -2,8 +2,8 @@ integration on ``/api/sessions//messages``. The route is the return path for images attached to persisted user turns: -:meth:`WebSocketChannel._sign_media_path` mints URLs during session reads, -and :meth:`WebSocketChannel._handle_media_fetch` serves the bytes back. +:meth:`WebSocketChannel.gateway.media.sign_media_path` mints URLs during session reads, +and :meth:`GatewayHTTPHandler._handle_media_fetch` serves the bytes back. These tests cover the two halves end-to-end plus the adversarial edges (bad signatures, ``..`` traversal, non-existent files, non-image types). """ @@ -22,13 +22,12 @@ import httpx import pytest from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig +from nanobot.session.manager import Session, SessionManager +from nanobot.webui.gateway_services import build_gateway_services from nanobot.webui.media_api import ( b64url_decode, b64url_encode, ) -from nanobot.session.manager import Session, SessionManager -from nanobot.webui.ws_http import GatewayHTTPHandler - # PNG magic bytes + a couple of sentinel bytes so we can verify byte-for-byte # round-trip of the served payload. Stays under mimetype + size limits. @@ -57,17 +56,18 @@ def _ch( "websocketRequiresToken": False, } parsed = WebSocketConfig.model_validate(cfg) - http_handler = GatewayHTTPHandler( + gateway = build_gateway_services( config=parsed, + bus=bus, session_manager=session_manager, static_dist_path=None, workspace_path=workspace_path or Path.cwd(), + default_restrict_to_workspace=False, runtime_model_name=None, runtime_surface="browser", runtime_capabilities_overrides=None, - bus=bus, ) - return WebSocketChannel(cfg, bus, http_handler=http_handler) + return WebSocketChannel(cfg, bus, gateway=gateway) @pytest.fixture() @@ -95,7 +95,7 @@ async def _http_get( # --------------------------------------------------------------------------- -# _sign_media_path: the URL minter +# gateway.media.sign_media_path: the URL minter # --------------------------------------------------------------------------- @@ -114,11 +114,11 @@ def test_sign_media_path_rejects_paths_outside_media_root( media = tmp_path / "media" media.mkdir() channel = _ch(bus, port=0) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): - assert channel._sign_media_path(outside) is None + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): + assert channel.gateway.media.sign_media_path(outside) is None # Traversal via the media root is also rejected — the resolve() step # normalises ``..`` out before the relative_to check. - assert channel._sign_media_path(media / ".." / "secrets" / "cred.txt") is None + assert channel.gateway.media.sign_media_path(media / ".." / "secrets" / "cred.txt") is None def test_sign_media_path_round_trips_via_hmac( @@ -129,13 +129,13 @@ def test_sign_media_path_round_trips_via_hmac( media.mkdir() (media / "a.png").write_bytes(_PNG_BYTES) channel = _ch(bus, port=0) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): - url = channel._sign_media_path(media / "a.png") + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): + url = channel.gateway.media.sign_media_path(media / "a.png") assert url is not None assert url.startswith("/api/media/") sig, payload = url[len("/api/media/"):].split("/", 1) expected = hmac.new( - channel._media_secret, payload.encode("ascii"), hashlib.sha256 + channel.gateway.media.secret, payload.encode("ascii"), hashlib.sha256 ).digest()[:16] assert b64url_decode(sig) == expected # The payload decodes back to the *relative* path — no absolute-path leaks. @@ -152,8 +152,8 @@ def test_local_markdown_image_is_staged_and_rewritten( media = tmp_path / "media" channel = _ch(bus, workspace_path=workspace, port=0) - with patch("nanobot.webui.ws_http.get_media_dir", side_effect=_fake_media_dir(media)): - rewritten = channel._rewrite_local_markdown_images( + with patch("nanobot.webui.media_gateway.get_media_dir", side_effect=_fake_media_dir(media)): + rewritten = channel.gateway.media.rewrite_local_markdown_images( "The result:\n![Cloud Architecture Diagram](demo_arch.png)" ) @@ -174,8 +174,8 @@ def test_local_markdown_video_is_staged_and_rewritten( media = tmp_path / "media" channel = _ch(bus, workspace_path=workspace, port=0) - with patch("nanobot.webui.ws_http.get_media_dir", side_effect=_fake_media_dir(media)): - rewritten = channel._rewrite_local_markdown_images( + with patch("nanobot.webui.media_gateway.get_media_dir", side_effect=_fake_media_dir(media)): + rewritten = channel.gateway.media.rewrite_local_markdown_images( "The result:\n![nanobot-intro.mp4](nanobot-intro.mp4)" ) @@ -197,8 +197,8 @@ def test_local_markdown_image_rejects_workspace_escape( channel = _ch(bus, workspace_path=workspace, port=0) text = "![nope](../outside.png)" - with patch("nanobot.webui.ws_http.get_media_dir", side_effect=_fake_media_dir(media)): - assert channel._rewrite_local_markdown_images(text) == text + with patch("nanobot.webui.media_gateway.get_media_dir", side_effect=_fake_media_dir(media)): + assert channel.gateway.media.rewrite_local_markdown_images(text) == text assert not (media / "websocket").exists() @@ -219,8 +219,8 @@ async def test_media_route_serves_signed_file( target.write_bytes(_PNG_BYTES) channel = _ch(bus, port=29920) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): - url_path = channel._sign_media_path(target) + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): + url_path = channel.gateway.media.sign_media_path(target) assert url_path is not None server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -252,8 +252,8 @@ async def test_media_route_serves_video_byte_ranges( target.write_bytes(b"0123456789") channel = _ch(bus, port=29927) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): - url_path = channel._sign_media_path(target) + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): + url_path = channel.gateway.media.sign_media_path(target) assert url_path is not None server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -284,8 +284,8 @@ async def test_media_route_serves_suffix_video_byte_ranges( target.write_bytes(b"0123456789") channel = _ch(bus, port=29928) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): - url_path = channel._sign_media_path(target) + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): + url_path = channel.gateway.media.sign_media_path(target) assert url_path is not None server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -313,8 +313,8 @@ async def test_media_route_rejects_unsatisfiable_byte_range( target.write_bytes(b"0123456789") channel = _ch(bus, port=29929) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): - url_path = channel._sign_media_path(target) + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): + url_path = channel.gateway.media.sign_media_path(target) assert url_path is not None server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -339,15 +339,15 @@ async def test_media_route_rejects_bad_signature( """A payload re-signed with a different secret must 401. Protects against a restart: old URLs baked into a stale tab become - un-forgeable once ``_media_secret`` regenerates. + un-forgeable once ``gateway.media.secret`` regenerates. """ media = tmp_path / "media" media.mkdir() (media / "f.png").write_bytes(_PNG_BYTES) channel = _ch(bus, port=29921) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): - good = channel._sign_media_path(media / "f.png") + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): + good = channel.gateway.media.sign_media_path(media / "f.png") assert good is not None _, payload = good[len("/api/media/"):].split("/", 1) # Forge a sig with a *different* secret. @@ -385,11 +385,11 @@ async def test_media_route_rejects_path_traversal_payload( # Hand-craft a traversal payload the legit signer would refuse to mint. payload = b64url_encode(b"../secret.txt") mac = hmac.new( - channel._media_secret, payload.encode("ascii"), hashlib.sha256 + channel.gateway.media.secret, payload.encode("ascii"), hashlib.sha256 ).digest()[:16] url = f"/api/media/{b64url_encode(mac)}/{payload}" - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: @@ -413,8 +413,8 @@ async def test_media_route_404s_missing_file( target.write_bytes(_PNG_BYTES) channel = _ch(bus, port=29923) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): - url_path = channel._sign_media_path(target) + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): + url_path = channel.gateway.media.sign_media_path(target) assert url_path is not None target.unlink() # the file vanishes between signing and fetching server_task = asyncio.create_task(channel.start()) @@ -441,10 +441,10 @@ async def test_media_route_degrades_non_image_to_octet_stream( (media / "scary.html").write_bytes(b"") channel = _ch(bus, port=29924) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): payload = b64url_encode(b"scary.html") mac = hmac.new( - channel._media_secret, payload.encode("ascii"), hashlib.sha256 + channel.gateway.media.secret, payload.encode("ascii"), hashlib.sha256 ).digest()[:16] url = f"/api/media/{b64url_encode(mac)}/{payload}" server_task = asyncio.create_task(channel.start()) @@ -472,8 +472,8 @@ async def test_media_route_serves_svg_with_strict_csp( target.write_text("") channel = _ch(bus, port=29928) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): - url_path = channel._sign_media_path(target) + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): + url_path = channel.gateway.media.sign_media_path(target) assert url_path is not None server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -513,7 +513,7 @@ async def test_session_messages_exposes_signed_media_urls( sm.save(sess) channel = _ch(bus, session_manager=sm, port=29925) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: @@ -558,7 +558,7 @@ async def test_session_messages_skips_vanished_media( sm.save(sess) channel = _ch(bus, session_manager=sm, port=29926) - with patch("nanobot.webui.ws_http.get_media_dir", return_value=media): + with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media): server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: