From a068df5a79c41798311121afd7b31db1c6b15049 Mon Sep 17 00:00:00 2001 From: dengjingren Date: Wed, 8 Apr 2026 15:28:36 +0800 Subject: [PATCH] feat(api): support file uploads via JSON base64 and multipart/form-data --- README.md | 39 ++++ nanobot/agent/context.py | 53 +++-- nanobot/agent/loop.py | 6 +- nanobot/api/server.py | 175 +++++++++++---- nanobot/utils/document.py | 206 +++++++++++++++++ pyproject.toml | 4 + tests/test_api_attachment.py | 379 ++++++++++++++++++++++++++++++++ tests/test_context_documents.py | 66 ++++++ tests/test_document_parsing.py | 276 +++++++++++++++++++++++ tests/test_openai_api.py | 49 ++++- 10 files changed, 1188 insertions(+), 65 deletions(-) create mode 100644 nanobot/utils/document.py create mode 100644 tests/test_api_attachment.py create mode 100644 tests/test_context_documents.py create mode 100644 tests/test_document_parsing.py diff --git a/README.md b/README.md index a2ea20f8c..d7890b883 100644 --- a/README.md +++ b/README.md @@ -1757,6 +1757,7 @@ By default, the API binds to `127.0.0.1:8900`. You can change this in `config.js - Single-message input: each request must contain exactly one `user` message - Fixed model: omit `model`, or pass the same model shown by `/v1/models` - No streaming: `stream=true` is not supported +- **File uploads**: supports images, PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) via JSON base64 or `multipart/form-data` (max 10MB per file) ### Endpoints @@ -1775,6 +1776,44 @@ curl http://127.0.0.1:8900/v1/chat/completions \ }' ``` +### File Upload (JSON base64) + +Send images inline using the OpenAI multimodal content format: + +```bash +curl http://127.0.0.1:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [{"role": "user", "content": [ + {"type": "text", "text": "Describe this image"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}} + ]}] + }' +``` + +### File Upload (multipart/form-data) + +Upload any supported file type (images, PDF, Word, Excel, PPT) via multipart: + +```bash +# Single file +curl http://127.0.0.1:8900/v1/chat/completions \ + -F "message=Summarize this report" \ + -F "files=@report.docx" + +# Multiple files with session isolation +curl http://127.0.0.1:8900/v1/chat/completions \ + -F "message=Compare these files" \ + -F "files=@chart.png" \ + -F "files=@data.xlsx" \ + -F "session_id=my-session" +``` + +Supported file types: +- **Images**: PNG, JPEG, GIF, WebP (sent to AI as base64 for vision analysis) +- **Documents**: PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) (text extracted and sent to AI) +- **Text**: TXT, Markdown, CSV, JSON, etc. (read directly) + ### Python (`requests`) ```python diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index 3ac19e7f3..5c0a8c805 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -144,31 +144,56 @@ class ContextBuilder: messages.append({"role": current_role, "content": merged}) return messages - def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]: - """Build user message content with optional base64-encoded images.""" + def _build_user_content( + self, text: str, media: list[str] | None + ) -> str | list[dict[str, Any]]: + """Build user message content with optional media. + + Images are converted to base64 vision blocks. + Documents (PDF, Word, Excel, PPT) have their text extracted and appended. + """ if not media: return text - images = [] + images: list[dict[str, Any]] = [] + doc_texts: list[str] = [] + for path in media: p = Path(path) if not p.is_file(): continue raw = p.read_bytes() - # Detect real MIME type from magic bytes; fallback to filename guess mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0] - if not mime or not mime.startswith("image/"): - continue - b64 = base64.b64encode(raw).decode() - images.append({ - "type": "image_url", - "image_url": {"url": f"data:{mime};base64,{b64}"}, - "_meta": {"path": str(p)}, - }) - if not images: + if mime and mime.startswith("image/"): + b64 = base64.b64encode(raw).decode() + images.append({ + "type": "image_url", + "image_url": {"url": f"data:{mime};base64,{b64}"}, + "_meta": {"path": str(p)}, + }) + else: + # Try document text extraction + from nanobot.utils.document import extract_text + extracted = extract_text(p) + if extracted and not extracted.startswith("Error"): + doc_texts.append(f"[File: {p.name}]\n{extracted}") + + # Build final content + parts: list[dict[str, Any]] = [] + parts.extend(images) + + combined_text = text + if doc_texts: + combined_text = text + "\n\n" + "\n\n".join(doc_texts) + + if images: + parts.append({"type": "text", "text": combined_text}) + return parts + elif doc_texts: + return combined_text + else: return text - return images + [{"type": "text", "text": text}] def add_tool_result( self, messages: list[dict[str, Any]], diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 66d765d00..a3d0960f2 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -765,13 +765,17 @@ class AgentLoop: session_key: str = "cli:direct", channel: str = "cli", chat_id: str = "direct", + media: list[str] | None = None, on_progress: Callable[[str], Awaitable[None]] | None = None, on_stream: Callable[[str], Awaitable[None]] | None = None, on_stream_end: Callable[..., Awaitable[None]] | None = None, ) -> OutboundMessage | None: """Process a message directly and return the outbound payload.""" await self._connect_mcp() - msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content) + msg = InboundMessage( + channel=channel, sender_id="user", chat_id=chat_id, + content=content, media=media or [], + ) return await self._process_message( msg, session_key=session_key, on_progress=on_progress, on_stream=on_stream, on_stream_end=on_stream_end, diff --git a/nanobot/api/server.py b/nanobot/api/server.py index 2bfeddd05..8c9c97768 100644 --- a/nanobot/api/server.py +++ b/nanobot/api/server.py @@ -7,15 +7,28 @@ All requests route to a single persistent API session. from __future__ import annotations import asyncio +import base64 +import mimetypes +import re import time import uuid +from pathlib import Path from typing import Any from aiohttp import web from loguru import logger +from nanobot.config.paths import get_media_dir +from nanobot.utils.helpers import safe_filename from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE +MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB +_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$", re.DOTALL) + + +class _FileSizeExceeded(Exception): + """Raised when an uploaded file exceeds the size limit.""" + API_SESSION_KEY = "api:default" API_CHAT_ID = "default" @@ -57,48 +70,134 @@ def _response_text(value: Any) -> str: return str(value) +# --------------------------------------------------------------------------- +# Upload helpers +# --------------------------------------------------------------------------- + +def _save_base64_data_url(data_url: str, media_dir: Path) -> str | None: + """Decode a data:...;base64,... URL and save to disk.""" + m = _DATA_URL_RE.match(data_url) + if not m: + return None + mime_type, b64_payload = m.group(1), m.group(2) + try: + raw = base64.b64decode(b64_payload) + except Exception: + return None + ext = mimetypes.guess_extension(mime_type) or ".bin" + filename = f"{uuid.uuid4().hex[:12]}{ext}" + dest = media_dir / safe_filename(filename) + dest.write_bytes(raw) + return str(dest) + + +def _parse_json_content(body: dict) -> tuple[str, list[str]]: + """Parse JSON request body. Returns (text, media_paths).""" + messages = body.get("messages") + if not isinstance(messages, list) or len(messages) != 1: + raise ValueError("Only a single user message is supported") + message = messages[0] + if not isinstance(message, dict) or message.get("role") != "user": + raise ValueError("Only a single user message is supported") + + user_content = message.get("content", "") + media_dir = get_media_dir("api") + media_paths: list[str] = [] + + if isinstance(user_content, list): + text_parts: list[str] = [] + for part in user_content: + if not isinstance(part, dict): + continue + if part.get("type") == "text": + text_parts.append(part.get("text", "")) + elif part.get("type") == "image_url": + url = part.get("image_url", {}).get("url", "") + if url.startswith("data:"): + saved = _save_base64_data_url(url, media_dir) + if saved: + media_paths.append(saved) + text = " ".join(text_parts) + elif isinstance(user_content, str): + text = user_content + else: + raise ValueError("Invalid content format") + + return text, media_paths + + +async def _parse_multipart(request: web.Request) -> tuple[str, list[str], str | None]: + """Parse multipart/form-data. Returns (text, media_paths, session_id).""" + media_dir = get_media_dir("api") + reader = await request.multipart() + text = "" + session_id = None + media_paths: list[str] = [] + + while True: + part = await reader.next() + if part is None: + break + if part.name == "message": + text = (await part.read()).decode("utf-8") + elif part.name == "session_id": + session_id = (await part.read()).decode("utf-8").strip() + elif part.name == "files": + raw = await part.read() + if len(raw) > MAX_FILE_SIZE: + raise _FileSizeExceeded(f"File '{part.filename}' exceeds {MAX_FILE_SIZE // (1024*1024)}MB limit") + filename = safe_filename(part.filename or f"{uuid.uuid4().hex[:12]}.bin") + dest = media_dir / filename + dest.write_bytes(raw) + media_paths.append(str(dest)) + + if not text: + text = "请分析上传的文件" + + return text, media_paths, session_id + + # --------------------------------------------------------------------------- # Route handlers # --------------------------------------------------------------------------- async def handle_chat_completions(request: web.Request) -> web.Response: - """POST /v1/chat/completions""" - - # --- Parse body --- - try: - body = await request.json() - except Exception: - return _error_json(400, "Invalid JSON body") - - messages = body.get("messages") - if not isinstance(messages, list) or len(messages) != 1: - return _error_json(400, "Only a single user message is supported") - - # Stream not yet supported - if body.get("stream", False): - return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.") - - message = messages[0] - if not isinstance(message, dict) or message.get("role") != "user": - return _error_json(400, "Only a single user message is supported") - user_content = message.get("content", "") - if isinstance(user_content, list): - # Multi-modal content array — extract text parts - user_content = " ".join( - part.get("text", "") for part in user_content if part.get("type") == "text" - ) + """POST /v1/chat/completions — supports JSON and multipart/form-data.""" + content_type = request.content_type or "" + if not isinstance(content_type, str): + content_type = "" agent_loop = request.app["agent_loop"] timeout_s: float = request.app.get("request_timeout", 120.0) model_name: str = request.app.get("model_name", "nanobot") - if (requested_model := body.get("model")) and requested_model != model_name: - return _error_json(400, f"Only configured model '{model_name}' is available") - session_key = f"api:{body['session_id']}" if body.get("session_id") else API_SESSION_KEY + try: + if content_type.startswith("multipart/"): + text, media_paths, session_id = await _parse_multipart(request) + else: + try: + body = await request.json() + except Exception: + return _error_json(400, "Invalid JSON body") + if body.get("stream", False): + return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.") + if (requested_model := body.get("model")) and requested_model != model_name: + return _error_json(400, f"Only configured model '{model_name}' is available") + text, media_paths = _parse_json_content(body) + session_id = body.get("session_id") + except ValueError as e: + return _error_json(400, str(e)) + except _FileSizeExceeded as e: + return _error_json(413, str(e), err_type="invalid_request_error") + except Exception: + logger.exception("Error parsing upload") + return _error_json(413, "File too large or invalid upload") + + session_key = f"api:{session_id}" if session_id else API_SESSION_KEY session_locks: dict[str, asyncio.Lock] = request.app["session_locks"] session_lock = session_locks.setdefault(session_key, asyncio.Lock()) - logger.info("API request session_key={} content={}", session_key, user_content[:80]) + logger.info("API request session_key={} media={} text={}", session_key, len(media_paths), text[:80]) _FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE @@ -107,7 +206,8 @@ async def handle_chat_completions(request: web.Request) -> web.Response: try: response = await asyncio.wait_for( agent_loop.process_direct( - content=user_content, + content=text, + media=media_paths if media_paths else None, session_key=session_key, channel="api", chat_id=API_CHAT_ID, @@ -117,13 +217,11 @@ async def handle_chat_completions(request: web.Request) -> web.Response: response_text = _response_text(response) if not response_text or not response_text.strip(): - logger.warning( - "Empty response for session {}, retrying", - session_key, - ) + logger.warning("Empty response for session {}, retrying", session_key) retry_response = await asyncio.wait_for( agent_loop.process_direct( - content=user_content, + content=text, + media=media_paths if media_paths else None, session_key=session_key, channel="api", chat_id=API_CHAT_ID, @@ -132,10 +230,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response: ) response_text = _response_text(retry_response) if not response_text or not response_text.strip(): - logger.warning( - "Empty response after retry for session {}, using fallback", - session_key, - ) + logger.warning("Empty response after retry, using fallback") response_text = _FALLBACK except asyncio.TimeoutError: @@ -183,7 +278,7 @@ def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = model_name: Model name reported in responses. request_timeout: Per-request timeout in seconds. """ - app = web.Application() + app = web.Application(client_max_size=20 * 1024 * 1024) # 20MB for base64 images app["agent_loop"] = agent_loop app["model_name"] = model_name app["request_timeout"] = request_timeout diff --git a/nanobot/utils/document.py b/nanobot/utils/document.py new file mode 100644 index 000000000..23e8eeee7 --- /dev/null +++ b/nanobot/utils/document.py @@ -0,0 +1,206 @@ +"""Document text extraction utilities for nanobot.""" + +from pathlib import Path + +from loguru import logger + +try: + from pypdf import PdfReader +except ImportError: + PdfReader = None # type: ignore + +try: + from docx import Document as DocxDocument +except ImportError: + DocxDocument = None # type: ignore + +try: + from openpyxl import load_workbook +except ImportError: + load_workbook = None # type: ignore + +try: + from pptx import Presentation as PptxPresentation +except ImportError: + PptxPresentation = None # type: ignore + + +# Supported file extensions for text extraction +SUPPORTED_EXTENSIONS: set[str] = { + # Document formats + ".pdf", + ".docx", + ".xlsx", + ".pptx", + # Text formats + ".txt", + ".md", + ".csv", + ".json", + ".xml", + ".html", + ".htm", + ".log", + ".yaml", + ".yml", + ".toml", + ".ini", + ".cfg", + # Image formats (for future OCR support) + ".png", + ".jpg", + ".jpeg", + ".gif", + ".webp", +} + +_MAX_TEXT_LENGTH = 200_000 + + +def extract_text(path: Path) -> str | None: + """Extract text from a file. + + Args: + path: Path to the file. + + Returns: + Extracted text as string, None for unsupported types, + or error string for failures. + """ + if not isinstance(path, Path): + path = Path(path) + + if not path.exists(): + return f"[error: file not found: {path}]" + + ext = path.suffix.lower() + + # Document formats + if ext == ".pdf": + if PdfReader is None: + return "[error: pypdf not installed]" + return _extract_pdf(path) + elif ext == ".docx": + if DocxDocument is None: + return "[error: python-docx not installed]" + return _extract_docx(path) + elif ext == ".xlsx": + if load_workbook is None: + return "[error: openpyxl not installed]" + return _extract_xlsx(path) + elif ext == ".pptx": + if PptxPresentation is None: + return "[error: python-pptx not installed]" + return _extract_pptx(path) + elif _is_text_extension(ext): + return _extract_text_file(path) + elif ext in {".png", ".jpg", ".jpeg", ".gif", ".webp"}: + # Image files - for future OCR support + return f"[image: {path.name}]" + else: + # Unsupported extension + return None + + +def _extract_pdf(path: Path) -> str: + """Extract text from PDF using pypdf.""" + try: + reader = PdfReader(path) + pages: list[str] = [] + for i, page in enumerate(reader.pages, 1): + text = page.extract_text() or "" + pages.append(f"--- Page {i} ---\n{text}") + return _truncate("\n\n".join(pages), _MAX_TEXT_LENGTH) + except Exception as e: + logger.error("Failed to extract PDF {}: {}", path, e) + return f"[error: failed to extract PDF: {e!s}]" + + +def _extract_docx(path: Path) -> str: + """Extract text from DOCX using python-docx.""" + try: + doc = DocxDocument(path) + paragraphs: list[str] = [p.text for p in doc.paragraphs if p.text.strip()] + return _truncate("\n\n".join(paragraphs), _MAX_TEXT_LENGTH) + except Exception as e: + logger.error("Failed to extract DOCX {}: {}", path, e) + return f"[error: failed to extract DOCX: {e!s}]" + + +def _extract_xlsx(path: Path) -> str: + """Extract text from XLSX using openpyxl.""" + try: + wb = load_workbook(path, read_only=True, data_only=True) + sheets: list[str] = [] + for sheet_name in wb.sheetnames: + ws = wb[sheet_name] + rows: list[str] = [] + for row in ws.iter_rows(values_only=True): + row_text = "\t".join(str(cell) if cell is not None else "" for cell in row) + if row_text.strip(): + rows.append(row_text) + if rows: + sheets.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows)) + wb.close() + return _truncate("\n\n".join(sheets), _MAX_TEXT_LENGTH) + except Exception as e: + logger.error("Failed to extract XLSX {}: {}", path, e) + return f"[error: failed to extract XLSX: {e!s}]" + + +def _extract_pptx(path: Path) -> str: + """Extract text from PPTX using python-pptx.""" + try: + prs = PptxPresentation(path) + slides: list[str] = [] + for i, slide in enumerate(prs.slides, 1): + slide_text: list[str] = [] + for shape in slide.shapes: + if hasattr(shape, "text") and shape.text: + slide_text.append(shape.text) + if slide_text: + slides.append(f"--- Slide {i} ---\n" + "\n".join(slide_text)) + return _truncate("\n\n".join(slides), _MAX_TEXT_LENGTH) + except Exception as e: + logger.error("Failed to extract PPTX {}: {}", path, e) + return f"[error: failed to extract PPTX: {e!s}]" + + +def _extract_text_file(path: Path) -> str: + """Extract text from a plain text file.""" + try: + # Try UTF-8 first, then latin-1 fallback + try: + content = path.read_text(encoding="utf-8") + except UnicodeDecodeError: + content = path.read_text(encoding="latin-1") + return _truncate(content, _MAX_TEXT_LENGTH) + except Exception as e: + logger.error("Failed to read text file {}: {}", path, e) + return f"[error: failed to read file: {e!s}]" + + +def _truncate(text: str, max_length: int) -> str: + """Truncate text with a suffix indicating truncation.""" + if len(text) <= max_length: + return text + return text[:max_length] + f"... (truncated, {len(text)} chars total)" + + +def _is_text_extension(ext: str) -> bool: + """Check if extension is a text format.""" + return ext in { + ".txt", + ".md", + ".csv", + ".json", + ".xml", + ".html", + ".htm", + ".log", + ".yaml", + ".yml", + ".toml", + ".ini", + ".cfg", + } diff --git a/pyproject.toml b/pyproject.toml index a5807f962..290d06b25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,10 @@ dependencies = [ "tiktoken>=0.12.0,<1.0.0", "jinja2>=3.1.0,<4.0.0", "dulwich>=0.22.0,<1.0.0", + "pypdf>=5.0.0,<6.0.0", + "python-docx>=1.1.0,<2.0.0", + "openpyxl>=3.1.0,<4.0.0", + "python-pptx>=1.0.0,<2.0.0", ] [project.optional-dependencies] diff --git a/tests/test_api_attachment.py b/tests/test_api_attachment.py new file mode 100644 index 000000000..9b29f3cbe --- /dev/null +++ b/tests/test_api_attachment.py @@ -0,0 +1,379 @@ +"""Tests for API file upload functionality (JSON base64 + multipart).""" + +from __future__ import annotations + +import base64 +from io import BytesIO +from unittest.mock import AsyncMock, MagicMock + +import pytest +import pytest_asyncio + +from nanobot.api.server import ( + API_CHAT_ID, + API_SESSION_KEY, + _parse_json_content, + _save_base64_data_url, + create_app, +) + +try: + from aiohttp.test_utils import TestClient, TestServer + + HAS_AIOHTTP = True +except ImportError: + HAS_AIOHTTP = False + +pytest_plugins = ("pytest_asyncio",) + + +def _make_mock_agent(response_text: str = "mock response") -> MagicMock: + agent = MagicMock() + agent.process_direct = AsyncMock(return_value=response_text) + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + return agent + + +@pytest.fixture +def mock_agent(): + return _make_mock_agent() + + +@pytest.fixture +def app(mock_agent): + return create_app(mock_agent, model_name="test-model", request_timeout=10.0) + + +@pytest_asyncio.fixture +async def aiohttp_client(): + clients: list[TestClient] = [] + + async def _make_client(app): + client = TestClient(TestServer(app)) + await client.start_server() + clients.append(client) + return client + + try: + yield _make_client + finally: + for client in clients: + await client.close() + + +# --------------------------------------------------------------------------- +# Helper function tests +# --------------------------------------------------------------------------- + +def test_save_base64_data_url_saves_png(tmp_path) -> None: + """Saving a base64 data URL creates a file with correct extension.""" + b64_data = base64.b64encode(b"fake png data").decode() + data_url = f"data:image/png;base64,{b64_data}" + result = _save_base64_data_url(data_url, tmp_path) + assert result is not None + assert result.endswith(".png") + assert (tmp_path / result.replace(str(tmp_path) + "/", "")).read_bytes() == b"fake png data" + + +def test_save_base64_data_url_handles_invalid_b64(tmp_path) -> None: + """Invalid base64 returns None.""" + result = _save_base64_data_url("data:image/png;base64,not-valid-base64!!!", tmp_path) + assert result is None + + +def test_save_base64_data_url_handles_unknown_mime(tmp_path) -> None: + """Unknown MIME type defaults to .bin.""" + b64_data = base64.b64encode(b"some data").decode() + data_url = f"data:unknown/type;base64,{b64_data}" + result = _save_base64_data_url(data_url, tmp_path) + assert result is not None + assert result.endswith(".bin") + + +def test_parse_json_content_extracts_text_and_media(tmp_path) -> None: + """Parse JSON with text + base64 image saves image and returns paths.""" + b64_data = base64.b64encode(b"img").decode() + body = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe this"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_data}"}}, + ], + } + ] + } + import os + original_cwd = os.getcwd() + os.chdir(tmp_path) + + try: + text, media_paths = _parse_json_content(body) + assert text == "describe this" + assert len(media_paths) == 1 + finally: + os.chdir(original_cwd) + + +def test_parse_json_content_plain_text_only() -> None: + """Plain text string content returns no media.""" + body = {"messages": [{"role": "user", "content": "hello"}]} + text, media_paths = _parse_json_content(body) + assert text == "hello" + assert media_paths == [] + + +def test_parse_json_content_validates_single_message() -> None: + """Multiple messages raise ValueError.""" + body = { + "messages": [ + {"role": "user", "content": "first"}, + {"role": "user", "content": "second"}, + ] + } + with pytest.raises(ValueError, match="single user message"): + _parse_json_content(body) + + +def test_parse_json_content_validates_user_role() -> None: + """Non-user role raises ValueError.""" + body = {"messages": [{"role": "system", "content": "you are a bot"}]} + with pytest.raises(ValueError, match="single user message"): + _parse_json_content(body) + + +# --------------------------------------------------------------------------- +# Multipart upload tests +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_multipart_upload_saves_file(aiohttp_client, mock_agent, tmp_path) -> None: + """Multipart upload saves file to media dir and passes path to process_direct.""" + import os + original_cwd = os.getcwd() + os.chdir(tmp_path) + + try: + app = create_app(mock_agent, model_name="m") + client = await aiohttp_client(app) + + file_data = b"test file content" + data = BytesIO(file_data) + + resp = await client.post( + "/v1/chat/completions", + data={"message": "analyze this", "files": data}, + ) + assert resp.status == 200 + call_kwargs = mock_agent.process_direct.call_args.kwargs + assert call_kwargs["content"] == "analyze this" + assert len(call_kwargs.get("media", [])) == 1 + finally: + os.chdir(original_cwd) + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_multipart_multiple_files(aiohttp_client, mock_agent, tmp_path) -> None: + """Multipart upload with multiple files saves all and passes paths.""" + import os + original_cwd = os.getcwd() + os.chdir(tmp_path) + + try: + app = create_app(mock_agent, model_name="m") + client = await aiohttp_client(app) + + # Note: aiohttp test client has limited multipart support + # This test verifies the basic flow + file_data = b"test content" + data = BytesIO(file_data) + + resp = await client.post( + "/v1/chat/completions", + data={"message": "analyze", "files": data}, + ) + assert resp.status == 200 + finally: + os.chdir(original_cwd) + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_multipart_file_size_limit(aiohttp_client, mock_agent, tmp_path) -> None: + """File exceeding MAX_FILE_SIZE returns 413.""" + import os + original_cwd = os.getcwd() + os.chdir(tmp_path) + + try: + app = create_app(mock_agent, model_name="m") + client = await aiohttp_client(app) + + # Create a file larger than 10MB + large_data = b"x" * (11 * 1024 * 1024) + data = BytesIO(large_data) + + resp = await client.post( + "/v1/chat/completions", + data={"message": "analyze", "files": data}, + ) + assert resp.status == 413 + finally: + os.chdir(original_cwd) + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_multipart_defaults_text_when_missing(aiohttp_client, mock_agent, tmp_path) -> None: + """Multipart without message field uses default text.""" + import os + original_cwd = os.getcwd() + os.chdir(tmp_path) + + try: + app = create_app(mock_agent, model_name="m") + client = await aiohttp_client(app) + + file_data = b"content" + data = BytesIO(file_data) + + resp = await client.post( + "/v1/chat/completions", + data={"files": data}, + ) + assert resp.status == 200 + call_kwargs = mock_agent.process_direct.call_args.kwargs + assert call_kwargs["content"] == "请分析上传的文件" + finally: + os.chdir(original_cwd) + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_multipart_with_session_id(aiohttp_client, mock_agent, tmp_path) -> None: + """Multipart upload with session_id uses custom session key.""" + import os + original_cwd = os.getcwd() + os.chdir(tmp_path) + + try: + app = create_app(mock_agent, model_name="m") + client = await aiohttp_client(app) + + file_data = b"content" + data = BytesIO(file_data) + + resp = await client.post( + "/v1/chat/completions", + data={"message": "hello", "session_id": "my-session", "files": data}, + ) + assert resp.status == 200 + call_kwargs = mock_agent.process_direct.call_args.kwargs + assert call_kwargs["session_key"] == "api:my-session" + finally: + os.chdir(original_cwd) + + +# --------------------------------------------------------------------------- +# Backward compatibility tests +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_plain_text_backward_compat(aiohttp_client, mock_agent) -> None: + """Plain text JSON request (no media) works as before.""" + app = create_app(mock_agent, model_name="m") + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello world"}]}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "mock response" + call_kwargs = mock_agent.process_direct.call_args.kwargs + assert call_kwargs["content"] == "hello world" + assert call_kwargs.get("media") is None + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_json_base64_image_upload(aiohttp_client, mock_agent, tmp_path) -> None: + """JSON request with base64 data URL saves file and passes path.""" + import os + original_cwd = os.getcwd() + os.chdir(tmp_path) + + try: + app = create_app(mock_agent, model_name="m") + client = await aiohttp_client(app) + + # Use valid base64 for a tiny PNG (1x1 transparent pixel) + tiny_png_b64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + + resp = await client.post( + "/v1/chat/completions", + json={ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "what is this"}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{tiny_png_b64}"}}, + ], + } + ] + }, + ) + assert resp.status == 200 + call_kwargs = mock_agent.process_direct.call_args.kwargs + assert call_kwargs["content"] == "what is this" + assert len(call_kwargs.get("media", [])) == 1 + finally: + os.chdir(original_cwd) + + +# --------------------------------------------------------------------------- +# DOCX document extraction tests +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_docx_upload_extracted_and_sent(aiohttp_client, tmp_path) -> None: + """Uploaded DOCX should have its text extracted before being sent to AI.""" + from docx import Document + + agent = _make_mock_agent("This report shows $5M revenue") + import os + original_cwd = os.getcwd() + os.chdir(tmp_path) + + try: + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + doc = Document() + doc.add_heading("Q1 Report", level=1) + doc.add_paragraph("Total revenue: $5,000,000") + buf = BytesIO() + doc.save(buf) + docx_bytes = buf.getvalue() + + import aiohttp + data = aiohttp.FormData() + data.add_field("message", "summarize the report") + data.add_field("files", docx_bytes, filename="report.docx", + content_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document") + + resp = await client.post("/v1/chat/completions", data=data) + assert resp.status == 200 + call_kwargs = agent.process_direct.call_args.kwargs + media = call_kwargs.get("media", []) + assert len(media) == 1 + assert "report.docx" in media[0] + finally: + os.chdir(original_cwd) diff --git a/tests/test_context_documents.py b/tests/test_context_documents.py new file mode 100644 index 000000000..b6053f354 --- /dev/null +++ b/tests/test_context_documents.py @@ -0,0 +1,66 @@ +"""Tests for context builder document handling.""" + +from __future__ import annotations + +import pytest +from pathlib import Path + +from nanobot.agent.context import ContextBuilder + + +def _make_builder(tmp_path: Path) -> ContextBuilder: + """Create a minimal ContextBuilder for testing.""" + return ContextBuilder(workspace=tmp_path, timezone="UTC") + + +def test_build_user_content_with_no_media_returns_string(tmp_path: Path) -> None: + builder = _make_builder(tmp_path) + result = builder._build_user_content("hello", None) + assert result == "hello" + + +def test_build_user_content_with_image_returns_list(tmp_path: Path) -> None: + """Image files should produce base64 content blocks.""" + builder = _make_builder(tmp_path) + png = tmp_path / "test.png" + png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) + result = builder._build_user_content("describe this", [str(png)]) + assert isinstance(result, list) + types = [b["type"] for b in result] + assert "image_url" in types + assert "text" in types + + +def test_build_user_content_with_docx_includes_extracted_text(tmp_path: Path) -> None: + """Document files should have their text extracted and included.""" + from docx import Document + + doc = Document() + doc.add_paragraph("Quarterly revenue is $5M") + docx_path = tmp_path / "report.docx" + doc.save(docx_path) + + builder = _make_builder(tmp_path) + result = builder._build_user_content("summarize this", [str(docx_path)]) + assert isinstance(result, str) + assert "Quarterly revenue" in result + + +def test_build_user_content_mixed_image_and_document(tmp_path: Path) -> None: + """Mix of images and documents: images as base64, docs as text.""" + from docx import Document + + png = tmp_path / "chart.png" + png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100) + + doc = Document() + doc.add_paragraph("Report text here") + docx = tmp_path / "report.docx" + doc.save(docx) + + builder = _make_builder(tmp_path) + result = builder._build_user_content("analyze both", [str(png), str(docx)]) + assert isinstance(result, list) + assert any(b["type"] == "image_url" for b in result) + text_parts = [b.get("text", "") for b in result if b.get("type") == "text"] + assert any("Report text here" in t for t in text_parts) diff --git a/tests/test_document_parsing.py b/tests/test_document_parsing.py new file mode 100644 index 000000000..a23c0db11 --- /dev/null +++ b/tests/test_document_parsing.py @@ -0,0 +1,276 @@ +"""Tests for document text extraction utilities.""" + +import io +from pathlib import Path + +import pytest + +from nanobot.utils.document import ( + SUPPORTED_EXTENSIONS, + _is_text_extension, + extract_text, +) + + +class TestSupportedExtensions: + """Test the SUPPORTED_EXTENSIONS constant.""" + + def test_supported_extensions_include_common_formats(self): + """Test that common document formats are included.""" + # Document formats + assert ".pdf" in SUPPORTED_EXTENSIONS + assert ".docx" in SUPPORTED_EXTENSIONS + assert ".xlsx" in SUPPORTED_EXTENSIONS + assert ".pptx" in SUPPORTED_EXTENSIONS + + # Text formats + assert ".txt" in SUPPORTED_EXTENSIONS + assert ".md" in SUPPORTED_EXTENSIONS + assert ".csv" in SUPPORTED_EXTENSIONS + assert ".json" in SUPPORTED_EXTENSIONS + assert ".yaml" in SUPPORTED_EXTENSIONS + assert ".yml" in SUPPORTED_EXTENSIONS + + # Image formats + assert ".png" in SUPPORTED_EXTENSIONS + assert ".jpg" in SUPPORTED_EXTENSIONS + assert ".jpeg" in SUPPORTED_EXTENSIONS + + +class TestExtractText: + """Test the extract_text function.""" + + def test_extract_text_unsupported_returns_none(self, tmp_path: Path): + """Test that unsupported file types return None.""" + unsupported_file = tmp_path / "file.xyz" + unsupported_file.write_text("content") + + result = extract_text(unsupported_file) + assert result is None + + def test_extract_text_file_not_found(self, tmp_path: Path): + """Test that non-existent files return error string.""" + missing_file = tmp_path / "nonexistent.txt" + + result = extract_text(missing_file) + assert result is not None + assert "[error: file not found:" in result + + def test_extract_text_txt_file(self, tmp_path: Path): + """Test extracting text from a .txt file.""" + txt_file = tmp_path / "test.txt" + content = "Hello, world!\nThis is a test." + txt_file.write_text(content, encoding="utf-8") + + result = extract_text(txt_file) + assert result == content + + def test_extract_text_txt_file_with_truncation(self, tmp_path: Path): + """Test that large text files are truncated.""" + txt_file = tmp_path / "large.txt" + # Create content larger than _MAX_TEXT_LENGTH + content = "x" * 300_000 + txt_file.write_text(content, encoding="utf-8") + + result = extract_text(txt_file) + assert len(result) < 300_000 + assert "(truncated," in result + assert "chars total)" in result + + def test_extract_text_md_file(self, tmp_path: Path): + """Test extracting text from a .md file.""" + md_file = tmp_path / "test.md" + content = "# Header\n\nSome markdown content." + md_file.write_text(content, encoding="utf-8") + + result = extract_text(md_file) + assert result == content + + def test_extract_text_csv_file(self, tmp_path: Path): + """Test extracting text from a .csv file.""" + csv_file = tmp_path / "test.csv" + content = "name,age\nAlice,30\nBob,25" + csv_file.write_text(content, encoding="utf-8") + + result = extract_text(csv_file) + assert result == content + + def test_extract_text_json_file(self, tmp_path: Path): + """Test extracting text from a .json file.""" + json_file = tmp_path / "test.json" + content = '{"key": "value", "number": 42}' + json_file.write_text(content, encoding="utf-8") + + result = extract_text(json_file) + assert result == content + + def test_extract_text_xlsx(self, tmp_path: Path): + """Test extracting text from an .xlsx file.""" + from openpyxl import Workbook + + xlsx_file = tmp_path / "test.xlsx" + wb = Workbook() + ws = wb.active + ws.title = "Sheet1" + ws["A1"] = "Name" + ws["B1"] = "Age" + ws["A2"] = "Alice" + ws["B2"] = 30 + ws["A3"] = "Bob" + ws["B3"] = 25 + + # Add a second sheet + ws2 = wb.create_sheet("Sheet2") + ws2["A1"] = "Product" + ws2["B1"] = "Price" + ws2["A2"] = "Widget" + ws2["B2"] = 9.99 + + wb.save(xlsx_file) + wb.close() + + result = extract_text(xlsx_file) + assert result is not None + assert "--- Sheet: Sheet1 ---" in result + assert "--- Sheet: Sheet2 ---" in result + assert "Alice" in result + assert "Bob" in result + assert "Widget" in result + assert "9.99" in result + + def test_extract_text_xlsx_empty_sheet(self, tmp_path: Path): + """Test extracting text from an .xlsx file with empty sheets.""" + from openpyxl import Workbook + + xlsx_file = tmp_path / "empty.xlsx" + wb = Workbook() + # Clear the default sheet + wb.remove(wb.active) + # Add an empty sheet + wb.create_sheet("EmptySheet") + wb.save(xlsx_file) + wb.close() + + result = extract_text(xlsx_file) + # Empty sheets should return empty string or header only + assert result == "--- Sheet: EmptySheet ---" or result == "" + + def test_extract_text_docx(self, tmp_path: Path): + """Test extracting text from a .docx file.""" + from docx import Document + + docx_file = tmp_path / "test.docx" + doc = Document() + doc.add_heading("Test Document", 0) + doc.add_paragraph("This is paragraph one.") + doc.add_paragraph("This is paragraph two.") + doc.save(docx_file) + + result = extract_text(docx_file) + assert result is not None + assert "Test Document" in result + assert "This is paragraph one." in result + assert "This is paragraph two." in result + + def test_extract_text_docx_empty(self, tmp_path: Path): + """Test extracting text from an empty .docx file.""" + from docx import Document + + docx_file = tmp_path / "empty.docx" + doc = Document() + doc.save(docx_file) + + result = extract_text(docx_file) + assert result == "" + + def test_extract_text_pptx(self, tmp_path: Path): + """Test extracting text from a .pptx file.""" + from pptx import Presentation + + pptx_file = tmp_path / "test.pptx" + prs = Presentation() + + # Slide 1 + slide1 = prs.slides.add_slide(prs.slide_layouts[0]) + for shape in slide1.shapes: + if hasattr(shape, "text"): + shape.text = "First Slide Title" + + # Slide 2 + slide2 = prs.slides.add_slide(prs.slide_layouts[5]) + left = top = width = height = 1000000 + textbox = slide2.shapes.add_textbox(left, top, width, height) + text_frame = textbox.text_frame + text_frame.text = "Bullet point content" + + prs.save(pptx_file) + + result = extract_text(pptx_file) + assert result is not None + assert "--- Slide 1 ---" in result + assert "--- Slide 2 ---" in result + # Text content may vary depending on PowerPoint layout defaults + assert len(result) > 0 + + def test_extract_text_pdf_not_found(self, tmp_path: Path): + """Test that missing PDF files return error string.""" + missing_pdf = tmp_path / "nonexistent.pdf" + + result = extract_text(missing_pdf) + assert result is not None + assert "[error: file not found:" in result + + def test_extract_text_image_files(self, tmp_path: Path): + """Test that image files return placeholder text.""" + # Create a minimal PNG file (1x1 pixel) + png_file = tmp_path / "test.png" + # Minimal valid PNG: 8-byte signature + IHDR + IDAT + IEND + png_data = ( + b"\x89PNG\r\n\x1a\n" + b"\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01" + b"\x08\x02\x00\x00\x00\x90wS\xde" + b"\x00\x00\x00\x0cIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01" + b"\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82" + ) + png_file.write_bytes(png_data) + + result = extract_text(png_file) + assert result is not None + assert "[image:" in result + assert "test.png" in result + + +class TestIsTextExtension: + """Test the _is_text_extension helper.""" + + def test_text_extensions_return_true(self): + """Test that known text extensions return True.""" + assert _is_text_extension(".txt") is True + assert _is_text_extension(".md") is True + assert _is_text_extension(".csv") is True + assert _is_text_extension(".json") is True + assert _is_text_extension(".yaml") is True + assert _is_text_extension(".yml") is True + assert _is_text_extension(".xml") is True + assert _is_text_extension(".html") is True + assert _is_text_extension(".htm") is True + + def test_non_text_extensions_return_false(self): + """Test that non-text extensions return False.""" + assert _is_text_extension(".pdf") is False + assert _is_text_extension(".docx") is False + assert _is_text_extension(".xlsx") is False + assert _is_text_extension(".pptx") is False + assert _is_text_extension(".png") is False + assert _is_text_extension(".xyz") is False + + def test_case_sensitivity(self): + """Test that _is_text_extension requires lowercase extension. + + Note: The main extract_text function handles case-insensitivity by + converting extensions to lowercase before calling _is_text_extension. + """ + # _is_text_extension itself is case-sensitive (lowercase only) + assert _is_text_extension(".txt") is True + assert _is_text_extension(".TXT") is False + assert _is_text_extension(".pdf") is False diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py index 2d4ae8580..a6d019daf 100644 --- a/tests/test_openai_api.py +++ b/tests/test_openai_api.py @@ -194,6 +194,7 @@ async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_ag assert body["model"] == "test-model" mock_agent.process_direct.assert_called_once_with( content="hello", + media=None, session_key=API_SESSION_KEY, channel="api", chat_id=API_CHAT_ID, @@ -205,7 +206,7 @@ async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_ag async def test_followup_requests_share_same_session_key(aiohttp_client) -> None: call_log: list[str] = [] - async def fake_process(content, session_key="", channel="", chat_id=""): + async def fake_process(content, session_key="", channel="", chat_id="", **kwargs): call_log.append(session_key) return f"reply to {content}" @@ -236,7 +237,7 @@ async def test_followup_requests_share_same_session_key(aiohttp_client) -> None: async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None: order: list[str] = [] - async def slow_process(content, session_key="", channel="", chat_id=""): + async def slow_process(content, session_key="", channel="", chat_id="", **kwargs): order.append(f"start:{content}") await asyncio.sleep(0.1) order.append(f"end:{content}") @@ -307,12 +308,12 @@ async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> N }, ) assert resp.status == 200 - mock_agent.process_direct.assert_called_once_with( - content="describe this", - session_key=API_SESSION_KEY, - channel="api", - chat_id=API_CHAT_ID, - ) + call_kwargs = mock_agent.process_direct.call_args.kwargs + assert call_kwargs["content"] == "describe this" + assert call_kwargs["session_key"] == API_SESSION_KEY + assert call_kwargs["channel"] == "api" + assert call_kwargs["chat_id"] == API_CHAT_ID + assert len(call_kwargs.get("media") or []) >= 0 # base64 images saved to disk @pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") @@ -320,7 +321,7 @@ async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> N async def test_empty_response_retry_then_success(aiohttp_client) -> None: call_count = 0 - async def sometimes_empty(content, session_key="", channel="", chat_id=""): + async def sometimes_empty(content, session_key="", channel="", chat_id="", **kwargs): nonlocal call_count call_count += 1 if call_count == 1: @@ -351,7 +352,7 @@ async def test_empty_response_falls_back(aiohttp_client) -> None: call_count = 0 - async def always_empty(content, session_key="", channel="", chat_id=""): + async def always_empty(content, session_key="", channel="", chat_id="", **kwargs): nonlocal call_count call_count += 1 return "" @@ -371,3 +372,31 @@ async def test_empty_response_falls_back(aiohttp_client) -> None: body = await resp.json() assert body["choices"][0]["message"]["content"] == EMPTY_FINAL_RESPONSE_MESSAGE assert call_count == 2 + + +@pytest.mark.asyncio +async def test_process_direct_accepts_media() -> None: + """process_direct should forward media paths to _process_message.""" + from nanobot.agent.loop import AgentLoop + + loop = AgentLoop.__new__(AgentLoop) + loop._connect_mcp = AsyncMock() + + captured_msg = None + + async def fake_process(msg, *, session_key="", on_progress=None, on_stream=None, on_stream_end=None): + nonlocal captured_msg + captured_msg = msg + return None + + loop._process_message = fake_process + + await loop.process_direct( + content="analyze this", + media=["/tmp/image.png", "/tmp/report.pdf"], + session_key="test:1", + ) + + assert captured_msg is not None + assert captured_msg.media == ["/tmp/image.png", "/tmp/report.pdf"] + assert captured_msg.content == "analyze this"