mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
refactor: split WebUI gateway dependencies
Maintainer edit for PR 4115: rebase onto origin/main and split gateway HTTP routing from token, media, and workspace services so WebSocketChannel depends on explicit gateway services instead of GatewayHTTPHandler internals. Preserve file edit channel capabilities and restore tools.restrict_to_workspace wiring through ChannelManager.
This commit is contained in:
parent
2420826e05
commit
2a98360105
@ -16,7 +16,6 @@ from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.utils.restart import consume_restart_notice_from_env, format_restart_completed_message
|
||||
from nanobot.webui.ws_http import GatewayHTTPHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.session.manager import SessionManager
|
||||
@ -113,21 +112,24 @@ class ChannelManager:
|
||||
kwargs: dict[str, Any] = {}
|
||||
if cls.name == "websocket":
|
||||
from nanobot.channels.websocket import WebSocketConfig
|
||||
from nanobot.webui.gateway_services import build_gateway_services
|
||||
|
||||
parsed = WebSocketConfig.model_validate(section)
|
||||
static_path = _default_webui_dist() if self._webui_static_dist else None
|
||||
workspace = Path(self.config.workspace_path)
|
||||
http_handler = GatewayHTTPHandler(
|
||||
gateway = build_gateway_services(
|
||||
config=parsed,
|
||||
bus=self.bus,
|
||||
session_manager=self._session_manager,
|
||||
static_dist_path=static_path,
|
||||
workspace_path=workspace,
|
||||
default_restrict_to_workspace=self.config.tools.restrict_to_workspace,
|
||||
runtime_model_name=self._webui_runtime_model_name,
|
||||
runtime_surface=self._webui_runtime_surface,
|
||||
runtime_capabilities_overrides=self._webui_runtime_capabilities,
|
||||
bus=self.bus,
|
||||
logger=logger,
|
||||
)
|
||||
kwargs["http_handler"] = http_handler
|
||||
kwargs["gateway"] = gateway
|
||||
channel = cls(section, self.bus, **kwargs)
|
||||
channel.transcription_provider = transcription_provider
|
||||
channel.transcription_api_key = transcription_key
|
||||
|
||||
@ -3,27 +3,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import email.utils
|
||||
import hmac
|
||||
import http
|
||||
import json
|
||||
import re
|
||||
import ssl
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
from urllib.parse import parse_qs, unquote, urlparse
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from websockets.asyncio.server import ServerConnection, serve, unix_serve
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.http11 import Request as WsRequest
|
||||
from websockets.http11 import Response
|
||||
|
||||
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@ -41,53 +34,22 @@ from nanobot.utils.media_decode import (
|
||||
save_base64_data_url,
|
||||
)
|
||||
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
|
||||
from nanobot.webui.gateway_services import GatewayServices
|
||||
from nanobot.webui.http_utils import (
|
||||
is_localhost as _is_localhost,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
normalize_config_path as _normalize_config_path,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
parse_request_path as _parse_request_path,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
query_first as _query_first,
|
||||
)
|
||||
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
||||
from nanobot.webui.transcript import append_transcript_object
|
||||
|
||||
|
||||
def _strip_trailing_slash(path: str) -> str:
|
||||
if len(path) > 1 and path.endswith("/"):
|
||||
return path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _normalize_config_path(path: str) -> str:
|
||||
return _strip_trailing_slash(path)
|
||||
|
||||
|
||||
def _case_insensitive_header(headers: Any, key: str) -> str:
|
||||
"""Read a header from websockets/http test stubs without assuming casing."""
|
||||
try:
|
||||
value = headers.get(key)
|
||||
except Exception:
|
||||
value = None
|
||||
if value is None:
|
||||
try:
|
||||
value = headers.get(key.lower())
|
||||
except Exception:
|
||||
value = None
|
||||
return str(value or "").strip()
|
||||
|
||||
|
||||
def _safe_host_header(value: str) -> str:
|
||||
"""Return a safe Host header value, or empty when it should not be echoed."""
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return ""
|
||||
if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value):
|
||||
return value
|
||||
if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value):
|
||||
return value
|
||||
return ""
|
||||
|
||||
|
||||
def _host_for_url(host: str, port: int) -> str:
|
||||
host = host.strip()
|
||||
if host in ("0.0.0.0", "::"):
|
||||
host = "127.0.0.1"
|
||||
if ":" in host and not host.startswith("["):
|
||||
host = f"[{host}]"
|
||||
return f"{host}:{port}"
|
||||
from nanobot.webui.websocket_logging import websockets_server_logger
|
||||
|
||||
|
||||
class WebSocketConfig(Base):
|
||||
@ -182,20 +144,6 @@ class WebSocketConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||
headers = Headers(
|
||||
[
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", "application/json; charset=utf-8"),
|
||||
]
|
||||
)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, headers, body)
|
||||
|
||||
|
||||
def publish_runtime_model_update(
|
||||
bus: MessageBus,
|
||||
model: str,
|
||||
@ -214,57 +162,6 @@ def publish_runtime_model_update(
|
||||
))
|
||||
|
||||
|
||||
def _default_model_name_from_config() -> str | None:
|
||||
"""Resolved model string from on-disk config (bootstrap fallback)."""
|
||||
try:
|
||||
from nanobot.config.loader import load_config
|
||||
|
||||
model = load_config().resolve_preset().model.strip()
|
||||
return model or None
|
||||
except Exception as e:
|
||||
logger.debug("bootstrap model_name could not load from config: {}", e)
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_bootstrap_model_name(
|
||||
runtime_name: Callable[[], str | None] | None,
|
||||
) -> str | None:
|
||||
"""Prefer an in-process resolver (e.g. AgentLoop); else config-derived default."""
|
||||
if runtime_name is not None:
|
||||
try:
|
||||
raw = runtime_name()
|
||||
except Exception as e:
|
||||
logger.debug("bootstrap runtime model resolver failed: {}", e)
|
||||
else:
|
||||
if isinstance(raw, str):
|
||||
stripped = raw.strip()
|
||||
if stripped:
|
||||
return stripped
|
||||
return _default_model_name_from_config()
|
||||
|
||||
|
||||
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||||
"""Parse normalized path and query parameters in one pass."""
|
||||
parsed = urlparse("ws://x" + path_with_query)
|
||||
path = _strip_trailing_slash(parsed.path or "/")
|
||||
return path, parse_qs(parsed.query, keep_blank_values=True)
|
||||
|
||||
|
||||
def _normalize_http_path(path_with_query: str) -> str:
|
||||
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
|
||||
return _parse_request_path(path_with_query)[0]
|
||||
|
||||
|
||||
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||||
return _parse_request_path(path_with_query)[1]
|
||||
|
||||
|
||||
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
|
||||
"""Return the first value for *key*, or None."""
|
||||
values = query.get(key)
|
||||
return values[0] if values else None
|
||||
|
||||
|
||||
def _parse_inbound_payload(raw: str) -> str | None:
|
||||
"""Parse a client frame into text; return None for empty or unrecognized content."""
|
||||
text = raw.strip()
|
||||
@ -355,67 +252,6 @@ def _extract_data_url_mime(url: str) -> str | None:
|
||||
return m.group(1).strip().lower() or None
|
||||
|
||||
|
||||
_LOCALHOSTS = frozenset({"127.0.0.1", "::1", "localhost"})
|
||||
|
||||
# Matches the legacy chat-id pattern but allows file-system-safe stems too,
|
||||
# so the API can address sessions whose keys came from non-WebSocket channels.
|
||||
_API_KEY_RE = re.compile(r"^[A-Za-z0-9_:.-]{1,128}$")
|
||||
|
||||
|
||||
def _decode_api_key(raw_key: str) -> str | None:
|
||||
"""Decode a percent-encoded API path segment, then validate the result."""
|
||||
key = unquote(raw_key)
|
||||
if _API_KEY_RE.match(key) is None:
|
||||
return None
|
||||
return key
|
||||
|
||||
|
||||
def _is_localhost(connection: Any) -> bool:
|
||||
"""Return True if *connection* originated from the loopback interface."""
|
||||
addr = getattr(connection, "remote_address", None)
|
||||
if not addr:
|
||||
return False
|
||||
host = addr[0] if isinstance(addr, tuple) else addr
|
||||
if not isinstance(host, str):
|
||||
return False
|
||||
# ``::ffff:127.0.0.1`` is loopback in IPv6-mapped form.
|
||||
if host.startswith("::ffff:"):
|
||||
host = host[7:]
|
||||
return host in _LOCALHOSTS
|
||||
|
||||
|
||||
def _http_response(
|
||||
body: bytes,
|
||||
*,
|
||||
status: int = 200,
|
||||
content_type: str = "text/plain; charset=utf-8",
|
||||
extra_headers: list[tuple[str, str]] | None = None,
|
||||
) -> Response:
|
||||
headers = [
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", content_type),
|
||||
]
|
||||
if extra_headers:
|
||||
headers.extend(extra_headers)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, Headers(headers), body)
|
||||
|
||||
|
||||
def _http_error(status: int, message: str | None = None) -> Response:
|
||||
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
|
||||
return _http_response(body, status=status)
|
||||
|
||||
|
||||
def _bearer_token(headers: Any) -> str | None:
|
||||
"""Pull a Bearer token out of standard or query-style headers."""
|
||||
auth = headers.get("Authorization") or headers.get("authorization")
|
||||
if auth and auth.lower().startswith("bearer "):
|
||||
return auth[7:].strip() or None
|
||||
return None
|
||||
|
||||
|
||||
def _is_websocket_upgrade(request: WsRequest) -> bool:
|
||||
"""Detect an actual WS upgrade; plain HTTP GETs to the same path should fall through."""
|
||||
upgrade = request.headers.get("Upgrade") or request.headers.get("upgrade")
|
||||
@ -427,20 +263,6 @@ def _is_websocket_upgrade(request: WsRequest) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||||
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
|
||||
if not configured_secret:
|
||||
return True
|
||||
authorization = headers.get("Authorization") or headers.get("authorization")
|
||||
if authorization and authorization.lower().startswith("bearer "):
|
||||
supplied = authorization[7:].strip()
|
||||
return hmac.compare_digest(supplied, configured_secret)
|
||||
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||||
if not header_token:
|
||||
return False
|
||||
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||||
|
||||
|
||||
class WebSocketChannel(BaseChannel):
|
||||
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
|
||||
|
||||
@ -452,7 +274,7 @@ class WebSocketChannel(BaseChannel):
|
||||
config: Any,
|
||||
bus: MessageBus,
|
||||
*,
|
||||
http_handler: Any | None = None,
|
||||
gateway: GatewayServices,
|
||||
):
|
||||
if isinstance(config, dict):
|
||||
config = WebSocketConfig.model_validate(config)
|
||||
@ -467,11 +289,11 @@ class WebSocketChannel(BaseChannel):
|
||||
self._stop_event: asyncio.Event | None = None
|
||||
self._server_task: asyncio.Task[None] | None = None
|
||||
|
||||
# HTTP handler injected from outside (ChannelManager / gateway startup).
|
||||
# Owns tokens, sessions, media, settings, static serving.
|
||||
self._http = http_handler
|
||||
# Backwards-compat: workspace controller used in envelope dispatch
|
||||
self._webui_workspaces = http_handler.workspaces if http_handler else None
|
||||
self.gateway = gateway
|
||||
self._http_router = gateway.http
|
||||
self._tokens = gateway.tokens
|
||||
self._media = gateway.media
|
||||
self._workspaces = gateway.workspaces
|
||||
|
||||
self._stream_text_buffers: dict[tuple[str, str], list[str]] = {}
|
||||
|
||||
@ -501,9 +323,9 @@ class WebSocketChannel(BaseChannel):
|
||||
connected clients normally see it via ``goal_state`` / ``turn_end`` frames.
|
||||
Pushing here makes refresh + reconnect restore the strip without a new model turn.
|
||||
"""
|
||||
if self._http.session_manager is None:
|
||||
if self.gateway.session_manager is None:
|
||||
return
|
||||
row = self._http.session_manager.read_session_file(f"websocket:{chat_id}")
|
||||
row = self.gateway.session_manager.read_session_file(f"websocket:{chat_id}")
|
||||
meta = row.get("metadata", {}) if isinstance(row, dict) else {}
|
||||
if not isinstance(meta, dict):
|
||||
meta = {}
|
||||
@ -543,57 +365,6 @@ class WebSocketChannel(BaseChannel):
|
||||
def _expected_path(self) -> str:
|
||||
return _normalize_config_path(self.config.path)
|
||||
|
||||
# -- Backwards-compat property aliases (used by tests) ------------------
|
||||
|
||||
@property
|
||||
def _session_manager(self):
|
||||
return self._http.session_manager
|
||||
|
||||
@_session_manager.setter
|
||||
def _session_manager(self, value):
|
||||
self._http.session_manager = value
|
||||
|
||||
@property
|
||||
def _media_secret(self):
|
||||
return self._http.media_secret
|
||||
|
||||
@property
|
||||
def _issued_tokens(self):
|
||||
return self._http.issued_tokens
|
||||
|
||||
@_issued_tokens.setter
|
||||
def _issued_tokens(self, value):
|
||||
self._http.issued_tokens = value
|
||||
|
||||
@property
|
||||
def _api_tokens(self):
|
||||
return self._http.api_tokens
|
||||
|
||||
def _check_api_token(self, request):
|
||||
return self._http.check_api_token(request)
|
||||
|
||||
def _sign_media_path(self, path):
|
||||
return self._http.sign_media_path(path)
|
||||
|
||||
def _sign_or_stage_media_path(self, path):
|
||||
return self._http.sign_or_stage_media_path(path)
|
||||
|
||||
def _rewrite_local_markdown_images(self, text):
|
||||
return self._http.rewrite_local_markdown_images(text)
|
||||
|
||||
def _handle_bootstrap(self, connection, request):
|
||||
return self._http._handle_bootstrap(connection, request)
|
||||
|
||||
def _handle_sessions_list(self, request):
|
||||
return self._http._handle_sessions_list(request)
|
||||
|
||||
def _handle_webui_thread_get(self, request, key):
|
||||
return self._http._handle_webui_thread_get(request, key)
|
||||
|
||||
@property
|
||||
def _workspace_path(self):
|
||||
return self._http.workspace_path
|
||||
|
||||
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
||||
cert = self.config.ssl_certfile.strip()
|
||||
key = self.config.ssl_keyfile.strip()
|
||||
@ -624,8 +395,8 @@ class WebSocketChannel(BaseChannel):
|
||||
return connection.respond(403, "Forbidden")
|
||||
return self._authorize_websocket_handshake(connection, query)
|
||||
|
||||
<< # Everything else goes to the HTTP handler
|
||||
return await self._http.dispatch(connection, request)
|
||||
# Everything else goes to the HTTP handler
|
||||
return await self._http_router.dispatch(connection, request)
|
||||
|
||||
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
|
||||
supplied = _query_first(query, "token")
|
||||
@ -634,17 +405,17 @@ class WebSocketChannel(BaseChannel):
|
||||
if static_token:
|
||||
if supplied and hmac.compare_digest(supplied, static_token):
|
||||
return None
|
||||
if supplied and self._http.take_issued_token_if_valid(supplied):
|
||||
if supplied and self._tokens.take_issued_token_if_valid(supplied):
|
||||
return None
|
||||
return connection.respond(401, "Unauthorized")
|
||||
|
||||
if self.config.websocket_requires_token:
|
||||
if supplied and self._http.take_issued_token_if_valid(supplied):
|
||||
if supplied and self._tokens.take_issued_token_if_valid(supplied):
|
||||
return None
|
||||
return connection.respond(401, "Unauthorized")
|
||||
|
||||
if supplied:
|
||||
self._http.take_issued_token_if_valid(supplied)
|
||||
self._tokens.take_issued_token_if_valid(supplied)
|
||||
return None
|
||||
|
||||
# -- Server lifecycle and connection ingress ---------------------------
|
||||
@ -878,14 +649,14 @@ class WebSocketChannel(BaseChannel):
|
||||
new_id = str(uuid.uuid4())
|
||||
scope = await self._workspace_scope_or_error(
|
||||
connection,
|
||||
lambda: self._webui_workspaces.scope_for_new_chat(
|
||||
lambda: self._workspaces.scope_for_new_chat(
|
||||
envelope,
|
||||
controls_available=_is_localhost(connection),
|
||||
),
|
||||
)
|
||||
if scope is None:
|
||||
return
|
||||
self._webui_workspaces.persist_scope(new_id, scope)
|
||||
self._workspaces.persist_scope(new_id, scope)
|
||||
self._attach(connection, new_id)
|
||||
await self._send_event(connection, "attached", chat_id=new_id)
|
||||
await self._send_event(
|
||||
@ -913,7 +684,7 @@ class WebSocketChannel(BaseChannel):
|
||||
return
|
||||
scope = await self._workspace_scope_or_error(
|
||||
connection,
|
||||
lambda: self._webui_workspaces.scope_for_set_request(
|
||||
lambda: self._workspaces.scope_for_set_request(
|
||||
envelope,
|
||||
chat_id=cid,
|
||||
chat_running=websocket_turn_wall_started_at(cid) is not None,
|
||||
@ -923,7 +694,7 @@ class WebSocketChannel(BaseChannel):
|
||||
)
|
||||
if scope is None:
|
||||
return
|
||||
self._webui_workspaces.persist_scope(cid, scope)
|
||||
self._workspaces.persist_scope(cid, scope)
|
||||
await self._send_event(
|
||||
connection,
|
||||
"session_updated",
|
||||
@ -965,7 +736,7 @@ class WebSocketChannel(BaseChannel):
|
||||
return
|
||||
scope = await self._workspace_scope_or_error(
|
||||
connection,
|
||||
lambda: self._webui_workspaces.scope_for_message(
|
||||
lambda: self._workspaces.scope_for_message(
|
||||
envelope,
|
||||
chat_id=cid,
|
||||
chat_running=websocket_turn_wall_started_at(cid) is not None,
|
||||
@ -989,7 +760,7 @@ class WebSocketChannel(BaseChannel):
|
||||
if mcp_presets:
|
||||
metadata["mcp_presets"] = mcp_presets
|
||||
metadata[WORKSPACE_SCOPE_METADATA_KEY] = scope.metadata()
|
||||
self._webui_workspaces.persist_scope(cid, scope)
|
||||
self._workspaces.persist_scope(cid, scope)
|
||||
image_generation = envelope.get("image_generation")
|
||||
if isinstance(image_generation, dict) and image_generation.get("enabled") is True:
|
||||
aspect_ratio = image_generation.get("aspect_ratio")
|
||||
@ -1044,8 +815,7 @@ class WebSocketChannel(BaseChannel):
|
||||
self._subs.clear()
|
||||
self._conn_chats.clear()
|
||||
self._conn_default.clear()
|
||||
self._http.issued_tokens.clear()
|
||||
self._http.api_tokens.clear()
|
||||
self._tokens.clear()
|
||||
|
||||
async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None:
|
||||
"""Send a raw frame to one connection, cleaning up on ConnectionClosed."""
|
||||
@ -1127,7 +897,7 @@ class WebSocketChannel(BaseChannel):
|
||||
)
|
||||
return
|
||||
text = msg.content
|
||||
wire_text = self._http.rewrite_local_markdown_images(text)
|
||||
wire_text = self._media.rewrite_local_markdown_images(text)
|
||||
payload: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"chat_id": msg.chat_id,
|
||||
@ -1137,7 +907,7 @@ class WebSocketChannel(BaseChannel):
|
||||
payload["media"] = msg.media
|
||||
urls: list[dict[str, str]] = []
|
||||
for entry in msg.media:
|
||||
signed = self._http.sign_or_stage_media_path(Path(entry))
|
||||
signed = self._media.sign_or_stage_media_path(Path(entry))
|
||||
if signed is not None:
|
||||
urls.append(signed)
|
||||
if urls:
|
||||
@ -1252,7 +1022,7 @@ class WebSocketChannel(BaseChannel):
|
||||
if delta:
|
||||
buffered.append(delta)
|
||||
full_text = "".join(buffered)
|
||||
rewritten = self._http.rewrite_local_markdown_images(full_text)
|
||||
rewritten = self._media.rewrite_local_markdown_images(full_text)
|
||||
if rewritten != full_text:
|
||||
body["text"] = rewritten
|
||||
else:
|
||||
|
||||
70
nanobot/webui/gateway_services.py
Normal file
70
nanobot/webui/gateway_services.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""Composition helpers for the embedded WebUI gateway."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger as default_logger
|
||||
|
||||
from nanobot.webui.gateway_tokens import GatewayTokenStore
|
||||
from nanobot.webui.media_gateway import WebUIMediaGateway
|
||||
from nanobot.webui.workspaces import WebUIWorkspaceController
|
||||
from nanobot.webui.ws_http import GatewayHTTPHandler
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GatewayServices:
|
||||
"""Explicit dependencies shared by WebSocket transport and HTTP routes."""
|
||||
|
||||
http: GatewayHTTPHandler
|
||||
tokens: GatewayTokenStore
|
||||
media: WebUIMediaGateway
|
||||
workspaces: WebUIWorkspaceController
|
||||
session_manager: Any | None
|
||||
|
||||
|
||||
def build_gateway_services(
|
||||
*,
|
||||
config: Any,
|
||||
bus: Any,
|
||||
session_manager: Any | None,
|
||||
static_dist_path: Path | None,
|
||||
workspace_path: Path,
|
||||
default_restrict_to_workspace: bool,
|
||||
runtime_model_name: Any | None,
|
||||
runtime_surface: str,
|
||||
runtime_capabilities_overrides: dict[str, Any] | None,
|
||||
logger: Any = default_logger,
|
||||
) -> GatewayServices:
|
||||
tokens = GatewayTokenStore()
|
||||
media = WebUIMediaGateway(
|
||||
workspace_path=workspace_path,
|
||||
logger=logger,
|
||||
)
|
||||
workspaces = WebUIWorkspaceController(
|
||||
session_manager=session_manager,
|
||||
default_workspace=workspace_path,
|
||||
default_restrict_to_workspace=default_restrict_to_workspace,
|
||||
)
|
||||
http = GatewayHTTPHandler(
|
||||
config=config,
|
||||
session_manager=session_manager,
|
||||
static_dist_path=static_dist_path,
|
||||
runtime_model_name=runtime_model_name,
|
||||
runtime_surface=runtime_surface,
|
||||
runtime_capabilities_overrides=runtime_capabilities_overrides,
|
||||
bus=bus,
|
||||
tokens=tokens,
|
||||
media=media,
|
||||
workspaces=workspaces,
|
||||
log=logger,
|
||||
)
|
||||
return GatewayServices(
|
||||
http=http,
|
||||
tokens=tokens,
|
||||
media=media,
|
||||
workspaces=workspaces,
|
||||
session_manager=session_manager,
|
||||
)
|
||||
82
nanobot/webui/gateway_tokens.py
Normal file
82
nanobot/webui/gateway_tokens.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""Token state for the embedded WebUI gateway."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from websockets.http11 import Request as WsRequest
|
||||
|
||||
from nanobot.webui.http_utils import bearer_token, parse_query, query_first
|
||||
|
||||
|
||||
@dataclass
|
||||
class GatewayTokenStore:
|
||||
"""Own short-lived WebSocket and WebUI API tokens for one gateway process."""
|
||||
|
||||
max_tokens: int = 10_000
|
||||
issued_tokens: dict[str, float] = field(default_factory=dict)
|
||||
api_tokens: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def check_api_token(self, request: WsRequest) -> bool:
|
||||
self._purge_expired_api_tokens()
|
||||
token = bearer_token(request.headers) or query_first(
|
||||
parse_query(request.path), "token"
|
||||
)
|
||||
if not token:
|
||||
return False
|
||||
expiry = self.api_tokens.get(token)
|
||||
if expiry is None or time.monotonic() > expiry:
|
||||
self.api_tokens.pop(token, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
def can_issue(self, *, include_api_token: bool = False) -> bool:
|
||||
self._purge_expired_issued_tokens()
|
||||
self._purge_expired_api_tokens()
|
||||
if len(self.issued_tokens) >= self.max_tokens:
|
||||
return False
|
||||
if include_api_token and len(self.api_tokens) >= self.max_tokens:
|
||||
return False
|
||||
return True
|
||||
|
||||
def issue_token(self, ttl_s: int | float, *, api_token: bool = False) -> str:
|
||||
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||
expiry = time.monotonic() + float(ttl_s)
|
||||
self.issued_tokens[token_value] = expiry
|
||||
if api_token:
|
||||
self.api_tokens[token_value] = expiry
|
||||
return token_value
|
||||
|
||||
def take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||||
if not token_value:
|
||||
return False
|
||||
self._purge_expired_issued_tokens()
|
||||
expiry = self.issued_tokens.pop(token_value, None)
|
||||
if expiry is None:
|
||||
return False
|
||||
if time.monotonic() > expiry:
|
||||
return False
|
||||
return True
|
||||
|
||||
def clear(self) -> None:
|
||||
self.issued_tokens.clear()
|
||||
self.api_tokens.clear()
|
||||
|
||||
def _purge_expired_api_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self.api_tokens.items()):
|
||||
if now > expiry:
|
||||
self.api_tokens.pop(token_key, None)
|
||||
|
||||
def _purge_expired_issued_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self.issued_tokens.items()):
|
||||
if now > expiry:
|
||||
self.issued_tokens.pop(token_key, None)
|
||||
|
||||
|
||||
def token_response_payload(token: str, expires_in: Any) -> dict[str, Any]:
|
||||
return {"token": token, "expires_in": expires_in}
|
||||
151
nanobot/webui/http_utils.py
Normal file
151
nanobot/webui/http_utils.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""Shared HTTP helpers for the embedded WebUI gateway."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import email.utils
|
||||
import hmac
|
||||
import http
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.http11 import Response
|
||||
|
||||
QueryParams = dict[str, list[str]]
|
||||
|
||||
|
||||
def strip_trailing_slash(path: str) -> str:
|
||||
if len(path) > 1 and path.endswith("/"):
|
||||
return path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def normalize_config_path(path: str) -> str:
|
||||
return strip_trailing_slash(path)
|
||||
|
||||
|
||||
def case_insensitive_header(headers: Any, key: str) -> str:
|
||||
"""Read a header from websockets/http test stubs without assuming casing."""
|
||||
try:
|
||||
value = headers.get(key)
|
||||
except Exception:
|
||||
value = None
|
||||
if value is None:
|
||||
try:
|
||||
value = headers.get(key.lower())
|
||||
except Exception:
|
||||
value = None
|
||||
return str(value or "").strip()
|
||||
|
||||
|
||||
def safe_host_header(value: str) -> str:
|
||||
"""Return a safe Host header value, or empty when it should not be echoed."""
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return ""
|
||||
if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value):
|
||||
return value
|
||||
if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value):
|
||||
return value
|
||||
return ""
|
||||
|
||||
|
||||
def host_for_url(host: str, port: int) -> str:
|
||||
host = host.strip()
|
||||
if host in ("0.0.0.0", "::"):
|
||||
host = "127.0.0.1"
|
||||
if ":" in host and not host.startswith("["):
|
||||
host = f"[{host}]"
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
def http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||
headers = Headers(
|
||||
[
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", "application/json; charset=utf-8"),
|
||||
]
|
||||
)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, headers, body)
|
||||
|
||||
|
||||
def http_response(
|
||||
body: bytes,
|
||||
*,
|
||||
status: int = 200,
|
||||
content_type: str = "text/plain; charset=utf-8",
|
||||
extra_headers: list[tuple[str, str]] | None = None,
|
||||
) -> Response:
|
||||
headers = [
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", content_type),
|
||||
]
|
||||
if extra_headers:
|
||||
headers.extend(extra_headers)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, Headers(headers), body)
|
||||
|
||||
|
||||
def http_error(status: int, message: str | None = None) -> Response:
|
||||
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
|
||||
return http_response(body, status=status)
|
||||
|
||||
|
||||
def parse_request_path(path_with_query: str) -> tuple[str, QueryParams]:
|
||||
"""Parse normalized path and query parameters in one pass."""
|
||||
parsed = urlparse("ws://x" + path_with_query)
|
||||
path = strip_trailing_slash(parsed.path or "/")
|
||||
return path, parse_qs(parsed.query, keep_blank_values=True)
|
||||
|
||||
|
||||
def normalize_http_path(path_with_query: str) -> str:
|
||||
return parse_request_path(path_with_query)[0]
|
||||
|
||||
|
||||
def parse_query(path_with_query: str) -> QueryParams:
|
||||
return parse_request_path(path_with_query)[1]
|
||||
|
||||
|
||||
def query_first(query: QueryParams, key: str) -> str | None:
|
||||
values = query.get(key)
|
||||
return values[0] if values else None
|
||||
|
||||
|
||||
def is_localhost(connection: Any) -> bool:
|
||||
addr = getattr(connection, "remote_address", None)
|
||||
if not addr:
|
||||
return False
|
||||
host = addr[0] if isinstance(addr, tuple) else addr
|
||||
if not isinstance(host, str):
|
||||
return False
|
||||
if host.startswith("::ffff:"):
|
||||
host = host[7:]
|
||||
return host in {"127.0.0.1", "::1", "localhost"}
|
||||
|
||||
|
||||
def bearer_token(headers: Any) -> str | None:
|
||||
auth = headers.get("Authorization") or headers.get("authorization")
|
||||
if auth and auth.lower().startswith("bearer "):
|
||||
return auth[7:].strip() or None
|
||||
return None
|
||||
|
||||
|
||||
def issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||||
if not configured_secret:
|
||||
return True
|
||||
authorization = headers.get("Authorization") or headers.get("authorization")
|
||||
if authorization and authorization.lower().startswith("bearer "):
|
||||
supplied = authorization[7:].strip()
|
||||
return hmac.compare_digest(supplied, configured_secret)
|
||||
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||||
if not header_token:
|
||||
return False
|
||||
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||||
@ -4,10 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import email.utils
|
||||
import hashlib
|
||||
import hmac
|
||||
import http
|
||||
import mimetypes
|
||||
import re
|
||||
import shutil
|
||||
@ -16,12 +14,20 @@ from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.http11 import Request as WsRequest
|
||||
from websockets.http11 import Response
|
||||
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
from nanobot.webui.http_utils import (
|
||||
case_insensitive_header as _case_insensitive_header,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
http_error as _http_error,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
http_response as _http_response,
|
||||
)
|
||||
|
||||
MediaDirProvider = Callable[[str | None], Path]
|
||||
SignedMediaPath = Callable[[Path], dict[str, str] | None]
|
||||
@ -67,43 +73,6 @@ _SVG_MEDIA_HEADERS: tuple[tuple[str, str], ...] = (
|
||||
_BYTE_RANGE_RE = re.compile(r"^bytes=(\d*)-(\d*)$")
|
||||
|
||||
|
||||
def _http_response(
|
||||
body: bytes,
|
||||
*,
|
||||
status: int = 200,
|
||||
content_type: str = "text/plain; charset=utf-8",
|
||||
extra_headers: list[tuple[str, str]] | None = None,
|
||||
) -> Response:
|
||||
headers = [
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", content_type),
|
||||
]
|
||||
if extra_headers:
|
||||
headers.extend(extra_headers)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, Headers(headers), body)
|
||||
|
||||
|
||||
def _http_error(status: int, message: str | None = None) -> Response:
|
||||
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
|
||||
return _http_response(body, status=status)
|
||||
|
||||
|
||||
def _case_insensitive_header(headers: Any, key: str) -> str:
|
||||
try:
|
||||
value = headers.get(key)
|
||||
except Exception:
|
||||
value = None
|
||||
if value is None:
|
||||
try:
|
||||
value = headers.get(key.lower())
|
||||
except Exception:
|
||||
value = None
|
||||
return str(value or "").strip()
|
||||
|
||||
|
||||
def _parse_single_byte_range(range_header: str, size: int) -> tuple[int, int]:
|
||||
"""Parse a single HTTP byte range for signed media responses."""
|
||||
if size <= 0 or "," in range_header:
|
||||
|
||||
92
nanobot/webui/media_gateway.py
Normal file
92
nanobot/webui/media_gateway.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""Media gateway services shared by WebUI HTTP routes and WebSocket frames."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from websockets.http11 import Request as WsRequest
|
||||
from websockets.http11 import Response
|
||||
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.webui.media_api import (
|
||||
attach_signed_media_urls,
|
||||
serve_signed_media,
|
||||
sign_media_path,
|
||||
sign_or_stage_media_path,
|
||||
signed_media_attachments,
|
||||
)
|
||||
from nanobot.webui.transcript import rewrite_local_markdown_images
|
||||
|
||||
|
||||
class WebUIMediaGateway:
|
||||
"""Own media URL signing and WebUI markdown/media augmentation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
workspace_path: Path,
|
||||
logger: Any,
|
||||
media_dir: Callable[[str | None], Path] | None = None,
|
||||
secret: bytes | None = None,
|
||||
) -> None:
|
||||
self.workspace_path = workspace_path
|
||||
self.logger = logger
|
||||
self._media_dir = media_dir or (lambda channel=None: get_media_dir(channel))
|
||||
self.secret = secret or secrets.token_bytes(32)
|
||||
|
||||
def serve_signed_media(
|
||||
self,
|
||||
sig: str,
|
||||
payload: str,
|
||||
*,
|
||||
request: WsRequest | None = None,
|
||||
) -> Response:
|
||||
return serve_signed_media(
|
||||
sig,
|
||||
payload,
|
||||
secret=self.secret,
|
||||
request=request,
|
||||
media_dir=self._media_dir,
|
||||
)
|
||||
|
||||
def sign_media_path(self, abs_path: Path) -> str | None:
|
||||
return sign_media_path(
|
||||
abs_path,
|
||||
secret=self.secret,
|
||||
media_dir=self._media_dir,
|
||||
)
|
||||
|
||||
def sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
|
||||
return sign_or_stage_media_path(
|
||||
path,
|
||||
secret=self.secret,
|
||||
media_dir=self._media_dir,
|
||||
logger=self.logger,
|
||||
)
|
||||
|
||||
def rewrite_local_markdown_images(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
workspace_path: Path | None = None,
|
||||
) -> str:
|
||||
return rewrite_local_markdown_images(
|
||||
text,
|
||||
workspace_path=workspace_path or self.workspace_path,
|
||||
sign_path=self.sign_or_stage_media_path,
|
||||
)
|
||||
|
||||
def augment_media_urls(self, payload: dict[str, Any]) -> None:
|
||||
attach_signed_media_urls(payload, sign_path=self.sign_media_path)
|
||||
|
||||
def augment_transcript_media(self, paths: list[str]) -> list[dict[str, Any]]:
|
||||
return signed_media_attachments(
|
||||
paths,
|
||||
sign_path=self.sign_or_stage_media_path,
|
||||
)
|
||||
|
||||
def augment_transcript_user_media(self, paths: list[str]) -> list[dict[str, Any]]:
|
||||
return self.augment_transcript_media(paths)
|
||||
@ -9,187 +9,70 @@ Also houses shared HTTP utility functions used by both this module and
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import email.utils
|
||||
import hmac
|
||||
import http
|
||||
import json
|
||||
import mimetypes
|
||||
import re
|
||||
import secrets
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from loguru import logger
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.http11 import Request as WsRequest
|
||||
from websockets.http11 import Response
|
||||
|
||||
from nanobot.command.builtin import builtin_command_palette
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel
|
||||
from nanobot.webui.media_api import (
|
||||
serve_signed_media,
|
||||
sign_media_path,
|
||||
sign_or_stage_media_path,
|
||||
from nanobot.webui.gateway_tokens import GatewayTokenStore, token_response_payload
|
||||
from nanobot.webui.http_utils import (
|
||||
case_insensitive_header as _case_insensitive_header,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
host_for_url as _host_for_url,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
http_error as _http_error,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
http_json_response as _http_json_response,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
http_response as _http_response,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
is_localhost as _is_localhost,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
issue_route_secret_matches as _issue_route_secret_matches,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
normalize_config_path as _normalize_config_path,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
parse_query as _parse_query,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
parse_request_path as _parse_request_path,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
query_first as _query_first,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
safe_host_header as _safe_host_header,
|
||||
)
|
||||
from nanobot.webui.media_gateway import WebUIMediaGateway
|
||||
from nanobot.webui.sidebar_state import (
|
||||
read_webui_sidebar_state,
|
||||
write_webui_sidebar_state,
|
||||
)
|
||||
from nanobot.webui.thread_disk import delete_webui_thread
|
||||
from nanobot.webui.transcript import (
|
||||
build_webui_thread_response,
|
||||
rewrite_local_markdown_images,
|
||||
)
|
||||
from nanobot.webui.transcript import build_webui_thread_response
|
||||
from nanobot.webui.workspaces import WebUIWorkspaceController
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared HTTP utility functions (imported by websocket.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _strip_trailing_slash(path: str) -> str:
|
||||
if len(path) > 1 and path.endswith("/"):
|
||||
return path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _normalize_config_path(path: str) -> str:
|
||||
return _strip_trailing_slash(path)
|
||||
|
||||
|
||||
def _case_insensitive_header(headers: Any, key: str) -> str:
|
||||
"""Read a header from websockets/http test stubs without assuming casing."""
|
||||
try:
|
||||
value = headers.get(key)
|
||||
except Exception:
|
||||
value = None
|
||||
if value is None:
|
||||
try:
|
||||
value = headers.get(key.lower())
|
||||
except Exception:
|
||||
value = None
|
||||
return str(value or "").strip()
|
||||
|
||||
|
||||
def _safe_host_header(value: str) -> str:
|
||||
"""Return a safe Host header value, or empty when it should not be echoed."""
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return ""
|
||||
if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value):
|
||||
return value
|
||||
if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value):
|
||||
return value
|
||||
return ""
|
||||
|
||||
|
||||
def _host_for_url(host: str, port: int) -> str:
|
||||
host = host.strip()
|
||||
if host in ("0.0.0.0", "::"):
|
||||
host = "127.0.0.1"
|
||||
if ":" in host and not host.startswith("["):
|
||||
host = f"[{host}]"
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||
headers = Headers(
|
||||
[
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", "application/json; charset=utf-8"),
|
||||
]
|
||||
)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, headers, body)
|
||||
|
||||
|
||||
def _http_response(
|
||||
body: bytes,
|
||||
*,
|
||||
status: int = 200,
|
||||
content_type: str = "text/plain; charset=utf-8",
|
||||
extra_headers: list[tuple[str, str]] | None = None,
|
||||
) -> Response:
|
||||
headers = [
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", content_type),
|
||||
]
|
||||
if extra_headers:
|
||||
headers.extend(extra_headers)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, Headers(headers), body)
|
||||
|
||||
|
||||
def _http_error(status: int, message: str | None = None) -> Response:
|
||||
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
|
||||
return _http_response(body, status=status)
|
||||
|
||||
|
||||
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||||
"""Parse normalized path and query parameters in one pass."""
|
||||
parsed = urlparse("ws://x" + path_with_query)
|
||||
path = _strip_trailing_slash(parsed.path or "/")
|
||||
return path, parse_qs(parsed.query, keep_blank_values=True)
|
||||
|
||||
|
||||
def _normalize_http_path(path_with_query: str) -> str:
|
||||
return _parse_request_path(path_with_query)[0]
|
||||
|
||||
|
||||
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||||
return _parse_request_path(path_with_query)[1]
|
||||
|
||||
|
||||
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
|
||||
values = query.get(key)
|
||||
return values[0] if values else None
|
||||
|
||||
|
||||
def _is_localhost(connection: Any) -> bool:
|
||||
addr = getattr(connection, "remote_address", None)
|
||||
if not addr:
|
||||
return False
|
||||
host = addr[0] if isinstance(addr, tuple) else addr
|
||||
if not isinstance(host, str):
|
||||
return False
|
||||
if host.startswith("::ffff:"):
|
||||
host = host[7:]
|
||||
return host in {"127.0.0.1", "::1", "localhost"}
|
||||
|
||||
|
||||
def _bearer_token(headers: Any) -> str | None:
|
||||
auth = headers.get("Authorization") or headers.get("authorization")
|
||||
if auth and auth.lower().startswith("bearer "):
|
||||
return auth[7:].strip() or None
|
||||
return None
|
||||
|
||||
|
||||
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||||
if not configured_secret:
|
||||
return True
|
||||
authorization = headers.get("Authorization") or headers.get("authorization")
|
||||
if authorization and authorization.lower().startswith("bearer "):
|
||||
supplied = authorization[7:].strip()
|
||||
return hmac.compare_digest(supplied, configured_secret)
|
||||
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||||
if not header_token:
|
||||
return False
|
||||
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||||
|
||||
|
||||
def _decode_api_key(raw_key: str) -> str | None:
|
||||
from urllib.parse import unquote
|
||||
|
||||
@ -234,48 +117,36 @@ def _resolve_bootstrap_model_name(
|
||||
class GatewayHTTPHandler:
|
||||
"""Handles all HTTP routes served alongside the WebSocket endpoint.
|
||||
|
||||
Owns token management, session API, media API, static file serving,
|
||||
and delegates settings routes to ``WebUISettingsRouter``.
|
||||
Routes HTTP requests and delegates stateful work to explicit gateway
|
||||
services owned by the composition layer.
|
||||
"""
|
||||
|
||||
_MAX_ISSUED_TOKENS = 10_000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: Any, # WebSocketConfig
|
||||
session_manager: SessionManager | None,
|
||||
static_dist_path: Path | None,
|
||||
workspace_path: Path,
|
||||
runtime_model_name: Callable[[], str | None] | None,
|
||||
runtime_surface: str,
|
||||
runtime_capabilities_overrides: dict[str, Any] | None,
|
||||
bus: MessageBus,
|
||||
tokens: GatewayTokenStore,
|
||||
media: WebUIMediaGateway,
|
||||
workspaces: WebUIWorkspaceController,
|
||||
log: Any = logger,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.session_manager = session_manager
|
||||
self.static_dist_path = static_dist_path
|
||||
self.workspace_path = workspace_path
|
||||
self.runtime_model_name = runtime_model_name
|
||||
self.bus = bus
|
||||
self.tokens = tokens
|
||||
self.media = media
|
||||
self.workspaces = workspaces
|
||||
self._log = log
|
||||
self._runtime_surface = runtime_surface
|
||||
|
||||
self.issued_tokens: dict[str, float] = {}
|
||||
self.api_tokens: dict[str, float] = {}
|
||||
self.media_secret: bytes = secrets.token_bytes(32)
|
||||
|
||||
# Workspace controller
|
||||
from nanobot.webui.workspaces import WebUIWorkspaceController
|
||||
|
||||
self.workspaces = WebUIWorkspaceController(
|
||||
session_manager=session_manager,
|
||||
default_workspace=workspace_path,
|
||||
default_restrict_to_workspace=None,
|
||||
)
|
||||
|
||||
# Settings router
|
||||
from nanobot.webui.settings_api import runtime_capabilities as _rc
|
||||
from nanobot.webui.settings_routes import WebUISettingsRouter
|
||||
|
||||
@ -294,46 +165,13 @@ class GatewayHTTPHandler:
|
||||
# -- Token management ---------------------------------------------------
|
||||
|
||||
def check_api_token(self, request: WsRequest) -> bool:
|
||||
self._purge_expired_api_tokens()
|
||||
token = _bearer_token(request.headers) or _query_first(
|
||||
_parse_query(request.path), "token"
|
||||
)
|
||||
if not token:
|
||||
return False
|
||||
expiry = self.api_tokens.get(token)
|
||||
if expiry is None or time.monotonic() > expiry:
|
||||
self.api_tokens.pop(token, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _purge_expired_api_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self.api_tokens.items()):
|
||||
if now > expiry:
|
||||
self.api_tokens.pop(token_key, None)
|
||||
|
||||
def _purge_expired_issued_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self.issued_tokens.items()):
|
||||
if now > expiry:
|
||||
self.issued_tokens.pop(token_key, None)
|
||||
|
||||
def take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||||
if not token_value:
|
||||
return False
|
||||
self._purge_expired_issued_tokens()
|
||||
expiry = self.issued_tokens.pop(token_value, None)
|
||||
if expiry is None:
|
||||
return False
|
||||
if time.monotonic() > expiry:
|
||||
return False
|
||||
return True
|
||||
return self.tokens.check_api_token(request)
|
||||
|
||||
# -- Main dispatch ------------------------------------------------------
|
||||
|
||||
async def dispatch(self, connection: Any, request: WsRequest) -> Any | None:
|
||||
"""Route an HTTP request. Returns Response or None."""
|
||||
got, query = _parse_request_path(request.path)
|
||||
got, _ = _parse_request_path(request.path)
|
||||
|
||||
# Token issue endpoint
|
||||
if self.config.token_issue_path:
|
||||
@ -389,18 +227,14 @@ class GatewayHTTPHandler:
|
||||
"token_issue_path is set but token_issue_secret is empty; "
|
||||
"any client can obtain connection tokens — set token_issue_secret for production."
|
||||
)
|
||||
self._purge_expired_issued_tokens()
|
||||
if len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
||||
if not self.tokens.can_issue():
|
||||
self._log.error(
|
||||
"too many outstanding issued tokens ({}), rejecting issuance",
|
||||
len(self.issued_tokens),
|
||||
len(self.tokens.issued_tokens),
|
||||
)
|
||||
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
||||
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||
self.issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||||
return _http_json_response(
|
||||
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||||
)
|
||||
token_value = self.tokens.issue_token(self.config.token_ttl_s)
|
||||
return _http_json_response(token_response_payload(token_value, self.config.token_ttl_s))
|
||||
|
||||
# -- Bootstrap ----------------------------------------------------------
|
||||
|
||||
@ -412,21 +246,13 @@ class GatewayHTTPHandler:
|
||||
elif not _is_localhost(connection):
|
||||
return _http_error(403, "bootstrap is localhost-only")
|
||||
|
||||
self._purge_expired_issued_tokens()
|
||||
self._purge_expired_api_tokens()
|
||||
if (
|
||||
len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS
|
||||
or len(self.api_tokens) >= self._MAX_ISSUED_TOKENS
|
||||
):
|
||||
if not self.tokens.can_issue(include_api_token=True):
|
||||
return _http_response(
|
||||
json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"),
|
||||
status=429,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
token = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||
expiry = time.monotonic() + float(self.config.token_ttl_s)
|
||||
self.issued_tokens[token] = expiry
|
||||
self.api_tokens[token] = expiry
|
||||
token = self.tokens.issue_token(self.config.token_ttl_s, api_token=True)
|
||||
|
||||
ws_url = self._bootstrap_ws_url(request)
|
||||
expected_path = _normalize_config_path(self.config.path)
|
||||
@ -510,7 +336,7 @@ class GatewayHTTPHandler:
|
||||
messages = data.get("messages")
|
||||
if isinstance(messages, list):
|
||||
scrub_subagent_messages_for_channel(messages)
|
||||
self._augment_media_urls(data)
|
||||
self.media.augment_media_urls(data)
|
||||
return _http_json_response(data)
|
||||
|
||||
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
|
||||
@ -524,11 +350,11 @@ class GatewayHTTPHandler:
|
||||
scope = self.workspaces.scope_for_session_key(decoded_key)
|
||||
data = build_webui_thread_response(
|
||||
decoded_key,
|
||||
augment_user_media=self._augment_transcript_user_media,
|
||||
augment_assistant_text=lambda text: rewrite_local_markdown_images(
|
||||
augment_user_media=self.media.augment_transcript_media,
|
||||
augment_assistant_media=self.media.augment_transcript_media,
|
||||
augment_assistant_text=lambda text: self.media.rewrite_local_markdown_images(
|
||||
text,
|
||||
workspace_path=scope.project_path,
|
||||
sign_path=self.sign_or_stage_media_path,
|
||||
),
|
||||
)
|
||||
if data is None:
|
||||
@ -561,34 +387,10 @@ class GatewayHTTPHandler:
|
||||
def _handle_media_fetch(
|
||||
self, sig: str, payload: str, request: WsRequest | None = None
|
||||
) -> Response:
|
||||
return serve_signed_media(
|
||||
return self.media.serve_signed_media(
|
||||
sig,
|
||||
payload,
|
||||
secret=self.media_secret,
|
||||
request=request,
|
||||
media_dir=lambda channel=None: get_media_dir(channel),
|
||||
)
|
||||
|
||||
def sign_media_path(self, abs_path: Path) -> str | None:
|
||||
return sign_media_path(
|
||||
abs_path,
|
||||
secret=self.media_secret,
|
||||
media_dir=lambda channel=None: get_media_dir(channel),
|
||||
)
|
||||
|
||||
def sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
|
||||
return sign_or_stage_media_path(
|
||||
path,
|
||||
secret=self.media_secret,
|
||||
media_dir=lambda channel=None: get_media_dir(channel),
|
||||
logger=self._log,
|
||||
)
|
||||
|
||||
def rewrite_local_markdown_images(self, text: str) -> str:
|
||||
return rewrite_local_markdown_images(
|
||||
text,
|
||||
workspace_path=self.workspace_path,
|
||||
sign_path=self.sign_or_stage_media_path,
|
||||
)
|
||||
|
||||
# -- Misc routes --------------------------------------------------------
|
||||
@ -688,44 +490,5 @@ class GatewayHTTPHandler:
|
||||
extra_headers=[("Cache-Control", cache)],
|
||||
)
|
||||
|
||||
# -- Media helpers (called by WebSocketChannel.send) --------------------
|
||||
|
||||
def _augment_media_urls(self, payload: dict[str, Any]) -> None:
|
||||
messages = payload.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
media = msg.get("media")
|
||||
if not isinstance(media, list) or not media:
|
||||
continue
|
||||
urls: list[dict[str, str]] = []
|
||||
for entry in media:
|
||||
if not isinstance(entry, str) or not entry:
|
||||
continue
|
||||
signed = self.sign_media_path(Path(entry))
|
||||
if signed is None:
|
||||
continue
|
||||
urls.append({"url": signed, "name": Path(entry).name})
|
||||
if urls:
|
||||
msg["media_urls"] = urls
|
||||
msg.pop("media", None)
|
||||
|
||||
def _augment_transcript_user_media(self, paths: list[str]) -> list[dict[str, Any]]:
|
||||
out: list[dict[str, Any]] = []
|
||||
for pstr in paths:
|
||||
path = Path(pstr)
|
||||
att = self.sign_or_stage_media_path(path)
|
||||
if att is None:
|
||||
continue
|
||||
mime, _ = mimetypes.guess_type(path.name)
|
||||
kind = "video" if mime and mime.startswith("video/") else "image"
|
||||
out.append(
|
||||
{"kind": kind, "url": att["url"], "name": att.get("name", path.name)},
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _is_websocket_channel_session_key(key: str) -> bool:
|
||||
return key.startswith("websocket:")
|
||||
|
||||
@ -65,6 +65,32 @@ def manager() -> ChannelManager:
|
||||
return mgr
|
||||
|
||||
|
||||
def test_websocket_gateway_uses_configured_workspace_restriction(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"nanobot.webui.workspaces.read_webui_default_access_mode",
|
||||
lambda: "default",
|
||||
)
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"workspace": str(tmp_path)}},
|
||||
"tools": {"restrictToWorkspace": True},
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": True,
|
||||
"websocketRequiresToken": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
mgr = ChannelManager(config, MessageBus(), webui_static_dist=False)
|
||||
channel = mgr.channels["websocket"]
|
||||
|
||||
scope = channel.gateway.workspaces.default_scope()
|
||||
assert scope.project_path == tmp_path
|
||||
assert scope.restrict_to_workspace is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_delta_routes_to_send_reasoning_delta(manager):
|
||||
channel = manager.channels["mock"]
|
||||
|
||||
@ -20,20 +20,30 @@ from nanobot.channels.websocket import (
|
||||
WebSocketChannel,
|
||||
WebSocketConfig,
|
||||
_is_valid_chat_id,
|
||||
_issue_route_secret_matches,
|
||||
_normalize_config_path,
|
||||
_normalize_http_path,
|
||||
_parse_envelope,
|
||||
_parse_inbound_payload,
|
||||
_parse_query,
|
||||
_parse_request_path,
|
||||
publish_runtime_model_update,
|
||||
)
|
||||
from nanobot.webui.ws_http import GatewayHTTPHandler
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
from nanobot.config.schema import Config, ModelPresetConfig
|
||||
from nanobot.session import webui_turns as wth
|
||||
from nanobot.session.manager import SessionManager
|
||||
from nanobot.webui.gateway_services import GatewayServices, build_gateway_services
|
||||
from nanobot.webui.http_utils import (
|
||||
issue_route_secret_matches as _issue_route_secret_matches,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
normalize_config_path as _normalize_config_path,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
normalize_http_path as _normalize_http_path,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
parse_query as _parse_query,
|
||||
)
|
||||
from nanobot.webui.http_utils import (
|
||||
parse_request_path as _parse_request_path,
|
||||
)
|
||||
from nanobot.webui.settings_api import settings_payload, update_provider_settings
|
||||
|
||||
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
|
||||
@ -52,34 +62,36 @@ def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
|
||||
}
|
||||
cfg.update(kw)
|
||||
parsed = WebSocketConfig.model_validate(cfg)
|
||||
http_handler = GatewayHTTPHandler(
|
||||
gateway = build_gateway_services(
|
||||
config=parsed,
|
||||
bus=bus,
|
||||
session_manager=None,
|
||||
static_dist_path=None,
|
||||
workspace_path=Path.cwd(),
|
||||
default_restrict_to_workspace=False,
|
||||
runtime_model_name=None,
|
||||
runtime_surface="browser",
|
||||
runtime_capabilities_overrides=None,
|
||||
bus=bus,
|
||||
)
|
||||
return WebSocketChannel(cfg, bus, http_handler=http_handler)
|
||||
return WebSocketChannel(cfg, bus, gateway=gateway)
|
||||
|
||||
|
||||
def _basic_handler(bus: Any, **kw: Any) -> GatewayHTTPHandler:
|
||||
def _basic_handler(bus: Any, **kw: Any) -> GatewayServices:
|
||||
cfg = WebSocketConfig.model_validate({
|
||||
"enabled": True, "allowFrom": ["*"],
|
||||
"host": "127.0.0.1", "port": _PORT,
|
||||
"path": "/ws", "websocketRequiresToken": False,
|
||||
})
|
||||
return GatewayHTTPHandler(
|
||||
return build_gateway_services(
|
||||
config=cfg,
|
||||
bus=bus,
|
||||
session_manager=kw.get("session_manager"),
|
||||
static_dist_path=None,
|
||||
workspace_path=kw.get("workspace_path", Path.cwd()),
|
||||
default_restrict_to_workspace=kw.get("default_restrict_to_workspace", False),
|
||||
runtime_model_name=None,
|
||||
runtime_surface=kw.get("runtime_surface", "browser"),
|
||||
runtime_capabilities_overrides=kw.get("runtime_capabilities_overrides"),
|
||||
bus=bus,
|
||||
)
|
||||
|
||||
|
||||
@ -194,7 +206,7 @@ def test_ssl_context_requires_both_cert_and_key_files() -> None:
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
|
||||
bus,
|
||||
http_handler=_basic_handler(bus),
|
||||
gateway=_basic_handler(bus),
|
||||
)
|
||||
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
|
||||
channel._build_ssl_context()
|
||||
@ -310,7 +322,7 @@ async def test_webui_message_scope_inherits_persisted_session_scope(
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
|
||||
bus,
|
||||
http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
|
||||
gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
|
||||
)
|
||||
conn = AsyncMock()
|
||||
conn.remote_address = ("127.0.0.1", 50123)
|
||||
@ -356,7 +368,7 @@ async def test_webui_scope_expands_home_project_path(
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
|
||||
bus,
|
||||
http_handler=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace),
|
||||
gateway=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace),
|
||||
)
|
||||
conn = AsyncMock()
|
||||
conn.remote_address = ("127.0.0.1", 50123)
|
||||
@ -393,7 +405,7 @@ async def test_webui_scope_rejects_missing_project_path(bus: MagicMock, tmp_path
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
|
||||
bus,
|
||||
http_handler=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace),
|
||||
gateway=_basic_handler(bus, session_manager=SessionManager(tmp_path / "sessions"), workspace_path=default_workspace),
|
||||
)
|
||||
conn = AsyncMock()
|
||||
conn.remote_address = ("127.0.0.1", 50123)
|
||||
@ -430,7 +442,7 @@ async def test_webui_scope_rejects_running_scope_change(bus: MagicMock, tmp_path
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
|
||||
bus,
|
||||
http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
|
||||
gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
|
||||
)
|
||||
conn = AsyncMock()
|
||||
conn.remote_address = ("127.0.0.1", 50123)
|
||||
@ -486,7 +498,7 @@ async def test_webui_set_workspace_scope_rejects_running_chat(bus: MagicMock, tm
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
|
||||
bus,
|
||||
http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
|
||||
gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
|
||||
)
|
||||
conn = AsyncMock()
|
||||
conn.remote_address = ("127.0.0.1", 50123)
|
||||
@ -545,7 +557,7 @@ async def test_webui_scope_rejects_non_loopback_custom_scope(bus: MagicMock, tmp
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "host": "127.0.0.1"},
|
||||
bus,
|
||||
http_handler=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
|
||||
gateway=_basic_handler(bus, session_manager=sessions, workspace_path=default_workspace),
|
||||
)
|
||||
conn = AsyncMock()
|
||||
conn.remote_address = ("203.0.113.8", 50123)
|
||||
@ -574,7 +586,7 @@ async def test_webui_scope_rejects_non_loopback_custom_scope(bus: MagicMock, tmp
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -600,7 +612,7 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_broadcasts_runtime_model_updates() -> None:
|
||||
bus = MessageBus()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -647,8 +659,8 @@ async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -
|
||||
return ws_media if channel == "websocket" else media_root
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
||||
monkeypatch.setattr("nanobot.webui.ws_http.get_media_dir", fake_media_dir)
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
monkeypatch.setattr("nanobot.webui.media_gateway.get_media_dir", fake_media_dir)
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -671,7 +683,7 @@ async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_missing_connection_is_noop_without_error() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
|
||||
await channel.send(msg)
|
||||
|
||||
@ -679,7 +691,7 @@ async def test_send_missing_connection_is_noop_without_error() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_removes_connection_on_connection_closed() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
@ -694,7 +706,7 @@ async def test_send_removes_connection_on_connection_closed() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_includes_structured_tool_events() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -742,7 +754,7 @@ async def test_send_progress_includes_structured_tool_events() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_file_edit_progress_uses_file_edit_event() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -791,7 +803,7 @@ async def test_send_file_edit_progress_uses_file_edit_event() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_includes_agent_ui_blob() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -815,7 +827,7 @@ async def test_send_progress_includes_agent_ui_blob() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_removes_connection_on_connection_closed() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
@ -829,7 +841,7 @@ async def test_send_delta_removes_connection_on_connection_closed() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_emits_delta_and_stream_end() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -862,11 +874,11 @@ async def test_send_delta_stream_end_rewrites_local_markdown_image(monkeypatch,
|
||||
return path
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
||||
monkeypatch.setattr("nanobot.webui.ws_http.get_media_dir", fake_media_dir)
|
||||
monkeypatch.setattr("nanobot.webui.media_gateway.get_media_dir", fake_media_dir)
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "streaming": True},
|
||||
bus,
|
||||
http_handler=_basic_handler(bus, workspace_path=workspace),
|
||||
gateway=_basic_handler(bus, workspace_path=workspace),
|
||||
)
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
@ -895,11 +907,11 @@ async def test_send_delta_stream_end_rewrites_inline_final_text(monkeypatch, tmp
|
||||
return path
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
||||
monkeypatch.setattr("nanobot.webui.ws_http.get_media_dir", fake_media_dir)
|
||||
monkeypatch.setattr("nanobot.webui.media_gateway.get_media_dir", fake_media_dir)
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "streaming": True},
|
||||
bus,
|
||||
http_handler=_basic_handler(bus, workspace_path=workspace),
|
||||
gateway=_basic_handler(bus, workspace_path=workspace),
|
||||
)
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
@ -919,7 +931,7 @@ async def test_send_delta_stream_end_rewrites_inline_final_text(monkeypatch, tmp
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reasoning_delta_emits_streaming_frame() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -940,7 +952,7 @@ async def test_send_reasoning_delta_emits_streaming_frame() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reasoning_end_emits_close_frame() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -956,7 +968,7 @@ async def test_send_reasoning_one_shot_expands_to_delta_plus_end() -> None:
|
||||
the base implementation must produce one delta and one end so the
|
||||
WebUI sees the same shape either way."""
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -978,7 +990,7 @@ async def test_send_reasoning_one_shot_expands_to_delta_plus_end() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reasoning_delta_drops_empty_chunks() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -990,7 +1002,7 @@ async def test_send_reasoning_delta_drops_empty_chunks() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reasoning_without_subscribers_is_noop() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
|
||||
await channel.send_reasoning_delta("unattached", "thinking", None)
|
||||
await channel.send_reasoning_end("unattached", None)
|
||||
@ -1000,7 +1012,7 @@ async def test_send_reasoning_without_subscribers_is_noop() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_turn_end_emits_turn_end_event() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -1019,7 +1031,7 @@ async def test_send_turn_end_emits_turn_end_event() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_turn_end_includes_latency_ms_when_present() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -1038,7 +1050,7 @@ async def test_send_turn_end_includes_latency_ms_when_present() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_turn_end_includes_goal_state_when_present() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -1058,7 +1070,7 @@ async def test_send_turn_end_includes_goal_state_when_present() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_goal_status_running_emits_event_with_started_at() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -1086,7 +1098,7 @@ async def test_send_goal_status_running_emits_event_with_started_at() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_goal_status_idle_omits_started_at() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -1109,7 +1121,7 @@ async def test_send_goal_status_idle_omits_started_at() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_goal_state_emits_blob_per_chat() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_a = AsyncMock()
|
||||
mock_b = AsyncMock()
|
||||
channel._attach(mock_a, "chat-a")
|
||||
@ -1138,10 +1150,9 @@ async def test_send_goal_state_emits_blob_per_chat() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_push_active_goal_state_noop_without_session_manager() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
channel._session_manager = None
|
||||
await channel._maybe_push_active_goal_state("chat-1")
|
||||
mock_ws.send.assert_not_called()
|
||||
|
||||
@ -1149,10 +1160,13 @@ async def test_maybe_push_active_goal_state_noop_without_session_manager() -> No
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_push_active_goal_state_skips_when_no_goal_on_disk() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
sm = MagicMock()
|
||||
sm.read_session_file.return_value = None
|
||||
channel._session_manager = sm
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"]},
|
||||
bus,
|
||||
gateway=_basic_handler(bus, session_manager=sm),
|
||||
)
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
await channel._maybe_push_active_goal_state("chat-1")
|
||||
@ -1162,7 +1176,6 @@ async def test_maybe_push_active_goal_state_skips_when_no_goal_on_disk() -> None
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_push_active_goal_state_notifies_when_goal_active_on_disk() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
sm = MagicMock()
|
||||
sm.read_session_file.return_value = {
|
||||
"metadata": {
|
||||
@ -1174,7 +1187,11 @@ async def test_maybe_push_active_goal_state_notifies_when_goal_active_on_disk()
|
||||
},
|
||||
"messages": [],
|
||||
}
|
||||
channel._session_manager = sm
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"]},
|
||||
bus,
|
||||
gateway=_basic_handler(bus, session_manager=sm),
|
||||
)
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
await channel._maybe_push_active_goal_state("chat-1")
|
||||
@ -1190,7 +1207,7 @@ async def test_maybe_push_active_goal_state_notifies_when_goal_active_on_disk()
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_push_turn_run_wall_clock_skips_when_no_active_turn() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
from nanobot.session import webui_turns as wth
|
||||
@ -1203,7 +1220,7 @@ async def test_maybe_push_turn_run_wall_clock_skips_when_no_active_turn() -> Non
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_push_turn_run_wall_clock_replays_running() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
from nanobot.session import webui_turns as wth
|
||||
@ -1228,7 +1245,7 @@ async def test_maybe_push_turn_run_wall_clock_replays_running() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_session_updated_emits_session_updated_event() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -1247,7 +1264,7 @@ async def test_send_session_updated_emits_session_updated_event() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_session_updated_includes_scope_when_present() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
|
||||
@ -1266,7 +1283,7 @@ async def test_send_session_updated_includes_scope_when_present() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_non_connection_closed_exception_is_raised() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = RuntimeError("unexpected")
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
@ -1279,7 +1296,7 @@ async def test_send_non_connection_closed_exception_is_raised() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_missing_connection_is_noop() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus, gateway=_basic_handler(bus))
|
||||
# No exception, no error — just a no-op
|
||||
await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
|
||||
|
||||
@ -1287,7 +1304,7 @@ async def test_send_delta_missing_connection_is_noop() -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_is_idempotent() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, http_handler=_basic_handler(bus))
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus, gateway=_basic_handler(bus))
|
||||
# stop() before start() should not raise
|
||||
await channel.stop()
|
||||
await channel.stop()
|
||||
@ -1448,7 +1465,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist(
|
||||
)
|
||||
|
||||
channel = _ch(bus, port=port)
|
||||
channel._api_tokens["tok"] = time.monotonic() + 300
|
||||
channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
@ -1733,7 +1750,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist(
|
||||
async def test_commands_api_returns_slash_command_metadata(bus: MagicMock) -> None:
|
||||
port = 29892
|
||||
channel = _ch(bus, port=port)
|
||||
channel._api_tokens["tok"] = time.monotonic() + 300
|
||||
channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
@ -1771,7 +1788,7 @@ async def test_bootstrap_exposes_native_surface(bus: MagicMock) -> None:
|
||||
"websocketRequiresToken": True,
|
||||
},
|
||||
bus,
|
||||
http_handler=_basic_handler(bus, runtime_surface="native", runtime_capabilities_overrides={"can_pick_folder": True}),
|
||||
gateway=_basic_handler(bus, runtime_surface="native", runtime_capabilities_overrides={"can_pick_folder": True}),
|
||||
)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
@ -1941,8 +1958,9 @@ async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None:
|
||||
|
||||
try:
|
||||
# Fill issued tokens to capacity
|
||||
channel._http.issued_tokens = {
|
||||
f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._http._MAX_ISSUED_TOKENS)
|
||||
channel.gateway.tokens.issued_tokens = {
|
||||
f"nbwt_fill_{i}": time.monotonic() + 300
|
||||
for i in range(channel.gateway.tokens.max_tokens)
|
||||
}
|
||||
|
||||
resp = await _http_get(
|
||||
@ -2299,10 +2317,8 @@ def test_sessions_list_includes_active_run_started_at() -> None:
|
||||
from nanobot.session import webui_turns as wth
|
||||
|
||||
bus = MagicMock()
|
||||
channel = _ch(bus)
|
||||
channel._api_tokens["tok"] = time.monotonic() + 300.0
|
||||
channel._session_manager = MagicMock()
|
||||
channel._session_manager.list_sessions.return_value = [
|
||||
session_manager = MagicMock()
|
||||
session_manager.list_sessions.return_value = [
|
||||
{
|
||||
"key": "websocket:chat-1",
|
||||
"created_at": "2026-05-19T10:00:00Z",
|
||||
@ -2317,19 +2333,25 @@ def test_sessions_list_includes_active_run_started_at() -> None:
|
||||
"updated_at": "2026-05-19T10:01:00Z",
|
||||
},
|
||||
]
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"]},
|
||||
bus,
|
||||
gateway=_basic_handler(bus, session_manager=session_manager),
|
||||
)
|
||||
channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300.0
|
||||
|
||||
wth._WEBSOCKET_TURN_WALL_STARTED_AT.clear()
|
||||
try:
|
||||
wth._WEBSOCKET_TURN_WALL_STARTED_AT["chat-1"] = 1_700_000_000.0
|
||||
req = Request("/api/sessions", Headers([("Authorization", "Bearer tok")]))
|
||||
resp = channel._handle_sessions_list(req)
|
||||
resp = channel.gateway.http._handle_sessions_list(req)
|
||||
finally:
|
||||
wth._WEBSOCKET_TURN_WALL_STARTED_AT.clear()
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = json.loads(resp.body.decode())
|
||||
workspace_scope = body["sessions"][0].pop("workspace_scope")
|
||||
assert workspace_scope["project_path"] == str(channel._workspace_path)
|
||||
assert workspace_scope["project_path"] == str(channel.gateway.media.workspace_path)
|
||||
assert workspace_scope["access_mode"] in {"restricted", "full"}
|
||||
assert body["sessions"] == [
|
||||
{
|
||||
@ -2376,10 +2398,10 @@ def test_handle_webui_thread_get_returns_json(tmp_path, monkeypatch) -> None:
|
||||
append_transcript_object(key, {"event": "user", "chat_id": "c1", "text": "hi"})
|
||||
bus = MagicMock()
|
||||
channel = _ch(bus)
|
||||
channel._api_tokens["tok"] = time.monotonic() + 300.0
|
||||
channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300.0
|
||||
enc = quote(key, safe="")
|
||||
req = Request(f"/api/sessions/{enc}/webui-thread", Headers([("Authorization", "Bearer tok")]))
|
||||
resp = channel._handle_webui_thread_get(req, enc)
|
||||
resp = channel.gateway.http._handle_webui_thread_get(req, enc)
|
||||
assert resp.status_code == 200
|
||||
body = json.loads(resp.body.decode())
|
||||
assert body["sessionKey"] == key
|
||||
|
||||
@ -21,7 +21,7 @@ from nanobot.channels.websocket import (
|
||||
WebSocketConfig,
|
||||
_extract_data_url_mime,
|
||||
)
|
||||
from nanobot.webui.ws_http import GatewayHTTPHandler
|
||||
from nanobot.webui.gateway_services import build_gateway_services
|
||||
|
||||
|
||||
def _tiny_png_data_url() -> str:
|
||||
@ -45,17 +45,18 @@ def _make_channel() -> WebSocketChannel:
|
||||
bus.publish_inbound = AsyncMock()
|
||||
cfg = {"enabled": True, "allowFrom": ["*"], "websocketRequiresToken": False}
|
||||
parsed = WebSocketConfig.model_validate(cfg)
|
||||
handler = GatewayHTTPHandler(
|
||||
gateway = build_gateway_services(
|
||||
config=parsed,
|
||||
bus=bus,
|
||||
session_manager=None,
|
||||
static_dist_path=None,
|
||||
workspace_path=Path.cwd(),
|
||||
default_restrict_to_workspace=False,
|
||||
runtime_model_name=None,
|
||||
runtime_surface="browser",
|
||||
runtime_capabilities_overrides=None,
|
||||
bus=bus,
|
||||
)
|
||||
channel = WebSocketChannel(cfg, bus, http_handler=handler)
|
||||
channel = WebSocketChannel(cfg, bus, gateway=gateway)
|
||||
channel._handle_message = AsyncMock() # type: ignore[method-assign]
|
||||
return channel
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ import pytest
|
||||
|
||||
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.webui.ws_http import GatewayHTTPHandler
|
||||
from nanobot.webui.gateway_services import GatewayServices, build_gateway_services
|
||||
|
||||
_PORT = 29900
|
||||
|
||||
@ -25,18 +25,19 @@ def _make_handler(
|
||||
session_manager: SessionManager | None = None,
|
||||
static_dist_path: Path | None = None,
|
||||
runtime_model_name: Any | None = None,
|
||||
) -> GatewayHTTPHandler:
|
||||
) -> GatewayServices:
|
||||
config = WebSocketConfig.model_validate(cfg) if isinstance(cfg, dict) else cfg
|
||||
workspace = Path.cwd()
|
||||
return GatewayHTTPHandler(
|
||||
return build_gateway_services(
|
||||
config=config,
|
||||
bus=bus,
|
||||
session_manager=session_manager,
|
||||
static_dist_path=static_dist_path,
|
||||
workspace_path=workspace,
|
||||
default_restrict_to_workspace=False,
|
||||
runtime_model_name=runtime_model_name,
|
||||
runtime_surface="browser",
|
||||
runtime_capabilities_overrides=None,
|
||||
bus=bus,
|
||||
)
|
||||
|
||||
|
||||
@ -58,13 +59,13 @@ def _ch(
|
||||
"websocketRequiresToken": False,
|
||||
}
|
||||
cfg.update(extra)
|
||||
http_handler = _make_handler(
|
||||
gateway = _make_handler(
|
||||
cfg, bus,
|
||||
session_manager=session_manager,
|
||||
static_dist_path=static_dist_path,
|
||||
runtime_model_name=runtime_model_name,
|
||||
)
|
||||
return WebSocketChannel(cfg, bus, http_handler=http_handler)
|
||||
return WebSocketChannel(cfg, bus, gateway=gateway)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -729,20 +730,20 @@ async def test_api_token_pool_purges_expired(bus: MagicMock, tmp_path: Path) ->
|
||||
channel = _ch(bus, session_manager=sm, port=29908)
|
||||
# Don't start a server — directly inject and validate.
|
||||
import time as _time
|
||||
channel._http.api_tokens["expired"] = _time.monotonic() - 1
|
||||
channel._http.api_tokens["live"] = _time.monotonic() + 60
|
||||
channel.gateway.tokens.api_tokens["expired"] = _time.monotonic() - 1
|
||||
channel.gateway.tokens.api_tokens["live"] = _time.monotonic() + 60
|
||||
|
||||
class _FakeReq:
|
||||
path = "/api/sessions"
|
||||
headers = {"Authorization": "Bearer expired"}
|
||||
|
||||
assert channel._http.check_api_token(_FakeReq()) is False
|
||||
assert channel.gateway.tokens.check_api_token(_FakeReq()) is False
|
||||
|
||||
class _LiveReq:
|
||||
path = "/api/sessions"
|
||||
headers = {"Authorization": "Bearer live"}
|
||||
|
||||
assert channel._http.check_api_token(_LiveReq()) is True
|
||||
assert channel.gateway.tokens.check_api_token(_LiveReq()) is True
|
||||
|
||||
|
||||
class _FakeConn:
|
||||
@ -797,7 +798,7 @@ def test_wildcard_ipv6_without_auth_raises(bus: MagicMock) -> None:
|
||||
|
||||
def test_wildcard_ipv6_with_secret_is_valid(bus: MagicMock) -> None:
|
||||
channel = _ch(bus, host="::", tokenIssueSecret="s3cret")
|
||||
resp = channel._handle_bootstrap(
|
||||
resp = channel.gateway.http._handle_bootstrap(
|
||||
_REMOTE, _FakeReq({"X-Nanobot-Auth": "s3cret"})
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
@ -806,7 +807,7 @@ def test_wildcard_ipv6_with_secret_is_valid(bus: MagicMock) -> None:
|
||||
def test_bootstrap_accepts_static_token_as_secret(bus: MagicMock) -> None:
|
||||
"""When only token (not token_issue_secret) is set, bootstrap accepts it."""
|
||||
channel = _ch(bus, host="0.0.0.0", token="static-tok")
|
||||
resp = channel._handle_bootstrap(
|
||||
resp = channel.gateway.http._handle_bootstrap(
|
||||
_REMOTE, _FakeReq({"Authorization": "Bearer static-tok"})
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
@ -816,7 +817,7 @@ def test_bootstrap_accepts_static_token_as_secret(bus: MagicMock) -> None:
|
||||
|
||||
def test_bootstrap_ws_url_uses_forwarded_https_host(bus: MagicMock) -> None:
|
||||
channel = _ch(bus, host="127.0.0.1", port=29931)
|
||||
resp = channel._handle_bootstrap(
|
||||
resp = channel.gateway.http._handle_bootstrap(
|
||||
_LOCAL,
|
||||
_FakeReq({"Host": "nanobot.example", "X-Forwarded-Proto": "https"}),
|
||||
)
|
||||
@ -827,7 +828,7 @@ def test_bootstrap_ws_url_uses_forwarded_https_host(bus: MagicMock) -> None:
|
||||
|
||||
def test_localhost_without_auth_is_valid(bus: MagicMock) -> None:
|
||||
channel = _ch(bus, host="127.0.0.1")
|
||||
resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS)
|
||||
resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@ -837,7 +838,7 @@ def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytes
|
||||
lambda: "from-disk",
|
||||
)
|
||||
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " live/model ")
|
||||
resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS)
|
||||
resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS)
|
||||
assert resp.status_code == 200
|
||||
body = json.loads(resp.body)
|
||||
assert body["model_name"] == "live/model"
|
||||
@ -849,7 +850,7 @@ def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeyp
|
||||
lambda: "from-disk",
|
||||
)
|
||||
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " ")
|
||||
resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS)
|
||||
resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS)
|
||||
assert resp.status_code == 200
|
||||
body = json.loads(resp.body)
|
||||
assert body["model_name"] == "from-disk"
|
||||
@ -865,7 +866,7 @@ def test_bootstrap_falls_back_when_runtime_raises(bus: MagicMock, monkeypatch: p
|
||||
raise RuntimeError("resolver failed")
|
||||
|
||||
channel = _ch(bus, host="127.0.0.1", runtime_model_name=boom)
|
||||
resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS)
|
||||
resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS)
|
||||
assert resp.status_code == 200
|
||||
body = json.loads(resp.body)
|
||||
assert body["model_name"] == "from-disk"
|
||||
@ -873,7 +874,7 @@ def test_bootstrap_falls_back_when_runtime_raises(bus: MagicMock, monkeypatch: p
|
||||
|
||||
def test_bootstrap_rejects_wrong_secret(bus: MagicMock) -> None:
|
||||
channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="correct")
|
||||
resp = channel._handle_bootstrap(
|
||||
resp = channel.gateway.http._handle_bootstrap(
|
||||
_REMOTE, _FakeReq({"Authorization": "Bearer wrong"})
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
@ -881,7 +882,7 @@ def test_bootstrap_rejects_wrong_secret(bus: MagicMock) -> None:
|
||||
|
||||
def test_bootstrap_accepts_remote_with_valid_secret(bus: MagicMock) -> None:
|
||||
channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="s3cret")
|
||||
resp = channel._handle_bootstrap(
|
||||
resp = channel.gateway.http._handle_bootstrap(
|
||||
_REMOTE, _FakeReq({"Authorization": "Bearer s3cret"})
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
@ -891,7 +892,7 @@ def test_bootstrap_accepts_remote_with_valid_secret(bus: MagicMock) -> None:
|
||||
|
||||
def test_bootstrap_accepts_x_nanobot_auth_header(bus: MagicMock) -> None:
|
||||
channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="s3cret")
|
||||
resp = channel._handle_bootstrap(
|
||||
resp = channel.gateway.http._handle_bootstrap(
|
||||
_REMOTE, _FakeReq({"X-Nanobot-Auth": "s3cret"})
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
@ -900,5 +901,5 @@ def test_bootstrap_accepts_x_nanobot_auth_header(bus: MagicMock) -> None:
|
||||
def test_bootstrap_secret_also_enforced_on_localhost(bus: MagicMock) -> None:
|
||||
"""When secret is set, even localhost must provide it (reverse-proxy safety)."""
|
||||
channel = _ch(bus, host="0.0.0.0", tokenIssueSecret="s3cret")
|
||||
resp = channel._handle_bootstrap(_LOCAL, _NO_HEADERS)
|
||||
resp = channel.gateway.http._handle_bootstrap(_LOCAL, _NO_HEADERS)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@ -7,19 +7,18 @@ multi-client scenarios, edge cases, and realistic usage patterns.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
|
||||
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.webui.ws_http import GatewayHTTPHandler
|
||||
from ws_test_client import WsTestClient, issue_token, issue_token_ok
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
|
||||
from nanobot.webui.gateway_services import build_gateway_services
|
||||
|
||||
|
||||
def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
|
||||
cfg: dict[str, Any] = {
|
||||
@ -32,17 +31,18 @@ def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
|
||||
}
|
||||
cfg.update(kw)
|
||||
parsed = WebSocketConfig.model_validate(cfg)
|
||||
handler = GatewayHTTPHandler(
|
||||
gateway = build_gateway_services(
|
||||
config=parsed,
|
||||
bus=bus,
|
||||
session_manager=None,
|
||||
static_dist_path=None,
|
||||
workspace_path=Path.cwd(),
|
||||
default_restrict_to_workspace=False,
|
||||
runtime_model_name=None,
|
||||
runtime_surface="browser",
|
||||
runtime_capabilities_overrides=None,
|
||||
bus=bus,
|
||||
)
|
||||
return WebSocketChannel(cfg, bus, http_handler=handler)
|
||||
return WebSocketChannel(cfg, bus, gateway=gateway)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -67,7 +67,8 @@ async def test_ready_event_fields(bus: MagicMock) -> None:
|
||||
assert len(r.chat_id) == 36
|
||||
assert r.client_id == "c1"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -80,7 +81,8 @@ async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None:
|
||||
r = await c.recv_ready()
|
||||
assert r.client_id.startswith("anon-")
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -93,7 +95,8 @@ async def test_each_connection_unique_chat_id(bus: MagicMock) -> None:
|
||||
async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2:
|
||||
assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
# -- Inbound messages (client -> server) ----------------------------------
|
||||
@ -113,7 +116,8 @@ async def test_plain_text(bus: MagicMock) -> None:
|
||||
assert inbound.content == "hello world"
|
||||
assert inbound.sender_id == "p"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -128,7 +132,8 @@ async def test_json_content_field(bus: MagicMock) -> None:
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "structured"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -146,7 +151,8 @@ async def test_json_text_and_message_fields(bus: MagicMock) -> None:
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "via message"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -162,7 +168,8 @@ async def test_empty_payload_ignored(bus: MagicMock) -> None:
|
||||
await asyncio.sleep(0.1)
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -179,7 +186,8 @@ async def test_messages_preserve_order(bus: MagicMock) -> None:
|
||||
contents = [call[0][0].content for call in bus.publish_inbound.call_args_list]
|
||||
assert contents == [f"msg-{i}" for i in range(5)]
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
# -- Outbound messages (server -> client) ---------------------------------
|
||||
@ -199,7 +207,8 @@ async def test_server_send_message(bus: MagicMock) -> None:
|
||||
msg = await c.recv_message()
|
||||
assert msg.text == "reply"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -238,7 +247,8 @@ async def test_server_send_tags_tool_hint_with_kind(bus: MagicMock) -> None:
|
||||
prog = await c.recv_message()
|
||||
assert prog.raw.get("kind") == "progress"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -258,7 +268,8 @@ async def test_server_send_with_media_and_reply(bus: MagicMock) -> None:
|
||||
assert msg.media == ["/tmp/a.png"]
|
||||
assert msg.reply_to == "m1"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
# -- Streaming ------------------------------------------------------------
|
||||
@ -282,7 +293,8 @@ async def test_streaming_deltas_and_end(bus: MagicMock) -> None:
|
||||
ends = [m for m in msgs if m.event == "stream_end"]
|
||||
assert len(ends) == 1
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -306,7 +318,8 @@ async def test_interleaved_streams(bus: MagicMock) -> None:
|
||||
assert sa == "A1A2"
|
||||
assert sb == "B1B2"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
# -- Multi-client ---------------------------------------------------------
|
||||
@ -330,7 +343,8 @@ async def test_independent_sessions(bus: MagicMock) -> None:
|
||||
))
|
||||
assert (await c2.recv_message()).text == "for-u2"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -348,7 +362,8 @@ async def test_disconnected_client_cleanup(bus: MagicMock) -> None:
|
||||
))
|
||||
assert chat_id not in ch._subs
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
# -- Authentication -------------------------------------------------------
|
||||
@ -363,7 +378,8 @@ async def test_static_token_accepted(bus: MagicMock) -> None:
|
||||
async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c:
|
||||
assert (await c.recv_ready()).client_id == "a"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -377,7 +393,8 @@ async def test_static_token_rejected(bus: MagicMock) -> None:
|
||||
pass
|
||||
assert exc.value.response.status_code == 401
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -411,7 +428,8 @@ async def test_token_issue_full_flow(bus: MagicMock) -> None:
|
||||
pass
|
||||
assert exc.value.response.status_code == 401
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
# -- Path routing ---------------------------------------------------------
|
||||
@ -426,7 +444,8 @@ async def test_custom_path(bus: MagicMock) -> None:
|
||||
async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c:
|
||||
assert (await c.recv_ready()).event == "ready"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -440,7 +459,8 @@ async def test_wrong_path_404(bus: MagicMock) -> None:
|
||||
pass
|
||||
assert exc.value.response.status_code == 404
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -452,7 +472,8 @@ async def test_trailing_slash_normalized(bus: MagicMock) -> None:
|
||||
async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c:
|
||||
assert (await c.recv_ready()).event == "ready"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
# -- Edge cases -----------------------------------------------------------
|
||||
@ -471,7 +492,8 @@ async def test_large_message(bus: MagicMock) -> None:
|
||||
await asyncio.sleep(0.2)
|
||||
assert bus.publish_inbound.call_args[0][0].content == big
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -491,7 +513,8 @@ async def test_unicode_roundtrip(bus: MagicMock) -> None:
|
||||
))
|
||||
assert (await c.recv_message()).text == text
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -513,7 +536,8 @@ async def test_rapid_fire(bus: MagicMock) -> None:
|
||||
received = [(await c.recv_message()).text for _ in range(50)]
|
||||
assert received == [f"out-{i}" for i in range(50)]
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -528,4 +552,5 @@ async def test_invalid_json_as_plain_text(bus: MagicMock) -> None:
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "{broken json"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
await ch.stop()
|
||||
await t
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
integration on ``/api/sessions/<key>/messages``.
|
||||
|
||||
The route is the return path for images attached to persisted user turns:
|
||||
:meth:`WebSocketChannel._sign_media_path` mints URLs during session reads,
|
||||
and :meth:`WebSocketChannel._handle_media_fetch` serves the bytes back.
|
||||
:meth:`WebSocketChannel.gateway.media.sign_media_path` mints URLs during session reads,
|
||||
and :meth:`GatewayHTTPHandler._handle_media_fetch` serves the bytes back.
|
||||
These tests cover the two halves end-to-end plus the adversarial edges
|
||||
(bad signatures, ``..`` traversal, non-existent files, non-image types).
|
||||
"""
|
||||
@ -22,13 +22,12 @@ import httpx
|
||||
import pytest
|
||||
|
||||
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.webui.gateway_services import build_gateway_services
|
||||
from nanobot.webui.media_api import (
|
||||
b64url_decode,
|
||||
b64url_encode,
|
||||
)
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.webui.ws_http import GatewayHTTPHandler
|
||||
|
||||
|
||||
# PNG magic bytes + a couple of sentinel bytes so we can verify byte-for-byte
|
||||
# round-trip of the served payload. Stays under mimetype + size limits.
|
||||
@ -57,17 +56,18 @@ def _ch(
|
||||
"websocketRequiresToken": False,
|
||||
}
|
||||
parsed = WebSocketConfig.model_validate(cfg)
|
||||
http_handler = GatewayHTTPHandler(
|
||||
gateway = build_gateway_services(
|
||||
config=parsed,
|
||||
bus=bus,
|
||||
session_manager=session_manager,
|
||||
static_dist_path=None,
|
||||
workspace_path=workspace_path or Path.cwd(),
|
||||
default_restrict_to_workspace=False,
|
||||
runtime_model_name=None,
|
||||
runtime_surface="browser",
|
||||
runtime_capabilities_overrides=None,
|
||||
bus=bus,
|
||||
)
|
||||
return WebSocketChannel(cfg, bus, http_handler=http_handler)
|
||||
return WebSocketChannel(cfg, bus, gateway=gateway)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -95,7 +95,7 @@ async def _http_get(
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sign_media_path: the URL minter
|
||||
# gateway.media.sign_media_path: the URL minter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@ -114,11 +114,11 @@ def test_sign_media_path_rejects_paths_outside_media_root(
|
||||
media = tmp_path / "media"
|
||||
media.mkdir()
|
||||
channel = _ch(bus, port=0)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
assert channel._sign_media_path(outside) is None
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
assert channel.gateway.media.sign_media_path(outside) is None
|
||||
# Traversal via the media root is also rejected — the resolve() step
|
||||
# normalises ``..`` out before the relative_to check.
|
||||
assert channel._sign_media_path(media / ".." / "secrets" / "cred.txt") is None
|
||||
assert channel.gateway.media.sign_media_path(media / ".." / "secrets" / "cred.txt") is None
|
||||
|
||||
|
||||
def test_sign_media_path_round_trips_via_hmac(
|
||||
@ -129,13 +129,13 @@ def test_sign_media_path_round_trips_via_hmac(
|
||||
media.mkdir()
|
||||
(media / "a.png").write_bytes(_PNG_BYTES)
|
||||
channel = _ch(bus, port=0)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
url = channel._sign_media_path(media / "a.png")
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
url = channel.gateway.media.sign_media_path(media / "a.png")
|
||||
assert url is not None
|
||||
assert url.startswith("/api/media/")
|
||||
sig, payload = url[len("/api/media/"):].split("/", 1)
|
||||
expected = hmac.new(
|
||||
channel._media_secret, payload.encode("ascii"), hashlib.sha256
|
||||
channel.gateway.media.secret, payload.encode("ascii"), hashlib.sha256
|
||||
).digest()[:16]
|
||||
assert b64url_decode(sig) == expected
|
||||
# The payload decodes back to the *relative* path — no absolute-path leaks.
|
||||
@ -152,8 +152,8 @@ def test_local_markdown_image_is_staged_and_rewritten(
|
||||
media = tmp_path / "media"
|
||||
channel = _ch(bus, workspace_path=workspace, port=0)
|
||||
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
rewritten = channel._rewrite_local_markdown_images(
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
rewritten = channel.gateway.media.rewrite_local_markdown_images(
|
||||
"The result:\n"
|
||||
)
|
||||
|
||||
@ -174,8 +174,8 @@ def test_local_markdown_video_is_staged_and_rewritten(
|
||||
media = tmp_path / "media"
|
||||
channel = _ch(bus, workspace_path=workspace, port=0)
|
||||
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
rewritten = channel._rewrite_local_markdown_images(
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
rewritten = channel.gateway.media.rewrite_local_markdown_images(
|
||||
"The result:\n"
|
||||
)
|
||||
|
||||
@ -197,8 +197,8 @@ def test_local_markdown_image_rejects_workspace_escape(
|
||||
channel = _ch(bus, workspace_path=workspace, port=0)
|
||||
text = ""
|
||||
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
assert channel._rewrite_local_markdown_images(text) == text
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
assert channel.gateway.media.rewrite_local_markdown_images(text) == text
|
||||
|
||||
assert not (media / "websocket").exists()
|
||||
|
||||
@ -219,8 +219,8 @@ async def test_media_route_serves_signed_file(
|
||||
target.write_bytes(_PNG_BYTES)
|
||||
|
||||
channel = _ch(bus, port=29920)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
url_path = channel.gateway.media.sign_media_path(target)
|
||||
assert url_path is not None
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
@ -252,8 +252,8 @@ async def test_media_route_serves_video_byte_ranges(
|
||||
target.write_bytes(b"0123456789")
|
||||
|
||||
channel = _ch(bus, port=29927)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
url_path = channel.gateway.media.sign_media_path(target)
|
||||
assert url_path is not None
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
@ -284,8 +284,8 @@ async def test_media_route_serves_suffix_video_byte_ranges(
|
||||
target.write_bytes(b"0123456789")
|
||||
|
||||
channel = _ch(bus, port=29928)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
url_path = channel.gateway.media.sign_media_path(target)
|
||||
assert url_path is not None
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
@ -313,8 +313,8 @@ async def test_media_route_rejects_unsatisfiable_byte_range(
|
||||
target.write_bytes(b"0123456789")
|
||||
|
||||
channel = _ch(bus, port=29929)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
url_path = channel.gateway.media.sign_media_path(target)
|
||||
assert url_path is not None
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
@ -339,15 +339,15 @@ async def test_media_route_rejects_bad_signature(
|
||||
"""A payload re-signed with a different secret must 401.
|
||||
|
||||
Protects against a restart: old URLs baked into a stale tab become
|
||||
un-forgeable once ``_media_secret`` regenerates.
|
||||
un-forgeable once ``gateway.media.secret`` regenerates.
|
||||
"""
|
||||
media = tmp_path / "media"
|
||||
media.mkdir()
|
||||
(media / "f.png").write_bytes(_PNG_BYTES)
|
||||
|
||||
channel = _ch(bus, port=29921)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
good = channel._sign_media_path(media / "f.png")
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
good = channel.gateway.media.sign_media_path(media / "f.png")
|
||||
assert good is not None
|
||||
_, payload = good[len("/api/media/"):].split("/", 1)
|
||||
# Forge a sig with a *different* secret.
|
||||
@ -385,11 +385,11 @@ async def test_media_route_rejects_path_traversal_payload(
|
||||
# Hand-craft a traversal payload the legit signer would refuse to mint.
|
||||
payload = b64url_encode(b"../secret.txt")
|
||||
mac = hmac.new(
|
||||
channel._media_secret, payload.encode("ascii"), hashlib.sha256
|
||||
channel.gateway.media.secret, payload.encode("ascii"), hashlib.sha256
|
||||
).digest()[:16]
|
||||
url = f"/api/media/{b64url_encode(mac)}/{payload}"
|
||||
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
@ -413,8 +413,8 @@ async def test_media_route_404s_missing_file(
|
||||
target.write_bytes(_PNG_BYTES)
|
||||
|
||||
channel = _ch(bus, port=29923)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
url_path = channel.gateway.media.sign_media_path(target)
|
||||
assert url_path is not None
|
||||
target.unlink() # the file vanishes between signing and fetching
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
@ -441,10 +441,10 @@ async def test_media_route_degrades_non_image_to_octet_stream(
|
||||
(media / "scary.html").write_bytes(b"<script>alert(1)</script>")
|
||||
|
||||
channel = _ch(bus, port=29924)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
payload = b64url_encode(b"scary.html")
|
||||
mac = hmac.new(
|
||||
channel._media_secret, payload.encode("ascii"), hashlib.sha256
|
||||
channel.gateway.media.secret, payload.encode("ascii"), hashlib.sha256
|
||||
).digest()[:16]
|
||||
url = f"/api/media/{b64url_encode(mac)}/{payload}"
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
@ -472,8 +472,8 @@ async def test_media_route_serves_svg_with_strict_csp(
|
||||
target.write_text("<svg xmlns='http://www.w3.org/2000/svg'><script>alert(1)</script></svg>")
|
||||
|
||||
channel = _ch(bus, port=29928)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
url_path = channel.gateway.media.sign_media_path(target)
|
||||
assert url_path is not None
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
@ -513,7 +513,7 @@ async def test_session_messages_exposes_signed_media_urls(
|
||||
sm.save(sess)
|
||||
|
||||
channel = _ch(bus, session_manager=sm, port=29925)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
@ -558,7 +558,7 @@ async def test_session_messages_skips_vanished_media(
|
||||
sm.save(sess)
|
||||
|
||||
channel = _ch(bus, session_manager=sm, port=29926)
|
||||
with patch("nanobot.webui.ws_http.get_media_dir", return_value=media):
|
||||
with patch("nanobot.webui.media_gateway.get_media_dir", return_value=media):
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user