mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-17 00:04:07 +00:00
feat(api): support file uploads via JSON base64 and multipart/form-data
This commit is contained in:
parent
e21ba5f667
commit
a068df5a79
39
README.md
39
README.md
@ -1757,6 +1757,7 @@ By default, the API binds to `127.0.0.1:8900`. You can change this in `config.js
|
|||||||
- Single-message input: each request must contain exactly one `user` message
|
- Single-message input: each request must contain exactly one `user` message
|
||||||
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
|
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
|
||||||
- No streaming: `stream=true` is not supported
|
- No streaming: `stream=true` is not supported
|
||||||
|
- **File uploads**: supports images, PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) via JSON base64 or `multipart/form-data` (max 10MB per file)
|
||||||
|
|
||||||
### Endpoints
|
### Endpoints
|
||||||
|
|
||||||
@ -1775,6 +1776,44 @@ curl http://127.0.0.1:8900/v1/chat/completions \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### File Upload (JSON base64)
|
||||||
|
|
||||||
|
Send images inline using the OpenAI multimodal content format:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"messages": [{"role": "user", "content": [
|
||||||
|
{"type": "text", "text": "Describe this image"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}}
|
||||||
|
]}]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### File Upload (multipart/form-data)
|
||||||
|
|
||||||
|
Upload any supported file type (images, PDF, Word, Excel, PPT) via multipart:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Single file
|
||||||
|
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||||
|
-F "message=Summarize this report" \
|
||||||
|
-F "files=@report.docx"
|
||||||
|
|
||||||
|
# Multiple files with session isolation
|
||||||
|
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||||
|
-F "message=Compare these files" \
|
||||||
|
-F "files=@chart.png" \
|
||||||
|
-F "files=@data.xlsx" \
|
||||||
|
-F "session_id=my-session"
|
||||||
|
```
|
||||||
|
|
||||||
|
Supported file types:
|
||||||
|
- **Images**: PNG, JPEG, GIF, WebP (sent to AI as base64 for vision analysis)
|
||||||
|
- **Documents**: PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) (text extracted and sent to AI)
|
||||||
|
- **Text**: TXT, Markdown, CSV, JSON, etc. (read directly)
|
||||||
|
|
||||||
### Python (`requests`)
|
### Python (`requests`)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|||||||
@ -144,31 +144,56 @@ class ContextBuilder:
|
|||||||
messages.append({"role": current_role, "content": merged})
|
messages.append({"role": current_role, "content": merged})
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
def _build_user_content(
|
||||||
"""Build user message content with optional base64-encoded images."""
|
self, text: str, media: list[str] | None
|
||||||
|
) -> str | list[dict[str, Any]]:
|
||||||
|
"""Build user message content with optional media.
|
||||||
|
|
||||||
|
Images are converted to base64 vision blocks.
|
||||||
|
Documents (PDF, Word, Excel, PPT) have their text extracted and appended.
|
||||||
|
"""
|
||||||
if not media:
|
if not media:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
images = []
|
images: list[dict[str, Any]] = []
|
||||||
|
doc_texts: list[str] = []
|
||||||
|
|
||||||
for path in media:
|
for path in media:
|
||||||
p = Path(path)
|
p = Path(path)
|
||||||
if not p.is_file():
|
if not p.is_file():
|
||||||
continue
|
continue
|
||||||
raw = p.read_bytes()
|
raw = p.read_bytes()
|
||||||
# Detect real MIME type from magic bytes; fallback to filename guess
|
|
||||||
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
|
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
|
||||||
if not mime or not mime.startswith("image/"):
|
|
||||||
continue
|
|
||||||
b64 = base64.b64encode(raw).decode()
|
|
||||||
images.append({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": f"data:{mime};base64,{b64}"},
|
|
||||||
"_meta": {"path": str(p)},
|
|
||||||
})
|
|
||||||
|
|
||||||
if not images:
|
if mime and mime.startswith("image/"):
|
||||||
|
b64 = base64.b64encode(raw).decode()
|
||||||
|
images.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:{mime};base64,{b64}"},
|
||||||
|
"_meta": {"path": str(p)},
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Try document text extraction
|
||||||
|
from nanobot.utils.document import extract_text
|
||||||
|
extracted = extract_text(p)
|
||||||
|
if extracted and not extracted.startswith("Error"):
|
||||||
|
doc_texts.append(f"[File: {p.name}]\n{extracted}")
|
||||||
|
|
||||||
|
# Build final content
|
||||||
|
parts: list[dict[str, Any]] = []
|
||||||
|
parts.extend(images)
|
||||||
|
|
||||||
|
combined_text = text
|
||||||
|
if doc_texts:
|
||||||
|
combined_text = text + "\n\n" + "\n\n".join(doc_texts)
|
||||||
|
|
||||||
|
if images:
|
||||||
|
parts.append({"type": "text", "text": combined_text})
|
||||||
|
return parts
|
||||||
|
elif doc_texts:
|
||||||
|
return combined_text
|
||||||
|
else:
|
||||||
return text
|
return text
|
||||||
return images + [{"type": "text", "text": text}]
|
|
||||||
|
|
||||||
def add_tool_result(
|
def add_tool_result(
|
||||||
self, messages: list[dict[str, Any]],
|
self, messages: list[dict[str, Any]],
|
||||||
|
|||||||
@ -765,13 +765,17 @@ class AgentLoop:
|
|||||||
session_key: str = "cli:direct",
|
session_key: str = "cli:direct",
|
||||||
channel: str = "cli",
|
channel: str = "cli",
|
||||||
chat_id: str = "direct",
|
chat_id: str = "direct",
|
||||||
|
media: list[str] | None = None,
|
||||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||||
) -> OutboundMessage | None:
|
) -> OutboundMessage | None:
|
||||||
"""Process a message directly and return the outbound payload."""
|
"""Process a message directly and return the outbound payload."""
|
||||||
await self._connect_mcp()
|
await self._connect_mcp()
|
||||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
msg = InboundMessage(
|
||||||
|
channel=channel, sender_id="user", chat_id=chat_id,
|
||||||
|
content=content, media=media or [],
|
||||||
|
)
|
||||||
return await self._process_message(
|
return await self._process_message(
|
||||||
msg, session_key=session_key, on_progress=on_progress,
|
msg, session_key=session_key, on_progress=on_progress,
|
||||||
on_stream=on_stream, on_stream_end=on_stream_end,
|
on_stream=on_stream, on_stream_end=on_stream_end,
|
||||||
|
|||||||
@ -7,15 +7,28 @@ All requests route to a single persistent API session.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import mimetypes
|
||||||
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.config.paths import get_media_dir
|
||||||
|
from nanobot.utils.helpers import safe_filename
|
||||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
|
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
|
||||||
|
_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$", re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
class _FileSizeExceeded(Exception):
|
||||||
|
"""Raised when an uploaded file exceeds the size limit."""
|
||||||
|
|
||||||
API_SESSION_KEY = "api:default"
|
API_SESSION_KEY = "api:default"
|
||||||
API_CHAT_ID = "default"
|
API_CHAT_ID = "default"
|
||||||
|
|
||||||
@ -57,48 +70,134 @@ def _response_text(value: Any) -> str:
|
|||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Upload helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _save_base64_data_url(data_url: str, media_dir: Path) -> str | None:
|
||||||
|
"""Decode a data:...;base64,... URL and save to disk."""
|
||||||
|
m = _DATA_URL_RE.match(data_url)
|
||||||
|
if not m:
|
||||||
|
return None
|
||||||
|
mime_type, b64_payload = m.group(1), m.group(2)
|
||||||
|
try:
|
||||||
|
raw = base64.b64decode(b64_payload)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
ext = mimetypes.guess_extension(mime_type) or ".bin"
|
||||||
|
filename = f"{uuid.uuid4().hex[:12]}{ext}"
|
||||||
|
dest = media_dir / safe_filename(filename)
|
||||||
|
dest.write_bytes(raw)
|
||||||
|
return str(dest)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_content(body: dict) -> tuple[str, list[str]]:
|
||||||
|
"""Parse JSON request body. Returns (text, media_paths)."""
|
||||||
|
messages = body.get("messages")
|
||||||
|
if not isinstance(messages, list) or len(messages) != 1:
|
||||||
|
raise ValueError("Only a single user message is supported")
|
||||||
|
message = messages[0]
|
||||||
|
if not isinstance(message, dict) or message.get("role") != "user":
|
||||||
|
raise ValueError("Only a single user message is supported")
|
||||||
|
|
||||||
|
user_content = message.get("content", "")
|
||||||
|
media_dir = get_media_dir("api")
|
||||||
|
media_paths: list[str] = []
|
||||||
|
|
||||||
|
if isinstance(user_content, list):
|
||||||
|
text_parts: list[str] = []
|
||||||
|
for part in user_content:
|
||||||
|
if not isinstance(part, dict):
|
||||||
|
continue
|
||||||
|
if part.get("type") == "text":
|
||||||
|
text_parts.append(part.get("text", ""))
|
||||||
|
elif part.get("type") == "image_url":
|
||||||
|
url = part.get("image_url", {}).get("url", "")
|
||||||
|
if url.startswith("data:"):
|
||||||
|
saved = _save_base64_data_url(url, media_dir)
|
||||||
|
if saved:
|
||||||
|
media_paths.append(saved)
|
||||||
|
text = " ".join(text_parts)
|
||||||
|
elif isinstance(user_content, str):
|
||||||
|
text = user_content
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid content format")
|
||||||
|
|
||||||
|
return text, media_paths
|
||||||
|
|
||||||
|
|
||||||
|
async def _parse_multipart(request: web.Request) -> tuple[str, list[str], str | None]:
|
||||||
|
"""Parse multipart/form-data. Returns (text, media_paths, session_id)."""
|
||||||
|
media_dir = get_media_dir("api")
|
||||||
|
reader = await request.multipart()
|
||||||
|
text = ""
|
||||||
|
session_id = None
|
||||||
|
media_paths: list[str] = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
part = await reader.next()
|
||||||
|
if part is None:
|
||||||
|
break
|
||||||
|
if part.name == "message":
|
||||||
|
text = (await part.read()).decode("utf-8")
|
||||||
|
elif part.name == "session_id":
|
||||||
|
session_id = (await part.read()).decode("utf-8").strip()
|
||||||
|
elif part.name == "files":
|
||||||
|
raw = await part.read()
|
||||||
|
if len(raw) > MAX_FILE_SIZE:
|
||||||
|
raise _FileSizeExceeded(f"File '{part.filename}' exceeds {MAX_FILE_SIZE // (1024*1024)}MB limit")
|
||||||
|
filename = safe_filename(part.filename or f"{uuid.uuid4().hex[:12]}.bin")
|
||||||
|
dest = media_dir / filename
|
||||||
|
dest.write_bytes(raw)
|
||||||
|
media_paths.append(str(dest))
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
text = "请分析上传的文件"
|
||||||
|
|
||||||
|
return text, media_paths, session_id
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Route handlers
|
# Route handlers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
async def handle_chat_completions(request: web.Request) -> web.Response:
|
async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||||
"""POST /v1/chat/completions"""
|
"""POST /v1/chat/completions — supports JSON and multipart/form-data."""
|
||||||
|
content_type = request.content_type or ""
|
||||||
# --- Parse body ---
|
if not isinstance(content_type, str):
|
||||||
try:
|
content_type = ""
|
||||||
body = await request.json()
|
|
||||||
except Exception:
|
|
||||||
return _error_json(400, "Invalid JSON body")
|
|
||||||
|
|
||||||
messages = body.get("messages")
|
|
||||||
if not isinstance(messages, list) or len(messages) != 1:
|
|
||||||
return _error_json(400, "Only a single user message is supported")
|
|
||||||
|
|
||||||
# Stream not yet supported
|
|
||||||
if body.get("stream", False):
|
|
||||||
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
|
|
||||||
|
|
||||||
message = messages[0]
|
|
||||||
if not isinstance(message, dict) or message.get("role") != "user":
|
|
||||||
return _error_json(400, "Only a single user message is supported")
|
|
||||||
user_content = message.get("content", "")
|
|
||||||
if isinstance(user_content, list):
|
|
||||||
# Multi-modal content array — extract text parts
|
|
||||||
user_content = " ".join(
|
|
||||||
part.get("text", "") for part in user_content if part.get("type") == "text"
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_loop = request.app["agent_loop"]
|
agent_loop = request.app["agent_loop"]
|
||||||
timeout_s: float = request.app.get("request_timeout", 120.0)
|
timeout_s: float = request.app.get("request_timeout", 120.0)
|
||||||
model_name: str = request.app.get("model_name", "nanobot")
|
model_name: str = request.app.get("model_name", "nanobot")
|
||||||
if (requested_model := body.get("model")) and requested_model != model_name:
|
|
||||||
return _error_json(400, f"Only configured model '{model_name}' is available")
|
|
||||||
|
|
||||||
session_key = f"api:{body['session_id']}" if body.get("session_id") else API_SESSION_KEY
|
try:
|
||||||
|
if content_type.startswith("multipart/"):
|
||||||
|
text, media_paths, session_id = await _parse_multipart(request)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except Exception:
|
||||||
|
return _error_json(400, "Invalid JSON body")
|
||||||
|
if body.get("stream", False):
|
||||||
|
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
|
||||||
|
if (requested_model := body.get("model")) and requested_model != model_name:
|
||||||
|
return _error_json(400, f"Only configured model '{model_name}' is available")
|
||||||
|
text, media_paths = _parse_json_content(body)
|
||||||
|
session_id = body.get("session_id")
|
||||||
|
except ValueError as e:
|
||||||
|
return _error_json(400, str(e))
|
||||||
|
except _FileSizeExceeded as e:
|
||||||
|
return _error_json(413, str(e), err_type="invalid_request_error")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error parsing upload")
|
||||||
|
return _error_json(413, "File too large or invalid upload")
|
||||||
|
|
||||||
|
session_key = f"api:{session_id}" if session_id else API_SESSION_KEY
|
||||||
session_locks: dict[str, asyncio.Lock] = request.app["session_locks"]
|
session_locks: dict[str, asyncio.Lock] = request.app["session_locks"]
|
||||||
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
|
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
|
||||||
|
|
||||||
logger.info("API request session_key={} content={}", session_key, user_content[:80])
|
logger.info("API request session_key={} media={} text={}", session_key, len(media_paths), text[:80])
|
||||||
|
|
||||||
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
|
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
@ -107,7 +206,8 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
try:
|
try:
|
||||||
response = await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
agent_loop.process_direct(
|
agent_loop.process_direct(
|
||||||
content=user_content,
|
content=text,
|
||||||
|
media=media_paths if media_paths else None,
|
||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
channel="api",
|
channel="api",
|
||||||
chat_id=API_CHAT_ID,
|
chat_id=API_CHAT_ID,
|
||||||
@ -117,13 +217,11 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
response_text = _response_text(response)
|
response_text = _response_text(response)
|
||||||
|
|
||||||
if not response_text or not response_text.strip():
|
if not response_text or not response_text.strip():
|
||||||
logger.warning(
|
logger.warning("Empty response for session {}, retrying", session_key)
|
||||||
"Empty response for session {}, retrying",
|
|
||||||
session_key,
|
|
||||||
)
|
|
||||||
retry_response = await asyncio.wait_for(
|
retry_response = await asyncio.wait_for(
|
||||||
agent_loop.process_direct(
|
agent_loop.process_direct(
|
||||||
content=user_content,
|
content=text,
|
||||||
|
media=media_paths if media_paths else None,
|
||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
channel="api",
|
channel="api",
|
||||||
chat_id=API_CHAT_ID,
|
chat_id=API_CHAT_ID,
|
||||||
@ -132,10 +230,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
)
|
)
|
||||||
response_text = _response_text(retry_response)
|
response_text = _response_text(retry_response)
|
||||||
if not response_text or not response_text.strip():
|
if not response_text or not response_text.strip():
|
||||||
logger.warning(
|
logger.warning("Empty response after retry, using fallback")
|
||||||
"Empty response after retry for session {}, using fallback",
|
|
||||||
session_key,
|
|
||||||
)
|
|
||||||
response_text = _FALLBACK
|
response_text = _FALLBACK
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
@ -183,7 +278,7 @@ def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float =
|
|||||||
model_name: Model name reported in responses.
|
model_name: Model name reported in responses.
|
||||||
request_timeout: Per-request timeout in seconds.
|
request_timeout: Per-request timeout in seconds.
|
||||||
"""
|
"""
|
||||||
app = web.Application()
|
app = web.Application(client_max_size=20 * 1024 * 1024) # 20MB for base64 images
|
||||||
app["agent_loop"] = agent_loop
|
app["agent_loop"] = agent_loop
|
||||||
app["model_name"] = model_name
|
app["model_name"] = model_name
|
||||||
app["request_timeout"] = request_timeout
|
app["request_timeout"] = request_timeout
|
||||||
|
|||||||
206
nanobot/utils/document.py
Normal file
206
nanobot/utils/document.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
"""Document text extraction utilities for nanobot."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pypdf import PdfReader
|
||||||
|
except ImportError:
|
||||||
|
PdfReader = None # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
from docx import Document as DocxDocument
|
||||||
|
except ImportError:
|
||||||
|
DocxDocument = None # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
from openpyxl import load_workbook
|
||||||
|
except ImportError:
|
||||||
|
load_workbook = None # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pptx import Presentation as PptxPresentation
|
||||||
|
except ImportError:
|
||||||
|
PptxPresentation = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
# Supported file extensions for text extraction
|
||||||
|
SUPPORTED_EXTENSIONS: set[str] = {
|
||||||
|
# Document formats
|
||||||
|
".pdf",
|
||||||
|
".docx",
|
||||||
|
".xlsx",
|
||||||
|
".pptx",
|
||||||
|
# Text formats
|
||||||
|
".txt",
|
||||||
|
".md",
|
||||||
|
".csv",
|
||||||
|
".json",
|
||||||
|
".xml",
|
||||||
|
".html",
|
||||||
|
".htm",
|
||||||
|
".log",
|
||||||
|
".yaml",
|
||||||
|
".yml",
|
||||||
|
".toml",
|
||||||
|
".ini",
|
||||||
|
".cfg",
|
||||||
|
# Image formats (for future OCR support)
|
||||||
|
".png",
|
||||||
|
".jpg",
|
||||||
|
".jpeg",
|
||||||
|
".gif",
|
||||||
|
".webp",
|
||||||
|
}
|
||||||
|
|
||||||
|
_MAX_TEXT_LENGTH = 200_000
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text(path: Path) -> str | None:
|
||||||
|
"""Extract text from a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Extracted text as string, None for unsupported types,
|
||||||
|
or error string for failures.
|
||||||
|
"""
|
||||||
|
if not isinstance(path, Path):
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
return f"[error: file not found: {path}]"
|
||||||
|
|
||||||
|
ext = path.suffix.lower()
|
||||||
|
|
||||||
|
# Document formats
|
||||||
|
if ext == ".pdf":
|
||||||
|
if PdfReader is None:
|
||||||
|
return "[error: pypdf not installed]"
|
||||||
|
return _extract_pdf(path)
|
||||||
|
elif ext == ".docx":
|
||||||
|
if DocxDocument is None:
|
||||||
|
return "[error: python-docx not installed]"
|
||||||
|
return _extract_docx(path)
|
||||||
|
elif ext == ".xlsx":
|
||||||
|
if load_workbook is None:
|
||||||
|
return "[error: openpyxl not installed]"
|
||||||
|
return _extract_xlsx(path)
|
||||||
|
elif ext == ".pptx":
|
||||||
|
if PptxPresentation is None:
|
||||||
|
return "[error: python-pptx not installed]"
|
||||||
|
return _extract_pptx(path)
|
||||||
|
elif _is_text_extension(ext):
|
||||||
|
return _extract_text_file(path)
|
||||||
|
elif ext in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
|
||||||
|
# Image files - for future OCR support
|
||||||
|
return f"[image: {path.name}]"
|
||||||
|
else:
|
||||||
|
# Unsupported extension
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_pdf(path: Path) -> str:
|
||||||
|
"""Extract text from PDF using pypdf."""
|
||||||
|
try:
|
||||||
|
reader = PdfReader(path)
|
||||||
|
pages: list[str] = []
|
||||||
|
for i, page in enumerate(reader.pages, 1):
|
||||||
|
text = page.extract_text() or ""
|
||||||
|
pages.append(f"--- Page {i} ---\n{text}")
|
||||||
|
return _truncate("\n\n".join(pages), _MAX_TEXT_LENGTH)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to extract PDF {}: {}", path, e)
|
||||||
|
return f"[error: failed to extract PDF: {e!s}]"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_docx(path: Path) -> str:
|
||||||
|
"""Extract text from DOCX using python-docx."""
|
||||||
|
try:
|
||||||
|
doc = DocxDocument(path)
|
||||||
|
paragraphs: list[str] = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||||
|
return _truncate("\n\n".join(paragraphs), _MAX_TEXT_LENGTH)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to extract DOCX {}: {}", path, e)
|
||||||
|
return f"[error: failed to extract DOCX: {e!s}]"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_xlsx(path: Path) -> str:
|
||||||
|
"""Extract text from XLSX using openpyxl."""
|
||||||
|
try:
|
||||||
|
wb = load_workbook(path, read_only=True, data_only=True)
|
||||||
|
sheets: list[str] = []
|
||||||
|
for sheet_name in wb.sheetnames:
|
||||||
|
ws = wb[sheet_name]
|
||||||
|
rows: list[str] = []
|
||||||
|
for row in ws.iter_rows(values_only=True):
|
||||||
|
row_text = "\t".join(str(cell) if cell is not None else "" for cell in row)
|
||||||
|
if row_text.strip():
|
||||||
|
rows.append(row_text)
|
||||||
|
if rows:
|
||||||
|
sheets.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
|
||||||
|
wb.close()
|
||||||
|
return _truncate("\n\n".join(sheets), _MAX_TEXT_LENGTH)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to extract XLSX {}: {}", path, e)
|
||||||
|
return f"[error: failed to extract XLSX: {e!s}]"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_pptx(path: Path) -> str:
|
||||||
|
"""Extract text from PPTX using python-pptx."""
|
||||||
|
try:
|
||||||
|
prs = PptxPresentation(path)
|
||||||
|
slides: list[str] = []
|
||||||
|
for i, slide in enumerate(prs.slides, 1):
|
||||||
|
slide_text: list[str] = []
|
||||||
|
for shape in slide.shapes:
|
||||||
|
if hasattr(shape, "text") and shape.text:
|
||||||
|
slide_text.append(shape.text)
|
||||||
|
if slide_text:
|
||||||
|
slides.append(f"--- Slide {i} ---\n" + "\n".join(slide_text))
|
||||||
|
return _truncate("\n\n".join(slides), _MAX_TEXT_LENGTH)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to extract PPTX {}: {}", path, e)
|
||||||
|
return f"[error: failed to extract PPTX: {e!s}]"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_file(path: Path) -> str:
|
||||||
|
"""Extract text from a plain text file."""
|
||||||
|
try:
|
||||||
|
# Try UTF-8 first, then latin-1 fallback
|
||||||
|
try:
|
||||||
|
content = path.read_text(encoding="utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
content = path.read_text(encoding="latin-1")
|
||||||
|
return _truncate(content, _MAX_TEXT_LENGTH)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to read text file {}: {}", path, e)
|
||||||
|
return f"[error: failed to read file: {e!s}]"
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate(text: str, max_length: int) -> str:
|
||||||
|
"""Truncate text with a suffix indicating truncation."""
|
||||||
|
if len(text) <= max_length:
|
||||||
|
return text
|
||||||
|
return text[:max_length] + f"... (truncated, {len(text)} chars total)"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_text_extension(ext: str) -> bool:
|
||||||
|
"""Check if extension is a text format."""
|
||||||
|
return ext in {
|
||||||
|
".txt",
|
||||||
|
".md",
|
||||||
|
".csv",
|
||||||
|
".json",
|
||||||
|
".xml",
|
||||||
|
".html",
|
||||||
|
".htm",
|
||||||
|
".log",
|
||||||
|
".yaml",
|
||||||
|
".yml",
|
||||||
|
".toml",
|
||||||
|
".ini",
|
||||||
|
".cfg",
|
||||||
|
}
|
||||||
@ -50,6 +50,10 @@ dependencies = [
|
|||||||
"tiktoken>=0.12.0,<1.0.0",
|
"tiktoken>=0.12.0,<1.0.0",
|
||||||
"jinja2>=3.1.0,<4.0.0",
|
"jinja2>=3.1.0,<4.0.0",
|
||||||
"dulwich>=0.22.0,<1.0.0",
|
"dulwich>=0.22.0,<1.0.0",
|
||||||
|
"pypdf>=5.0.0,<6.0.0",
|
||||||
|
"python-docx>=1.1.0,<2.0.0",
|
||||||
|
"openpyxl>=3.1.0,<4.0.0",
|
||||||
|
"python-pptx>=1.0.0,<2.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
379
tests/test_api_attachment.py
Normal file
379
tests/test_api_attachment.py
Normal file
@ -0,0 +1,379 @@
|
|||||||
|
"""Tests for API file upload functionality (JSON base64 + multipart)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from nanobot.api.server import (
|
||||||
|
API_CHAT_ID,
|
||||||
|
API_SESSION_KEY,
|
||||||
|
_parse_json_content,
|
||||||
|
_save_base64_data_url,
|
||||||
|
create_app,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from aiohttp.test_utils import TestClient, TestServer
|
||||||
|
|
||||||
|
HAS_AIOHTTP = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_AIOHTTP = False
|
||||||
|
|
||||||
|
pytest_plugins = ("pytest_asyncio",)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
|
||||||
|
agent = MagicMock()
|
||||||
|
agent.process_direct = AsyncMock(return_value=response_text)
|
||||||
|
agent._connect_mcp = AsyncMock()
|
||||||
|
agent.close_mcp = AsyncMock()
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_agent():
|
||||||
|
return _make_mock_agent()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(mock_agent):
|
||||||
|
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def aiohttp_client():
|
||||||
|
clients: list[TestClient] = []
|
||||||
|
|
||||||
|
async def _make_client(app):
|
||||||
|
client = TestClient(TestServer(app))
|
||||||
|
await client.start_server()
|
||||||
|
clients.append(client)
|
||||||
|
return client
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield _make_client
|
||||||
|
finally:
|
||||||
|
for client in clients:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper function tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_save_base64_data_url_saves_png(tmp_path) -> None:
|
||||||
|
"""Saving a base64 data URL creates a file with correct extension."""
|
||||||
|
b64_data = base64.b64encode(b"fake png data").decode()
|
||||||
|
data_url = f"data:image/png;base64,{b64_data}"
|
||||||
|
result = _save_base64_data_url(data_url, tmp_path)
|
||||||
|
assert result is not None
|
||||||
|
assert result.endswith(".png")
|
||||||
|
assert (tmp_path / result.replace(str(tmp_path) + "/", "")).read_bytes() == b"fake png data"
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_base64_data_url_handles_invalid_b64(tmp_path) -> None:
|
||||||
|
"""Invalid base64 returns None."""
|
||||||
|
result = _save_base64_data_url("data:image/png;base64,not-valid-base64!!!", tmp_path)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_base64_data_url_handles_unknown_mime(tmp_path) -> None:
|
||||||
|
"""Unknown MIME type defaults to .bin."""
|
||||||
|
b64_data = base64.b64encode(b"some data").decode()
|
||||||
|
data_url = f"data:unknown/type;base64,{b64_data}"
|
||||||
|
result = _save_base64_data_url(data_url, tmp_path)
|
||||||
|
assert result is not None
|
||||||
|
assert result.endswith(".bin")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_json_content_extracts_text_and_media(tmp_path) -> None:
|
||||||
|
"""Parse JSON with text + base64 image saves image and returns paths."""
|
||||||
|
b64_data = base64.b64encode(b"img").decode()
|
||||||
|
body = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "describe this"},
|
||||||
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_data}"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
import os
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
text, media_paths = _parse_json_content(body)
|
||||||
|
assert text == "describe this"
|
||||||
|
assert len(media_paths) == 1
|
||||||
|
finally:
|
||||||
|
os.chdir(original_cwd)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_json_content_plain_text_only() -> None:
|
||||||
|
"""Plain text string content returns no media."""
|
||||||
|
body = {"messages": [{"role": "user", "content": "hello"}]}
|
||||||
|
text, media_paths = _parse_json_content(body)
|
||||||
|
assert text == "hello"
|
||||||
|
assert media_paths == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_json_content_validates_single_message() -> None:
|
||||||
|
"""Multiple messages raise ValueError."""
|
||||||
|
body = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "first"},
|
||||||
|
{"role": "user", "content": "second"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
with pytest.raises(ValueError, match="single user message"):
|
||||||
|
_parse_json_content(body)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_json_content_validates_user_role() -> None:
|
||||||
|
"""Non-user role raises ValueError."""
|
||||||
|
body = {"messages": [{"role": "system", "content": "you are a bot"}]}
|
||||||
|
with pytest.raises(ValueError, match="single user message"):
|
||||||
|
_parse_json_content(body)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Multipart upload tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multipart_upload_saves_file(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||||
|
"""Multipart upload saves file to media dir and passes path to process_direct."""
|
||||||
|
import os
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app(mock_agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
file_data = b"test file content"
|
||||||
|
data = BytesIO(file_data)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
data={"message": "analyze this", "files": data},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||||
|
assert call_kwargs["content"] == "analyze this"
|
||||||
|
assert len(call_kwargs.get("media", [])) == 1
|
||||||
|
finally:
|
||||||
|
os.chdir(original_cwd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multipart_multiple_files(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||||
|
"""Multipart upload with multiple files saves all and passes paths."""
|
||||||
|
import os
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app(mock_agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
# Note: aiohttp test client has limited multipart support
|
||||||
|
# This test verifies the basic flow
|
||||||
|
file_data = b"test content"
|
||||||
|
data = BytesIO(file_data)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
data={"message": "analyze", "files": data},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
finally:
|
||||||
|
os.chdir(original_cwd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multipart_file_size_limit(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||||
|
"""File exceeding MAX_FILE_SIZE returns 413."""
|
||||||
|
import os
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app(mock_agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
# Create a file larger than 10MB
|
||||||
|
large_data = b"x" * (11 * 1024 * 1024)
|
||||||
|
data = BytesIO(large_data)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
data={"message": "analyze", "files": data},
|
||||||
|
)
|
||||||
|
assert resp.status == 413
|
||||||
|
finally:
|
||||||
|
os.chdir(original_cwd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multipart_defaults_text_when_missing(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||||
|
"""Multipart without message field uses default text."""
|
||||||
|
import os
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app(mock_agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
file_data = b"content"
|
||||||
|
data = BytesIO(file_data)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
data={"files": data},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||||
|
assert call_kwargs["content"] == "请分析上传的文件"
|
||||||
|
finally:
|
||||||
|
os.chdir(original_cwd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multipart_with_session_id(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||||
|
"""Multipart upload with session_id uses custom session key."""
|
||||||
|
import os
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app(mock_agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
file_data = b"content"
|
||||||
|
data = BytesIO(file_data)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
data={"message": "hello", "session_id": "my-session", "files": data},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||||
|
assert call_kwargs["session_key"] == "api:my-session"
|
||||||
|
finally:
|
||||||
|
os.chdir(original_cwd)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Backward compatibility tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plain_text_backward_compat(aiohttp_client, mock_agent) -> None:
|
||||||
|
"""Plain text JSON request (no media) works as before."""
|
||||||
|
app = create_app(mock_agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hello world"}]},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
body = await resp.json()
|
||||||
|
assert body["choices"][0]["message"]["content"] == "mock response"
|
||||||
|
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||||
|
assert call_kwargs["content"] == "hello world"
|
||||||
|
assert call_kwargs.get("media") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_json_base64_image_upload(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||||
|
"""JSON request with base64 data URL saves file and passes path."""
|
||||||
|
import os
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app(mock_agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
# Use valid base64 for a tiny PNG (1x1 transparent pixel)
|
||||||
|
tiny_png_b64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "what is this"},
|
||||||
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{tiny_png_b64}"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||||
|
assert call_kwargs["content"] == "what is this"
|
||||||
|
assert len(call_kwargs.get("media", [])) == 1
|
||||||
|
finally:
|
||||||
|
os.chdir(original_cwd)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DOCX document extraction tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_docx_upload_extracted_and_sent(aiohttp_client, tmp_path) -> None:
|
||||||
|
"""Uploaded DOCX should have its text extracted before being sent to AI."""
|
||||||
|
from docx import Document
|
||||||
|
|
||||||
|
agent = _make_mock_agent("This report shows $5M revenue")
|
||||||
|
import os
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
app = create_app(agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
doc = Document()
|
||||||
|
doc.add_heading("Q1 Report", level=1)
|
||||||
|
doc.add_paragraph("Total revenue: $5,000,000")
|
||||||
|
buf = BytesIO()
|
||||||
|
doc.save(buf)
|
||||||
|
docx_bytes = buf.getvalue()
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
data = aiohttp.FormData()
|
||||||
|
data.add_field("message", "summarize the report")
|
||||||
|
data.add_field("files", docx_bytes, filename="report.docx",
|
||||||
|
content_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document")
|
||||||
|
|
||||||
|
resp = await client.post("/v1/chat/completions", data=data)
|
||||||
|
assert resp.status == 200
|
||||||
|
call_kwargs = agent.process_direct.call_args.kwargs
|
||||||
|
media = call_kwargs.get("media", [])
|
||||||
|
assert len(media) == 1
|
||||||
|
assert "report.docx" in media[0]
|
||||||
|
finally:
|
||||||
|
os.chdir(original_cwd)
|
||||||
66
tests/test_context_documents.py
Normal file
66
tests/test_context_documents.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
"""Tests for context builder document handling."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from nanobot.agent.context import ContextBuilder
|
||||||
|
|
||||||
|
|
||||||
|
def _make_builder(tmp_path: Path) -> ContextBuilder:
|
||||||
|
"""Create a minimal ContextBuilder for testing."""
|
||||||
|
return ContextBuilder(workspace=tmp_path, timezone="UTC")
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_content_with_no_media_returns_string(tmp_path: Path) -> None:
|
||||||
|
builder = _make_builder(tmp_path)
|
||||||
|
result = builder._build_user_content("hello", None)
|
||||||
|
assert result == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_content_with_image_returns_list(tmp_path: Path) -> None:
|
||||||
|
"""Image files should produce base64 content blocks."""
|
||||||
|
builder = _make_builder(tmp_path)
|
||||||
|
png = tmp_path / "test.png"
|
||||||
|
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||||
|
result = builder._build_user_content("describe this", [str(png)])
|
||||||
|
assert isinstance(result, list)
|
||||||
|
types = [b["type"] for b in result]
|
||||||
|
assert "image_url" in types
|
||||||
|
assert "text" in types
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_content_with_docx_includes_extracted_text(tmp_path: Path) -> None:
|
||||||
|
"""Document files should have their text extracted and included."""
|
||||||
|
from docx import Document
|
||||||
|
|
||||||
|
doc = Document()
|
||||||
|
doc.add_paragraph("Quarterly revenue is $5M")
|
||||||
|
docx_path = tmp_path / "report.docx"
|
||||||
|
doc.save(docx_path)
|
||||||
|
|
||||||
|
builder = _make_builder(tmp_path)
|
||||||
|
result = builder._build_user_content("summarize this", [str(docx_path)])
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert "Quarterly revenue" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_content_mixed_image_and_document(tmp_path: Path) -> None:
|
||||||
|
"""Mix of images and documents: images as base64, docs as text."""
|
||||||
|
from docx import Document
|
||||||
|
|
||||||
|
png = tmp_path / "chart.png"
|
||||||
|
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||||
|
|
||||||
|
doc = Document()
|
||||||
|
doc.add_paragraph("Report text here")
|
||||||
|
docx = tmp_path / "report.docx"
|
||||||
|
doc.save(docx)
|
||||||
|
|
||||||
|
builder = _make_builder(tmp_path)
|
||||||
|
result = builder._build_user_content("analyze both", [str(png), str(docx)])
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert any(b["type"] == "image_url" for b in result)
|
||||||
|
text_parts = [b.get("text", "") for b in result if b.get("type") == "text"]
|
||||||
|
assert any("Report text here" in t for t in text_parts)
|
||||||
276
tests/test_document_parsing.py
Normal file
276
tests/test_document_parsing.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
"""Tests for document text extraction utilities."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.utils.document import (
|
||||||
|
SUPPORTED_EXTENSIONS,
|
||||||
|
_is_text_extension,
|
||||||
|
extract_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSupportedExtensions:
|
||||||
|
"""Test the SUPPORTED_EXTENSIONS constant."""
|
||||||
|
|
||||||
|
def test_supported_extensions_include_common_formats(self):
|
||||||
|
"""Test that common document formats are included."""
|
||||||
|
# Document formats
|
||||||
|
assert ".pdf" in SUPPORTED_EXTENSIONS
|
||||||
|
assert ".docx" in SUPPORTED_EXTENSIONS
|
||||||
|
assert ".xlsx" in SUPPORTED_EXTENSIONS
|
||||||
|
assert ".pptx" in SUPPORTED_EXTENSIONS
|
||||||
|
|
||||||
|
# Text formats
|
||||||
|
assert ".txt" in SUPPORTED_EXTENSIONS
|
||||||
|
assert ".md" in SUPPORTED_EXTENSIONS
|
||||||
|
assert ".csv" in SUPPORTED_EXTENSIONS
|
||||||
|
assert ".json" in SUPPORTED_EXTENSIONS
|
||||||
|
assert ".yaml" in SUPPORTED_EXTENSIONS
|
||||||
|
assert ".yml" in SUPPORTED_EXTENSIONS
|
||||||
|
|
||||||
|
# Image formats
|
||||||
|
assert ".png" in SUPPORTED_EXTENSIONS
|
||||||
|
assert ".jpg" in SUPPORTED_EXTENSIONS
|
||||||
|
assert ".jpeg" in SUPPORTED_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractText:
|
||||||
|
"""Test the extract_text function."""
|
||||||
|
|
||||||
|
def test_extract_text_unsupported_returns_none(self, tmp_path: Path):
|
||||||
|
"""Test that unsupported file types return None."""
|
||||||
|
unsupported_file = tmp_path / "file.xyz"
|
||||||
|
unsupported_file.write_text("content")
|
||||||
|
|
||||||
|
result = extract_text(unsupported_file)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_extract_text_file_not_found(self, tmp_path: Path):
|
||||||
|
"""Test that non-existent files return error string."""
|
||||||
|
missing_file = tmp_path / "nonexistent.txt"
|
||||||
|
|
||||||
|
result = extract_text(missing_file)
|
||||||
|
assert result is not None
|
||||||
|
assert "[error: file not found:" in result
|
||||||
|
|
||||||
|
def test_extract_text_txt_file(self, tmp_path: Path):
|
||||||
|
"""Test extracting text from a .txt file."""
|
||||||
|
txt_file = tmp_path / "test.txt"
|
||||||
|
content = "Hello, world!\nThis is a test."
|
||||||
|
txt_file.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
result = extract_text(txt_file)
|
||||||
|
assert result == content
|
||||||
|
|
||||||
|
def test_extract_text_txt_file_with_truncation(self, tmp_path: Path):
|
||||||
|
"""Test that large text files are truncated."""
|
||||||
|
txt_file = tmp_path / "large.txt"
|
||||||
|
# Create content larger than _MAX_TEXT_LENGTH
|
||||||
|
content = "x" * 300_000
|
||||||
|
txt_file.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
result = extract_text(txt_file)
|
||||||
|
assert len(result) < 300_000
|
||||||
|
assert "(truncated," in result
|
||||||
|
assert "chars total)" in result
|
||||||
|
|
||||||
|
def test_extract_text_md_file(self, tmp_path: Path):
|
||||||
|
"""Test extracting text from a .md file."""
|
||||||
|
md_file = tmp_path / "test.md"
|
||||||
|
content = "# Header\n\nSome markdown content."
|
||||||
|
md_file.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
result = extract_text(md_file)
|
||||||
|
assert result == content
|
||||||
|
|
||||||
|
def test_extract_text_csv_file(self, tmp_path: Path):
|
||||||
|
"""Test extracting text from a .csv file."""
|
||||||
|
csv_file = tmp_path / "test.csv"
|
||||||
|
content = "name,age\nAlice,30\nBob,25"
|
||||||
|
csv_file.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
result = extract_text(csv_file)
|
||||||
|
assert result == content
|
||||||
|
|
||||||
|
def test_extract_text_json_file(self, tmp_path: Path):
|
||||||
|
"""Test extracting text from a .json file."""
|
||||||
|
json_file = tmp_path / "test.json"
|
||||||
|
content = '{"key": "value", "number": 42}'
|
||||||
|
json_file.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
result = extract_text(json_file)
|
||||||
|
assert result == content
|
||||||
|
|
||||||
|
def test_extract_text_xlsx(self, tmp_path: Path):
|
||||||
|
"""Test extracting text from an .xlsx file."""
|
||||||
|
from openpyxl import Workbook
|
||||||
|
|
||||||
|
xlsx_file = tmp_path / "test.xlsx"
|
||||||
|
wb = Workbook()
|
||||||
|
ws = wb.active
|
||||||
|
ws.title = "Sheet1"
|
||||||
|
ws["A1"] = "Name"
|
||||||
|
ws["B1"] = "Age"
|
||||||
|
ws["A2"] = "Alice"
|
||||||
|
ws["B2"] = 30
|
||||||
|
ws["A3"] = "Bob"
|
||||||
|
ws["B3"] = 25
|
||||||
|
|
||||||
|
# Add a second sheet
|
||||||
|
ws2 = wb.create_sheet("Sheet2")
|
||||||
|
ws2["A1"] = "Product"
|
||||||
|
ws2["B1"] = "Price"
|
||||||
|
ws2["A2"] = "Widget"
|
||||||
|
ws2["B2"] = 9.99
|
||||||
|
|
||||||
|
wb.save(xlsx_file)
|
||||||
|
wb.close()
|
||||||
|
|
||||||
|
result = extract_text(xlsx_file)
|
||||||
|
assert result is not None
|
||||||
|
assert "--- Sheet: Sheet1 ---" in result
|
||||||
|
assert "--- Sheet: Sheet2 ---" in result
|
||||||
|
assert "Alice" in result
|
||||||
|
assert "Bob" in result
|
||||||
|
assert "Widget" in result
|
||||||
|
assert "9.99" in result
|
||||||
|
|
||||||
|
def test_extract_text_xlsx_empty_sheet(self, tmp_path: Path):
|
||||||
|
"""Test extracting text from an .xlsx file with empty sheets."""
|
||||||
|
from openpyxl import Workbook
|
||||||
|
|
||||||
|
xlsx_file = tmp_path / "empty.xlsx"
|
||||||
|
wb = Workbook()
|
||||||
|
# Clear the default sheet
|
||||||
|
wb.remove(wb.active)
|
||||||
|
# Add an empty sheet
|
||||||
|
wb.create_sheet("EmptySheet")
|
||||||
|
wb.save(xlsx_file)
|
||||||
|
wb.close()
|
||||||
|
|
||||||
|
result = extract_text(xlsx_file)
|
||||||
|
# Empty sheets should return empty string or header only
|
||||||
|
assert result == "--- Sheet: EmptySheet ---" or result == ""
|
||||||
|
|
||||||
|
def test_extract_text_docx(self, tmp_path: Path):
|
||||||
|
"""Test extracting text from a .docx file."""
|
||||||
|
from docx import Document
|
||||||
|
|
||||||
|
docx_file = tmp_path / "test.docx"
|
||||||
|
doc = Document()
|
||||||
|
doc.add_heading("Test Document", 0)
|
||||||
|
doc.add_paragraph("This is paragraph one.")
|
||||||
|
doc.add_paragraph("This is paragraph two.")
|
||||||
|
doc.save(docx_file)
|
||||||
|
|
||||||
|
result = extract_text(docx_file)
|
||||||
|
assert result is not None
|
||||||
|
assert "Test Document" in result
|
||||||
|
assert "This is paragraph one." in result
|
||||||
|
assert "This is paragraph two." in result
|
||||||
|
|
||||||
|
def test_extract_text_docx_empty(self, tmp_path: Path):
|
||||||
|
"""Test extracting text from an empty .docx file."""
|
||||||
|
from docx import Document
|
||||||
|
|
||||||
|
docx_file = tmp_path / "empty.docx"
|
||||||
|
doc = Document()
|
||||||
|
doc.save(docx_file)
|
||||||
|
|
||||||
|
result = extract_text(docx_file)
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_extract_text_pptx(self, tmp_path: Path):
|
||||||
|
"""Test extracting text from a .pptx file."""
|
||||||
|
from pptx import Presentation
|
||||||
|
|
||||||
|
pptx_file = tmp_path / "test.pptx"
|
||||||
|
prs = Presentation()
|
||||||
|
|
||||||
|
# Slide 1
|
||||||
|
slide1 = prs.slides.add_slide(prs.slide_layouts[0])
|
||||||
|
for shape in slide1.shapes:
|
||||||
|
if hasattr(shape, "text"):
|
||||||
|
shape.text = "First Slide Title"
|
||||||
|
|
||||||
|
# Slide 2
|
||||||
|
slide2 = prs.slides.add_slide(prs.slide_layouts[5])
|
||||||
|
left = top = width = height = 1000000
|
||||||
|
textbox = slide2.shapes.add_textbox(left, top, width, height)
|
||||||
|
text_frame = textbox.text_frame
|
||||||
|
text_frame.text = "Bullet point content"
|
||||||
|
|
||||||
|
prs.save(pptx_file)
|
||||||
|
|
||||||
|
result = extract_text(pptx_file)
|
||||||
|
assert result is not None
|
||||||
|
assert "--- Slide 1 ---" in result
|
||||||
|
assert "--- Slide 2 ---" in result
|
||||||
|
# Text content may vary depending on PowerPoint layout defaults
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
def test_extract_text_pdf_not_found(self, tmp_path: Path):
|
||||||
|
"""Test that missing PDF files return error string."""
|
||||||
|
missing_pdf = tmp_path / "nonexistent.pdf"
|
||||||
|
|
||||||
|
result = extract_text(missing_pdf)
|
||||||
|
assert result is not None
|
||||||
|
assert "[error: file not found:" in result
|
||||||
|
|
||||||
|
def test_extract_text_image_files(self, tmp_path: Path):
|
||||||
|
"""Test that image files return placeholder text."""
|
||||||
|
# Create a minimal PNG file (1x1 pixel)
|
||||||
|
png_file = tmp_path / "test.png"
|
||||||
|
# Minimal valid PNG: 8-byte signature + IHDR + IDAT + IEND
|
||||||
|
png_data = (
|
||||||
|
b"\x89PNG\r\n\x1a\n"
|
||||||
|
b"\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01"
|
||||||
|
b"\x08\x02\x00\x00\x00\x90wS\xde"
|
||||||
|
b"\x00\x00\x00\x0cIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01"
|
||||||
|
b"\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||||
|
)
|
||||||
|
png_file.write_bytes(png_data)
|
||||||
|
|
||||||
|
result = extract_text(png_file)
|
||||||
|
assert result is not None
|
||||||
|
assert "[image:" in result
|
||||||
|
assert "test.png" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsTextExtension:
|
||||||
|
"""Test the _is_text_extension helper."""
|
||||||
|
|
||||||
|
def test_text_extensions_return_true(self):
|
||||||
|
"""Test that known text extensions return True."""
|
||||||
|
assert _is_text_extension(".txt") is True
|
||||||
|
assert _is_text_extension(".md") is True
|
||||||
|
assert _is_text_extension(".csv") is True
|
||||||
|
assert _is_text_extension(".json") is True
|
||||||
|
assert _is_text_extension(".yaml") is True
|
||||||
|
assert _is_text_extension(".yml") is True
|
||||||
|
assert _is_text_extension(".xml") is True
|
||||||
|
assert _is_text_extension(".html") is True
|
||||||
|
assert _is_text_extension(".htm") is True
|
||||||
|
|
||||||
|
def test_non_text_extensions_return_false(self):
|
||||||
|
"""Test that non-text extensions return False."""
|
||||||
|
assert _is_text_extension(".pdf") is False
|
||||||
|
assert _is_text_extension(".docx") is False
|
||||||
|
assert _is_text_extension(".xlsx") is False
|
||||||
|
assert _is_text_extension(".pptx") is False
|
||||||
|
assert _is_text_extension(".png") is False
|
||||||
|
assert _is_text_extension(".xyz") is False
|
||||||
|
|
||||||
|
def test_case_sensitivity(self):
|
||||||
|
"""Test that _is_text_extension requires lowercase extension.
|
||||||
|
|
||||||
|
Note: The main extract_text function handles case-insensitivity by
|
||||||
|
converting extensions to lowercase before calling _is_text_extension.
|
||||||
|
"""
|
||||||
|
# _is_text_extension itself is case-sensitive (lowercase only)
|
||||||
|
assert _is_text_extension(".txt") is True
|
||||||
|
assert _is_text_extension(".TXT") is False
|
||||||
|
assert _is_text_extension(".pdf") is False
|
||||||
@ -194,6 +194,7 @@ async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_ag
|
|||||||
assert body["model"] == "test-model"
|
assert body["model"] == "test-model"
|
||||||
mock_agent.process_direct.assert_called_once_with(
|
mock_agent.process_direct.assert_called_once_with(
|
||||||
content="hello",
|
content="hello",
|
||||||
|
media=None,
|
||||||
session_key=API_SESSION_KEY,
|
session_key=API_SESSION_KEY,
|
||||||
channel="api",
|
channel="api",
|
||||||
chat_id=API_CHAT_ID,
|
chat_id=API_CHAT_ID,
|
||||||
@ -205,7 +206,7 @@ async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_ag
|
|||||||
async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
|
async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
|
||||||
call_log: list[str] = []
|
call_log: list[str] = []
|
||||||
|
|
||||||
async def fake_process(content, session_key="", channel="", chat_id=""):
|
async def fake_process(content, session_key="", channel="", chat_id="", **kwargs):
|
||||||
call_log.append(session_key)
|
call_log.append(session_key)
|
||||||
return f"reply to {content}"
|
return f"reply to {content}"
|
||||||
|
|
||||||
@ -236,7 +237,7 @@ async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
|
|||||||
async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None:
|
async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None:
|
||||||
order: list[str] = []
|
order: list[str] = []
|
||||||
|
|
||||||
async def slow_process(content, session_key="", channel="", chat_id=""):
|
async def slow_process(content, session_key="", channel="", chat_id="", **kwargs):
|
||||||
order.append(f"start:{content}")
|
order.append(f"start:{content}")
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
order.append(f"end:{content}")
|
order.append(f"end:{content}")
|
||||||
@ -307,12 +308,12 @@ async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> N
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
mock_agent.process_direct.assert_called_once_with(
|
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||||
content="describe this",
|
assert call_kwargs["content"] == "describe this"
|
||||||
session_key=API_SESSION_KEY,
|
assert call_kwargs["session_key"] == API_SESSION_KEY
|
||||||
channel="api",
|
assert call_kwargs["channel"] == "api"
|
||||||
chat_id=API_CHAT_ID,
|
assert call_kwargs["chat_id"] == API_CHAT_ID
|
||||||
)
|
assert len(call_kwargs.get("media") or []) >= 0 # base64 images saved to disk
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
@ -320,7 +321,7 @@ async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> N
|
|||||||
async def test_empty_response_retry_then_success(aiohttp_client) -> None:
|
async def test_empty_response_retry_then_success(aiohttp_client) -> None:
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def sometimes_empty(content, session_key="", channel="", chat_id=""):
|
async def sometimes_empty(content, session_key="", channel="", chat_id="", **kwargs):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
if call_count == 1:
|
if call_count == 1:
|
||||||
@ -351,7 +352,7 @@ async def test_empty_response_falls_back(aiohttp_client) -> None:
|
|||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def always_empty(content, session_key="", channel="", chat_id=""):
|
async def always_empty(content, session_key="", channel="", chat_id="", **kwargs):
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
return ""
|
return ""
|
||||||
@ -371,3 +372,31 @@ async def test_empty_response_falls_back(aiohttp_client) -> None:
|
|||||||
body = await resp.json()
|
body = await resp.json()
|
||||||
assert body["choices"][0]["message"]["content"] == EMPTY_FINAL_RESPONSE_MESSAGE
|
assert body["choices"][0]["message"]["content"] == EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
assert call_count == 2
|
assert call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_direct_accepts_media() -> None:
|
||||||
|
"""process_direct should forward media paths to _process_message."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
|
||||||
|
loop = AgentLoop.__new__(AgentLoop)
|
||||||
|
loop._connect_mcp = AsyncMock()
|
||||||
|
|
||||||
|
captured_msg = None
|
||||||
|
|
||||||
|
async def fake_process(msg, *, session_key="", on_progress=None, on_stream=None, on_stream_end=None):
|
||||||
|
nonlocal captured_msg
|
||||||
|
captured_msg = msg
|
||||||
|
return None
|
||||||
|
|
||||||
|
loop._process_message = fake_process
|
||||||
|
|
||||||
|
await loop.process_direct(
|
||||||
|
content="analyze this",
|
||||||
|
media=["/tmp/image.png", "/tmp/report.pdf"],
|
||||||
|
session_key="test:1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured_msg is not None
|
||||||
|
assert captured_msg.media == ["/tmp/image.png", "/tmp/report.pdf"]
|
||||||
|
assert captured_msg.content == "analyze this"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user