mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-30 06:45:55 +00:00
Bug 1: _drain_pending did not call extract_documents on follow-up messages arriving mid-turn. Documents attached to queued messages were silently dropped because _build_user_content only handles images. Fix: call extract_documents before _build_user_content in _drain_pending. Bug 2: extract_documents read the entire file into memory (up to 50 MB) just to check 16 bytes of magic header for MIME detection. Fix: read only the first 16 bytes via open()+read(16) instead of Path.read_bytes(). Added regression tests for both bugs. Made-with: Cursor
497 lines
16 KiB
Python
497 lines
16 KiB
Python
"""Tests for API file upload functionality (JSON base64 + multipart)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
from io import BytesIO
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from nanobot.api.server import (
|
|
_FileSizeExceeded,
|
|
_parse_json_content,
|
|
_save_base64_data_url,
|
|
create_app,
|
|
)
|
|
from nanobot.utils.document import extract_documents
|
|
|
|
try:
|
|
from aiohttp.test_utils import TestClient, TestServer
|
|
|
|
HAS_AIOHTTP = True
|
|
except ImportError:
|
|
HAS_AIOHTTP = False
|
|
|
|
pytest_plugins = ("pytest_asyncio",)
|
|
|
|
|
|
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
|
|
agent = MagicMock()
|
|
agent.process_direct = AsyncMock(return_value=response_text)
|
|
agent._connect_mcp = AsyncMock()
|
|
agent.close_mcp = AsyncMock()
|
|
return agent
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent():
|
|
return _make_mock_agent()
|
|
|
|
|
|
@pytest.fixture
|
|
def app(mock_agent):
|
|
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def aiohttp_client():
|
|
clients: list[TestClient] = []
|
|
|
|
async def _make_client(app):
|
|
client = TestClient(TestServer(app))
|
|
await client.start_server()
|
|
clients.append(client)
|
|
return client
|
|
|
|
try:
|
|
yield _make_client
|
|
finally:
|
|
for client in clients:
|
|
await client.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper function tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_save_base64_data_url_saves_png(tmp_path) -> None:
|
|
"""Saving a base64 data URL creates a file with correct extension."""
|
|
b64_data = base64.b64encode(b"fake png data").decode()
|
|
data_url = f"data:image/png;base64,{b64_data}"
|
|
result = _save_base64_data_url(data_url, tmp_path)
|
|
assert result is not None
|
|
assert result.endswith(".png")
|
|
assert (tmp_path / result.replace(str(tmp_path) + "/", "")).read_bytes() == b"fake png data"
|
|
|
|
|
|
def test_save_base64_data_url_handles_invalid_b64(tmp_path) -> None:
|
|
"""Invalid base64 returns None."""
|
|
result = _save_base64_data_url("data:image/png;base64,not-valid-base64!!!", tmp_path)
|
|
assert result is None
|
|
|
|
|
|
def test_save_base64_data_url_handles_unknown_mime(tmp_path) -> None:
|
|
"""Unknown MIME type defaults to .bin."""
|
|
b64_data = base64.b64encode(b"some data").decode()
|
|
data_url = f"data:unknown/type;base64,{b64_data}"
|
|
result = _save_base64_data_url(data_url, tmp_path)
|
|
assert result is not None
|
|
assert result.endswith(".bin")
|
|
|
|
|
|
def test_save_base64_data_url_rejects_oversized_payload(tmp_path) -> None:
|
|
"""Base64 uploads should respect the same per-file limit as multipart."""
|
|
large_payload = base64.b64encode(b"x" * (11 * 1024 * 1024)).decode()
|
|
data_url = f"data:image/png;base64,{large_payload}"
|
|
|
|
with pytest.raises(_FileSizeExceeded, match="10MB limit"):
|
|
_save_base64_data_url(data_url, tmp_path)
|
|
|
|
|
|
def test_parse_json_content_extracts_text_and_media(tmp_path) -> None:
|
|
"""Parse JSON with text + base64 image saves image and returns paths."""
|
|
b64_data = base64.b64encode(b"img").decode()
|
|
body = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "describe this"},
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_data}"}},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
import os
|
|
original_cwd = os.getcwd()
|
|
os.chdir(tmp_path)
|
|
|
|
try:
|
|
text, media_paths = _parse_json_content(body)
|
|
assert text == "describe this"
|
|
assert len(media_paths) == 1
|
|
finally:
|
|
os.chdir(original_cwd)
|
|
|
|
|
|
def test_parse_json_content_plain_text_only() -> None:
|
|
"""Plain text string content returns no media."""
|
|
body = {"messages": [{"role": "user", "content": "hello"}]}
|
|
text, media_paths = _parse_json_content(body)
|
|
assert text == "hello"
|
|
assert media_paths == []
|
|
|
|
|
|
def test_parse_json_content_validates_single_message() -> None:
|
|
"""Multiple messages raise ValueError."""
|
|
body = {
|
|
"messages": [
|
|
{"role": "user", "content": "first"},
|
|
{"role": "user", "content": "second"},
|
|
]
|
|
}
|
|
with pytest.raises(ValueError, match="single user message"):
|
|
_parse_json_content(body)
|
|
|
|
|
|
def test_parse_json_content_validates_user_role() -> None:
|
|
"""Non-user role raises ValueError."""
|
|
body = {"messages": [{"role": "system", "content": "you are a bot"}]}
|
|
with pytest.raises(ValueError, match="single user message"):
|
|
_parse_json_content(body)
|
|
|
|
|
|
def test_parse_json_content_rejects_oversized_base64_file(tmp_path) -> None:
|
|
"""Oversized JSON data URLs should fail before writing to disk."""
|
|
large_payload = base64.b64encode(b"x" * (11 * 1024 * 1024)).decode()
|
|
body = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "describe"},
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{large_payload}"}},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
import os
|
|
original_cwd = os.getcwd()
|
|
os.chdir(tmp_path)
|
|
|
|
try:
|
|
with pytest.raises(_FileSizeExceeded, match="10MB limit"):
|
|
_parse_json_content(body)
|
|
finally:
|
|
os.chdir(original_cwd)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Multipart upload tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_multipart_upload_saves_file(aiohttp_client, mock_agent, tmp_path) -> None:
|
|
"""Multipart upload saves file to media dir and passes path to process_direct."""
|
|
import os
|
|
original_cwd = os.getcwd()
|
|
os.chdir(tmp_path)
|
|
|
|
try:
|
|
app = create_app(mock_agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
file_data = b"test file content"
|
|
data = BytesIO(file_data)
|
|
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
data={"message": "analyze this", "files": data},
|
|
)
|
|
assert resp.status == 200
|
|
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
|
assert call_kwargs["content"] == "analyze this"
|
|
assert len(call_kwargs.get("media") or []) == 1
|
|
finally:
|
|
os.chdir(original_cwd)
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_multipart_multiple_files(aiohttp_client, mock_agent, tmp_path) -> None:
|
|
"""Multipart upload with multiple files saves all and passes paths."""
|
|
import os
|
|
original_cwd = os.getcwd()
|
|
os.chdir(tmp_path)
|
|
|
|
try:
|
|
app = create_app(mock_agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
# Note: aiohttp test client has limited multipart support
|
|
# This test verifies the basic flow
|
|
file_data = b"test content"
|
|
data = BytesIO(file_data)
|
|
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
data={"message": "analyze", "files": data},
|
|
)
|
|
assert resp.status == 200
|
|
finally:
|
|
os.chdir(original_cwd)
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_multipart_file_size_limit(aiohttp_client, mock_agent, tmp_path) -> None:
|
|
"""File exceeding MAX_FILE_SIZE returns 413."""
|
|
import os
|
|
original_cwd = os.getcwd()
|
|
os.chdir(tmp_path)
|
|
|
|
try:
|
|
app = create_app(mock_agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
# Create a file larger than 10MB
|
|
large_data = b"x" * (11 * 1024 * 1024)
|
|
data = BytesIO(large_data)
|
|
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
data={"message": "analyze", "files": data},
|
|
)
|
|
assert resp.status == 413
|
|
finally:
|
|
os.chdir(original_cwd)
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_multipart_defaults_text_when_missing(aiohttp_client, mock_agent, tmp_path) -> None:
|
|
"""Multipart without message field uses default text."""
|
|
import os
|
|
original_cwd = os.getcwd()
|
|
os.chdir(tmp_path)
|
|
|
|
try:
|
|
app = create_app(mock_agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
file_data = b"content"
|
|
data = BytesIO(file_data)
|
|
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
data={"files": data},
|
|
)
|
|
assert resp.status == 200
|
|
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
|
assert call_kwargs["content"] == "请分析上传的文件"
|
|
finally:
|
|
os.chdir(original_cwd)
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_multipart_with_session_id(aiohttp_client, mock_agent, tmp_path) -> None:
|
|
"""Multipart upload with session_id uses custom session key."""
|
|
import os
|
|
original_cwd = os.getcwd()
|
|
os.chdir(tmp_path)
|
|
|
|
try:
|
|
app = create_app(mock_agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
file_data = b"content"
|
|
data = BytesIO(file_data)
|
|
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
data={"message": "hello", "session_id": "my-session", "files": data},
|
|
)
|
|
assert resp.status == 200
|
|
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
|
assert call_kwargs["session_key"] == "api:my-session"
|
|
finally:
|
|
os.chdir(original_cwd)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Backward compatibility tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_plain_text_backward_compat(aiohttp_client, mock_agent) -> None:
|
|
"""Plain text JSON request (no media) works as before."""
|
|
app = create_app(mock_agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": "hello world"}]},
|
|
)
|
|
assert resp.status == 200
|
|
body = await resp.json()
|
|
assert body["choices"][0]["message"]["content"] == "mock response"
|
|
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
|
assert call_kwargs["content"] == "hello world"
|
|
assert call_kwargs.get("media") is None
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_json_base64_image_upload(aiohttp_client, mock_agent, tmp_path) -> None:
|
|
"""JSON request with base64 data URL saves file and passes path."""
|
|
import os
|
|
original_cwd = os.getcwd()
|
|
os.chdir(tmp_path)
|
|
|
|
try:
|
|
app = create_app(mock_agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
# Use valid base64 for a tiny PNG (1x1 transparent pixel)
|
|
tiny_png_b64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
|
|
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "what is this"},
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{tiny_png_b64}"}},
|
|
],
|
|
}
|
|
]
|
|
},
|
|
)
|
|
assert resp.status == 200
|
|
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
|
assert call_kwargs["content"] == "what is this"
|
|
assert len(call_kwargs.get("media", [])) == 1
|
|
finally:
|
|
os.chdir(original_cwd)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# extract_documents tests (now in nanobot.utils.document)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_extract_documents_separates_images_from_docs(tmp_path) -> None:
|
|
"""Images stay in media; document text is appended to content."""
|
|
from docx import Document
|
|
|
|
png = tmp_path / "chart.png"
|
|
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
|
|
|
doc = Document()
|
|
doc.add_paragraph("Quarterly revenue is $5M")
|
|
docx_path = tmp_path / "report.docx"
|
|
doc.save(docx_path)
|
|
|
|
text, image_paths = extract_documents("summarize", [str(png), str(docx_path)])
|
|
assert len(image_paths) == 1
|
|
assert image_paths[0] == str(png)
|
|
assert "Quarterly revenue" in text
|
|
assert "summarize" in text
|
|
|
|
|
|
def test_extract_documents_skips_extraction_errors(tmp_path, monkeypatch) -> None:
|
|
"""Document extraction errors should not leak into user text."""
|
|
bad_file = tmp_path / "broken.docx"
|
|
bad_file.write_text("not a docx", encoding="utf-8")
|
|
|
|
import nanobot.utils.document as _doc
|
|
monkeypatch.setattr(
|
|
_doc, "extract_text",
|
|
lambda _path: "[error: failed to extract DOCX: boom]",
|
|
)
|
|
|
|
text, image_paths = extract_documents("hello", [str(bad_file)])
|
|
assert text == "hello"
|
|
assert image_paths == []
|
|
|
|
|
|
def test_extract_documents_images_only(tmp_path) -> None:
|
|
"""When all files are images, text is unchanged and all paths kept."""
|
|
png = tmp_path / "a.png"
|
|
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
|
text, image_paths = extract_documents("describe", [str(png)])
|
|
assert text == "describe"
|
|
assert len(image_paths) == 1
|
|
|
|
|
|
def test_extract_documents_skips_oversized_files(tmp_path) -> None:
|
|
"""Files exceeding the size limit should be silently skipped."""
|
|
big = tmp_path / "huge.txt"
|
|
big.write_bytes(b"x" * 200)
|
|
|
|
text, image_paths = extract_documents("hello", [str(big)], max_file_size=100)
|
|
assert text == "hello"
|
|
assert image_paths == []
|
|
|
|
|
|
def test_extract_documents_does_not_read_full_file_for_mime(tmp_path) -> None:
|
|
"""MIME detection should only read header bytes, not the entire file."""
|
|
from pathlib import Path as _Path
|
|
|
|
big_txt = tmp_path / "big.txt"
|
|
big_txt.write_bytes(b"hello world " * 100_000) # ~1.2 MB
|
|
|
|
original_read_bytes = _Path.read_bytes
|
|
read_sizes: list[int] = []
|
|
|
|
def _tracking_read_bytes(self):
|
|
data = original_read_bytes(self)
|
|
read_sizes.append(len(data))
|
|
return data
|
|
|
|
import unittest.mock
|
|
with unittest.mock.patch.object(_Path, "read_bytes", _tracking_read_bytes):
|
|
extract_documents("test", [str(big_txt)])
|
|
|
|
# If the full file was read for MIME detection, read_sizes would
|
|
# contain a >1MB entry. After the fix, only a small header is read.
|
|
assert all(size <= 4096 for size in read_sizes), (
|
|
f"extract_documents read full file for MIME detection: sizes={read_sizes}"
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DOCX upload test — API saves file, loop layer extracts text
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_docx_upload_passes_media_path(aiohttp_client, tmp_path) -> None:
|
|
"""Uploaded DOCX is saved to disk and its path passed as media.
|
|
(Text extraction happens later in AgentLoop._process_message.)"""
|
|
agent = _make_mock_agent("report summary")
|
|
import os
|
|
original_cwd = os.getcwd()
|
|
os.chdir(tmp_path)
|
|
|
|
try:
|
|
app = create_app(agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
from docx import Document
|
|
doc = Document()
|
|
doc.add_paragraph("Total revenue: $5,000,000")
|
|
buf = BytesIO()
|
|
doc.save(buf)
|
|
|
|
import aiohttp
|
|
data = aiohttp.FormData()
|
|
data.add_field("message", "summarize the report")
|
|
data.add_field("files", buf.getvalue(), filename="report.docx",
|
|
content_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document")
|
|
|
|
resp = await client.post("/v1/chat/completions", data=data)
|
|
assert resp.status == 200
|
|
call_kwargs = agent.process_direct.call_args.kwargs
|
|
assert call_kwargs["content"] == "summarize the report"
|
|
media = call_kwargs.get("media", [])
|
|
assert len(media) == 1
|
|
assert "report.docx" in media[0]
|
|
finally:
|
|
os.chdir(original_cwd)
|