feat(api): support file uploads via JSON base64 and multipart/form-data

This commit is contained in:
dengjingren 2026-04-08 15:28:36 +08:00
parent e21ba5f667
commit a068df5a79
10 changed files with 1188 additions and 65 deletions

View File

@ -1757,6 +1757,7 @@ By default, the API binds to `127.0.0.1:8900`. You can change this in `config.js
- Single-message input: each request must contain exactly one `user` message
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
- No streaming: `stream=true` is not supported
- **File uploads**: supports images, PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) via JSON base64 or `multipart/form-data` (max 10MB per file)
### Endpoints
@ -1775,6 +1776,44 @@ curl http://127.0.0.1:8900/v1/chat/completions \
}'
```
### File Upload (JSON base64)
Send images inline using the OpenAI multimodal content format:
```bash
curl http://127.0.0.1:8900/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": [
{"type": "text", "text": "Describe this image"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}}
]}]
}'
```
### File Upload (multipart/form-data)
Upload any supported file type (images, PDF, Word, Excel, PPT) via multipart:
```bash
# Single file
curl http://127.0.0.1:8900/v1/chat/completions \
-F "message=Summarize this report" \
-F "files=@report.docx"
# Multiple files with session isolation
curl http://127.0.0.1:8900/v1/chat/completions \
-F "message=Compare these files" \
-F "files=@chart.png" \
-F "files=@data.xlsx" \
-F "session_id=my-session"
```
Supported file types:
- **Images**: PNG, JPEG, GIF, WebP (sent to AI as base64 for vision analysis)
- **Documents**: PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) (text extracted and sent to AI)
- **Text**: TXT, Markdown, CSV, JSON, etc. (read directly)
### Python (`requests`)
```python

View File

@ -144,31 +144,56 @@ class ContextBuilder:
messages.append({"role": current_role, "content": merged})
return messages
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
"""Build user message content with optional base64-encoded images."""
def _build_user_content(
self, text: str, media: list[str] | None
) -> str | list[dict[str, Any]]:
"""Build user message content with optional media.
Images are converted to base64 vision blocks.
Documents (PDF, Word, Excel, PPT) have their text extracted and appended.
"""
if not media:
return text
images = []
images: list[dict[str, Any]] = []
doc_texts: list[str] = []
for path in media:
p = Path(path)
if not p.is_file():
continue
raw = p.read_bytes()
# Detect real MIME type from magic bytes; fallback to filename guess
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
if not mime or not mime.startswith("image/"):
continue
b64 = base64.b64encode(raw).decode()
images.append({
"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"},
"_meta": {"path": str(p)},
})
if not images:
if mime and mime.startswith("image/"):
b64 = base64.b64encode(raw).decode()
images.append({
"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"},
"_meta": {"path": str(p)},
})
else:
# Try document text extraction
from nanobot.utils.document import extract_text
extracted = extract_text(p)
if extracted and not extracted.startswith("Error"):
doc_texts.append(f"[File: {p.name}]\n{extracted}")
# Build final content
parts: list[dict[str, Any]] = []
parts.extend(images)
combined_text = text
if doc_texts:
combined_text = text + "\n\n" + "\n\n".join(doc_texts)
if images:
parts.append({"type": "text", "text": combined_text})
return parts
elif doc_texts:
return combined_text
else:
return text
return images + [{"type": "text", "text": text}]
def add_tool_result(
self, messages: list[dict[str, Any]],

View File

@ -765,13 +765,17 @@ class AgentLoop:
session_key: str = "cli:direct",
channel: str = "cli",
chat_id: str = "direct",
media: list[str] | None = None,
on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
) -> OutboundMessage | None:
"""Process a message directly and return the outbound payload."""
await self._connect_mcp()
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
msg = InboundMessage(
channel=channel, sender_id="user", chat_id=chat_id,
content=content, media=media or [],
)
return await self._process_message(
msg, session_key=session_key, on_progress=on_progress,
on_stream=on_stream, on_stream_end=on_stream_end,

View File

@ -7,15 +7,28 @@ All requests route to a single persistent API session.
from __future__ import annotations
import asyncio
import base64
import mimetypes
import re
import time
import uuid
from pathlib import Path
from typing import Any
from aiohttp import web
from loguru import logger
from nanobot.config.paths import get_media_dir
from nanobot.utils.helpers import safe_filename
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$", re.DOTALL)
class _FileSizeExceeded(Exception):
"""Raised when an uploaded file exceeds the size limit."""
API_SESSION_KEY = "api:default"
API_CHAT_ID = "default"
@ -57,48 +70,134 @@ def _response_text(value: Any) -> str:
return str(value)
# ---------------------------------------------------------------------------
# Upload helpers
# ---------------------------------------------------------------------------
def _save_base64_data_url(data_url: str, media_dir: Path) -> str | None:
"""Decode a data:...;base64,... URL and save to disk."""
m = _DATA_URL_RE.match(data_url)
if not m:
return None
mime_type, b64_payload = m.group(1), m.group(2)
try:
raw = base64.b64decode(b64_payload)
except Exception:
return None
ext = mimetypes.guess_extension(mime_type) or ".bin"
filename = f"{uuid.uuid4().hex[:12]}{ext}"
dest = media_dir / safe_filename(filename)
dest.write_bytes(raw)
return str(dest)
def _parse_json_content(body: dict) -> tuple[str, list[str]]:
"""Parse JSON request body. Returns (text, media_paths)."""
messages = body.get("messages")
if not isinstance(messages, list) or len(messages) != 1:
raise ValueError("Only a single user message is supported")
message = messages[0]
if not isinstance(message, dict) or message.get("role") != "user":
raise ValueError("Only a single user message is supported")
user_content = message.get("content", "")
media_dir = get_media_dir("api")
media_paths: list[str] = []
if isinstance(user_content, list):
text_parts: list[str] = []
for part in user_content:
if not isinstance(part, dict):
continue
if part.get("type") == "text":
text_parts.append(part.get("text", ""))
elif part.get("type") == "image_url":
url = part.get("image_url", {}).get("url", "")
if url.startswith("data:"):
saved = _save_base64_data_url(url, media_dir)
if saved:
media_paths.append(saved)
text = " ".join(text_parts)
elif isinstance(user_content, str):
text = user_content
else:
raise ValueError("Invalid content format")
return text, media_paths
async def _parse_multipart(request: web.Request) -> tuple[str, list[str], str | None]:
"""Parse multipart/form-data. Returns (text, media_paths, session_id)."""
media_dir = get_media_dir("api")
reader = await request.multipart()
text = ""
session_id = None
media_paths: list[str] = []
while True:
part = await reader.next()
if part is None:
break
if part.name == "message":
text = (await part.read()).decode("utf-8")
elif part.name == "session_id":
session_id = (await part.read()).decode("utf-8").strip()
elif part.name == "files":
raw = await part.read()
if len(raw) > MAX_FILE_SIZE:
raise _FileSizeExceeded(f"File '{part.filename}' exceeds {MAX_FILE_SIZE // (1024*1024)}MB limit")
filename = safe_filename(part.filename or f"{uuid.uuid4().hex[:12]}.bin")
dest = media_dir / filename
dest.write_bytes(raw)
media_paths.append(str(dest))
if not text:
text = "请分析上传的文件"
return text, media_paths, session_id
# ---------------------------------------------------------------------------
# Route handlers
# ---------------------------------------------------------------------------
async def handle_chat_completions(request: web.Request) -> web.Response:
"""POST /v1/chat/completions"""
# --- Parse body ---
try:
body = await request.json()
except Exception:
return _error_json(400, "Invalid JSON body")
messages = body.get("messages")
if not isinstance(messages, list) or len(messages) != 1:
return _error_json(400, "Only a single user message is supported")
# Stream not yet supported
if body.get("stream", False):
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
message = messages[0]
if not isinstance(message, dict) or message.get("role") != "user":
return _error_json(400, "Only a single user message is supported")
user_content = message.get("content", "")
if isinstance(user_content, list):
# Multi-modal content array — extract text parts
user_content = " ".join(
part.get("text", "") for part in user_content if part.get("type") == "text"
)
"""POST /v1/chat/completions — supports JSON and multipart/form-data."""
content_type = request.content_type or ""
if not isinstance(content_type, str):
content_type = ""
agent_loop = request.app["agent_loop"]
timeout_s: float = request.app.get("request_timeout", 120.0)
model_name: str = request.app.get("model_name", "nanobot")
if (requested_model := body.get("model")) and requested_model != model_name:
return _error_json(400, f"Only configured model '{model_name}' is available")
session_key = f"api:{body['session_id']}" if body.get("session_id") else API_SESSION_KEY
try:
if content_type.startswith("multipart/"):
text, media_paths, session_id = await _parse_multipart(request)
else:
try:
body = await request.json()
except Exception:
return _error_json(400, "Invalid JSON body")
if body.get("stream", False):
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
if (requested_model := body.get("model")) and requested_model != model_name:
return _error_json(400, f"Only configured model '{model_name}' is available")
text, media_paths = _parse_json_content(body)
session_id = body.get("session_id")
except ValueError as e:
return _error_json(400, str(e))
except _FileSizeExceeded as e:
return _error_json(413, str(e), err_type="invalid_request_error")
except Exception:
logger.exception("Error parsing upload")
return _error_json(413, "File too large or invalid upload")
session_key = f"api:{session_id}" if session_id else API_SESSION_KEY
session_locks: dict[str, asyncio.Lock] = request.app["session_locks"]
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
logger.info("API request session_key={} content={}", session_key, user_content[:80])
logger.info("API request session_key={} media={} text={}", session_key, len(media_paths), text[:80])
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
@ -107,7 +206,8 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
try:
response = await asyncio.wait_for(
agent_loop.process_direct(
content=user_content,
content=text,
media=media_paths if media_paths else None,
session_key=session_key,
channel="api",
chat_id=API_CHAT_ID,
@ -117,13 +217,11 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
response_text = _response_text(response)
if not response_text or not response_text.strip():
logger.warning(
"Empty response for session {}, retrying",
session_key,
)
logger.warning("Empty response for session {}, retrying", session_key)
retry_response = await asyncio.wait_for(
agent_loop.process_direct(
content=user_content,
content=text,
media=media_paths if media_paths else None,
session_key=session_key,
channel="api",
chat_id=API_CHAT_ID,
@ -132,10 +230,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
)
response_text = _response_text(retry_response)
if not response_text or not response_text.strip():
logger.warning(
"Empty response after retry for session {}, using fallback",
session_key,
)
logger.warning("Empty response after retry, using fallback")
response_text = _FALLBACK
except asyncio.TimeoutError:
@ -183,7 +278,7 @@ def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float =
model_name: Model name reported in responses.
request_timeout: Per-request timeout in seconds.
"""
app = web.Application()
app = web.Application(client_max_size=20 * 1024 * 1024) # 20MB for base64 images
app["agent_loop"] = agent_loop
app["model_name"] = model_name
app["request_timeout"] = request_timeout

206
nanobot/utils/document.py Normal file
View File

@ -0,0 +1,206 @@
"""Document text extraction utilities for nanobot."""
from pathlib import Path
from loguru import logger
try:
from pypdf import PdfReader
except ImportError:
PdfReader = None # type: ignore
try:
from docx import Document as DocxDocument
except ImportError:
DocxDocument = None # type: ignore
try:
from openpyxl import load_workbook
except ImportError:
load_workbook = None # type: ignore
try:
from pptx import Presentation as PptxPresentation
except ImportError:
PptxPresentation = None # type: ignore
# Supported file extensions for text extraction
SUPPORTED_EXTENSIONS: set[str] = {
# Document formats
".pdf",
".docx",
".xlsx",
".pptx",
# Text formats
".txt",
".md",
".csv",
".json",
".xml",
".html",
".htm",
".log",
".yaml",
".yml",
".toml",
".ini",
".cfg",
# Image formats (for future OCR support)
".png",
".jpg",
".jpeg",
".gif",
".webp",
}
_MAX_TEXT_LENGTH = 200_000
def extract_text(path: Path) -> str | None:
"""Extract text from a file.
Args:
path: Path to the file.
Returns:
Extracted text as string, None for unsupported types,
or error string for failures.
"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists():
return f"[error: file not found: {path}]"
ext = path.suffix.lower()
# Document formats
if ext == ".pdf":
if PdfReader is None:
return "[error: pypdf not installed]"
return _extract_pdf(path)
elif ext == ".docx":
if DocxDocument is None:
return "[error: python-docx not installed]"
return _extract_docx(path)
elif ext == ".xlsx":
if load_workbook is None:
return "[error: openpyxl not installed]"
return _extract_xlsx(path)
elif ext == ".pptx":
if PptxPresentation is None:
return "[error: python-pptx not installed]"
return _extract_pptx(path)
elif _is_text_extension(ext):
return _extract_text_file(path)
elif ext in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
# Image files - for future OCR support
return f"[image: {path.name}]"
else:
# Unsupported extension
return None
def _extract_pdf(path: Path) -> str:
"""Extract text from PDF using pypdf."""
try:
reader = PdfReader(path)
pages: list[str] = []
for i, page in enumerate(reader.pages, 1):
text = page.extract_text() or ""
pages.append(f"--- Page {i} ---\n{text}")
return _truncate("\n\n".join(pages), _MAX_TEXT_LENGTH)
except Exception as e:
logger.error("Failed to extract PDF {}: {}", path, e)
return f"[error: failed to extract PDF: {e!s}]"
def _extract_docx(path: Path) -> str:
"""Extract text from DOCX using python-docx."""
try:
doc = DocxDocument(path)
paragraphs: list[str] = [p.text for p in doc.paragraphs if p.text.strip()]
return _truncate("\n\n".join(paragraphs), _MAX_TEXT_LENGTH)
except Exception as e:
logger.error("Failed to extract DOCX {}: {}", path, e)
return f"[error: failed to extract DOCX: {e!s}]"
def _extract_xlsx(path: Path) -> str:
"""Extract text from XLSX using openpyxl."""
try:
wb = load_workbook(path, read_only=True, data_only=True)
sheets: list[str] = []
for sheet_name in wb.sheetnames:
ws = wb[sheet_name]
rows: list[str] = []
for row in ws.iter_rows(values_only=True):
row_text = "\t".join(str(cell) if cell is not None else "" for cell in row)
if row_text.strip():
rows.append(row_text)
if rows:
sheets.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
wb.close()
return _truncate("\n\n".join(sheets), _MAX_TEXT_LENGTH)
except Exception as e:
logger.error("Failed to extract XLSX {}: {}", path, e)
return f"[error: failed to extract XLSX: {e!s}]"
def _extract_pptx(path: Path) -> str:
"""Extract text from PPTX using python-pptx."""
try:
prs = PptxPresentation(path)
slides: list[str] = []
for i, slide in enumerate(prs.slides, 1):
slide_text: list[str] = []
for shape in slide.shapes:
if hasattr(shape, "text") and shape.text:
slide_text.append(shape.text)
if slide_text:
slides.append(f"--- Slide {i} ---\n" + "\n".join(slide_text))
return _truncate("\n\n".join(slides), _MAX_TEXT_LENGTH)
except Exception as e:
logger.error("Failed to extract PPTX {}: {}", path, e)
return f"[error: failed to extract PPTX: {e!s}]"
def _extract_text_file(path: Path) -> str:
"""Extract text from a plain text file."""
try:
# Try UTF-8 first, then latin-1 fallback
try:
content = path.read_text(encoding="utf-8")
except UnicodeDecodeError:
content = path.read_text(encoding="latin-1")
return _truncate(content, _MAX_TEXT_LENGTH)
except Exception as e:
logger.error("Failed to read text file {}: {}", path, e)
return f"[error: failed to read file: {e!s}]"
def _truncate(text: str, max_length: int) -> str:
"""Truncate text with a suffix indicating truncation."""
if len(text) <= max_length:
return text
return text[:max_length] + f"... (truncated, {len(text)} chars total)"
def _is_text_extension(ext: str) -> bool:
"""Check if extension is a text format."""
return ext in {
".txt",
".md",
".csv",
".json",
".xml",
".html",
".htm",
".log",
".yaml",
".yml",
".toml",
".ini",
".cfg",
}

View File

@ -50,6 +50,10 @@ dependencies = [
"tiktoken>=0.12.0,<1.0.0",
"jinja2>=3.1.0,<4.0.0",
"dulwich>=0.22.0,<1.0.0",
"pypdf>=5.0.0,<6.0.0",
"python-docx>=1.1.0,<2.0.0",
"openpyxl>=3.1.0,<4.0.0",
"python-pptx>=1.0.0,<2.0.0",
]
[project.optional-dependencies]

View File

@ -0,0 +1,379 @@
"""Tests for API file upload functionality (JSON base64 + multipart)."""
from __future__ import annotations
import base64
from io import BytesIO
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
from nanobot.api.server import (
API_CHAT_ID,
API_SESSION_KEY,
_parse_json_content,
_save_base64_data_url,
create_app,
)
try:
from aiohttp.test_utils import TestClient, TestServer
HAS_AIOHTTP = True
except ImportError:
HAS_AIOHTTP = False
pytest_plugins = ("pytest_asyncio",)
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
agent = MagicMock()
agent.process_direct = AsyncMock(return_value=response_text)
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
return agent
@pytest.fixture
def mock_agent():
return _make_mock_agent()
@pytest.fixture
def app(mock_agent):
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
@pytest_asyncio.fixture
async def aiohttp_client():
clients: list[TestClient] = []
async def _make_client(app):
client = TestClient(TestServer(app))
await client.start_server()
clients.append(client)
return client
try:
yield _make_client
finally:
for client in clients:
await client.close()
# ---------------------------------------------------------------------------
# Helper function tests
# ---------------------------------------------------------------------------
def test_save_base64_data_url_saves_png(tmp_path) -> None:
"""Saving a base64 data URL creates a file with correct extension."""
b64_data = base64.b64encode(b"fake png data").decode()
data_url = f"data:image/png;base64,{b64_data}"
result = _save_base64_data_url(data_url, tmp_path)
assert result is not None
assert result.endswith(".png")
assert (tmp_path / result.replace(str(tmp_path) + "/", "")).read_bytes() == b"fake png data"
def test_save_base64_data_url_handles_invalid_b64(tmp_path) -> None:
"""Invalid base64 returns None."""
result = _save_base64_data_url("data:image/png;base64,not-valid-base64!!!", tmp_path)
assert result is None
def test_save_base64_data_url_handles_unknown_mime(tmp_path) -> None:
"""Unknown MIME type defaults to .bin."""
b64_data = base64.b64encode(b"some data").decode()
data_url = f"data:unknown/type;base64,{b64_data}"
result = _save_base64_data_url(data_url, tmp_path)
assert result is not None
assert result.endswith(".bin")
def test_parse_json_content_extracts_text_and_media(tmp_path) -> None:
"""Parse JSON with text + base64 image saves image and returns paths."""
b64_data = base64.b64encode(b"img").decode()
body = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "describe this"},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_data}"}},
],
}
]
}
import os
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
text, media_paths = _parse_json_content(body)
assert text == "describe this"
assert len(media_paths) == 1
finally:
os.chdir(original_cwd)
def test_parse_json_content_plain_text_only() -> None:
"""Plain text string content returns no media."""
body = {"messages": [{"role": "user", "content": "hello"}]}
text, media_paths = _parse_json_content(body)
assert text == "hello"
assert media_paths == []
def test_parse_json_content_validates_single_message() -> None:
"""Multiple messages raise ValueError."""
body = {
"messages": [
{"role": "user", "content": "first"},
{"role": "user", "content": "second"},
]
}
with pytest.raises(ValueError, match="single user message"):
_parse_json_content(body)
def test_parse_json_content_validates_user_role() -> None:
"""Non-user role raises ValueError."""
body = {"messages": [{"role": "system", "content": "you are a bot"}]}
with pytest.raises(ValueError, match="single user message"):
_parse_json_content(body)
# ---------------------------------------------------------------------------
# Multipart upload tests
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_multipart_upload_saves_file(aiohttp_client, mock_agent, tmp_path) -> None:
"""Multipart upload saves file to media dir and passes path to process_direct."""
import os
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
app = create_app(mock_agent, model_name="m")
client = await aiohttp_client(app)
file_data = b"test file content"
data = BytesIO(file_data)
resp = await client.post(
"/v1/chat/completions",
data={"message": "analyze this", "files": data},
)
assert resp.status == 200
call_kwargs = mock_agent.process_direct.call_args.kwargs
assert call_kwargs["content"] == "analyze this"
assert len(call_kwargs.get("media", [])) == 1
finally:
os.chdir(original_cwd)
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_multipart_multiple_files(aiohttp_client, mock_agent, tmp_path) -> None:
"""Multipart upload with multiple files saves all and passes paths."""
import os
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
app = create_app(mock_agent, model_name="m")
client = await aiohttp_client(app)
# Note: aiohttp test client has limited multipart support
# This test verifies the basic flow
file_data = b"test content"
data = BytesIO(file_data)
resp = await client.post(
"/v1/chat/completions",
data={"message": "analyze", "files": data},
)
assert resp.status == 200
finally:
os.chdir(original_cwd)
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_multipart_file_size_limit(aiohttp_client, mock_agent, tmp_path) -> None:
"""File exceeding MAX_FILE_SIZE returns 413."""
import os
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
app = create_app(mock_agent, model_name="m")
client = await aiohttp_client(app)
# Create a file larger than 10MB
large_data = b"x" * (11 * 1024 * 1024)
data = BytesIO(large_data)
resp = await client.post(
"/v1/chat/completions",
data={"message": "analyze", "files": data},
)
assert resp.status == 413
finally:
os.chdir(original_cwd)
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_multipart_defaults_text_when_missing(aiohttp_client, mock_agent, tmp_path) -> None:
"""Multipart without message field uses default text."""
import os
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
app = create_app(mock_agent, model_name="m")
client = await aiohttp_client(app)
file_data = b"content"
data = BytesIO(file_data)
resp = await client.post(
"/v1/chat/completions",
data={"files": data},
)
assert resp.status == 200
call_kwargs = mock_agent.process_direct.call_args.kwargs
assert call_kwargs["content"] == "请分析上传的文件"
finally:
os.chdir(original_cwd)
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_multipart_with_session_id(aiohttp_client, mock_agent, tmp_path) -> None:
"""Multipart upload with session_id uses custom session key."""
import os
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
app = create_app(mock_agent, model_name="m")
client = await aiohttp_client(app)
file_data = b"content"
data = BytesIO(file_data)
resp = await client.post(
"/v1/chat/completions",
data={"message": "hello", "session_id": "my-session", "files": data},
)
assert resp.status == 200
call_kwargs = mock_agent.process_direct.call_args.kwargs
assert call_kwargs["session_key"] == "api:my-session"
finally:
os.chdir(original_cwd)
# ---------------------------------------------------------------------------
# Backward compatibility tests
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_plain_text_backward_compat(aiohttp_client, mock_agent) -> None:
"""Plain text JSON request (no media) works as before."""
app = create_app(mock_agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello world"}]},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "mock response"
call_kwargs = mock_agent.process_direct.call_args.kwargs
assert call_kwargs["content"] == "hello world"
assert call_kwargs.get("media") is None
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_json_base64_image_upload(aiohttp_client, mock_agent, tmp_path) -> None:
"""JSON request with base64 data URL saves file and passes path."""
import os
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
app = create_app(mock_agent, model_name="m")
client = await aiohttp_client(app)
# Use valid base64 for a tiny PNG (1x1 transparent pixel)
tiny_png_b64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
resp = await client.post(
"/v1/chat/completions",
json={
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "what is this"},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{tiny_png_b64}"}},
],
}
]
},
)
assert resp.status == 200
call_kwargs = mock_agent.process_direct.call_args.kwargs
assert call_kwargs["content"] == "what is this"
assert len(call_kwargs.get("media", [])) == 1
finally:
os.chdir(original_cwd)
# ---------------------------------------------------------------------------
# DOCX document extraction tests
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_docx_upload_extracted_and_sent(aiohttp_client, tmp_path) -> None:
"""Uploaded DOCX should have its text extracted before being sent to AI."""
from docx import Document
agent = _make_mock_agent("This report shows $5M revenue")
import os
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
doc = Document()
doc.add_heading("Q1 Report", level=1)
doc.add_paragraph("Total revenue: $5,000,000")
buf = BytesIO()
doc.save(buf)
docx_bytes = buf.getvalue()
import aiohttp
data = aiohttp.FormData()
data.add_field("message", "summarize the report")
data.add_field("files", docx_bytes, filename="report.docx",
content_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document")
resp = await client.post("/v1/chat/completions", data=data)
assert resp.status == 200
call_kwargs = agent.process_direct.call_args.kwargs
media = call_kwargs.get("media", [])
assert len(media) == 1
assert "report.docx" in media[0]
finally:
os.chdir(original_cwd)

View File

@ -0,0 +1,66 @@
"""Tests for context builder document handling."""
from __future__ import annotations
import pytest
from pathlib import Path
from nanobot.agent.context import ContextBuilder
def _make_builder(tmp_path: Path) -> ContextBuilder:
"""Create a minimal ContextBuilder for testing."""
return ContextBuilder(workspace=tmp_path, timezone="UTC")
def test_build_user_content_with_no_media_returns_string(tmp_path: Path) -> None:
builder = _make_builder(tmp_path)
result = builder._build_user_content("hello", None)
assert result == "hello"
def test_build_user_content_with_image_returns_list(tmp_path: Path) -> None:
"""Image files should produce base64 content blocks."""
builder = _make_builder(tmp_path)
png = tmp_path / "test.png"
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
result = builder._build_user_content("describe this", [str(png)])
assert isinstance(result, list)
types = [b["type"] for b in result]
assert "image_url" in types
assert "text" in types
def test_build_user_content_with_docx_includes_extracted_text(tmp_path: Path) -> None:
"""Document files should have their text extracted and included."""
from docx import Document
doc = Document()
doc.add_paragraph("Quarterly revenue is $5M")
docx_path = tmp_path / "report.docx"
doc.save(docx_path)
builder = _make_builder(tmp_path)
result = builder._build_user_content("summarize this", [str(docx_path)])
assert isinstance(result, str)
assert "Quarterly revenue" in result
def test_build_user_content_mixed_image_and_document(tmp_path: Path) -> None:
"""Mix of images and documents: images as base64, docs as text."""
from docx import Document
png = tmp_path / "chart.png"
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
doc = Document()
doc.add_paragraph("Report text here")
docx = tmp_path / "report.docx"
doc.save(docx)
builder = _make_builder(tmp_path)
result = builder._build_user_content("analyze both", [str(png), str(docx)])
assert isinstance(result, list)
assert any(b["type"] == "image_url" for b in result)
text_parts = [b.get("text", "") for b in result if b.get("type") == "text"]
assert any("Report text here" in t for t in text_parts)

View File

@ -0,0 +1,276 @@
"""Tests for document text extraction utilities."""
import io
from pathlib import Path
import pytest
from nanobot.utils.document import (
SUPPORTED_EXTENSIONS,
_is_text_extension,
extract_text,
)
class TestSupportedExtensions:
"""Test the SUPPORTED_EXTENSIONS constant."""
def test_supported_extensions_include_common_formats(self):
"""Test that common document formats are included."""
# Document formats
assert ".pdf" in SUPPORTED_EXTENSIONS
assert ".docx" in SUPPORTED_EXTENSIONS
assert ".xlsx" in SUPPORTED_EXTENSIONS
assert ".pptx" in SUPPORTED_EXTENSIONS
# Text formats
assert ".txt" in SUPPORTED_EXTENSIONS
assert ".md" in SUPPORTED_EXTENSIONS
assert ".csv" in SUPPORTED_EXTENSIONS
assert ".json" in SUPPORTED_EXTENSIONS
assert ".yaml" in SUPPORTED_EXTENSIONS
assert ".yml" in SUPPORTED_EXTENSIONS
# Image formats
assert ".png" in SUPPORTED_EXTENSIONS
assert ".jpg" in SUPPORTED_EXTENSIONS
assert ".jpeg" in SUPPORTED_EXTENSIONS
class TestExtractText:
"""Test the extract_text function."""
def test_extract_text_unsupported_returns_none(self, tmp_path: Path):
"""Test that unsupported file types return None."""
unsupported_file = tmp_path / "file.xyz"
unsupported_file.write_text("content")
result = extract_text(unsupported_file)
assert result is None
def test_extract_text_file_not_found(self, tmp_path: Path):
"""Test that non-existent files return error string."""
missing_file = tmp_path / "nonexistent.txt"
result = extract_text(missing_file)
assert result is not None
assert "[error: file not found:" in result
def test_extract_text_txt_file(self, tmp_path: Path):
"""Test extracting text from a .txt file."""
txt_file = tmp_path / "test.txt"
content = "Hello, world!\nThis is a test."
txt_file.write_text(content, encoding="utf-8")
result = extract_text(txt_file)
assert result == content
def test_extract_text_txt_file_with_truncation(self, tmp_path: Path):
"""Test that large text files are truncated."""
txt_file = tmp_path / "large.txt"
# Create content larger than _MAX_TEXT_LENGTH
content = "x" * 300_000
txt_file.write_text(content, encoding="utf-8")
result = extract_text(txt_file)
assert len(result) < 300_000
assert "(truncated," in result
assert "chars total)" in result
def test_extract_text_md_file(self, tmp_path: Path):
"""Test extracting text from a .md file."""
md_file = tmp_path / "test.md"
content = "# Header\n\nSome markdown content."
md_file.write_text(content, encoding="utf-8")
result = extract_text(md_file)
assert result == content
def test_extract_text_csv_file(self, tmp_path: Path):
"""Test extracting text from a .csv file."""
csv_file = tmp_path / "test.csv"
content = "name,age\nAlice,30\nBob,25"
csv_file.write_text(content, encoding="utf-8")
result = extract_text(csv_file)
assert result == content
def test_extract_text_json_file(self, tmp_path: Path):
"""Test extracting text from a .json file."""
json_file = tmp_path / "test.json"
content = '{"key": "value", "number": 42}'
json_file.write_text(content, encoding="utf-8")
result = extract_text(json_file)
assert result == content
def test_extract_text_xlsx(self, tmp_path: Path):
"""Test extracting text from an .xlsx file."""
from openpyxl import Workbook
xlsx_file = tmp_path / "test.xlsx"
wb = Workbook()
ws = wb.active
ws.title = "Sheet1"
ws["A1"] = "Name"
ws["B1"] = "Age"
ws["A2"] = "Alice"
ws["B2"] = 30
ws["A3"] = "Bob"
ws["B3"] = 25
# Add a second sheet
ws2 = wb.create_sheet("Sheet2")
ws2["A1"] = "Product"
ws2["B1"] = "Price"
ws2["A2"] = "Widget"
ws2["B2"] = 9.99
wb.save(xlsx_file)
wb.close()
result = extract_text(xlsx_file)
assert result is not None
assert "--- Sheet: Sheet1 ---" in result
assert "--- Sheet: Sheet2 ---" in result
assert "Alice" in result
assert "Bob" in result
assert "Widget" in result
assert "9.99" in result
def test_extract_text_xlsx_empty_sheet(self, tmp_path: Path):
"""Test extracting text from an .xlsx file with empty sheets."""
from openpyxl import Workbook
xlsx_file = tmp_path / "empty.xlsx"
wb = Workbook()
# Clear the default sheet
wb.remove(wb.active)
# Add an empty sheet
wb.create_sheet("EmptySheet")
wb.save(xlsx_file)
wb.close()
result = extract_text(xlsx_file)
# Empty sheets should return empty string or header only
assert result == "--- Sheet: EmptySheet ---" or result == ""
def test_extract_text_docx(self, tmp_path: Path):
"""Test extracting text from a .docx file."""
from docx import Document
docx_file = tmp_path / "test.docx"
doc = Document()
doc.add_heading("Test Document", 0)
doc.add_paragraph("This is paragraph one.")
doc.add_paragraph("This is paragraph two.")
doc.save(docx_file)
result = extract_text(docx_file)
assert result is not None
assert "Test Document" in result
assert "This is paragraph one." in result
assert "This is paragraph two." in result
def test_extract_text_docx_empty(self, tmp_path: Path):
"""Test extracting text from an empty .docx file."""
from docx import Document
docx_file = tmp_path / "empty.docx"
doc = Document()
doc.save(docx_file)
result = extract_text(docx_file)
assert result == ""
def test_extract_text_pptx(self, tmp_path: Path):
"""Test extracting text from a .pptx file."""
from pptx import Presentation
pptx_file = tmp_path / "test.pptx"
prs = Presentation()
# Slide 1
slide1 = prs.slides.add_slide(prs.slide_layouts[0])
for shape in slide1.shapes:
if hasattr(shape, "text"):
shape.text = "First Slide Title"
# Slide 2
slide2 = prs.slides.add_slide(prs.slide_layouts[5])
left = top = width = height = 1000000
textbox = slide2.shapes.add_textbox(left, top, width, height)
text_frame = textbox.text_frame
text_frame.text = "Bullet point content"
prs.save(pptx_file)
result = extract_text(pptx_file)
assert result is not None
assert "--- Slide 1 ---" in result
assert "--- Slide 2 ---" in result
# Text content may vary depending on PowerPoint layout defaults
assert len(result) > 0
def test_extract_text_pdf_not_found(self, tmp_path: Path):
"""Test that missing PDF files return error string."""
missing_pdf = tmp_path / "nonexistent.pdf"
result = extract_text(missing_pdf)
assert result is not None
assert "[error: file not found:" in result
def test_extract_text_image_files(self, tmp_path: Path):
"""Test that image files return placeholder text."""
# Create a minimal PNG file (1x1 pixel)
png_file = tmp_path / "test.png"
# Minimal valid PNG: 8-byte signature + IHDR + IDAT + IEND
png_data = (
b"\x89PNG\r\n\x1a\n"
b"\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01"
b"\x08\x02\x00\x00\x00\x90wS\xde"
b"\x00\x00\x00\x0cIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01"
b"\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82"
)
png_file.write_bytes(png_data)
result = extract_text(png_file)
assert result is not None
assert "[image:" in result
assert "test.png" in result
class TestIsTextExtension:
"""Test the _is_text_extension helper."""
def test_text_extensions_return_true(self):
"""Test that known text extensions return True."""
assert _is_text_extension(".txt") is True
assert _is_text_extension(".md") is True
assert _is_text_extension(".csv") is True
assert _is_text_extension(".json") is True
assert _is_text_extension(".yaml") is True
assert _is_text_extension(".yml") is True
assert _is_text_extension(".xml") is True
assert _is_text_extension(".html") is True
assert _is_text_extension(".htm") is True
def test_non_text_extensions_return_false(self):
"""Test that non-text extensions return False."""
assert _is_text_extension(".pdf") is False
assert _is_text_extension(".docx") is False
assert _is_text_extension(".xlsx") is False
assert _is_text_extension(".pptx") is False
assert _is_text_extension(".png") is False
assert _is_text_extension(".xyz") is False
def test_case_sensitivity(self):
"""Test that _is_text_extension requires lowercase extension.
Note: The main extract_text function handles case-insensitivity by
converting extensions to lowercase before calling _is_text_extension.
"""
# _is_text_extension itself is case-sensitive (lowercase only)
assert _is_text_extension(".txt") is True
assert _is_text_extension(".TXT") is False
assert _is_text_extension(".pdf") is False

View File

@ -194,6 +194,7 @@ async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_ag
assert body["model"] == "test-model"
mock_agent.process_direct.assert_called_once_with(
content="hello",
media=None,
session_key=API_SESSION_KEY,
channel="api",
chat_id=API_CHAT_ID,
@ -205,7 +206,7 @@ async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_ag
async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
call_log: list[str] = []
async def fake_process(content, session_key="", channel="", chat_id=""):
async def fake_process(content, session_key="", channel="", chat_id="", **kwargs):
call_log.append(session_key)
return f"reply to {content}"
@ -236,7 +237,7 @@ async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None:
order: list[str] = []
async def slow_process(content, session_key="", channel="", chat_id=""):
async def slow_process(content, session_key="", channel="", chat_id="", **kwargs):
order.append(f"start:{content}")
await asyncio.sleep(0.1)
order.append(f"end:{content}")
@ -307,12 +308,12 @@ async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> N
},
)
assert resp.status == 200
mock_agent.process_direct.assert_called_once_with(
content="describe this",
session_key=API_SESSION_KEY,
channel="api",
chat_id=API_CHAT_ID,
)
call_kwargs = mock_agent.process_direct.call_args.kwargs
assert call_kwargs["content"] == "describe this"
assert call_kwargs["session_key"] == API_SESSION_KEY
assert call_kwargs["channel"] == "api"
assert call_kwargs["chat_id"] == API_CHAT_ID
assert len(call_kwargs.get("media") or []) >= 0 # base64 images saved to disk
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@ -320,7 +321,7 @@ async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> N
async def test_empty_response_retry_then_success(aiohttp_client) -> None:
call_count = 0
async def sometimes_empty(content, session_key="", channel="", chat_id=""):
async def sometimes_empty(content, session_key="", channel="", chat_id="", **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
@ -351,7 +352,7 @@ async def test_empty_response_falls_back(aiohttp_client) -> None:
call_count = 0
async def always_empty(content, session_key="", channel="", chat_id=""):
async def always_empty(content, session_key="", channel="", chat_id="", **kwargs):
nonlocal call_count
call_count += 1
return ""
@ -371,3 +372,31 @@ async def test_empty_response_falls_back(aiohttp_client) -> None:
body = await resp.json()
assert body["choices"][0]["message"]["content"] == EMPTY_FINAL_RESPONSE_MESSAGE
assert call_count == 2
@pytest.mark.asyncio
async def test_process_direct_accepts_media() -> None:
"""process_direct should forward media paths to _process_message."""
from nanobot.agent.loop import AgentLoop
loop = AgentLoop.__new__(AgentLoop)
loop._connect_mcp = AsyncMock()
captured_msg = None
async def fake_process(msg, *, session_key="", on_progress=None, on_stream=None, on_stream_end=None):
nonlocal captured_msg
captured_msg = msg
return None
loop._process_message = fake_process
await loop.process_direct(
content="analyze this",
media=["/tmp/image.png", "/tmp/report.pdf"],
session_key="test:1",
)
assert captured_msg is not None
assert captured_msg.media == ["/tmp/image.png", "/tmp/report.pdf"]
assert captured_msg.content == "analyze this"