refactor(webui): isolate signed media serving

This commit is contained in:
Xubin Ren 2026-05-29 17:05:59 +08:00
parent 4a0035ef8f
commit 9ed5643d93
3 changed files with 279 additions and 176 deletions

View File

@ -3,17 +3,13 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import base64
import binascii
import email.utils import email.utils
import hashlib
import hmac import hmac
import http import http
import json import json
import mimetypes import mimetypes
import re import re
import secrets import secrets
import shutil
import ssl import ssl
import time import time
import uuid import uuid
@ -44,7 +40,6 @@ from nanobot.config.paths import get_media_dir, get_workspace_path
from nanobot.config.schema import Base from nanobot.config.schema import Base
from nanobot.session.goal_state import goal_state_ws_blob from nanobot.session.goal_state import goal_state_ws_blob
from nanobot.session.webui_turns import websocket_turn_wall_started_at from nanobot.session.webui_turns import websocket_turn_wall_started_at
from nanobot.utils.helpers import safe_filename
from nanobot.utils.media_decode import ( from nanobot.utils.media_decode import (
FileSizeExceeded, FileSizeExceeded,
save_base64_data_url, save_base64_data_url,
@ -70,6 +65,11 @@ from nanobot.webui.cli_apps_api import (
cli_apps_payload, cli_apps_payload,
normalize_cli_app_mentions, normalize_cli_app_mentions,
) )
from nanobot.webui.media_api import (
serve_signed_media,
sign_media_path,
sign_or_stage_media_path,
)
from nanobot.webui.mcp_presets_api import ( from nanobot.webui.mcp_presets_api import (
mcp_presets_settings_action, mcp_presets_settings_action,
normalize_mcp_preset_mentions, normalize_mcp_preset_mentions,
@ -514,59 +514,6 @@ def _is_websocket_upgrade(request: WsRequest) -> bool:
return True return True
def _b64url_encode(data: bytes) -> str:
"""URL-safe base64 without padding — compact + friendly in URL paths."""
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def _b64url_decode(s: str) -> bytes:
"""Reverse of :func:`_b64url_encode`; caller handles ``ValueError``."""
pad = "=" * (-len(s) % 4)
return base64.urlsafe_b64decode(s + pad)
# Allowed MIME types we actually serve from the media endpoint. Anything
# outside this set is degraded to ``application/octet-stream`` so an
# attacker who somehow gets a signed URL for an unexpected file type can't
# trick the browser into sniffing executable content.
_MEDIA_ALLOWED_MIMES: frozenset[str] = frozenset({
"image/png",
"image/jpeg",
"image/webp",
"image/gif",
"video/mp4",
"video/webm",
"video/quicktime",
})
_BYTE_RANGE_RE = re.compile(r"^bytes=(\d*)-(\d*)$")
def _parse_single_byte_range(range_header: str, size: int) -> tuple[int, int]:
"""Parse a single HTTP byte range for signed media responses."""
if size <= 0 or "," in range_header:
raise ValueError("invalid byte range")
m = _BYTE_RANGE_RE.fullmatch(range_header.strip())
if m is None:
raise ValueError("invalid byte range")
start_text, end_text = m.groups()
if not start_text and not end_text:
raise ValueError("invalid byte range")
if not start_text:
suffix_length = int(end_text)
if suffix_length <= 0:
raise ValueError("invalid byte range")
start = max(size - suffix_length, 0)
end = size - 1
else:
start = int(start_text)
end = int(end_text) if end_text else size - 1
if start >= size or start > end:
raise ValueError("invalid byte range")
end = min(end, size - 1)
return start, end
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool: def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``.""" """Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
if not configured_secret: if not configured_secret:
@ -1368,16 +1315,11 @@ class WebSocketChannel(BaseChannel):
be fetched. The returned path is relative to the server origin; the be fetched. The returned path is relative to the server origin; the
client joins it against this server's HTTP origin (same host as WS). client joins it against this server's HTTP origin (same host as WS).
""" """
try: return sign_media_path(
media_root = get_media_dir().resolve() abs_path,
rel = abs_path.resolve().relative_to(media_root) secret=self._media_secret,
except (OSError, ValueError): media_dir=lambda channel=None: get_media_dir(channel),
return None )
payload = _b64url_encode(rel.as_posix().encode("utf-8"))
mac = hmac.new(
self._media_secret, payload.encode("ascii"), hashlib.sha256
).digest()[:16]
return f"/api/media/{_b64url_encode(mac)}/{payload}"
def _sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None: def _sign_or_stage_media_path(self, path: Path) -> dict[str, str] | None:
"""Return a signed media URL payload for *path*. """Return a signed media URL payload for *path*.
@ -1388,23 +1330,12 @@ class WebSocketChannel(BaseChannel):
can fetch them through the existing signed media route without can fetch them through the existing signed media route without
exposing arbitrary filesystem paths. exposing arbitrary filesystem paths.
""" """
signed = self._sign_media_path(path) return sign_or_stage_media_path(
if signed is not None: path,
return {"url": signed, "name": path.name} secret=self._media_secret,
try: media_dir=lambda channel=None: get_media_dir(channel),
if not path.is_file(): logger=self.logger,
return None )
media_dir = get_media_dir("websocket")
safe_name = safe_filename(path.name) or "attachment"
staged = media_dir / f"{uuid.uuid4().hex[:12]}-{safe_name}"
shutil.copyfile(path, staged)
except OSError as exc:
self.logger.warning("failed to stage outbound media {}: {}", path, exc)
return None
signed = self._sign_media_path(staged)
if signed is None:
return None
return {"url": signed, "name": path.name}
def _rewrite_local_markdown_images(self, text: str) -> str: def _rewrite_local_markdown_images(self, text: str) -> str:
return rewrite_local_markdown_images( return rewrite_local_markdown_images(
@ -1421,86 +1352,12 @@ class WebSocketChannel(BaseChannel):
payload to a relative path, and streams the file bytes with a payload to a relative path, and streams the file bytes with a
long-lived immutable cache header (the URL already encodes the long-lived immutable cache header (the URL already encodes the
file identity, so caches can be aggressive).""" file identity, so caches can be aggressive)."""
try: return serve_signed_media(
provided_mac = _b64url_decode(sig) sig,
except (ValueError, binascii.Error): payload,
return _http_error(401, "invalid signature") secret=self._media_secret,
expected_mac = hmac.new( request=request,
self._media_secret, payload.encode("ascii"), hashlib.sha256 media_dir=lambda channel=None: get_media_dir(channel),
).digest()[:16]
if not hmac.compare_digest(expected_mac, provided_mac):
return _http_error(401, "invalid signature")
try:
rel_bytes = _b64url_decode(payload)
rel_str = rel_bytes.decode("utf-8")
except (ValueError, binascii.Error, UnicodeDecodeError):
return _http_error(400, "invalid payload")
# An attacker who somehow bypassed the HMAC check would still need
# the resolved path to escape the media root; guard defensively.
try:
media_root = get_media_dir().resolve()
candidate = (media_root / rel_str).resolve()
candidate.relative_to(media_root)
except (OSError, ValueError):
return _http_error(404, "not found")
if not candidate.is_file():
return _http_error(404, "not found")
mime, _ = mimetypes.guess_type(candidate.name)
if mime not in _MEDIA_ALLOWED_MIMES:
mime = "application/octet-stream"
common_headers = [
("Accept-Ranges", "bytes"),
("Cache-Control", "private, max-age=31536000, immutable"),
# Paired with the MIME whitelist above: prevents browsers from
# MIME-sniffing an octet-stream fallback into executable HTML.
("X-Content-Type-Options", "nosniff"),
]
try:
size = candidate.stat().st_size
except OSError:
return _http_error(500, "read error")
range_header = (
_case_insensitive_header(request.headers, "Range") if request else ""
)
if range_header:
try:
start, end = _parse_single_byte_range(range_header, size)
except ValueError:
return _http_response(
b"range not satisfiable",
status=416,
extra_headers=[
("Accept-Ranges", "bytes"),
("Content-Range", f"bytes */{size}"),
("X-Content-Type-Options", "nosniff"),
],
)
try:
length = end - start + 1
with candidate.open("rb") as fh:
fh.seek(start)
body = fh.read(length)
except OSError:
return _http_error(500, "read error")
return _http_response(
body,
status=206,
content_type=mime,
extra_headers=[
*common_headers,
("Content-Range", f"bytes {start}-{end}/{size}"),
],
)
try:
body = candidate.read_bytes()
except OSError:
return _http_error(500, "read error")
return _http_response(
body,
content_type=mime,
extra_headers=common_headers,
) )
def _handle_session_delete(self, request: WsRequest, key: str) -> Response: def _handle_session_delete(self, request: WsRequest, key: str) -> Response:

246
nanobot/webui/media_api.py Normal file
View File

@ -0,0 +1,246 @@
"""Signed media helpers for the WebUI HTTP surface."""
from __future__ import annotations
import base64
import binascii
import email.utils
import hashlib
import hmac
import http
import mimetypes
import re
import shutil
import uuid
from collections.abc import Callable
from pathlib import Path
from typing import Any
from websockets.datastructures import Headers
from websockets.http11 import Request as WsRequest
from websockets.http11 import Response
from nanobot.config.paths import get_media_dir
from nanobot.utils.helpers import safe_filename
MediaDirProvider = Callable[[str | None], Path]
def b64url_encode(data: bytes) -> str:
"""URL-safe base64 without padding."""
return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")
def b64url_decode(value: str) -> bytes:
"""Reverse of :func:`b64url_encode`; caller handles decode errors."""
pad = "=" * (-len(value) % 4)
return base64.urlsafe_b64decode(value + pad)
def _default_media_dir(channel: str | None = None) -> Path:
return get_media_dir(channel)
# Allowed MIME types we actually serve from the media endpoint. Anything
# outside this set is degraded to ``application/octet-stream`` so an
# attacker who somehow gets a signed URL for an unexpected file type can't
# trick the browser into sniffing executable content.
_MEDIA_ALLOWED_MIMES: frozenset[str] = frozenset({
"image/png",
"image/jpeg",
"image/webp",
"image/gif",
"video/mp4",
"video/webm",
"video/quicktime",
})
_BYTE_RANGE_RE = re.compile(r"^bytes=(\d*)-(\d*)$")
def _http_response(
body: bytes,
*,
status: int = 200,
content_type: str = "text/plain; charset=utf-8",
extra_headers: list[tuple[str, str]] | None = None,
) -> Response:
headers = [
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", content_type),
]
if extra_headers:
headers.extend(extra_headers)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, Headers(headers), body)
def _http_error(status: int, message: str | None = None) -> Response:
body = (message or http.HTTPStatus(status).phrase).encode("utf-8")
return _http_response(body, status=status)
def _case_insensitive_header(headers: Any, key: str) -> str:
try:
value = headers.get(key)
except Exception:
value = None
if value is None:
try:
value = headers.get(key.lower())
except Exception:
value = None
return str(value or "").strip()
def _parse_single_byte_range(range_header: str, size: int) -> tuple[int, int]:
"""Parse a single HTTP byte range for signed media responses."""
if size <= 0 or "," in range_header:
raise ValueError("invalid byte range")
m = _BYTE_RANGE_RE.fullmatch(range_header.strip())
if m is None:
raise ValueError("invalid byte range")
start_text, end_text = m.groups()
if not start_text and not end_text:
raise ValueError("invalid byte range")
if not start_text:
suffix_length = int(end_text)
if suffix_length <= 0:
raise ValueError("invalid byte range")
start = max(size - suffix_length, 0)
end = size - 1
else:
start = int(start_text)
end = int(end_text) if end_text else size - 1
if start >= size or start > end:
raise ValueError("invalid byte range")
end = min(end, size - 1)
return start, end
def sign_media_path(
abs_path: Path,
*,
secret: bytes,
media_dir: MediaDirProvider = _default_media_dir,
) -> str | None:
"""Return a signed ``/api/media/<sig>/<payload>`` URL for a media-root path."""
try:
media_root = media_dir(None).resolve()
rel = abs_path.resolve().relative_to(media_root)
except (OSError, ValueError):
return None
payload = b64url_encode(rel.as_posix().encode("utf-8"))
mac = hmac.new(secret, payload.encode("ascii"), hashlib.sha256).digest()[:16]
return f"/api/media/{b64url_encode(mac)}/{payload}"
def sign_or_stage_media_path(
path: Path,
*,
secret: bytes,
media_dir: MediaDirProvider = _default_media_dir,
logger: Any | None = None,
) -> dict[str, str] | None:
"""Sign an existing media-root path, or stage an arbitrary file before signing."""
signed = sign_media_path(path, secret=secret, media_dir=media_dir)
if signed is not None:
return {"url": signed, "name": path.name}
try:
if not path.is_file():
return None
target_dir = media_dir("websocket")
safe_name = safe_filename(path.name) or "attachment"
staged = target_dir / f"{uuid.uuid4().hex[:12]}-{safe_name}"
shutil.copyfile(path, staged)
except OSError as exc:
if logger is not None:
logger.warning("failed to stage outbound media {}: {}", path, exc)
return None
signed = sign_media_path(staged, secret=secret, media_dir=media_dir)
if signed is None:
return None
return {"url": signed, "name": path.name}
def serve_signed_media(
sig: str,
payload: str,
*,
secret: bytes,
request: WsRequest | None = None,
media_dir: MediaDirProvider = _default_media_dir,
) -> Response:
"""Serve a signed media URL, including browser-friendly byte ranges."""
try:
provided_mac = b64url_decode(sig)
except (ValueError, binascii.Error):
return _http_error(401, "invalid signature")
expected_mac = hmac.new(secret, payload.encode("ascii"), hashlib.sha256).digest()[:16]
if not hmac.compare_digest(expected_mac, provided_mac):
return _http_error(401, "invalid signature")
try:
rel_bytes = b64url_decode(payload)
rel_str = rel_bytes.decode("utf-8")
except (ValueError, binascii.Error, UnicodeDecodeError):
return _http_error(400, "invalid payload")
try:
media_root = media_dir(None).resolve()
candidate = (media_root / rel_str).resolve()
candidate.relative_to(media_root)
except (OSError, ValueError):
return _http_error(404, "not found")
if not candidate.is_file():
return _http_error(404, "not found")
mime, _ = mimetypes.guess_type(candidate.name)
if mime not in _MEDIA_ALLOWED_MIMES:
mime = "application/octet-stream"
common_headers = [
("Accept-Ranges", "bytes"),
("Cache-Control", "private, max-age=31536000, immutable"),
("X-Content-Type-Options", "nosniff"),
]
try:
size = candidate.stat().st_size
except OSError:
return _http_error(500, "read error")
range_header = _case_insensitive_header(request.headers, "Range") if request else ""
if range_header:
try:
start, end = _parse_single_byte_range(range_header, size)
except ValueError:
return _http_response(
b"range not satisfiable",
status=416,
extra_headers=[
("Accept-Ranges", "bytes"),
("Content-Range", f"bytes */{size}"),
("X-Content-Type-Options", "nosniff"),
],
)
try:
length = end - start + 1
with candidate.open("rb") as fh:
fh.seek(start)
body = fh.read(length)
except OSError:
return _http_error(500, "read error")
return _http_response(
body,
status=206,
content_type=mime,
extra_headers=[
*common_headers,
("Content-Range", f"bytes {start}-{end}/{size}"),
],
)
try:
body = candidate.read_bytes()
except OSError:
return _http_error(500, "read error")
return _http_response(body, content_type=mime, extra_headers=common_headers)

View File

@ -21,10 +21,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import httpx import httpx
import pytest import pytest
from nanobot.channels.websocket import ( from nanobot.channels.websocket import WebSocketChannel
WebSocketChannel, from nanobot.webui.media_api import (
_b64url_decode, b64url_decode,
_b64url_encode, b64url_encode,
) )
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import Session, SessionManager
@ -129,9 +129,9 @@ def test_sign_media_path_round_trips_via_hmac(
expected = hmac.new( expected = hmac.new(
channel._media_secret, payload.encode("ascii"), hashlib.sha256 channel._media_secret, payload.encode("ascii"), hashlib.sha256
).digest()[:16] ).digest()[:16]
assert _b64url_decode(sig) == expected assert b64url_decode(sig) == expected
# The payload decodes back to the *relative* path — no absolute-path leaks. # The payload decodes back to the *relative* path — no absolute-path leaks.
assert _b64url_decode(payload).decode() == "a.png" assert b64url_decode(payload).decode() == "a.png"
def test_local_markdown_image_is_staged_and_rewritten( def test_local_markdown_image_is_staged_and_rewritten(
@ -346,7 +346,7 @@ async def test_media_route_rejects_bad_signature(
forged_mac = hmac.new( forged_mac = hmac.new(
b"\x00" * 32, payload.encode("ascii"), hashlib.sha256 b"\x00" * 32, payload.encode("ascii"), hashlib.sha256
).digest()[:16] ).digest()[:16]
forged = f"/api/media/{_b64url_encode(forged_mac)}/{payload}" forged = f"/api/media/{b64url_encode(forged_mac)}/{payload}"
server_task = asyncio.create_task(channel.start()) server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
@ -375,11 +375,11 @@ async def test_media_route_rejects_path_traversal_payload(
channel = _ch(bus, port=29922) channel = _ch(bus, port=29922)
# Hand-craft a traversal payload the legit signer would refuse to mint. # Hand-craft a traversal payload the legit signer would refuse to mint.
payload = _b64url_encode(b"../secret.txt") payload = b64url_encode(b"../secret.txt")
mac = hmac.new( mac = hmac.new(
channel._media_secret, payload.encode("ascii"), hashlib.sha256 channel._media_secret, payload.encode("ascii"), hashlib.sha256
).digest()[:16] ).digest()[:16]
url = f"/api/media/{_b64url_encode(mac)}/{payload}" url = f"/api/media/{b64url_encode(mac)}/{payload}"
with patch("nanobot.channels.websocket.get_media_dir", return_value=media): with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
server_task = asyncio.create_task(channel.start()) server_task = asyncio.create_task(channel.start())
@ -434,11 +434,11 @@ async def test_media_route_degrades_non_image_to_octet_stream(
channel = _ch(bus, port=29924) channel = _ch(bus, port=29924)
with patch("nanobot.channels.websocket.get_media_dir", return_value=media): with patch("nanobot.channels.websocket.get_media_dir", return_value=media):
payload = _b64url_encode(b"scary.html") payload = b64url_encode(b"scary.html")
mac = hmac.new( mac = hmac.new(
channel._media_secret, payload.encode("ascii"), hashlib.sha256 channel._media_secret, payload.encode("ascii"), hashlib.sha256
).digest()[:16] ).digest()[:16]
url = f"/api/media/{_b64url_encode(mac)}/{payload}" url = f"/api/media/{b64url_encode(mac)}/{payload}"
server_task = asyncio.create_task(channel.start()) server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3) await asyncio.sleep(0.3)
try: try: