mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
feat(api): add OpenAI-compatible endpoint with x-session-key isolation
This commit is contained in:
parent
e1832e75b5
commit
80219baf25
96
examples/curl.txt
Normal file
96
examples/curl.txt
Normal file
@ -0,0 +1,96 @@
|
||||
# =============================================================================
|
||||
# nanobot OpenAI-Compatible API — curl examples
|
||||
# =============================================================================
|
||||
#
|
||||
# Prerequisites:
|
||||
# pip install nanobot-ai[api] # installs aiohttp
|
||||
# nanobot serve --port 8900 # start the API server
|
||||
#
|
||||
# The x-session-key header is REQUIRED for every request.
|
||||
# Convention:
|
||||
# Private chat: wx:dm:{sender_id}
|
||||
# Group @: wx:group:{group_id}:user:{sender_id}
|
||||
# =============================================================================
|
||||
|
||||
# --- 1. Basic chat completion (private chat) ---
|
||||
|
||||
curl -X POST http://localhost:8900/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-session-key: wx:dm:user_alice" \
|
||||
-d '{
|
||||
"model": "nanobot",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello, who are you?"}
|
||||
]
|
||||
}'
|
||||
|
||||
# --- 2. Follow-up in the same session (context is remembered) ---
|
||||
|
||||
curl -X POST http://localhost:8900/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-session-key: wx:dm:user_alice" \
|
||||
-d '{
|
||||
"model": "nanobot",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What did I just ask you?"}
|
||||
]
|
||||
}'
|
||||
|
||||
# --- 3. Different user — isolated session ---
|
||||
|
||||
curl -X POST http://localhost:8900/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-session-key: wx:dm:user_bob" \
|
||||
-d '{
|
||||
"model": "nanobot",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What did I just ask you?"}
|
||||
]
|
||||
}'
|
||||
# ↑ Bob gets a fresh context — he never asked anything before.
|
||||
|
||||
# --- 4. Group chat — per-user session within a group ---
|
||||
|
||||
curl -X POST http://localhost:8900/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-session-key: wx:group:group_abc:user:user_alice" \
|
||||
-d '{
|
||||
"model": "nanobot",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Summarize our discussion"}
|
||||
]
|
||||
}'
|
||||
|
||||
# --- 5. List available models ---
|
||||
|
||||
curl http://localhost:8900/v1/models
|
||||
|
||||
# --- 6. Health check ---
|
||||
|
||||
curl http://localhost:8900/health
|
||||
|
||||
# --- 7. Missing header — expect 400 ---
|
||||
|
||||
curl -X POST http://localhost:8900/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "nanobot",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"}
|
||||
]
|
||||
}'
|
||||
# ↑ Returns: {"error": {"message": "Missing required header: x-session-key", ...}}
|
||||
|
||||
# --- 8. Stream not yet supported — expect 400 ---
|
||||
|
||||
curl -X POST http://localhost:8900/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "x-session-key: wx:dm:user_alice" \
|
||||
-d '{
|
||||
"model": "nanobot",
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"}
|
||||
],
|
||||
"stream": true
|
||||
}'
|
||||
# ↑ Returns: {"error": {"message": "stream=true is not supported yet...", ...}}
|
||||
@ -23,15 +23,25 @@ class ContextBuilder:
|
||||
self.memory = MemoryStore(workspace)
|
||||
self.skills = SkillsLoader(workspace)
|
||||
|
||||
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
|
||||
"""Build the system prompt from identity, bootstrap files, memory, and skills."""
|
||||
parts = [self._get_identity()]
|
||||
def build_system_prompt(
|
||||
self,
|
||||
skill_names: list[str] | None = None,
|
||||
memory_store: "MemoryStore | None" = None,
|
||||
) -> str:
|
||||
"""Build the system prompt from identity, bootstrap files, memory, and skills.
|
||||
|
||||
Args:
|
||||
memory_store: If provided, use this MemoryStore instead of the default
|
||||
workspace-level one. Used for per-session memory isolation.
|
||||
"""
|
||||
parts = [self._get_identity(memory_store=memory_store)]
|
||||
|
||||
bootstrap = self._load_bootstrap_files()
|
||||
if bootstrap:
|
||||
parts.append(bootstrap)
|
||||
|
||||
memory = self.memory.get_memory_context()
|
||||
store = memory_store or self.memory
|
||||
memory = store.get_memory_context()
|
||||
if memory:
|
||||
parts.append(f"# Memory\n\n{memory}")
|
||||
|
||||
@ -52,12 +62,19 @@ Skills with available="false" need dependencies installed first - you can try in
|
||||
|
||||
return "\n\n---\n\n".join(parts)
|
||||
|
||||
def _get_identity(self) -> str:
|
||||
def _get_identity(self, memory_store: "MemoryStore | None" = None) -> str:
|
||||
"""Get the core identity section."""
|
||||
workspace_path = str(self.workspace.expanduser().resolve())
|
||||
system = platform.system()
|
||||
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
|
||||
|
||||
|
||||
if memory_store is not None:
|
||||
mem_path = str(memory_store.memory_file)
|
||||
hist_path = str(memory_store.history_file)
|
||||
else:
|
||||
mem_path = f"{workspace_path}/memory/MEMORY.md"
|
||||
hist_path = f"{workspace_path}/memory/HISTORY.md"
|
||||
|
||||
return f"""# nanobot 🐈
|
||||
|
||||
You are nanobot, a helpful AI assistant.
|
||||
@ -67,8 +84,8 @@ You are nanobot, a helpful AI assistant.
|
||||
|
||||
## Workspace
|
||||
Your workspace is at: {workspace_path}
|
||||
- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here)
|
||||
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
|
||||
- Long-term memory: {mem_path} (write important facts here)
|
||||
- History log: {hist_path} (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
|
||||
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
|
||||
|
||||
## nanobot Guidelines
|
||||
@ -110,10 +127,11 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
||||
media: list[str] | None = None,
|
||||
channel: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
memory_store: "MemoryStore | None" = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the complete message list for an LLM call."""
|
||||
return [
|
||||
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
||||
{"role": "system", "content": self.build_system_prompt(skill_names, memory_store=memory_store)},
|
||||
*history,
|
||||
{"role": "user", "content": self._build_runtime_context(channel, chat_id)},
|
||||
{"role": "user", "content": self._build_user_content(current_message, media)},
|
||||
|
||||
@ -174,6 +174,7 @@ class AgentLoop:
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
disabled_tools: set[str] | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
"""Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
|
||||
messages = initial_messages
|
||||
@ -181,12 +182,19 @@ class AgentLoop:
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
|
||||
# Build tool definitions, filtering out disabled tools
|
||||
if disabled_tools:
|
||||
tool_defs = [d for d in self.tools.get_definitions()
|
||||
if d.get("function", {}).get("name") not in disabled_tools]
|
||||
else:
|
||||
tool_defs = self.tools.get_definitions()
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
|
||||
response = await self.provider.chat(
|
||||
messages=messages,
|
||||
tools=self.tools.get_definitions(),
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
@ -219,7 +227,10 @@ class AgentLoop:
|
||||
tools_used.append(tool_call.name)
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
if disabled_tools and tool_call.name in disabled_tools:
|
||||
result = f"Error: Tool '{tool_call.name}' is not available in this mode."
|
||||
else:
|
||||
result = await self.tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
@ -322,6 +333,8 @@ class AgentLoop:
|
||||
msg: InboundMessage,
|
||||
session_key: str | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
memory_store: MemoryStore | None = None,
|
||||
disabled_tools: set[str] | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a single inbound message and return the response."""
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
@ -336,8 +349,11 @@ class AgentLoop:
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
memory_store=memory_store,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
messages, disabled_tools=disabled_tools,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(messages)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
@ -360,7 +376,9 @@ class AgentLoop:
|
||||
if snapshot:
|
||||
temp = Session(key=session.key)
|
||||
temp.messages = list(snapshot)
|
||||
if not await self._consolidate_memory(temp, archive_all=True):
|
||||
if not await self._consolidate_memory(
|
||||
temp, archive_all=True, memory_store=memory_store,
|
||||
):
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Memory archival failed, session not cleared. Please try again.",
|
||||
@ -393,7 +411,9 @@ class AgentLoop:
|
||||
async def _consolidate_and_unlock():
|
||||
try:
|
||||
async with lock:
|
||||
await self._consolidate_memory(session)
|
||||
await self._consolidate_memory(
|
||||
session, memory_store=memory_store,
|
||||
)
|
||||
finally:
|
||||
self._consolidating.discard(session.key)
|
||||
if not lock.locked():
|
||||
@ -416,6 +436,7 @@ class AgentLoop:
|
||||
current_message=msg.content,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
memory_store=memory_store,
|
||||
)
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
@ -428,6 +449,7 @@ class AgentLoop:
|
||||
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
initial_messages, on_progress=on_progress or _bus_progress,
|
||||
disabled_tools=disabled_tools,
|
||||
)
|
||||
|
||||
if final_content is None:
|
||||
@ -470,9 +492,30 @@ class AgentLoop:
|
||||
session.messages.append(entry)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
|
||||
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
|
||||
return await MemoryStore(self.workspace).consolidate(
|
||||
def _isolated_memory_store(self, session_key: str) -> MemoryStore:
|
||||
"""Return a per-session-key MemoryStore for multi-tenant isolation."""
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
safe_key = safe_filename(session_key.replace(":", "_"))
|
||||
memory_dir = self.workspace / "sessions" / safe_key / "memory"
|
||||
memory_dir.mkdir(parents=True, exist_ok=True)
|
||||
store = MemoryStore.__new__(MemoryStore)
|
||||
store.memory_dir = memory_dir
|
||||
store.memory_file = memory_dir / "MEMORY.md"
|
||||
store.history_file = memory_dir / "HISTORY.md"
|
||||
return store
|
||||
|
||||
async def _consolidate_memory(
|
||||
self, session, archive_all: bool = False,
|
||||
memory_store: MemoryStore | None = None,
|
||||
) -> bool:
|
||||
"""Delegate to MemoryStore.consolidate(). Returns True on success.
|
||||
|
||||
Args:
|
||||
memory_store: If provided, consolidate into this store instead of
|
||||
the default workspace-level one.
|
||||
"""
|
||||
store = memory_store or MemoryStore(self.workspace)
|
||||
return await store.consolidate(
|
||||
session, self.provider, self.model,
|
||||
archive_all=archive_all, memory_window=self.memory_window,
|
||||
)
|
||||
@ -484,9 +527,26 @@ class AgentLoop:
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
isolate_memory: bool = False,
|
||||
disabled_tools: set[str] | None = None,
|
||||
) -> str:
|
||||
"""Process a message directly (for CLI or cron usage)."""
|
||||
"""Process a message directly (for CLI or cron usage).
|
||||
|
||||
Args:
|
||||
isolate_memory: When True, use a per-session-key memory directory
|
||||
instead of the shared workspace memory. This prevents context
|
||||
leakage between different session keys in multi-tenant (API) mode.
|
||||
disabled_tools: Tool names to exclude from the LLM tool list and
|
||||
reject at execution time. Use to block filesystem access in
|
||||
multi-tenant API mode.
|
||||
"""
|
||||
await self._connect_mcp()
|
||||
memory_store: MemoryStore | None = None
|
||||
if isolate_memory:
|
||||
memory_store = self._isolated_memory_store(session_key)
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
|
||||
response = await self._process_message(
|
||||
msg, session_key=session_key, on_progress=on_progress,
|
||||
memory_store=memory_store, disabled_tools=disabled_tools,
|
||||
)
|
||||
return response.content if response else ""
|
||||
|
||||
1
nanobot/api/__init__.py
Normal file
1
nanobot/api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""OpenAI-compatible HTTP API for nanobot."""
|
||||
222
nanobot/api/server.py
Normal file
222
nanobot/api/server.py
Normal file
@ -0,0 +1,222 @@
|
||||
"""OpenAI-compatible HTTP API server for nanobot.
|
||||
|
||||
Provides /v1/chat/completions and /v1/models endpoints.
|
||||
Session isolation is enforced via the x-session-key request header.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
from loguru import logger
|
||||
|
||||
# Tools that must NOT run in multi-tenant API mode.
|
||||
# Filesystem tools allow the LLM to read/write the shared workspace (including
|
||||
# global MEMORY.md), and exec allows shell commands that can bypass filesystem
|
||||
# restrictions (e.g. `cat ~/.nanobot/workspace/memory/MEMORY.md`).
|
||||
_API_DISABLED_TOOLS: set[str] = {
|
||||
"read_file", "write_file", "edit_file", "list_dir", "exec",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-session-key lock manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _SessionLocks:
|
||||
"""Manages one asyncio.Lock per session key for serial execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._locks: dict[str, asyncio.Lock] = {}
|
||||
self._ref: dict[str, int] = {} # reference count for cleanup
|
||||
|
||||
def acquire(self, key: str) -> asyncio.Lock:
|
||||
if key not in self._locks:
|
||||
self._locks[key] = asyncio.Lock()
|
||||
self._ref[key] = 0
|
||||
self._ref[key] += 1
|
||||
return self._locks[key]
|
||||
|
||||
def release(self, key: str) -> None:
|
||||
self._ref[key] -= 1
|
||||
if self._ref[key] <= 0:
|
||||
self._locks.pop(key, None)
|
||||
self._ref.pop(key, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response:
|
||||
return web.json_response(
|
||||
{"error": {"message": message, "type": err_type, "code": status}},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _chat_completion_response(content: str, model: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
"""POST /v1/chat/completions"""
|
||||
|
||||
# --- x-session-key validation ---
|
||||
session_key = request.headers.get("x-session-key", "").strip()
|
||||
if not session_key:
|
||||
return _error_json(400, "Missing required header: x-session-key")
|
||||
|
||||
# --- Parse body ---
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return _error_json(400, "Invalid JSON body")
|
||||
|
||||
messages = body.get("messages")
|
||||
if not messages or not isinstance(messages, list):
|
||||
return _error_json(400, "messages field is required and must be a non-empty array")
|
||||
|
||||
# Stream not yet supported
|
||||
if body.get("stream", False):
|
||||
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
|
||||
|
||||
# Extract last user message — nanobot manages its own multi-turn history
|
||||
user_content = None
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
user_content = msg.get("content", "")
|
||||
break
|
||||
if user_content is None:
|
||||
return _error_json(400, "messages must contain at least one user message")
|
||||
if isinstance(user_content, list):
|
||||
# Multi-modal content array — extract text parts
|
||||
user_content = " ".join(
|
||||
part.get("text", "") for part in user_content if part.get("type") == "text"
|
||||
)
|
||||
|
||||
agent_loop = request.app["agent_loop"]
|
||||
timeout_s: float = request.app.get("request_timeout", 120.0)
|
||||
model_name: str = body.get("model") or request.app.get("model_name", "nanobot")
|
||||
locks: _SessionLocks = request.app["session_locks"]
|
||||
|
||||
safe_key = session_key[:32] + ("…" if len(session_key) > 32 else "")
|
||||
logger.info("API request session_key={} content={}", safe_key, user_content[:80])
|
||||
|
||||
_FALLBACK = "I've completed processing but have no response to give."
|
||||
|
||||
lock = locks.acquire(session_key)
|
||||
try:
|
||||
async with lock:
|
||||
try:
|
||||
response_text = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
session_key=session_key,
|
||||
channel="api",
|
||||
chat_id=session_key,
|
||||
isolate_memory=True,
|
||||
disabled_tools=_API_DISABLED_TOOLS,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning("Empty response for session {}, retrying", safe_key)
|
||||
response_text = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
session_key=session_key,
|
||||
channel="api",
|
||||
chat_id=session_key,
|
||||
isolate_memory=True,
|
||||
disabled_tools=_API_DISABLED_TOOLS,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning("Empty response after retry for session {}, using fallback", safe_key)
|
||||
response_text = _FALLBACK
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return _error_json(504, f"Request timed out after {timeout_s}s")
|
||||
except Exception:
|
||||
logger.exception("Error processing request for session {}", safe_key)
|
||||
return _error_json(500, "Internal server error", err_type="server_error")
|
||||
finally:
|
||||
locks.release(session_key)
|
||||
|
||||
return web.json_response(_chat_completion_response(response_text, model_name))
|
||||
|
||||
|
||||
async def handle_models(request: web.Request) -> web.Response:
|
||||
"""GET /v1/models"""
|
||||
model_name = request.app.get("model_name", "nanobot")
|
||||
return web.json_response({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "nanobot",
|
||||
}
|
||||
],
|
||||
})
|
||||
|
||||
|
||||
async def handle_health(request: web.Request) -> web.Response:
|
||||
"""GET /health"""
|
||||
return web.json_response({"status": "ok"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# App factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0) -> web.Application:
|
||||
"""Create the aiohttp application.
|
||||
|
||||
Args:
|
||||
agent_loop: An initialized AgentLoop instance.
|
||||
model_name: Model name reported in responses.
|
||||
request_timeout: Per-request timeout in seconds.
|
||||
"""
|
||||
app = web.Application()
|
||||
app["agent_loop"] = agent_loop
|
||||
app["model_name"] = model_name
|
||||
app["request_timeout"] = request_timeout
|
||||
app["session_locks"] = _SessionLocks()
|
||||
|
||||
app.router.add_post("/v1/chat/completions", handle_chat_completions)
|
||||
app.router.add_get("/v1/models", handle_models)
|
||||
app.router.add_get("/health", handle_health)
|
||||
return app
|
||||
|
||||
|
||||
def run_server(agent_loop, host: str = "0.0.0.0", port: int = 8900,
|
||||
model_name: str = "nanobot", request_timeout: float = 120.0) -> None:
|
||||
"""Create and run the server (blocking)."""
|
||||
app = create_app(agent_loop, model_name=model_name, request_timeout=request_timeout)
|
||||
web.run_app(app, host=host, port=port, print=lambda msg: logger.info(msg))
|
||||
@ -237,6 +237,83 @@ def _make_provider(config: Config):
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OpenAI-Compatible API Server
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
port: int = typer.Option(8900, "--port", "-p", help="API server port"),
|
||||
host: str = typer.Option("0.0.0.0", "--host", "-H", help="Bind address"),
|
||||
timeout: float = typer.Option(120.0, "--timeout", "-t", help="Per-request timeout (seconds)"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"),
|
||||
):
|
||||
"""Start the OpenAI-compatible API server (/v1/chat/completions)."""
|
||||
try:
|
||||
from aiohttp import web # noqa: F401
|
||||
except ImportError:
|
||||
console.print("[red]aiohttp is required. Install with: pip install aiohttp[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.api.server import create_app
|
||||
from loguru import logger
|
||||
|
||||
if verbose:
|
||||
logger.enable("nanobot")
|
||||
else:
|
||||
logger.disable("nanobot")
|
||||
|
||||
config = load_config()
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
provider = _make_provider(config)
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
bus = MessageBus()
|
||||
session_manager = SessionManager(config.workspace_path)
|
||||
agent_loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=config.agents.defaults.model,
|
||||
temperature=config.agents.defaults.temperature,
|
||||
max_tokens=config.agents.defaults.max_tokens,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
memory_window=config.agents.defaults.memory_window,
|
||||
brave_api_key=config.tools.web.search.api_key or None,
|
||||
exec_config=config.tools.exec,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
session_manager=session_manager,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
)
|
||||
|
||||
model_name = config.agents.defaults.model
|
||||
console.print(f"{__logo__} Starting OpenAI-compatible API server")
|
||||
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
|
||||
console.print(f" [cyan]Model[/cyan] : {model_name}")
|
||||
console.print(f" [cyan]Timeout[/cyan] : {timeout}s")
|
||||
console.print(f" [cyan]Header[/cyan] : x-session-key (required)")
|
||||
console.print()
|
||||
|
||||
api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout)
|
||||
|
||||
async def on_startup(_app):
|
||||
await agent_loop._connect_mcp()
|
||||
|
||||
async def on_cleanup(_app):
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
api_app.on_startup.append(on_startup)
|
||||
api_app.on_cleanup.append(on_cleanup)
|
||||
|
||||
web.run_app(api_app, host=host, port=port, print=lambda msg: logger.info(msg))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Gateway / Server
|
||||
# ============================================================================
|
||||
|
||||
@ -45,6 +45,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
api = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
]
|
||||
matrix = [
|
||||
"matrix-nio[e2e]>=0.25.2",
|
||||
"mistune>=3.0.0,<4.0.0",
|
||||
@ -53,6 +56,7 @@ matrix = [
|
||||
dev = [
|
||||
"pytest>=9.0.0,<10.0.0",
|
||||
"pytest-asyncio>=1.3.0,<2.0.0",
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
"ruff>=0.1.0",
|
||||
]
|
||||
|
||||
|
||||
@ -509,7 +509,7 @@ class TestConsolidationDeduplicationGuard:
|
||||
|
||||
consolidation_calls = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
async def _fake_consolidate(_session, archive_all: bool = False, **kw) -> None:
|
||||
nonlocal consolidation_calls
|
||||
consolidation_calls += 1
|
||||
await asyncio.sleep(0.05)
|
||||
@ -555,7 +555,7 @@ class TestConsolidationDeduplicationGuard:
|
||||
active = 0
|
||||
max_active = 0
|
||||
|
||||
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
|
||||
async def _fake_consolidate(_session, archive_all: bool = False, **kw) -> None:
|
||||
nonlocal consolidation_calls, active, max_active
|
||||
consolidation_calls += 1
|
||||
active += 1
|
||||
@ -605,7 +605,7 @@ class TestConsolidationDeduplicationGuard:
|
||||
|
||||
started = asyncio.Event()
|
||||
|
||||
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
|
||||
async def _slow_consolidate(_session, archive_all: bool = False, **kw) -> None:
|
||||
started.set()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@ -652,7 +652,7 @@ class TestConsolidationDeduplicationGuard:
|
||||
release = asyncio.Event()
|
||||
archived_count = 0
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
async def _fake_consolidate(sess, archive_all: bool = False, **kw) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
@ -707,7 +707,7 @@ class TestConsolidationDeduplicationGuard:
|
||||
loop.sessions.save(session)
|
||||
before_count = len(session.messages)
|
||||
|
||||
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
async def _failing_consolidate(sess, archive_all: bool = False, **kw) -> bool:
|
||||
if archive_all:
|
||||
return False
|
||||
return True
|
||||
@ -754,7 +754,7 @@ class TestConsolidationDeduplicationGuard:
|
||||
release = asyncio.Event()
|
||||
archived_count = -1
|
||||
|
||||
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
async def _fake_consolidate(sess, archive_all: bool = False, **kw) -> bool:
|
||||
nonlocal archived_count
|
||||
if archive_all:
|
||||
archived_count = len(sess.messages)
|
||||
@ -815,7 +815,7 @@ class TestConsolidationDeduplicationGuard:
|
||||
loop._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
||||
assert session.key in loop._consolidation_locks
|
||||
|
||||
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
|
||||
async def _ok_consolidate(sess, archive_all: bool = False, **kw) -> bool:
|
||||
return True
|
||||
|
||||
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]
|
||||
|
||||
883
tests/test_openai_api.py
Normal file
883
tests/test_openai_api.py
Normal file
@ -0,0 +1,883 @@
|
||||
"""Tests for the OpenAI-compatible API server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.api.server import _SessionLocks, _chat_completion_response, _error_json, create_app
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# aiohttp test client helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
|
||||
from aiohttp import web
|
||||
|
||||
HAS_AIOHTTP = True
|
||||
except ImportError:
|
||||
HAS_AIOHTTP = False
|
||||
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests — no aiohttp required
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionLocks:
|
||||
def test_acquire_creates_lock(self):
|
||||
sl = _SessionLocks()
|
||||
lock = sl.acquire("k1")
|
||||
assert isinstance(lock, asyncio.Lock)
|
||||
|
||||
def test_same_key_returns_same_lock(self):
|
||||
sl = _SessionLocks()
|
||||
l1 = sl.acquire("k1")
|
||||
l2 = sl.acquire("k1")
|
||||
assert l1 is l2
|
||||
|
||||
def test_different_keys_different_locks(self):
|
||||
sl = _SessionLocks()
|
||||
l1 = sl.acquire("k1")
|
||||
l2 = sl.acquire("k2")
|
||||
assert l1 is not l2
|
||||
|
||||
def test_release_cleans_up(self):
|
||||
sl = _SessionLocks()
|
||||
sl.acquire("k1")
|
||||
sl.release("k1")
|
||||
assert "k1" not in sl._locks
|
||||
|
||||
def test_release_keeps_lock_if_still_referenced(self):
|
||||
sl = _SessionLocks()
|
||||
sl.acquire("k1")
|
||||
sl.acquire("k1")
|
||||
sl.release("k1")
|
||||
assert "k1" in sl._locks
|
||||
sl.release("k1")
|
||||
assert "k1" not in sl._locks
|
||||
|
||||
|
||||
class TestResponseHelpers:
|
||||
def test_error_json(self):
|
||||
resp = _error_json(400, "bad request")
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert body["error"]["message"] == "bad request"
|
||||
assert body["error"]["code"] == 400
|
||||
|
||||
def test_chat_completion_response(self):
|
||||
result = _chat_completion_response("hello world", "test-model")
|
||||
assert result["object"] == "chat.completion"
|
||||
assert result["model"] == "test-model"
|
||||
assert result["choices"][0]["message"]["content"] == "hello world"
|
||||
assert result["choices"][0]["finish_reason"] == "stop"
|
||||
assert result["id"].startswith("chatcmpl-")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests — require aiohttp
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
|
||||
agent = MagicMock()
|
||||
agent.process_direct = AsyncMock(return_value=response_text)
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
return _make_mock_agent()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_agent):
|
||||
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli(event_loop, aiohttp_client, app):
|
||||
return event_loop.run_until_complete(aiohttp_client(app))
|
||||
|
||||
|
||||
# ---- Missing header tests ----
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_session_key_returns_400(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert "x-session-key" in body["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_session_key_returns_400(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
headers={"x-session-key": " "},
|
||||
)
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
# ---- Missing messages tests ----
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_messages_returns_400(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"model": "test"},
|
||||
headers={"x-session-key": "test-key"},
|
||||
)
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_user_message_returns_400(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "system", "content": "you are a bot"}]},
|
||||
headers={"x-session-key": "test-key"},
|
||||
)
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
# ---- Stream not supported ----
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_true_returns_400(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": True,
|
||||
},
|
||||
headers={"x-session-key": "test-key"},
|
||||
)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert "stream" in body["error"]["message"].lower()
|
||||
|
||||
|
||||
# ---- Successful request ----
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_request(aiohttp_client, mock_agent):
|
||||
app = create_app(mock_agent, model_name="test-model")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
headers={"x-session-key": "wx:dm:user1"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "mock response"
|
||||
assert body["model"] == "test-model"
|
||||
mock_agent.process_direct.assert_called_once_with(
|
||||
content="hello",
|
||||
session_key="wx:dm:user1",
|
||||
channel="api",
|
||||
chat_id="wx:dm:user1",
|
||||
isolate_memory=True,
|
||||
disabled_tools={"read_file", "write_file", "edit_file", "list_dir", "exec"},
|
||||
)
|
||||
|
||||
|
||||
# ---- Session isolation ----
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_isolation_different_keys(aiohttp_client):
|
||||
"""Two different session keys must route to separate session_key arguments."""
|
||||
call_log: list[str] = []
|
||||
|
||||
async def fake_process(content, session_key="", channel="", chat_id="",
|
||||
isolate_memory=False, disabled_tools=None):
|
||||
call_log.append(session_key)
|
||||
return f"reply to {session_key}"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = fake_process
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
r1 = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "msg1"}]},
|
||||
headers={"x-session-key": "wx:dm:alice"},
|
||||
)
|
||||
r2 = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "msg2"}]},
|
||||
headers={"x-session-key": "wx:group:g1:user:bob"},
|
||||
)
|
||||
|
||||
assert r1.status == 200
|
||||
assert r2.status == 200
|
||||
|
||||
b1 = await r1.json()
|
||||
b2 = await r2.json()
|
||||
assert b1["choices"][0]["message"]["content"] == "reply to wx:dm:alice"
|
||||
assert b2["choices"][0]["message"]["content"] == "reply to wx:group:g1:user:bob"
|
||||
assert call_log == ["wx:dm:alice", "wx:group:g1:user:bob"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_session_key_serialized(aiohttp_client):
|
||||
"""Concurrent requests with the same session key must run serially."""
|
||||
order: list[str] = []
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_process(content, session_key="", channel="", chat_id="",
|
||||
isolate_memory=False, disabled_tools=None):
|
||||
order.append(f"start:{content}")
|
||||
if content == "first":
|
||||
barrier.set()
|
||||
await asyncio.sleep(0.1) # hold lock
|
||||
else:
|
||||
await barrier.wait() # ensure "second" starts after "first" begins
|
||||
order.append(f"end:{content}")
|
||||
return content
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = slow_process
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
async def send(msg):
|
||||
return await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": msg}]},
|
||||
headers={"x-session-key": "same-key"},
|
||||
)
|
||||
|
||||
r1, r2 = await asyncio.gather(send("first"), send("second"))
|
||||
assert r1.status == 200
|
||||
assert r2.status == 200
|
||||
# "first" must fully complete before "second" starts
|
||||
assert order.index("end:first") < order.index("start:second")
|
||||
|
||||
|
||||
# ---- /v1/models ----
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_endpoint(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v1/models")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["object"] == "list"
|
||||
assert len(body["data"]) >= 1
|
||||
assert body["data"][0]["id"] == "test-model"
|
||||
|
||||
|
||||
# ---- /health ----
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(aiohttp_client, app):
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/health")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["status"] == "ok"
|
||||
|
||||
|
||||
# ---- Multimodal content array ----
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent):
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
headers={"x-session-key": "test"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
mock_agent.process_direct.assert_called_once()
|
||||
call_kwargs = mock_agent.process_direct.call_args
|
||||
assert call_kwargs.kwargs["content"] == "describe this"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Memory isolation regression tests (root cause of cross-session leakage)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMemoryIsolation:
|
||||
"""Verify that per-session-key memory prevents cross-session context leakage.
|
||||
|
||||
Root cause: ContextBuilder.build_system_prompt() reads a SHARED
|
||||
workspace/memory/MEMORY.md into the system prompt of ALL users.
|
||||
If user_1 writes "my name is Alice" and the agent persists it to
|
||||
MEMORY.md, user_2/user_N will see it.
|
||||
|
||||
Fix: API mode passes a per-session MemoryStore so each session reads/
|
||||
writes its own MEMORY.md.
|
||||
"""
|
||||
|
||||
def test_context_builder_uses_override_memory(self, tmp_path):
|
||||
"""build_system_prompt with memory_store= must use the override, not global."""
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
(workspace / "memory").mkdir()
|
||||
(workspace / "memory" / "MEMORY.md").write_text("Global: I am shared context")
|
||||
|
||||
ctx = ContextBuilder(workspace)
|
||||
|
||||
# Without override → sees global memory
|
||||
prompt_global = ctx.build_system_prompt()
|
||||
assert "I am shared context" in prompt_global
|
||||
|
||||
# With override → sees only the override's memory
|
||||
override_dir = tmp_path / "isolated" / "memory"
|
||||
override_dir.mkdir(parents=True)
|
||||
(override_dir / "MEMORY.md").write_text("User Alice's private note")
|
||||
|
||||
override_store = MemoryStore.__new__(MemoryStore)
|
||||
override_store.memory_dir = override_dir
|
||||
override_store.memory_file = override_dir / "MEMORY.md"
|
||||
override_store.history_file = override_dir / "HISTORY.md"
|
||||
|
||||
prompt_isolated = ctx.build_system_prompt(memory_store=override_store)
|
||||
assert "User Alice's private note" in prompt_isolated
|
||||
assert "I am shared context" not in prompt_isolated
|
||||
|
||||
def test_different_session_keys_get_different_memory_dirs(self, tmp_path):
|
||||
"""_isolated_memory_store must return distinct paths for distinct keys."""
|
||||
from unittest.mock import MagicMock
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
agent = MagicMock(spec=AgentLoop)
|
||||
agent.workspace = tmp_path
|
||||
agent._isolated_memory_store = AgentLoop._isolated_memory_store.__get__(agent)
|
||||
|
||||
store_a = agent._isolated_memory_store("wx:dm:alice")
|
||||
store_b = agent._isolated_memory_store("wx:dm:bob")
|
||||
|
||||
assert store_a.memory_file != store_b.memory_file
|
||||
assert store_a.memory_dir != store_b.memory_dir
|
||||
assert store_a.memory_file.parent.exists()
|
||||
assert store_b.memory_file.parent.exists()
|
||||
|
||||
def test_isolated_memory_does_not_leak_across_sessions(self, tmp_path):
|
||||
"""End-to-end: writing to one session's memory must not appear in another's."""
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
(workspace / "memory").mkdir()
|
||||
(workspace / "memory" / "MEMORY.md").write_text("")
|
||||
|
||||
ctx = ContextBuilder(workspace)
|
||||
|
||||
# Simulate two isolated memory stores (as the API server would create)
|
||||
def make_store(name):
|
||||
d = tmp_path / "sessions" / name / "memory"
|
||||
d.mkdir(parents=True)
|
||||
s = MemoryStore.__new__(MemoryStore)
|
||||
s.memory_dir = d
|
||||
s.memory_file = d / "MEMORY.md"
|
||||
s.history_file = d / "HISTORY.md"
|
||||
return s
|
||||
|
||||
store_alice = make_store("wx_dm_alice")
|
||||
store_bob = make_store("wx_dm_bob")
|
||||
|
||||
# Use unique markers that won't appear in builtin skills/prompts
|
||||
alice_marker = "XYZZY_ALICE_PRIVATE_MARKER_42"
|
||||
store_alice.write_long_term(alice_marker)
|
||||
|
||||
# Alice's prompt sees it
|
||||
prompt_alice = ctx.build_system_prompt(memory_store=store_alice)
|
||||
assert alice_marker in prompt_alice
|
||||
|
||||
# Bob's prompt must NOT see it
|
||||
prompt_bob = ctx.build_system_prompt(memory_store=store_bob)
|
||||
assert alice_marker not in prompt_bob
|
||||
|
||||
# Global prompt must NOT see it either
|
||||
prompt_global = ctx.build_system_prompt()
|
||||
assert alice_marker not in prompt_global
|
||||
|
||||
def test_build_messages_passes_memory_store(self, tmp_path):
|
||||
"""build_messages must forward memory_store to build_system_prompt."""
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
(workspace / "memory").mkdir()
|
||||
(workspace / "memory" / "MEMORY.md").write_text("GLOBAL_SECRET")
|
||||
|
||||
ctx = ContextBuilder(workspace)
|
||||
|
||||
override_dir = tmp_path / "per_session" / "memory"
|
||||
override_dir.mkdir(parents=True)
|
||||
(override_dir / "MEMORY.md").write_text("SESSION_PRIVATE")
|
||||
|
||||
override_store = MemoryStore.__new__(MemoryStore)
|
||||
override_store.memory_dir = override_dir
|
||||
override_store.memory_file = override_dir / "MEMORY.md"
|
||||
override_store.history_file = override_dir / "HISTORY.md"
|
||||
|
||||
messages = ctx.build_messages(
|
||||
history=[], current_message="hello",
|
||||
memory_store=override_store,
|
||||
)
|
||||
system_content = messages[0]["content"]
|
||||
assert "SESSION_PRIVATE" in system_content
|
||||
assert "GLOBAL_SECRET" not in system_content
|
||||
|
||||
def test_api_handler_passes_isolate_memory_and_disabled_tools(self):
|
||||
"""The API handler must call process_direct with isolate_memory=True and disabled filesystem tools."""
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
server_path = Path(__file__).parent.parent / "nanobot" / "api" / "server.py"
|
||||
source = server_path.read_text()
|
||||
tree = ast.parse(source)
|
||||
|
||||
found_isolate = False
|
||||
found_disabled = False
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.keyword):
|
||||
if node.arg == "isolate_memory" and isinstance(node.value, ast.Constant) and node.value.value is True:
|
||||
found_isolate = True
|
||||
if node.arg == "disabled_tools":
|
||||
found_disabled = True
|
||||
assert found_isolate, "server.py must call process_direct with isolate_memory=True"
|
||||
assert found_disabled, "server.py must call process_direct with disabled_tools"
|
||||
|
||||
def test_disabled_tools_constant_blocks_filesystem_and_exec(self):
|
||||
"""_API_DISABLED_TOOLS must include all filesystem tool names and exec."""
|
||||
from nanobot.api.server import _API_DISABLED_TOOLS
|
||||
for name in ("read_file", "write_file", "edit_file", "list_dir", "exec"):
|
||||
assert name in _API_DISABLED_TOOLS, f"{name} missing from _API_DISABLED_TOOLS"
|
||||
|
||||
def test_system_prompt_uses_isolated_memory_path(self, tmp_path):
|
||||
"""When memory_store is provided, the system prompt must reference
|
||||
the store's paths, NOT the global workspace/memory/MEMORY.md."""
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
(workspace / "memory").mkdir()
|
||||
|
||||
ctx = ContextBuilder(workspace)
|
||||
|
||||
# Default prompt references global path
|
||||
default_prompt = ctx.build_system_prompt()
|
||||
assert "memory/MEMORY.md" in default_prompt
|
||||
|
||||
# Isolated store
|
||||
iso_dir = tmp_path / "sessions" / "wx_dm_alice" / "memory"
|
||||
iso_dir.mkdir(parents=True)
|
||||
store = MemoryStore.__new__(MemoryStore)
|
||||
store.memory_dir = iso_dir
|
||||
store.memory_file = iso_dir / "MEMORY.md"
|
||||
store.history_file = iso_dir / "HISTORY.md"
|
||||
|
||||
iso_prompt = ctx.build_system_prompt(memory_store=store)
|
||||
# Must reference the isolated path
|
||||
assert str(iso_dir / "MEMORY.md") in iso_prompt
|
||||
assert str(iso_dir / "HISTORY.md") in iso_prompt
|
||||
# Must NOT reference the global workspace memory path
|
||||
global_mem = str(workspace.resolve() / "memory" / "MEMORY.md")
|
||||
assert global_mem not in iso_prompt
|
||||
|
||||
def test_run_agent_loop_filters_disabled_tools(self):
|
||||
"""_run_agent_loop must exclude disabled tools from definitions
|
||||
and reject execution of disabled tools."""
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
|
||||
registry = ToolRegistry()
|
||||
|
||||
# Create minimal fake tool definitions
|
||||
class FakeTool:
|
||||
def __init__(self, n):
|
||||
self._name = n
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def to_schema(self):
|
||||
return {"type": "function", "function": {"name": self._name, "parameters": {}}}
|
||||
|
||||
def validate_params(self, params):
|
||||
return []
|
||||
|
||||
async def execute(self, **kw):
|
||||
return "ok"
|
||||
|
||||
for n in ("read_file", "write_file", "web_search", "exec"):
|
||||
registry.register(FakeTool(n))
|
||||
|
||||
all_defs = registry.get_definitions()
|
||||
assert len(all_defs) == 4
|
||||
|
||||
disabled = {"read_file", "write_file"}
|
||||
filtered = [d for d in all_defs
|
||||
if d.get("function", {}).get("name") not in disabled]
|
||||
assert len(filtered) == 2
|
||||
names = {d["function"]["name"] for d in filtered}
|
||||
assert names == {"web_search", "exec"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Consolidation isolation regression tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConsolidationIsolation:
|
||||
"""Verify that memory consolidation in API (isolate_memory) mode writes
|
||||
to the per-session directory and never touches global workspace/memory."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_memory_uses_provided_store(self, tmp_path):
|
||||
"""_consolidate_memory(memory_store=X) must call X.consolidate,
|
||||
not MemoryStore(self.workspace).consolidate."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.session.manager import Session
|
||||
|
||||
agent = MagicMock(spec=AgentLoop)
|
||||
agent.workspace = tmp_path / "workspace"
|
||||
agent.workspace.mkdir()
|
||||
agent.provider = MagicMock()
|
||||
agent.model = "test"
|
||||
agent.memory_window = 50
|
||||
|
||||
# Bind the real method
|
||||
agent._consolidate_memory = AgentLoop._consolidate_memory.__get__(agent)
|
||||
|
||||
session = Session(key="test")
|
||||
session.messages = [{"role": "user", "content": "hi", "timestamp": "2025-01-01T00:00"}] * 10
|
||||
|
||||
# Create an isolated store and mock its consolidate
|
||||
iso_store = MagicMock(spec=MemoryStore)
|
||||
iso_store.consolidate = AsyncMock(return_value=True)
|
||||
|
||||
result = await agent._consolidate_memory(session, memory_store=iso_store)
|
||||
|
||||
assert result is True
|
||||
iso_store.consolidate.assert_called_once()
|
||||
call_args = iso_store.consolidate.call_args
|
||||
assert call_args[0][0] is session # first positional arg is session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidate_memory_defaults_to_global_when_no_store(self, tmp_path):
|
||||
"""Without memory_store, _consolidate_memory must use MemoryStore(workspace)."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.session.manager import Session
|
||||
|
||||
agent = MagicMock(spec=AgentLoop)
|
||||
agent.workspace = tmp_path / "workspace"
|
||||
agent.workspace.mkdir()
|
||||
(agent.workspace / "memory").mkdir()
|
||||
agent.provider = MagicMock()
|
||||
agent.model = "test"
|
||||
agent.memory_window = 50
|
||||
agent._consolidate_memory = AgentLoop._consolidate_memory.__get__(agent)
|
||||
|
||||
session = Session(key="test")
|
||||
|
||||
with patch("nanobot.agent.loop.MemoryStore") as MockStore:
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.consolidate = AsyncMock(return_value=True)
|
||||
MockStore.return_value = mock_instance
|
||||
|
||||
await agent._consolidate_memory(session)
|
||||
|
||||
MockStore.assert_called_once_with(agent.workspace)
|
||||
mock_instance.consolidate.assert_called_once()
|
||||
|
||||
def test_consolidate_writes_to_isolated_dir_not_global(self, tmp_path):
|
||||
"""End-to-end: MemoryStore.consolidate with an isolated store must
|
||||
write HISTORY.md in the isolated dir, not in workspace/memory."""
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
|
||||
# Set up global workspace memory
|
||||
global_mem_dir = tmp_path / "workspace" / "memory"
|
||||
global_mem_dir.mkdir(parents=True)
|
||||
(global_mem_dir / "MEMORY.md").write_text("")
|
||||
(global_mem_dir / "HISTORY.md").write_text("")
|
||||
|
||||
# Set up isolated per-session store
|
||||
iso_dir = tmp_path / "sessions" / "wx_dm_alice" / "memory"
|
||||
iso_dir.mkdir(parents=True)
|
||||
|
||||
iso_store = MemoryStore.__new__(MemoryStore)
|
||||
iso_store.memory_dir = iso_dir
|
||||
iso_store.memory_file = iso_dir / "MEMORY.md"
|
||||
iso_store.history_file = iso_dir / "HISTORY.md"
|
||||
|
||||
# Write via the isolated store
|
||||
iso_store.write_long_term("Alice's private data")
|
||||
iso_store.append_history("[2025-01-01 00:00] Alice asked about X")
|
||||
|
||||
# Isolated store has the data
|
||||
assert "Alice's private data" in iso_store.read_long_term()
|
||||
assert "Alice asked about X" in iso_store.history_file.read_text()
|
||||
|
||||
# Global store must NOT have it
|
||||
assert (global_mem_dir / "MEMORY.md").read_text() == ""
|
||||
assert (global_mem_dir / "HISTORY.md").read_text() == ""
|
||||
|
||||
def test_process_message_passes_memory_store_to_consolidation_paths(self):
|
||||
"""Verify that _process_message passes memory_store to both
|
||||
consolidation triggers (source code check)."""
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
loop_path = Path(__file__).parent.parent / "nanobot" / "agent" / "loop.py"
|
||||
source = loop_path.read_text()
|
||||
tree = ast.parse(source)
|
||||
|
||||
# Find all calls to self._consolidate_memory inside _process_message
|
||||
# and verify they all pass memory_store=
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.FunctionDef) or node.name != "_process_message":
|
||||
continue
|
||||
consolidate_calls = []
|
||||
for child in ast.walk(node):
|
||||
if (isinstance(child, ast.Call)
|
||||
and isinstance(child.func, ast.Attribute)
|
||||
and child.func.attr == "_consolidate_memory"):
|
||||
kw_names = {kw.arg for kw in child.keywords}
|
||||
consolidate_calls.append(kw_names)
|
||||
|
||||
assert len(consolidate_calls) == 2, (
|
||||
f"Expected 2 _consolidate_memory calls in _process_message, "
|
||||
f"found {len(consolidate_calls)}"
|
||||
)
|
||||
for i, kw_names in enumerate(consolidate_calls):
|
||||
assert "memory_store" in kw_names, (
|
||||
f"_consolidate_memory call #{i+1} in _process_message "
|
||||
f"missing memory_store= keyword argument"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty response retry + fallback tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_retry_then_success(aiohttp_client):
|
||||
"""First call returns empty → retry once → second call returns real text."""
|
||||
call_count = 0
|
||||
|
||||
async def sometimes_empty(content, session_key="", channel="", chat_id="",
|
||||
isolate_memory=False, disabled_tools=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return ""
|
||||
return "recovered response"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = sometimes_empty
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
headers={"x-session-key": "retry-test"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "recovered response"
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_both_empty_returns_fallback(aiohttp_client):
|
||||
"""Both calls return empty → must use the fallback text."""
|
||||
call_count = 0
|
||||
|
||||
async def always_empty(content, session_key="", channel="", chat_id="",
|
||||
isolate_memory=False, disabled_tools=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return ""
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = always_empty
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
headers={"x-session-key": "fallback-test"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give."
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only_response_triggers_retry(aiohttp_client):
|
||||
"""Whitespace-only response should be treated as empty and trigger retry."""
|
||||
call_count = 0
|
||||
|
||||
async def whitespace_then_ok(content, session_key="", channel="", chat_id="",
|
||||
isolate_memory=False, disabled_tools=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return " \n "
|
||||
return "real answer"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = whitespace_then_ok
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
headers={"x-session-key": "ws-test"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "real answer"
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_response_triggers_retry(aiohttp_client):
|
||||
"""None response should be treated as empty and trigger retry."""
|
||||
call_count = 0
|
||||
|
||||
async def none_then_ok(content, session_key="", channel="", chat_id="",
|
||||
isolate_memory=False, disabled_tools=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return None
|
||||
return "got it"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = none_then_ok
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
headers={"x-session-key": "none-test"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "got it"
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonempty_response_no_retry(aiohttp_client):
|
||||
"""A normal non-empty response must NOT trigger a retry."""
|
||||
call_count = 0
|
||||
|
||||
async def normal_response(content, session_key="", channel="", chat_id="",
|
||||
isolate_memory=False, disabled_tools=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return "immediate answer"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = normal_response
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
headers={"x-session-key": "normal-test"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "immediate answer"
|
||||
assert call_count == 1
|
||||
Loading…
x
Reference in New Issue
Block a user