mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-30 14:56:01 +00:00
421 lines
13 KiB
Python
421 lines
13 KiB
Python
"""Tests for WS envelope media handling (client image upload path).
|
||
|
||
Exercises ``WebSocketChannel._dispatch_envelope`` for the ``message`` branch:
|
||
decoding base64 data URLs, rejecting malformed / oversized / non-whitelisted
|
||
payloads, preserving backward compatibility with media-less frames, and
|
||
forwarding saved paths to ``_handle_message``.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import base64
|
||
import json
|
||
from pathlib import Path
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from nanobot.channels.websocket import (
|
||
WebSocketChannel,
|
||
_extract_data_url_mime,
|
||
)
|
||
|
||
|
||
def _tiny_png_data_url() -> str:
|
||
"""A 1-pixel PNG prefixed as a data URL — just enough for magic-bytes sniffing."""
|
||
# 1x1 transparent PNG
|
||
png = (
|
||
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00"
|
||
b"\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx"
|
||
b"\x9cc\xf8\xcf\xc0\x00\x00\x00\x03\x00\x01\x00\x18\xdd\x8d\xb4\x00"
|
||
b"\x00\x00\x00IEND\xaeB`\x82"
|
||
)
|
||
return f"data:image/png;base64,{base64.b64encode(png).decode()}"
|
||
|
||
|
||
def _data_url(mime: str, payload: bytes) -> str:
|
||
return f"data:{mime};base64,{base64.b64encode(payload).decode()}"
|
||
|
||
|
||
def _make_channel() -> WebSocketChannel:
|
||
bus = MagicMock()
|
||
bus.publish_inbound = AsyncMock()
|
||
channel = WebSocketChannel(
|
||
{"enabled": True, "allowFrom": ["*"], "websocketRequiresToken": False},
|
||
bus,
|
||
)
|
||
channel._handle_message = AsyncMock() # type: ignore[method-assign]
|
||
return channel
|
||
|
||
|
||
# -- Pure helpers --------------------------------------------------------------
|
||
|
||
|
||
@pytest.mark.parametrize(
|
||
("url", "expected"),
|
||
[
|
||
("data:image/png;base64,AAAA", "image/png"),
|
||
("data:image/jpeg;base64,AAAA", "image/jpeg"),
|
||
("data:IMAGE/PNG;base64,AAAA", "image/png"),
|
||
("data:image/svg+xml;base64,AAAA", "image/svg+xml"),
|
||
("data:text/plain;base64,AAAA", "text/plain"),
|
||
("http://evil.example/x.png", None),
|
||
("data:image/png,AAAA", None), # missing `;base64`
|
||
("", None),
|
||
(None, None),
|
||
],
|
||
)
|
||
def test_extract_data_url_mime(url: Any, expected: str | None) -> None:
|
||
assert _extract_data_url_mime(url) == expected
|
||
|
||
|
||
# -- max_message_bytes bump ----------------------------------------------------
|
||
|
||
|
||
def test_max_message_bytes_default_supports_multi_image_frame() -> None:
|
||
"""Default 36 MB must comfortably hold 4 × 6 MB base64-encoded images."""
|
||
from nanobot.channels.websocket import WebSocketConfig
|
||
|
||
default = WebSocketConfig().max_message_bytes
|
||
# 4 images × 6 MB × 1.37 base64 overhead ≈ 33 MB
|
||
assert default >= 33 * 1024 * 1024
|
||
# Upper bound 40 MB matches plan
|
||
with pytest.raises(Exception):
|
||
WebSocketConfig(max_message_bytes=41_943_040 + 1)
|
||
|
||
|
||
# -- _dispatch_envelope message branch + media --------------------------------
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_without_media_backward_compatible() -> None:
|
||
"""Existing clients that don't send ``media`` keep working unchanged."""
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {"type": "message", "chat_id": "abc123", "content": "hello"}
|
||
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_awaited_once()
|
||
call = channel._handle_message.call_args
|
||
assert call.kwargs["chat_id"] == "abc123"
|
||
assert call.kwargs["content"] == "hello"
|
||
# When no media, we pass ``media=None`` so downstream treats it as absent.
|
||
assert call.kwargs["media"] is None
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_with_single_image_forwards_saved_path(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "look at this",
|
||
"media": [{"data_url": _tiny_png_data_url(), "name": "shot.png"}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_awaited_once()
|
||
paths = channel._handle_message.call_args.kwargs["media"]
|
||
assert isinstance(paths, list) and len(paths) == 1
|
||
saved = Path(paths[0])
|
||
assert saved.exists()
|
||
assert saved.suffix == ".png"
|
||
assert saved.is_relative_to(tmp_path)
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_with_multiple_images(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "a couple",
|
||
"media": [
|
||
{"data_url": _tiny_png_data_url()},
|
||
{"data_url": _tiny_png_data_url()},
|
||
{"data_url": _tiny_png_data_url()},
|
||
],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
paths = channel._handle_message.call_args.kwargs["media"]
|
||
assert len(paths) == 3
|
||
# Saved filenames must be unique.
|
||
assert len({Path(p).name for p in paths}) == 3
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_image_only_message_allows_empty_text(tmp_path) -> None:
|
||
"""When media is attached, empty text is acceptable."""
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "",
|
||
"media": [{"data_url": _tiny_png_data_url()}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_awaited_once()
|
||
# Error event NOT sent.
|
||
mock_conn.send.assert_not_awaited()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_when_more_than_four_images(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "hi",
|
||
"media": [{"data_url": _tiny_png_data_url()}] * 5,
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
mock_conn.send.assert_awaited_once()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["event"] == "error"
|
||
assert err["detail"] == "image_rejected"
|
||
assert err["reason"] == "too_many_images"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_on_oversize_payload(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
oversized = b"x" * (9 * 1024 * 1024) # > 8 MB WS limit
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "big",
|
||
"media": [{"data_url": _data_url("image/png", oversized)}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["detail"] == "image_rejected"
|
||
assert err["reason"] == "size"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_on_non_image_mime(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "pdf?",
|
||
"media": [{"data_url": _data_url("application/pdf", b"%PDF-1.4")}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["detail"] == "image_rejected"
|
||
assert err["reason"] == "mime"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_on_svg_mime(tmp_path) -> None:
|
||
"""SVG is explicitly rejected — XSS surface inside embedded scripts."""
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "svg",
|
||
"media": [{"data_url": _data_url("image/svg+xml", b"<svg/>")}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["reason"] == "mime"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_on_malformed_data_url(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "nope",
|
||
"media": [{"data_url": "http://evil.example/image.png"}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["reason"] == "decode"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_on_broken_base64(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "nope",
|
||
"media": [{"data_url": "data:image/png;base64,not-valid-base64!!!"}],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["reason"] == "decode"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_when_media_item_shape_wrong(tmp_path) -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "huh",
|
||
# Not a dict — plain string at the top level.
|
||
"media": ["data:image/png;base64,XXXX"],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["reason"] == "malformed"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_message_rejected_when_media_field_is_not_list() -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "huh",
|
||
"media": "not-a-list",
|
||
}
|
||
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["detail"] == "image_rejected"
|
||
assert err["reason"] == "malformed"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_failed_media_does_not_partially_persist(tmp_path) -> None:
|
||
"""If the second image is invalid, the first must not be forwarded.
|
||
|
||
Also: images already written in this call are cleaned up on failure, so
|
||
a mixed-valid/invalid batch never leaves orphan files in the media dir.
|
||
"""
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": "mixed",
|
||
"media": [
|
||
{"data_url": _tiny_png_data_url()},
|
||
{"data_url": _data_url("application/pdf", b"%PDF-1.4")},
|
||
],
|
||
}
|
||
|
||
with patch(
|
||
"nanobot.channels.websocket.get_media_dir", return_value=tmp_path
|
||
):
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["reason"] == "mime"
|
||
# Partial-batch failures must not leak files to disk.
|
||
leftover = [p for p in tmp_path.iterdir() if p.is_file()]
|
||
assert leftover == [], f"orphan media after rejected batch: {leftover}"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_rejects_empty_text_without_media() -> None:
|
||
"""When no media is attached, whitespace-only content is still rejected
|
||
(matches the existing behavior for backward compat)."""
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": " ",
|
||
}
|
||
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["detail"] == "missing content"
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_non_string_content_still_rejected() -> None:
|
||
channel = _make_channel()
|
||
mock_conn = AsyncMock()
|
||
envelope = {
|
||
"type": "message",
|
||
"chat_id": "abc123",
|
||
"content": 42,
|
||
}
|
||
|
||
await channel._dispatch_envelope(mock_conn, "client-1", envelope)
|
||
|
||
channel._handle_message.assert_not_awaited()
|
||
err = json.loads(mock_conn.send.call_args[0][0])
|
||
assert err["detail"] == "missing content"
|