mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-31 14:01:17 +00:00
feat(api): support file uploads via JSON base64 and multipart/form-data
This commit is contained in:
parent
e21ba5f667
commit
a068df5a79
39
README.md
39
README.md
@ -1757,6 +1757,7 @@ By default, the API binds to `127.0.0.1:8900`. You can change this in `config.js
|
||||
- Single-message input: each request must contain exactly one `user` message
|
||||
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
|
||||
- No streaming: `stream=true` is not supported
|
||||
- **File uploads**: supports images, PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) via JSON base64 or `multipart/form-data` (max 10MB per file)
|
||||
|
||||
### Endpoints
|
||||
|
||||
@ -1775,6 +1776,44 @@ curl http://127.0.0.1:8900/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
### File Upload (JSON base64)
|
||||
|
||||
Send images inline using the OpenAI multimodal content format:
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": [
|
||||
{"type": "text", "text": "Describe this image"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}}
|
||||
]}]
|
||||
}'
|
||||
```
|
||||
|
||||
### File Upload (multipart/form-data)
|
||||
|
||||
Upload any supported file type (images, PDF, Word, Excel, PPT) via multipart:
|
||||
|
||||
```bash
|
||||
# Single file
|
||||
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||
-F "message=Summarize this report" \
|
||||
-F "files=@report.docx"
|
||||
|
||||
# Multiple files with session isolation
|
||||
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||
-F "message=Compare these files" \
|
||||
-F "files=@chart.png" \
|
||||
-F "files=@data.xlsx" \
|
||||
-F "session_id=my-session"
|
||||
```
|
||||
|
||||
Supported file types:
|
||||
- **Images**: PNG, JPEG, GIF, WebP (sent to AI as base64 for vision analysis)
|
||||
- **Documents**: PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) (text extracted and sent to AI)
|
||||
- **Text**: TXT, Markdown, CSV, JSON, etc. (read directly)
|
||||
|
||||
### Python (`requests`)
|
||||
|
||||
```python
|
||||
|
||||
@ -144,31 +144,56 @@ class ContextBuilder:
|
||||
messages.append({"role": current_role, "content": merged})
|
||||
return messages
|
||||
|
||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
||||
"""Build user message content with optional base64-encoded images."""
|
||||
def _build_user_content(
|
||||
self, text: str, media: list[str] | None
|
||||
) -> str | list[dict[str, Any]]:
|
||||
"""Build user message content with optional media.
|
||||
|
||||
Images are converted to base64 vision blocks.
|
||||
Documents (PDF, Word, Excel, PPT) have their text extracted and appended.
|
||||
"""
|
||||
if not media:
|
||||
return text
|
||||
|
||||
images = []
|
||||
images: list[dict[str, Any]] = []
|
||||
doc_texts: list[str] = []
|
||||
|
||||
for path in media:
|
||||
p = Path(path)
|
||||
if not p.is_file():
|
||||
continue
|
||||
raw = p.read_bytes()
|
||||
# Detect real MIME type from magic bytes; fallback to filename guess
|
||||
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
|
||||
if not mime or not mime.startswith("image/"):
|
||||
continue
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
images.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{mime};base64,{b64}"},
|
||||
"_meta": {"path": str(p)},
|
||||
})
|
||||
|
||||
if not images:
|
||||
if mime and mime.startswith("image/"):
|
||||
b64 = base64.b64encode(raw).decode()
|
||||
images.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{mime};base64,{b64}"},
|
||||
"_meta": {"path": str(p)},
|
||||
})
|
||||
else:
|
||||
# Try document text extraction
|
||||
from nanobot.utils.document import extract_text
|
||||
extracted = extract_text(p)
|
||||
if extracted and not extracted.startswith("Error"):
|
||||
doc_texts.append(f"[File: {p.name}]\n{extracted}")
|
||||
|
||||
# Build final content
|
||||
parts: list[dict[str, Any]] = []
|
||||
parts.extend(images)
|
||||
|
||||
combined_text = text
|
||||
if doc_texts:
|
||||
combined_text = text + "\n\n" + "\n\n".join(doc_texts)
|
||||
|
||||
if images:
|
||||
parts.append({"type": "text", "text": combined_text})
|
||||
return parts
|
||||
elif doc_texts:
|
||||
return combined_text
|
||||
else:
|
||||
return text
|
||||
return images + [{"type": "text", "text": text}]
|
||||
|
||||
def add_tool_result(
|
||||
self, messages: list[dict[str, Any]],
|
||||
|
||||
@ -765,13 +765,17 @@ class AgentLoop:
|
||||
session_key: str = "cli:direct",
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
media: list[str] | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a message directly and return the outbound payload."""
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
msg = InboundMessage(
|
||||
channel=channel, sender_id="user", chat_id=chat_id,
|
||||
content=content, media=media or [],
|
||||
)
|
||||
return await self._process_message(
|
||||
msg, session_key=session_key, on_progress=on_progress,
|
||||
on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
|
||||
@ -7,15 +7,28 @@ All requests route to a single persistent API session.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import mimetypes
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
|
||||
_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$", re.DOTALL)
|
||||
|
||||
|
||||
class _FileSizeExceeded(Exception):
|
||||
"""Raised when an uploaded file exceeds the size limit."""
|
||||
|
||||
API_SESSION_KEY = "api:default"
|
||||
API_CHAT_ID = "default"
|
||||
|
||||
@ -57,48 +70,134 @@ def _response_text(value: Any) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _save_base64_data_url(data_url: str, media_dir: Path) -> str | None:
|
||||
"""Decode a data:...;base64,... URL and save to disk."""
|
||||
m = _DATA_URL_RE.match(data_url)
|
||||
if not m:
|
||||
return None
|
||||
mime_type, b64_payload = m.group(1), m.group(2)
|
||||
try:
|
||||
raw = base64.b64decode(b64_payload)
|
||||
except Exception:
|
||||
return None
|
||||
ext = mimetypes.guess_extension(mime_type) or ".bin"
|
||||
filename = f"{uuid.uuid4().hex[:12]}{ext}"
|
||||
dest = media_dir / safe_filename(filename)
|
||||
dest.write_bytes(raw)
|
||||
return str(dest)
|
||||
|
||||
|
||||
def _parse_json_content(body: dict) -> tuple[str, list[str]]:
|
||||
"""Parse JSON request body. Returns (text, media_paths)."""
|
||||
messages = body.get("messages")
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
raise ValueError("Only a single user message is supported")
|
||||
message = messages[0]
|
||||
if not isinstance(message, dict) or message.get("role") != "user":
|
||||
raise ValueError("Only a single user message is supported")
|
||||
|
||||
user_content = message.get("content", "")
|
||||
media_dir = get_media_dir("api")
|
||||
media_paths: list[str] = []
|
||||
|
||||
if isinstance(user_content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in user_content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
if part.get("type") == "text":
|
||||
text_parts.append(part.get("text", ""))
|
||||
elif part.get("type") == "image_url":
|
||||
url = part.get("image_url", {}).get("url", "")
|
||||
if url.startswith("data:"):
|
||||
saved = _save_base64_data_url(url, media_dir)
|
||||
if saved:
|
||||
media_paths.append(saved)
|
||||
text = " ".join(text_parts)
|
||||
elif isinstance(user_content, str):
|
||||
text = user_content
|
||||
else:
|
||||
raise ValueError("Invalid content format")
|
||||
|
||||
return text, media_paths
|
||||
|
||||
|
||||
async def _parse_multipart(request: web.Request) -> tuple[str, list[str], str | None]:
|
||||
"""Parse multipart/form-data. Returns (text, media_paths, session_id)."""
|
||||
media_dir = get_media_dir("api")
|
||||
reader = await request.multipart()
|
||||
text = ""
|
||||
session_id = None
|
||||
media_paths: list[str] = []
|
||||
|
||||
while True:
|
||||
part = await reader.next()
|
||||
if part is None:
|
||||
break
|
||||
if part.name == "message":
|
||||
text = (await part.read()).decode("utf-8")
|
||||
elif part.name == "session_id":
|
||||
session_id = (await part.read()).decode("utf-8").strip()
|
||||
elif part.name == "files":
|
||||
raw = await part.read()
|
||||
if len(raw) > MAX_FILE_SIZE:
|
||||
raise _FileSizeExceeded(f"File '{part.filename}' exceeds {MAX_FILE_SIZE // (1024*1024)}MB limit")
|
||||
filename = safe_filename(part.filename or f"{uuid.uuid4().hex[:12]}.bin")
|
||||
dest = media_dir / filename
|
||||
dest.write_bytes(raw)
|
||||
media_paths.append(str(dest))
|
||||
|
||||
if not text:
|
||||
text = "请分析上传的文件"
|
||||
|
||||
return text, media_paths, session_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
"""POST /v1/chat/completions"""
|
||||
|
||||
# --- Parse body ---
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return _error_json(400, "Invalid JSON body")
|
||||
|
||||
messages = body.get("messages")
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
return _error_json(400, "Only a single user message is supported")
|
||||
|
||||
# Stream not yet supported
|
||||
if body.get("stream", False):
|
||||
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
|
||||
|
||||
message = messages[0]
|
||||
if not isinstance(message, dict) or message.get("role") != "user":
|
||||
return _error_json(400, "Only a single user message is supported")
|
||||
user_content = message.get("content", "")
|
||||
if isinstance(user_content, list):
|
||||
# Multi-modal content array — extract text parts
|
||||
user_content = " ".join(
|
||||
part.get("text", "") for part in user_content if part.get("type") == "text"
|
||||
)
|
||||
"""POST /v1/chat/completions — supports JSON and multipart/form-data."""
|
||||
content_type = request.content_type or ""
|
||||
if not isinstance(content_type, str):
|
||||
content_type = ""
|
||||
|
||||
agent_loop = request.app["agent_loop"]
|
||||
timeout_s: float = request.app.get("request_timeout", 120.0)
|
||||
model_name: str = request.app.get("model_name", "nanobot")
|
||||
if (requested_model := body.get("model")) and requested_model != model_name:
|
||||
return _error_json(400, f"Only configured model '{model_name}' is available")
|
||||
|
||||
session_key = f"api:{body['session_id']}" if body.get("session_id") else API_SESSION_KEY
|
||||
try:
|
||||
if content_type.startswith("multipart/"):
|
||||
text, media_paths, session_id = await _parse_multipart(request)
|
||||
else:
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return _error_json(400, "Invalid JSON body")
|
||||
if body.get("stream", False):
|
||||
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
|
||||
if (requested_model := body.get("model")) and requested_model != model_name:
|
||||
return _error_json(400, f"Only configured model '{model_name}' is available")
|
||||
text, media_paths = _parse_json_content(body)
|
||||
session_id = body.get("session_id")
|
||||
except ValueError as e:
|
||||
return _error_json(400, str(e))
|
||||
except _FileSizeExceeded as e:
|
||||
return _error_json(413, str(e), err_type="invalid_request_error")
|
||||
except Exception:
|
||||
logger.exception("Error parsing upload")
|
||||
return _error_json(413, "File too large or invalid upload")
|
||||
|
||||
session_key = f"api:{session_id}" if session_id else API_SESSION_KEY
|
||||
session_locks: dict[str, asyncio.Lock] = request.app["session_locks"]
|
||||
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
|
||||
|
||||
logger.info("API request session_key={} content={}", session_key, user_content[:80])
|
||||
logger.info("API request session_key={} media={} text={}", session_key, len(media_paths), text[:80])
|
||||
|
||||
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
@ -107,7 +206,8 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
content=text,
|
||||
media=media_paths if media_paths else None,
|
||||
session_key=session_key,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
@ -117,13 +217,11 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
response_text = _response_text(response)
|
||||
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning(
|
||||
"Empty response for session {}, retrying",
|
||||
session_key,
|
||||
)
|
||||
logger.warning("Empty response for session {}, retrying", session_key)
|
||||
retry_response = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
content=text,
|
||||
media=media_paths if media_paths else None,
|
||||
session_key=session_key,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
@ -132,10 +230,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
)
|
||||
response_text = _response_text(retry_response)
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning(
|
||||
"Empty response after retry for session {}, using fallback",
|
||||
session_key,
|
||||
)
|
||||
logger.warning("Empty response after retry, using fallback")
|
||||
response_text = _FALLBACK
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
@ -183,7 +278,7 @@ def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float =
|
||||
model_name: Model name reported in responses.
|
||||
request_timeout: Per-request timeout in seconds.
|
||||
"""
|
||||
app = web.Application()
|
||||
app = web.Application(client_max_size=20 * 1024 * 1024) # 20MB for base64 images
|
||||
app["agent_loop"] = agent_loop
|
||||
app["model_name"] = model_name
|
||||
app["request_timeout"] = request_timeout
|
||||
|
||||
206
nanobot/utils/document.py
Normal file
206
nanobot/utils/document.py
Normal file
@ -0,0 +1,206 @@
|
||||
"""Document text extraction utilities for nanobot."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
PdfReader = None # type: ignore
|
||||
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
except ImportError:
|
||||
DocxDocument = None # type: ignore
|
||||
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
except ImportError:
|
||||
load_workbook = None # type: ignore
|
||||
|
||||
try:
|
||||
from pptx import Presentation as PptxPresentation
|
||||
except ImportError:
|
||||
PptxPresentation = None # type: ignore
|
||||
|
||||
|
||||
# Supported file extensions for text extraction
|
||||
SUPPORTED_EXTENSIONS: set[str] = {
|
||||
# Document formats
|
||||
".pdf",
|
||||
".docx",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
# Text formats
|
||||
".txt",
|
||||
".md",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".html",
|
||||
".htm",
|
||||
".log",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
# Image formats (for future OCR support)
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
}
|
||||
|
||||
_MAX_TEXT_LENGTH = 200_000
|
||||
|
||||
|
||||
def extract_text(path: Path) -> str | None:
|
||||
"""Extract text from a file.
|
||||
|
||||
Args:
|
||||
path: Path to the file.
|
||||
|
||||
Returns:
|
||||
Extracted text as string, None for unsupported types,
|
||||
or error string for failures.
|
||||
"""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
return f"[error: file not found: {path}]"
|
||||
|
||||
ext = path.suffix.lower()
|
||||
|
||||
# Document formats
|
||||
if ext == ".pdf":
|
||||
if PdfReader is None:
|
||||
return "[error: pypdf not installed]"
|
||||
return _extract_pdf(path)
|
||||
elif ext == ".docx":
|
||||
if DocxDocument is None:
|
||||
return "[error: python-docx not installed]"
|
||||
return _extract_docx(path)
|
||||
elif ext == ".xlsx":
|
||||
if load_workbook is None:
|
||||
return "[error: openpyxl not installed]"
|
||||
return _extract_xlsx(path)
|
||||
elif ext == ".pptx":
|
||||
if PptxPresentation is None:
|
||||
return "[error: python-pptx not installed]"
|
||||
return _extract_pptx(path)
|
||||
elif _is_text_extension(ext):
|
||||
return _extract_text_file(path)
|
||||
elif ext in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
|
||||
# Image files - for future OCR support
|
||||
return f"[image: {path.name}]"
|
||||
else:
|
||||
# Unsupported extension
|
||||
return None
|
||||
|
||||
|
||||
def _extract_pdf(path: Path) -> str:
|
||||
"""Extract text from PDF using pypdf."""
|
||||
try:
|
||||
reader = PdfReader(path)
|
||||
pages: list[str] = []
|
||||
for i, page in enumerate(reader.pages, 1):
|
||||
text = page.extract_text() or ""
|
||||
pages.append(f"--- Page {i} ---\n{text}")
|
||||
return _truncate("\n\n".join(pages), _MAX_TEXT_LENGTH)
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract PDF {}: {}", path, e)
|
||||
return f"[error: failed to extract PDF: {e!s}]"
|
||||
|
||||
|
||||
def _extract_docx(path: Path) -> str:
|
||||
"""Extract text from DOCX using python-docx."""
|
||||
try:
|
||||
doc = DocxDocument(path)
|
||||
paragraphs: list[str] = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
return _truncate("\n\n".join(paragraphs), _MAX_TEXT_LENGTH)
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract DOCX {}: {}", path, e)
|
||||
return f"[error: failed to extract DOCX: {e!s}]"
|
||||
|
||||
|
||||
def _extract_xlsx(path: Path) -> str:
|
||||
"""Extract text from XLSX using openpyxl."""
|
||||
try:
|
||||
wb = load_workbook(path, read_only=True, data_only=True)
|
||||
sheets: list[str] = []
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
rows: list[str] = []
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
row_text = "\t".join(str(cell) if cell is not None else "" for cell in row)
|
||||
if row_text.strip():
|
||||
rows.append(row_text)
|
||||
if rows:
|
||||
sheets.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
|
||||
wb.close()
|
||||
return _truncate("\n\n".join(sheets), _MAX_TEXT_LENGTH)
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract XLSX {}: {}", path, e)
|
||||
return f"[error: failed to extract XLSX: {e!s}]"
|
||||
|
||||
|
||||
def _extract_pptx(path: Path) -> str:
|
||||
"""Extract text from PPTX using python-pptx."""
|
||||
try:
|
||||
prs = PptxPresentation(path)
|
||||
slides: list[str] = []
|
||||
for i, slide in enumerate(prs.slides, 1):
|
||||
slide_text: list[str] = []
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text") and shape.text:
|
||||
slide_text.append(shape.text)
|
||||
if slide_text:
|
||||
slides.append(f"--- Slide {i} ---\n" + "\n".join(slide_text))
|
||||
return _truncate("\n\n".join(slides), _MAX_TEXT_LENGTH)
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract PPTX {}: {}", path, e)
|
||||
return f"[error: failed to extract PPTX: {e!s}]"
|
||||
|
||||
|
||||
def _extract_text_file(path: Path) -> str:
|
||||
"""Extract text from a plain text file."""
|
||||
try:
|
||||
# Try UTF-8 first, then latin-1 fallback
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
content = path.read_text(encoding="latin-1")
|
||||
return _truncate(content, _MAX_TEXT_LENGTH)
|
||||
except Exception as e:
|
||||
logger.error("Failed to read text file {}: {}", path, e)
|
||||
return f"[error: failed to read file: {e!s}]"
|
||||
|
||||
|
||||
def _truncate(text: str, max_length: int) -> str:
|
||||
"""Truncate text with a suffix indicating truncation."""
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
return text[:max_length] + f"... (truncated, {len(text)} chars total)"
|
||||
|
||||
|
||||
def _is_text_extension(ext: str) -> bool:
|
||||
"""Check if extension is a text format."""
|
||||
return ext in {
|
||||
".txt",
|
||||
".md",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".html",
|
||||
".htm",
|
||||
".log",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
}
|
||||
@ -50,6 +50,10 @@ dependencies = [
|
||||
"tiktoken>=0.12.0,<1.0.0",
|
||||
"jinja2>=3.1.0,<4.0.0",
|
||||
"dulwich>=0.22.0,<1.0.0",
|
||||
"pypdf>=5.0.0,<6.0.0",
|
||||
"python-docx>=1.1.0,<2.0.0",
|
||||
"openpyxl>=3.1.0,<4.0.0",
|
||||
"python-pptx>=1.0.0,<2.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
379
tests/test_api_attachment.py
Normal file
379
tests/test_api_attachment.py
Normal file
@ -0,0 +1,379 @@
|
||||
"""Tests for API file upload functionality (JSON base64 + multipart)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from nanobot.api.server import (
|
||||
API_CHAT_ID,
|
||||
API_SESSION_KEY,
|
||||
_parse_json_content,
|
||||
_save_base64_data_url,
|
||||
create_app,
|
||||
)
|
||||
|
||||
try:
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
HAS_AIOHTTP = True
|
||||
except ImportError:
|
||||
HAS_AIOHTTP = False
|
||||
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
|
||||
agent = MagicMock()
|
||||
agent.process_direct = AsyncMock(return_value=response_text)
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
return _make_mock_agent()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_agent):
|
||||
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def aiohttp_client():
|
||||
clients: list[TestClient] = []
|
||||
|
||||
async def _make_client(app):
|
||||
client = TestClient(TestServer(app))
|
||||
await client.start_server()
|
||||
clients.append(client)
|
||||
return client
|
||||
|
||||
try:
|
||||
yield _make_client
|
||||
finally:
|
||||
for client in clients:
|
||||
await client.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper function tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_save_base64_data_url_saves_png(tmp_path) -> None:
|
||||
"""Saving a base64 data URL creates a file with correct extension."""
|
||||
b64_data = base64.b64encode(b"fake png data").decode()
|
||||
data_url = f"data:image/png;base64,{b64_data}"
|
||||
result = _save_base64_data_url(data_url, tmp_path)
|
||||
assert result is not None
|
||||
assert result.endswith(".png")
|
||||
assert (tmp_path / result.replace(str(tmp_path) + "/", "")).read_bytes() == b"fake png data"
|
||||
|
||||
|
||||
def test_save_base64_data_url_handles_invalid_b64(tmp_path) -> None:
|
||||
"""Invalid base64 returns None."""
|
||||
result = _save_base64_data_url("data:image/png;base64,not-valid-base64!!!", tmp_path)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_save_base64_data_url_handles_unknown_mime(tmp_path) -> None:
|
||||
"""Unknown MIME type defaults to .bin."""
|
||||
b64_data = base64.b64encode(b"some data").decode()
|
||||
data_url = f"data:unknown/type;base64,{b64_data}"
|
||||
result = _save_base64_data_url(data_url, tmp_path)
|
||||
assert result is not None
|
||||
assert result.endswith(".bin")
|
||||
|
||||
|
||||
def test_parse_json_content_extracts_text_and_media(tmp_path) -> None:
|
||||
"""Parse JSON with text + base64 image saves image and returns paths."""
|
||||
b64_data = base64.b64encode(b"img").decode()
|
||||
body = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_data}"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
text, media_paths = _parse_json_content(body)
|
||||
assert text == "describe this"
|
||||
assert len(media_paths) == 1
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
def test_parse_json_content_plain_text_only() -> None:
|
||||
"""Plain text string content returns no media."""
|
||||
body = {"messages": [{"role": "user", "content": "hello"}]}
|
||||
text, media_paths = _parse_json_content(body)
|
||||
assert text == "hello"
|
||||
assert media_paths == []
|
||||
|
||||
|
||||
def test_parse_json_content_validates_single_message() -> None:
|
||||
"""Multiple messages raise ValueError."""
|
||||
body = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "first"},
|
||||
{"role": "user", "content": "second"},
|
||||
]
|
||||
}
|
||||
with pytest.raises(ValueError, match="single user message"):
|
||||
_parse_json_content(body)
|
||||
|
||||
|
||||
def test_parse_json_content_validates_user_role() -> None:
|
||||
"""Non-user role raises ValueError."""
|
||||
body = {"messages": [{"role": "system", "content": "you are a bot"}]}
|
||||
with pytest.raises(ValueError, match="single user message"):
|
||||
_parse_json_content(body)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multipart upload tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_upload_saves_file(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""Multipart upload saves file to media dir and passes path to process_direct."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
file_data = b"test file content"
|
||||
data = BytesIO(file_data)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
data={"message": "analyze this", "files": data},
|
||||
)
|
||||
assert resp.status == 200
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["content"] == "analyze this"
|
||||
assert len(call_kwargs.get("media", [])) == 1
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_multiple_files(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""Multipart upload with multiple files saves all and passes paths."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Note: aiohttp test client has limited multipart support
|
||||
# This test verifies the basic flow
|
||||
file_data = b"test content"
|
||||
data = BytesIO(file_data)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
data={"message": "analyze", "files": data},
|
||||
)
|
||||
assert resp.status == 200
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_file_size_limit(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""File exceeding MAX_FILE_SIZE returns 413."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Create a file larger than 10MB
|
||||
large_data = b"x" * (11 * 1024 * 1024)
|
||||
data = BytesIO(large_data)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
data={"message": "analyze", "files": data},
|
||||
)
|
||||
assert resp.status == 413
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_defaults_text_when_missing(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""Multipart without message field uses default text."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
file_data = b"content"
|
||||
data = BytesIO(file_data)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
data={"files": data},
|
||||
)
|
||||
assert resp.status == 200
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["content"] == "请分析上传的文件"
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_with_session_id(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""Multipart upload with session_id uses custom session key."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
file_data = b"content"
|
||||
data = BytesIO(file_data)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
data={"message": "hello", "session_id": "my-session", "files": data},
|
||||
)
|
||||
assert resp.status == 200
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["session_key"] == "api:my-session"
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward compatibility tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_text_backward_compat(aiohttp_client, mock_agent) -> None:
|
||||
"""Plain text JSON request (no media) works as before."""
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello world"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "mock response"
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["content"] == "hello world"
|
||||
assert call_kwargs.get("media") is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_base64_image_upload(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""JSON request with base64 data URL saves file and passes path."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Use valid base64 for a tiny PNG (1x1 transparent pixel)
|
||||
tiny_png_b64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "what is this"},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{tiny_png_b64}"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["content"] == "what is this"
|
||||
assert len(call_kwargs.get("media", [])) == 1
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DOCX document extraction tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_docx_upload_extracted_and_sent(aiohttp_client, tmp_path) -> None:
|
||||
"""Uploaded DOCX should have its text extracted before being sent to AI."""
|
||||
from docx import Document
|
||||
|
||||
agent = _make_mock_agent("This report shows $5M revenue")
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
doc = Document()
|
||||
doc.add_heading("Q1 Report", level=1)
|
||||
doc.add_paragraph("Total revenue: $5,000,000")
|
||||
buf = BytesIO()
|
||||
doc.save(buf)
|
||||
docx_bytes = buf.getvalue()
|
||||
|
||||
import aiohttp
|
||||
data = aiohttp.FormData()
|
||||
data.add_field("message", "summarize the report")
|
||||
data.add_field("files", docx_bytes, filename="report.docx",
|
||||
content_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document")
|
||||
|
||||
resp = await client.post("/v1/chat/completions", data=data)
|
||||
assert resp.status == 200
|
||||
call_kwargs = agent.process_direct.call_args.kwargs
|
||||
media = call_kwargs.get("media", [])
|
||||
assert len(media) == 1
|
||||
assert "report.docx" in media[0]
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
66
tests/test_context_documents.py
Normal file
66
tests/test_context_documents.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Tests for context builder document handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
|
||||
def _make_builder(tmp_path: Path) -> ContextBuilder:
|
||||
"""Create a minimal ContextBuilder for testing."""
|
||||
return ContextBuilder(workspace=tmp_path, timezone="UTC")
|
||||
|
||||
|
||||
def test_build_user_content_with_no_media_returns_string(tmp_path: Path) -> None:
|
||||
builder = _make_builder(tmp_path)
|
||||
result = builder._build_user_content("hello", None)
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
def test_build_user_content_with_image_returns_list(tmp_path: Path) -> None:
|
||||
"""Image files should produce base64 content blocks."""
|
||||
builder = _make_builder(tmp_path)
|
||||
png = tmp_path / "test.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||
result = builder._build_user_content("describe this", [str(png)])
|
||||
assert isinstance(result, list)
|
||||
types = [b["type"] for b in result]
|
||||
assert "image_url" in types
|
||||
assert "text" in types
|
||||
|
||||
|
||||
def test_build_user_content_with_docx_includes_extracted_text(tmp_path: Path) -> None:
|
||||
"""Document files should have their text extracted and included."""
|
||||
from docx import Document
|
||||
|
||||
doc = Document()
|
||||
doc.add_paragraph("Quarterly revenue is $5M")
|
||||
docx_path = tmp_path / "report.docx"
|
||||
doc.save(docx_path)
|
||||
|
||||
builder = _make_builder(tmp_path)
|
||||
result = builder._build_user_content("summarize this", [str(docx_path)])
|
||||
assert isinstance(result, str)
|
||||
assert "Quarterly revenue" in result
|
||||
|
||||
|
||||
def test_build_user_content_mixed_image_and_document(tmp_path: Path) -> None:
|
||||
"""Mix of images and documents: images as base64, docs as text."""
|
||||
from docx import Document
|
||||
|
||||
png = tmp_path / "chart.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||
|
||||
doc = Document()
|
||||
doc.add_paragraph("Report text here")
|
||||
docx = tmp_path / "report.docx"
|
||||
doc.save(docx)
|
||||
|
||||
builder = _make_builder(tmp_path)
|
||||
result = builder._build_user_content("analyze both", [str(png), str(docx)])
|
||||
assert isinstance(result, list)
|
||||
assert any(b["type"] == "image_url" for b in result)
|
||||
text_parts = [b.get("text", "") for b in result if b.get("type") == "text"]
|
||||
assert any("Report text here" in t for t in text_parts)
|
||||
276
tests/test_document_parsing.py
Normal file
276
tests/test_document_parsing.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""Tests for document text extraction utilities."""
|
||||
|
||||
import io
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.utils.document import (
|
||||
SUPPORTED_EXTENSIONS,
|
||||
_is_text_extension,
|
||||
extract_text,
|
||||
)
|
||||
|
||||
|
||||
class TestSupportedExtensions:
|
||||
"""Test the SUPPORTED_EXTENSIONS constant."""
|
||||
|
||||
def test_supported_extensions_include_common_formats(self):
|
||||
"""Test that common document formats are included."""
|
||||
# Document formats
|
||||
assert ".pdf" in SUPPORTED_EXTENSIONS
|
||||
assert ".docx" in SUPPORTED_EXTENSIONS
|
||||
assert ".xlsx" in SUPPORTED_EXTENSIONS
|
||||
assert ".pptx" in SUPPORTED_EXTENSIONS
|
||||
|
||||
# Text formats
|
||||
assert ".txt" in SUPPORTED_EXTENSIONS
|
||||
assert ".md" in SUPPORTED_EXTENSIONS
|
||||
assert ".csv" in SUPPORTED_EXTENSIONS
|
||||
assert ".json" in SUPPORTED_EXTENSIONS
|
||||
assert ".yaml" in SUPPORTED_EXTENSIONS
|
||||
assert ".yml" in SUPPORTED_EXTENSIONS
|
||||
|
||||
# Image formats
|
||||
assert ".png" in SUPPORTED_EXTENSIONS
|
||||
assert ".jpg" in SUPPORTED_EXTENSIONS
|
||||
assert ".jpeg" in SUPPORTED_EXTENSIONS
|
||||
|
||||
|
||||
class TestExtractText:
|
||||
"""Test the extract_text function."""
|
||||
|
||||
def test_extract_text_unsupported_returns_none(self, tmp_path: Path):
|
||||
"""Test that unsupported file types return None."""
|
||||
unsupported_file = tmp_path / "file.xyz"
|
||||
unsupported_file.write_text("content")
|
||||
|
||||
result = extract_text(unsupported_file)
|
||||
assert result is None
|
||||
|
||||
def test_extract_text_file_not_found(self, tmp_path: Path):
|
||||
"""Test that non-existent files return error string."""
|
||||
missing_file = tmp_path / "nonexistent.txt"
|
||||
|
||||
result = extract_text(missing_file)
|
||||
assert result is not None
|
||||
assert "[error: file not found:" in result
|
||||
|
||||
def test_extract_text_txt_file(self, tmp_path: Path):
|
||||
"""Test extracting text from a .txt file."""
|
||||
txt_file = tmp_path / "test.txt"
|
||||
content = "Hello, world!\nThis is a test."
|
||||
txt_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = extract_text(txt_file)
|
||||
assert result == content
|
||||
|
||||
def test_extract_text_txt_file_with_truncation(self, tmp_path: Path):
|
||||
"""Test that large text files are truncated."""
|
||||
txt_file = tmp_path / "large.txt"
|
||||
# Create content larger than _MAX_TEXT_LENGTH
|
||||
content = "x" * 300_000
|
||||
txt_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = extract_text(txt_file)
|
||||
assert len(result) < 300_000
|
||||
assert "(truncated," in result
|
||||
assert "chars total)" in result
|
||||
|
||||
def test_extract_text_md_file(self, tmp_path: Path):
|
||||
"""Test extracting text from a .md file."""
|
||||
md_file = tmp_path / "test.md"
|
||||
content = "# Header\n\nSome markdown content."
|
||||
md_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = extract_text(md_file)
|
||||
assert result == content
|
||||
|
||||
def test_extract_text_csv_file(self, tmp_path: Path):
|
||||
"""Test extracting text from a .csv file."""
|
||||
csv_file = tmp_path / "test.csv"
|
||||
content = "name,age\nAlice,30\nBob,25"
|
||||
csv_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = extract_text(csv_file)
|
||||
assert result == content
|
||||
|
||||
def test_extract_text_json_file(self, tmp_path: Path):
|
||||
"""Test extracting text from a .json file."""
|
||||
json_file = tmp_path / "test.json"
|
||||
content = '{"key": "value", "number": 42}'
|
||||
json_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = extract_text(json_file)
|
||||
assert result == content
|
||||
|
||||
def test_extract_text_xlsx(self, tmp_path: Path):
|
||||
"""Test extracting text from an .xlsx file."""
|
||||
from openpyxl import Workbook
|
||||
|
||||
xlsx_file = tmp_path / "test.xlsx"
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
ws.title = "Sheet1"
|
||||
ws["A1"] = "Name"
|
||||
ws["B1"] = "Age"
|
||||
ws["A2"] = "Alice"
|
||||
ws["B2"] = 30
|
||||
ws["A3"] = "Bob"
|
||||
ws["B3"] = 25
|
||||
|
||||
# Add a second sheet
|
||||
ws2 = wb.create_sheet("Sheet2")
|
||||
ws2["A1"] = "Product"
|
||||
ws2["B1"] = "Price"
|
||||
ws2["A2"] = "Widget"
|
||||
ws2["B2"] = 9.99
|
||||
|
||||
wb.save(xlsx_file)
|
||||
wb.close()
|
||||
|
||||
result = extract_text(xlsx_file)
|
||||
assert result is not None
|
||||
assert "--- Sheet: Sheet1 ---" in result
|
||||
assert "--- Sheet: Sheet2 ---" in result
|
||||
assert "Alice" in result
|
||||
assert "Bob" in result
|
||||
assert "Widget" in result
|
||||
assert "9.99" in result
|
||||
|
||||
def test_extract_text_xlsx_empty_sheet(self, tmp_path: Path):
|
||||
"""Test extracting text from an .xlsx file with empty sheets."""
|
||||
from openpyxl import Workbook
|
||||
|
||||
xlsx_file = tmp_path / "empty.xlsx"
|
||||
wb = Workbook()
|
||||
# Clear the default sheet
|
||||
wb.remove(wb.active)
|
||||
# Add an empty sheet
|
||||
wb.create_sheet("EmptySheet")
|
||||
wb.save(xlsx_file)
|
||||
wb.close()
|
||||
|
||||
result = extract_text(xlsx_file)
|
||||
# Empty sheets should return empty string or header only
|
||||
assert result == "--- Sheet: EmptySheet ---" or result == ""
|
||||
|
||||
def test_extract_text_docx(self, tmp_path: Path):
|
||||
"""Test extracting text from a .docx file."""
|
||||
from docx import Document
|
||||
|
||||
docx_file = tmp_path / "test.docx"
|
||||
doc = Document()
|
||||
doc.add_heading("Test Document", 0)
|
||||
doc.add_paragraph("This is paragraph one.")
|
||||
doc.add_paragraph("This is paragraph two.")
|
||||
doc.save(docx_file)
|
||||
|
||||
result = extract_text(docx_file)
|
||||
assert result is not None
|
||||
assert "Test Document" in result
|
||||
assert "This is paragraph one." in result
|
||||
assert "This is paragraph two." in result
|
||||
|
||||
def test_extract_text_docx_empty(self, tmp_path: Path):
|
||||
"""Test extracting text from an empty .docx file."""
|
||||
from docx import Document
|
||||
|
||||
docx_file = tmp_path / "empty.docx"
|
||||
doc = Document()
|
||||
doc.save(docx_file)
|
||||
|
||||
result = extract_text(docx_file)
|
||||
assert result == ""
|
||||
|
||||
def test_extract_text_pptx(self, tmp_path: Path):
|
||||
"""Test extracting text from a .pptx file."""
|
||||
from pptx import Presentation
|
||||
|
||||
pptx_file = tmp_path / "test.pptx"
|
||||
prs = Presentation()
|
||||
|
||||
# Slide 1
|
||||
slide1 = prs.slides.add_slide(prs.slide_layouts[0])
|
||||
for shape in slide1.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
shape.text = "First Slide Title"
|
||||
|
||||
# Slide 2
|
||||
slide2 = prs.slides.add_slide(prs.slide_layouts[5])
|
||||
left = top = width = height = 1000000
|
||||
textbox = slide2.shapes.add_textbox(left, top, width, height)
|
||||
text_frame = textbox.text_frame
|
||||
text_frame.text = "Bullet point content"
|
||||
|
||||
prs.save(pptx_file)
|
||||
|
||||
result = extract_text(pptx_file)
|
||||
assert result is not None
|
||||
assert "--- Slide 1 ---" in result
|
||||
assert "--- Slide 2 ---" in result
|
||||
# Text content may vary depending on PowerPoint layout defaults
|
||||
assert len(result) > 0
|
||||
|
||||
def test_extract_text_pdf_not_found(self, tmp_path: Path):
|
||||
"""Test that missing PDF files return error string."""
|
||||
missing_pdf = tmp_path / "nonexistent.pdf"
|
||||
|
||||
result = extract_text(missing_pdf)
|
||||
assert result is not None
|
||||
assert "[error: file not found:" in result
|
||||
|
||||
def test_extract_text_image_files(self, tmp_path: Path):
|
||||
"""Test that image files return placeholder text."""
|
||||
# Create a minimal PNG file (1x1 pixel)
|
||||
png_file = tmp_path / "test.png"
|
||||
# Minimal valid PNG: 8-byte signature + IHDR + IDAT + IEND
|
||||
png_data = (
|
||||
b"\x89PNG\r\n\x1a\n"
|
||||
b"\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01"
|
||||
b"\x08\x02\x00\x00\x00\x90wS\xde"
|
||||
b"\x00\x00\x00\x0cIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01"
|
||||
b"\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
png_file.write_bytes(png_data)
|
||||
|
||||
result = extract_text(png_file)
|
||||
assert result is not None
|
||||
assert "[image:" in result
|
||||
assert "test.png" in result
|
||||
|
||||
|
||||
class TestIsTextExtension:
|
||||
"""Test the _is_text_extension helper."""
|
||||
|
||||
def test_text_extensions_return_true(self):
|
||||
"""Test that known text extensions return True."""
|
||||
assert _is_text_extension(".txt") is True
|
||||
assert _is_text_extension(".md") is True
|
||||
assert _is_text_extension(".csv") is True
|
||||
assert _is_text_extension(".json") is True
|
||||
assert _is_text_extension(".yaml") is True
|
||||
assert _is_text_extension(".yml") is True
|
||||
assert _is_text_extension(".xml") is True
|
||||
assert _is_text_extension(".html") is True
|
||||
assert _is_text_extension(".htm") is True
|
||||
|
||||
def test_non_text_extensions_return_false(self):
|
||||
"""Test that non-text extensions return False."""
|
||||
assert _is_text_extension(".pdf") is False
|
||||
assert _is_text_extension(".docx") is False
|
||||
assert _is_text_extension(".xlsx") is False
|
||||
assert _is_text_extension(".pptx") is False
|
||||
assert _is_text_extension(".png") is False
|
||||
assert _is_text_extension(".xyz") is False
|
||||
|
||||
def test_case_sensitivity(self):
|
||||
"""Test that _is_text_extension requires lowercase extension.
|
||||
|
||||
Note: The main extract_text function handles case-insensitivity by
|
||||
converting extensions to lowercase before calling _is_text_extension.
|
||||
"""
|
||||
# _is_text_extension itself is case-sensitive (lowercase only)
|
||||
assert _is_text_extension(".txt") is True
|
||||
assert _is_text_extension(".TXT") is False
|
||||
assert _is_text_extension(".pdf") is False
|
||||
@ -194,6 +194,7 @@ async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_ag
|
||||
assert body["model"] == "test-model"
|
||||
mock_agent.process_direct.assert_called_once_with(
|
||||
content="hello",
|
||||
media=None,
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
@ -205,7 +206,7 @@ async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_ag
|
||||
async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
|
||||
call_log: list[str] = []
|
||||
|
||||
async def fake_process(content, session_key="", channel="", chat_id=""):
|
||||
async def fake_process(content, session_key="", channel="", chat_id="", **kwargs):
|
||||
call_log.append(session_key)
|
||||
return f"reply to {content}"
|
||||
|
||||
@ -236,7 +237,7 @@ async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
|
||||
async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None:
|
||||
order: list[str] = []
|
||||
|
||||
async def slow_process(content, session_key="", channel="", chat_id=""):
|
||||
async def slow_process(content, session_key="", channel="", chat_id="", **kwargs):
|
||||
order.append(f"start:{content}")
|
||||
await asyncio.sleep(0.1)
|
||||
order.append(f"end:{content}")
|
||||
@ -307,12 +308,12 @@ async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> N
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
mock_agent.process_direct.assert_called_once_with(
|
||||
content="describe this",
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
)
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["content"] == "describe this"
|
||||
assert call_kwargs["session_key"] == API_SESSION_KEY
|
||||
assert call_kwargs["channel"] == "api"
|
||||
assert call_kwargs["chat_id"] == API_CHAT_ID
|
||||
assert len(call_kwargs.get("media") or []) >= 0 # base64 images saved to disk
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@ -320,7 +321,7 @@ async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> N
|
||||
async def test_empty_response_retry_then_success(aiohttp_client) -> None:
|
||||
call_count = 0
|
||||
|
||||
async def sometimes_empty(content, session_key="", channel="", chat_id=""):
|
||||
async def sometimes_empty(content, session_key="", channel="", chat_id="", **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
@ -351,7 +352,7 @@ async def test_empty_response_falls_back(aiohttp_client) -> None:
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def always_empty(content, session_key="", channel="", chat_id=""):
|
||||
async def always_empty(content, session_key="", channel="", chat_id="", **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return ""
|
||||
@ -371,3 +372,31 @@ async def test_empty_response_falls_back(aiohttp_client) -> None:
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_direct_accepts_media() -> None:
|
||||
"""process_direct should forward media paths to _process_message."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop._connect_mcp = AsyncMock()
|
||||
|
||||
captured_msg = None
|
||||
|
||||
async def fake_process(msg, *, session_key="", on_progress=None, on_stream=None, on_stream_end=None):
|
||||
nonlocal captured_msg
|
||||
captured_msg = msg
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process
|
||||
|
||||
await loop.process_direct(
|
||||
content="analyze this",
|
||||
media=["/tmp/image.png", "/tmp/report.pdf"],
|
||||
session_key="test:1",
|
||||
)
|
||||
|
||||
assert captured_msg is not None
|
||||
assert captured_msg.media == ["/tmp/image.png", "/tmp/report.pdf"]
|
||||
assert captured_msg.content == "analyze this"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user