nanobot/nanobot/api/server.py

401 lines
14 KiB
Python

"""OpenAI-compatible HTTP API server for a fixed nanobot session.
Provides /v1/chat/completions and /v1/models endpoints.
All requests route to a single persistent API session.
"""
from __future__ import annotations
import asyncio
import contextlib
import json as _json
import time
import uuid
from typing import Any
from aiohttp import web
from loguru import logger
from nanobot.config.paths import get_media_dir
from nanobot.utils.helpers import safe_filename
from nanobot.utils.media_decode import (
MAX_FILE_SIZE,
)
from nanobot.utils.media_decode import (
FileSizeExceeded as _FileSizeExceeded,
)
from nanobot.utils.media_decode import (
save_base64_data_url as _save_base64_data_url,
)
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
__all__ = (
"MAX_FILE_SIZE",
"_FileSizeExceeded",
"_save_base64_data_url",
"create_app",
"handle_chat_completions",
)
API_SESSION_KEY = "api:default"
API_CHAT_ID = "default"
# ---------------------------------------------------------------------------
# Response helpers
# ---------------------------------------------------------------------------
def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response:
return web.json_response(
{"error": {"message": message, "type": err_type, "code": status}},
status=status,
)
def _chat_completion_response(content: str, model: str) -> dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
def _response_text(value: Any) -> str:
"""Normalize process_direct output to plain assistant text."""
if value is None:
return ""
if hasattr(value, "content"):
return str(getattr(value, "content") or "")
return str(value)
# ---------------------------------------------------------------------------
# SSE helpers
# ---------------------------------------------------------------------------
def _sse_chunk(delta: str, model: str, chunk_id: str, finish_reason: str | None = None) -> bytes:
"""Format a single OpenAI-compatible SSE chunk."""
payload = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": delta} if delta else {},
"finish_reason": finish_reason,
}
],
}
return f"data: {_json.dumps(payload)}\n\n".encode()
_SSE_DONE = b"data: [DONE]\n\n"
# ---------------------------------------------------------------------------
# Upload helpers
# ---------------------------------------------------------------------------
def _parse_json_content(body: dict) -> tuple[str, list[str]]:
"""Parse JSON request body. Returns (text, media_paths)."""
messages = body.get("messages")
if not isinstance(messages, list) or len(messages) != 1:
raise ValueError("Only a single user message is supported")
message = messages[0]
if not isinstance(message, dict) or message.get("role") != "user":
raise ValueError("Only a single user message is supported")
user_content = message.get("content", "")
media_dir = get_media_dir("api")
media_paths: list[str] = []
if isinstance(user_content, list):
text_parts: list[str] = []
for part in user_content:
if not isinstance(part, dict):
continue
if part.get("type") == "text":
text_parts.append(part.get("text", ""))
elif part.get("type") == "image_url":
url = part.get("image_url", {}).get("url", "")
if url.startswith("data:"):
saved = _save_base64_data_url(url, media_dir)
if saved:
media_paths.append(saved)
elif url:
raise ValueError(
"Remote image URLs are not supported. "
"Use base64 data URLs or upload files via multipart/form-data."
)
text = " ".join(text_parts)
elif isinstance(user_content, str):
text = user_content
else:
raise ValueError("Invalid content format")
return text, media_paths
async def _parse_multipart(request: web.Request) -> tuple[str, list[str], str | None, str | None]:
"""Parse multipart/form-data. Returns (text, media_paths, session_id, model)."""
media_dir = get_media_dir("api")
reader = await request.multipart()
text = ""
session_id = None
model = None
media_paths: list[str] = []
while True:
part = await reader.next()
if part is None:
break
if part.name == "message":
text = (await part.read()).decode("utf-8")
elif part.name == "session_id":
session_id = (await part.read()).decode("utf-8").strip()
elif part.name == "model":
model = (await part.read()).decode("utf-8").strip()
elif part.name == "files":
raw = await part.read()
if len(raw) > MAX_FILE_SIZE:
raise _FileSizeExceeded(
f"File '{part.filename}' exceeds {MAX_FILE_SIZE // (1024 * 1024)}MB limit"
)
base = safe_filename(part.filename or "upload.bin")
filename = f"{uuid.uuid4().hex[:12]}_{base}"
dest = media_dir / filename
dest.write_bytes(raw)
media_paths.append(str(dest))
if not text:
text = "请分析上传的文件"
return text, media_paths, session_id, model
# ---------------------------------------------------------------------------
# Route handlers
# ---------------------------------------------------------------------------
async def handle_chat_completions(request: web.Request) -> web.Response:
"""POST /v1/chat/completions — supports JSON and multipart/form-data."""
content_type = request.content_type or ""
if not isinstance(content_type, str):
content_type = ""
agent_loop = request.app["agent_loop"]
timeout_s: float = request.app.get("request_timeout", 120.0)
model_name: str = request.app.get("model_name", "nanobot")
stream = False
try:
if content_type.startswith("multipart/"):
text, media_paths, session_id, requested_model = await _parse_multipart(request)
else:
try:
body = await request.json()
except Exception:
return _error_json(400, "Invalid JSON body")
stream = body.get("stream", False)
requested_model = body.get("model")
text, media_paths = _parse_json_content(body)
session_id = body.get("session_id")
except ValueError as e:
return _error_json(400, str(e))
except _FileSizeExceeded as e:
return _error_json(413, str(e), err_type="invalid_request_error")
except Exception:
logger.exception("Error parsing upload")
return _error_json(413, "File too large or invalid upload")
if requested_model and requested_model != model_name:
return _error_json(400, f"Only configured model '{model_name}' is available")
session_key = f"api:{session_id}" if session_id else API_SESSION_KEY
session_locks: dict[str, asyncio.Lock] = request.app["session_locks"]
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
logger.info(
"API request session_key={} media={} text={} stream={}",
session_key, len(media_paths), text[:80], stream,
)
# -- streaming path --
if stream:
resp = web.StreamResponse()
resp.content_type = "text/event-stream"
resp.headers["Cache-Control"] = "no-cache"
resp.headers["Connection"] = "keep-alive"
resp.enable_compression()
await resp.prepare(request)
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
queue: asyncio.Queue[str | None] = asyncio.Queue()
stream_failed = False
emitted_content = False
async def _on_stream(token: str) -> None:
nonlocal emitted_content
if token:
emitted_content = True
await queue.put(token)
async def _on_stream_end(*_a: Any, **_kw: Any) -> None:
# Agent stream-end callbacks mark generation segment boundaries.
# Tool-backed requests may continue after a segment ends, so the
# HTTP SSE stream is closed only when process_direct returns.
return None
async def _run() -> None:
nonlocal stream_failed
try:
async with session_lock:
response = await asyncio.wait_for(
agent_loop.process_direct(
content=text,
media=media_paths if media_paths else None,
session_key=session_key,
channel="api",
chat_id=API_CHAT_ID,
on_stream=_on_stream,
on_stream_end=_on_stream_end,
),
timeout=timeout_s,
)
if not emitted_content:
response_text = _response_text(response)
if response_text.strip():
await queue.put(response_text)
except Exception:
stream_failed = True
logger.exception("Streaming error for session {}", session_key)
finally:
await queue.put(None)
task = asyncio.create_task(_run())
try:
while True:
token = await queue.get()
if token is None:
break
await resp.write(_sse_chunk(token, model_name, chunk_id))
finally:
if not task.done():
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
if not stream_failed:
await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop"))
await resp.write(_SSE_DONE)
return resp
# -- non-streaming path (original logic) --
fallback = EMPTY_FINAL_RESPONSE_MESSAGE
try:
async with session_lock:
try:
response = await asyncio.wait_for(
agent_loop.process_direct(
content=text,
media=media_paths if media_paths else None,
session_key=session_key,
channel="api",
chat_id=API_CHAT_ID,
),
timeout=timeout_s,
)
response_text = _response_text(response)
if not response_text or not response_text.strip():
logger.warning("Empty response for session {}, retrying", session_key)
retry_response = await asyncio.wait_for(
agent_loop.process_direct(
content=text,
media=media_paths if media_paths else None,
session_key=session_key,
channel="api",
chat_id=API_CHAT_ID,
),
timeout=timeout_s,
)
response_text = _response_text(retry_response)
if not response_text or not response_text.strip():
logger.warning("Empty response after retry, using fallback")
response_text = fallback
except asyncio.TimeoutError:
return _error_json(504, f"Request timed out after {timeout_s}s")
except Exception:
logger.exception("Error processing request for session {}", session_key)
return _error_json(500, "Internal server error", err_type="server_error")
except Exception:
logger.exception("Unexpected API lock error for session {}", session_key)
return _error_json(500, "Internal server error", err_type="server_error")
return web.json_response(_chat_completion_response(response_text, model_name))
async def handle_models(request: web.Request) -> web.Response:
"""GET /v1/models"""
model_name = request.app.get("model_name", "nanobot")
return web.json_response(
{
"object": "list",
"data": [
{
"id": model_name,
"object": "model",
"created": 0,
"owned_by": "nanobot",
}
],
}
)
async def handle_health(request: web.Request) -> web.Response:
"""GET /health"""
return web.json_response({"status": "ok"})
# ---------------------------------------------------------------------------
# App factory
# ---------------------------------------------------------------------------
def create_app(
agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0
) -> web.Application:
"""Create the aiohttp application.
Args:
agent_loop: An initialized AgentLoop instance.
model_name: Model name reported in responses.
request_timeout: Per-request timeout in seconds.
"""
app = web.Application(client_max_size=20 * 1024 * 1024) # 20MB for base64 images
app["agent_loop"] = agent_loop
app["model_name"] = model_name
app["request_timeout"] = request_timeout
app["session_locks"] = {} # per-user locks, keyed by session_key
app.router.add_post("/v1/chat/completions", handle_chat_completions)
app.router.add_get("/v1/models", handle_models)
app.router.add_get("/health", handle_health)
return app