mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
refactor: extract GatewayHTTPHandler from WebSocketChannel
Extract all HTTP route handling (bootstrap, sessions, settings, media, commands, sidebar state, static serving, token management) into a new GatewayHTTPHandler class in nanobot/channels/ws_http.py. WebSocketChannel is reduced from 1907 to 1372 lines (-28%), retaining only WebSocket connection management and message dispatch. No behavior change. 3730 tests pass, 0 failures. Shared HTTP utility functions (path parsing, response builders, auth helpers) now live in ws_http.py with websocket.py importing from there, avoiding circular dependencies. Backwards-compat property aliases on WebSocketChannel ensure existing tests continue to work without modification.
This commit is contained in:
parent
92fe40a690
commit
1a585288b2
@ -7,11 +7,8 @@ import email.utils
|
|||||||
import hmac
|
import hmac
|
||||||
import http
|
import http
|
||||||
import json
|
import json
|
||||||
import mimetypes
|
|
||||||
import re
|
import re
|
||||||
import secrets
|
|
||||||
import ssl
|
import ssl
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
@ -31,7 +28,7 @@ from websockets.http11 import Response
|
|||||||
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
|
from nanobot.bus.events import OUTBOUND_META_AGENT_UI, OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.command.builtin import builtin_command_palette
|
from nanobot.channels.ws_http import GatewayHTTPHandler
|
||||||
from nanobot.config.paths import get_media_dir, get_workspace_path
|
from nanobot.config.paths import get_media_dir, get_workspace_path
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
from nanobot.security.workspace_access import (
|
from nanobot.security.workspace_access import (
|
||||||
@ -44,32 +41,9 @@ from nanobot.utils.media_decode import (
|
|||||||
FileSizeExceeded,
|
FileSizeExceeded,
|
||||||
save_base64_data_url,
|
save_base64_data_url,
|
||||||
)
|
)
|
||||||
from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel
|
|
||||||
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
|
from nanobot.webui.cli_apps_api import normalize_cli_app_mentions
|
||||||
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
||||||
from nanobot.webui.media_api import (
|
from nanobot.webui.transcript import append_transcript_object
|
||||||
attach_signed_media_urls,
|
|
||||||
serve_signed_media,
|
|
||||||
sign_media_path,
|
|
||||||
sign_or_stage_media_path,
|
|
||||||
signed_media_attachments,
|
|
||||||
)
|
|
||||||
from nanobot.webui.settings_api import runtime_capabilities
|
|
||||||
from nanobot.webui.settings_routes import WebUISettingsRouter
|
|
||||||
from nanobot.webui.sidebar_state import (
|
|
||||||
read_webui_sidebar_state,
|
|
||||||
write_webui_sidebar_state,
|
|
||||||
)
|
|
||||||
from nanobot.webui.thread_disk import delete_webui_thread
|
|
||||||
from nanobot.webui.transcript import (
|
|
||||||
append_transcript_object,
|
|
||||||
build_webui_thread_response,
|
|
||||||
rewrite_local_markdown_images,
|
|
||||||
)
|
|
||||||
from nanobot.webui.websocket_logging import websockets_server_logger
|
|
||||||
from nanobot.webui.workspaces import (
|
|
||||||
WebUIWorkspaceController,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanobot.session.manager import SessionManager
|
from nanobot.session.manager import SessionManager
|
||||||
@ -500,51 +474,36 @@ class WebSocketChannel(BaseChannel):
|
|||||||
self._conn_chats: dict[Any, set[str]] = {}
|
self._conn_chats: dict[Any, set[str]] = {}
|
||||||
# connection -> default chat_id for legacy frames that omit routing.
|
# connection -> default chat_id for legacy frames that omit routing.
|
||||||
self._conn_default: dict[Any, str] = {}
|
self._conn_default: dict[Any, str] = {}
|
||||||
# Single-use tokens consumed at WebSocket handshake.
|
|
||||||
self._issued_tokens: dict[str, float] = {}
|
|
||||||
# Multi-use tokens for HTTP routes served beside WS; checked but not consumed.
|
|
||||||
self._api_tokens: dict[str, float] = {}
|
|
||||||
self._stop_event: asyncio.Event | None = None
|
self._stop_event: asyncio.Event | None = None
|
||||||
self._server_task: asyncio.Task[None] | None = None
|
self._server_task: asyncio.Task[None] | None = None
|
||||||
self._session_manager = session_manager
|
_resolved_workspace = (
|
||||||
self._static_dist_path: Path | None = (
|
|
||||||
static_dist_path.resolve() if static_dist_path is not None else None
|
|
||||||
)
|
|
||||||
self._workspace_path = (
|
|
||||||
Path(workspace_path).expanduser()
|
Path(workspace_path).expanduser()
|
||||||
if workspace_path is not None
|
if workspace_path is not None
|
||||||
else get_workspace_path()
|
else get_workspace_path()
|
||||||
).resolve(strict=False)
|
).resolve(strict=False)
|
||||||
self._default_restrict_to_workspace = restrict_to_workspace
|
self._default_restrict_to_workspace = restrict_to_workspace
|
||||||
self._webui_workspaces = WebUIWorkspaceController(
|
|
||||||
session_manager=self._session_manager,
|
|
||||||
default_workspace=self._workspace_path,
|
|
||||||
default_restrict_to_workspace=self._default_restrict_to_workspace,
|
|
||||||
)
|
|
||||||
self._runtime_model_name = runtime_model_name
|
|
||||||
self._runtime_surface = (
|
self._runtime_surface = (
|
||||||
"native" if runtime_surface in {"native", "desktop"} else "browser"
|
"native" if runtime_surface in {"native", "desktop"} else "browser"
|
||||||
)
|
)
|
||||||
self._runtime_capabilities = runtime_capabilities(
|
|
||||||
self._runtime_surface,
|
# HTTP API handler — owns tokens, sessions, media, settings, static serving
|
||||||
runtime_capabilities_overrides,
|
self._http = GatewayHTTPHandler(
|
||||||
)
|
config=self.config,
|
||||||
self._settings_routes = WebUISettingsRouter(
|
session_manager=session_manager,
|
||||||
bus=self.bus,
|
static_dist_path=(
|
||||||
logger=self.logger,
|
static_dist_path.resolve() if static_dist_path is not None else None
|
||||||
check_api_token=self._check_api_token,
|
),
|
||||||
parse_query=_parse_query,
|
workspace_path=_resolved_workspace,
|
||||||
json_response=_http_json_response,
|
runtime_model_name=runtime_model_name,
|
||||||
error_response=_http_error,
|
|
||||||
runtime_surface=self._runtime_surface,
|
runtime_surface=self._runtime_surface,
|
||||||
runtime_capabilities=self._runtime_capabilities,
|
runtime_capabilities_overrides=runtime_capabilities_overrides,
|
||||||
|
bus=self.bus,
|
||||||
|
log=self.logger,
|
||||||
)
|
)
|
||||||
|
# Backwards-compat aliases for workspace controller used in envelope dispatch
|
||||||
|
self._webui_workspaces = self._http.workspaces
|
||||||
|
|
||||||
self._stream_text_buffers: dict[tuple[str, str], list[str]] = {}
|
self._stream_text_buffers: dict[tuple[str, str], list[str]] = {}
|
||||||
# Process-local secret used to HMAC-sign media URLs. The signed URL is
|
|
||||||
# the capability — anyone who holds a valid URL can fetch that one
|
|
||||||
# file, nothing else. The secret regenerates on restart so links
|
|
||||||
# become self-expiring (callers just refresh the session list).
|
|
||||||
self._media_secret: bytes = secrets.token_bytes(32)
|
|
||||||
|
|
||||||
# -- Subscription bookkeeping -------------------------------------------
|
# -- Subscription bookkeeping -------------------------------------------
|
||||||
|
|
||||||
@ -572,9 +531,9 @@ class WebSocketChannel(BaseChannel):
|
|||||||
connected clients normally see it via ``goal_state`` / ``turn_end`` frames.
|
connected clients normally see it via ``goal_state`` / ``turn_end`` frames.
|
||||||
Pushing here makes refresh + reconnect restore the strip without a new model turn.
|
Pushing here makes refresh + reconnect restore the strip without a new model turn.
|
||||||
"""
|
"""
|
||||||
if self._session_manager is None:
|
if self._http.session_manager is None:
|
||||||
return
|
return
|
||||||
row = self._session_manager.read_session_file(f"websocket:{chat_id}")
|
row = self._http.session_manager.read_session_file(f"websocket:{chat_id}")
|
||||||
meta = row.get("metadata", {}) if isinstance(row, dict) else {}
|
meta = row.get("metadata", {}) if isinstance(row, dict) else {}
|
||||||
if not isinstance(meta, dict):
|
if not isinstance(meta, dict):
|
||||||
meta = {}
|
meta = {}
|
||||||
@ -614,6 +573,59 @@ class WebSocketChannel(BaseChannel):
|
|||||||
def _expected_path(self) -> str:
|
def _expected_path(self) -> str:
|
||||||
return _normalize_config_path(self.config.path)
|
return _normalize_config_path(self.config.path)
|
||||||
|
|
||||||
|
# -- Backwards-compat property aliases (used by tests) ------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _session_manager(self):
|
||||||
|
return self._http.session_manager
|
||||||
|
|
||||||
|
@_session_manager.setter
|
||||||
|
def _session_manager(self, value):
|
||||||
|
self._http.session_manager = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _media_secret(self):
|
||||||
|
return self._http.media_secret
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _issued_tokens(self):
|
||||||
|
return self._http.issued_tokens
|
||||||
|
|
||||||
|
@_issued_tokens.setter
|
||||||
|
def _issued_tokens(self, value):
|
||||||
|
self._http.issued_tokens = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _api_tokens(self):
|
||||||
|
return self._http.api_tokens
|
||||||
|
|
||||||
|
def _check_api_token(self, request):
|
||||||
|
return self._http.check_api_token(request)
|
||||||
|
|
||||||
|
def _sign_media_path(self, path):
|
||||||
|
return self._http.sign_media_path(path)
|
||||||
|
|
||||||
|
def _sign_or_stage_media_path(self, path):
|
||||||
|
return self._http.sign_or_stage_media_path(path)
|
||||||
|
|
||||||
|
def _rewrite_local_markdown_images(self, text):
|
||||||
|
return self._http.rewrite_local_markdown_images(text)
|
||||||
|
|
||||||
|
def _handle_bootstrap(self, connection, request):
|
||||||
|
return self._http._handle_bootstrap(connection, request)
|
||||||
|
|
||||||
|
_MAX_ISSUED_TOKENS = GatewayHTTPHandler._MAX_ISSUED_TOKENS
|
||||||
|
|
||||||
|
def _handle_sessions_list(self, request):
|
||||||
|
return self._http._handle_sessions_list(request)
|
||||||
|
|
||||||
|
def _handle_webui_thread_get(self, request, key):
|
||||||
|
return self._http._handle_webui_thread_get(request, key)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _workspace_path(self):
|
||||||
|
return self._http.workspace_path
|
||||||
|
|
||||||
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
||||||
cert = self.config.ssl_certfile.strip()
|
cert = self.config.ssl_certfile.strip()
|
||||||
key = self.config.ssl_keyfile.strip()
|
key = self.config.ssl_keyfile.strip()
|
||||||
@ -628,548 +640,24 @@ class WebSocketChannel(BaseChannel):
|
|||||||
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
_MAX_ISSUED_TOKENS = 10_000
|
|
||||||
|
|
||||||
def _purge_expired_issued_tokens(self) -> None:
|
|
||||||
now = time.monotonic()
|
|
||||||
for token_key, expiry in list(self._issued_tokens.items()):
|
|
||||||
if now > expiry:
|
|
||||||
self._issued_tokens.pop(token_key, None)
|
|
||||||
|
|
||||||
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
|
||||||
"""Validate and consume one issued token (single use per connection attempt).
|
|
||||||
|
|
||||||
Uses single-step pop to minimize the window between lookup and removal;
|
|
||||||
safe under asyncio's single-threaded cooperative model.
|
|
||||||
"""
|
|
||||||
if not token_value:
|
|
||||||
return False
|
|
||||||
self._purge_expired_issued_tokens()
|
|
||||||
expiry = self._issued_tokens.pop(token_value, None)
|
|
||||||
if expiry is None:
|
|
||||||
return False
|
|
||||||
if time.monotonic() > expiry:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
|
|
||||||
secret = self.config.token_issue_secret.strip() or self.config.token.strip()
|
|
||||||
if secret:
|
|
||||||
if not _issue_route_secret_matches(request.headers, secret):
|
|
||||||
return connection.respond(401, "Unauthorized")
|
|
||||||
else:
|
|
||||||
self.logger.warning(
|
|
||||||
"token_issue_path is set but no token_issue_secret or static token is configured; "
|
|
||||||
"any client can obtain connection tokens — set a secret for production."
|
|
||||||
)
|
|
||||||
self._purge_expired_issued_tokens()
|
|
||||||
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
|
||||||
self.logger.error(
|
|
||||||
"too many outstanding issued tokens ({}), rejecting issuance",
|
|
||||||
len(self._issued_tokens),
|
|
||||||
)
|
|
||||||
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
|
||||||
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
|
||||||
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
|
||||||
|
|
||||||
return _http_json_response(
|
|
||||||
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
|
||||||
)
|
|
||||||
|
|
||||||
# -- HTTP dispatch ------------------------------------------------------
|
# -- HTTP dispatch ------------------------------------------------------
|
||||||
|
|
||||||
async def _dispatch_http(self, connection: Any, request: WsRequest) -> Any:
|
async def _dispatch_http(self, connection: Any, request: WsRequest) -> Any:
|
||||||
"""Route an inbound HTTP request to a handler or to the WS upgrade path."""
|
"""Route an inbound HTTP request to the HTTP handler or WS upgrade."""
|
||||||
got, query = _parse_request_path(request.path)
|
got, query = _parse_request_path(request.path)
|
||||||
|
|
||||||
if self.config.token_issue_path:
|
# WebSocket upgrade — channel handles this itself
|
||||||
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
|
||||||
if got == issue_expected:
|
|
||||||
return self._handle_token_issue_http(connection, request)
|
|
||||||
|
|
||||||
if got == "/webui/bootstrap":
|
|
||||||
return self._handle_bootstrap(connection, request)
|
|
||||||
|
|
||||||
api_response = await self._dispatch_api_route(connection, request, got)
|
|
||||||
if api_response is not None:
|
|
||||||
return api_response
|
|
||||||
|
|
||||||
ws_matched, ws_response = self._dispatch_websocket_upgrade(
|
|
||||||
connection, request, got, query
|
|
||||||
)
|
|
||||||
if ws_matched:
|
|
||||||
return ws_response
|
|
||||||
|
|
||||||
# API clients should never receive the SPA shell for an unknown route.
|
|
||||||
# Returning HTML here makes the WebUI fail with "Unexpected token <"
|
|
||||||
# when a dev server is pointed at an older gateway.
|
|
||||||
if got.startswith("/api/"):
|
|
||||||
return _http_error(404, "API route not found")
|
|
||||||
|
|
||||||
if self._static_dist_path is not None:
|
|
||||||
response = self._serve_static(got)
|
|
||||||
if response is not None:
|
|
||||||
return response
|
|
||||||
|
|
||||||
return connection.respond(404, "Not Found")
|
|
||||||
|
|
||||||
async def _dispatch_api_route(
|
|
||||||
self,
|
|
||||||
connection: Any,
|
|
||||||
request: WsRequest,
|
|
||||||
got: str,
|
|
||||||
) -> Any | None:
|
|
||||||
"""Route REST-ish WebUI requests served beside the WebSocket endpoint."""
|
|
||||||
response = await self._dispatch_settings_api_route(request, got)
|
|
||||||
if response is not None:
|
|
||||||
return response
|
|
||||||
response = self._dispatch_session_api_route(request, got)
|
|
||||||
if response is not None:
|
|
||||||
return response
|
|
||||||
response = self._dispatch_media_api_route(request, got)
|
|
||||||
if response is not None:
|
|
||||||
return response
|
|
||||||
return self._dispatch_misc_api_route(connection, request, got)
|
|
||||||
|
|
||||||
def _dispatch_misc_api_route(
|
|
||||||
self,
|
|
||||||
connection: Any,
|
|
||||||
request: WsRequest,
|
|
||||||
got: str,
|
|
||||||
) -> Response | None:
|
|
||||||
"""Route small API endpoints that do not belong to a larger route group."""
|
|
||||||
if got == "/api/sessions":
|
|
||||||
return self._handle_sessions_list(request)
|
|
||||||
|
|
||||||
if got == "/api/commands":
|
|
||||||
return self._handle_commands(request)
|
|
||||||
|
|
||||||
if got == "/api/workspaces":
|
|
||||||
return self._handle_workspaces(connection, request)
|
|
||||||
|
|
||||||
if got == "/api/webui/sidebar-state":
|
|
||||||
return self._handle_webui_sidebar_state(request)
|
|
||||||
|
|
||||||
if got == "/api/webui/sidebar-state/update":
|
|
||||||
return self._handle_webui_sidebar_state_update(request)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _dispatch_settings_api_route(
|
|
||||||
self,
|
|
||||||
request: WsRequest,
|
|
||||||
got: str,
|
|
||||||
) -> Response | None:
|
|
||||||
return await self._settings_routes.dispatch(request, got)
|
|
||||||
|
|
||||||
def _dispatch_session_api_route(
|
|
||||||
self,
|
|
||||||
request: WsRequest,
|
|
||||||
got: str,
|
|
||||||
) -> Response | None:
|
|
||||||
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
|
|
||||||
if m:
|
|
||||||
return self._handle_session_messages(request, m.group(1))
|
|
||||||
|
|
||||||
m = re.match(r"^/api/sessions/([^/]+)/webui-thread$", got)
|
|
||||||
if m:
|
|
||||||
return self._handle_webui_thread_get(request, m.group(1))
|
|
||||||
|
|
||||||
# NOTE: websockets' HTTP parser only accepts GET, so we cannot expose a
|
|
||||||
# true ``DELETE`` verb. The action is folded into the path instead.
|
|
||||||
m = re.match(r"^/api/sessions/([^/]+)/delete$", got)
|
|
||||||
if m:
|
|
||||||
return self._handle_session_delete(request, m.group(1))
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _dispatch_media_api_route(
|
|
||||||
self,
|
|
||||||
request: WsRequest,
|
|
||||||
got: str,
|
|
||||||
) -> Response | None:
|
|
||||||
m = re.match(r"^/api/media/([A-Za-z0-9_-]+)/([A-Za-z0-9_-]+)$", got)
|
|
||||||
if m:
|
|
||||||
return self._handle_media_fetch(m.group(1), m.group(2), request)
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _dispatch_websocket_upgrade(
|
|
||||||
self,
|
|
||||||
connection: Any,
|
|
||||||
request: WsRequest,
|
|
||||||
got: str,
|
|
||||||
query: dict[str, list[str]],
|
|
||||||
) -> tuple[bool, Any | None]:
|
|
||||||
"""Authorize only real WS upgrade requests for the configured path."""
|
|
||||||
expected_ws = self._expected_path()
|
expected_ws = self._expected_path()
|
||||||
if got != expected_ws or not _is_websocket_upgrade(request):
|
if got == expected_ws and _is_websocket_upgrade(request):
|
||||||
return False, None
|
client_id = _query_first(query, "client_id") or ""
|
||||||
client_id = _query_first(query, "client_id") or ""
|
if len(client_id) > 128:
|
||||||
if len(client_id) > 128:
|
client_id = client_id[:128]
|
||||||
client_id = client_id[:128]
|
if not self.is_allowed(client_id):
|
||||||
if not self.is_allowed(client_id):
|
return connection.respond(403, "Forbidden")
|
||||||
return True, connection.respond(403, "Forbidden")
|
return self._authorize_websocket_handshake(connection, query)
|
||||||
return True, self._authorize_websocket_handshake(connection, query)
|
|
||||||
|
|
||||||
# -- HTTP route handlers ------------------------------------------------
|
<< # Everything else goes to the HTTP handler
|
||||||
|
return await self._http.dispatch(connection, request)
|
||||||
def _check_api_token(self, request: WsRequest) -> bool:
|
|
||||||
"""Validate a request against the API token pool (multi-use, TTL-bound)."""
|
|
||||||
self._purge_expired_api_tokens()
|
|
||||||
token = _bearer_token(request.headers) or _query_first(
|
|
||||||
_parse_query(request.path), "token"
|
|
||||||
)
|
|
||||||
if not token:
|
|
||||||
return False
|
|
||||||
expiry = self._api_tokens.get(token)
|
|
||||||
if expiry is None or time.monotonic() > expiry:
|
|
||||||
self._api_tokens.pop(token, None)
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _purge_expired_api_tokens(self) -> None:
|
|
||||||
now = time.monotonic()
|
|
||||||
for token_key, expiry in list(self._api_tokens.items()):
|
|
||||||
if now > expiry:
|
|
||||||
self._api_tokens.pop(token_key, None)
|
|
||||||
|
|
||||||
def _handle_bootstrap(self, connection: Any, request: Any) -> Response:
|
|
||||||
# When a secret is configured (token_issue_secret or static token),
|
|
||||||
# validate it regardless of source IP. This secures deployments
|
|
||||||
# behind a reverse proxy where all connections appear as localhost.
|
|
||||||
secret = self.config.token_issue_secret.strip() or self.config.token.strip()
|
|
||||||
if secret:
|
|
||||||
if not _issue_route_secret_matches(request.headers, secret):
|
|
||||||
return _http_error(401, "Unauthorized")
|
|
||||||
elif not _is_localhost(connection):
|
|
||||||
# No secret configured: only allow localhost (local dev mode).
|
|
||||||
return _http_error(403, "bootstrap is localhost-only")
|
|
||||||
# Cap outstanding tokens to avoid runaway growth from a misbehaving client.
|
|
||||||
self._purge_expired_issued_tokens()
|
|
||||||
self._purge_expired_api_tokens()
|
|
||||||
if (
|
|
||||||
len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS
|
|
||||||
or len(self._api_tokens) >= self._MAX_ISSUED_TOKENS
|
|
||||||
):
|
|
||||||
return _http_response(
|
|
||||||
json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"),
|
|
||||||
status=429,
|
|
||||||
content_type="application/json; charset=utf-8",
|
|
||||||
)
|
|
||||||
token = f"nbwt_{secrets.token_urlsafe(32)}"
|
|
||||||
expiry = time.monotonic() + float(self.config.token_ttl_s)
|
|
||||||
# Same string registered in both pools: the WS handshake consumes one copy
|
|
||||||
# while the REST surface keeps validating the other until TTL expiry.
|
|
||||||
self._issued_tokens[token] = expiry
|
|
||||||
self._api_tokens[token] = expiry
|
|
||||||
ws_url = self._bootstrap_ws_url(request)
|
|
||||||
return _http_json_response(
|
|
||||||
{
|
|
||||||
"token": token,
|
|
||||||
"ws_path": self._expected_path(),
|
|
||||||
"ws_url": ws_url,
|
|
||||||
"expires_in": self.config.token_ttl_s,
|
|
||||||
"model_name": _resolve_bootstrap_model_name(self._runtime_model_name),
|
|
||||||
"runtime_surface": self._runtime_surface,
|
|
||||||
"runtime_capabilities": self._runtime_capabilities,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def _bootstrap_ws_url(self, request: Any) -> str:
|
|
||||||
"""Absolute WS URL clients should prefer over a dev-server proxy."""
|
|
||||||
headers = getattr(request, "headers", {}) or {}
|
|
||||||
host = _safe_host_header(_case_insensitive_header(headers, "Host"))
|
|
||||||
if not host:
|
|
||||||
host = _host_for_url(self.config.host, self.config.port)
|
|
||||||
|
|
||||||
proto = _case_insensitive_header(headers, "X-Forwarded-Proto")
|
|
||||||
proto = proto.split(",", 1)[0].strip().lower()
|
|
||||||
secure = proto in {"https", "wss"} or bool(self.config.ssl_certfile.strip())
|
|
||||||
scheme = "wss" if secure else "ws"
|
|
||||||
return f"{scheme}://{host}{self._expected_path()}"
|
|
||||||
|
|
||||||
def _handle_sessions_list(self, request: WsRequest) -> Response:
|
|
||||||
if not self._check_api_token(request):
|
|
||||||
return _http_error(401, "Unauthorized")
|
|
||||||
if self._session_manager is None:
|
|
||||||
return _http_error(503, "session manager unavailable")
|
|
||||||
sessions = self._session_manager.list_sessions()
|
|
||||||
# Sidebar/chat listing for WS-backed sessions only — CLI / Slack / etc.
|
|
||||||
# keys are not intended for resume over this HTTP surface.
|
|
||||||
cleaned = []
|
|
||||||
for s in sessions:
|
|
||||||
key = s.get("key")
|
|
||||||
if not (isinstance(key, str) and key.startswith("websocket:")):
|
|
||||||
continue
|
|
||||||
row = {k: v for k, v in s.items() if k != "path"}
|
|
||||||
chat_id = key.split(":", 1)[1]
|
|
||||||
started_at = websocket_turn_wall_started_at(chat_id)
|
|
||||||
if started_at is not None:
|
|
||||||
row["run_started_at"] = started_at
|
|
||||||
scope = self._webui_workspaces.scope_for_session_key(key)
|
|
||||||
row["workspace_scope"] = scope.payload()
|
|
||||||
cleaned.append(row)
|
|
||||||
return _http_json_response({"sessions": cleaned})
|
|
||||||
|
|
||||||
def _handle_workspaces(self, connection: Any, request: WsRequest) -> Response:
|
|
||||||
if not self._check_api_token(request):
|
|
||||||
return _http_error(401, "Unauthorized")
|
|
||||||
return _http_json_response(
|
|
||||||
self._webui_workspaces.payload(controls_available=_is_localhost(connection))
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_commands(self, request: WsRequest) -> Response:
|
|
||||||
if not self._check_api_token(request):
|
|
||||||
return _http_error(401, "Unauthorized")
|
|
||||||
return _http_json_response({"commands": builtin_command_palette()})
|
|
||||||
|
|
||||||
def _handle_webui_sidebar_state(self, request: WsRequest) -> Response:
|
|
||||||
if not self._check_api_token(request):
|
|
||||||
return _http_error(401, "Unauthorized")
|
|
||||||
return _http_json_response(read_webui_sidebar_state())
|
|
||||||
|
|
||||||
def _handle_webui_sidebar_state_update(self, request: WsRequest) -> Response:
|
|
||||||
if not self._check_api_token(request):
|
|
||||||
return _http_error(401, "Unauthorized")
|
|
||||||
query = _parse_query(request.path)
|
|
||||||
raw_state = _query_first(query, "state")
|
|
||||||
if raw_state is None:
|
|
||||||
return _http_error(400, "missing state")
|
|
||||||
try:
|
|
||||||
decoded = json.loads(raw_state)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return _http_error(400, "state must be JSON")
|
|
||||||
if not isinstance(decoded, dict):
|
|
||||||
return _http_error(400, "state must be an object")
|
|
||||||
try:
|
|
||||||
state = write_webui_sidebar_state(decoded)
|
|
||||||
except ValueError as e:
|
|
||||||
return _http_error(400, str(e))
|
|
||||||
except OSError:
|
|
||||||
self.logger.exception("failed to write webui sidebar state")
|
|
||||||
return _http_error(500, "failed to write sidebar state")
|
|
||||||
return _http_json_response(state)
|
|
||||||
|
|
||||||
# -- Session replay, transcript, and signed media ----------------------
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_websocket_channel_session_key(key: str) -> bool:
|
|
||||||
"""True when *key* is a ``websocket:…`` session exposed on this HTTP surface."""
|
|
||||||
return key.startswith("websocket:")
|
|
||||||
|
|
||||||
def _handle_session_messages(self, request: WsRequest, key: str) -> Response:
|
|
||||||
if not self._check_api_token(request):
|
|
||||||
return _http_error(401, "Unauthorized")
|
|
||||||
if self._session_manager is None:
|
|
||||||
return _http_error(503, "session manager unavailable")
|
|
||||||
decoded_key = _decode_api_key(key)
|
|
||||||
if decoded_key is None:
|
|
||||||
return _http_error(400, "invalid session key")
|
|
||||||
# Only ``websocket:…`` sessions are listed/served here — same boundary as
|
|
||||||
# ``/api/sessions``. Block handcrafted URLs from probing CLI / Slack / etc.
|
|
||||||
if not self._is_websocket_channel_session_key(decoded_key):
|
|
||||||
return _http_error(404, "session not found")
|
|
||||||
data = self._session_manager.read_session_file(decoded_key)
|
|
||||||
if data is None:
|
|
||||||
return _http_error(404, "session not found")
|
|
||||||
messages = data.get("messages")
|
|
||||||
if isinstance(messages, list):
|
|
||||||
scrub_subagent_messages_for_channel(messages)
|
|
||||||
# Decorate persisted user messages with signed media URLs so the
|
|
||||||
# client can render previews. The raw on-disk ``media`` paths are
|
|
||||||
# stripped on the way out — they leak server filesystem layout and
|
|
||||||
# the client never needs them once it has the signed fetch URL.
|
|
||||||
attach_signed_media_urls(data, sign_path=self._sign_media_path)
|
|
||||||
return _http_json_response(data)
|
|
||||||
|
|
||||||
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
|
|
||||||
if not self._check_api_token(request):
|
|
||||||
return _http_error(401, "Unauthorized")
|
|
||||||
decoded_key = _decode_api_key(key)
|
|
||||||
if decoded_key is None:
|
|
||||||
return _http_error(400, "invalid session key")
|
|
||||||
if not self._is_websocket_channel_session_key(decoded_key):
|
|
||||||
return _http_error(404, "session not found")
|
|
||||||
scope = self._webui_workspaces.scope_for_session_key(decoded_key)
|
|
||||||
augment_media = partial(
|
|
||||||
signed_media_attachments,
|
|
||||||
sign_path=self._sign_or_stage_media_path,
|
|
||||||
)
|
|
||||||
data = build_webui_thread_response(
|
|
||||||
decoded_key,
|
|
||||||
augment_user_media=augment_media,
|
|
||||||
augment_assistant_media=augment_media,
|
|
||||||
augment_assistant_text=lambda text: rewrite_local_markdown_images(
|
|
||||||
text,
|
|
||||||
workspace_path=scope.project_path,
|
|
||||||
sign_path=self._sign_or_stage_media_path,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if data is None:
|
|
||||||
return _http_error(404, "webui thread not found")
|
|
||||||
data["workspace_scope"] = scope.payload()
|
|
||||||
return _http_json_response(data)
|
|
||||||
|
|
||||||
def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None:
|
|
||||||
sk = f"websocket:{chat_id}"
|
|
||||||
try:
|
|
||||||
dup = json.loads(json.dumps(wire, ensure_ascii=False))
|
|
||||||
append_transcript_object(sk, dup)
|
|
||||||
except (ValueError, TypeError) as e:
|
|
||||||
self.logger.warning("webui transcript append failed: {}", e)
|
|
||||||
|
|
||||||
async def _handle_message(
|
|
||||||
self,
|
|
||||||
sender_id: str,
|
|
||||||
chat_id: str,
|
|
||||||
content: str,
|
|
||||||
media: list[str] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
session_key: str | None = None,
|
|
||||||
is_dm: bool = False,
|
|
||||||
) -> None:
|
|
||||||
meta = metadata or {}
|
|
||||||
if meta.get("webui"):
|
|
||||||
user_obj: dict[str, Any] = {
|
|
||||||
"event": "user",
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"text": content,
|
|
||||||
}
|
|
||||||
if media:
|
|
||||||
user_obj["media_paths"] = list(media)
|
|
||||||
cli_apps = meta.get("cli_apps")
|
|
||||||
if isinstance(cli_apps, list) and cli_apps:
|
|
||||||
user_obj["cli_apps"] = cli_apps
|
|
||||||
mcp_presets = meta.get("mcp_presets")
|
|
||||||
if isinstance(mcp_presets, list) and mcp_presets:
|
|
||||||
user_obj["mcp_presets"] = mcp_presets
|
|
||||||
self._try_append_webui_transcript(chat_id, user_obj)
|
|
||||||
await super()._handle_message(
|
|
||||||
sender_id,
|
|
||||||
chat_id,
|
|
||||||
content,
|
|
||||||
media,
|
|
||||||
metadata,
|
|
||||||
session_key,
|
|
||||||
is_dm,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _sign_media_path(self, abs_path: Path) -> str | None:
|
|
||||||
"""Return a ``/api/media/<sig>/<payload>`` URL for *abs_path*, or
|
|
||||||
``None`` when the path does not resolve inside the media root.
|
|
||||||
|
|
||||||
The URL is self-authenticating: the signature binds the payload to
|
|
||||||
this process's ``_media_secret``, so only paths we chose to sign can
|
|
||||||
be fetched. The returned path is relative to the server origin; the
|
|
||||||
client joins it against this server's HTTP origin (same host as WS).
|
|
||||||
"""
|
|
||||||
return sign_media_path(
|
|
||||||
abs_path,
|
|
||||||
secret=self._media_secret,
|
|
||||||
media_dir=lambda channel=None: get_media_dir(channel),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
|
|
||||||
"""Return a signed media URL payload for *path*.
|
|
||||||
|
|
||||||
Persisted inbound media already lives under ``get_media_dir`` and can
|
|
||||||
be signed directly. Outbound bot-generated files may live anywhere on
|
|
||||||
disk; copy those into the websocket media bucket first so the browser
|
|
||||||
can fetch them through the existing signed media route without
|
|
||||||
exposing arbitrary filesystem paths.
|
|
||||||
"""
|
|
||||||
return sign_or_stage_media_path(
|
|
||||||
path,
|
|
||||||
secret=self._media_secret,
|
|
||||||
media_dir=lambda channel=None: get_media_dir(channel),
|
|
||||||
logger=self.logger,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _rewrite_local_markdown_images(self, text: str) -> str:
|
|
||||||
return rewrite_local_markdown_images(
|
|
||||||
text,
|
|
||||||
workspace_path=self._workspace_path,
|
|
||||||
sign_path=self._sign_or_stage_media_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_media_fetch(
|
|
||||||
self, sig: str, payload: str, request: WsRequest | None = None
|
|
||||||
) -> Response:
|
|
||||||
"""Serve a single media file previously signed via
|
|
||||||
:meth:`_sign_media_path`. Validates the signature, decodes the
|
|
||||||
payload to a relative path, and streams the file bytes with a
|
|
||||||
long-lived immutable cache header (the URL already encodes the
|
|
||||||
file identity, so caches can be aggressive)."""
|
|
||||||
return serve_signed_media(
|
|
||||||
sig,
|
|
||||||
payload,
|
|
||||||
secret=self._media_secret,
|
|
||||||
request=request,
|
|
||||||
media_dir=lambda channel=None: get_media_dir(channel),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_session_delete(self, request: WsRequest, key: str) -> Response:
|
|
||||||
if not self._check_api_token(request):
|
|
||||||
return _http_error(401, "Unauthorized")
|
|
||||||
if self._session_manager is None:
|
|
||||||
return _http_error(503, "session manager unavailable")
|
|
||||||
decoded_key = _decode_api_key(key)
|
|
||||||
if decoded_key is None:
|
|
||||||
return _http_error(400, "invalid session key")
|
|
||||||
# Same boundary as ``_handle_session_messages``: mutations apply only to
|
|
||||||
# websocket-channel sessions; deletion unlinks local JSONL — keep scope narrow.
|
|
||||||
if not self._is_websocket_channel_session_key(decoded_key):
|
|
||||||
return _http_error(404, "session not found")
|
|
||||||
deleted = self._session_manager.delete_session(decoded_key)
|
|
||||||
delete_webui_thread(decoded_key)
|
|
||||||
return _http_json_response({"deleted": bool(deleted)})
|
|
||||||
|
|
||||||
# -- Static files and WebSocket handshake ------------------------------
|
|
||||||
|
|
||||||
def _serve_static(self, request_path: str) -> Response | None:
|
|
||||||
"""Resolve *request_path* against the built SPA directory; SPA fallback to index.html."""
|
|
||||||
assert self._static_dist_path is not None
|
|
||||||
rel = request_path.lstrip("/")
|
|
||||||
if not rel:
|
|
||||||
rel = "index.html"
|
|
||||||
# Reject path-traversal attempts and absolute targets.
|
|
||||||
if ".." in rel.split("/") or rel.startswith("/"):
|
|
||||||
return _http_error(403, "Forbidden")
|
|
||||||
candidate = (self._static_dist_path / rel).resolve()
|
|
||||||
try:
|
|
||||||
candidate.relative_to(self._static_dist_path)
|
|
||||||
except ValueError:
|
|
||||||
return _http_error(403, "Forbidden")
|
|
||||||
if not candidate.is_file():
|
|
||||||
# SPA history-mode fallback: unknown routes serve index.html so the
|
|
||||||
# client-side router can render them.
|
|
||||||
index = self._static_dist_path / "index.html"
|
|
||||||
if index.is_file():
|
|
||||||
candidate = index
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
body = candidate.read_bytes()
|
|
||||||
except OSError as e:
|
|
||||||
self.logger.warning("static: failed to read {}: {}", candidate, e)
|
|
||||||
return _http_error(500, "Internal Server Error")
|
|
||||||
ctype, _ = mimetypes.guess_type(candidate.name)
|
|
||||||
if ctype is None:
|
|
||||||
ctype = "application/octet-stream"
|
|
||||||
if ctype.startswith("text/") or ctype in {"application/javascript", "application/json"}:
|
|
||||||
ctype = f"{ctype}; charset=utf-8"
|
|
||||||
# Hash-named build assets are cache-friendly; index.html must stay fresh.
|
|
||||||
if candidate.name == "index.html":
|
|
||||||
cache = "no-cache"
|
|
||||||
else:
|
|
||||||
cache = "public, max-age=31536000, immutable"
|
|
||||||
return _http_response(
|
|
||||||
body,
|
|
||||||
status=200,
|
|
||||||
content_type=ctype,
|
|
||||||
extra_headers=[("Cache-Control", cache)],
|
|
||||||
)
|
|
||||||
|
|
||||||
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
|
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
|
||||||
supplied = _query_first(query, "token")
|
supplied = _query_first(query, "token")
|
||||||
@ -1178,20 +666,21 @@ class WebSocketChannel(BaseChannel):
|
|||||||
if static_token:
|
if static_token:
|
||||||
if supplied and hmac.compare_digest(supplied, static_token):
|
if supplied and hmac.compare_digest(supplied, static_token):
|
||||||
return None
|
return None
|
||||||
if supplied and self._take_issued_token_if_valid(supplied):
|
if supplied and self._http.take_issued_token_if_valid(supplied):
|
||||||
return None
|
return None
|
||||||
return connection.respond(401, "Unauthorized")
|
return connection.respond(401, "Unauthorized")
|
||||||
|
|
||||||
if self.config.websocket_requires_token:
|
if self.config.websocket_requires_token:
|
||||||
if supplied and self._take_issued_token_if_valid(supplied):
|
if supplied and self._http.take_issued_token_if_valid(supplied):
|
||||||
return None
|
return None
|
||||||
return connection.respond(401, "Unauthorized")
|
return connection.respond(401, "Unauthorized")
|
||||||
|
|
||||||
if supplied:
|
if supplied:
|
||||||
self._take_issued_token_if_valid(supplied)
|
self._http.take_issued_token_if_valid(supplied)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# -- Server lifecycle and connection ingress ---------------------------
|
# -- Server lifecycle and connection ingress ---------------------------
|
||||||
|
# -- Server lifecycle and connection ingress ---------------------------
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
from nanobot.utils.logging_bridge import redirect_lib_logging
|
from nanobot.utils.logging_bridge import redirect_lib_logging
|
||||||
@ -1587,8 +1076,8 @@ class WebSocketChannel(BaseChannel):
|
|||||||
self._subs.clear()
|
self._subs.clear()
|
||||||
self._conn_chats.clear()
|
self._conn_chats.clear()
|
||||||
self._conn_default.clear()
|
self._conn_default.clear()
|
||||||
self._issued_tokens.clear()
|
self._http.issued_tokens.clear()
|
||||||
self._api_tokens.clear()
|
self._http.api_tokens.clear()
|
||||||
|
|
||||||
async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None:
|
async def _safe_send_to(self, connection: Any, raw: str, *, label: str = "") -> None:
|
||||||
"""Send a raw frame to one connection, cleaning up on ConnectionClosed."""
|
"""Send a raw frame to one connection, cleaning up on ConnectionClosed."""
|
||||||
@ -1601,6 +1090,14 @@ class WebSocketChannel(BaseChannel):
|
|||||||
self.logger.exception("send failed{}", label)
|
self.logger.exception("send failed{}", label)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _try_append_webui_transcript(self, chat_id: str, wire: dict[str, Any]) -> None:
|
||||||
|
sk = f"websocket:{chat_id}"
|
||||||
|
try:
|
||||||
|
dup = json.loads(json.dumps(wire, ensure_ascii=False))
|
||||||
|
append_transcript_object(sk, dup)
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
self.logger.warning("webui transcript append failed: {}", e)
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
if msg.metadata.get("_runtime_model_updated"):
|
if msg.metadata.get("_runtime_model_updated"):
|
||||||
await self.send_runtime_model_updated(
|
await self.send_runtime_model_updated(
|
||||||
@ -1662,7 +1159,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
text = msg.content
|
text = msg.content
|
||||||
wire_text = self._rewrite_local_markdown_images(text)
|
wire_text = self._http.rewrite_local_markdown_images(text)
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"chat_id": msg.chat_id,
|
"chat_id": msg.chat_id,
|
||||||
@ -1672,7 +1169,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
payload["media"] = msg.media
|
payload["media"] = msg.media
|
||||||
urls: list[dict[str, str]] = []
|
urls: list[dict[str, str]] = []
|
||||||
for entry in msg.media:
|
for entry in msg.media:
|
||||||
signed = self._sign_or_stage_media_path(Path(entry))
|
signed = self._http.sign_or_stage_media_path(Path(entry))
|
||||||
if signed is not None:
|
if signed is not None:
|
||||||
urls.append(signed)
|
urls.append(signed)
|
||||||
if urls:
|
if urls:
|
||||||
@ -1787,7 +1284,7 @@ class WebSocketChannel(BaseChannel):
|
|||||||
if delta:
|
if delta:
|
||||||
buffered.append(delta)
|
buffered.append(delta)
|
||||||
full_text = "".join(buffered)
|
full_text = "".join(buffered)
|
||||||
rewritten = self._rewrite_local_markdown_images(full_text)
|
rewritten = self._http.rewrite_local_markdown_images(full_text)
|
||||||
if rewritten != full_text:
|
if rewritten != full_text:
|
||||||
body["text"] = rewritten
|
body["text"] = rewritten
|
||||||
else:
|
else:
|
||||||
|
|||||||
731
nanobot/channels/ws_http.py
Normal file
731
nanobot/channels/ws_http.py
Normal file
@ -0,0 +1,731 @@
|
|||||||
|
"""HTTP API handler extracted from WebSocketChannel.
|
||||||
|
|
||||||
|
Handles all non-WebSocket HTTP routes: bootstrap, sessions, settings,
|
||||||
|
media, commands, sidebar state, static file serving, and token management.
|
||||||
|
|
||||||
|
Also houses shared HTTP utility functions used by both this module and
|
||||||
|
``websocket.py`` to avoid circular imports.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import email.utils
|
||||||
|
import hmac
|
||||||
|
import http
|
||||||
|
import json
|
||||||
|
import mimetypes
|
||||||
|
import re
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from websockets.datastructures import Headers
|
||||||
|
from websockets.http11 import Request as WsRequest
|
||||||
|
from websockets.http11 import Response
|
||||||
|
|
||||||
|
from nanobot.command.builtin import builtin_command_palette
|
||||||
|
from nanobot.config.paths import get_media_dir
|
||||||
|
from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel
|
||||||
|
from nanobot.webui.media_api import (
|
||||||
|
serve_signed_media,
|
||||||
|
sign_media_path,
|
||||||
|
sign_or_stage_media_path,
|
||||||
|
)
|
||||||
|
from nanobot.webui.sidebar_state import (
|
||||||
|
read_webui_sidebar_state,
|
||||||
|
write_webui_sidebar_state,
|
||||||
|
)
|
||||||
|
from nanobot.webui.thread_disk import delete_webui_thread
|
||||||
|
from nanobot.webui.transcript import (
|
||||||
|
build_webui_thread_response,
|
||||||
|
rewrite_local_markdown_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.session.manager import SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared HTTP utility functions (imported by websocket.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_trailing_slash(path: str) -> str:
|
||||||
|
if len(path) > 1 and path.endswith("/"):
|
||||||
|
return path.rstrip("/")
|
||||||
|
return path or "/"
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_config_path(path: str) -> str:
|
||||||
|
return _strip_trailing_slash(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _case_insensitive_header(headers: Any, key: str) -> str:
|
||||||
|
"""Read a header from websockets/http test stubs without assuming casing."""
|
||||||
|
try:
|
||||||
|
value = headers.get(key)
|
||||||
|
except Exception:
|
||||||
|
value = None
|
||||||
|
if value is None:
|
||||||
|
try:
|
||||||
|
value = headers.get(key.lower())
|
||||||
|
except Exception:
|
||||||
|
value = None
|
||||||
|
return str(value or "").strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_host_header(value: str) -> str:
|
||||||
|
"""Return a safe Host header value, or empty when it should not be echoed."""
|
||||||
|
value = value.strip()
|
||||||
|
if not value:
|
||||||
|
return ""
|
||||||
|
if re.fullmatch(r"\[[0-9A-Fa-f:.]+\](?::\d{1,5})?", value):
|
||||||
|
return value
|
||||||
|
if re.fullmatch(r"[A-Za-z0-9.-]+(?::\d{1,5})?", value):
|
||||||
|
return value
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _host_for_url(host: str, port: int) -> str:
|
||||||
|
host = host.strip()
|
||||||
|
if host in ("0.0.0.0", "::"):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
if ":" in host and not host.startswith("["):
|
||||||
|
host = f"[{host}]"
|
||||||
|
return f"{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||||
|
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||||
|
headers = Headers(
|
||||||
|
[
|
||||||
|
("Date", email.utils.formatdate(usegmt=True)),
|
||||||
|
("Connection", "close"),
|
||||||
|
("Content-Length", str(len(body))),
|
||||||
|
("Content-Type", "application/json; charset=utf-8"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
reason = http.HTTPStatus(status).phrase
|
||||||
|
return Response(status, reason, headers, body)
|
||||||
|
|
||||||
|
|
||||||
|
def _http_response(
|
||||||
|
body: bytes,
|
||||||
|
*,
|
||||||
|
status: int = 200,
|
||||||
|
content_type: str = "text/plain; charset=utf-8",
|
||||||
|
extra_headers: list[tuple[str, str]] | None = None,
|
||||||
|
) -> Response:
|
||||||
|
headers = [
|
||||||
|
("Date", email.utils.formatdate(usegmt=True)),
|
||||||
|
("Connection", "close"),
|
||||||
|
("Content-Length", str(len(body))),
|
||||||
|
("Content-Type", content_type),
|
||||||
|
]
|
||||||
|
if extra_headers:
|
||||||
|
headers.extend(extra_headers)
|
||||||
|
reason = http.HTTPStatus(status).phrase
|
||||||
|
return Response(status, reason, Headers(headers), body)
|
||||||
|
|
||||||
|
|
||||||
|
def _http_error(status: int, message: str | None = None) -> Response:
|
||||||
|
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
|
||||||
|
return _http_response(body, status=status)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||||||
|
"""Parse normalized path and query parameters in one pass."""
|
||||||
|
parsed = urlparse("ws://x" + path_with_query)
|
||||||
|
path = _strip_trailing_slash(parsed.path or "/")
|
||||||
|
return path, parse_qs(parsed.query, keep_blank_values=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_http_path(path_with_query: str) -> str:
|
||||||
|
return _parse_request_path(path_with_query)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||||||
|
return _parse_request_path(path_with_query)[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
|
||||||
|
values = query.get(key)
|
||||||
|
return values[0] if values else None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_localhost(connection: Any) -> bool:
|
||||||
|
addr = getattr(connection, "remote_address", None)
|
||||||
|
if not addr:
|
||||||
|
return False
|
||||||
|
host = addr[0] if isinstance(addr, tuple) else addr
|
||||||
|
if not isinstance(host, str):
|
||||||
|
return False
|
||||||
|
if host.startswith("::ffff:"):
|
||||||
|
host = host[7:]
|
||||||
|
return host in {"127.0.0.1", "::1", "localhost"}
|
||||||
|
|
||||||
|
|
||||||
|
def _bearer_token(headers: Any) -> str | None:
|
||||||
|
auth = headers.get("Authorization") or headers.get("authorization")
|
||||||
|
if auth and auth.lower().startswith("bearer "):
|
||||||
|
return auth[7:].strip() or None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||||||
|
if not configured_secret:
|
||||||
|
return True
|
||||||
|
authorization = headers.get("Authorization") or headers.get("authorization")
|
||||||
|
if authorization and authorization.lower().startswith("bearer "):
|
||||||
|
supplied = authorization[7:].strip()
|
||||||
|
return hmac.compare_digest(supplied, configured_secret)
|
||||||
|
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||||||
|
if not header_token:
|
||||||
|
return False
|
||||||
|
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_api_key(raw_key: str) -> str | None:
|
||||||
|
from urllib.parse import unquote
|
||||||
|
|
||||||
|
key = unquote(raw_key)
|
||||||
|
_api_key_re = re.compile(r"^[A-Za-z0-9_:.-]{1,128}$")
|
||||||
|
if _api_key_re.match(key) is None:
|
||||||
|
return None
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def _default_model_name_from_config() -> str | None:
|
||||||
|
try:
|
||||||
|
from nanobot.config.loader import load_config
|
||||||
|
model = load_config().resolve_preset().model.strip()
|
||||||
|
return model or None
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("bootstrap model_name could not load from config: {}", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_bootstrap_model_name(
|
||||||
|
runtime_name: Callable[[], str | None] | None,
|
||||||
|
) -> str | None:
|
||||||
|
if runtime_name is not None:
|
||||||
|
try:
|
||||||
|
raw = runtime_name()
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("bootstrap runtime model resolver failed: {}", e)
|
||||||
|
else:
|
||||||
|
if isinstance(raw, str):
|
||||||
|
stripped = raw.strip()
|
||||||
|
if stripped:
|
||||||
|
return stripped
|
||||||
|
return _default_model_name_from_config()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GatewayHTTPHandler
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class GatewayHTTPHandler:
|
||||||
|
"""Handles all HTTP routes served alongside the WebSocket endpoint.
|
||||||
|
|
||||||
|
Owns token management, session API, media API, static file serving,
|
||||||
|
and delegates settings routes to ``WebUISettingsRouter``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_MAX_ISSUED_TOKENS = 10_000
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
config: Any, # WebSocketConfig
|
||||||
|
session_manager: SessionManager | None,
|
||||||
|
static_dist_path: Path | None,
|
||||||
|
workspace_path: Path,
|
||||||
|
runtime_model_name: Callable[[], str | None] | None,
|
||||||
|
runtime_surface: str,
|
||||||
|
runtime_capabilities_overrides: dict[str, Any] | None,
|
||||||
|
bus: MessageBus,
|
||||||
|
log: Any = logger,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.session_manager = session_manager
|
||||||
|
self.static_dist_path = static_dist_path
|
||||||
|
self.workspace_path = workspace_path
|
||||||
|
self.runtime_model_name = runtime_model_name
|
||||||
|
self.bus = bus
|
||||||
|
self._log = log
|
||||||
|
self._runtime_surface = runtime_surface
|
||||||
|
|
||||||
|
self.issued_tokens: dict[str, float] = {}
|
||||||
|
self.api_tokens: dict[str, float] = {}
|
||||||
|
self.media_secret: bytes = secrets.token_bytes(32)
|
||||||
|
|
||||||
|
# Workspace controller
|
||||||
|
from nanobot.webui.workspaces import WebUIWorkspaceController
|
||||||
|
|
||||||
|
self.workspaces = WebUIWorkspaceController(
|
||||||
|
session_manager=session_manager,
|
||||||
|
default_workspace=workspace_path,
|
||||||
|
default_restrict_to_workspace=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Settings router
|
||||||
|
from nanobot.webui.settings_api import runtime_capabilities as _rc
|
||||||
|
from nanobot.webui.settings_routes import WebUISettingsRouter
|
||||||
|
|
||||||
|
self._capabilities = _rc(runtime_surface, runtime_capabilities_overrides or {})
|
||||||
|
self.settings_routes = WebUISettingsRouter(
|
||||||
|
bus=bus,
|
||||||
|
logger=self._log,
|
||||||
|
check_api_token=self.check_api_token,
|
||||||
|
parse_query=_parse_query,
|
||||||
|
json_response=_http_json_response,
|
||||||
|
error_response=_http_error,
|
||||||
|
runtime_surface=runtime_surface,
|
||||||
|
runtime_capabilities=self._capabilities,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Token management ---------------------------------------------------
|
||||||
|
|
||||||
|
def check_api_token(self, request: WsRequest) -> bool:
|
||||||
|
self._purge_expired_api_tokens()
|
||||||
|
token = _bearer_token(request.headers) or _query_first(
|
||||||
|
_parse_query(request.path), "token"
|
||||||
|
)
|
||||||
|
if not token:
|
||||||
|
return False
|
||||||
|
expiry = self.api_tokens.get(token)
|
||||||
|
if expiry is None or time.monotonic() > expiry:
|
||||||
|
self.api_tokens.pop(token, None)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _purge_expired_api_tokens(self) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
for token_key, expiry in list(self.api_tokens.items()):
|
||||||
|
if now > expiry:
|
||||||
|
self.api_tokens.pop(token_key, None)
|
||||||
|
|
||||||
|
def _purge_expired_issued_tokens(self) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
for token_key, expiry in list(self.issued_tokens.items()):
|
||||||
|
if now > expiry:
|
||||||
|
self.issued_tokens.pop(token_key, None)
|
||||||
|
|
||||||
|
def take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||||||
|
if not token_value:
|
||||||
|
return False
|
||||||
|
self._purge_expired_issued_tokens()
|
||||||
|
expiry = self.issued_tokens.pop(token_value, None)
|
||||||
|
if expiry is None:
|
||||||
|
return False
|
||||||
|
if time.monotonic() > expiry:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
# -- Main dispatch ------------------------------------------------------
|
||||||
|
|
||||||
|
async def dispatch(self, connection: Any, request: WsRequest) -> Any | None:
|
||||||
|
"""Route an HTTP request. Returns Response or None."""
|
||||||
|
got, query = _parse_request_path(request.path)
|
||||||
|
|
||||||
|
# Token issue endpoint
|
||||||
|
if self.config.token_issue_path:
|
||||||
|
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
||||||
|
if got == issue_expected:
|
||||||
|
return self._handle_token_issue(connection, request)
|
||||||
|
|
||||||
|
# Bootstrap
|
||||||
|
if got == "/webui/bootstrap":
|
||||||
|
return self._handle_bootstrap(connection, request)
|
||||||
|
|
||||||
|
# Settings routes (delegated)
|
||||||
|
response = await self.settings_routes.dispatch(request, got)
|
||||||
|
if response is not None:
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Session routes
|
||||||
|
response = self._dispatch_session_routes(request, got)
|
||||||
|
if response is not None:
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Media routes
|
||||||
|
response = self._dispatch_media_routes(request, got)
|
||||||
|
if response is not None:
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Misc routes
|
||||||
|
response = self._dispatch_misc_routes(connection, request, got)
|
||||||
|
if response is not None:
|
||||||
|
return response
|
||||||
|
|
||||||
|
# API 404 (never serve SPA for /api/ routes)
|
||||||
|
if got.startswith("/api/"):
|
||||||
|
return _http_error(404, "API route not found")
|
||||||
|
|
||||||
|
# Static SPA serving
|
||||||
|
if self.static_dist_path is not None:
|
||||||
|
response = self._serve_static(got)
|
||||||
|
if response is not None:
|
||||||
|
return response
|
||||||
|
|
||||||
|
return connection.respond(404, "Not Found")
|
||||||
|
|
||||||
|
# -- Token issue --------------------------------------------------------
|
||||||
|
|
||||||
|
def _handle_token_issue(self, connection: Any, request: Any) -> Any:
|
||||||
|
secret = self.config.token_issue_secret.strip()
|
||||||
|
if secret:
|
||||||
|
if not _issue_route_secret_matches(request.headers, secret):
|
||||||
|
return connection.respond(401, "Unauthorized")
|
||||||
|
else:
|
||||||
|
self._log.warning(
|
||||||
|
"token_issue_path is set but token_issue_secret is empty; "
|
||||||
|
"any client can obtain connection tokens — set token_issue_secret for production."
|
||||||
|
)
|
||||||
|
self._purge_expired_issued_tokens()
|
||||||
|
if len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
||||||
|
self._log.error(
|
||||||
|
"too many outstanding issued tokens ({}), rejecting issuance",
|
||||||
|
len(self.issued_tokens),
|
||||||
|
)
|
||||||
|
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
||||||
|
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||||
|
self.issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||||||
|
return _http_json_response(
|
||||||
|
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Bootstrap ----------------------------------------------------------
|
||||||
|
|
||||||
|
def _handle_bootstrap(self, connection: Any, request: Any) -> Response:
|
||||||
|
secret = self.config.token_issue_secret.strip() or self.config.token.strip()
|
||||||
|
if secret:
|
||||||
|
if not _issue_route_secret_matches(request.headers, secret):
|
||||||
|
return _http_error(401, "Unauthorized")
|
||||||
|
elif not _is_localhost(connection):
|
||||||
|
return _http_error(403, "bootstrap is localhost-only")
|
||||||
|
|
||||||
|
self._purge_expired_issued_tokens()
|
||||||
|
self._purge_expired_api_tokens()
|
||||||
|
if (
|
||||||
|
len(self.issued_tokens) >= self._MAX_ISSUED_TOKENS
|
||||||
|
or len(self.api_tokens) >= self._MAX_ISSUED_TOKENS
|
||||||
|
):
|
||||||
|
return _http_response(
|
||||||
|
json.dumps({"error": "too many outstanding tokens"}).encode("utf-8"),
|
||||||
|
status=429,
|
||||||
|
content_type="application/json; charset=utf-8",
|
||||||
|
)
|
||||||
|
token = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||||
|
expiry = time.monotonic() + float(self.config.token_ttl_s)
|
||||||
|
self.issued_tokens[token] = expiry
|
||||||
|
self.api_tokens[token] = expiry
|
||||||
|
|
||||||
|
ws_url = self._bootstrap_ws_url(request)
|
||||||
|
expected_path = _normalize_config_path(self.config.path)
|
||||||
|
return _http_json_response(
|
||||||
|
{
|
||||||
|
"token": token,
|
||||||
|
"ws_path": expected_path,
|
||||||
|
"ws_url": ws_url,
|
||||||
|
"expires_in": self.config.token_ttl_s,
|
||||||
|
"model_name": _resolve_bootstrap_model_name(self.runtime_model_name),
|
||||||
|
"runtime_surface": self._runtime_surface,
|
||||||
|
"runtime_capabilities": self._capabilities,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _bootstrap_ws_url(self, request: Any) -> str:
|
||||||
|
headers = getattr(request, "headers", {}) or {}
|
||||||
|
host = _safe_host_header(_case_insensitive_header(headers, "Host"))
|
||||||
|
if not host:
|
||||||
|
host = _host_for_url(self.config.host, self.config.port)
|
||||||
|
proto = _case_insensitive_header(headers, "X-Forwarded-Proto")
|
||||||
|
proto = proto.split(",", 1)[0].strip().lower()
|
||||||
|
secure = proto in {"https", "wss"} or bool(self.config.ssl_certfile.strip())
|
||||||
|
scheme = "wss" if secure else "ws"
|
||||||
|
expected_path = _normalize_config_path(self.config.path)
|
||||||
|
return f"{scheme}://{host}{expected_path}"
|
||||||
|
|
||||||
|
# -- Session routes -----------------------------------------------------
|
||||||
|
|
||||||
|
def _dispatch_session_routes(self, request: WsRequest, got: str) -> Response | None:
|
||||||
|
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
|
||||||
|
if m:
|
||||||
|
return self._handle_session_messages(request, m.group(1))
|
||||||
|
|
||||||
|
m = re.match(r"^/api/sessions/([^/]+)/webui-thread$", got)
|
||||||
|
if m:
|
||||||
|
return self._handle_webui_thread_get(request, m.group(1))
|
||||||
|
|
||||||
|
m = re.match(r"^/api/sessions/([^/]+)/delete$", got)
|
||||||
|
if m:
|
||||||
|
return self._handle_session_delete(request, m.group(1))
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _handle_sessions_list(self, request: WsRequest) -> Response:
|
||||||
|
if not self.check_api_token(request):
|
||||||
|
return _http_error(401, "Unauthorized")
|
||||||
|
if self.session_manager is None:
|
||||||
|
return _http_error(503, "session manager unavailable")
|
||||||
|
sessions = self.session_manager.list_sessions()
|
||||||
|
from nanobot.session.webui_turns import websocket_turn_wall_started_at
|
||||||
|
|
||||||
|
cleaned = []
|
||||||
|
for s in sessions:
|
||||||
|
key = s.get("key")
|
||||||
|
if not (isinstance(key, str) and key.startswith("websocket:")):
|
||||||
|
continue
|
||||||
|
row = {k: v for k, v in s.items() if k != "path"}
|
||||||
|
chat_id = key.split(":", 1)[1]
|
||||||
|
started_at = websocket_turn_wall_started_at(chat_id)
|
||||||
|
if started_at is not None:
|
||||||
|
row["run_started_at"] = started_at
|
||||||
|
scope = self.workspaces.scope_for_session_key(key)
|
||||||
|
row["workspace_scope"] = scope.payload()
|
||||||
|
cleaned.append(row)
|
||||||
|
return _http_json_response({"sessions": cleaned})
|
||||||
|
|
||||||
|
def _handle_session_messages(self, request: WsRequest, key: str) -> Response:
|
||||||
|
if not self.check_api_token(request):
|
||||||
|
return _http_error(401, "Unauthorized")
|
||||||
|
if self.session_manager is None:
|
||||||
|
return _http_error(503, "session manager unavailable")
|
||||||
|
decoded_key = _decode_api_key(key)
|
||||||
|
if decoded_key is None:
|
||||||
|
return _http_error(400, "invalid session key")
|
||||||
|
if not _is_websocket_channel_session_key(decoded_key):
|
||||||
|
return _http_error(404, "session not found")
|
||||||
|
data = self.session_manager.read_session_file(decoded_key)
|
||||||
|
if data is None:
|
||||||
|
return _http_error(404, "session not found")
|
||||||
|
messages = data.get("messages")
|
||||||
|
if isinstance(messages, list):
|
||||||
|
scrub_subagent_messages_for_channel(messages)
|
||||||
|
self._augment_media_urls(data)
|
||||||
|
return _http_json_response(data)
|
||||||
|
|
||||||
|
def _handle_webui_thread_get(self, request: WsRequest, key: str) -> Response:
|
||||||
|
if not self.check_api_token(request):
|
||||||
|
return _http_error(401, "Unauthorized")
|
||||||
|
decoded_key = _decode_api_key(key)
|
||||||
|
if decoded_key is None:
|
||||||
|
return _http_error(400, "invalid session key")
|
||||||
|
if not _is_websocket_channel_session_key(decoded_key):
|
||||||
|
return _http_error(404, "session not found")
|
||||||
|
scope = self.workspaces.scope_for_session_key(decoded_key)
|
||||||
|
data = build_webui_thread_response(
|
||||||
|
decoded_key,
|
||||||
|
augment_user_media=self._augment_transcript_user_media,
|
||||||
|
augment_assistant_text=lambda text: rewrite_local_markdown_images(
|
||||||
|
text,
|
||||||
|
workspace_path=scope.project_path,
|
||||||
|
sign_path=self.sign_or_stage_media_path,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if data is None:
|
||||||
|
return _http_error(404, "webui thread not found")
|
||||||
|
data["workspace_scope"] = scope.payload()
|
||||||
|
return _http_json_response(data)
|
||||||
|
|
||||||
|
def _handle_session_delete(self, request: WsRequest, key: str) -> Response:
|
||||||
|
if not self.check_api_token(request):
|
||||||
|
return _http_error(401, "Unauthorized")
|
||||||
|
if self.session_manager is None:
|
||||||
|
return _http_error(503, "session manager unavailable")
|
||||||
|
decoded_key = _decode_api_key(key)
|
||||||
|
if decoded_key is None:
|
||||||
|
return _http_error(400, "invalid session key")
|
||||||
|
if not _is_websocket_channel_session_key(decoded_key):
|
||||||
|
return _http_error(404, "session not found")
|
||||||
|
deleted = self.session_manager.delete_session(decoded_key)
|
||||||
|
delete_webui_thread(decoded_key)
|
||||||
|
return _http_json_response({"deleted": bool(deleted)})
|
||||||
|
|
||||||
|
# -- Media routes -------------------------------------------------------
|
||||||
|
|
||||||
|
def _dispatch_media_routes(self, request: WsRequest, got: str) -> Response | None:
|
||||||
|
m = re.match(r"^/api/media/([A-Za-z0-9_-]+)/([A-Za-z0-9_-]+)$", got)
|
||||||
|
if m:
|
||||||
|
return self._handle_media_fetch(m.group(1), m.group(2), request)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _handle_media_fetch(
|
||||||
|
self, sig: str, payload: str, request: WsRequest | None = None
|
||||||
|
) -> Response:
|
||||||
|
return serve_signed_media(
|
||||||
|
sig,
|
||||||
|
payload,
|
||||||
|
secret=self.media_secret,
|
||||||
|
request=request,
|
||||||
|
media_dir=lambda channel=None: get_media_dir(channel),
|
||||||
|
)
|
||||||
|
|
||||||
|
def sign_media_path(self, abs_path: Path) -> str | None:
|
||||||
|
return sign_media_path(
|
||||||
|
abs_path,
|
||||||
|
secret=self.media_secret,
|
||||||
|
media_dir=lambda channel=None: get_media_dir(channel),
|
||||||
|
)
|
||||||
|
|
||||||
|
def sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
|
||||||
|
return sign_or_stage_media_path(
|
||||||
|
path,
|
||||||
|
secret=self.media_secret,
|
||||||
|
media_dir=lambda channel=None: get_media_dir(channel),
|
||||||
|
logger=self._log,
|
||||||
|
)
|
||||||
|
|
||||||
|
def rewrite_local_markdown_images(self, text: str) -> str:
|
||||||
|
return rewrite_local_markdown_images(
|
||||||
|
text,
|
||||||
|
workspace_path=self.workspace_path,
|
||||||
|
sign_path=self.sign_or_stage_media_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Misc routes --------------------------------------------------------
|
||||||
|
|
||||||
|
def _dispatch_misc_routes(
|
||||||
|
self, connection: Any, request: WsRequest, got: str
|
||||||
|
) -> Response | None:
|
||||||
|
if got == "/api/sessions":
|
||||||
|
return self._handle_sessions_list(request)
|
||||||
|
if got == "/api/commands":
|
||||||
|
return self._handle_commands(request)
|
||||||
|
if got == "/api/workspaces":
|
||||||
|
return self._handle_workspaces(connection, request)
|
||||||
|
if got == "/api/webui/sidebar-state":
|
||||||
|
return self._handle_webui_sidebar_state(request)
|
||||||
|
if got == "/api/webui/sidebar-state/update":
|
||||||
|
return self._handle_webui_sidebar_state_update(request)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _handle_commands(self, request: WsRequest) -> Response:
|
||||||
|
if not self.check_api_token(request):
|
||||||
|
return _http_error(401, "Unauthorized")
|
||||||
|
return _http_json_response({"commands": builtin_command_palette()})
|
||||||
|
|
||||||
|
def _handle_workspaces(self, connection: Any, request: WsRequest) -> Response:
|
||||||
|
if not self.check_api_token(request):
|
||||||
|
return _http_error(401, "Unauthorized")
|
||||||
|
return _http_json_response(
|
||||||
|
self.workspaces.payload(controls_available=_is_localhost(connection))
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_webui_sidebar_state(self, request: WsRequest) -> Response:
|
||||||
|
if not self.check_api_token(request):
|
||||||
|
return _http_error(401, "Unauthorized")
|
||||||
|
return _http_json_response(read_webui_sidebar_state())
|
||||||
|
|
||||||
|
def _handle_webui_sidebar_state_update(self, request: WsRequest) -> Response:
|
||||||
|
if not self.check_api_token(request):
|
||||||
|
return _http_error(401, "Unauthorized")
|
||||||
|
query = _parse_query(request.path)
|
||||||
|
raw_state = _query_first(query, "state")
|
||||||
|
if raw_state is None:
|
||||||
|
return _http_error(400, "missing state")
|
||||||
|
try:
|
||||||
|
decoded = json.loads(raw_state)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return _http_error(400, "state must be JSON")
|
||||||
|
if not isinstance(decoded, dict):
|
||||||
|
return _http_error(400, "state must be an object")
|
||||||
|
try:
|
||||||
|
state = write_webui_sidebar_state(decoded)
|
||||||
|
except ValueError as e:
|
||||||
|
return _http_error(400, str(e))
|
||||||
|
except OSError:
|
||||||
|
self._log.exception("failed to write webui sidebar state")
|
||||||
|
return _http_error(500, "failed to write sidebar state")
|
||||||
|
return _http_json_response(state)
|
||||||
|
|
||||||
|
# -- Static file serving ------------------------------------------------
|
||||||
|
|
||||||
|
def _serve_static(self, request_path: str) -> Response | None:
|
||||||
|
assert self.static_dist_path is not None
|
||||||
|
rel = request_path.lstrip("/")
|
||||||
|
if not rel:
|
||||||
|
rel = "index.html"
|
||||||
|
if ".." in rel.split("/") or rel.startswith("/"):
|
||||||
|
return _http_error(403, "Forbidden")
|
||||||
|
candidate = (self.static_dist_path / rel).resolve()
|
||||||
|
try:
|
||||||
|
candidate.relative_to(self.static_dist_path)
|
||||||
|
except ValueError:
|
||||||
|
return _http_error(403, "Forbidden")
|
||||||
|
if not candidate.is_file():
|
||||||
|
index = self.static_dist_path / "index.html"
|
||||||
|
if index.is_file():
|
||||||
|
candidate = index
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
body = candidate.read_bytes()
|
||||||
|
except OSError as e:
|
||||||
|
self._log.warning("static: failed to read {}: {}", candidate, e)
|
||||||
|
return _http_error(500, "Internal Server Error")
|
||||||
|
ctype, _ = mimetypes.guess_type(candidate.name)
|
||||||
|
if ctype is None:
|
||||||
|
ctype = "application/octet-stream"
|
||||||
|
if ctype.startswith("text/") or ctype in {"application/javascript", "application/json"}:
|
||||||
|
ctype = f"{ctype}; charset=utf-8"
|
||||||
|
if candidate.name == "index.html":
|
||||||
|
cache = "no-cache"
|
||||||
|
else:
|
||||||
|
cache = "public, max-age=31536000, immutable"
|
||||||
|
return _http_response(
|
||||||
|
body,
|
||||||
|
status=200,
|
||||||
|
content_type=ctype,
|
||||||
|
extra_headers=[("Cache-Control", cache)],
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Media helpers (called by WebSocketChannel.send) --------------------
|
||||||
|
|
||||||
|
def _augment_media_urls(self, payload: dict[str, Any]) -> None:
|
||||||
|
messages = payload.get("messages")
|
||||||
|
if not isinstance(messages, list):
|
||||||
|
return
|
||||||
|
for msg in messages:
|
||||||
|
if not isinstance(msg, dict):
|
||||||
|
continue
|
||||||
|
media = msg.get("media")
|
||||||
|
if not isinstance(media, list) or not media:
|
||||||
|
continue
|
||||||
|
urls: list[dict[str, str]] = []
|
||||||
|
for entry in media:
|
||||||
|
if not isinstance(entry, str) or not entry:
|
||||||
|
continue
|
||||||
|
signed = self.sign_media_path(Path(entry))
|
||||||
|
if signed is None:
|
||||||
|
continue
|
||||||
|
urls.append({"url": signed, "name": Path(entry).name})
|
||||||
|
if urls:
|
||||||
|
msg["media_urls"] = urls
|
||||||
|
msg.pop("media", None)
|
||||||
|
|
||||||
|
def _augment_transcript_user_media(self, paths: list[str]) -> list[dict[str, Any]]:
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
for pstr in paths:
|
||||||
|
path = Path(pstr)
|
||||||
|
att = self.sign_or_stage_media_path(path)
|
||||||
|
if att is None:
|
||||||
|
continue
|
||||||
|
mime, _ = mimetypes.guess_type(path.name)
|
||||||
|
kind = "video" if mime and mime.startswith("video/") else "image"
|
||||||
|
out.append(
|
||||||
|
{"kind": kind, "url": att["url"], "name": att.get("name", path.name)},
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _is_websocket_channel_session_key(key: str) -> bool:
|
||||||
|
return key.startswith("websocket:")
|
||||||
@ -626,6 +626,7 @@ async def test_send_stages_external_media_as_signed_url(monkeypatch, tmp_path) -
|
|||||||
return ws_media if channel == "websocket" else media_root
|
return ws_media if channel == "websocket" else media_root
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
||||||
|
monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir)
|
||||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||||
mock_ws = AsyncMock()
|
mock_ws = AsyncMock()
|
||||||
channel._attach(mock_ws, "chat-1")
|
channel._attach(mock_ws, "chat-1")
|
||||||
@ -840,6 +841,7 @@ async def test_send_delta_stream_end_rewrites_local_markdown_image(monkeypatch,
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
||||||
|
monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir)
|
||||||
channel = WebSocketChannel(
|
channel = WebSocketChannel(
|
||||||
{"enabled": True, "allowFrom": ["*"], "streaming": True},
|
{"enabled": True, "allowFrom": ["*"], "streaming": True},
|
||||||
bus,
|
bus,
|
||||||
@ -872,6 +874,7 @@ async def test_send_delta_stream_end_rewrites_inline_final_text(monkeypatch, tmp
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
monkeypatch.setattr("nanobot.channels.websocket.get_media_dir", fake_media_dir)
|
||||||
|
monkeypatch.setattr("nanobot.channels.ws_http.get_media_dir", fake_media_dir)
|
||||||
channel = WebSocketChannel(
|
channel = WebSocketChannel(
|
||||||
{"enabled": True, "allowFrom": ["*"], "streaming": True},
|
{"enabled": True, "allowFrom": ["*"], "streaming": True},
|
||||||
bus,
|
bus,
|
||||||
|
|||||||
@ -710,20 +710,20 @@ async def test_api_token_pool_purges_expired(bus: MagicMock, tmp_path: Path) ->
|
|||||||
channel = _ch(bus, session_manager=sm, port=29908)
|
channel = _ch(bus, session_manager=sm, port=29908)
|
||||||
# Don't start a server — directly inject and validate.
|
# Don't start a server — directly inject and validate.
|
||||||
import time as _time
|
import time as _time
|
||||||
channel._api_tokens["expired"] = _time.monotonic() - 1
|
channel._http.api_tokens["expired"] = _time.monotonic() - 1
|
||||||
channel._api_tokens["live"] = _time.monotonic() + 60
|
channel._http.api_tokens["live"] = _time.monotonic() + 60
|
||||||
|
|
||||||
class _FakeReq:
|
class _FakeReq:
|
||||||
path = "/api/sessions"
|
path = "/api/sessions"
|
||||||
headers = {"Authorization": "Bearer expired"}
|
headers = {"Authorization": "Bearer expired"}
|
||||||
|
|
||||||
assert channel._check_api_token(_FakeReq()) is False
|
assert channel._http.check_api_token(_FakeReq()) is False
|
||||||
|
|
||||||
class _LiveReq:
|
class _LiveReq:
|
||||||
path = "/api/sessions"
|
path = "/api/sessions"
|
||||||
headers = {"Authorization": "Bearer live"}
|
headers = {"Authorization": "Bearer live"}
|
||||||
|
|
||||||
assert channel._check_api_token(_LiveReq()) is True
|
assert channel._http.check_api_token(_LiveReq()) is True
|
||||||
|
|
||||||
|
|
||||||
class _FakeConn:
|
class _FakeConn:
|
||||||
@ -814,7 +814,7 @@ def test_localhost_without_auth_is_valid(bus: MagicMock) -> None:
|
|||||||
|
|
||||||
def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.channels.websocket._default_model_name_from_config",
|
"nanobot.channels.ws_http._default_model_name_from_config",
|
||||||
lambda: "from-disk",
|
lambda: "from-disk",
|
||||||
)
|
)
|
||||||
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " live/model ")
|
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " live/model ")
|
||||||
@ -826,7 +826,7 @@ def test_bootstrap_prefers_runtime_model_name(bus: MagicMock, monkeypatch: pytes
|
|||||||
|
|
||||||
def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.channels.websocket._default_model_name_from_config",
|
"nanobot.channels.ws_http._default_model_name_from_config",
|
||||||
lambda: "from-disk",
|
lambda: "from-disk",
|
||||||
)
|
)
|
||||||
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " ")
|
channel = _ch(bus, host="127.0.0.1", runtime_model_name=lambda: " ")
|
||||||
@ -838,7 +838,7 @@ def test_bootstrap_falls_back_when_runtime_returns_empty(bus: MagicMock, monkeyp
|
|||||||
|
|
||||||
def test_bootstrap_falls_back_when_runtime_raises(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_bootstrap_falls_back_when_runtime_raises(bus: MagicMock, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.channels.websocket._default_model_name_from_config",
|
"nanobot.channels.ws_http._default_model_name_from_config",
|
||||||
lambda: "from-disk",
|
lambda: "from-disk",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -106,7 +106,7 @@ def test_sign_media_path_rejects_paths_outside_media_root(
|
|||||||
media = tmp_path / "media"
|
media = tmp_path / "media"
|
||||||
media.mkdir()
|
media.mkdir()
|
||||||
channel = _ch(bus, port=0)
|
channel = _ch(bus, port=0)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
assert channel._sign_media_path(outside) is None
|
assert channel._sign_media_path(outside) is None
|
||||||
# Traversal via the media root is also rejected — the resolve() step
|
# Traversal via the media root is also rejected — the resolve() step
|
||||||
# normalises ``..`` out before the relative_to check.
|
# normalises ``..`` out before the relative_to check.
|
||||||
@ -121,7 +121,7 @@ def test_sign_media_path_round_trips_via_hmac(
|
|||||||
media.mkdir()
|
media.mkdir()
|
||||||
(media / "a.png").write_bytes(_PNG_BYTES)
|
(media / "a.png").write_bytes(_PNG_BYTES)
|
||||||
channel = _ch(bus, port=0)
|
channel = _ch(bus, port=0)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
url = channel._sign_media_path(media / "a.png")
|
url = channel._sign_media_path(media / "a.png")
|
||||||
assert url is not None
|
assert url is not None
|
||||||
assert url.startswith("/api/media/")
|
assert url.startswith("/api/media/")
|
||||||
@ -144,7 +144,7 @@ def test_local_markdown_image_is_staged_and_rewritten(
|
|||||||
media = tmp_path / "media"
|
media = tmp_path / "media"
|
||||||
channel = _ch(bus, workspace_path=workspace, port=0)
|
channel = _ch(bus, workspace_path=workspace, port=0)
|
||||||
|
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)):
|
with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||||
rewritten = channel._rewrite_local_markdown_images(
|
rewritten = channel._rewrite_local_markdown_images(
|
||||||
"The result:\n"
|
"The result:\n"
|
||||||
)
|
)
|
||||||
@ -166,7 +166,7 @@ def test_local_markdown_video_is_staged_and_rewritten(
|
|||||||
media = tmp_path / "media"
|
media = tmp_path / "media"
|
||||||
channel = _ch(bus, workspace_path=workspace, port=0)
|
channel = _ch(bus, workspace_path=workspace, port=0)
|
||||||
|
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)):
|
with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||||
rewritten = channel._rewrite_local_markdown_images(
|
rewritten = channel._rewrite_local_markdown_images(
|
||||||
"The result:\n"
|
"The result:\n"
|
||||||
)
|
)
|
||||||
@ -189,7 +189,7 @@ def test_local_markdown_image_rejects_workspace_escape(
|
|||||||
channel = _ch(bus, workspace_path=workspace, port=0)
|
channel = _ch(bus, workspace_path=workspace, port=0)
|
||||||
text = ""
|
text = ""
|
||||||
|
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", side_effect=_fake_media_dir(media)):
|
with patch("nanobot.channels.ws_http.get_media_dir", side_effect=_fake_media_dir(media)):
|
||||||
assert channel._rewrite_local_markdown_images(text) == text
|
assert channel._rewrite_local_markdown_images(text) == text
|
||||||
|
|
||||||
assert not (media / "websocket").exists()
|
assert not (media / "websocket").exists()
|
||||||
@ -211,7 +211,7 @@ async def test_media_route_serves_signed_file(
|
|||||||
target.write_bytes(_PNG_BYTES)
|
target.write_bytes(_PNG_BYTES)
|
||||||
|
|
||||||
channel = _ch(bus, port=29920)
|
channel = _ch(bus, port=29920)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
url_path = channel._sign_media_path(target)
|
url_path = channel._sign_media_path(target)
|
||||||
assert url_path is not None
|
assert url_path is not None
|
||||||
server_task = asyncio.create_task(channel.start())
|
server_task = asyncio.create_task(channel.start())
|
||||||
@ -244,7 +244,7 @@ async def test_media_route_serves_video_byte_ranges(
|
|||||||
target.write_bytes(b"0123456789")
|
target.write_bytes(b"0123456789")
|
||||||
|
|
||||||
channel = _ch(bus, port=29927)
|
channel = _ch(bus, port=29927)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
url_path = channel._sign_media_path(target)
|
url_path = channel._sign_media_path(target)
|
||||||
assert url_path is not None
|
assert url_path is not None
|
||||||
server_task = asyncio.create_task(channel.start())
|
server_task = asyncio.create_task(channel.start())
|
||||||
@ -276,7 +276,7 @@ async def test_media_route_serves_suffix_video_byte_ranges(
|
|||||||
target.write_bytes(b"0123456789")
|
target.write_bytes(b"0123456789")
|
||||||
|
|
||||||
channel = _ch(bus, port=29928)
|
channel = _ch(bus, port=29928)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
url_path = channel._sign_media_path(target)
|
url_path = channel._sign_media_path(target)
|
||||||
assert url_path is not None
|
assert url_path is not None
|
||||||
server_task = asyncio.create_task(channel.start())
|
server_task = asyncio.create_task(channel.start())
|
||||||
@ -305,7 +305,7 @@ async def test_media_route_rejects_unsatisfiable_byte_range(
|
|||||||
target.write_bytes(b"0123456789")
|
target.write_bytes(b"0123456789")
|
||||||
|
|
||||||
channel = _ch(bus, port=29929)
|
channel = _ch(bus, port=29929)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
url_path = channel._sign_media_path(target)
|
url_path = channel._sign_media_path(target)
|
||||||
assert url_path is not None
|
assert url_path is not None
|
||||||
server_task = asyncio.create_task(channel.start())
|
server_task = asyncio.create_task(channel.start())
|
||||||
@ -338,7 +338,7 @@ async def test_media_route_rejects_bad_signature(
|
|||||||
(media / "f.png").write_bytes(_PNG_BYTES)
|
(media / "f.png").write_bytes(_PNG_BYTES)
|
||||||
|
|
||||||
channel = _ch(bus, port=29921)
|
channel = _ch(bus, port=29921)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
good = channel._sign_media_path(media / "f.png")
|
good = channel._sign_media_path(media / "f.png")
|
||||||
assert good is not None
|
assert good is not None
|
||||||
_, payload = good[len("/api/media/"):].split("/", 1)
|
_, payload = good[len("/api/media/"):].split("/", 1)
|
||||||
@ -381,7 +381,7 @@ async def test_media_route_rejects_path_traversal_payload(
|
|||||||
).digest()[:16]
|
).digest()[:16]
|
||||||
url = f"/api/media/{b64url_encode(mac)}/{payload}"
|
url = f"/api/media/{b64url_encode(mac)}/{payload}"
|
||||||
|
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
server_task = asyncio.create_task(channel.start())
|
server_task = asyncio.create_task(channel.start())
|
||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
try:
|
try:
|
||||||
@ -405,7 +405,7 @@ async def test_media_route_404s_missing_file(
|
|||||||
target.write_bytes(_PNG_BYTES)
|
target.write_bytes(_PNG_BYTES)
|
||||||
|
|
||||||
channel = _ch(bus, port=29923)
|
channel = _ch(bus, port=29923)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
url_path = channel._sign_media_path(target)
|
url_path = channel._sign_media_path(target)
|
||||||
assert url_path is not None
|
assert url_path is not None
|
||||||
target.unlink() # the file vanishes between signing and fetching
|
target.unlink() # the file vanishes between signing and fetching
|
||||||
@ -433,7 +433,7 @@ async def test_media_route_degrades_non_image_to_octet_stream(
|
|||||||
(media / "scary.html").write_bytes(b"<script>alert(1)</script>")
|
(media / "scary.html").write_bytes(b"<script>alert(1)</script>")
|
||||||
|
|
||||||
channel = _ch(bus, port=29924)
|
channel = _ch(bus, port=29924)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
payload = b64url_encode(b"scary.html")
|
payload = b64url_encode(b"scary.html")
|
||||||
mac = hmac.new(
|
mac = hmac.new(
|
||||||
channel._media_secret, payload.encode("ascii"), hashlib.sha256
|
channel._media_secret, payload.encode("ascii"), hashlib.sha256
|
||||||
@ -464,7 +464,7 @@ async def test_media_route_serves_svg_with_strict_csp(
|
|||||||
target.write_text("<svg xmlns='http://www.w3.org/2000/svg'><script>alert(1)</script></svg>")
|
target.write_text("<svg xmlns='http://www.w3.org/2000/svg'><script>alert(1)</script></svg>")
|
||||||
|
|
||||||
channel = _ch(bus, port=29928)
|
channel = _ch(bus, port=29928)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
url_path = channel._sign_media_path(target)
|
url_path = channel._sign_media_path(target)
|
||||||
assert url_path is not None
|
assert url_path is not None
|
||||||
server_task = asyncio.create_task(channel.start())
|
server_task = asyncio.create_task(channel.start())
|
||||||
@ -505,7 +505,7 @@ async def test_session_messages_exposes_signed_media_urls(
|
|||||||
sm.save(sess)
|
sm.save(sess)
|
||||||
|
|
||||||
channel = _ch(bus, session_manager=sm, port=29925)
|
channel = _ch(bus, session_manager=sm, port=29925)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
server_task = asyncio.create_task(channel.start())
|
server_task = asyncio.create_task(channel.start())
|
||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
try:
|
try:
|
||||||
@ -550,7 +550,7 @@ async def test_session_messages_skips_vanished_media(
|
|||||||
sm.save(sess)
|
sm.save(sess)
|
||||||
|
|
||||||
channel = _ch(bus, session_manager=sm, port=29926)
|
channel = _ch(bus, session_manager=sm, port=29926)
|
||||||
with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
|
with patch("nanobot.channels.ws_http.get_media_dir", return_value=media):
|
||||||
server_task = asyncio.create_task(channel.start())
|
server_task = asyncio.create_task(channel.start())
|
||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user