mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
refactor: extract GatewayHTTPHandler from WebSocketChannel
Extract all HTTP route handling (bootstrap, sessions, settings, media, commands, sidebar state, static serving, token management) into a new GatewayHTTPHandler class in nanobot/channels/ws_http.py. WebSocketChannel is reduced from 1907 to 1372 lines (-28%), retaining only WebSocket connection management and message dispatch. No behavior change. 3730 tests pass, 0 failures. Shared HTTP utility functions (path parsing, response builders, auth helpers) now live in ws_http.py with websocket.py importing from there, avoiding circular dependencies. Backwards-compat property aliases on WebSocketChannel ensure existing tests continue to work without modification.
This commit is contained in:
parent
92fe40a690
commit
1a585288b2
@ -7,11 +7,8 @@ import email.utils
|
||||
import hmac
|
||||
import http
|
||||
import json
|
||||
import mimetypes
|
||||
import re
|
||||
import secrets
|
||||
import ssl
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from contextlib import suppress
|
||||
@ -31,7 +28,7 @@ from websockets.http11 import Response
|
||||
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.command.builtin import builtin_command_palette
|
||||
from nanobot.channels.ws_http import GatewayHTTPHandler
|
||||
from nanobot.config.paths import get_media_dir, get_workspace_path
|
||||
from nanobot.config.schema import Base
|
||||
from nanobot.security.workspace_access import (
|
||||
@ -44,32 +41,9 @@ from nanobot.utils.media_decode import (
|
||||
FileSizeExceeded,
|
||||
save_base64_data_url,
|
||||
)
|
||||
from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel
|
||||
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
|
||||
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
||||
from nanobot.webui.media_api import (
|
||||
attach_signed_media_urls,
|
||||
serve_signed_media,
|
||||
sign_media_path,
|
||||
sign_or_stage_media_path,
|
||||
signed_media_attachments,
|
||||
)
|
||||
from nanobot.webui.settings_api import runtime_capabilities
|
||||
from nanobot.webui.settings_routes import WebUISettingsRouter
|
||||
from nanobot.webui.sidebar_state import (
|
||||
read_webui_sidebar_state,
|
||||
write_webui_sidebar_state,
|
||||
)
|
||||
from nanobot.webui.thread_disk import delete_webui_thread
|
||||
from nanobot.webui.transcript import (
|
||||
append_transcript_object,
|
||||
build_webui_thread_response,
|
||||
rewrite_local_markdown_images,
|
||||
)
|
||||
from nanobot.webui.websocket_logging import websockets_server_logger
|
||||
from nanobot.webui.workspaces import (
|
||||
WebUIWorkspaceController,
|
||||
)
|
||||
from nanobot.webui.transcript import append_transcript_object
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.session.manager import SessionManager
|
||||
@ -500,51 +474,36 @@ class WebSocketChannel(BaseChannel):
|
||||
self._conn_chats: dict[Any, set[str]] = {}
|
||||
# connection -> default chat_id for legacy frames that omit routing.
|
||||
self._conn_default: dict[Any, str] = {}
|
||||
# Single-use tokens consumed at WebSocket handshake.
|
||||
self._issued_tokens: dict[str, float] = {}
|
||||
# Multi-use tokens for HTTP routes served beside WS; checked but not consumed.
|
||||
self._api_tokens: dict[str, float] = {}
|
||||
self._stop_event: asyncio.Event | None = None
|
||||
self._server_task: asyncio.Task[None] | None = None
|
||||
self._session_manager = session_manager
|
||||
self._static_dist_path: Path | None = (
|
||||
static_dist_path.resolve() if static_dist_path is not None else None
|
||||
)
|
||||
self._workspace_path = (
|
||||
_resolved_workspace = (
|
||||
Path(workspace_path).expanduser()
|
||||
if workspace_path is not None
|
||||
else get_workspace_path()
|
||||
).resolve(strict=False)
|
||||
self._default_restrict_to_workspace = restrict_to_workspace
|
||||
self._webui_workspaces = WebUIWorkspaceController(
|
||||
session_manager=self._session_manager,
|
||||
default_workspace=self._workspace_path,
|
||||
default_restrict_to_workspace=self._default_restrict_to_workspace,
|
||||
)
|
||||
self._runtime_model_name = runtime_model_name
|
||||
self._runtime_surface = (
|
||||
"native" if runtime_surface in {"native", "desktop"} else "browser"
|
||||
)
|
||||
self._runtime_capabilities = runtime_capabilities(
|
||||
self._runtime_surface,
|
||||
runtime_capabilities_overrides,
|
||||
)
|
||||
self._settings_routes = WebUISettingsRouter(
|
||||
bus=self.bus,
|
||||
logger=self.logger,
|
||||
check_api_token=self._check_api_token,
|
||||
parse_query=_parse_query,
|
||||
json_response=_http_json_response,
|
||||
error_response=_http_error,
|
||||
|
||||
# HTTP API handler — owns tokens, sessions, media, settings, static serving
|
||||
self._http = GatewayHTTPHandler(
|
||||
config=self.config,
|
||||
session_manager=session_manager,
|
||||
static_dist_path=(
|
||||
static_dist_path.resolve() if static_dist_path is not None else None
|
||||
),
|
||||
workspace_path=_resolved_workspace,
|
||||
runtime_model_name=runtime_model_name,
|
||||
runtime_surface=self._runtime_surface,
|
||||
runtime_capabilities=self._runtime_capabilities,
|
||||
runtime_capabilities_overrides=runtime_capabilities_overrides,
|
||||
bus=self.bus,
|
||||
log=self.logger,
|
||||
)
|
||||
# Backwards-compat aliases for workspace controller used in envelope dispatch
|
||||
self._webui_workspaces = self._http.workspaces
|
||||
|
||||
self._stream_text_buffers: dict[tuple[str, str], list[str]] = {}
|
||||
# Process-local secret used to HMAC-sign media URLs. The signed URL is
|
||||
# the capability — anyone who holds a valid URL can fetch that one
|
||||
# file, nothing else. The secret regenerates on restart so links
|
||||
# become self-expiring (callers just refresh the session list).
|
||||
self._media_secret: bytes = secrets.token_bytes(32)
|
||||
|
||||
# -- Subscription bookkeeping -------------------------------------------
|
||||
|
||||
@ -572,9 +531,9 @@ class WebSocketChannel(BaseChannel):
|
||||
connected clients normally see it via ``goal_state`` / ``turn_end`` frames.
|
||||
Pushing here makes refresh + reconnect restore the strip without a new model turn.
|
||||
"""
|
||||
if self._session_manager is None:
|
||||
if self._http.session_manager is None:
|
||||
return
|
||||
row = self._session_manager.read_session_file(f"websocket:{chat_id}")
|
||||
row = self._http.session_manager.read_session_file(f"websocket:{chat_id}")
|
||||
meta = row.get("metadata", {}) if isinstance(row, dict) else {}
|
||||
if not isinstance(meta, dict):
|
||||
meta = {}
|
||||
@ -614,6 +573,59 @@ class WebSocketChannel(BaseChannel):
|
||||
def _expected_path(self) -> str:
|
||||
return _normalize_config_path(self.config.path)
|
||||
|
||||
# -- Backwards-compat property aliases (used by tests) ------------------
|
||||
|
||||
@property
|
||||
def _session_manager(self):
|
||||
return self._http.session_manager
|
||||
|
||||
@_session_manager.setter
|
||||
def _session_manager(self, value):
|
||||
self._http.session_manager = value
|
||||
|
||||
@property
|
||||
def _media_secret(self):
|
||||
return self._http.media_secret
|
||||
|
||||
@property
|
||||
def _issued_tokens(self):
|
||||
return self._http.issued_tokens
|
||||
|
||||
@_issued_tokens.setter
|
||||
def _issued_tokens(self, value):
|
||||
self._http.issued_tokens = value
|
||||
|
||||
@property
|
||||
def _api_tokens(self):
|
||||
return self._http.api_tokens
|
||||
|
||||
def _check_api_token(self, request):
|
||||
return self._http.check_api_token(request)
|
||||
|
||||
def _sign_media_path(self, path):
|
||||
return self._http.sign_media_path(path)
|
||||
|
||||
def _sign_or_stage_media_path(self, path):
|
||||
return self._http.sign_or_stage_media_path(path)
|
||||
|
||||
def _rewrite_local_markdown_images(self, text):
|
||||
return self._http.rewrite_local_markdown_images(text)
|
||||
|
||||
def _handle_bootstrap(self, connection, request):
|
||||
return self._http._handle_bootstrap(connection, request)
|
||||
|
||||
_MAX_ISSUED_TOKENS = GatewayHTTPHandler._MAX_ISSUED_TOKENS
|
||||
|
||||
def _handle_sessions_list(self, request):
|
||||
return self._http._handle_sessions_list(request)
|
||||
|
||||
def _handle_webui_thread_get(self, request, key):
|
||||
return self._http._handle_webui_thread_get(request, key)
|
||||
|
||||
@property
|
||||
def _workspace_path(self):
|
||||
return self._http.workspace_path
|
||||
|
||||
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
||||
cert = self.config.ssl_certfile.strip()
|
||||
key = self.config.ssl_keyfile.strip()
|
||||
@ -628,548 +640,24 @@ class WebSocketChannel(BaseChannel):
|
||||
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||||
return ctx
|
||||
|
||||
_MAX_ISSUED_TOKENS = 10_000
|
||||
|
||||
def _purge_expired_issued_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self._issued_tokens.items()):
|
||||
if now > expiry:
|
||||
self._issued_tokens.pop(token_key, None)
|
||||
|
||||
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||||
"""Validate and consume one issued token (single use per connection attempt).
|
||||
|
||||
Uses single-step pop to minimize the window between lookup and removal;
|
||||
safe under asyncio's single-threaded cooperative model.
|
||||
"""
|
||||
if not token_value:
|
||||
return False
|
||||
self._purge_expired_issued_tokens()
|
||||
expiry = self._issued_tokens.pop(token_value, None)
|
||||
if expiry is None:
|
||||
return False
|
||||
if time.monotonic() > expiry:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
|
||||
secret = self.config.token_issue_secret.strip() or self.config.token.strip()
|
||||
if secret:
|
||||
if not _issue_route_secret_matches(request.headers, secret):
|
||||
return connection.respond(401, "Unauthorized")
|
||||
else:
|
||||
self.logger.warning(
|
||||
"token_issue_path is set but no token_issue_secret or static token is configured; "
|
||||
"any client can obtain connection tokens — set a secret for production."
|
||||
)
|
||||
self._purge_expired_issued_tokens()
|
||||
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
||||
self.logger.error(
|
||||
"too many outstanding issued tokens ({}), rejecting issuance",
|
||||
len(self._issued_tokens),
|
||||
)
|
||||
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
||||
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||||
|
||||
return _http_json_response(
|
||||
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||||
)
|
||||
|
||||
# -- HTTP dispatch ------------------------------------------------------
|
||||
|
||||
async def _dispatch_http(self, connection: Any, request: WsRequest) -> Any:
|
||||
"""Route an inbound HTTP request to a handler or to the WS upgrade path."""
|
||||
"""Route an inbound HTTP request to the HTTP handler or WS upgrade."""
|
||||
got, query = _parse_request_path(request.path)
|
||||
|
||||
if self.config.token_issue_path:
|
||||
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
||||
if got == issue_expected:
|
||||
return self._handle_token_issue_http(connection, request)
|
||||
|
||||
if got == "/webui/bootstrap":
|
||||
return self._handle_bootstrap(connection, request)
|
||||
|
||||
api_response = await self._dispatch_api_route(connection, request, got)
|
||||
if api_response is not None:
|
||||
return api_response
|
||||
|
||||
ws_matched, ws_response = self._dispatch_websocket_upgrade(
|
||||
connection, request, got, query
|
||||
)
|
||||
if ws_matched:
|
||||
return ws_response
|
||||
|
||||
# API clients should never receive the SPA shell for an unknown route.
|
||||
# Returning HTML here makes the WebUI fail with "Unexpected token <"
|
||||
# when a dev server is pointed at an older gateway.
|
||||
if got.startswith("/api/"):
|
||||
return _http_error(404, "API route not found")
|
||||
|
||||
if self._static_dist_path is not None:
|
||||
response = self._serve_static(got)
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
return connection.respond(404, "Not Found")
|
||||
|
||||
async def _dispatch_api_route(
|
||||
self,
|
||||
connection: Any,
|
||||
request: WsRequest,
|
||||
got: str,
|
||||
) -> Any | None:
|
||||
"""Route REST-ish WebUI requests served beside the WebSocket endpoint."""
|
||||
response = await self._dispatch_settings_api_route(request, got)
|
||||
if response is not None:
|
||||
return response
|
||||
response = self._dispatch_session_api_route(request, got)
|
||||
if response is not None:
|
||||
return response
|
||||
response = self._dispatch_media_api_route(request, got)
|
||||
if response is not None:
|
||||
return response
|
||||
return self._dispatch_misc_api_route(connection, request, got)
|
||||
|
||||
def _dispatch_misc_api_route(
|
||||
self,
|
||||
connection: Any,
|
||||
request: WsRequest,
|
||||
got: str,
|
||||
) -> Response | None:
|
||||
"""Route small API endpoints that do not belong to a larger route group."""
|
||||
if got == "/api/sessions":
|
||||
return self._handle_sessions_list(request)
|
||||
|
||||
if got == "/api/commands":
|
||||
return self._handle_commands(request)
|
||||
|
||||
if got == "/api/workspaces":
|
||||
return self._handle_workspaces(connection, request)
|
||||
|
||||
if got == "/api/webui/sidebar-state":
|
||||
return self._handle_webui_sidebar_state(request)
|
||||
|
||||
if got == "/api/webui/sidebar-state/update":
|
||||
return self._handle_webui_sidebar_state_update(request)
|
||||
|
||||
return None
|
||||
|
||||
async def _dispatch_settings_api_route(
|
||||
self,
|
||||
request: WsRequest,
|
||||
got: str,
|
||||
) -> Response | None:
|
||||
return await self._settings_routes.dispatch(request, got)
|
||||
|
||||
def _dispatch_session_api_route(
|
||||
self,
|
||||
request: WsRequest,
|
||||
got: str,
|
||||
) -> Response | None:
|
||||
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
|
||||
if m:
|
||||
return self._handle_session_messages(request, m.group(1))
|
||||
|
||||
m = re.match(r"^/api/sessions/([^/]+)/webui-thread$", got)
|
||||
if m:
|
||||
return self._handle_webui_thread_get(request, m.group(1))
|
||||
|
||||
# NOTE: websockets' HTTP parser only accepts GET, so we cannot expose a
|
||||
# true ``DELETE`` verb. The action is folded into the path instead.
|
||||
m = re.match(r"^/api/sessions/([^/]+)/delete$", got)
|
||||
if m:
|
||||
return self._handle_session_delete(request, m.group(1))
|
||||
|
||||
return None
|
||||
|
||||
def _dispatch_media_api_route(
|
||||
self,
|
||||
request: WsRequest,
|
||||
got: str,
|
||||
) -> Response | None:
|
||||
m = re.match(r"^/api/media/([A-Za-z0-9_-]+)/([A-Za-z0-9_-]+)$", got)
|
||||
if m:
|
||||
return self._handle_media_fetch(m.group(1), m.group(2), request)
|
||||
|
||||
return None
|
||||
|
||||
def _dispatch_websocket_upgrade(
|
||||
self,
|
||||
connection: Any,
|
||||
request: WsRequest,
|
||||
got: str,
|
||||
query: dict[str, list[str]],
|
||||
) -> tuple[bool, Any | None]:
|
||||
"""Authorize only real WS upgrade requests for the configured path."""
|
||||
# WebSocket upgrade — channel handles this itself
|
||||
expected_ws = self._expected_path()
|
||||
if got != expected_ws or not _is_websocket_upgrade(request):
|
||||
return False, None
|
||||
client_id = _query_first(query, "client_id") or ""
|
||||
if len(client_id) > 128:
|
||||
client_id = client_id[:128]
|
||||
if not self.is_allowed(client_id):
|
||||
return True, connection.respond(403, "Forbidden")
|
||||
return True, self._authorize_websocket_handshake(connection, query)
|
||||
if got == expected_ws and _is_websocket_upgrade(request):
|
||||
client_id = _query_first(query, "client_id") or ""
|
||||
if len(client_id) > 128:
|
||||
client_id = client_id[:128]
|
||||
if not self.is_allowed(client_id):
|
||||
return connection.respond(403, "Forbidden")
|
||||
return self._authorize_websocket_handshake(connection, query)
|
||||
|
||||
# -- HTTP route handlers ------------------------------------------------
|
||||
|
||||
def _check_api_token(self, request: WsRequest) -> bool:
|
||||
"""Validate a request against the API token pool (multi-use, TTL-bound)."""
|
||||
self._purge_expired_api_tokens()
|
||||
token = _bearer_token(request.headers) or _query_first(
|
||||
_parse_query(request.path), "token"
|
||||
)
|
||||
if not token:
|
||||
return False
|
||||
expiry = self._api_tokens.get(token)
|
||||
if expiry is None or time.monotonic() > expiry:
|
||||
self._api_tokens.pop(token, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _purge_expired_api_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self._api_tokens.items()):
|
||||
if now > expiry:
|
||||
self._api_tokens.pop(token_key, None)
|
||||
|
||||
def _handle_bootstrap(self, connection: Any, request: Any) -> Response:
|
||||
# When a secret is configured (token_issue_secret or static token),
|
||||
# validate it regardless of source IP. This secures deployments
|
||||
# behind a reverse proxy where all connections appear as localhost.
|
||||
secret = self.config.token_issue_secret.strip() or self.config.token.strip()
|
||||
if secret:
|
||||
if not _issue_route_secret_matches(request.headers, secret):
|
||||
return _http_error(401, "Unauthorized")
|
||||
elif not _is_localhost(connection):
|
||||
# No secret configured: only allow localhost (local dev mode).
|
||||
return _http_error(403, "bootstrap is localhost-only")
|
||||
# Cap outstanding tokens to avoid runaway growth from a misbehaving client.
|
||||
self._purge_expired_issued_tokens()
|
||||
self._purge_expired_api_tokens()
|
||||
if (
|
||||
len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS
|
||||
or len(self._api_tokens) >= self._MAX_ISSUED_TOKENS
|
||||
):
|
||||
return _http_response(
|
||||
json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"),
|
||||
status=429,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
token = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||
expiry = time.monotonic() + float(self.config.token_ttl_s)
|
||||
# Same string registered in both pools: the WS handshake consumes one copy
|
||||
# while the REST surface keeps validating the other until TTL expiry.
|
||||
self._issued_tokens[token] = expiry
|
||||
self._api_tokens[token] = expiry
|
||||
ws_url = self._bootstrap_ws_url(request)
|
||||
return _http_json_response(
|
||||
{
|
||||
"token": token,
|
||||
"ws_path": self._expected_path(),
|
||||
"ws_url": ws_url,
|
||||
"expires_in": self.config.token_ttl_s,
|
||||
"model_name": _resolve_bootstrap_model_name(self._runtime_model_name),
|
||||
"runtime_surface": self._runtime_surface,
|
||||
"runtime_capabilities": self._runtime_capabilities,
|
||||
}
|
||||
)
|
||||
|
||||
def _bootstrap_ws_url(self, request: Any) -> str:
|
||||
"""Absolute WS URL clients should prefer over a dev-server proxy."""
|
||||
headers = getattr(request, "headers", {}) or {}
|
||||
host = _safe_host_header(_case_insensitive_header(headers, "Host"))
|
||||
if not host:
|
||||
host = _host_for_url(self.config.host, self.config.port)
|
||||
|
||||
proto = _case_insensitive_header(headers, "X-Forwarded-Proto")
|
||||
proto = proto.split(",", 1)[0].strip().lower()
|
||||
secure = proto in {"https", "wss"} or bool(self.config.ssl_certfile.strip())
|
||||
scheme = "wss" if secure else "ws"
|
||||
return f"{scheme}://{host}{self._expected_path()}"
|
||||
|
||||
def _handle_sessions_list(self, request: WsRequest) -> Response:
|
||||
if not self._check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
if self._session_manager is None:
|
||||
return _http_error(503, "session manager unavailable")
|
||||
sessions = self._session_manager.list_sessions()
|
||||
# Sidebar/chat listing for WS-backed sessions only — CLI / Slack / etc.
|
||||
# keys are not intended for resume over this HTTP surface.
|
||||
cleaned = []
|
||||
for s in sessions:
|
||||
key = s.get("key")
|
||||
if not (isinstance(key, str) and key.startswith("websocket:")):
|
||||
continue
|
||||
row = {k: v for k, v in s.items() if k != "path"}
|
||||
chat_id = key.split(":", 1)[1]
|
||||
started_at = websocket_turn_wall_started_at(chat_id)
|
||||
if started_at is not None:
|
||||
row["run_started_at"] = started_at
|
||||
scope = self._webui_workspaces.scope_for_session_key(key)
|
||||
row["workspace_scope"] = scope.payload()
|
||||
cleaned.append(row)
|
||||
return _http_json_response({"sessions": cleaned})
|
||||
|
||||
def _handle_workspaces(self, connection: Any, request: WsRequest) -> Response:
|
||||
if not self._check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
return _http_json_response(
|
||||
self._webui_workspaces.payload(controls_available=_is_localhost(connection))
|
||||
)
|
||||
|
||||
def _handle_commands(self, request: WsRequest) -> Response:
|
||||
if not self._check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
return _http_json_response({"commands": builtin_command_palette()})
|
||||
|
||||
def _handle_webui_sidebar_state(self, request: WsRequest) -> Response:
|
||||
if not self._check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
return _http_json_response(read_webui_sidebar_state())
|
||||
|
||||
def _handle_webui_sidebar_state_update(self, request: WsRequest) -> Response:
|
||||
if not self._check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
query = _parse_query(request.path)
|
||||
raw_state = _query_first(query, "state")
|
||||
if raw_state is None:
|
||||
return _http_error(400, "missing state")
|
||||
try:
|
||||
decoded = json.loads(raw_state)
|
||||
except json.JSONDecodeError:
|
||||
return _http_error(400, "state must be JSON")
|
||||
if not isinstance(decoded, dict):
|
||||
return _http_error(400, "state must be an object")
|
||||
try:
|
||||
state = write_webui_sidebar_state(decoded)
|
||||
except ValueError as e:
|
||||
return _http_error(400, str(e))
|
||||
except OSError:
|
||||
self.logger.exception("failed to write webui sidebar state")
|
||||
return _http_error(500, "failed to write sidebar state")
|
||||
return _http_json_response(state)
|
||||
|
||||
# -- Session replay, transcript, and signed media ----------------------
|
||||
|
||||
@staticmethod
|
||||
def _is_websocket_channel_session_key(key: str) -> bool:
|
||||
"""True when *key* is a ``websocket:…`` session exposed on this HTTP surface."""
|
||||
return key.startswith("websocket:")
|
||||
|
||||
def _handle_session_messages(self, request: WsRequest, key: str) -> Response:
|
||||
if not self._check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
if self._session_manager is None:
|
||||
return _http_error(503, "session manager unavailable")
|
||||
decoded_key = _decode_api_key(key)
|
||||
if decoded_key is None:
|
||||
return _http_error(400, "invalid session key")
|
||||
# Only ``websocket:…`` sessions are listed/served here — same boundary as
|
||||
# ``/api/sessions``. Block handcrafted URLs from probing CLI / Slack / etc.
|
||||
if not self._is_websocket_channel_session_key(decoded_key):
|
||||
return _http_error(404, "session not found")
|
||||
data = self._session_manager.read_session_file(decoded_key)
|
||||
if data is None:
|
||||
return _http_error(404, "session not found")
|
||||
messages = data.get("messages")
|
||||
if isinstance(messages, list):
|
||||
scrub_subagent_messages_for_channel(messages)
|
||||
# Decorate persisted user messages with signed media URLs so the
|
||||
# client can render previews. The raw on-disk ``media`` paths are
|
||||
# stripped on the way out — they leak server filesystem layout and
|
||||
# the client never needs them once it has the signed fetch URL.
|
||||
attach_signed_media_urls(data, sign_path=self._sign_media_path)
|
||||
return _http_json_response(data)
|
||||
|
||||
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
|
||||
if not self._check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
decoded_key = _decode_api_key(key)
|
||||
if decoded_key is None:
|
||||
return _http_error(400, "invalid session key")
|
||||
if not self._is_websocket_channel_session_key(decoded_key):
|
||||
return _http_error(404, "session not found")
|
||||
scope = self._webui_workspaces.scope_for_session_key(decoded_key)
|
||||
augment_media = partial(
|
||||
signed_media_attachments,
|
||||
sign_path=self._sign_or_stage_media_path,
|
||||
)
|
||||
data = build_webui_thread_response(
|
||||
decoded_key,
|
||||
augment_user_media=augment_media,
|
||||
augment_assistant_media=augment_media,
|
||||
augment_assistant_text=lambda text: rewrite_local_markdown_images(
|
||||
text,
|
||||
workspace_path=scope.project_path,
|
||||
sign_path=self._sign_or_stage_media_path,
|
||||
),
|
||||
)
|
||||
if data is None:
|
||||
return _http_error(404, "webui thread not found")
|
||||
data["workspace_scope"] = scope.payload()
|
||||
return _http_json_response(data)
|
||||
|
||||
def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None:
|
||||
sk = f"websocket:{chat_id}"
|
||||
try:
|
||||
dup = json.loads(json.dumps(wire, ensure_ascii=False))
|
||||
append_transcript_object(sk, dup)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning("webui transcript append failed: {}", e)
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
sender_id: str,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
media: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
is_dm: bool = False,
|
||||
) -> None:
|
||||
meta = metadata or {}
|
||||
if meta.get("webui"):
|
||||
user_obj: dict[str, Any] = {
|
||||
"event": "user",
|
||||
"chat_id": chat_id,
|
||||
"text": content,
|
||||
}
|
||||
if media:
|
||||
user_obj["media_paths"] = list(media)
|
||||
cli_apps = meta.get("cli_apps")
|
||||
if isinstance(cli_apps, list) and cli_apps:
|
||||
user_obj["cli_apps"] = cli_apps
|
||||
mcp_presets = meta.get("mcp_presets")
|
||||
if isinstance(mcp_presets, list) and mcp_presets:
|
||||
user_obj["mcp_presets"] = mcp_presets
|
||||
self._try_append_webui_transcript(chat_id, user_obj)
|
||||
await super()._handle_message(
|
||||
sender_id,
|
||||
chat_id,
|
||||
content,
|
||||
media,
|
||||
metadata,
|
||||
session_key,
|
||||
is_dm,
|
||||
)
|
||||
|
||||
def _sign_media_path(self, abs_path: Path) -> str | None:
|
||||
"""Return a ``/api/media/<sig>/<payload>`` URL for *abs_path*, or
|
||||
``None`` when the path does not resolve inside the media root.
|
||||
|
||||
The URL is self-authenticating: the signature binds the payload to
|
||||
this process's ``_media_secret``, so only paths we chose to sign can
|
||||
be fetched. The returned path is relative to the server origin; the
|
||||
client joins it against this server's HTTP origin (same host as WS).
|
||||
"""
|
||||
return sign_media_path(
|
||||
abs_path,
|
||||
secret=self._media_secret,
|
||||
media_dir=lambda channel=None: get_media_dir(channel),
|
||||
)
|
||||
|
||||
def _sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
|
||||
"""Return a signed media URL payload for *path*.
|
||||
|
||||
Persisted inbound media already lives under ``get_media_dir`` and can
|
||||
be signed directly. Outbound bot-generated files may live anywhere on
|
||||
disk; copy those into the websocket media bucket first so the browser
|
||||
can fetch them through the existing signed media route without
|
||||
exposing arbitrary filesystem paths.
|
||||
"""
|
||||
return sign_or_stage_media_path(
|
||||
path,
|
||||
secret=self._media_secret,
|
||||
media_dir=lambda channel=None: get_media_dir(channel),
|
||||
logger=self.logger,
|
||||
)
|
||||
|
||||
def _rewrite_local_markdown_images(self, text: str) -> str:
|
||||
return rewrite_local_markdown_images(
|
||||
text,
|
||||
workspace_path=self._workspace_path,
|
||||
sign_path=self._sign_or_stage_media_path,
|
||||
)
|
||||
|
||||
def _handle_media_fetch(
|
||||
self, sig: str, payload: str, request: WsRequest | None = None
|
||||
) -> Response:
|
||||
"""Serve a single media file previously signed via
|
||||
:meth:`_sign_media_path`. Validates the signature, decodes the
|
||||
payload to a relative path, and streams the file bytes with a
|
||||
long-lived immutable cache header (the URL already encodes the
|
||||
file identity, so caches can be aggressive)."""
|
||||
return serve_signed_media(
|
||||
sig,
|
||||
payload,
|
||||
secret=self._media_secret,
|
||||
request=request,
|
||||
media_dir=lambda channel=None: get_media_dir(channel),
|
||||
)
|
||||
|
||||
def _handle_session_delete(self, request: WsRequest, key: str) -> Response:
|
||||
if not self._check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
if self._session_manager is None:
|
||||
return _http_error(503, "session manager unavailable")
|
||||
decoded_key = _decode_api_key(key)
|
||||
if decoded_key is None:
|
||||
return _http_error(400, "invalid session key")
|
||||
# Same boundary as ``_handle_session_messages``: mutations apply only to
|
||||
# websocket-channel sessions; deletion unlinks local JSONL — keep scope narrow.
|
||||
if not self._is_websocket_channel_session_key(decoded_key):
|
||||
return _http_error(404, "session not found")
|
||||
deleted = self._session_manager.delete_session(decoded_key)
|
||||
delete_webui_thread(decoded_key)
|
||||
return _http_json_response({"deleted": bool(deleted)})
|
||||
|
||||
# -- Static files and WebSocket handshake ------------------------------
|
||||
|
||||
def _serve_static(self, request_path: str) -> Response | None:
|
||||
"""Resolve *request_path* against the built SPA directory; SPA fallback to index.html."""
|
||||
assert self._static_dist_path is not None
|
||||
rel = request_path.lstrip("/")
|
||||
if not rel:
|
||||
rel = "index.html"
|
||||
# Reject path-traversal attempts and absolute targets.
|
||||
if ".." in rel.split("/") or rel.startswith("/"):
|
||||
return _http_error(403, "Forbidden")
|
||||
candidate = (self._static_dist_path / rel).resolve()
|
||||
try:
|
||||
candidate.relative_to(self._static_dist_path)
|
||||
except ValueError:
|
||||
return _http_error(403, "Forbidden")
|
||||
if not candidate.is_file():
|
||||
# SPA history-mode fallback: unknown routes serve index.html so the
|
||||
# client-side router can render them.
|
||||
index = self._static_dist_path / "index.html"
|
||||
if index.is_file():
|
||||
candidate = index
|
||||
else:
|
||||
return None
|
||||
try:
|
||||
body = candidate.read_bytes()
|
||||
except OSError as e:
|
||||
self.logger.warning("static: failed to read {}: {}", candidate, e)
|
||||
return _http_error(500, "Internal Server Error")
|
||||
ctype, _ = mimetypes.guess_type(candidate.name)
|
||||
if ctype is None:
|
||||
ctype = "application/octet-stream"
|
||||
if ctype.startswith("text/") or ctype in {"application/javascript", "application/json"}:
|
||||
ctype = f"{ctype}; charset=utf-8"
|
||||
# Hash-named build assets are cache-friendly; index.html must stay fresh.
|
||||
if candidate.name == "index.html":
|
||||
cache = "no-cache"
|
||||
else:
|
||||
cache = "public, max-age=31536000, immutable"
|
||||
return _http_response(
|
||||
body,
|
||||
status=200,
|
||||
content_type=ctype,
|
||||
extra_headers=[("Cache-Control", cache)],
|
||||
)
|
||||
<< # Everything else goes to the HTTP handler
|
||||
return await self._http.dispatch(connection, request)
|
||||
|
||||
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
|
||||
supplied = _query_first(query, "token")
|
||||
@ -1178,20 +666,21 @@ class WebSocketChannel(BaseChannel):
|
||||
if static_token:
|
||||
if supplied and hmac.compare_digest(supplied, static_token):
|
||||
return None
|
||||
if supplied and self._take_issued_token_if_valid(supplied):
|
||||
if supplied and self._http.take_issued_token_if_valid(supplied):
|
||||
return None
|
||||
return connection.respond(401, "Unauthorized")
|
||||
|
||||
if self.config.websocket_requires_token:
|
||||
if supplied and self._take_issued_token_if_valid(supplied):
|
||||
if supplied and self._http.take_issued_token_if_valid(supplied):
|
||||
return None
|
||||
return connection.respond(401, "Unauthorized")
|
||||
|
||||
if supplied:
|
||||
self._take_issued_token_if_valid(supplied)
|
||||
self._http.take_issued_token_if_valid(supplied)
|
||||
return None
|
||||
|
||||
# -- Server lifecycle and connection ingress ---------------------------
|
||||
# -- Server lifecycle and connection ingress ---------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
from nanobot.utils.logging_bridge import redirect_lib_logging
|
||||
@ -1587,8 +1076,8 @@ class WebSocketChannel(BaseChannel):
|
||||
self._subs.clear()
|
||||
self._conn_chats.clear()
|
||||
self._conn_default.clear()
|
||||
self._issued_tokens.clear()
|
||||
self._api_tokens.clear()
|
||||
self._http.issued_tokens.clear()
|
||||
self._http.api_tokens.clear()
|
||||
|
||||
async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None:
|
||||
"""Send a raw frame to one connection, cleaning up on ConnectionClosed."""
|
||||
@ -1601,6 +1090,14 @@ class WebSocketChannel(BaseChannel):
|
||||
self.logger.exception("send failed{}", label)
|
||||
raise
|
||||
|
||||
def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None:
|
||||
sk = f"websocket:{chat_id}"
|
||||
try:
|
||||
dup = json.loads(json.dumps(wire, ensure_ascii=False))
|
||||
append_transcript_object(sk, dup)
|
||||
except (ValueError, TypeError) as e:
|
||||
self.logger.warning("webui transcript append failed: {}", e)
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
if msg.metadata.get("_runtime_model_updated"):
|
||||
await self.send_runtime_model_updated(
|
||||
@ -1662,7 +1159,7 @@ class WebSocketChannel(BaseChannel):
|
||||
)
|
||||
return
|
||||
text = msg.content
|
||||
wire_text = self._rewrite_local_markdown_images(text)
|
||||
wire_text = self._http.rewrite_local_markdown_images(text)
|
||||
payload: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"chat_id": msg.chat_id,
|
||||
@ -1672,7 +1169,7 @@ class WebSocketChannel(BaseChannel):
|
||||
payload["media"] = msg.media
|
||||
urls: list[dict[str, str]] = []
|
||||
for entry in msg.media:
|
||||
signed = self._sign_or_stage_media_path(Path(entry))
|
||||
signed = self._http.sign_or_stage_media_path(Path(entry))
|
||||
if signed is not None:
|
||||
urls.append(signed)
|
||||
if urls:
|
||||
@ -1787,7 +1284,7 @@ class WebSocketChannel(BaseChannel):
|
||||
if delta:
|
||||
buffered.append(delta)
|
||||
full_text = "".join(buffered)
|
||||
rewritten = self._rewrite_local_markdown_images(full_text)
|
||||
rewritten = self._http.rewrite_local_markdown_images(full_text)
|
||||
if rewritten != full_text:
|
||||
body["text"] = rewritten
|
||||
else:
|
||||
|
||||
731
nanobot/channels/ws_http.py
Normal file
731
nanobot/channels/ws_http.py
Normal file
@ -0,0 +1,731 @@
|
||||
"""HTTP API handler extracted from WebSocketChannel.
|
||||
|
||||
Handles all non-WebSocket HTTP routes: bootstrap, sessions, settings,
|
||||
media, commands, sidebar state, static file serving, and token management.
|
||||
|
||||
Also houses shared HTTP utility functions used by both this module and
|
||||
``websocket.py`` to avoid circular imports.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import email.utils
|
||||
import hmac
|
||||
import http
|
||||
import json
|
||||
import mimetypes
|
||||
import re
|
||||
import secrets
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from loguru import logger
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.http11 import Request as WsRequest
|
||||
from websockets.http11 import Response
|
||||
|
||||
from nanobot.command.builtin import builtin_command_palette
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel
|
||||
from nanobot.webui.media_api import (
|
||||
serve_signed_media,
|
||||
sign_media_path,
|
||||
sign_or_stage_media_path,
|
||||
)
|
||||
from nanobot.webui.sidebar_state import (
|
||||
read_webui_sidebar_state,
|
||||
write_webui_sidebar_state,
|
||||
)
|
||||
from nanobot.webui.thread_disk import delete_webui_thread
|
||||
from nanobot.webui.transcript import (
|
||||
build_webui_thread_response,
|
||||
rewrite_local_markdown_images,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared HTTP utility functions (imported by websocket.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _strip_trailing_slash(path: str) -> str:
|
||||
if len(path) > 1 and path.endswith("/"):
|
||||
return path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _normalize_config_path(path: str) -> str:
|
||||
return _strip_trailing_slash(path)
|
||||
|
||||
|
||||
def _case_insensitive_header(headers: Any, key: str) -> str:
|
||||
"""Read a header from websockets/http test stubs without assuming casing."""
|
||||
try:
|
||||
value = headers.get(key)
|
||||
except Exception:
|
||||
value = None
|
||||
if value is None:
|
||||
try:
|
||||
value = headers.get(key.lower())
|
||||
except Exception:
|
||||
value = None
|
||||
return str(value or "").strip()
|
||||
|
||||
|
||||
def _safe_host_header(value: str) -> str:
|
||||
"""Return a safe Host header value, or empty when it should not be echoed."""
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return ""
|
||||
if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value):
|
||||
return value
|
||||
if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value):
|
||||
return value
|
||||
return ""
|
||||
|
||||
|
||||
def _host_for_url(host: str, port: int) -> str:
|
||||
host = host.strip()
|
||||
if host in ("0.0.0.0", "::"):
|
||||
host = "127.0.0.1"
|
||||
if ":" in host and not host.startswith("["):
|
||||
host = f"[{host}]"
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||
headers = Headers(
|
||||
[
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", "application/json; charset=utf-8"),
|
||||
]
|
||||
)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, headers, body)
|
||||
|
||||
|
||||
def _http_response(
|
||||
body: bytes,
|
||||
*,
|
||||
status: int = 200,
|
||||
content_type: str = "text/plain; charset=utf-8",
|
||||
extra_headers: list[tuple[str, str]] | None = None,
|
||||
) -> Response:
|
||||
headers = [
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", content_type),
|
||||
]
|
||||
if extra_headers:
|
||||
headers.extend(extra_headers)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, Headers(headers), body)
|
||||
|
||||
|
||||
def _http_error(status: int, message: str | None = None) -> Response:
|
||||
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
|
||||
return _http_response(body, status=status)
|
||||
|
||||
|
||||
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||||
"""Parse normalized path and query parameters in one pass."""
|
||||
parsed = urlparse("ws://x" + path_with_query)
|
||||
path = _strip_trailing_slash(parsed.path or "/")
|
||||
return path, parse_qs(parsed.query, keep_blank_values=True)
|
||||
|
||||
|
||||
def _normalize_http_path(path_with_query: str) -> str:
|
||||
return _parse_request_path(path_with_query)[0]
|
||||
|
||||
|
||||
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||||
return _parse_request_path(path_with_query)[1]
|
||||
|
||||
|
||||
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
|
||||
values = query.get(key)
|
||||
return values[0] if values else None
|
||||
|
||||
|
||||
def _is_localhost(connection: Any) -> bool:
|
||||
addr = getattr(connection, "remote_address", None)
|
||||
if not addr:
|
||||
return False
|
||||
host = addr[0] if isinstance(addr, tuple) else addr
|
||||
if not isinstance(host, str):
|
||||
return False
|
||||
if host.startswith("::ffff:"):
|
||||
host = host[7:]
|
||||
return host in {"127.0.0.1", "::1", "localhost"}
|
||||
|
||||
|
||||
def _bearer_token(headers: Any) -> str | None:
|
||||
auth = headers.get("Authorization") or headers.get("authorization")
|
||||
if auth and auth.lower().startswith("bearer "):
|
||||
return auth[7:].strip() or None
|
||||
return None
|
||||
|
||||
|
||||
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||||
if not configured_secret:
|
||||
return True
|
||||
authorization = headers.get("Authorization") or headers.get("authorization")
|
||||
if authorization and authorization.lower().startswith("bearer "):
|
||||
supplied = authorization[7:].strip()
|
||||
return hmac.compare_digest(supplied, configured_secret)
|
||||
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||||
if not header_token:
|
||||
return False
|
||||
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||||
|
||||
|
||||
def _decode_api_key(raw_key: str) -> str | None:
|
||||
from urllib.parse import unquote
|
||||
|
||||
key = unquote(raw_key)
|
||||
_api_key_re = re.compile(r"^[A-Za-z0-9_:.-]{1,128}$")
|
||||
if _api_key_re.match(key) is None:
|
||||
return None
|
||||
return key
|
||||
|
||||
|
||||
def _default_model_name_from_config() -> str | None:
|
||||
try:
|
||||
from nanobot.config.loader import load_config
|
||||
model = load_config().resolve_preset().model.strip()
|
||||
return model or None
|
||||
except Exception as e:
|
||||
logger.debug("bootstrap model_name could not load from config: {}", e)
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_bootstrap_model_name(
|
||||
runtime_name: Callable[[], str | None] | None,
|
||||
) -> str | None:
|
||||
if runtime_name is not None:
|
||||
try:
|
||||
raw = runtime_name()
|
||||
except Exception as e:
|
||||
logger.debug("bootstrap runtime model resolver failed: {}", e)
|
||||
else:
|
||||
if isinstance(raw, str):
|
||||
stripped = raw.strip()
|
||||
if stripped:
|
||||
return stripped
|
||||
return _default_model_name_from_config()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GatewayHTTPHandler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GatewayHTTPHandler:
|
||||
"""Handles all HTTP routes served alongside the WebSocket endpoint.
|
||||
|
||||
Owns token management, session API, media API, static file serving,
|
||||
and delegates settings routes to ``WebUISettingsRouter``.
|
||||
"""
|
||||
|
||||
_MAX_ISSUED_TOKENS = 10_000
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: Any, # WebSocketConfig
|
||||
session_manager: SessionManager | None,
|
||||
static_dist_path: Path | None,
|
||||
workspace_path: Path,
|
||||
runtime_model_name: Callable[[], str | None] | None,
|
||||
runtime_surface: str,
|
||||
runtime_capabilities_overrides: dict[str, Any] | None,
|
||||
bus: MessageBus,
|
||||
log: Any = logger,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.session_manager = session_manager
|
||||
self.static_dist_path = static_dist_path
|
||||
self.workspace_path = workspace_path
|
||||
self.runtime_model_name = runtime_model_name
|
||||
self.bus = bus
|
||||
self._log = log
|
||||
self._runtime_surface = runtime_surface
|
||||
|
||||
self.issued_tokens: dict[str, float] = {}
|
||||
self.api_tokens: dict[str, float] = {}
|
||||
self.media_secret: bytes = secrets.token_bytes(32)
|
||||
|
||||
# Workspace controller
|
||||
from nanobot.webui.workspaces import WebUIWorkspaceController
|
||||
|
||||
self.workspaces = WebUIWorkspaceController(
|
||||
session_manager=session_manager,
|
||||
default_workspace=workspace_path,
|
||||
default_restrict_to_workspace=None,
|
||||
)
|
||||
|
||||
# Settings router
|
||||
from nanobot.webui.settings_api import runtime_capabilities as _rc
|
||||
from nanobot.webui.settings_routes import WebUISettingsRouter
|
||||
|
||||
self._capabilities = _rc(runtime_surface, runtime_capabilities_overrides or {})
|
||||
self.settings_routes = WebUISettingsRouter(
|
||||
bus=bus,
|
||||
logger=self._log,
|
||||
check_api_token=self.check_api_token,
|
||||
parse_query=_parse_query,
|
||||
json_response=_http_json_response,
|
||||
error_response=_http_error,
|
||||
runtime_surface=runtime_surface,
|
||||
runtime_capabilities=self._capabilities,
|
||||
)
|
||||
|
||||
# -- Token management ---------------------------------------------------
|
||||
|
||||
def check_api_token(self, request: WsRequest) -> bool:
|
||||
self._purge_expired_api_tokens()
|
||||
token = _bearer_token(request.headers) or _query_first(
|
||||
_parse_query(request.path), "token"
|
||||
)
|
||||
if not token:
|
||||
return False
|
||||
expiry = self.api_tokens.get(token)
|
||||
if expiry is None or time.monotonic() > expiry:
|
||||
self.api_tokens.pop(token, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
def _purge_expired_api_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self.api_tokens.items()):
|
||||
if now > expiry:
|
||||
self.api_tokens.pop(token_key, None)
|
||||
|
||||
def _purge_expired_issued_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self.issued_tokens.items()):
|
||||
if now > expiry:
|
||||
self.issued_tokens.pop(token_key, None)
|
||||
|
||||
def take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||||
if not token_value:
|
||||
return False
|
||||
self._purge_expired_issued_tokens()
|
||||
expiry = self.issued_tokens.pop(token_value, None)
|
||||
if expiry is None:
|
||||
return False
|
||||
if time.monotonic() > expiry:
|
||||
return False
|
||||
return True
|
||||
|
||||
# -- Main dispatch ------------------------------------------------------
|
||||
|
||||
async def dispatch(self, connection: Any, request: WsRequest) -> Any | None:
|
||||
"""Route an HTTP request. Returns Response or None."""
|
||||
got, query = _parse_request_path(request.path)
|
||||
|
||||
# Token issue endpoint
|
||||
if self.config.token_issue_path:
|
||||
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
||||
if got == issue_expected:
|
||||
return self._handle_token_issue(connection, request)
|
||||
|
||||
# Bootstrap
|
||||
if got == "/webui/bootstrap":
|
||||
return self._handle_bootstrap(connection, request)
|
||||
|
||||
# Settings routes (delegated)
|
||||
response = await self.settings_routes.dispatch(request, got)
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
# Session routes
|
||||
response = self._dispatch_session_routes(request, got)
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
# Media routes
|
||||
response = self._dispatch_media_routes(request, got)
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
# Misc routes
|
||||
response = self._dispatch_misc_routes(connection, request, got)
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
# API 404 (never serve SPA for /api/ routes)
|
||||
if got.startswith("/api/"):
|
||||
return _http_error(404, "API route not found")
|
||||
|
||||
# Static SPA serving
|
||||
if self.static_dist_path is not None:
|
||||
response = self._serve_static(got)
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
return connection.respond(404, "Not Found")
|
||||
|
||||
# -- Token issue --------------------------------------------------------
|
||||
|
||||
def _handle_token_issue(self, connection: Any, request: Any) -> Any:
|
||||
secret = self.config.token_issue_secret.strip()
|
||||
if secret:
|
||||
if not _issue_route_secret_matches(request.headers, secret):
|
||||
return connection.respond(401, "Unauthorized")
|
||||
else:
|
||||
self._log.warning(
|
||||
"token_issue_path is set but token_issue_secret is empty; "
|
||||
"any client can obtain connection tokens — set token_issue_secret for production."
|
||||
)
|
||||
self._purge_expired_issued_tokens()
|
||||
if len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
||||
self._log.error(
|
||||
"too many outstanding issued tokens ({}), rejecting issuance",
|
||||
len(self.issued_tokens),
|
||||
)
|
||||
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
||||
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||
self.issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||||
return _http_json_response(
|
||||
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||||
)
|
||||
|
||||
# -- Bootstrap ----------------------------------------------------------
|
||||
|
||||
def _handle_bootstrap(self, connection: Any, request: Any) -> Response:
|
||||
secret = self.config.token_issue_secret.strip() or self.config.token.strip()
|
||||
if secret:
|
||||
if not _issue_route_secret_matches(request.headers, secret):
|
||||
return _http_error(401, "Unauthorized")
|
||||
elif not _is_localhost(connection):
|
||||
return _http_error(403, "bootstrap is localhost-only")
|
||||
|
||||
self._purge_expired_issued_tokens()
|
||||
self._purge_expired_api_tokens()
|
||||
if (
|
||||
len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS
|
||||
or len(self.api_tokens) >= self._MAX_ISSUED_TOKENS
|
||||
):
|
||||
return _http_response(
|
||||
json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"),
|
||||
status=429,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
token = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||
expiry = time.monotonic() + float(self.config.token_ttl_s)
|
||||
self.issued_tokens[token] = expiry
|
||||
self.api_tokens[token] = expiry
|
||||
|
||||
ws_url = self._bootstrap_ws_url(request)
|
||||
expected_path = _normalize_config_path(self.config.path)
|
||||
return _http_json_response(
|
||||
{
|
||||
"token": token,
|
||||
"ws_path": expected_path,
|
||||
"ws_url": ws_url,
|
||||
"expires_in": self.config.token_ttl_s,
|
||||
"model_name": _resolve_bootstrap_model_name(self.runtime_model_name),
|
||||
"runtime_surface": self._runtime_surface,
|
||||
"runtime_capabilities": self._capabilities,
|
||||
}
|
||||
)
|
||||
|
||||
def _bootstrap_ws_url(self, request: Any) -> str:
|
||||
headers = getattr(request, "headers", {}) or {}
|
||||
host = _safe_host_header(_case_insensitive_header(headers, "Host"))
|
||||
if not host:
|
||||
host = _host_for_url(self.config.host, self.config.port)
|
||||
proto = _case_insensitive_header(headers, "X-Forwarded-Proto")
|
||||
proto = proto.split(",", 1)[0].strip().lower()
|
||||
secure = proto in {"https", "wss"} or bool(self.config.ssl_certfile.strip())
|
||||
scheme = "wss" if secure else "ws"
|
||||
expected_path = _normalize_config_path(self.config.path)
|
||||
return f"{scheme}://{host}{expected_path}"
|
||||
|
||||
# -- Session routes -----------------------------------------------------
|
||||
|
||||
def _dispatch_session_routes(self, request: WsRequest, got: str) -> Response | None:
|
||||
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
|
||||
if m:
|
||||
return self._handle_session_messages(request, m.group(1))
|
||||
|
||||
m = re.match(r"^/api/sessions/([^/]+)/webui-thread$", got)
|
||||
if m:
|
||||
return self._handle_webui_thread_get(request, m.group(1))
|
||||
|
||||
m = re.match(r"^/api/sessions/([^/]+)/delete$", got)
|
||||
if m:
|
||||
return self._handle_session_delete(request, m.group(1))
|
||||
|
||||
return None
|
||||
|
||||
def _handle_sessions_list(self, request: WsRequest) -> Response:
|
||||
if not self.check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
if self.session_manager is None:
|
||||
return _http_error(503, "session manager unavailable")
|
||||
sessions = self.session_manager.list_sessions()
|
||||
from nanobot.session.webui_turns import websocket_turn_wall_started_at
|
||||
|
||||
cleaned = []
|
||||
for s in sessions:
|
||||
key = s.get("key")
|
||||
if not (isinstance(key, str) and key.startswith("websocket:")):
|
||||
continue
|
||||
row = {k: v for k, v in s.items() if k != "path"}
|
||||
chat_id = key.split(":", 1)[1]
|
||||
started_at = websocket_turn_wall_started_at(chat_id)
|
||||
if started_at is not None:
|
||||
row["run_started_at"] = started_at
|
||||
scope = self.workspaces.scope_for_session_key(key)
|
||||
row["workspace_scope"] = scope.payload()
|
||||
cleaned.append(row)
|
||||
return _http_json_response({"sessions": cleaned})
|
||||
|
||||
def _handle_session_messages(self, request: WsRequest, key: str) -> Response:
|
||||
if not self.check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
if self.session_manager is None:
|
||||
return _http_error(503, "session manager unavailable")
|
||||
decoded_key = _decode_api_key(key)
|
||||
if decoded_key is None:
|
||||
return _http_error(400, "invalid session key")
|
||||
if not _is_websocket_channel_session_key(decoded_key):
|
||||
return _http_error(404, "session not found")
|
||||
data = self.session_manager.read_session_file(decoded_key)
|
||||
if data is None:
|
||||
return _http_error(404, "session not found")
|
||||
messages = data.get("messages")
|
||||
if isinstance(messages, list):
|
||||
scrub_subagent_messages_for_channel(messages)
|
||||
self._augment_media_urls(data)
|
||||
return _http_json_response(data)
|
||||
|
||||
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
|
||||
if not self.check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
decoded_key = _decode_api_key(key)
|
||||
if decoded_key is None:
|
||||
return _http_error(400, "invalid session key")
|
||||
if not _is_websocket_channel_session_key(decoded_key):
|
||||
return _http_error(404, "session not found")
|
||||
scope = self.workspaces.scope_for_session_key(decoded_key)
|
||||
data = build_webui_thread_response(
|
||||
decoded_key,
|
||||
augment_user_media=self._augment_transcript_user_media,
|
||||
augment_assistant_text=lambda text: rewrite_local_markdown_images(
|
||||
text,
|
||||
workspace_path=scope.project_path,
|
||||
sign_path=self.sign_or_stage_media_path,
|
||||
),
|
||||
)
|
||||
if data is None:
|
||||
return _http_error(404, "webui thread not found")
|
||||
data["workspace_scope"] = scope.payload()
|
||||
return _http_json_response(data)
|
||||
|
||||
def _handle_session_delete(self, request: WsRequest, key: str) -> Response:
|
||||
if not self.check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
if self.session_manager is None:
|
||||
return _http_error(503, "session manager unavailable")
|
||||
decoded_key = _decode_api_key(key)
|
||||
if decoded_key is None:
|
||||
return _http_error(400, "invalid session key")
|
||||
if not _is_websocket_channel_session_key(decoded_key):
|
||||
return _http_error(404, "session not found")
|
||||
deleted = self.session_manager.delete_session(decoded_key)
|
||||
delete_webui_thread(decoded_key)
|
||||
return _http_json_response({"deleted": bool(deleted)})
|
||||
|
||||
# -- Media routes -------------------------------------------------------
|
||||
|
||||
def _dispatch_media_routes(self, request: WsRequest, got: str) -> Response | None:
|
||||
m = re.match(r"^/api/media/([A-Za-z0-9_-]+)/([A-Za-z0-9_-]+)$", got)
|
||||
if m:
|
||||
return self._handle_media_fetch(m.group(1), m.group(2), request)
|
||||
return None
|
||||
|
||||
def _handle_media_fetch(
|
||||
self, sig: str, payload: str, request: WsRequest | None = None
|
||||
) -> Response:
|
||||
return serve_signed_media(
|
||||
sig,
|
||||
payload,
|
||||
secret=self.media_secret,
|
||||
request=request,
|
||||
media_dir=lambda channel=None: get_media_dir(channel),
|
||||
)
|
||||
|
||||
def sign_media_path(self, abs_path: Path) -> str | None:
|
||||
return sign_media_path(
|
||||
abs_path,
|
||||
secret=self.media_secret,
|
||||
media_dir=lambda channel=None: get_media_dir(channel),
|
||||
)
|
||||
|
||||
def sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
|
||||
return sign_or_stage_media_path(
|
||||
path,
|
||||
secret=self.media_secret,
|
||||
media_dir=lambda channel=None: get_media_dir(channel),
|
||||
logger=self._log,
|
||||
)
|
||||
|
||||
def rewrite_local_markdown_images(self, text: str) -> str:
|
||||
return rewrite_local_markdown_images(
|
||||
text,
|
||||
workspace_path=self.workspace_path,
|
||||
sign_path=self.sign_or_stage_media_path,
|
||||
)
|
||||
|
||||
# -- Misc routes --------------------------------------------------------
|
||||
|
||||
def _dispatch_misc_routes(
|
||||
self, connection: Any, request: WsRequest, got: str
|
||||
) -> Response | None:
|
||||
if got == "/api/sessions":
|
||||
return self._handle_sessions_list(request)
|
||||
if got == "/api/commands":
|
||||
return self._handle_commands(request)
|
||||
if got == "/api/workspaces":
|
||||
return self._handle_workspaces(connection, request)
|
||||
if got == "/api/webui/sidebar-state":
|
||||
return self._handle_webui_sidebar_state(request)
|
||||
if got == "/api/webui/sidebar-state/update":
|
||||
return self._handle_webui_sidebar_state_update(request)
|
||||
return None
|
||||
|
||||
def _handle_commands(self, request: WsRequest) -> Response:
|
||||
if not self.check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
return _http_json_response({"commands": builtin_command_palette()})
|
||||
|
||||
def _handle_workspaces(self, connection: Any, request: WsRequest) -> Response:
|
||||
if not self.check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
return _http_json_response(
|
||||
self.workspaces.payload(controls_available=_is_localhost(connection))
|
||||
)
|
||||
|
||||
def _handle_webui_sidebar_state(self, request: WsRequest) -> Response:
|
||||
if not self.check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
return _http_json_response(read_webui_sidebar_state())
|
||||
|
||||
def _handle_webui_sidebar_state_update(self, request: WsRequest) -> Response:
|
||||
if not self.check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
query = _parse_query(request.path)
|
||||
raw_state = _query_first(query, "state")
|
||||
if raw_state is None:
|
||||
return _http_error(400, "missing state")
|
||||
try:
|
||||
decoded = json.loads(raw_state)
|
||||
except json.JSONDecodeError:
|
||||
return _http_error(400, "state must be JSON")
|
||||
if not isinstance(decoded, dict):
|
||||
return _http_error(400, "state must be an object")
|
||||
try:
|
||||
state = write_webui_sidebar_state(decoded)
|
||||
except ValueError as e:
|
||||
return _http_error(400, str(e))
|
||||
except OSError:
|
||||
self._log.exception("failed to write webui sidebar state")
|
||||
return _http_error(500, "failed to write sidebar state")
|
||||
return _http_json_response(state)
|
||||
|
||||
# -- Static file serving ------------------------------------------------
|
||||
|
||||
def _serve_static(self, request_path: str) -> Response | None:
|
||||
assert self.static_dist_path is not None
|
||||
rel = request_path.lstrip("/")
|
||||
if not rel:
|
||||
rel = "index.html"
|
||||
if ".." in rel.split("/") or rel.startswith("/"):
|
||||
return _http_error(403, "Forbidden")
|
||||
candidate = (self.static_dist_path / rel).resolve()
|
||||
try:
|
||||
candidate.relative_to(self.static_dist_path)
|
||||
except ValueError:
|
||||
return _http_error(403, "Forbidden")
|
||||
if not candidate.is_file():
|
||||
index = self.static_dist_path / "index.html"
|
||||
if index.is_file():
|
||||
candidate = index
|
||||
else:
|
||||
return None
|
||||
try:
|
||||
body = candidate.read_bytes()
|
||||
except OSError as e:
|
||||
self._log.warning("static: failed to read {}: {}", candidate, e)
|
||||
return _http_error(500, "Internal Server Error")
|
||||
ctype, _ = mimetypes.guess_type(candidate.name)
|
||||
if ctype is None:
|
||||
ctype = "application/octet-stream"
|
||||
if ctype.startswith("text/") or ctype in {"application/javascript", "application/json"}:
|
||||
ctype = f"{ctype}; charset=utf-8"
|
||||
if candidate.name == "index.html":
|
||||
cache = "no-cache"
|
||||
else:
|
||||
cache = "public, max-age=31536000, immutable"
|
||||
return _http_response(
|
||||
body,
|
||||
status=200,
|
||||
content_type=ctype,
|
||||
extra_headers=[("Cache-Control", cache)],
|
||||
)
|
||||
|
||||
# -- Media helpers (called by WebSocketChannel.send) --------------------
|
||||
|
||||
def _augment_media_urls(self, payload: dict[str, Any]) -> None:
|
||||
messages = payload.get("messages")
|
||||
if not isinstance(messages, list):
|
||||
return
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
media = msg.get("media")
|
||||
if not isinstance(media, list) or not media:
|
||||
continue
|
||||
urls: list[dict[str, str]] = []
|
||||
for entry in media:
|
||||
if not isinstance(entry, str) or not entry:
|
||||
continue
|
||||
signed = self.sign_media_path(Path(entry))
|
||||
if signed is None:
|
||||
continue
|
||||
urls.append({"url": signed, "name": Path(entry).name})
|
||||
if urls:
|
||||
msg["media_urls"] = urls
|
||||
msg.pop("media", None)
|
||||
|
||||
def _augment_transcript_user_media(self, paths: list[str]) -> list[dict[str, Any]]:
|
||||
out: list[dict[str, Any]] = []
|
||||
for pstr in paths:
|
||||
path = Path(pstr)
|
||||
att = self.sign_or_stage_media_path(path)
|
||||
if att is None:
|
||||
continue
|
||||
mime, _ = mimetypes.guess_type(path.name)
|
||||
kind = "video" if mime and mime.startswith("video/") else "image"
|
||||
out.append(
|
||||
{"kind": kind, "url": att["url"], "name": att.get("name", path.name)},
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _is_websocket_channel_session_key(key: str) -> bool:
|
||||
return key.startswith("websocket:")
|
||||
@ -626,6 +626,7 @@ async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -
|
||||
return ws_media if channel == "websocket" else media_root
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
||||
monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir)
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._attach(mock_ws, "chat-1")
|
||||
@ -840,6 +841,7 @@ async def test_send_delta_stream_end_rewrites_local_markdown_image(monkeypatch,
|
||||
return path
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
||||
monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir)
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "streaming": True},
|
||||
bus,
|
||||
@ -872,6 +874,7 @@ async def test_send_delta_stream_end_rewrites_inline_final_text(monkeypatch, tmp
|
||||
return path
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
||||
monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir)
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "streaming": True},
|
||||
bus,
|
||||
|
||||
@ -710,20 +710,20 @@ async def test_api_token_pool_purges_expired(bus: MagicMock, tmp_path: Path) ->
|
||||
channel = _ch(bus, session_manager=sm, port=29908)
|
||||
# Don't start a server — directly inject and validate.
|
||||
import time as _time
|
||||
channel._api_tokens["expired"] = _time.monotonic() - 1
|
||||
channel._api_tokens["live"] = _time.monotonic() + 60
|
||||
channel._http.api_tokens["expired"] = _time.monotonic() - 1
|
||||
channel._http.api_tokens["live"] = _time.monotonic() + 60
|
||||
|
||||
class _FakeReq:
|
||||
path = "/api/sessions"
|
||||
headers = {"Authorization": "Bearer expired"}
|
||||
|
||||
assert channel._check_api_token(_FakeReq()) is False
|
||||
assert channel._http.check_api_token(_FakeReq()) is False
|
||||
|
||||
class _LiveReq:
|
||||
path = "/api/sessions"
|
||||
headers = {"Authorization": "Bearer live"}
|
||||
|
||||
assert channel._check_api_token(_LiveReq()) is True
|
||||
assert channel._http.check_api_token(_LiveReq()) is True
|
||||
|
||||
|
||||
class _FakeConn:
|
||||
@ -814,7 +814,7 @@ def test_localhost_without_auth_is_valid(bus: MagicMock) -> None:
|
||||
|
||||
def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.websocket._default_model_name_from_config",
|
||||
"nanobot.channels.ws_http._default_model_name_from_config",
|
||||
lambda: "from-disk",
|
||||
)
|
||||
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " live/model ")
|
||||
@ -826,7 +826,7 @@ def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytes
|
||||
|
||||
def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.websocket._default_model_name_from_config",
|
||||
"nanobot.channels.ws_http._default_model_name_from_config",
|
||||
lambda: "from-disk",
|
||||
)
|
||||
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " ")
|
||||
@ -838,7 +838,7 @@ def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeyp
|
||||
|
||||
def test_bootstrap_falls_back_when_runtime_raises(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.websocket._default_model_name_from_config",
|
||||
"nanobot.channels.ws_http._default_model_name_from_config",
|
||||
lambda: "from-disk",
|
||||
)
|
||||
|
||||
|
||||
@ -106,7 +106,7 @@ def test_sign_media_path_rejects_paths_outside_media_root(
|
||||
media = tmp_path / "media"
|
||||
media.mkdir()
|
||||
channel = _ch(bus, port=0)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
assert channel._sign_media_path(outside) is None
|
||||
# Traversal via the media root is also rejected — the resolve() step
|
||||
# normalises ``..`` out before the relative_to check.
|
||||
@ -121,7 +121,7 @@ def test_sign_media_path_round_trips_via_hmac(
|
||||
media.mkdir()
|
||||
(media / "a.png").write_bytes(_PNG_BYTES)
|
||||
channel = _ch(bus, port=0)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
url = channel._sign_media_path(media / "a.png")
|
||||
assert url is not None
|
||||
assert url.startswith("/api/media/")
|
||||
@ -144,7 +144,7 @@ def test_local_markdown_image_is_staged_and_rewritten(
|
||||
media = tmp_path / "media"
|
||||
channel = _ch(bus, workspace_path=workspace, port=0)
|
||||
|
||||
with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
rewritten = channel._rewrite_local_markdown_images(
|
||||
"The result:\n"
|
||||
)
|
||||
@ -166,7 +166,7 @@ def test_local_markdown_video_is_staged_and_rewritten(
|
||||
media = tmp_path / "media"
|
||||
channel = _ch(bus, workspace_path=workspace, port=0)
|
||||
|
||||
with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
rewritten = channel._rewrite_local_markdown_images(
|
||||
"The result:\n"
|
||||
)
|
||||
@ -189,7 +189,7 @@ def test_local_markdown_image_rejects_workspace_escape(
|
||||
channel = _ch(bus, workspace_path=workspace, port=0)
|
||||
text = ""
|
||||
|
||||
with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||
assert channel._rewrite_local_markdown_images(text) == text
|
||||
|
||||
assert not (media / "websocket").exists()
|
||||
@ -211,7 +211,7 @@ async def test_media_route_serves_signed_file(
|
||||
target.write_bytes(_PNG_BYTES)
|
||||
|
||||
channel = _ch(bus, port=29920)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
assert url_path is not None
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
@ -244,7 +244,7 @@ async def test_media_route_serves_video_byte_ranges(
|
||||
target.write_bytes(b"0123456789")
|
||||
|
||||
channel = _ch(bus, port=29927)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
assert url_path is not None
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
@ -276,7 +276,7 @@ async def test_media_route_serves_suffix_video_byte_ranges(
|
||||
target.write_bytes(b"0123456789")
|
||||
|
||||
channel = _ch(bus, port=29928)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
assert url_path is not None
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
@ -305,7 +305,7 @@ async def test_media_route_rejects_unsatisfiable_byte_range(
|
||||
target.write_bytes(b"0123456789")
|
||||
|
||||
channel = _ch(bus, port=29929)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
assert url_path is not None
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
@ -338,7 +338,7 @@ async def test_media_route_rejects_bad_signature(
|
||||
(media / "f.png").write_bytes(_PNG_BYTES)
|
||||
|
||||
channel = _ch(bus, port=29921)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
good = channel._sign_media_path(media / "f.png")
|
||||
assert good is not None
|
||||
_, payload = good[len("/api/media/"):].split("/", 1)
|
||||
@ -381,7 +381,7 @@ async def test_media_route_rejects_path_traversal_payload(
|
||||
).digest()[:16]
|
||||
url = f"/api/media/{b64url_encode(mac)}/{payload}"
|
||||
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
@ -405,7 +405,7 @@ async def test_media_route_404s_missing_file(
|
||||
target.write_bytes(_PNG_BYTES)
|
||||
|
||||
channel = _ch(bus, port=29923)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
assert url_path is not None
|
||||
target.unlink() # the file vanishes between signing and fetching
|
||||
@ -433,7 +433,7 @@ async def test_media_route_degrades_non_image_to_octet_stream(
|
||||
(media / "scary.html").write_bytes(b"<script>alert(1)</script>")
|
||||
|
||||
channel = _ch(bus, port=29924)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
payload = b64url_encode(b"scary.html")
|
||||
mac = hmac.new(
|
||||
channel._media_secret, payload.encode("ascii"), hashlib.sha256
|
||||
@ -464,7 +464,7 @@ async def test_media_route_serves_svg_with_strict_csp(
|
||||
target.write_text("<svg xmlns='http://www.w3.org/2000/svg'><script>alert(1)</script></svg>")
|
||||
|
||||
channel = _ch(bus, port=29928)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
url_path = channel._sign_media_path(target)
|
||||
assert url_path is not None
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
@ -505,7 +505,7 @@ async def test_session_messages_exposes_signed_media_urls(
|
||||
sm.save(sess)
|
||||
|
||||
channel = _ch(bus, session_manager=sm, port=29925)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
@ -550,7 +550,7 @@ async def test_session_messages_skips_vanished_media(
|
||||
sm.save(sess)
|
||||
|
||||
channel = _ch(bus, session_manager=sm, port=29926)
|
||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
||||
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user