refactor: extract GatewayHTTPHandler from WebSocketChannel

Extract all HTTP route handling (bootstrap, sessions, settings, media,
commands, sidebar state, static serving, token management) into a new
GatewayHTTPHandler class in nanobot/channels/ws_http.py.

WebSocketChannel is reduced from 1907 to 1372 lines (-28%), retaining
only WebSocket connection management and message dispatch.

No behavior change. 3730 tests pass, 0 failures.

Shared HTTP utility functions (path parsing, response builders, auth
helpers) now live in ws_http.py with websocket.py importing from there,
avoiding circular dependencies.

Backwards-compat property aliases on WebSocketChannel ensure existing
tests continue to work without modification.
This commit is contained in:
chengyongru 2026-05-31 12:55:23 +08:00 committed by Xubin Ren
parent 92fe40a690
commit 1a585288b2
5 changed files with 859 additions and 628 deletions

View File

@ -7,11 +7,8 @@ import email.utils
import hmac
import http
import json
import mimetypes
import re
import secrets
import ssl
import time
import uuid
from collections.abc import Callable
from contextlib import suppress
@ -31,7 +28,7 @@ from websockets.http11 import Response
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.command.builtin import builtin_command_palette
from nanobot.channels.ws_http import GatewayHTTPHandler
from nanobot.config.paths import get_media_dir, get_workspace_path
from nanobot.config.schema import Base
from nanobot.security.workspace_access import (
@ -44,32 +41,9 @@ from nanobot.utils.media_decode import (
FileSizeExceeded,
save_base64_data_url,
)
from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
from nanobot.webui.media_api import (
attach_signed_media_urls,
serve_signed_media,
sign_media_path,
sign_or_stage_media_path,
signed_media_attachments,
)
from nanobot.webui.settings_api import runtime_capabilities
from nanobot.webui.settings_routes import WebUISettingsRouter
from nanobot.webui.sidebar_state import (
read_webui_sidebar_state,
write_webui_sidebar_state,
)
from nanobot.webui.thread_disk import delete_webui_thread
from nanobot.webui.transcript import (
append_transcript_object,
build_webui_thread_response,
rewrite_local_markdown_images,
)
from nanobot.webui.websocket_logging import websockets_server_logger
from nanobot.webui.workspaces import (
WebUIWorkspaceController,
)
from nanobot.webui.transcript import append_transcript_object
if TYPE_CHECKING:
from nanobot.session.manager import SessionManager
@ -500,51 +474,36 @@ class WebSocketChannel(BaseChannel):
self._conn_chats: dict[Any, set[str]] = {}
# connection -> default chat_id for legacy frames that omit routing.
self._conn_default: dict[Any, str] = {}
# Single-use tokens consumed at WebSocket handshake.
self._issued_tokens: dict[str, float] = {}
# Multi-use tokens for HTTP routes served beside WS; checked but not consumed.
self._api_tokens: dict[str, float] = {}
self._stop_event: asyncio.Event | None = None
self._server_task: asyncio.Task[None] | None = None
self._session_manager = session_manager
self._static_dist_path: Path | None = (
static_dist_path.resolve() if static_dist_path is not None else None
)
self._workspace_path = (
_resolved_workspace = (
Path(workspace_path).expanduser()
if workspace_path is not None
else get_workspace_path()
).resolve(strict=False)
self._default_restrict_to_workspace = restrict_to_workspace
self._webui_workspaces = WebUIWorkspaceController(
session_manager=self._session_manager,
default_workspace=self._workspace_path,
default_restrict_to_workspace=self._default_restrict_to_workspace,
)
self._runtime_model_name = runtime_model_name
self._runtime_surface = (
"native" if runtime_surface in {"native", "desktop"} else "browser"
)
self._runtime_capabilities = runtime_capabilities(
self._runtime_surface,
runtime_capabilities_overrides,
)
self._settings_routes = WebUISettingsRouter(
bus=self.bus,
logger=self.logger,
check_api_token=self._check_api_token,
parse_query=_parse_query,
json_response=_http_json_response,
error_response=_http_error,
# HTTP API handler — owns tokens, sessions, media, settings, static serving
self._http = GatewayHTTPHandler(
config=self.config,
session_manager=session_manager,
static_dist_path=(
static_dist_path.resolve() if static_dist_path is not None else None
),
workspace_path=_resolved_workspace,
runtime_model_name=runtime_model_name,
runtime_surface=self._runtime_surface,
runtime_capabilities=self._runtime_capabilities,
runtime_capabilities_overrides=runtime_capabilities_overrides,
bus=self.bus,
log=self.logger,
)
# Backwards-compat aliases for workspace controller used in envelope dispatch
self._webui_workspaces = self._http.workspaces
self._stream_text_buffers: dict[tuple[str, str], list[str]] = {}
# Process-local secret used to HMAC-sign media URLs. The signed URL is
# the capability — anyone who holds a valid URL can fetch that one
# file, nothing else. The secret regenerates on restart so links
# become self-expiring (callers just refresh the session list).
self._media_secret: bytes = secrets.token_bytes(32)
# -- Subscription bookkeeping -------------------------------------------
@ -572,9 +531,9 @@ class WebSocketChannel(BaseChannel):
connected clients normally see it via ``goal_state`` / ``turn_end`` frames.
Pushing here makes refresh + reconnect restore the strip without a new model turn.
"""
if self._session_manager is None:
if self._http.session_manager is None:
return
row = self._session_manager.read_session_file(f"websocket:{chat_id}")
row = self._http.session_manager.read_session_file(f"websocket:{chat_id}")
meta = row.get("metadata", {}) if isinstance(row, dict) else {}
if not isinstance(meta, dict):
meta = {}
@ -614,6 +573,59 @@ class WebSocketChannel(BaseChannel):
def _expected_path(self) -> str:
return _normalize_config_path(self.config.path)
# -- Backwards-compat property aliases (used by tests) ------------------
@property
def _session_manager(self):
return self._http.session_manager
@_session_manager.setter
def _session_manager(self, value):
self._http.session_manager = value
@property
def _media_secret(self):
return self._http.media_secret
@property
def _issued_tokens(self):
return self._http.issued_tokens
@_issued_tokens.setter
def _issued_tokens(self, value):
self._http.issued_tokens = value
@property
def _api_tokens(self):
return self._http.api_tokens
def _check_api_token(self, request):
return self._http.check_api_token(request)
def _sign_media_path(self, path):
return self._http.sign_media_path(path)
def _sign_or_stage_media_path(self, path):
return self._http.sign_or_stage_media_path(path)
def _rewrite_local_markdown_images(self, text):
return self._http.rewrite_local_markdown_images(text)
def _handle_bootstrap(self, connection, request):
return self._http._handle_bootstrap(connection, request)
_MAX_ISSUED_TOKENS = GatewayHTTPHandler._MAX_ISSUED_TOKENS
def _handle_sessions_list(self, request):
return self._http._handle_sessions_list(request)
def _handle_webui_thread_get(self, request, key):
return self._http._handle_webui_thread_get(request, key)
@property
def _workspace_path(self):
return self._http.workspace_path
def _build_ssl_context(self) -> ssl.SSLContext | None:
cert = self.config.ssl_certfile.strip()
key = self.config.ssl_keyfile.strip()
@ -628,548 +640,24 @@ class WebSocketChannel(BaseChannel):
ctx.load_cert_chain(certfile=cert, keyfile=key)
return ctx
_MAX_ISSUED_TOKENS = 10_000
def _purge_expired_issued_tokens(self) -> None:
now = time.monotonic()
for token_key, expiry in list(self._issued_tokens.items()):
if now > expiry:
self._issued_tokens.pop(token_key, None)
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
"""Validate and consume one issued token (single use per connection attempt).
Uses single-step pop to minimize the window between lookup and removal;
safe under asyncio's single-threaded cooperative model.
"""
if not token_value:
return False
self._purge_expired_issued_tokens()
expiry = self._issued_tokens.pop(token_value, None)
if expiry is None:
return False
if time.monotonic() > expiry:
return False
return True
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
secret = self.config.token_issue_secret.strip() or self.config.token.strip()
if secret:
if not _issue_route_secret_matches(request.headers, secret):
return connection.respond(401, "Unauthorized")
else:
self.logger.warning(
"token_issue_path is set but no token_issue_secret or static token is configured; "
"any client can obtain connection tokens — set a secret for production."
)
self._purge_expired_issued_tokens()
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
self.logger.error(
"too many outstanding issued tokens ({}), rejecting issuance",
len(self._issued_tokens),
)
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
return _http_json_response(
{"token": token_value, "expires_in": self.config.token_ttl_s}
)
# -- HTTP dispatch ------------------------------------------------------
async def _dispatch_http(self, connection: Any, request: WsRequest) -> Any:
"""Route an inbound HTTP request to a handler or to the WS upgrade path."""
"""Route an inbound HTTP request to the HTTP handler or WS upgrade."""
got, query = _parse_request_path(request.path)
if self.config.token_issue_path:
issue_expected = _normalize_config_path(self.config.token_issue_path)
if got == issue_expected:
return self._handle_token_issue_http(connection, request)
if got == "/webui/bootstrap":
return self._handle_bootstrap(connection, request)
api_response = await self._dispatch_api_route(connection, request, got)
if api_response is not None:
return api_response
ws_matched, ws_response = self._dispatch_websocket_upgrade(
connection, request, got, query
)
if ws_matched:
return ws_response
# API clients should never receive the SPA shell for an unknown route.
# Returning HTML here makes the WebUI fail with "Unexpected token <"
# when a dev server is pointed at an older gateway.
if got.startswith("/api/"):
return _http_error(404, "API route not found")
if self._static_dist_path is not None:
response = self._serve_static(got)
if response is not None:
return response
return connection.respond(404, "Not Found")
async def _dispatch_api_route(
self,
connection: Any,
request: WsRequest,
got: str,
) -> Any | None:
"""Route REST-ish WebUI requests served beside the WebSocket endpoint."""
response = await self._dispatch_settings_api_route(request, got)
if response is not None:
return response
response = self._dispatch_session_api_route(request, got)
if response is not None:
return response
response = self._dispatch_media_api_route(request, got)
if response is not None:
return response
return self._dispatch_misc_api_route(connection, request, got)
def _dispatch_misc_api_route(
self,
connection: Any,
request: WsRequest,
got: str,
) -> Response | None:
"""Route small API endpoints that do not belong to a larger route group."""
if got == "/api/sessions":
return self._handle_sessions_list(request)
if got == "/api/commands":
return self._handle_commands(request)
if got == "/api/workspaces":
return self._handle_workspaces(connection, request)
if got == "/api/webui/sidebar-state":
return self._handle_webui_sidebar_state(request)
if got == "/api/webui/sidebar-state/update":
return self._handle_webui_sidebar_state_update(request)
return None
async def _dispatch_settings_api_route(
self,
request: WsRequest,
got: str,
) -> Response | None:
return await self._settings_routes.dispatch(request, got)
def _dispatch_session_api_route(
self,
request: WsRequest,
got: str,
) -> Response | None:
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
if m:
return self._handle_session_messages(request, m.group(1))
m = re.match(r"^/api/sessions/([^/]+)/webui-thread$", got)
if m:
return self._handle_webui_thread_get(request, m.group(1))
# NOTE: websockets' HTTP parser only accepts GET, so we cannot expose a
# true ``DELETE`` verb. The action is folded into the path instead.
m = re.match(r"^/api/sessions/([^/]+)/delete$", got)
if m:
return self._handle_session_delete(request, m.group(1))
return None
def _dispatch_media_api_route(
self,
request: WsRequest,
got: str,
) -> Response | None:
m = re.match(r"^/api/media/([A-Za-z0-9_-]+)/([A-Za-z0-9_-]+)$", got)
if m:
return self._handle_media_fetch(m.group(1), m.group(2), request)
return None
def _dispatch_websocket_upgrade(
self,
connection: Any,
request: WsRequest,
got: str,
query: dict[str, list[str]],
) -> tuple[bool, Any | None]:
"""Authorize only real WS upgrade requests for the configured path."""
# WebSocket upgrade — channel handles this itself
expected_ws = self._expected_path()
if got != expected_ws or not _is_websocket_upgrade(request):
return False, None
client_id = _query_first(query, "client_id") or ""
if len(client_id) > 128:
client_id = client_id[:128]
if not self.is_allowed(client_id):
return True, connection.respond(403, "Forbidden")
return True, self._authorize_websocket_handshake(connection, query)
if got == expected_ws and _is_websocket_upgrade(request):
client_id = _query_first(query, "client_id") or ""
if len(client_id) > 128:
client_id = client_id[:128]
if not self.is_allowed(client_id):
return connection.respond(403, "Forbidden")
return self._authorize_websocket_handshake(connection, query)
# -- HTTP route handlers ------------------------------------------------
def _check_api_token(self, request: WsRequest) -> bool:
"""Validate a request against the API token pool (multi-use, TTL-bound)."""
self._purge_expired_api_tokens()
token = _bearer_token(request.headers) or _query_first(
_parse_query(request.path), "token"
)
if not token:
return False
expiry = self._api_tokens.get(token)
if expiry is None or time.monotonic() > expiry:
self._api_tokens.pop(token, None)
return False
return True
def _purge_expired_api_tokens(self) -> None:
now = time.monotonic()
for token_key, expiry in list(self._api_tokens.items()):
if now > expiry:
self._api_tokens.pop(token_key, None)
def _handle_bootstrap(self, connection: Any, request: Any) -> Response:
# When a secret is configured (token_issue_secret or static token),
# validate it regardless of source IP. This secures deployments
# behind a reverse proxy where all connections appear as localhost.
secret = self.config.token_issue_secret.strip() or self.config.token.strip()
if secret:
if not _issue_route_secret_matches(request.headers, secret):
return _http_error(401, "Unauthorized")
elif not _is_localhost(connection):
# No secret configured: only allow localhost (local dev mode).
return _http_error(403, "bootstrap is localhost-only")
# Cap outstanding tokens to avoid runaway growth from a misbehaving client.
self._purge_expired_issued_tokens()
self._purge_expired_api_tokens()
if (
len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS
or len(self._api_tokens) >= self._MAX_ISSUED_TOKENS
):
return _http_response(
json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"),
status=429,
content_type="application/json; charset=utf-8",
)
token = f"nbwt_{secrets.token_urlsafe(32)}"
expiry = time.monotonic() + float(self.config.token_ttl_s)
# Same string registered in both pools: the WS handshake consumes one copy
# while the REST surface keeps validating the other until TTL expiry.
self._issued_tokens[token] = expiry
self._api_tokens[token] = expiry
ws_url = self._bootstrap_ws_url(request)
return _http_json_response(
{
"token": token,
"ws_path": self._expected_path(),
"ws_url": ws_url,
"expires_in": self.config.token_ttl_s,
"model_name": _resolve_bootstrap_model_name(self._runtime_model_name),
"runtime_surface": self._runtime_surface,
"runtime_capabilities": self._runtime_capabilities,
}
)
def _bootstrap_ws_url(self, request: Any) -> str:
"""Absolute WS URL clients should prefer over a dev-server proxy."""
headers = getattr(request, "headers", {}) or {}
host = _safe_host_header(_case_insensitive_header(headers, "Host"))
if not host:
host = _host_for_url(self.config.host, self.config.port)
proto = _case_insensitive_header(headers, "X-Forwarded-Proto")
proto = proto.split(",", 1)[0].strip().lower()
secure = proto in {"https", "wss"} or bool(self.config.ssl_certfile.strip())
scheme = "wss" if secure else "ws"
return f"{scheme}://{host}{self._expected_path()}"
def _handle_sessions_list(self, request: WsRequest) -> Response:
if not self._check_api_token(request):
return _http_error(401, "Unauthorized")
if self._session_manager is None:
return _http_error(503, "session manager unavailable")
sessions = self._session_manager.list_sessions()
# Sidebar/chat listing for WS-backed sessions only — CLI / Slack / etc.
# keys are not intended for resume over this HTTP surface.
cleaned = []
for s in sessions:
key = s.get("key")
if not (isinstance(key, str) and key.startswith("websocket:")):
continue
row = {k: v for k, v in s.items() if k != "path"}
chat_id = key.split(":", 1)[1]
started_at = websocket_turn_wall_started_at(chat_id)
if started_at is not None:
row["run_started_at"] = started_at
scope = self._webui_workspaces.scope_for_session_key(key)
row["workspace_scope"] = scope.payload()
cleaned.append(row)
return _http_json_response({"sessions": cleaned})
def _handle_workspaces(self, connection: Any, request: WsRequest) -> Response:
if not self._check_api_token(request):
return _http_error(401, "Unauthorized")
return _http_json_response(
self._webui_workspaces.payload(controls_available=_is_localhost(connection))
)
def _handle_commands(self, request: WsRequest) -> Response:
if not self._check_api_token(request):
return _http_error(401, "Unauthorized")
return _http_json_response({"commands": builtin_command_palette()})
def _handle_webui_sidebar_state(self, request: WsRequest) -> Response:
if not self._check_api_token(request):
return _http_error(401, "Unauthorized")
return _http_json_response(read_webui_sidebar_state())
def _handle_webui_sidebar_state_update(self, request: WsRequest) -> Response:
if not self._check_api_token(request):
return _http_error(401, "Unauthorized")
query = _parse_query(request.path)
raw_state = _query_first(query, "state")
if raw_state is None:
return _http_error(400, "missing state")
try:
decoded = json.loads(raw_state)
except json.JSONDecodeError:
return _http_error(400, "state must be JSON")
if not isinstance(decoded, dict):
return _http_error(400, "state must be an object")
try:
state = write_webui_sidebar_state(decoded)
except ValueError as e:
return _http_error(400, str(e))
except OSError:
self.logger.exception("failed to write webui sidebar state")
return _http_error(500, "failed to write sidebar state")
return _http_json_response(state)
# -- Session replay, transcript, and signed media ----------------------
@staticmethod
def _is_websocket_channel_session_key(key: str) -> bool:
"""True when *key* is a ``websocket:…`` session exposed on this HTTP surface."""
return key.startswith("websocket:")
def _handle_session_messages(self, request: WsRequest, key: str) -> Response:
if not self._check_api_token(request):
return _http_error(401, "Unauthorized")
if self._session_manager is None:
return _http_error(503, "session manager unavailable")
decoded_key = _decode_api_key(key)
if decoded_key is None:
return _http_error(400, "invalid session key")
# Only ``websocket:…`` sessions are listed/served here — same boundary as
# ``/api/sessions``. Block handcrafted URLs from probing CLI / Slack / etc.
if not self._is_websocket_channel_session_key(decoded_key):
return _http_error(404, "session not found")
data = self._session_manager.read_session_file(decoded_key)
if data is None:
return _http_error(404, "session not found")
messages = data.get("messages")
if isinstance(messages, list):
scrub_subagent_messages_for_channel(messages)
# Decorate persisted user messages with signed media URLs so the
# client can render previews. The raw on-disk ``media`` paths are
# stripped on the way out — they leak server filesystem layout and
# the client never needs them once it has the signed fetch URL.
attach_signed_media_urls(data, sign_path=self._sign_media_path)
return _http_json_response(data)
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
if not self._check_api_token(request):
return _http_error(401, "Unauthorized")
decoded_key = _decode_api_key(key)
if decoded_key is None:
return _http_error(400, "invalid session key")
if not self._is_websocket_channel_session_key(decoded_key):
return _http_error(404, "session not found")
scope = self._webui_workspaces.scope_for_session_key(decoded_key)
augment_media = partial(
signed_media_attachments,
sign_path=self._sign_or_stage_media_path,
)
data = build_webui_thread_response(
decoded_key,
augment_user_media=augment_media,
augment_assistant_media=augment_media,
augment_assistant_text=lambda text: rewrite_local_markdown_images(
text,
workspace_path=scope.project_path,
sign_path=self._sign_or_stage_media_path,
),
)
if data is None:
return _http_error(404, "webui thread not found")
data["workspace_scope"] = scope.payload()
return _http_json_response(data)
def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None:
sk = f"websocket:{chat_id}"
try:
dup = json.loads(json.dumps(wire, ensure_ascii=False))
append_transcript_object(sk, dup)
except (ValueError, TypeError) as e:
self.logger.warning("webui transcript append failed: {}", e)
async def _handle_message(
self,
sender_id: str,
chat_id: str,
content: str,
media: list[str] | None = None,
metadata: dict[str, Any] | None = None,
session_key: str | None = None,
is_dm: bool = False,
) -> None:
meta = metadata or {}
if meta.get("webui"):
user_obj: dict[str, Any] = {
"event": "user",
"chat_id": chat_id,
"text": content,
}
if media:
user_obj["media_paths"] = list(media)
cli_apps = meta.get("cli_apps")
if isinstance(cli_apps, list) and cli_apps:
user_obj["cli_apps"] = cli_apps
mcp_presets = meta.get("mcp_presets")
if isinstance(mcp_presets, list) and mcp_presets:
user_obj["mcp_presets"] = mcp_presets
self._try_append_webui_transcript(chat_id, user_obj)
await super()._handle_message(
sender_id,
chat_id,
content,
media,
metadata,
session_key,
is_dm,
)
def _sign_media_path(self, abs_path: Path) -> str | None:
"""Return a ``/api/media/<sig>/<payload>`` URL for *abs_path*, or
``None`` when the path does not resolve inside the media root.
The URL is self-authenticating: the signature binds the payload to
this process's ``_media_secret``, so only paths we chose to sign can
be fetched. The returned path is relative to the server origin; the
client joins it against this server's HTTP origin (same host as WS).
"""
return sign_media_path(
abs_path,
secret=self._media_secret,
media_dir=lambda channel=None: get_media_dir(channel),
)
def _sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
"""Return a signed media URL payload for *path*.
Persisted inbound media already lives under ``get_media_dir`` and can
be signed directly. Outbound bot-generated files may live anywhere on
disk; copy those into the websocket media bucket first so the browser
can fetch them through the existing signed media route without
exposing arbitrary filesystem paths.
"""
return sign_or_stage_media_path(
path,
secret=self._media_secret,
media_dir=lambda channel=None: get_media_dir(channel),
logger=self.logger,
)
def _rewrite_local_markdown_images(self, text: str) -> str:
return rewrite_local_markdown_images(
text,
workspace_path=self._workspace_path,
sign_path=self._sign_or_stage_media_path,
)
def _handle_media_fetch(
self, sig: str, payload: str, request: WsRequest | None = None
) -> Response:
"""Serve a single media file previously signed via
:meth:`_sign_media_path`. Validates the signature, decodes the
payload to a relative path, and streams the file bytes with a
long-lived immutable cache header (the URL already encodes the
file identity, so caches can be aggressive)."""
return serve_signed_media(
sig,
payload,
secret=self._media_secret,
request=request,
media_dir=lambda channel=None: get_media_dir(channel),
)
def _handle_session_delete(self, request: WsRequest, key: str) -> Response:
if not self._check_api_token(request):
return _http_error(401, "Unauthorized")
if self._session_manager is None:
return _http_error(503, "session manager unavailable")
decoded_key = _decode_api_key(key)
if decoded_key is None:
return _http_error(400, "invalid session key")
# Same boundary as ``_handle_session_messages``: mutations apply only to
# websocket-channel sessions; deletion unlinks local JSONL — keep scope narrow.
if not self._is_websocket_channel_session_key(decoded_key):
return _http_error(404, "session not found")
deleted = self._session_manager.delete_session(decoded_key)
delete_webui_thread(decoded_key)
return _http_json_response({"deleted": bool(deleted)})
# -- Static files and WebSocket handshake ------------------------------
def _serve_static(self, request_path: str) -> Response | None:
"""Resolve *request_path* against the built SPA directory; SPA fallback to index.html."""
assert self._static_dist_path is not None
rel = request_path.lstrip("/")
if not rel:
rel = "index.html"
# Reject path-traversal attempts and absolute targets.
if ".." in rel.split("/") or rel.startswith("/"):
return _http_error(403, "Forbidden")
candidate = (self._static_dist_path / rel).resolve()
try:
candidate.relative_to(self._static_dist_path)
except ValueError:
return _http_error(403, "Forbidden")
if not candidate.is_file():
# SPA history-mode fallback: unknown routes serve index.html so the
# client-side router can render them.
index = self._static_dist_path / "index.html"
if index.is_file():
candidate = index
else:
return None
try:
body = candidate.read_bytes()
except OSError as e:
self.logger.warning("static: failed to read {}: {}", candidate, e)
return _http_error(500, "Internal Server Error")
ctype, _ = mimetypes.guess_type(candidate.name)
if ctype is None:
ctype = "application/octet-stream"
if ctype.startswith("text/") or ctype in {"application/javascript", "application/json"}:
ctype = f"{ctype}; charset=utf-8"
# Hash-named build assets are cache-friendly; index.html must stay fresh.
if candidate.name == "index.html":
cache = "no-cache"
else:
cache = "public, max-age=31536000, immutable"
return _http_response(
body,
status=200,
content_type=ctype,
extra_headers=[("Cache-Control", cache)],
)
<< # Everything else goes to the HTTP handler
return await self._http.dispatch(connection, request)
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
supplied = _query_first(query, "token")
@ -1178,20 +666,21 @@ class WebSocketChannel(BaseChannel):
if static_token:
if supplied and hmac.compare_digest(supplied, static_token):
return None
if supplied and self._take_issued_token_if_valid(supplied):
if supplied and self._http.take_issued_token_if_valid(supplied):
return None
return connection.respond(401, "Unauthorized")
if self.config.websocket_requires_token:
if supplied and self._take_issued_token_if_valid(supplied):
if supplied and self._http.take_issued_token_if_valid(supplied):
return None
return connection.respond(401, "Unauthorized")
if supplied:
self._take_issued_token_if_valid(supplied)
self._http.take_issued_token_if_valid(supplied)
return None
# -- Server lifecycle and connection ingress ---------------------------
# -- Server lifecycle and connection ingress ---------------------------
async def start(self) -> None:
from nanobot.utils.logging_bridge import redirect_lib_logging
@ -1587,8 +1076,8 @@ class WebSocketChannel(BaseChannel):
self._subs.clear()
self._conn_chats.clear()
self._conn_default.clear()
self._issued_tokens.clear()
self._api_tokens.clear()
self._http.issued_tokens.clear()
self._http.api_tokens.clear()
async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None:
"""Send a raw frame to one connection, cleaning up on ConnectionClosed."""
@ -1601,6 +1090,14 @@ class WebSocketChannel(BaseChannel):
self.logger.exception("send failed{}", label)
raise
def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None:
sk = f"websocket:{chat_id}"
try:
dup = json.loads(json.dumps(wire, ensure_ascii=False))
append_transcript_object(sk, dup)
except (ValueError, TypeError) as e:
self.logger.warning("webui transcript append failed: {}", e)
async def send(self, msg: OutboundMessage) -> None:
if msg.metadata.get("_runtime_model_updated"):
await self.send_runtime_model_updated(
@ -1662,7 +1159,7 @@ class WebSocketChannel(BaseChannel):
)
return
text = msg.content
wire_text = self._rewrite_local_markdown_images(text)
wire_text = self._http.rewrite_local_markdown_images(text)
payload: dict[str, Any] = {
"event": "message",
"chat_id": msg.chat_id,
@ -1672,7 +1169,7 @@ class WebSocketChannel(BaseChannel):
payload["media"] = msg.media
urls: list[dict[str, str]] = []
for entry in msg.media:
signed = self._sign_or_stage_media_path(Path(entry))
signed = self._http.sign_or_stage_media_path(Path(entry))
if signed is not None:
urls.append(signed)
if urls:
@ -1787,7 +1284,7 @@ class WebSocketChannel(BaseChannel):
if delta:
buffered.append(delta)
full_text = "".join(buffered)
rewritten = self._rewrite_local_markdown_images(full_text)
rewritten = self._http.rewrite_local_markdown_images(full_text)
if rewritten != full_text:
body["text"] = rewritten
else:

731
nanobot/channels/ws_http.py Normal file
View File

@ -0,0 +1,731 @@
"""HTTP API handler extracted from WebSocketChannel.
Handles all non-WebSocket HTTP routes: bootstrap, sessions, settings,
media, commands, sidebar state, static file serving, and token management.
Also houses shared HTTP utility functions used by both this module and
``websocket.py`` to avoid circular imports.
"""
from __future__ import annotations
import email.utils
import hmac
import http
import json
import mimetypes
import re
import secrets
import time
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import parse_qs, urlparse
from loguru import logger
from websockets.datastructures import Headers
from websockets.http11 import Request as WsRequest
from websockets.http11 import Response
from nanobot.command.builtin import builtin_command_palette
from nanobot.config.paths import get_media_dir
from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel
from nanobot.webui.media_api import (
serve_signed_media,
sign_media_path,
sign_or_stage_media_path,
)
from nanobot.webui.sidebar_state import (
read_webui_sidebar_state,
write_webui_sidebar_state,
)
from nanobot.webui.thread_disk import delete_webui_thread
from nanobot.webui.transcript import (
build_webui_thread_response,
rewrite_local_markdown_images,
)
if TYPE_CHECKING:
from nanobot.bus.queue import MessageBus
from nanobot.session.manager import SessionManager
# ---------------------------------------------------------------------------
# Shared HTTP utility functions (imported by websocket.py)
# ---------------------------------------------------------------------------
def _strip_trailing_slash(path: str) -> str:
if len(path) > 1 and path.endswith("/"):
return path.rstrip("/")
return path or "/"
def _normalize_config_path(path: str) -> str:
return _strip_trailing_slash(path)
def _case_insensitive_header(headers: Any, key: str) -> str:
"""Read a header from websockets/http test stubs without assuming casing."""
try:
value = headers.get(key)
except Exception:
value = None
if value is None:
try:
value = headers.get(key.lower())
except Exception:
value = None
return str(value or "").strip()
def _safe_host_header(value: str) -> str:
"""Return a safe Host header value, or empty when it should not be echoed."""
value = value.strip()
if not value:
return ""
if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value):
return value
if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value):
return value
return ""
def _host_for_url(host: str, port: int) -> str:
host = host.strip()
if host in ("0.0.0.0", "::"):
host = "127.0.0.1"
if ":" in host and not host.startswith("["):
host = f"[{host}]"
return f"{host}:{port}"
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
headers = Headers(
[
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", "application/json; charset=utf-8"),
]
)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, headers, body)
def _http_response(
body: bytes,
*,
status: int = 200,
content_type: str = "text/plain; charset=utf-8",
extra_headers: list[tuple[str, str]] | None = None,
) -> Response:
headers = [
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", content_type),
]
if extra_headers:
headers.extend(extra_headers)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, Headers(headers), body)
def _http_error(status: int, message: str | None = None) -> Response:
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
return _http_response(body, status=status)
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
"""Parse normalized path and query parameters in one pass."""
parsed = urlparse("ws://x" + path_with_query)
path = _strip_trailing_slash(parsed.path or "/")
return path, parse_qs(parsed.query, keep_blank_values=True)
def _normalize_http_path(path_with_query: str) -> str:
return _parse_request_path(path_with_query)[0]
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
return _parse_request_path(path_with_query)[1]
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
values = query.get(key)
return values[0] if values else None
def _is_localhost(connection: Any) -> bool:
addr = getattr(connection, "remote_address", None)
if not addr:
return False
host = addr[0] if isinstance(addr, tuple) else addr
if not isinstance(host, str):
return False
if host.startswith("::ffff:"):
host = host[7:]
return host in {"127.0.0.1", "::1", "localhost"}
def _bearer_token(headers: Any) -> str | None:
auth = headers.get("Authorization") or headers.get("authorization")
if auth and auth.lower().startswith("bearer "):
return auth[7:].strip() or None
return None
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
if not configured_secret:
return True
authorization = headers.get("Authorization") or headers.get("authorization")
if authorization and authorization.lower().startswith("bearer "):
supplied = authorization[7:].strip()
return hmac.compare_digest(supplied, configured_secret)
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
if not header_token:
return False
return hmac.compare_digest(header_token.strip(), configured_secret)
def _decode_api_key(raw_key: str) -> str | None:
from urllib.parse import unquote
key = unquote(raw_key)
_api_key_re = re.compile(r"^[A-Za-z0-9_:.-]{1,128}$")
if _api_key_re.match(key) is None:
return None
return key
def _default_model_name_from_config() -> str | None:
try:
from nanobot.config.loader import load_config
model = load_config().resolve_preset().model.strip()
return model or None
except Exception as e:
logger.debug("bootstrap model_name could not load from config: {}", e)
return None
def _resolve_bootstrap_model_name(
runtime_name: Callable[[], str | None] | None,
) -> str | None:
if runtime_name is not None:
try:
raw = runtime_name()
except Exception as e:
logger.debug("bootstrap runtime model resolver failed: {}", e)
else:
if isinstance(raw, str):
stripped = raw.strip()
if stripped:
return stripped
return _default_model_name_from_config()
# ---------------------------------------------------------------------------
# GatewayHTTPHandler
# ---------------------------------------------------------------------------
class GatewayHTTPHandler:
"""Handles all HTTP routes served alongside the WebSocket endpoint.
Owns token management, session API, media API, static file serving,
and delegates settings routes to ``WebUISettingsRouter``.
"""
_MAX_ISSUED_TOKENS = 10_000
def __init__(
self,
*,
config: Any, # WebSocketConfig
session_manager: SessionManager | None,
static_dist_path: Path | None,
workspace_path: Path,
runtime_model_name: Callable[[], str | None] | None,
runtime_surface: str,
runtime_capabilities_overrides: dict[str, Any] | None,
bus: MessageBus,
log: Any = logger,
) -> None:
self.config = config
self.session_manager = session_manager
self.static_dist_path = static_dist_path
self.workspace_path = workspace_path
self.runtime_model_name = runtime_model_name
self.bus = bus
self._log = log
self._runtime_surface = runtime_surface
self.issued_tokens: dict[str, float] = {}
self.api_tokens: dict[str, float] = {}
self.media_secret: bytes = secrets.token_bytes(32)
# Workspace controller
from nanobot.webui.workspaces import WebUIWorkspaceController
self.workspaces = WebUIWorkspaceController(
session_manager=session_manager,
default_workspace=workspace_path,
default_restrict_to_workspace=None,
)
# Settings router
from nanobot.webui.settings_api import runtime_capabilities as _rc
from nanobot.webui.settings_routes import WebUISettingsRouter
self._capabilities = _rc(runtime_surface, runtime_capabilities_overrides or {})
self.settings_routes = WebUISettingsRouter(
bus=bus,
logger=self._log,
check_api_token=self.check_api_token,
parse_query=_parse_query,
json_response=_http_json_response,
error_response=_http_error,
runtime_surface=runtime_surface,
runtime_capabilities=self._capabilities,
)
# -- Token management ---------------------------------------------------
def check_api_token(self, request: WsRequest) -> bool:
self._purge_expired_api_tokens()
token = _bearer_token(request.headers) or _query_first(
_parse_query(request.path), "token"
)
if not token:
return False
expiry = self.api_tokens.get(token)
if expiry is None or time.monotonic() > expiry:
self.api_tokens.pop(token, None)
return False
return True
def _purge_expired_api_tokens(self) -> None:
now = time.monotonic()
for token_key, expiry in list(self.api_tokens.items()):
if now > expiry:
self.api_tokens.pop(token_key, None)
def _purge_expired_issued_tokens(self) -> None:
now = time.monotonic()
for token_key, expiry in list(self.issued_tokens.items()):
if now > expiry:
self.issued_tokens.pop(token_key, None)
def take_issued_token_if_valid(self, token_value: str | None) -> bool:
if not token_value:
return False
self._purge_expired_issued_tokens()
expiry = self.issued_tokens.pop(token_value, None)
if expiry is None:
return False
if time.monotonic() > expiry:
return False
return True
# -- Main dispatch ------------------------------------------------------
async def dispatch(self, connection: Any, request: WsRequest) -> Any | None:
"""Route an HTTP request. Returns Response or None."""
got, query = _parse_request_path(request.path)
# Token issue endpoint
if self.config.token_issue_path:
issue_expected = _normalize_config_path(self.config.token_issue_path)
if got == issue_expected:
return self._handle_token_issue(connection, request)
# Bootstrap
if got == "/webui/bootstrap":
return self._handle_bootstrap(connection, request)
# Settings routes (delegated)
response = await self.settings_routes.dispatch(request, got)
if response is not None:
return response
# Session routes
response = self._dispatch_session_routes(request, got)
if response is not None:
return response
# Media routes
response = self._dispatch_media_routes(request, got)
if response is not None:
return response
# Misc routes
response = self._dispatch_misc_routes(connection, request, got)
if response is not None:
return response
# API 404 (never serve SPA for /api/ routes)
if got.startswith("/api/"):
return _http_error(404, "API route not found")
# Static SPA serving
if self.static_dist_path is not None:
response = self._serve_static(got)
if response is not None:
return response
return connection.respond(404, "Not Found")
# -- Token issue --------------------------------------------------------
def _handle_token_issue(self, connection: Any, request: Any) -> Any:
secret = self.config.token_issue_secret.strip()
if secret:
if not _issue_route_secret_matches(request.headers, secret):
return connection.respond(401, "Unauthorized")
else:
self._log.warning(
"token_issue_path is set but token_issue_secret is empty; "
"any client can obtain connection tokens — set token_issue_secret for production."
)
self._purge_expired_issued_tokens()
if len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS:
self._log.error(
"too many outstanding issued tokens ({}), rejecting issuance",
len(self.issued_tokens),
)
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
self.issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
return _http_json_response(
{"token": token_value, "expires_in": self.config.token_ttl_s}
)
# -- Bootstrap ----------------------------------------------------------
def _handle_bootstrap(self, connection: Any, request: Any) -> Response:
secret = self.config.token_issue_secret.strip() or self.config.token.strip()
if secret:
if not _issue_route_secret_matches(request.headers, secret):
return _http_error(401, "Unauthorized")
elif not _is_localhost(connection):
return _http_error(403, "bootstrap is localhost-only")
self._purge_expired_issued_tokens()
self._purge_expired_api_tokens()
if (
len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS
or len(self.api_tokens) >= self._MAX_ISSUED_TOKENS
):
return _http_response(
json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"),
status=429,
content_type="application/json; charset=utf-8",
)
token = f"nbwt_{secrets.token_urlsafe(32)}"
expiry = time.monotonic() + float(self.config.token_ttl_s)
self.issued_tokens[token] = expiry
self.api_tokens[token] = expiry
ws_url = self._bootstrap_ws_url(request)
expected_path = _normalize_config_path(self.config.path)
return _http_json_response(
{
"token": token,
"ws_path": expected_path,
"ws_url": ws_url,
"expires_in": self.config.token_ttl_s,
"model_name": _resolve_bootstrap_model_name(self.runtime_model_name),
"runtime_surface": self._runtime_surface,
"runtime_capabilities": self._capabilities,
}
)
def _bootstrap_ws_url(self, request: Any) -> str:
headers = getattr(request, "headers", {}) or {}
host = _safe_host_header(_case_insensitive_header(headers, "Host"))
if not host:
host = _host_for_url(self.config.host, self.config.port)
proto = _case_insensitive_header(headers, "X-Forwarded-Proto")
proto = proto.split(",", 1)[0].strip().lower()
secure = proto in {"https", "wss"} or bool(self.config.ssl_certfile.strip())
scheme = "wss" if secure else "ws"
expected_path = _normalize_config_path(self.config.path)
return f"{scheme}://{host}{expected_path}"
# -- Session routes -----------------------------------------------------
def _dispatch_session_routes(self, request: WsRequest, got: str) -> Response | None:
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
if m:
return self._handle_session_messages(request, m.group(1))
m = re.match(r"^/api/sessions/([^/]+)/webui-thread$", got)
if m:
return self._handle_webui_thread_get(request, m.group(1))
m = re.match(r"^/api/sessions/([^/]+)/delete$", got)
if m:
return self._handle_session_delete(request, m.group(1))
return None
def _handle_sessions_list(self, request: WsRequest) -> Response:
if not self.check_api_token(request):
return _http_error(401, "Unauthorized")
if self.session_manager is None:
return _http_error(503, "session manager unavailable")
sessions = self.session_manager.list_sessions()
from nanobot.session.webui_turns import websocket_turn_wall_started_at
cleaned = []
for s in sessions:
key = s.get("key")
if not (isinstance(key, str) and key.startswith("websocket:")):
continue
row = {k: v for k, v in s.items() if k != "path"}
chat_id = key.split(":", 1)[1]
started_at = websocket_turn_wall_started_at(chat_id)
if started_at is not None:
row["run_started_at"] = started_at
scope = self.workspaces.scope_for_session_key(key)
row["workspace_scope"] = scope.payload()
cleaned.append(row)
return _http_json_response({"sessions": cleaned})
def _handle_session_messages(self, request: WsRequest, key: str) -> Response:
if not self.check_api_token(request):
return _http_error(401, "Unauthorized")
if self.session_manager is None:
return _http_error(503, "session manager unavailable")
decoded_key = _decode_api_key(key)
if decoded_key is None:
return _http_error(400, "invalid session key")
if not _is_websocket_channel_session_key(decoded_key):
return _http_error(404, "session not found")
data = self.session_manager.read_session_file(decoded_key)
if data is None:
return _http_error(404, "session not found")
messages = data.get("messages")
if isinstance(messages, list):
scrub_subagent_messages_for_channel(messages)
self._augment_media_urls(data)
return _http_json_response(data)
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
if not self.check_api_token(request):
return _http_error(401, "Unauthorized")
decoded_key = _decode_api_key(key)
if decoded_key is None:
return _http_error(400, "invalid session key")
if not _is_websocket_channel_session_key(decoded_key):
return _http_error(404, "session not found")
scope = self.workspaces.scope_for_session_key(decoded_key)
data = build_webui_thread_response(
decoded_key,
augment_user_media=self._augment_transcript_user_media,
augment_assistant_text=lambda text: rewrite_local_markdown_images(
text,
workspace_path=scope.project_path,
sign_path=self.sign_or_stage_media_path,
),
)
if data is None:
return _http_error(404, "webui thread not found")
data["workspace_scope"] = scope.payload()
return _http_json_response(data)
def _handle_session_delete(self, request: WsRequest, key: str) -> Response:
if not self.check_api_token(request):
return _http_error(401, "Unauthorized")
if self.session_manager is None:
return _http_error(503, "session manager unavailable")
decoded_key = _decode_api_key(key)
if decoded_key is None:
return _http_error(400, "invalid session key")
if not _is_websocket_channel_session_key(decoded_key):
return _http_error(404, "session not found")
deleted = self.session_manager.delete_session(decoded_key)
delete_webui_thread(decoded_key)
return _http_json_response({"deleted": bool(deleted)})
# -- Media routes -------------------------------------------------------
def _dispatch_media_routes(self, request: WsRequest, got: str) -> Response | None:
m = re.match(r"^/api/media/([A-Za-z0-9_-]+)/([A-Za-z0-9_-]+)$", got)
if m:
return self._handle_media_fetch(m.group(1), m.group(2), request)
return None
def _handle_media_fetch(
self, sig: str, payload: str, request: WsRequest | None = None
) -> Response:
return serve_signed_media(
sig,
payload,
secret=self.media_secret,
request=request,
media_dir=lambda channel=None: get_media_dir(channel),
)
def sign_media_path(self, abs_path: Path) -> str | None:
return sign_media_path(
abs_path,
secret=self.media_secret,
media_dir=lambda channel=None: get_media_dir(channel),
)
def sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
return sign_or_stage_media_path(
path,
secret=self.media_secret,
media_dir=lambda channel=None: get_media_dir(channel),
logger=self._log,
)
def rewrite_local_markdown_images(self, text: str) -> str:
return rewrite_local_markdown_images(
text,
workspace_path=self.workspace_path,
sign_path=self.sign_or_stage_media_path,
)
# -- Misc routes --------------------------------------------------------
def _dispatch_misc_routes(
self, connection: Any, request: WsRequest, got: str
) -> Response | None:
if got == "/api/sessions":
return self._handle_sessions_list(request)
if got == "/api/commands":
return self._handle_commands(request)
if got == "/api/workspaces":
return self._handle_workspaces(connection, request)
if got == "/api/webui/sidebar-state":
return self._handle_webui_sidebar_state(request)
if got == "/api/webui/sidebar-state/update":
return self._handle_webui_sidebar_state_update(request)
return None
def _handle_commands(self, request: WsRequest) -> Response:
if not self.check_api_token(request):
return _http_error(401, "Unauthorized")
return _http_json_response({"commands": builtin_command_palette()})
def _handle_workspaces(self, connection: Any, request: WsRequest) -> Response:
if not self.check_api_token(request):
return _http_error(401, "Unauthorized")
return _http_json_response(
self.workspaces.payload(controls_available=_is_localhost(connection))
)
def _handle_webui_sidebar_state(self, request: WsRequest) -> Response:
if not self.check_api_token(request):
return _http_error(401, "Unauthorized")
return _http_json_response(read_webui_sidebar_state())
def _handle_webui_sidebar_state_update(self, request: WsRequest) -> Response:
if not self.check_api_token(request):
return _http_error(401, "Unauthorized")
query = _parse_query(request.path)
raw_state = _query_first(query, "state")
if raw_state is None:
return _http_error(400, "missing state")
try:
decoded = json.loads(raw_state)
except json.JSONDecodeError:
return _http_error(400, "state must be JSON")
if not isinstance(decoded, dict):
return _http_error(400, "state must be an object")
try:
state = write_webui_sidebar_state(decoded)
except ValueError as e:
return _http_error(400, str(e))
except OSError:
self._log.exception("failed to write webui sidebar state")
return _http_error(500, "failed to write sidebar state")
return _http_json_response(state)
# -- Static file serving ------------------------------------------------
def _serve_static(self, request_path: str) -> Response | None:
assert self.static_dist_path is not None
rel = request_path.lstrip("/")
if not rel:
rel = "index.html"
if ".." in rel.split("/") or rel.startswith("/"):
return _http_error(403, "Forbidden")
candidate = (self.static_dist_path / rel).resolve()
try:
candidate.relative_to(self.static_dist_path)
except ValueError:
return _http_error(403, "Forbidden")
if not candidate.is_file():
index = self.static_dist_path / "index.html"
if index.is_file():
candidate = index
else:
return None
try:
body = candidate.read_bytes()
except OSError as e:
self._log.warning("static: failed to read {}: {}", candidate, e)
return _http_error(500, "Internal Server Error")
ctype, _ = mimetypes.guess_type(candidate.name)
if ctype is None:
ctype = "application/octet-stream"
if ctype.startswith("text/") or ctype in {"application/javascript", "application/json"}:
ctype = f"{ctype}; charset=utf-8"
if candidate.name == "index.html":
cache = "no-cache"
else:
cache = "public, max-age=31536000, immutable"
return _http_response(
body,
status=200,
content_type=ctype,
extra_headers=[("Cache-Control", cache)],
)
# -- Media helpers (called by WebSocketChannel.send) --------------------
def _augment_media_urls(self, payload: dict[str, Any]) -> None:
messages = payload.get("messages")
if not isinstance(messages, list):
return
for msg in messages:
if not isinstance(msg, dict):
continue
media = msg.get("media")
if not isinstance(media, list) or not media:
continue
urls: list[dict[str, str]] = []
for entry in media:
if not isinstance(entry, str) or not entry:
continue
signed = self.sign_media_path(Path(entry))
if signed is None:
continue
urls.append({"url": signed, "name": Path(entry).name})
if urls:
msg["media_urls"] = urls
msg.pop("media", None)
def _augment_transcript_user_media(self, paths: list[str]) -> list[dict[str, Any]]:
out: list[dict[str, Any]] = []
for pstr in paths:
path = Path(pstr)
att = self.sign_or_stage_media_path(path)
if att is None:
continue
mime, _ = mimetypes.guess_type(path.name)
kind = "video" if mime and mime.startswith("video/") else "image"
out.append(
{"kind": kind, "url": att["url"], "name": att.get("name", path.name)},
)
return out
def _is_websocket_channel_session_key(key: str) -> bool:
return key.startswith("websocket:")

View File

@ -626,6 +626,7 @@ async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -
return ws_media if channel == "websocket" else media_root
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir)
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock()
channel._attach(mock_ws, "chat-1")
@ -840,6 +841,7 @@ async def test_send_delta_stream_end_rewrites_local_markdown_image(monkeypatch,
return path
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir)
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "streaming": True},
bus,
@ -872,6 +874,7 @@ async def test_send_delta_stream_end_rewrites_inline_final_text(monkeypatch, tmp
return path
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir)
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "streaming": True},
bus,

View File

@ -710,20 +710,20 @@ async def test_api_token_pool_purges_expired(bus: MagicMock, tmp_path: Path) ->
channel = _ch(bus, session_manager=sm, port=29908)
# Don't start a server — directly inject and validate.
import time as _time
channel._api_tokens["expired"] = _time.monotonic() - 1
channel._api_tokens["live"] = _time.monotonic() + 60
channel._http.api_tokens["expired"] = _time.monotonic() - 1
channel._http.api_tokens["live"] = _time.monotonic() + 60
class _FakeReq:
path = "/api/sessions"
headers = {"Authorization": "Bearer expired"}
assert channel._check_api_token(_FakeReq()) is False
assert channel._http.check_api_token(_FakeReq()) is False
class _LiveReq:
path = "/api/sessions"
headers = {"Authorization": "Bearer live"}
assert channel._check_api_token(_LiveReq()) is True
assert channel._http.check_api_token(_LiveReq()) is True
class _FakeConn:
@ -814,7 +814,7 @@ def test_localhost_without_auth_is_valid(bus: MagicMock) -> None:
def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"nanobot.channels.websocket._default_model_name_from_config",
"nanobot.channels.ws_http._default_model_name_from_config",
lambda: "from-disk",
)
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " live/model ")
@ -826,7 +826,7 @@ def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytes
def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"nanobot.channels.websocket._default_model_name_from_config",
"nanobot.channels.ws_http._default_model_name_from_config",
lambda: "from-disk",
)
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " ")
@ -838,7 +838,7 @@ def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeyp
def test_bootstrap_falls_back_when_runtime_raises(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
"nanobot.channels.websocket._default_model_name_from_config",
"nanobot.channels.ws_http._default_model_name_from_config",
lambda: "from-disk",
)

View File

@ -106,7 +106,7 @@ def test_sign_media_path_rejects_paths_outside_media_root(
media = tmp_path / "media"
media.mkdir()
channel = _ch(bus, port=0)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
assert channel._sign_media_path(outside) is None
# Traversal via the media root is also rejected — the resolve() step
# normalises ``..`` out before the relative_to check.
@ -121,7 +121,7 @@ def test_sign_media_path_round_trips_via_hmac(
media.mkdir()
(media / "a.png").write_bytes(_PNG_BYTES)
channel = _ch(bus, port=0)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
url = channel._sign_media_path(media / "a.png")
assert url is not None
assert url.startswith("/api/media/")
@ -144,7 +144,7 @@ def test_local_markdown_image_is_staged_and_rewritten(
media = tmp_path / "media"
channel = _ch(bus, workspace_path=workspace, port=0)
with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)):
with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
rewritten = channel._rewrite_local_markdown_images(
"The result:\n![Cloud Architecture Diagram](demo_arch.png)"
)
@ -166,7 +166,7 @@ def test_local_markdown_video_is_staged_and_rewritten(
media = tmp_path / "media"
channel = _ch(bus, workspace_path=workspace, port=0)
with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)):
with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
rewritten = channel._rewrite_local_markdown_images(
"The result:\n![nanobot-intro.mp4](nanobot-intro.mp4)"
)
@ -189,7 +189,7 @@ def test_local_markdown_image_rejects_workspace_escape(
channel = _ch(bus, workspace_path=workspace, port=0)
text = "![nope](../outside.png)"
with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)):
with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
assert channel._rewrite_local_markdown_images(text) == text
assert not (media / "websocket").exists()
@ -211,7 +211,7 @@ async def test_media_route_serves_signed_file(
target.write_bytes(_PNG_BYTES)
channel = _ch(bus, port=29920)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
assert url_path is not None
server_task = asyncio.create_task(channel.start())
@ -244,7 +244,7 @@ async def test_media_route_serves_video_byte_ranges(
target.write_bytes(b"0123456789")
channel = _ch(bus, port=29927)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
assert url_path is not None
server_task = asyncio.create_task(channel.start())
@ -276,7 +276,7 @@ async def test_media_route_serves_suffix_video_byte_ranges(
target.write_bytes(b"0123456789")
channel = _ch(bus, port=29928)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
assert url_path is not None
server_task = asyncio.create_task(channel.start())
@ -305,7 +305,7 @@ async def test_media_route_rejects_unsatisfiable_byte_range(
target.write_bytes(b"0123456789")
channel = _ch(bus, port=29929)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
assert url_path is not None
server_task = asyncio.create_task(channel.start())
@ -338,7 +338,7 @@ async def test_media_route_rejects_bad_signature(
(media / "f.png").write_bytes(_PNG_BYTES)
channel = _ch(bus, port=29921)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
good = channel._sign_media_path(media / "f.png")
assert good is not None
_, payload = good[len("/api/media/"):].split("/", 1)
@ -381,7 +381,7 @@ async def test_media_route_rejects_path_traversal_payload(
).digest()[:16]
url = f"/api/media/{b64url_encode(mac)}/{payload}"
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
@ -405,7 +405,7 @@ async def test_media_route_404s_missing_file(
target.write_bytes(_PNG_BYTES)
channel = _ch(bus, port=29923)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
assert url_path is not None
target.unlink() # the file vanishes between signing and fetching
@ -433,7 +433,7 @@ async def test_media_route_degrades_non_image_to_octet_stream(
(media / "scary.html").write_bytes(b"<script>alert(1)</script>")
channel = _ch(bus, port=29924)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
payload = b64url_encode(b"scary.html")
mac = hmac.new(
channel._media_secret, payload.encode("ascii"), hashlib.sha256
@ -464,7 +464,7 @@ async def test_media_route_serves_svg_with_strict_csp(
target.write_text("<svg xmlns='http://www.w3.org/2000/svg'><script>alert(1)</script></svg>")
channel = _ch(bus, port=29928)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
url_path = channel._sign_media_path(target)
assert url_path is not None
server_task = asyncio.create_task(channel.start())
@ -505,7 +505,7 @@ async def test_session_messages_exposes_signed_media_urls(
sm.save(sess)
channel = _ch(bus, session_manager=sm, port=29925)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
@ -550,7 +550,7 @@ async def test_session_messages_skips_vanished_media(
sm.save(sess)
channel = _ch(bus, session_manager=sm, port=29926)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try: