mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-14 06:43:53 +00:00
Maintainer edit for PR 4115: rebase onto origin/main and split gateway HTTP routing from token, media, and workspace services so WebSocketChannel depends on explicit gateway services instead of GatewayHTTPHandler internals. Preserve file edit channel capabilities and restore tools.restrict_to_workspace wiring through ChannelManager.
470 lines
15 KiB
Python
470 lines
15 KiB
Python
"""Tests for WS envelope media handling (client image upload path).
|
||
|
||
Exercises ``WebSocketChannel._dispatch_envelope`` for the ``message`` branch:
|
||
decoding base64 data URLs, rejecting malformed / oversized / non-whitelisted
|
||
payloads, preserving backward compatibility with media-less frames, and
|
||
forwarding saved paths to ``_handle_message``.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import base64
|
||
import json
|
||
from pathlib import Path
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from nanobot.channels.websocket import (
|
||
WebSocketChannel,
|
||
WebSocketConfig,
|
||
_extract_data_url_mime,
|
||
)
|
||
from nanobot.webui.gateway_services import build_gateway_services
|
||
|
||
|
||
def _tiny_png_data_url() -> str:
|
||
"""A 1-pixel PNG prefixed as a data URL — just enough for magic-bytes sniffing."""
|
||
# 1x1 transparent PNG
|
||
png = (
|
||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00"
|
||
b"\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx"
|
||
b"\x9cc\xf8\xcf\xc0\x00\x00\x00\x03\x00\x01\x00\x18\xdd\x8d\xb4\x00"
|
||
b"\x00\x00\x00IEND\xaeB`\x82"
|
||
)
|
||
return f"data:image/png;base64,{base64.b64encode(png).decode()}"
|
||
|
||
|
||
def _data_url(mime: str, payload: bytes) -> str:
|
||
return f"data:{mime};base64,{base64.b64encode(payload).decode()}"
|
||
|
||
|
||
def _make_channel() -> WebSocketChannel:
|
||
bus = MagicMock()
|
||
bus.publish_inbound = AsyncMock()
|
||
cfg = {"enabled": True, "allowFrom": ["*"], "websocketRequiresToken": False}
|
||
parsed = WebSocketConfig.model_validate(cfg)
|
||
gateway = build_gateway_services(
|
||
config=parsed,
|
||
bus=bus,
|
||
session_manager=None,
|
||
static_dist_path=None,
|
||
workspace_path=Path.cwd(),
|
||
default_restrict_to_workspace=False,
|
||
runtime_model_name=None,
|
||
runtime_surface="browser",
|
||
runtime_capabilities_overrides=None,
|
||
)
|
||
channel = WebSocketChannel(cfg, bus, gateway=gateway)
|
||
channel._handle_message = AsyncMock() # type: ignore[method-assign]
|
||
return channel
|
||
|
||
|
||
# -- Pure helpers --------------------------------------------------------------
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
("url", "expected"),
|
||
[
|
||
("data:image/png;base64,AAAA", "image/png"),
|
||
("data:image/jpeg;base64,AAAA", "image/jpeg"),
|
||
("data:IMAGE/PNG;base64,AAAA", "image/png"),
|
||
("data:image/svg+xml;base64,AAAA", "image/svg+xml"),
|
||
("data:text/plain;base64,AAAA", "text/plain"),
|
||
("http://evil.example/x.png", None),
|
||
("data:image/png,AAAA", None), # missing `;base64`
|
||
("", None),
|
||
(None, None),
|
||
],
|
||
)
|
||
def test_extract_data_url_mime(url: Any, expected: str | None) -> None:
|
||
assert _extract_data_url_mime(url) == expected
|
||
|
||
|
||
# -- max_message_bytes bump ----------------------------------------------------
|
||
|
||
|
||
def test_max_message_bytes_default_supports_multi_image_frame() -> None:
|
||
"""Default 36 MB must comfortably hold 4 × 6 MB base64-encoded images."""
|
||
from nanobot.channels.websocket import WebSocketConfig
|
||
|
||
default = WebSocketConfig().max_message_bytes
|
||
# 4 images × 6 MB × 1.37 base64 overhead ≈ 33 MB
|
||
assert default >= 33 * 1024 * 1024
|
||
# Upper bound 40 MB matches plan
|
||
with pytest.raises(Exception):
|
||
WebSocketConfig(max_message_bytes=41_943_040 + 1)
|
||
|
||
|
||
# -- _dispatch_envelope message branch + media --------------------------------
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_without_media_backward_compatible() -> None:
|
||
"""Existing clients that don't send ``media`` keep working unchanged."""
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {"type": "message", "chat_id": "abc123", "content": "hello"}
|
||
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_awaited_once()
|
||
call = channel._handle_message.call_args
|
||
assert call.kwargs["chat_id"] == "abc123"
|
||
assert call.kwargs["content"] == "hello"
|
||
# When no media, we pass ``media=None`` so downstream treats it as absent.
|
||
assert call.kwargs["media"] is None
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_forwards_normalized_cli_app_attachments() -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "please use @drawio",
|
||
"webui": True,
|
||
"cli_apps": [
|
||
{
|
||
"name": "DrawIO",
|
||
"display_name": "Draw.io",
|
||
"category": "diagram",
|
||
"entry_point": "cli-anything-drawio",
|
||
"logo_url": "https://example.invalid/drawio.svg",
|
||
"brand_color": "#F08705",
|
||
},
|
||
{"name": "bad name", "entry_point": "nope"},
|
||
],
|
||
}
|
||
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_awaited_once()
|
||
metadata = channel._handle_message.call_args.kwargs["metadata"]
|
||
assert metadata["webui"] is True
|
||
assert metadata["cli_apps"] == [{
|
||
"name": "drawio",
|
||
"display_name": "Draw.io",
|
||
"category": "diagram",
|
||
"entry_point": "cli-anything-drawio",
|
||
"logo_url": "https://example.invalid/drawio.svg",
|
||
"brand_color": "#F08705",
|
||
}]
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_with_single_image_forwards_saved_path(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "look at this",
|
||
"media": [{"data_url": _tiny_png_data_url(), "name": "shot.png"}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_awaited_once()
|
||
paths = channel._handle_message.call_args.kwargs["media"]
|
||
assert isinstance(paths, list) and len(paths) == 1
|
||
saved = Path(paths[0])
|
||
assert saved.exists()
|
||
assert saved.suffix == ".png"
|
||
assert saved.is_relative_to(tmp_path)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_with_multiple_images(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "a couple",
|
||
"media": [
|
||
{"data_url": _tiny_png_data_url()},
|
||
{"data_url": _tiny_png_data_url()},
|
||
{"data_url": _tiny_png_data_url()},
|
||
],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
paths = channel._handle_message.call_args.kwargs["media"]
|
||
assert len(paths) == 3
|
||
# Saved filenames must be unique.
|
||
assert len({Path(p).name for p in paths}) == 3
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_image_only_message_allows_empty_text(tmp_path) -> None:
|
||
"""When media is attached, empty text is acceptable."""
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "",
|
||
"media": [{"data_url": _tiny_png_data_url()}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_awaited_once()
|
||
# Error event NOT sent.
|
||
mock_conn.send.assert_not_awaited()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_when_more_than_four_images(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "hi",
|
||
"media": [{"data_url": _tiny_png_data_url()}] * 5,
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
mock_conn.send.assert_awaited_once()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["event"] == "error"
|
||
assert err["detail"] == "image_rejected"
|
||
assert err["reason"] == "too_many_images"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_on_oversize_payload(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
oversized = b"x" * (9 * 1024 * 1024) # > 8 MB WS limit
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "big",
|
||
"media": [{"data_url": _data_url("image/png", oversized)}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["detail"] == "image_rejected"
|
||
assert err["reason"] == "size"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_on_non_image_mime(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "pdf?",
|
||
"media": [{"data_url": _data_url("application/pdf", b"%PDF-1.4")}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["detail"] == "image_rejected"
|
||
assert err["reason"] == "mime"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_on_svg_mime(tmp_path) -> None:
|
||
"""SVG is explicitly rejected — XSS surface inside embedded scripts."""
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "svg",
|
||
"media": [{"data_url": _data_url("image/svg+xml", b"<svg/>")}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["reason"] == "mime"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_on_malformed_data_url(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "nope",
|
||
"media": [{"data_url": "http://evil.example/image.png"}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["reason"] == "decode"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_on_broken_base64(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "nope",
|
||
"media": [{"data_url": "data:image/png;base64,not-valid-base64!!!"}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["reason"] == "decode"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_when_media_item_shape_wrong(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "huh",
|
||
# Not a dict — plain string at the top level.
|
||
"media": ["data:image/png;base64,XXXX"],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["reason"] == "malformed"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_when_media_field_is_not_list() -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "huh",
|
||
"media": "not-a-list",
|
||
}
|
||
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["detail"] == "image_rejected"
|
||
assert err["reason"] == "malformed"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_failed_media_does_not_partially_persist(tmp_path) -> None:
|
||
"""If the second image is invalid, the first must not be forwarded.
|
||
|
||
Also: images already written in this call are cleaned up on failure, so
|
||
a mixed-valid/invalid batch never leaves orphan files in the media dir.
|
||
"""
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "mixed",
|
||
"media": [
|
||
{"data_url": _tiny_png_data_url()},
|
||
{"data_url": _data_url("application/pdf", b"%PDF-1.4")},
|
||
],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["reason"] == "mime"
|
||
# Partial-batch failures must not leak files to disk.
|
||
leftover = [p for p in tmp_path.iterdir() if p.is_file()]
|
||
assert leftover == [], f"orphan media after rejected batch: {leftover}"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_rejects_empty_text_without_media() -> None:
|
||
"""When no media is attached, whitespace-only content is still rejected
|
||
(matches the existing behavior for backward compat)."""
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": " ",
|
||
}
|
||
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["detail"] == "missing content"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_non_string_content_still_rejected() -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": 42,
|
||
}
|
||
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["detail"] == "missing content"
|