mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 00:22:31 +00:00
401 lines
14 KiB
Python
401 lines
14 KiB
Python
"""OpenAI-compatible HTTP API server for a fixed nanobot session.
|
|
|
|
Provides /v1/chat/completions and /v1/models endpoints.
|
|
All requests route to a single persistent API session.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import json as _json
|
|
import time
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from aiohttp import web
|
|
from loguru import logger
|
|
|
|
from nanobot.config.paths import get_media_dir
|
|
from nanobot.utils.helpers import safe_filename
|
|
from nanobot.utils.media_decode import (
|
|
MAX_FILE_SIZE,
|
|
)
|
|
from nanobot.utils.media_decode import (
|
|
FileSizeExceeded as _FileSizeExceeded,
|
|
)
|
|
from nanobot.utils.media_decode import (
|
|
save_base64_data_url as _save_base64_data_url,
|
|
)
|
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
|
|
|
__all__ = (
|
|
"MAX_FILE_SIZE",
|
|
"_FileSizeExceeded",
|
|
"_save_base64_data_url",
|
|
"create_app",
|
|
"handle_chat_completions",
|
|
)
|
|
|
|
|
|
API_SESSION_KEY = "api:default"
|
|
API_CHAT_ID = "default"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Response helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response:
|
|
return web.json_response(
|
|
{"error": {"message": message, "type": err_type, "code": status}},
|
|
status=status,
|
|
)
|
|
|
|
|
|
def _chat_completion_response(content: str, model: str) -> dict[str, Any]:
|
|
return {
|
|
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
|
"object": "chat.completion",
|
|
"created": int(time.time()),
|
|
"model": model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": content},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
|
}
|
|
|
|
|
|
def _response_text(value: Any) -> str:
|
|
"""Normalize process_direct output to plain assistant text."""
|
|
if value is None:
|
|
return ""
|
|
if hasattr(value, "content"):
|
|
return str(getattr(value, "content") or "")
|
|
return str(value)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# SSE helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _sse_chunk(delta: str, model: str, chunk_id: str, finish_reason: str | None = None) -> bytes:
|
|
"""Format a single OpenAI-compatible SSE chunk."""
|
|
payload = {
|
|
"id": chunk_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"content": delta} if delta else {},
|
|
"finish_reason": finish_reason,
|
|
}
|
|
],
|
|
}
|
|
return f"data: {_json.dumps(payload)}\n\n".encode()
|
|
|
|
|
|
_SSE_DONE = b"data: [DONE]\n\n"
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Upload helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _parse_json_content(body: dict) -> tuple[str, list[str]]:
|
|
"""Parse JSON request body. Returns (text, media_paths)."""
|
|
messages = body.get("messages")
|
|
if not isinstance(messages, list) or len(messages) != 1:
|
|
raise ValueError("Only a single user message is supported")
|
|
message = messages[0]
|
|
if not isinstance(message, dict) or message.get("role") != "user":
|
|
raise ValueError("Only a single user message is supported")
|
|
|
|
user_content = message.get("content", "")
|
|
media_dir = get_media_dir("api")
|
|
media_paths: list[str] = []
|
|
|
|
if isinstance(user_content, list):
|
|
text_parts: list[str] = []
|
|
for part in user_content:
|
|
if not isinstance(part, dict):
|
|
continue
|
|
if part.get("type") == "text":
|
|
text_parts.append(part.get("text", ""))
|
|
elif part.get("type") == "image_url":
|
|
url = part.get("image_url", {}).get("url", "")
|
|
if url.startswith("data:"):
|
|
saved = _save_base64_data_url(url, media_dir)
|
|
if saved:
|
|
media_paths.append(saved)
|
|
elif url:
|
|
raise ValueError(
|
|
"Remote image URLs are not supported. "
|
|
"Use base64 data URLs or upload files via multipart/form-data."
|
|
)
|
|
text = " ".join(text_parts)
|
|
elif isinstance(user_content, str):
|
|
text = user_content
|
|
else:
|
|
raise ValueError("Invalid content format")
|
|
|
|
return text, media_paths
|
|
|
|
|
|
async def _parse_multipart(request: web.Request) -> tuple[str, list[str], str | None, str | None]:
|
|
"""Parse multipart/form-data. Returns (text, media_paths, session_id, model)."""
|
|
media_dir = get_media_dir("api")
|
|
reader = await request.multipart()
|
|
text = ""
|
|
session_id = None
|
|
model = None
|
|
media_paths: list[str] = []
|
|
|
|
while True:
|
|
part = await reader.next()
|
|
if part is None:
|
|
break
|
|
if part.name == "message":
|
|
text = (await part.read()).decode("utf-8")
|
|
elif part.name == "session_id":
|
|
session_id = (await part.read()).decode("utf-8").strip()
|
|
elif part.name == "model":
|
|
model = (await part.read()).decode("utf-8").strip()
|
|
elif part.name == "files":
|
|
raw = await part.read()
|
|
if len(raw) > MAX_FILE_SIZE:
|
|
raise _FileSizeExceeded(
|
|
f"File '{part.filename}' exceeds {MAX_FILE_SIZE // (1024 * 1024)}MB limit"
|
|
)
|
|
base = safe_filename(part.filename or "upload.bin")
|
|
filename = f"{uuid.uuid4().hex[:12]}_{base}"
|
|
dest = media_dir / filename
|
|
dest.write_bytes(raw)
|
|
media_paths.append(str(dest))
|
|
|
|
if not text:
|
|
text = "请分析上传的文件"
|
|
|
|
return text, media_paths, session_id, model
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Route handlers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
async def handle_chat_completions(request: web.Request) -> web.Response:
|
|
"""POST /v1/chat/completions — supports JSON and multipart/form-data."""
|
|
content_type = request.content_type or ""
|
|
if not isinstance(content_type, str):
|
|
content_type = ""
|
|
|
|
agent_loop = request.app["agent_loop"]
|
|
timeout_s: float = request.app.get("request_timeout", 120.0)
|
|
model_name: str = request.app.get("model_name", "nanobot")
|
|
|
|
stream = False
|
|
try:
|
|
if content_type.startswith("multipart/"):
|
|
text, media_paths, session_id, requested_model = await _parse_multipart(request)
|
|
else:
|
|
try:
|
|
body = await request.json()
|
|
except Exception:
|
|
return _error_json(400, "Invalid JSON body")
|
|
stream = body.get("stream", False)
|
|
requested_model = body.get("model")
|
|
text, media_paths = _parse_json_content(body)
|
|
session_id = body.get("session_id")
|
|
except ValueError as e:
|
|
return _error_json(400, str(e))
|
|
except _FileSizeExceeded as e:
|
|
return _error_json(413, str(e), err_type="invalid_request_error")
|
|
except Exception:
|
|
logger.exception("Error parsing upload")
|
|
return _error_json(413, "File too large or invalid upload")
|
|
|
|
if requested_model and requested_model != model_name:
|
|
return _error_json(400, f"Only configured model '{model_name}' is available")
|
|
|
|
session_key = f"api:{session_id}" if session_id else API_SESSION_KEY
|
|
session_locks: dict[str, asyncio.Lock] = request.app["session_locks"]
|
|
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
|
|
|
|
logger.info(
|
|
"API request session_key={} media={} text={} stream={}",
|
|
session_key, len(media_paths), text[:80], stream,
|
|
)
|
|
# -- streaming path --
|
|
if stream:
|
|
resp = web.StreamResponse()
|
|
resp.content_type = "text/event-stream"
|
|
resp.headers["Cache-Control"] = "no-cache"
|
|
resp.headers["Connection"] = "keep-alive"
|
|
resp.enable_compression()
|
|
await resp.prepare(request)
|
|
|
|
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
|
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
|
stream_failed = False
|
|
emitted_content = False
|
|
|
|
async def _on_stream(token: str) -> None:
|
|
nonlocal emitted_content
|
|
if token:
|
|
emitted_content = True
|
|
await queue.put(token)
|
|
|
|
async def _on_stream_end(*_a: Any, **_kw: Any) -> None:
|
|
# Agent stream-end callbacks mark generation segment boundaries.
|
|
# Tool-backed requests may continue after a segment ends, so the
|
|
# HTTP SSE stream is closed only when process_direct returns.
|
|
return None
|
|
|
|
async def _run() -> None:
|
|
nonlocal stream_failed
|
|
try:
|
|
async with session_lock:
|
|
response = await asyncio.wait_for(
|
|
agent_loop.process_direct(
|
|
content=text,
|
|
media=media_paths if media_paths else None,
|
|
session_key=session_key,
|
|
channel="api",
|
|
chat_id=API_CHAT_ID,
|
|
on_stream=_on_stream,
|
|
on_stream_end=_on_stream_end,
|
|
),
|
|
timeout=timeout_s,
|
|
)
|
|
if not emitted_content:
|
|
response_text = _response_text(response)
|
|
if response_text.strip():
|
|
await queue.put(response_text)
|
|
except Exception:
|
|
stream_failed = True
|
|
logger.exception("Streaming error for session {}", session_key)
|
|
finally:
|
|
await queue.put(None)
|
|
|
|
task = asyncio.create_task(_run())
|
|
try:
|
|
while True:
|
|
token = await queue.get()
|
|
if token is None:
|
|
break
|
|
await resp.write(_sse_chunk(token, model_name, chunk_id))
|
|
finally:
|
|
if not task.done():
|
|
task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await task
|
|
|
|
if not stream_failed:
|
|
await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop"))
|
|
await resp.write(_SSE_DONE)
|
|
return resp
|
|
|
|
# -- non-streaming path (original logic) --
|
|
fallback = EMPTY_FINAL_RESPONSE_MESSAGE
|
|
|
|
try:
|
|
async with session_lock:
|
|
try:
|
|
response = await asyncio.wait_for(
|
|
agent_loop.process_direct(
|
|
content=text,
|
|
media=media_paths if media_paths else None,
|
|
session_key=session_key,
|
|
channel="api",
|
|
chat_id=API_CHAT_ID,
|
|
),
|
|
timeout=timeout_s,
|
|
)
|
|
response_text = _response_text(response)
|
|
|
|
if not response_text or not response_text.strip():
|
|
logger.warning("Empty response for session {}, retrying", session_key)
|
|
retry_response = await asyncio.wait_for(
|
|
agent_loop.process_direct(
|
|
content=text,
|
|
media=media_paths if media_paths else None,
|
|
session_key=session_key,
|
|
channel="api",
|
|
chat_id=API_CHAT_ID,
|
|
),
|
|
timeout=timeout_s,
|
|
)
|
|
response_text = _response_text(retry_response)
|
|
if not response_text or not response_text.strip():
|
|
logger.warning("Empty response after retry, using fallback")
|
|
response_text = fallback
|
|
|
|
except asyncio.TimeoutError:
|
|
return _error_json(504, f"Request timed out after {timeout_s}s")
|
|
except Exception:
|
|
logger.exception("Error processing request for session {}", session_key)
|
|
return _error_json(500, "Internal server error", err_type="server_error")
|
|
except Exception:
|
|
logger.exception("Unexpected API lock error for session {}", session_key)
|
|
return _error_json(500, "Internal server error", err_type="server_error")
|
|
|
|
return web.json_response(_chat_completion_response(response_text, model_name))
|
|
|
|
|
|
async def handle_models(request: web.Request) -> web.Response:
|
|
"""GET /v1/models"""
|
|
model_name = request.app.get("model_name", "nanobot")
|
|
return web.json_response(
|
|
{
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": model_name,
|
|
"object": "model",
|
|
"created": 0,
|
|
"owned_by": "nanobot",
|
|
}
|
|
],
|
|
}
|
|
)
|
|
|
|
|
|
async def handle_health(request: web.Request) -> web.Response:
|
|
"""GET /health"""
|
|
return web.json_response({"status": "ok"})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# App factory
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def create_app(
|
|
agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0
|
|
) -> web.Application:
|
|
"""Create the aiohttp application.
|
|
|
|
Args:
|
|
agent_loop: An initialized AgentLoop instance.
|
|
model_name: Model name reported in responses.
|
|
request_timeout: Per-request timeout in seconds.
|
|
"""
|
|
app = web.Application(client_max_size=20 * 1024 * 1024) # 20MB for base64 images
|
|
app["agent_loop"] = agent_loop
|
|
app["model_name"] = model_name
|
|
app["request_timeout"] = request_timeout
|
|
app["session_locks"] = {} # per-user locks, keyed by session_key
|
|
|
|
app.router.add_post("/v1/chat/completions", handle_chat_completions)
|
|
app.router.add_get("/v1/models", handle_models)
|
|
app.router.add_get("/health", handle_health)
|
|
return app
|