diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index 65adc123f..48d73535b 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -3,17 +3,13 @@ from __future__ import annotations import asyncio -import base64 -import binascii import email.utils -import hashlib import hmac import http import json import mimetypes import re import secrets -import shutil import ssl import time import uuid @@ -44,7 +40,6 @@ from nanobot.config.paths import get_media_dir, get_workspace_path from nanobot.config.schema import Base from nanobot.session.goal_state import goal_state_ws_blob from nanobot.session.webui_turns import websocket_turn_wall_started_at -from nanobot.utils.helpers import safe_filename from nanobot.utils.media_decode import ( FileSizeExceeded, save_base64_data_url, @@ -70,6 +65,11 @@ from nanobot.webui.cli_apps_api import ( cli_apps_payload, normalize_cli_app_mentions, ) +from nanobot.webui.media_api import ( + serve_signed_media, + sign_media_path, + sign_or_stage_media_path, +) from nanobot.webui.mcp_presets_api import ( mcp_presets_settings_action, normalize_mcp_preset_mentions, @@ -514,59 +514,6 @@ def _is_websocket_upgrade(request: WsRequest) -> bool: return True -def _b64url_encode(data: bytes) -> str: - """URL-safe base64 without padding — compact + friendly in URL paths.""" - return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") - - -def _b64url_decode(s: str) -> bytes: - """Reverse of :func:`_b64url_encode`; caller handles ``ValueError``.""" - pad = "=" * (-len(s) % 4) - return base64.urlsafe_b64decode(s + pad) - - -# Allowed MIME types we actually serve from the media endpoint. Anything -# outside this set is degraded to ``application/octet-stream`` so an -# attacker who somehow gets a signed URL for an unexpected file type can't -# trick the browser into sniffing executable content. -_MEDIA_ALLOWED_MIMES: frozenset[str] = frozenset({ - "image/png", - "image/jpeg", - "image/webp", - "image/gif", - "video/mp4", - "video/webm", - "video/quicktime", -}) - -_BYTE_RANGE_RE = re.compile(r"^bytes=(\d*)-(\d*)$") - - -def _parse_single_byte_range(range_header: str, size: int) -> tuple[int, int]: - """Parse a single HTTP byte range for signed media responses.""" - if size <= 0 or "," in range_header: - raise ValueError("invalid byte range") - m = _BYTE_RANGE_RE.fullmatch(range_header.strip()) - if m is None: - raise ValueError("invalid byte range") - start_text, end_text = m.groups() - if not start_text and not end_text: - raise ValueError("invalid byte range") - if not start_text: - suffix_length = int(end_text) - if suffix_length <= 0: - raise ValueError("invalid byte range") - start = max(size - suffix_length, 0) - end = size - 1 - else: - start = int(start_text) - end = int(end_text) if end_text else size - 1 - if start >= size or start > end: - raise ValueError("invalid byte range") - end = min(end, size - 1) - return start, end - - def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool: """Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``.""" if not configured_secret: @@ -1368,16 +1315,11 @@ class WebSocketChannel(BaseChannel): be fetched. The returned path is relative to the server origin; the client joins it against this server's HTTP origin (same host as WS). """ - try: - media_root = get_media_dir().resolve() - rel = abs_path.resolve().relative_to(media_root) - except (OSError, ValueError): - return None - payload = _b64url_encode(rel.as_posix().encode("utf-8")) - mac = hmac.new( - self._media_secret, payload.encode("ascii"), hashlib.sha256 - ).digest()[:16] - return f"/api/media/{_b64url_encode(mac)}/{payload}" + return sign_media_path( + abs_path, + secret=self._media_secret, + media_dir=lambda channel=None: get_media_dir(channel), + ) def _sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None: """Return a signed media URL payload for *path*. @@ -1388,23 +1330,12 @@ class WebSocketChannel(BaseChannel): can fetch them through the existing signed media route without exposing arbitrary filesystem paths. """ - signed = self._sign_media_path(path) - if signed is not None: - return {"url": signed, "name": path.name} - try: - if not path.is_file(): - return None - media_dir = get_media_dir("websocket") - safe_name = safe_filename(path.name) or "attachment" - staged = media_dir / f"{uuid.uuid4().hex[:12]}-{safe_name}" - shutil.copyfile(path, staged) - except OSError as exc: - self.logger.warning("failed to stage outbound media {}: {}", path, exc) - return None - signed = self._sign_media_path(staged) - if signed is None: - return None - return {"url": signed, "name": path.name} + return sign_or_stage_media_path( + path, + secret=self._media_secret, + media_dir=lambda channel=None: get_media_dir(channel), + logger=self.logger, + ) def _rewrite_local_markdown_images(self, text: str) -> str: return rewrite_local_markdown_images( @@ -1421,86 +1352,12 @@ class WebSocketChannel(BaseChannel): payload to a relative path, and streams the file bytes with a long-lived immutable cache header (the URL already encodes the file identity, so caches can be aggressive).""" - try: - provided_mac = _b64url_decode(sig) - except (ValueError, binascii.Error): - return _http_error(401, "invalid signature") - expected_mac = hmac.new( - self._media_secret, payload.encode("ascii"), hashlib.sha256 - ).digest()[:16] - if not hmac.compare_digest(expected_mac, provided_mac): - return _http_error(401, "invalid signature") - try: - rel_bytes = _b64url_decode(payload) - rel_str = rel_bytes.decode("utf-8") - except (ValueError, binascii.Error, UnicodeDecodeError): - return _http_error(400, "invalid payload") - # An attacker who somehow bypassed the HMAC check would still need - # the resolved path to escape the media root; guard defensively. - try: - media_root = get_media_dir().resolve() - candidate = (media_root / rel_str).resolve() - candidate.relative_to(media_root) - except (OSError, ValueError): - return _http_error(404, "not found") - if not candidate.is_file(): - return _http_error(404, "not found") - mime, _ = mimetypes.guess_type(candidate.name) - if mime not in _MEDIA_ALLOWED_MIMES: - mime = "application/octet-stream" - common_headers = [ - ("Accept-Ranges", "bytes"), - ("Cache-Control", "private, max-age=31536000, immutable"), - # Paired with the MIME whitelist above: prevents browsers from - # MIME-sniffing an octet-stream fallback into executable HTML. - ("X-Content-Type-Options", "nosniff"), - ] - try: - size = candidate.stat().st_size - except OSError: - return _http_error(500, "read error") - - range_header = ( - _case_insensitive_header(request.headers, "Range") if request else "" - ) - if range_header: - try: - start, end = _parse_single_byte_range(range_header, size) - except ValueError: - return _http_response( - b"range not satisfiable", - status=416, - extra_headers=[ - ("Accept-Ranges", "bytes"), - ("Content-Range", f"bytes */{size}"), - ("X-Content-Type-Options", "nosniff"), - ], - ) - try: - length = end - start + 1 - with candidate.open("rb") as fh: - fh.seek(start) - body = fh.read(length) - except OSError: - return _http_error(500, "read error") - return _http_response( - body, - status=206, - content_type=mime, - extra_headers=[ - *common_headers, - ("Content-Range", f"bytes {start}-{end}/{size}"), - ], - ) - - try: - body = candidate.read_bytes() - except OSError: - return _http_error(500, "read error") - return _http_response( - body, - content_type=mime, - extra_headers=common_headers, + return serve_signed_media( + sig, + payload, + secret=self._media_secret, + request=request, + media_dir=lambda channel=None: get_media_dir(channel), ) def _handle_session_delete(self, request: WsRequest, key: str) -> Response: diff --git a/nanobot/webui/media_api.py b/nanobot/webui/media_api.py new file mode 100644 index 000000000..451252116 --- /dev/null +++ b/nanobot/webui/media_api.py @@ -0,0 +1,246 @@ +"""Signed media helpers for the WebUI HTTP surface.""" + +from __future__ import annotations + +import base64 +import binascii +import email.utils +import hashlib +import hmac +import http +import mimetypes +import re +import shutil +import uuid +from collections.abc import Callable +from pathlib import Path +from typing import Any + +from websockets.datastructures import Headers +from websockets.http11 import Request as WsRequest +from websockets.http11 import Response + +from nanobot.config.paths import get_media_dir +from nanobot.utils.helpers import safe_filename + +MediaDirProvider = Callable[[str | None], Path] + + +def b64url_encode(data: bytes) -> str: + """URL-safe base64 without padding.""" + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def b64url_decode(value: str) -> bytes: + """Reverse of :func:`b64url_encode`; caller handles decode errors.""" + pad = "=" * (-len(value) % 4) + return base64.urlsafe_b64decode(value + pad) + + +def _default_media_dir(channel: str | None = None) -> Path: + return get_media_dir(channel) + + +# Allowed MIME types we actually serve from the media endpoint. Anything +# outside this set is degraded to ``application/octet-stream`` so an +# attacker who somehow gets a signed URL for an unexpected file type can't +# trick the browser into sniffing executable content. +_MEDIA_ALLOWED_MIMES: frozenset[str] = frozenset({ + "image/png", + "image/jpeg", + "image/webp", + "image/gif", + "video/mp4", + "video/webm", + "video/quicktime", +}) + +_BYTE_RANGE_RE = re.compile(r"^bytes=(\d*)-(\d*)$") + + +def _http_response( + body: bytes, + *, + status: int = 200, + content_type: str = "text/plain; charset=utf-8", + extra_headers: list[tuple[str, str]] | None = None, +) -> Response: + headers = [ + ("Date", email.utils.formatdate(usegmt=True)), + ("Connection", "close"), + ("Content-Length", str(len(body))), + ("Content-Type", content_type), + ] + if extra_headers: + headers.extend(extra_headers) + reason = http.HTTPStatus(status).phrase + return Response(status, reason, Headers(headers), body) + + +def _http_error(status: int, message: str | None = None) -> Response: + body = (message or http.HTTPStatus(status).phrase).encode("utf-8") + return _http_response(body, status=status) + + +def _case_insensitive_header(headers: Any, key: str) -> str: + try: + value = headers.get(key) + except Exception: + value = None + if value is None: + try: + value = headers.get(key.lower()) + except Exception: + value = None + return str(value or "").strip() + + +def _parse_single_byte_range(range_header: str, size: int) -> tuple[int, int]: + """Parse a single HTTP byte range for signed media responses.""" + if size <= 0 or "," in range_header: + raise ValueError("invalid byte range") + m = _BYTE_RANGE_RE.fullmatch(range_header.strip()) + if m is None: + raise ValueError("invalid byte range") + start_text, end_text = m.groups() + if not start_text and not end_text: + raise ValueError("invalid byte range") + if not start_text: + suffix_length = int(end_text) + if suffix_length <= 0: + raise ValueError("invalid byte range") + start = max(size - suffix_length, 0) + end = size - 1 + else: + start = int(start_text) + end = int(end_text) if end_text else size - 1 + if start >= size or start > end: + raise ValueError("invalid byte range") + end = min(end, size - 1) + return start, end + + +def sign_media_path( + abs_path: Path, + *, + secret: bytes, + media_dir: MediaDirProvider = _default_media_dir, +) -> str | None: + """Return a signed ``/api/media//`` URL for a media-root path.""" + try: + media_root = media_dir(None).resolve() + rel = abs_path.resolve().relative_to(media_root) + except (OSError, ValueError): + return None + payload = b64url_encode(rel.as_posix().encode("utf-8")) + mac = hmac.new(secret, payload.encode("ascii"), hashlib.sha256).digest()[:16] + return f"/api/media/{b64url_encode(mac)}/{payload}" + + +def sign_or_stage_media_path( + path: Path, + *, + secret: bytes, + media_dir: MediaDirProvider = _default_media_dir, + logger: Any | None = None, +) -> dict[str, str] | None: + """Sign an existing media-root path, or stage an arbitrary file before signing.""" + signed = sign_media_path(path, secret=secret, media_dir=media_dir) + if signed is not None: + return {"url": signed, "name": path.name} + try: + if not path.is_file(): + return None + target_dir = media_dir("websocket") + safe_name = safe_filename(path.name) or "attachment" + staged = target_dir / f"{uuid.uuid4().hex[:12]}-{safe_name}" + shutil.copyfile(path, staged) + except OSError as exc: + if logger is not None: + logger.warning("failed to stage outbound media {}: {}", path, exc) + return None + signed = sign_media_path(staged, secret=secret, media_dir=media_dir) + if signed is None: + return None + return {"url": signed, "name": path.name} + + +def serve_signed_media( + sig: str, + payload: str, + *, + secret: bytes, + request: WsRequest | None = None, + media_dir: MediaDirProvider = _default_media_dir, +) -> Response: + """Serve a signed media URL, including browser-friendly byte ranges.""" + try: + provided_mac = b64url_decode(sig) + except (ValueError, binascii.Error): + return _http_error(401, "invalid signature") + expected_mac = hmac.new(secret, payload.encode("ascii"), hashlib.sha256).digest()[:16] + if not hmac.compare_digest(expected_mac, provided_mac): + return _http_error(401, "invalid signature") + try: + rel_bytes = b64url_decode(payload) + rel_str = rel_bytes.decode("utf-8") + except (ValueError, binascii.Error, UnicodeDecodeError): + return _http_error(400, "invalid payload") + try: + media_root = media_dir(None).resolve() + candidate = (media_root / rel_str).resolve() + candidate.relative_to(media_root) + except (OSError, ValueError): + return _http_error(404, "not found") + if not candidate.is_file(): + return _http_error(404, "not found") + + mime, _ = mimetypes.guess_type(candidate.name) + if mime not in _MEDIA_ALLOWED_MIMES: + mime = "application/octet-stream" + common_headers = [ + ("Accept-Ranges", "bytes"), + ("Cache-Control", "private, max-age=31536000, immutable"), + ("X-Content-Type-Options", "nosniff"), + ] + try: + size = candidate.stat().st_size + except OSError: + return _http_error(500, "read error") + + range_header = _case_insensitive_header(request.headers, "Range") if request else "" + if range_header: + try: + start, end = _parse_single_byte_range(range_header, size) + except ValueError: + return _http_response( + b"range not satisfiable", + status=416, + extra_headers=[ + ("Accept-Ranges", "bytes"), + ("Content-Range", f"bytes */{size}"), + ("X-Content-Type-Options", "nosniff"), + ], + ) + try: + length = end - start + 1 + with candidate.open("rb") as fh: + fh.seek(start) + body = fh.read(length) + except OSError: + return _http_error(500, "read error") + return _http_response( + body, + status=206, + content_type=mime, + extra_headers=[ + *common_headers, + ("Content-Range", f"bytes {start}-{end}/{size}"), + ], + ) + + try: + body = candidate.read_bytes() + except OSError: + return _http_error(500, "read error") + return _http_response(body, content_type=mime, extra_headers=common_headers) diff --git a/tests/channels/test_websocket_media_route.py b/tests/channels/test_websocket_media_route.py index 84cb6b47f..4c3f9161d 100644 --- a/tests/channels/test_websocket_media_route.py +++ b/tests/channels/test_websocket_media_route.py @@ -21,10 +21,10 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest -from nanobot.channels.websocket import ( - WebSocketChannel, - _b64url_decode, - _b64url_encode, +from nanobot.channels.websocket import WebSocketChannel +from nanobot.webui.media_api import ( + b64url_decode, + b64url_encode, ) from nanobot.session.manager import Session, SessionManager @@ -129,9 +129,9 @@ def test_sign_media_path_round_trips_via_hmac( expected = hmac.new( channel._media_secret, payload.encode("ascii"), hashlib.sha256 ).digest()[:16] - assert _b64url_decode(sig) == expected + assert b64url_decode(sig) == expected # The payload decodes back to the *relative* path — no absolute-path leaks. - assert _b64url_decode(payload).decode() == "a.png" + assert b64url_decode(payload).decode() == "a.png" def test_local_markdown_image_is_staged_and_rewritten( @@ -346,7 +346,7 @@ async def test_media_route_rejects_bad_signature( forged_mac = hmac.new( b"\x00" * 32, payload.encode("ascii"), hashlib.sha256 ).digest()[:16] - forged = f"/api/media/{_b64url_encode(forged_mac)}/{payload}" + forged = f"/api/media/{b64url_encode(forged_mac)}/{payload}" server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) @@ -375,11 +375,11 @@ async def test_media_route_rejects_path_traversal_payload( channel = _ch(bus, port=29922) # Hand-craft a traversal payload the legit signer would refuse to mint. - payload = _b64url_encode(b"../secret.txt") + payload = b64url_encode(b"../secret.txt") mac = hmac.new( channel._media_secret, payload.encode("ascii"), hashlib.sha256 ).digest()[:16] - url = f"/api/media/{_b64url_encode(mac)}/{payload}" + url = f"/api/media/{b64url_encode(mac)}/{payload}" with patch("nanobot.channels.websocket.get_media_dir", return_value=media): server_task = asyncio.create_task(channel.start()) @@ -434,11 +434,11 @@ async def test_media_route_degrades_non_image_to_octet_stream( channel = _ch(bus, port=29924) with patch("nanobot.channels.websocket.get_media_dir", return_value=media): - payload = _b64url_encode(b"scary.html") + payload = b64url_encode(b"scary.html") mac = hmac.new( channel._media_secret, payload.encode("ascii"), hashlib.sha256 ).digest()[:16] - url = f"/api/media/{_b64url_encode(mac)}/{payload}" + url = f"/api/media/{b64url_encode(mac)}/{payload}" server_task = asyncio.create_task(channel.start()) await asyncio.sleep(0.3) try: