refactor: split WebUI gateway dependencies

Maintainer edit for PR 4115: rebase onto origin/main and split gateway HTTP routing from token, media, and workspace services so WebSocketChannel depends on explicit gateway services instead of GatewayHTTPHandler internals.

Preserve file edit channel capabilities and restore tools.restrict_to_workspace wiring through ChannelManager.
This commit is contained in:
chengyongru 2026-06-02 14:49:06 +08:00 committed by Xubin Ren
parent 2420826e05
commit 2a98360105
14 changed files with 753 additions and 779 deletions

View File

@ -16,7 +16,6 @@ from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.schema import Config
from nanobot.utils.restart import consume_restart_notice_from_env, format_restart_completed_message
from nanobot.webui.ws_http import GatewayHTTPHandler
if TYPE_CHECKING:
from nanobot.session.manager import SessionManager
@ -113,21 +112,24 @@ class ChannelManager:
kwargs: dict[str, Any] = {}
if cls.name == "websocket":
from nanobot.channels.websocket import WebSocketConfig
from nanobot.webui.gateway_services import build_gateway_services
parsed = WebSocketConfig.model_validate(section)
static_path = _default_webui_dist() if self._webui_static_dist else None
workspace = Path(self.config.workspace_path)
http_handler = GatewayHTTPHandler(
gateway = build_gateway_services(
config=parsed,
bus=self.bus,
session_manager=self._session_manager,
static_dist_path=static_path,
workspace_path=workspace,
default_restrict_to_workspace=self.config.tools.restrict_to_workspace,
runtime_model_name=self._webui_runtime_model_name,
runtime_surface=self._webui_runtime_surface,
runtime_capabilities_overrides=self._webui_runtime_capabilities,
bus=self.bus,
logger=logger,
)
kwargs["http_handler"] = http_handler
kwargs["gateway"] = gateway
channel = cls(section, self.bus, **kwargs)
channel.transcription_provider = transcription_provider
channel.transcription_api_key = transcription_key

View File

@ -3,27 +3,20 @@
from __future__ import annotations
import asyncio
import email.utils
import hmac
import http
import json
import re
import ssl
import uuid
from collections.abc import Callable
from contextlib import suppress
from functools import partial
from pathlib import Path
from typing import Any, Self
from urllib.parse import parse_qs, unquote, urlparse
from loguru import logger
from pydantic import Field, field_validator, model_validator
from websockets.asyncio.server import ServerConnection, serve, unix_serve
from websockets.datastructures import Headers
from websockets.exceptions import ConnectionClosed
from websockets.http11 import Request as WsRequest
from websockets.http11 import Response
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
from nanobot.bus.queue import MessageBus
@ -41,53 +34,22 @@ from nanobot.utils.media_decode import (
save_base64_data_url,
)
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
from nanobot.webui.gateway_services import GatewayServices
from nanobot.webui.http_utils import (
is_localhost as _is_localhost,
)
from nanobot.webui.http_utils import (
normalize_config_path as _normalize_config_path,
)
from nanobot.webui.http_utils import (
parse_request_path as _parse_request_path,
)
from nanobot.webui.http_utils import (
query_first as _query_first,
)
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
from nanobot.webui.transcript import append_transcript_object
def _strip_trailing_slash(path: str) -> str:
if len(path) > 1 and path.endswith("/"):
return path.rstrip("/")
return path or "/"
def _normalize_config_path(path: str) -> str:
return _strip_trailing_slash(path)
def _case_insensitive_header(headers: Any, key: str) -> str:
"""Read a header from websockets/http test stubs without assuming casing."""
try:
value = headers.get(key)
except Exception:
value = None
if value is None:
try:
value = headers.get(key.lower())
except Exception:
value = None
return str(value or "").strip()
def _safe_host_header(value: str) -> str:
"""Return a safe Host header value, or empty when it should not be echoed."""
value = value.strip()
if not value:
return ""
if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value):
return value
if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value):
return value
return ""
def _host_for_url(host: str, port: int) -> str:
host = host.strip()
if host in ("0.0.0.0", "::"):
host = "127.0.0.1"
if ":" in host and not host.startswith("["):
host = f"[{host}]"
return f"{host}:{port}"
from nanobot.webui.websocket_logging import websockets_server_logger
class WebSocketConfig(Base):
@ -182,20 +144,6 @@ class WebSocketConfig(Base):
)
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
headers = Headers(
[
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", "application/json; charset=utf-8"),
]
)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, headers, body)
def publish_runtime_model_update(
bus: MessageBus,
model: str,
@ -214,57 +162,6 @@ def publish_runtime_model_update(
))
def _default_model_name_from_config() -> str | None:
"""Resolved model string from on-disk config (bootstrap fallback)."""
try:
from nanobot.config.loader import load_config
model = load_config().resolve_preset().model.strip()
return model or None
except Exception as e:
logger.debug("bootstrap model_name could not load from config: {}", e)
return None
def _resolve_bootstrap_model_name(
runtime_name: Callable[[], str | None] | None,
) -> str | None:
"""Prefer an in-process resolver (e.g. AgentLoop); else config-derived default."""
if runtime_name is not None:
try:
raw = runtime_name()
except Exception as e:
logger.debug("bootstrap runtime model resolver failed: {}", e)
else:
if isinstance(raw, str):
stripped = raw.strip()
if stripped:
return stripped
return _default_model_name_from_config()
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
"""Parse normalized path and query parameters in one pass."""
parsed = urlparse("ws://x" + path_with_query)
path = _strip_trailing_slash(parsed.path or "/")
return path, parse_qs(parsed.query, keep_blank_values=True)
def _normalize_http_path(path_with_query: str) -> str:
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
return _parse_request_path(path_with_query)[0]
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
return _parse_request_path(path_with_query)[1]
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
"""Return the first value for *key*, or None."""
values = query.get(key)
return values[0] if values else None
def _parse_inbound_payload(raw: str) -> str | None:
"""Parse a client frame into text; return None for empty or unrecognized content."""
text = raw.strip()
@ -355,67 +252,6 @@ def _extract_data_url_mime(url: str) -> str | None:
return m.group(1).strip().lower() or None
_LOCALHOSTS = frozenset({"127.0.0.1", "::1", "localhost"})
# Matches the legacy chat-id pattern but allows file-system-safe stems too,
# so the API can address sessions whose keys came from non-WebSocket channels.
_API_KEY_RE = re.compile(r"^[A-Za-z0-9_:.-]{1,128}$")
def _decode_api_key(raw_key: str) -> str | None:
"""Decode a percent-encoded API path segment, then validate the result."""
key = unquote(raw_key)
if _API_KEY_RE.match(key) is None:
return None
return key
def _is_localhost(connection: Any) -> bool:
"""Return True if *connection* originated from the loopback interface."""
addr = getattr(connection, "remote_address", None)
if not addr:
return False
host = addr[0] if isinstance(addr, tuple) else addr
if not isinstance(host, str):
return False
# ``::ffff:127.0.0.1`` is loopback in IPv6-mapped form.
if host.startswith("::ffff:"):
host = host[7:]
return host in _LOCALHOSTS
def _http_response(
body: bytes,
*,
status: int = 200,
content_type: str = "text/plain; charset=utf-8",
extra_headers: list[tuple[str, str]] | None = None,
) -> Response:
headers = [
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", content_type),
]
if extra_headers:
headers.extend(extra_headers)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, Headers(headers), body)
def _http_error(status: int, message: str | None = None) -> Response:
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
return _http_response(body, status=status)
def _bearer_token(headers: Any) -> str | None:
"""Pull a Bearer token out of standard or query-style headers."""
auth = headers.get("Authorization") or headers.get("authorization")
if auth and auth.lower().startswith("bearer "):
return auth[7:].strip() or None
return None
def _is_websocket_upgrade(request: WsRequest) -> bool:
"""Detect an actual WS upgrade; plain HTTP GETs to the same path should fall through."""
upgrade = request.headers.get("Upgrade") or request.headers.get("upgrade")
@ -427,20 +263,6 @@ def _is_websocket_upgrade(request: WsRequest) -> bool:
return True
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
if not configured_secret:
return True
authorization = headers.get("Authorization") or headers.get("authorization")
if authorization and authorization.lower().startswith("bearer "):
supplied = authorization[7:].strip()
return hmac.compare_digest(supplied, configured_secret)
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
if not header_token:
return False
return hmac.compare_digest(header_token.strip(), configured_secret)
class WebSocketChannel(BaseChannel):
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
@ -452,7 +274,7 @@ class WebSocketChannel(BaseChannel):
config: Any,
bus: MessageBus,
*,
http_handler: Any | None = None,
gateway: GatewayServices,
):
if isinstance(config, dict):
config = WebSocketConfig.model_validate(config)
@ -467,11 +289,11 @@ class WebSocketChannel(BaseChannel):
self._stop_event: asyncio.Event | None = None
self._server_task: asyncio.Task[None] | None = None
# HTTP handler injected from outside (ChannelManager / gateway startup).
# Owns tokens, sessions, media, settings, static serving.
self._http = http_handler
# Backwards-compat: workspace controller used in envelope dispatch
self._webui_workspaces = http_handler.workspaces if http_handler else None
self.gateway = gateway
self._http_router = gateway.http
self._tokens = gateway.tokens
self._media = gateway.media
self._workspaces = gateway.workspaces
self._stream_text_buffers: dict[tuple[str, str], list[str]] = {}
@ -501,9 +323,9 @@ class WebSocketChannel(BaseChannel):
connected clients normally see it via ``goal_state`` / ``turn_end`` frames.
Pushing here makes refresh + reconnect restore the strip without a new model turn.
"""
if self._http.session_manager is None:
if self.gateway.session_manager is None:
return
row = self._http.session_manager.read_session_file(f"websocket:{chat_id}")
row = self.gateway.session_manager.read_session_file(f"websocket:{chat_id}")
meta = row.get("metadata", {}) if isinstance(row, dict) else {}
if not isinstance(meta, dict):
meta = {}
@ -543,57 +365,6 @@ class WebSocketChannel(BaseChannel):
def _expected_path(self) -> str:
return _normalize_config_path(self.config.path)
# -- Backwards-compat property aliases (used by tests) ------------------
@property
def _session_manager(self):
return self._http.session_manager
@_session_manager.setter
def _session_manager(self, value):
self._http.session_manager = value
@property
def _media_secret(self):
return self._http.media_secret
@property
def _issued_tokens(self):
return self._http.issued_tokens
@_issued_tokens.setter
def _issued_tokens(self, value):
self._http.issued_tokens = value
@property
def _api_tokens(self):
return self._http.api_tokens
def _check_api_token(self, request):
return self._http.check_api_token(request)
def _sign_media_path(self, path):
return self._http.sign_media_path(path)
def _sign_or_stage_media_path(self, path):
return self._http.sign_or_stage_media_path(path)
def _rewrite_local_markdown_images(self, text):
return self._http.rewrite_local_markdown_images(text)
def _handle_bootstrap(self, connection, request):
return self._http._handle_bootstrap(connection, request)
def _handle_sessions_list(self, request):
return self._http._handle_sessions_list(request)
def _handle_webui_thread_get(self, request, key):
return self._http._handle_webui_thread_get(request, key)
@property
def _workspace_path(self):
return self._http.workspace_path
def _build_ssl_context(self) -> ssl.SSLContext | None:
cert = self.config.ssl_certfile.strip()
key = self.config.ssl_keyfile.strip()
@ -624,8 +395,8 @@ class WebSocketChannel(BaseChannel):
return connection.respond(403, "Forbidden")
return self._authorize_websocket_handshake(connection, query)
<< # Everything else goes to the HTTP handler
return await self._http.dispatch(connection, request)
# Everything else goes to the HTTP handler
return await self._http_router.dispatch(connection, request)
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
supplied = _query_first(query, "token")
@ -634,17 +405,17 @@ class WebSocketChannel(BaseChannel):
if static_token:
if supplied and hmac.compare_digest(supplied, static_token):
return None
if supplied and self._http.take_issued_token_if_valid(supplied):
if supplied and self._tokens.take_issued_token_if_valid(supplied):
return None
return connection.respond(401, "Unauthorized")
if self.config.websocket_requires_token:
if supplied and self._http.take_issued_token_if_valid(supplied):
if supplied and self._tokens.take_issued_token_if_valid(supplied):
return None
return connection.respond(401, "Unauthorized")
if supplied:
self._http.take_issued_token_if_valid(supplied)
self._tokens.take_issued_token_if_valid(supplied)
return None
# -- Server lifecycle and connection ingress ---------------------------
@ -878,14 +649,14 @@ class WebSocketChannel(BaseChannel):
new_id = str(uuid.uuid4())
scope = await self._workspace_scope_or_error(
connection,
lambda: self._webui_workspaces.scope_for_new_chat(
lambda: self._workspaces.scope_for_new_chat(
envelope,
controls_available=_is_localhost(connection),
),
)
if scope is None:
return
self._webui_workspaces.persist_scope(new_id, scope)
self._workspaces.persist_scope(new_id, scope)
self._attach(connection, new_id)
await self._send_event(connection, "attached", chat_id=new_id)
await self._send_event(
@ -913,7 +684,7 @@ class WebSocketChannel(BaseChannel):
return
scope = await self._workspace_scope_or_error(
connection,
lambda: self._webui_workspaces.scope_for_set_request(
lambda: self._workspaces.scope_for_set_request(
envelope,
chat_id=cid,
chat_running=websocket_turn_wall_started_at(cid) is not None,
@ -923,7 +694,7 @@ class WebSocketChannel(BaseChannel):
)
if scope is None:
return
self._webui_workspaces.persist_scope(cid, scope)
self._workspaces.persist_scope(cid, scope)
await self._send_event(
connection,
"session_updated",
@ -965,7 +736,7 @@ class WebSocketChannel(BaseChannel):
return
scope = await self._workspace_scope_or_error(
connection,
lambda: self._webui_workspaces.scope_for_message(
lambda: self._workspaces.scope_for_message(
envelope,
chat_id=cid,
chat_running=websocket_turn_wall_started_at(cid) is not None,
@ -989,7 +760,7 @@ class WebSocketChannel(BaseChannel):
if mcp_presets:
metadata["mcp_presets"] = mcp_presets
metadata[WORKSPACE_SCOPE_METADATA_KEY] = scope.metadata()
self._webui_workspaces.persist_scope(cid, scope)
self._workspaces.persist_scope(cid, scope)
image_generation = envelope.get("image_generation")
if isinstance(image_generation, dict) and image_generation.get("enabled") is True:
aspect_ratio = image_generation.get("aspect_ratio")
@ -1044,8 +815,7 @@ class WebSocketChannel(BaseChannel):
self._subs.clear()
self._conn_chats.clear()
self._conn_default.clear()
self._http.issued_tokens.clear()
self._http.api_tokens.clear()
self._tokens.clear()
async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None:
"""Send a raw frame to one connection, cleaning up on ConnectionClosed."""
@ -1127,7 +897,7 @@ class WebSocketChannel(BaseChannel):
)
return
text = msg.content
wire_text = self._http.rewrite_local_markdown_images(text)
wire_text = self._media.rewrite_local_markdown_images(text)
payload: dict[str, Any] = {
"event": "message",
"chat_id": msg.chat_id,
@ -1137,7 +907,7 @@ class WebSocketChannel(BaseChannel):
payload["media"] = msg.media
urls: list[dict[str, str]] = []
for entry in msg.media:
signed = self._http.sign_or_stage_media_path(Path(entry))
signed = self._media.sign_or_stage_media_path(Path(entry))
if signed is not None:
urls.append(signed)
if urls:
@ -1252,7 +1022,7 @@ class WebSocketChannel(BaseChannel):
if delta:
buffered.append(delta)
full_text = "".join(buffered)
rewritten = self._http.rewrite_local_markdown_images(full_text)
rewritten = self._media.rewrite_local_markdown_images(full_text)
if rewritten != full_text:
body["text"] = rewritten
else:

View File

@ -0,0 +1,70 @@
"""Composition helpers for the embedded WebUI gateway."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from loguru import logger as default_logger
from nanobot.webui.gateway_tokens import GatewayTokenStore
from nanobot.webui.media_gateway import WebUIMediaGateway
from nanobot.webui.workspaces import WebUIWorkspaceController
from nanobot.webui.ws_http import GatewayHTTPHandler
@dataclass(frozen=True)
class GatewayServices:
"""Explicit dependencies shared by WebSocket transport and HTTP routes."""
http: GatewayHTTPHandler
tokens: GatewayTokenStore
media: WebUIMediaGateway
workspaces: WebUIWorkspaceController
session_manager: Any | None
def build_gateway_services(
*,
config: Any,
bus: Any,
session_manager: Any | None,
static_dist_path: Path | None,
workspace_path: Path,
default_restrict_to_workspace: bool,
runtime_model_name: Any | None,
runtime_surface: str,
runtime_capabilities_overrides: dict[str, Any] | None,
logger: Any = default_logger,
) -> GatewayServices:
tokens = GatewayTokenStore()
media = WebUIMediaGateway(
workspace_path=workspace_path,
logger=logger,
)
workspaces = WebUIWorkspaceController(
session_manager=session_manager,
default_workspace=workspace_path,
default_restrict_to_workspace=default_restrict_to_workspace,
)
http = GatewayHTTPHandler(
config=config,
session_manager=session_manager,
static_dist_path=static_dist_path,
runtime_model_name=runtime_model_name,
runtime_surface=runtime_surface,
runtime_capabilities_overrides=runtime_capabilities_overrides,
bus=bus,
tokens=tokens,
media=media,
workspaces=workspaces,
log=logger,
)
return GatewayServices(
http=http,
tokens=tokens,
media=media,
workspaces=workspaces,
session_manager=session_manager,
)

View File

@ -0,0 +1,82 @@
"""Token state for the embedded WebUI gateway."""
from __future__ import annotations
import secrets
import time
from dataclasses import dataclass, field
from typing import Any
from websockets.http11 import Request as WsRequest
from nanobot.webui.http_utils import bearer_token, parse_query, query_first
@dataclass
class GatewayTokenStore:
"""Own short-lived WebSocket and WebUI API tokens for one gateway process."""
max_tokens: int = 10_000
issued_tokens: dict[str, float] = field(default_factory=dict)
api_tokens: dict[str, float] = field(default_factory=dict)
def check_api_token(self, request: WsRequest) -> bool:
self._purge_expired_api_tokens()
token = bearer_token(request.headers) or query_first(
parse_query(request.path), "token"
)
if not token:
return False
expiry = self.api_tokens.get(token)
if expiry is None or time.monotonic() > expiry:
self.api_tokens.pop(token, None)
return False
return True
def can_issue(self, *, include_api_token: bool = False) -> bool:
self._purge_expired_issued_tokens()
self._purge_expired_api_tokens()
if len(self.issued_tokens) >= self.max_tokens:
return False
if include_api_token and len(self.api_tokens) >= self.max_tokens:
return False
return True
def issue_token(self, ttl_s: int | float, *, api_token: bool = False) -> str:
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
expiry = time.monotonic() + float(ttl_s)
self.issued_tokens[token_value] = expiry
if api_token:
self.api_tokens[token_value] = expiry
return token_value
def take_issued_token_if_valid(self, token_value: str | None) -> bool:
if not token_value:
return False
self._purge_expired_issued_tokens()
expiry = self.issued_tokens.pop(token_value, None)
if expiry is None:
return False
if time.monotonic() > expiry:
return False
return True
def clear(self) -> None:
self.issued_tokens.clear()
self.api_tokens.clear()
def _purge_expired_api_tokens(self) -> None:
now = time.monotonic()
for token_key, expiry in list(self.api_tokens.items()):
if now > expiry:
self.api_tokens.pop(token_key, None)
def _purge_expired_issued_tokens(self) -> None:
now = time.monotonic()
for token_key, expiry in list(self.issued_tokens.items()):
if now > expiry:
self.issued_tokens.pop(token_key, None)
def token_response_payload(token: str, expires_in: Any) -> dict[str, Any]:
return {"token": token, "expires_in": expires_in}

151
nanobot/webui/http_utils.py Normal file
View File

@ -0,0 +1,151 @@
"""Shared HTTP helpers for the embedded WebUI gateway."""
from __future__ import annotations
import email.utils
import hmac
import http
import json
import re
from typing import Any
from urllib.parse import parse_qs, urlparse
from websockets.datastructures import Headers
from websockets.http11 import Response
QueryParams = dict[str, list[str]]
def strip_trailing_slash(path: str) -> str:
if len(path) > 1 and path.endswith("/"):
return path.rstrip("/")
return path or "/"
def normalize_config_path(path: str) -> str:
return strip_trailing_slash(path)
def case_insensitive_header(headers: Any, key: str) -> str:
"""Read a header from websockets/http test stubs without assuming casing."""
try:
value = headers.get(key)
except Exception:
value = None
if value is None:
try:
value = headers.get(key.lower())
except Exception:
value = None
return str(value or "").strip()
def safe_host_header(value: str) -> str:
"""Return a safe Host header value, or empty when it should not be echoed."""
value = value.strip()
if not value:
return ""
if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value):
return value
if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value):
return value
return ""
def host_for_url(host: str, port: int) -> str:
host = host.strip()
if host in ("0.0.0.0", "::"):
host = "127.0.0.1"
if ":" in host and not host.startswith("["):
host = f"[{host}]"
return f"{host}:{port}"
def http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
headers = Headers(
[
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", "application/json; charset=utf-8"),
]
)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, headers, body)
def http_response(
body: bytes,
*,
status: int = 200,
content_type: str = "text/plain; charset=utf-8",
extra_headers: list[tuple[str, str]] | None = None,
) -> Response:
headers = [
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", content_type),
]
if extra_headers:
headers.extend(extra_headers)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, Headers(headers), body)
def http_error(status: int, message: str | None = None) -> Response:
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
return http_response(body, status=status)
def parse_request_path(path_with_query: str) -> tuple[str, QueryParams]:
"""Parse normalized path and query parameters in one pass."""
parsed = urlparse("ws://x" + path_with_query)
path = strip_trailing_slash(parsed.path or "/")
return path, parse_qs(parsed.query, keep_blank_values=True)
def normalize_http_path(path_with_query: str) -> str:
return parse_request_path(path_with_query)[0]
def parse_query(path_with_query: str) -> QueryParams:
return parse_request_path(path_with_query)[1]
def query_first(query: QueryParams, key: str) -> str | None:
values = query.get(key)
return values[0] if values else None
def is_localhost(connection: Any) -> bool:
addr = getattr(connection, "remote_address", None)
if not addr:
return False
host = addr[0] if isinstance(addr, tuple) else addr
if not isinstance(host, str):
return False
if host.startswith("::ffff:"):
host = host[7:]
return host in {"127.0.0.1", "::1", "localhost"}
def bearer_token(headers: Any) -> str | None:
auth = headers.get("Authorization") or headers.get("authorization")
if auth and auth.lower().startswith("bearer "):
return auth[7:].strip() or None
return None
def issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
if not configured_secret:
return True
authorization = headers.get("Authorization") or headers.get("authorization")
if authorization and authorization.lower().startswith("bearer "):
supplied = authorization[7:].strip()
return hmac.compare_digest(supplied, configured_secret)
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
if not header_token:
return False
return hmac.compare_digest(header_token.strip(), configured_secret)

View File

@ -4,10 +4,8 @@ from __future__ import annotations
import base64
import binascii
import email.utils
import hashlib
import hmac
import http
import mimetypes
import re
import shutil
@ -16,12 +14,20 @@ from collections.abc import Callable
from pathlib import Path
from typing import Any
from websockets.datastructures import Headers
from websockets.http11 import Request as WsRequest
from websockets.http11 import Response
from nanobot.config.paths import get_media_dir
from nanobot.utils.helpers import safe_filename
from nanobot.webui.http_utils import (
case_insensitive_header as _case_insensitive_header,
)
from nanobot.webui.http_utils import (
http_error as _http_error,
)
from nanobot.webui.http_utils import (
http_response as _http_response,
)
MediaDirProvider = Callable[[str | None], Path]
SignedMediaPath = Callable[[Path], dict[str, str] | None]
@ -67,43 +73,6 @@ _SVG_MEDIA_HEADERS: tuple[tuple[str, str], ...] = (
_BYTE_RANGE_RE = re.compile(r"^bytes=(\d*)-(\d*)$")
def _http_response(
body: bytes,
*,
status: int = 200,
content_type: str = "text/plain; charset=utf-8",
extra_headers: list[tuple[str, str]] | None = None,
) -> Response:
headers = [
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", content_type),
]
if extra_headers:
headers.extend(extra_headers)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, Headers(headers), body)
def _http_error(status: int, message: str | None = None) -> Response:
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
return _http_response(body, status=status)
def _case_insensitive_header(headers: Any, key: str) -> str:
try:
value = headers.get(key)
except Exception:
value = None
if value is None:
try:
value = headers.get(key.lower())
except Exception:
value = None
return str(value or "").strip()
def _parse_single_byte_range(range_header: str, size: int) -> tuple[int, int]:
"""Parse a single HTTP byte range for signed media responses."""
if size <= 0 or "," in range_header:

View File

@ -0,0 +1,92 @@
"""Media gateway services shared by WebUI HTTP routes and WebSocket frames."""
from __future__ import annotations
import secrets
from collections.abc import Callable
from pathlib import Path
from typing import Any
from websockets.http11 import Request as WsRequest
from websockets.http11 import Response
from nanobot.config.paths import get_media_dir
from nanobot.webui.media_api import (
attach_signed_media_urls,
serve_signed_media,
sign_media_path,
sign_or_stage_media_path,
signed_media_attachments,
)
from nanobot.webui.transcript import rewrite_local_markdown_images
class WebUIMediaGateway:
"""Own media URL signing and WebUI markdown/media augmentation."""
def __init__(
self,
*,
workspace_path: Path,
logger: Any,
media_dir: Callable[[str | None], Path] | None = None,
secret: bytes | None = None,
) -> None:
self.workspace_path = workspace_path
self.logger = logger
self._media_dir = media_dir or (lambda channel=None: get_media_dir(channel))
self.secret = secret or secrets.token_bytes(32)
def serve_signed_media(
self,
sig: str,
payload: str,
*,
request: WsRequest | None = None,
) -> Response:
return serve_signed_media(
sig,
payload,
secret=self.secret,
request=request,
media_dir=self._media_dir,
)
def sign_media_path(self, abs_path: Path) -> str | None:
return sign_media_path(
abs_path,
secret=self.secret,
media_dir=self._media_dir,
)
def sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
return sign_or_stage_media_path(
path,
secret=self.secret,
media_dir=self._media_dir,
logger=self.logger,
)
def rewrite_local_markdown_images(
self,
text: str,
*,
workspace_path: Path | None = None,
) -> str:
return rewrite_local_markdown_images(
text,
workspace_path=workspace_path or self.workspace_path,
sign_path=self.sign_or_stage_media_path,
)
def augment_media_urls(self, payload: dict[str, Any]) -> None:
attach_signed_media_urls(payload, sign_path=self.sign_media_path)
def augment_transcript_media(self, paths: list[str]) -> list[dict[str, Any]]:
return signed_media_attachments(
paths,
sign_path=self.sign_or_stage_media_path,
)
def augment_transcript_user_media(self, paths: list[str]) -> list[dict[str, Any]]:
return self.augment_transcript_media(paths)

View File

@ -9,187 +9,70 @@ Also houses shared HTTP utility functions used by both this module and
from __future__ import annotations
import email.utils
import hmac
import http
import json
import mimetypes
import re
import secrets
import time
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import parse_qs, urlparse
from loguru import logger
from websockets.datastructures import Headers
from websockets.http11 import Request as WsRequest
from websockets.http11 import Response
from nanobot.command.builtin import builtin_command_palette
from nanobot.config.paths import get_media_dir
from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel
from nanobot.webui.media_api import (
serve_signed_media,
sign_media_path,
sign_or_stage_media_path,
from nanobot.webui.gateway_tokens import GatewayTokenStore, token_response_payload
from nanobot.webui.http_utils import (
case_insensitive_header as _case_insensitive_header,
)
from nanobot.webui.http_utils import (
host_for_url as _host_for_url,
)
from nanobot.webui.http_utils import (
http_error as _http_error,
)
from nanobot.webui.http_utils import (
http_json_response as _http_json_response,
)
from nanobot.webui.http_utils import (
http_response as _http_response,
)
from nanobot.webui.http_utils import (
is_localhost as _is_localhost,
)
from nanobot.webui.http_utils import (
issue_route_secret_matches as _issue_route_secret_matches,
)
from nanobot.webui.http_utils import (
normalize_config_path as _normalize_config_path,
)
from nanobot.webui.http_utils import (
parse_query as _parse_query,
)
from nanobot.webui.http_utils import (
parse_request_path as _parse_request_path,
)
from nanobot.webui.http_utils import (
query_first as _query_first,
)
from nanobot.webui.http_utils import (
safe_host_header as _safe_host_header,
)
from nanobot.webui.media_gateway import WebUIMediaGateway
from nanobot.webui.sidebar_state import (
read_webui_sidebar_state,
write_webui_sidebar_state,
)
from nanobot.webui.thread_disk import delete_webui_thread
from nanobot.webui.transcript import (
build_webui_thread_response,
rewrite_local_markdown_images,
)
from nanobot.webui.transcript import build_webui_thread_response
from nanobot.webui.workspaces import WebUIWorkspaceController
if TYPE_CHECKING:
from nanobot.bus.queue import MessageBus
from nanobot.session.manager import SessionManager
# ---------------------------------------------------------------------------
# Shared HTTP utility functions (imported by websocket.py)
# ---------------------------------------------------------------------------
def _strip_trailing_slash(path: str) -> str:
if len(path) > 1 and path.endswith("/"):
return path.rstrip("/")
return path or "/"
def _normalize_config_path(path: str) -> str:
return _strip_trailing_slash(path)
def _case_insensitive_header(headers: Any, key: str) -> str:
"""Read a header from websockets/http test stubs without assuming casing."""
try:
value = headers.get(key)
except Exception:
value = None
if value is None:
try:
value = headers.get(key.lower())
except Exception:
value = None
return str(value or "").strip()
def _safe_host_header(value: str) -> str:
"""Return a safe Host header value, or empty when it should not be echoed."""
value = value.strip()
if not value:
return ""
if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value):
return value
if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value):
return value
return ""
def _host_for_url(host: str, port: int) -> str:
host = host.strip()
if host in ("0.0.0.0", "::"):
host = "127.0.0.1"
if ":" in host and not host.startswith("["):
host = f"[{host}]"
return f"{host}:{port}"
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
headers = Headers(
[
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", "application/json; charset=utf-8"),
]
)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, headers, body)
def _http_response(
body: bytes,
*,
status: int = 200,
content_type: str = "text/plain; charset=utf-8",
extra_headers: list[tuple[str, str]] | None = None,
) -> Response:
headers = [
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", content_type),
]
if extra_headers:
headers.extend(extra_headers)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, Headers(headers), body)
def _http_error(status: int, message: str | None = None) -> Response:
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
return _http_response(body, status=status)
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
"""Parse normalized path and query parameters in one pass."""
parsed = urlparse("ws://x" + path_with_query)
path = _strip_trailing_slash(parsed.path or "/")
return path, parse_qs(parsed.query, keep_blank_values=True)
def _normalize_http_path(path_with_query: str) -> str:
return _parse_request_path(path_with_query)[0]
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
return _parse_request_path(path_with_query)[1]
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
values = query.get(key)
return values[0] if values else None
def _is_localhost(connection: Any) -> bool:
addr = getattr(connection, "remote_address", None)
if not addr:
return False
host = addr[0] if isinstance(addr, tuple) else addr
if not isinstance(host, str):
return False
if host.startswith("::ffff:"):
host = host[7:]
return host in {"127.0.0.1", "::1", "localhost"}
def _bearer_token(headers: Any) -> str | None:
auth = headers.get("Authorization") or headers.get("authorization")
if auth and auth.lower().startswith("bearer "):
return auth[7:].strip() or None
return None
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
if not configured_secret:
return True
authorization = headers.get("Authorization") or headers.get("authorization")
if authorization and authorization.lower().startswith("bearer "):
supplied = authorization[7:].strip()
return hmac.compare_digest(supplied, configured_secret)
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
if not header_token:
return False
return hmac.compare_digest(header_token.strip(), configured_secret)
def _decode_api_key(raw_key: str) -> str | None:
from urllib.parse import unquote
@ -234,48 +117,36 @@ def _resolve_bootstrap_model_name(
class GatewayHTTPHandler:
"""Handles all HTTP routes served alongside the WebSocket endpoint.
Owns token management, session API, media API, static file serving,
and delegates settings routes to ``WebUISettingsRouter``.
Routes HTTP requests and delegates stateful work to explicit gateway
services owned by the composition layer.
"""
_MAX_ISSUED_TOKENS = 10_000
def __init__(
self,
*,
config: Any, # WebSocketConfig
session_manager: SessionManager | None,
static_dist_path: Path | None,
workspace_path: Path,
runtime_model_name: Callable[[], str | None] | None,
runtime_surface: str,
runtime_capabilities_overrides: dict[str, Any] | None,
bus: MessageBus,
tokens: GatewayTokenStore,
media: WebUIMediaGateway,
workspaces: WebUIWorkspaceController,
log: Any = logger,
) -> None:
self.config = config
self.session_manager = session_manager
self.static_dist_path = static_dist_path
self.workspace_path = workspace_path
self.runtime_model_name = runtime_model_name
self.bus = bus
self.tokens = tokens
self.media = media
self.workspaces = workspaces
self._log = log
self._runtime_surface = runtime_surface
self.issued_tokens: dict[str, float] = {}
self.api_tokens: dict[str, float] = {}
self.media_secret: bytes = secrets.token_bytes(32)
# Workspace controller
from nanobot.webui.workspaces import WebUIWorkspaceController
self.workspaces = WebUIWorkspaceController(
session_manager=session_manager,
default_workspace=workspace_path,
default_restrict_to_workspace=None,
)
# Settings router
from nanobot.webui.settings_api import runtime_capabilities as _rc
from nanobot.webui.settings_routes import WebUISettingsRouter
@ -294,46 +165,13 @@ class GatewayHTTPHandler:
# -- Token management ---------------------------------------------------
def check_api_token(self, request: WsRequest) -> bool:
self._purge_expired_api_tokens()
token = _bearer_token(request.headers) or _query_first(
_parse_query(request.path), "token"
)
if not token:
return False
expiry = self.api_tokens.get(token)
if expiry is None or time.monotonic() > expiry:
self.api_tokens.pop(token, None)
return False
return True
def _purge_expired_api_tokens(self) -> None:
now = time.monotonic()
for token_key, expiry in list(self.api_tokens.items()):
if now > expiry:
self.api_tokens.pop(token_key, None)
def _purge_expired_issued_tokens(self) -> None:
now = time.monotonic()
for token_key, expiry in list(self.issued_tokens.items()):
if now > expiry:
self.issued_tokens.pop(token_key, None)
def take_issued_token_if_valid(self, token_value: str | None) -> bool:
if not token_value:
return False
self._purge_expired_issued_tokens()
expiry = self.issued_tokens.pop(token_value, None)
if expiry is None:
return False
if time.monotonic() > expiry:
return False
return True
return self.tokens.check_api_token(request)
# -- Main dispatch ------------------------------------------------------
async def dispatch(self, connection: Any, request: WsRequest) -> Any | None:
"""Route an HTTP request. Returns Response or None."""
got, query = _parse_request_path(request.path)
got, _ = _parse_request_path(request.path)
# Token issue endpoint
if self.config.token_issue_path:
@ -389,18 +227,14 @@ class GatewayHTTPHandler:
"token_issue_path is set but token_issue_secret is empty; "
"any client can obtain connection tokens — set token_issue_secret for production."
)
self._purge_expired_issued_tokens()
if len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS:
if not self.tokens.can_issue():
self._log.error(
"too many outstanding issued tokens ({}), rejecting issuance",
len(self.issued_tokens),
len(self.tokens.issued_tokens),
)
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
self.issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
return _http_json_response(
{"token": token_value, "expires_in": self.config.token_ttl_s}
)
token_value = self.tokens.issue_token(self.config.token_ttl_s)
return _http_json_response(token_response_payload(token_value, self.config.token_ttl_s))
# -- Bootstrap ----------------------------------------------------------
@ -412,21 +246,13 @@ class GatewayHTTPHandler:
elif not _is_localhost(connection):
return _http_error(403, "bootstrap is localhost-only")
self._purge_expired_issued_tokens()
self._purge_expired_api_tokens()
if (
len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS
or len(self.api_tokens) >= self._MAX_ISSUED_TOKENS
):
if not self.tokens.can_issue(include_api_token=True):
return _http_response(
json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"),
status=429,
content_type="application/json; charset=utf-8",
)
token = f"nbwt_{secrets.token_urlsafe(32)}"
expiry = time.monotonic() + float(self.config.token_ttl_s)
self.issued_tokens[token] = expiry
self.api_tokens[token] = expiry
token = self.tokens.issue_token(self.config.token_ttl_s, api_token=True)
ws_url = self._bootstrap_ws_url(request)
expected_path = _normalize_config_path(self.config.path)
@ -510,7 +336,7 @@ class GatewayHTTPHandler:
messages = data.get("messages")
if isinstance(messages, list):
scrub_subagent_messages_for_channel(messages)
self._augment_media_urls(data)
self.media.augment_media_urls(data)
return _http_json_response(data)
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
@ -524,11 +350,11 @@ class GatewayHTTPHandler:
scope = self.workspaces.scope_for_session_key(decoded_key)
data = build_webui_thread_response(
decoded_key,
augment_user_media=self._augment_transcript_user_media,
augment_assistant_text=lambda text: rewrite_local_markdown_images(
augment_user_media=self.media.augment_transcript_media,
augment_assistant_media=self.media.augment_transcript_media,
augment_assistant_text=lambda text: self.media.rewrite_local_markdown_images(
text,
workspace_path=scope.project_path,
sign_path=self.sign_or_stage_media_path,
),
)
if data is None:
@ -561,34 +387,10 @@ class GatewayHTTPHandler:
def _handle_media_fetch(
self, sig: str, payload: str, request: WsRequest | None = None
) -> Response:
return serve_signed_media(
return self.media.serve_signed_media(
sig,
payload,
secret=self.media_secret,
request=request,
media_dir=lambda channel=None: get_media_dir(channel),
)
def sign_media_path(self, abs_path: Path) -> str | None:
return sign_media_path(
abs_path,
secret=self.media_secret,
media_dir=lambda channel=None: get_media_dir(channel),
)
def sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
return sign_or_stage_media_path(
path,
secret=self.media_secret,
media_dir=lambda channel=None: get_media_dir(channel),
logger=self._log,
)
def rewrite_local_markdown_images(self, text: str) -> str:
return rewrite_local_markdown_images(
text,
workspace_path=self.workspace_path,
sign_path=self.sign_or_stage_media_path,
)
# -- Misc routes --------------------------------------------------------
@ -688,44 +490,5 @@ class GatewayHTTPHandler:
extra_headers=[("Cache-Control", cache)],
)
# -- Media helpers (called by WebSocketChannel.send) --------------------
def _augment_media_urls(self, payload: dict[str, Any]) -> None:
messages = payload.get("messages")
if not isinstance(messages, list):
return
for msg in messages:
if not isinstance(msg, dict):
continue
media = msg.get("media")
if not isinstance(media, list) or not media:
continue
urls: list[dict[str, str]] = []
for entry in media:
if not isinstance(entry, str) or not entry:
continue
signed = self.sign_media_path(Path(entry))
if signed is None:
continue
urls.append({"url": signed, "name": Path(entry).name})
if urls:
msg["media_urls"] = urls
msg.pop("media", None)
def _augment_transcript_user_media(self, paths: list[str]) -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for pstr in paths:
path = Path(pstr)
att = self.sign_or_stage_media_path(path)
if att is None:
continue
mime, _ = mimetypes.guess_type(path.name)
kind = "video" if mime and mime.startswith("video/") else "image"
out.append(
{"kind": kind, "url": att["url"], "name": att.get("name", path.name)},
)
return out
def _is_websocket_channel_session_key(key: str) -> bool:
return key.startswith("websocket:")

View File

@ -65,6 +65,32 @@ def manager() -> ChannelManager:
return mgr
def test_websocket_gateway_uses_configured_workspace_restriction(tmp_path, monkeypatch):
monkeypatch.setattr(
"nanobot.webui.workspaces.read_webui_default_access_mode",
lambda: "default",
)
config = Config.model_validate(
{
"agents": {"defaults": {"workspace": str(tmp_path)}},
"tools": {"restrictToWorkspace": True},
"channels": {
"websocket": {
"enabled": True,
"websocketRequiresToken": False,
},
},
}
)
mgr = ChannelManager(config, MessageBus(), webui_static_dist=False)
channel = mgr.channels["websocket"]
scope = channel.gateway.workspaces.default_scope()
assert scope.project_path == tmp_path
assert scope.restrict_to_workspace is True
@pytest.mark.asyncio
async def test_reasoning_delta_routes_to_send_reasoning_delta(manager):
channel = manager.channels["mock"]

View File

@ -20,20 +20,30 @@ from nanobot.channels.websocket import (
WebSocketChannel,
WebSocketConfig,
_is_valid_chat_id,
_issue_route_secret_matches,
_normalize_config_path,
_normalize_http_path,
_parse_envelope,
_parse_inbound_payload,
_parse_query,
_parse_request_path,
publish_runtime_model_update,
)
from nanobot.webui.ws_http import GatewayHTTPHandler
from nanobot.config.loader import load_config, save_config
from nanobot.config.schema import Config, ModelPresetConfig
from nanobot.session import webui_turns as wth
from nanobot.session.manager import SessionManager
from nanobot.webui.gateway_services import GatewayServices, build_gateway_services
from nanobot.webui.http_utils import (
issue_route_secret_matches as _issue_route_secret_matches,
)
from nanobot.webui.http_utils import (
normalize_config_path as _normalize_config_path,
)
from nanobot.webui.http_utils import (
normalize_http_path as _normalize_http_path,
)
from nanobot.webui.http_utils import (
parse_query as _parse_query,
)
from nanobot.webui.http_utils import (
parse_request_path as _parse_request_path,
)
from nanobot.webui.settings_api import settings_payload, update_provider_settings
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
@ -52,34 +62,36 @@ def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
}
cfg.update(kw)
parsed = WebSocketConfig.model_validate(cfg)
http_handler = GatewayHTTPHandler(
gateway = build_gateway_services(
config=parsed,
bus=bus,
session_manager=None,
static_dist_path=None,
workspace_path=Path.cwd(),
default_restrict_to_workspace=False,
runtime_model_name=None,
runtime_surface="browser",
runtime_capabilities_overrides=None,
bus=bus,
)
return WebSocketChannel(cfg, bus, http_handler=http_handler)
return WebSocketChannel(cfg, bus, gateway=gateway)
def _basic_handler(bus: Any, **kw: Any) -> GatewayHTTPHandler:
def _basic_handler(bus: Any, **kw: Any) -> GatewayServices:
cfg = WebSocketConfig.model_validate({
"enabled": True, "allowFrom": ["*"],
"host": "127.0.0.1", "port": _PORT,
"path": "/ws", "websocketRequiresToken": False,
})
return GatewayHTTPHandler(
return build_gateway_services(
config=cfg,
bus=bus,
session_manager=kw.get("session_manager"),
static_dist_path=None,
workspace_path=kw.get("workspace_path", Path.cwd()),
default_restrict_to_workspace=kw.get("default_restrict_to_workspace", False),
runtime_model_name=None,
runtime_surface=kw.get("runtime_surface", "browser"),
runtime_capabilities_overrides=kw.get("runtime_capabilities_overrides"),
bus=bus,
)
@ -194,7 +206,7 @@ def test_ssl_context_requires_both_cert_and_key_files() -> None:
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
bus,
http_handler=_basic_handler(bus),
gateway=_basic_handler(bus),
)
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
channel._build_ssl_context()
@ -310,7 +322,7 @@ async def test_webui_message_scope_inherits_persisted_session_scope(
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
bus,
http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
)
conn = AsyncMock()
conn.remote_address = ("127.0.0.1", 50123)
@ -356,7 +368,7 @@ async def test_webui_scope_expands_home_project_path(
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
bus,
http_handler=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace),
gateway=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace),
)
conn = AsyncMock()
conn.remote_address = ("127.0.0.1", 50123)
@ -393,7 +405,7 @@ async def test_webui_scope_rejects_missing_project_path(bus: MagicMock, tmp_path
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
bus,
http_handler=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace),
gateway=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace),
)
conn = AsyncMock()
conn.remote_address = ("127.0.0.1", 50123)
@ -430,7 +442,7 @@ async def test_webui_scope_rejects_running_scope_change(bus: MagicMock, tmp_path
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
bus,
http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
)
conn = AsyncMock()
conn.remote_address = ("127.0.0.1", 50123)
@ -486,7 +498,7 @@ async def test_webui_set_workspace_scope_rejects_running_chat(bus: MagicMock, tm
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
bus,
http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
)
conn = AsyncMock()
conn.remote_address = ("127.0.0.1", 50123)
@ -545,7 +557,7 @@ async def test_webui_scope_rejects_non_loopback_custom_scope(bus: MagicMock, tmp
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
bus,
http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
)
conn = AsyncMock()
conn.remote_address = ("203.0.113.8", 50123)
@ -574,7 +586,7 @@ async def test_webui_scope_rejects_non_loopback_custom_scope(bus: MagicMock, tmp
@pytest.mark.asyncio
async def test_send_delivers_json_message_with_media_and_reply() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -600,7 +612,7 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
@pytest.mark.asyncio
async def test_send_broadcasts_runtime_model_updates() -> None:
bus = MessageBus()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -647,8 +659,8 @@ async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -
return ws_media if channel == "websocket" else media_root
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
monkeypatch.setattr("nanobot.webui.ws_http.get_media_dir", fake_media_dir)
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
monkeypatch.setattr("nanobot.webui.media_gateway.get_media_dir", fake_media_dir)
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -671,7 +683,7 @@ async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -
@pytest.mark.asyncio
async def test_send_missing_connection_is_noop_without_error() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
await channel.send(msg)
@ -679,7 +691,7 @@ async def test_send_missing_connection_is_noop_without_error() -> None:
@pytest.mark.asyncio
async def test_send_removes_connection_on_connection_closed() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
channel._attach(mock_ws, "chat-1")
@ -694,7 +706,7 @@ async def test_send_removes_connection_on_connection_closed() -> None:
@pytest.mark.asyncio
async def test_send_progress_includes_structured_tool_events() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -742,7 +754,7 @@ async def test_send_progress_includes_structured_tool_events() -> None:
@pytest.mark.asyncio
async def test_send_file_edit_progress_uses_file_edit_event() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -791,7 +803,7 @@ async def test_send_file_edit_progress_uses_file_edit_event() -> None:
@pytest.mark.asyncio
async def test_send_progress_includes_agent_ui_blob() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -815,7 +827,7 @@ async def test_send_progress_includes_agent_ui_blob() -> None:
@pytest.mark.asyncio
async def test_send_delta_removes_connection_on_connection_closed() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
channel._attach(mock_ws, "chat-1")
@ -829,7 +841,7 @@ async def test_send_delta_removes_connection_on_connection_closed() -> None:
@pytest.mark.asyncio
async def test_send_delta_emits_delta_and_stream_end() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -862,11 +874,11 @@ async def test_send_delta_stream_end_rewrites_local_markdown_image(monkeypatch,
return path
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
monkeypatch.setattr("nanobot.webui.ws_http.get_media_dir", fake_media_dir)
monkeypatch.setattr("nanobot.webui.media_gateway.get_media_dir", fake_media_dir)
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "streaming": True},
bus,
http_handler=_basic_handler(bus, workspace_path=workspace),
gateway=_basic_handler(bus, workspace_path=workspace),
)
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -895,11 +907,11 @@ async def test_send_delta_stream_end_rewrites_inline_final_text(monkeypatch, tmp
return path
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
monkeypatch.setattr("nanobot.webui.ws_http.get_media_dir", fake_media_dir)
monkeypatch.setattr("nanobot.webui.media_gateway.get_media_dir", fake_media_dir)
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "streaming": True},
bus,
http_handler=_basic_handler(bus, workspace_path=workspace),
gateway=_basic_handler(bus, workspace_path=workspace),
)
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -919,7 +931,7 @@ async def test_send_delta_stream_end_rewrites_inline_final_text(monkeypatch, tmp
@pytest.mark.asyncio
async def test_send_reasoning_delta_emits_streaming_frame() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -940,7 +952,7 @@ async def test_send_reasoning_delta_emits_streaming_frame() -> None:
@pytest.mark.asyncio
async def test_send_reasoning_end_emits_close_frame() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -956,7 +968,7 @@ async def test_send_reasoning_one_shot_expands_to_delta_plus_end() -> None:
the base implementation must produce one delta and one end so the
WebUI sees the same shape either way."""
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -978,7 +990,7 @@ async def test_send_reasoning_one_shot_expands_to_delta_plus_end() -> None:
@pytest.mark.asyncio
async def test_send_reasoning_delta_drops_empty_chunks() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -990,7 +1002,7 @@ async def test_send_reasoning_delta_drops_empty_chunks() -> None:
@pytest.mark.asyncio
async def test_send_reasoning_without_subscribers_is_noop() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
await channel.send_reasoning_delta("unattached", "thinking", None)
await channel.send_reasoning_end("unattached", None)
@ -1000,7 +1012,7 @@ async def test_send_reasoning_without_subscribers_is_noop() -> None:
@pytest.mark.asyncio
async def test_send_turn_end_emits_turn_end_event() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -1019,7 +1031,7 @@ async def test_send_turn_end_emits_turn_end_event() -> None:
@pytest.mark.asyncio
async def test_send_turn_end_includes_latency_ms_when_present() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -1038,7 +1050,7 @@ async def test_send_turn_end_includes_latency_ms_when_present() -> None:
@pytest.mark.asyncio
async def test_send_turn_end_includes_goal_state_when_present() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -1058,7 +1070,7 @@ async def test_send_turn_end_includes_goal_state_when_present() -> None:
@pytest.mark.asyncio
async def test_send_goal_status_running_emits_event_with_started_at() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -1086,7 +1098,7 @@ async def test_send_goal_status_running_emits_event_with_started_at() -> None:
@pytest.mark.asyncio
async def test_send_goal_status_idle_omits_started_at() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -1109,7 +1121,7 @@ async def test_send_goal_status_idle_omits_started_at() -> None:
@pytest.mark.asyncio
async def test_send_goal_state_emits_blob_per_chat() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_a = AsyncMock()
mock_b = AsyncMock()
channel._attach(mock_a, "chat-a")
@ -1138,10 +1150,9 @@ async def test_send_goal_state_emits_blob_per_chat() -> None:
@pytest.mark.asyncio
async def test_maybe_push_active_goal_state_noop_without_session_manager() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
channel._session_manager = None
await channel._maybe_push_active_goal_state("chat-1")
mock_ws.send.assert_not_called()
@ -1149,10 +1160,13 @@ async def test_maybe_push_active_goal_state_noop_without_session_manager() -> No
@pytest.mark.asyncio
async def test_maybe_push_active_goal_state_skips_when_no_goal_on_disk() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
sm = MagicMock()
sm.read_session_file.return_value = None
channel._session_manager = sm
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"]},
bus,
gateway=_basic_handler(bus, session_manager=sm),
)
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
await channel._maybe_push_active_goal_state("chat-1")
@ -1162,7 +1176,6 @@ async def test_maybe_push_active_goal_state_skips_when_no_goal_on_disk() -> None
@pytest.mark.asyncio
async def test_maybe_push_active_goal_state_notifies_when_goal_active_on_disk() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
sm = MagicMock()
sm.read_session_file.return_value = {
"metadata": {
@ -1174,7 +1187,11 @@ async def test_maybe_push_active_goal_state_notifies_when_goal_active_on_disk()
},
"messages": [],
}
channel._session_manager = sm
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"]},
bus,
gateway=_basic_handler(bus, session_manager=sm),
)
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
await channel._maybe_push_active_goal_state("chat-1")
@ -1190,7 +1207,7 @@ async def test_maybe_push_active_goal_state_notifies_when_goal_active_on_disk()
@pytest.mark.asyncio
async def test_maybe_push_turn_run_wall_clock_skips_when_no_active_turn() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
from nanobot.session import webui_turns as wth
@ -1203,7 +1220,7 @@ async def test_maybe_push_turn_run_wall_clock_skips_when_no_active_turn() -> Non
@pytest.mark.asyncio
async def test_maybe_push_turn_run_wall_clock_replays_running() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
from nanobot.session import webui_turns as wth
@ -1228,7 +1245,7 @@ async def test_maybe_push_turn_run_wall_clock_replays_running() -> None:
@pytest.mark.asyncio
async def test_send_session_updated_emits_session_updated_event() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -1247,7 +1264,7 @@ async def test_send_session_updated_emits_session_updated_event() -> None:
@pytest.mark.asyncio
async def test_send_session_updated_includes_scope_when_present() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -1266,7 +1283,7 @@ async def test_send_session_updated_includes_scope_when_present() -> None:
@pytest.mark.asyncio
async def test_send_non_connection_closed_exception_is_raised() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
mock_ws = AsyncMock()
mock_ws.send.side_effect = RuntimeError("unexpected")
channel._attach(mock_ws, "chat-1")
@ -1279,7 +1296,7 @@ async def test_send_non_connection_closed_exception_is_raised() -> None:
@pytest.mark.asyncio
async def test_send_delta_missing_connection_is_noop() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, gateway=_basic_handler(bus))
# No exception, no error — just a no-op
await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
@ -1287,7 +1304,7 @@ async def test_send_delta_missing_connection_is_noop() -> None:
@pytest.mark.asyncio
async def test_stop_is_idempotent() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
# stop() before start() should not raise
await channel.stop()
await channel.stop()
@ -1448,7 +1465,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist(
)
channel = _ch(bus, port=port)
channel._api_tokens["tok"] = time.monotonic() + 300
channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
@ -1733,7 +1750,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist(
async def test_commands_api_returns_slash_command_metadata(bus: MagicMock) -> None:
port = 29892
channel = _ch(bus, port=port)
channel._api_tokens["tok"] = time.monotonic() + 300
channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
@ -1771,7 +1788,7 @@ async def test_bootstrap_exposes_native_surface(bus: MagicMock) -> None:
"websocketRequiresToken": True,
},
bus,
http_handler=_basic_handler(bus, runtime_surface="native", runtime_capabilities_overrides={"can_pick_folder": True}),
gateway=_basic_handler(bus, runtime_surface="native", runtime_capabilities_overrides={"can_pick_folder": True}),
)
server_task = asyncio.create_task(channel.start())
@ -1941,8 +1958,9 @@ async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None:
try:
# Fill issued tokens to capacity
channel._http.issued_tokens = {
f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._http._MAX_ISSUED_TOKENS)
channel.gateway.tokens.issued_tokens = {
f"nbwt_fill_{i}": time.monotonic() + 300
for i in range(channel.gateway.tokens.max_tokens)
}
resp = await _http_get(
@ -2299,10 +2317,8 @@ def test_sessions_list_includes_active_run_started_at() -> None:
from nanobot.session import webui_turns as wth
bus = MagicMock()
channel = _ch(bus)
channel._api_tokens["tok"] = time.monotonic() + 300.0
channel._session_manager = MagicMock()
channel._session_manager.list_sessions.return_value = [
session_manager = MagicMock()
session_manager.list_sessions.return_value = [
{
"key": "websocket:chat-1",
"created_at": "2026-05-19T10:00:00Z",
@ -2317,19 +2333,25 @@ def test_sessions_list_includes_active_run_started_at() -> None:
"updated_at": "2026-05-19T10:01:00Z",
},
]
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"]},
bus,
gateway=_basic_handler(bus, session_manager=session_manager),
)
channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300.0
wth._WEBSOCKET_TURN_WALL_STARTED_AT.clear()
try:
wth._WEBSOCKET_TURN_WALL_STARTED_AT["chat-1"] = 1_700_000_000.0
req = Request("/api/sessions", Headers([("Authorization", "Bearer tok")]))
resp = channel._handle_sessions_list(req)
resp = channel.gateway.http._handle_sessions_list(req)
finally:
wth._WEBSOCKET_TURN_WALL_STARTED_AT.clear()
assert resp.status_code == 200
body = json.loads(resp.body.decode())
workspace_scope = body["sessions"][0].pop("workspace_scope")
assert workspace_scope["project_path"] == str(channel._workspace_path)
assert workspace_scope["project_path"] == str(channel.gateway.media.workspace_path)
assert workspace_scope["access_mode"] in {"restricted", "full"}
assert body["sessions"] == [
{
@ -2376,10 +2398,10 @@ def test_handle_webui_thread_get_returns_json(tmp_path, monkeypatch) -> None:
append_transcript_object(key, {"event": "user", "chat_id": "c1", "text": "hi"})
bus = MagicMock()
channel = _ch(bus)
channel._api_tokens["tok"] = time.monotonic() + 300.0
channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300.0
enc = quote(key, safe="")
req = Request(f"/api/sessions/{enc}/webui-thread", Headers([("Authorization", "Bearer tok")]))
resp = channel._handle_webui_thread_get(req, enc)
resp = channel.gateway.http._handle_webui_thread_get(req, enc)
assert resp.status_code == 200
body = json.loads(resp.body.decode())
assert body["sessionKey"] == key

View File

@ -21,7 +21,7 @@ from nanobot.channels.websocket import (
WebSocketConfig,
_extract_data_url_mime,
)
from nanobot.webui.ws_http import GatewayHTTPHandler
from nanobot.webui.gateway_services import build_gateway_services
def _tiny_png_data_url() -> str:
@ -45,17 +45,18 @@ def _make_channel() -> WebSocketChannel:
bus.publish_inbound = AsyncMock()
cfg = {"enabled": True, "allowFrom": ["*"], "websocketRequiresToken": False}
parsed = WebSocketConfig.model_validate(cfg)
handler = GatewayHTTPHandler(
gateway = build_gateway_services(
config=parsed,
bus=bus,
session_manager=None,
static_dist_path=None,
workspace_path=Path.cwd(),
default_restrict_to_workspace=False,
runtime_model_name=None,
runtime_surface="browser",
runtime_capabilities_overrides=None,
bus=bus,
)
channel = WebSocketChannel(cfg, bus, http_handler=handler)
channel = WebSocketChannel(cfg, bus, gateway=gateway)
channel._handle_message = AsyncMock() # type: ignore[method-assign]
return channel

View File

@ -13,7 +13,7 @@ import pytest
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
from nanobot.session.manager import Session, SessionManager
from nanobot.webui.ws_http import GatewayHTTPHandler
from nanobot.webui.gateway_services import GatewayServices, build_gateway_services
_PORT = 29900
@ -25,18 +25,19 @@ def _make_handler(
session_manager: SessionManager | None = None,
static_dist_path: Path | None = None,
runtime_model_name: Any | None = None,
) -> GatewayHTTPHandler:
) -> GatewayServices:
config = WebSocketConfig.model_validate(cfg) if isinstance(cfg, dict) else cfg
workspace = Path.cwd()
return GatewayHTTPHandler(
return build_gateway_services(
config=config,
bus=bus,
session_manager=session_manager,
static_dist_path=static_dist_path,
workspace_path=workspace,
default_restrict_to_workspace=False,
runtime_model_name=runtime_model_name,
runtime_surface="browser",
runtime_capabilities_overrides=None,
bus=bus,
)
@ -58,13 +59,13 @@ def _ch(
"websocketRequiresToken": False,
}
cfg.update(extra)
http_handler = _make_handler(
gateway = _make_handler(
cfg, bus,
session_manager=session_manager,
static_dist_path=static_dist_path,
runtime_model_name=runtime_model_name,
)
return WebSocketChannel(cfg, bus, http_handler=http_handler)
return WebSocketChannel(cfg, bus, gateway=gateway)
@pytest.fixture()
@ -729,20 +730,20 @@ async def test_api_token_pool_purges_expired(bus: MagicMock, tmp_path: Path) ->
channel = _ch(bus, session_manager=sm, port=29908)
# Don't start a server — directly inject and validate.
import time as _time
channel._http.api_tokens["expired"] = _time.monotonic() - 1
channel._http.api_tokens["live"] = _time.monotonic() + 60
channel.gateway.tokens.api_tokens["expired"] = _time.monotonic() - 1
channel.gateway.tokens.api_tokens["live"] = _time.monotonic() + 60
class _FakeReq:
path = "/api/sessions"
headers = {"Authorization": "Bearer expired"}
assert channel._http.check_api_token(_FakeReq()) is False
assert channel.gateway.tokens.check_api_token(_FakeReq()) is False
class _LiveReq:
path = "/api/sessions"
headers = {"Authorization": "Bearer live"}
assert channel._http.check_api_token(_LiveReq()) is True
assert channel.gateway.tokens.check_api_token(_LiveReq()) is True
class _FakeConn:
@ -797,7 +798,7 @@ def test_wildcard_ipv6_without_auth_raises(bus: MagicMock) -> None:
def test_wildcard_ipv6_with_secret_is_valid(bus: MagicMock) -> None:
channel = _ch(bus, host="::", tokenIssueSecret="s3cret")
resp = channel._handle_bootstrap(
resp = channel.gateway.http._handle_bootstrap(
_REMOTE, _FakeReq({"X-Nanobot-Auth": "s3cret"})
)
assert resp.status_code == 200
@ -806,7 +807,7 @@ def test_wildcard_ipv6_with_secret_is_valid(bus: MagicMock) -> None:
def test_bootstrap_accepts_static_token_as_secret(bus: MagicMock) -> None:
"""When only token (not token_issue_secret) is set, bootstrap accepts it."""
channel = _ch(bus, host="0.0.0.0", token="static-tok")
resp = channel._handle_bootstrap(
resp = channel.gateway.http._handle_bootstrap(
_REMOTE, _FakeReq({"Authorization": "Bearer static-tok"})
)
assert resp.status_code == 200
@ -816,7 +817,7 @@ def test_bootstrap_accepts_static_token_as_secret(bus: MagicMock) -> None:
def test_bootstrap_ws_url_uses_forwarded_https_host(bus: MagicMock) -> None:
channel = _ch(bus, host="127.0.0.1", port=29931)
resp = channel._handle_bootstrap(
resp = channel.gateway.http._handle_bootstrap(
_LOCAL,
_FakeReq({"Host": "nanobot.example", "X-Forwarded-Proto": "https"}),
)
@ -827,7 +828,7 @@ def test_bootstrap_ws_url_uses_forwarded_https_host(bus: MagicMock) -> None:
def test_localhost_without_auth_is_valid(bus: MagicMock) -> None:
channel = _ch(bus, host="127.0.0.1")
resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS)
resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS)
assert resp.status_code == 200
@ -837,7 +838,7 @@ def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytes
lambda: "from-disk",
)
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " live/model ")
resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS)
resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS)
assert resp.status_code == 200
body = json.loads(resp.body)
assert body["model_name"] == "live/model"
@ -849,7 +850,7 @@ def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeyp
lambda: "from-disk",
)
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " ")
resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS)
resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS)
assert resp.status_code == 200
body = json.loads(resp.body)
assert body["model_name"] == "from-disk"
@ -865,7 +866,7 @@ def test_bootstrap_falls_back_when_runtime_raises(bus: MagicMock, monkeypatch: p
raise RuntimeError("resolver failed")
channel = _ch(bus, host="127.0.0.1", runtime_model_name=boom)
resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS)
resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS)
assert resp.status_code == 200
body = json.loads(resp.body)
assert body["model_name"] == "from-disk"
@ -873,7 +874,7 @@ def test_bootstrap_falls_back_when_runtime_raises(bus: MagicMock, monkeypatch: p
def test_bootstrap_rejects_wrong_secret(bus: MagicMock) -> None:
channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="correct")
resp = channel._handle_bootstrap(
resp = channel.gateway.http._handle_bootstrap(
_REMOTE, _FakeReq({"Authorization": "Bearer wrong"})
)
assert resp.status_code == 401
@ -881,7 +882,7 @@ def test_bootstrap_rejects_wrong_secret(bus: MagicMock) -> None:
def test_bootstrap_accepts_remote_with_valid_secret(bus: MagicMock) -> None:
channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="s3cret")
resp = channel._handle_bootstrap(
resp = channel.gateway.http._handle_bootstrap(
_REMOTE, _FakeReq({"Authorization": "Bearer s3cret"})
)
assert resp.status_code == 200
@ -891,7 +892,7 @@ def test_bootstrap_accepts_remote_with_valid_secret(bus: MagicMock) -> None:
def test_bootstrap_accepts_x_nanobot_auth_header(bus: MagicMock) -> None:
channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="s3cret")
resp = channel._handle_bootstrap(
resp = channel.gateway.http._handle_bootstrap(
_REMOTE, _FakeReq({"X-Nanobot-Auth": "s3cret"})
)
assert resp.status_code == 200
@ -900,5 +901,5 @@ def test_bootstrap_accepts_x_nanobot_auth_header(bus: MagicMock) -> None:
def test_bootstrap_secret_also_enforced_on_localhost(bus: MagicMock) -> None:
"""When secret is set, even localhost must provide it (reverse-proxy safety)."""
channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="s3cret")
resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS)
resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS)
assert resp.status_code == 401

View File

@ -7,19 +7,18 @@ multi-client scenarios, edge cases, and realistic usage patterns.
from __future__ import annotations
import asyncio
import json
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
import websockets
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
from nanobot.bus.events import OutboundMessage
from nanobot.webui.ws_http import GatewayHTTPHandler
from ws_test_client import WsTestClient, issue_token, issue_token_ok
from nanobot.bus.events import OutboundMessage
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
from nanobot.webui.gateway_services import build_gateway_services
def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
cfg: dict[str, Any] = {
@ -32,17 +31,18 @@ def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
}
cfg.update(kw)
parsed = WebSocketConfig.model_validate(cfg)
handler = GatewayHTTPHandler(
gateway = build_gateway_services(
config=parsed,
bus=bus,
session_manager=None,
static_dist_path=None,
workspace_path=Path.cwd(),
default_restrict_to_workspace=False,
runtime_model_name=None,
runtime_surface="browser",
runtime_capabilities_overrides=None,
bus=bus,
)
return WebSocketChannel(cfg, bus, http_handler=handler)
return WebSocketChannel(cfg, bus, gateway=gateway)
@pytest.fixture()
@ -67,7 +67,8 @@ async def test_ready_event_fields(bus: MagicMock) -> None:
assert len(r.chat_id) == 36
assert r.client_id == "c1"
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -80,7 +81,8 @@ async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None:
r = await c.recv_ready()
assert r.client_id.startswith("anon-")
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -93,7 +95,8 @@ async def test_each_connection_unique_chat_id(bus: MagicMock) -> None:
async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2:
assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id
finally:
await ch.stop(); await t
await ch.stop()
await t
# -- Inbound messages (client -> server) ----------------------------------
@ -113,7 +116,8 @@ async def test_plain_text(bus: MagicMock) -> None:
assert inbound.content == "hello world"
assert inbound.sender_id == "p"
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -128,7 +132,8 @@ async def test_json_content_field(bus: MagicMock) -> None:
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "structured"
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -146,7 +151,8 @@ async def test_json_text_and_message_fields(bus: MagicMock) -> None:
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "via message"
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -162,7 +168,8 @@ async def test_empty_payload_ignored(bus: MagicMock) -> None:
await asyncio.sleep(0.1)
bus.publish_inbound.assert_not_awaited()
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -179,7 +186,8 @@ async def test_messages_preserve_order(bus: MagicMock) -> None:
contents = [call[0][0].content for call in bus.publish_inbound.call_args_list]
assert contents == [f"msg-{i}" for i in range(5)]
finally:
await ch.stop(); await t
await ch.stop()
await t
# -- Outbound messages (server -> client) ---------------------------------
@ -199,7 +207,8 @@ async def test_server_send_message(bus: MagicMock) -> None:
msg = await c.recv_message()
assert msg.text == "reply"
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -238,7 +247,8 @@ async def test_server_send_tags_tool_hint_with_kind(bus: MagicMock) -> None:
prog = await c.recv_message()
assert prog.raw.get("kind") == "progress"
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -258,7 +268,8 @@ async def test_server_send_with_media_and_reply(bus: MagicMock) -> None:
assert msg.media == ["/tmp/a.png"]
assert msg.reply_to == "m1"
finally:
await ch.stop(); await t
await ch.stop()
await t
# -- Streaming ------------------------------------------------------------
@ -282,7 +293,8 @@ async def test_streaming_deltas_and_end(bus: MagicMock) -> None:
ends = [m for m in msgs if m.event == "stream_end"]
assert len(ends) == 1
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -306,7 +318,8 @@ async def test_interleaved_streams(bus: MagicMock) -> None:
assert sa == "A1A2"
assert sb == "B1B2"
finally:
await ch.stop(); await t
await ch.stop()
await t
# -- Multi-client ---------------------------------------------------------
@ -330,7 +343,8 @@ async def test_independent_sessions(bus: MagicMock) -> None:
))
assert (await c2.recv_message()).text == "for-u2"
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -348,7 +362,8 @@ async def test_disconnected_client_cleanup(bus: MagicMock) -> None:
))
assert chat_id not in ch._subs
finally:
await ch.stop(); await t
await ch.stop()
await t
# -- Authentication -------------------------------------------------------
@ -363,7 +378,8 @@ async def test_static_token_accepted(bus: MagicMock) -> None:
async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c:
assert (await c.recv_ready()).client_id == "a"
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -377,7 +393,8 @@ async def test_static_token_rejected(bus: MagicMock) -> None:
pass
assert exc.value.response.status_code == 401
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -411,7 +428,8 @@ async def test_token_issue_full_flow(bus: MagicMock) -> None:
pass
assert exc.value.response.status_code == 401
finally:
await ch.stop(); await t
await ch.stop()
await t
# -- Path routing ---------------------------------------------------------
@ -426,7 +444,8 @@ async def test_custom_path(bus: MagicMock) -> None:
async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c:
assert (await c.recv_ready()).event == "ready"
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -440,7 +459,8 @@ async def test_wrong_path_404(bus: MagicMock) -> None:
pass
assert exc.value.response.status_code == 404
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -452,7 +472,8 @@ async def test_trailing_slash_normalized(bus: MagicMock) -> None:
async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c:
assert (await c.recv_ready()).event == "ready"
finally:
await ch.stop(); await t
await ch.stop()
await t
# -- Edge cases -----------------------------------------------------------
@ -471,7 +492,8 @@ async def test_large_message(bus: MagicMock) -> None:
await asyncio.sleep(0.2)
assert bus.publish_inbound.call_args[0][0].content == big
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -491,7 +513,8 @@ async def test_unicode_roundtrip(bus: MagicMock) -> None:
))
assert (await c.recv_message()).text == text
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -513,7 +536,8 @@ async def test_rapid_fire(bus: MagicMock) -> None:
received = [(await c.recv_message()).text for _ in range(50)]
assert received == [f"out-{i}" for i in range(50)]
finally:
await ch.stop(); await t
await ch.stop()
await t
@pytest.mark.asyncio
@ -528,4 +552,5 @@ async def test_invalid_json_as_plain_text(bus: MagicMock) -> None:
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "{broken json"
finally:
await ch.stop(); await t
await ch.stop()
await t

View File

@ -2,8 +2,8 @@
integration on ``/api/sessions/<key>/messages``.
The route is the return path for images attached to persisted user turns:
:meth:`WebSocketChannel._sign_media_path` mints URLs during session reads,
and :meth:`WebSocketChannel._handle_media_fetch` serves the bytes back.
:meth:`WebSocketChannel.gateway.media.sign_media_path` mints URLs during session reads,
and :meth:`GatewayHTTPHandler._handle_media_fetch` serves the bytes back.
These tests cover the two halves end-to-end plus the adversarial edges
(bad signatures, ``..`` traversal, non-existent files, non-image types).
"""
@ -22,13 +22,12 @@ import httpx
import pytest
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
from nanobot.session.manager import Session, SessionManager
from nanobot.webui.gateway_services import build_gateway_services
from nanobot.webui.media_api import (
b64url_decode,
b64url_encode,
)
from nanobot.session.manager import Session, SessionManager
from nanobot.webui.ws_http import GatewayHTTPHandler
# PNG magic bytes + a couple of sentinel bytes so we can verify byte-for-byte
# round-trip of the served payload. Stays under mimetype + size limits.
@ -57,17 +56,18 @@ def _ch(
"websocketRequiresToken": False,
}
parsed = WebSocketConfig.model_validate(cfg)
http_handler = GatewayHTTPHandler(
gateway = build_gateway_services(
config=parsed,
bus=bus,
session_manager=session_manager,
static_dist_path=None,
workspace_path=workspace_path or Path.cwd(),
default_restrict_to_workspace=False,
runtime_model_name=None,
runtime_surface="browser",
runtime_capabilities_overrides=None,
bus=bus,
)
return WebSocketChannel(cfg, bus, http_handler=http_handler)
return WebSocketChannel(cfg, bus, gateway=gateway)
@pytest.fixture()
@ -95,7 +95,7 @@ async def _http_get(
# ---------------------------------------------------------------------------
# _sign_media_path: the URL minter
# gateway.media.sign_media_path: the URL minter
# ---------------------------------------------------------------------------
@ -114,11 +114,11 @@ def test_sign_media_path_rejects_paths_outside_media_root(
media = tmp_path / "media"
media.mkdir()
channel = _ch(bus, port=0)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
assert channel._sign_media_path(outside) is None
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
assert channel.gateway.media.sign_media_path(outside) is None
# Traversal via the media root is also rejected — the resolve() step
# normalises ``..`` out before the relative_to check.
assert channel._sign_media_path(media / ".." / "secrets" / "cred.txt") is None
assert channel.gateway.media.sign_media_path(media / ".." / "secrets" / "cred.txt") is None
def test_sign_media_path_round_trips_via_hmac(
@ -129,13 +129,13 @@ def test_sign_media_path_round_trips_via_hmac(
media.mkdir()
(media / "a.png").write_bytes(_PNG_BYTES)
channel = _ch(bus, port=0)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
url = channel._sign_media_path(media / "a.png")
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
url = channel.gateway.media.sign_media_path(media / "a.png")
assert url is not None
assert url.startswith("/api/media/")
sig, payload = url[len("/api/media/"):].split("/", 1)
expected = hmac.new(
channel._media_secret, payload.encode("ascii"), hashlib.sha256
channel.gateway.media.secret, payload.encode("ascii"), hashlib.sha256
).digest()[:16]
assert b64url_decode(sig) == expected
# The payload decodes back to the *relative* path — no absolute-path leaks.
@ -152,8 +152,8 @@ def test_local_markdown_image_is_staged_and_rewritten(
media = tmp_path / "media"
channel = _ch(bus, workspace_path=workspace, port=0)
with patch("nanobot.webui.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
rewritten = channel._rewrite_local_markdown_images(
with patch("nanobot.webui.media_gateway.get_media_dir", side_effect=_fake_media_dir(media)):
rewritten = channel.gateway.media.rewrite_local_markdown_images(
"The result:\n![Cloud Architecture Diagram](demo_arch.png)"
)
@ -174,8 +174,8 @@ def test_local_markdown_video_is_staged_and_rewritten(
media = tmp_path / "media"
channel = _ch(bus, workspace_path=workspace, port=0)
with patch("nanobot.webui.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
rewritten = channel._rewrite_local_markdown_images(
with patch("nanobot.webui.media_gateway.get_media_dir", side_effect=_fake_media_dir(media)):
rewritten = channel.gateway.media.rewrite_local_markdown_images(
"The result:\n![nanobot-intro.mp4](nanobot-intro.mp4)"
)
@ -197,8 +197,8 @@ def test_local_markdown_image_rejects_workspace_escape(
channel = _ch(bus, workspace_path=workspace, port=0)
text = "![nope](../outside.png)"
with patch("nanobot.webui.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
assert channel._rewrite_local_markdown_images(text) == text
with patch("nanobot.webui.media_gateway.get_media_dir", side_effect=_fake_media_dir(media)):
assert channel.gateway.media.rewrite_local_markdown_images(text) == text
assert not (media / "websocket").exists()
@ -219,8 +219,8 @@ async def test_media_route_serves_signed_file(
target.write_bytes(_PNG_BYTES)
channel = _ch(bus, port=29920)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
url_path = channel.gateway.media.sign_media_path(target)
assert url_path is not None
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
@ -252,8 +252,8 @@ async def test_media_route_serves_video_byte_ranges(
target.write_bytes(b"0123456789")
channel = _ch(bus, port=29927)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
url_path = channel.gateway.media.sign_media_path(target)
assert url_path is not None
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
@ -284,8 +284,8 @@ async def test_media_route_serves_suffix_video_byte_ranges(
target.write_bytes(b"0123456789")
channel = _ch(bus, port=29928)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
url_path = channel.gateway.media.sign_media_path(target)
assert url_path is not None
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
@ -313,8 +313,8 @@ async def test_media_route_rejects_unsatisfiable_byte_range(
target.write_bytes(b"0123456789")
channel = _ch(bus, port=29929)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
url_path = channel.gateway.media.sign_media_path(target)
assert url_path is not None
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
@ -339,15 +339,15 @@ async def test_media_route_rejects_bad_signature(
"""A payload re-signed with a different secret must 401.
Protects against a restart: old URLs baked into a stale tab become
un-forgeable once ``_media_secret`` regenerates.
un-forgeable once ``gateway.media.secret`` regenerates.
"""
media = tmp_path / "media"
media.mkdir()
(media / "f.png").write_bytes(_PNG_BYTES)
channel = _ch(bus, port=29921)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
good = channel._sign_media_path(media / "f.png")
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
good = channel.gateway.media.sign_media_path(media / "f.png")
assert good is not None
_, payload = good[len("/api/media/"):].split("/", 1)
# Forge a sig with a *different* secret.
@ -385,11 +385,11 @@ async def test_media_route_rejects_path_traversal_payload(
# Hand-craft a traversal payload the legit signer would refuse to mint.
payload = b64url_encode(b"../secret.txt")
mac = hmac.new(
channel._media_secret, payload.encode("ascii"), hashlib.sha256
channel.gateway.media.secret, payload.encode("ascii"), hashlib.sha256
).digest()[:16]
url = f"/api/media/{b64url_encode(mac)}/{payload}"
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
@ -413,8 +413,8 @@ async def test_media_route_404s_missing_file(
target.write_bytes(_PNG_BYTES)
channel = _ch(bus, port=29923)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
url_path = channel.gateway.media.sign_media_path(target)
assert url_path is not None
target.unlink() # the file vanishes between signing and fetching
server_task = asyncio.create_task(channel.start())
@ -441,10 +441,10 @@ async def test_media_route_degrades_non_image_to_octet_stream(
(media / "scary.html").write_bytes(b"<script>alert(1)</script>")
channel = _ch(bus, port=29924)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
payload = b64url_encode(b"scary.html")
mac = hmac.new(
channel._media_secret, payload.encode("ascii"), hashlib.sha256
channel.gateway.media.secret, payload.encode("ascii"), hashlib.sha256
).digest()[:16]
url = f"/api/media/{b64url_encode(mac)}/{payload}"
server_task = asyncio.create_task(channel.start())
@ -472,8 +472,8 @@ async def test_media_route_serves_svg_with_strict_csp(
target.write_text("<svg xmlns='http://www.w3.org/2000/svg'><script>alert(1)</script></svg>")
channel = _ch(bus, port=29928)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
url_path = channel.gateway.media.sign_media_path(target)
assert url_path is not None
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
@ -513,7 +513,7 @@ async def test_session_messages_exposes_signed_media_urls(
sm.save(sess)
channel = _ch(bus, session_manager=sm, port=29925)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
@ -558,7 +558,7 @@ async def test_session_messages_skips_vanished_media(
sm.save(sess)
channel = _ch(bus, session_manager=sm, port=29926)
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try: