feat(api): add OpenAI-compatible endpoint with x-session-key isolation

This commit is contained in:
Tink 2026-03-01 10:53:45 +08:00
parent e1832e75b5
commit 80219baf25
9 changed files with 1387 additions and 26 deletions

96
examples/curl.txt Normal file
View File

@ -0,0 +1,96 @@
# =============================================================================
# nanobot OpenAI-Compatible API — curl examples
# =============================================================================
#
# Prerequisites:
# pip install nanobot-ai[api] # installs aiohttp
# nanobot serve --port 8900 # start the API server
#
# The x-session-key header is REQUIRED for every request.
# Convention:
# Private chat: wx:dm:{sender_id}
# Group @: wx:group:{group_id}:user:{sender_id}
# =============================================================================
# --- 1. Basic chat completion (private chat) ---
curl -X POST http://localhost:8900/v1/chat/completions \
-H "Content-Type: application/json" \
-H "x-session-key: wx:dm:user_alice" \
-d '{
"model": "nanobot",
"messages": [
{"role": "user", "content": "Hello, who are you?"}
]
}'
# --- 2. Follow-up in the same session (context is remembered) ---
curl -X POST http://localhost:8900/v1/chat/completions \
-H "Content-Type: application/json" \
-H "x-session-key: wx:dm:user_alice" \
-d '{
"model": "nanobot",
"messages": [
{"role": "user", "content": "What did I just ask you?"}
]
}'
# --- 3. Different user — isolated session ---
curl -X POST http://localhost:8900/v1/chat/completions \
-H "Content-Type: application/json" \
-H "x-session-key: wx:dm:user_bob" \
-d '{
"model": "nanobot",
"messages": [
{"role": "user", "content": "What did I just ask you?"}
]
}'
# ↑ Bob gets a fresh context — he never asked anything before.
# --- 4. Group chat — per-user session within a group ---
curl -X POST http://localhost:8900/v1/chat/completions \
-H "Content-Type: application/json" \
-H "x-session-key: wx:group:group_abc:user:user_alice" \
-d '{
"model": "nanobot",
"messages": [
{"role": "user", "content": "Summarize our discussion"}
]
}'
# --- 5. List available models ---
curl http://localhost:8900/v1/models
# --- 6. Health check ---
curl http://localhost:8900/health
# --- 7. Missing header — expect 400 ---
curl -X POST http://localhost:8900/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "nanobot",
"messages": [
{"role": "user", "content": "hello"}
]
}'
# ↑ Returns: {"error": {"message": "Missing required header: x-session-key", ...}}
# --- 8. Stream not yet supported — expect 400 ---
curl -X POST http://localhost:8900/v1/chat/completions \
-H "Content-Type: application/json" \
-H "x-session-key: wx:dm:user_alice" \
-d '{
"model": "nanobot",
"messages": [
{"role": "user", "content": "hello"}
],
"stream": true
}'
# ↑ Returns: {"error": {"message": "stream=true is not supported yet...", ...}}

View File

@ -23,15 +23,25 @@ class ContextBuilder:
self.memory = MemoryStore(workspace)
self.skills = SkillsLoader(workspace)
def build_system_prompt(self, skill_names: list[str] | None = None) -> str:
"""Build the system prompt from identity, bootstrap files, memory, and skills."""
parts = [self._get_identity()]
def build_system_prompt(
self,
skill_names: list[str] | None = None,
memory_store: "MemoryStore | None" = None,
) -> str:
"""Build the system prompt from identity, bootstrap files, memory, and skills.
Args:
memory_store: If provided, use this MemoryStore instead of the default
workspace-level one. Used for per-session memory isolation.
"""
parts = [self._get_identity(memory_store=memory_store)]
bootstrap = self._load_bootstrap_files()
if bootstrap:
parts.append(bootstrap)
memory = self.memory.get_memory_context()
store = memory_store or self.memory
memory = store.get_memory_context()
if memory:
parts.append(f"# Memory\n\n{memory}")
@ -52,12 +62,19 @@ Skills with available="false" need dependencies installed first - you can try in
return "\n\n---\n\n".join(parts)
def _get_identity(self) -> str:
def _get_identity(self, memory_store: "MemoryStore | None" = None) -> str:
"""Get the core identity section."""
workspace_path = str(self.workspace.expanduser().resolve())
system = platform.system()
runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}"
if memory_store is not None:
mem_path = str(memory_store.memory_file)
hist_path = str(memory_store.history_file)
else:
mem_path = f"{workspace_path}/memory/MEMORY.md"
hist_path = f"{workspace_path}/memory/HISTORY.md"
return f"""# nanobot 🐈
You are nanobot, a helpful AI assistant.
@ -67,8 +84,8 @@ You are nanobot, a helpful AI assistant.
## Workspace
Your workspace is at: {workspace_path}
- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here)
- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
- Long-term memory: {mem_path} (write important facts here)
- History log: {hist_path} (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM].
- Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md
## nanobot Guidelines
@ -110,10 +127,11 @@ Reply directly with text for conversations. Only use the 'message' tool to send
media: list[str] | None = None,
channel: str | None = None,
chat_id: str | None = None,
memory_store: "MemoryStore | None" = None,
) -> list[dict[str, Any]]:
"""Build the complete message list for an LLM call."""
return [
{"role": "system", "content": self.build_system_prompt(skill_names)},
{"role": "system", "content": self.build_system_prompt(skill_names, memory_store=memory_store)},
*history,
{"role": "user", "content": self._build_runtime_context(channel, chat_id)},
{"role": "user", "content": self._build_user_content(current_message, media)},

View File

@ -174,6 +174,7 @@ class AgentLoop:
self,
initial_messages: list[dict],
on_progress: Callable[..., Awaitable[None]] | None = None,
disabled_tools: set[str] | None = None,
) -> tuple[str | None, list[str], list[dict]]:
"""Run the agent iteration loop. Returns (final_content, tools_used, messages)."""
messages = initial_messages
@ -181,12 +182,19 @@ class AgentLoop:
final_content = None
tools_used: list[str] = []
# Build tool definitions, filtering out disabled tools
if disabled_tools:
tool_defs = [d for d in self.tools.get_definitions()
if d.get("function", {}).get("name") not in disabled_tools]
else:
tool_defs = self.tools.get_definitions()
while iteration < self.max_iterations:
iteration += 1
response = await self.provider.chat(
messages=messages,
tools=self.tools.get_definitions(),
tools=tool_defs,
model=self.model,
temperature=self.temperature,
max_tokens=self.max_tokens,
@ -219,7 +227,10 @@ class AgentLoop:
tools_used.append(tool_call.name)
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tool_call.name, args_str[:200])
result = await self.tools.execute(tool_call.name, tool_call.arguments)
if disabled_tools and tool_call.name in disabled_tools:
result = f"Error: Tool '{tool_call.name}' is not available in this mode."
else:
result = await self.tools.execute(tool_call.name, tool_call.arguments)
messages = self.context.add_tool_result(
messages, tool_call.id, tool_call.name, result
)
@ -322,6 +333,8 @@ class AgentLoop:
msg: InboundMessage,
session_key: str | None = None,
on_progress: Callable[[str], Awaitable[None]] | None = None,
memory_store: MemoryStore | None = None,
disabled_tools: set[str] | None = None,
) -> OutboundMessage | None:
"""Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id")
@ -336,8 +349,11 @@ class AgentLoop:
messages = self.context.build_messages(
history=history,
current_message=msg.content, channel=channel, chat_id=chat_id,
memory_store=memory_store,
)
final_content, _, all_msgs = await self._run_agent_loop(
messages, disabled_tools=disabled_tools,
)
final_content, _, all_msgs = await self._run_agent_loop(messages)
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
return OutboundMessage(channel=channel, chat_id=chat_id,
@ -360,7 +376,9 @@ class AgentLoop:
if snapshot:
temp = Session(key=session.key)
temp.messages = list(snapshot)
if not await self._consolidate_memory(temp, archive_all=True):
if not await self._consolidate_memory(
temp, archive_all=True, memory_store=memory_store,
):
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Memory archival failed, session not cleared. Please try again.",
@ -393,7 +411,9 @@ class AgentLoop:
async def _consolidate_and_unlock():
try:
async with lock:
await self._consolidate_memory(session)
await self._consolidate_memory(
session, memory_store=memory_store,
)
finally:
self._consolidating.discard(session.key)
if not lock.locked():
@ -416,6 +436,7 @@ class AgentLoop:
current_message=msg.content,
media=msg.media if msg.media else None,
channel=msg.channel, chat_id=msg.chat_id,
memory_store=memory_store,
)
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
@ -428,6 +449,7 @@ class AgentLoop:
final_content, _, all_msgs = await self._run_agent_loop(
initial_messages, on_progress=on_progress or _bus_progress,
disabled_tools=disabled_tools,
)
if final_content is None:
@ -470,9 +492,30 @@ class AgentLoop:
session.messages.append(entry)
session.updated_at = datetime.now()
async def _consolidate_memory(self, session, archive_all: bool = False) -> bool:
"""Delegate to MemoryStore.consolidate(). Returns True on success."""
return await MemoryStore(self.workspace).consolidate(
def _isolated_memory_store(self, session_key: str) -> MemoryStore:
"""Return a per-session-key MemoryStore for multi-tenant isolation."""
from nanobot.utils.helpers import safe_filename
safe_key = safe_filename(session_key.replace(":", "_"))
memory_dir = self.workspace / "sessions" / safe_key / "memory"
memory_dir.mkdir(parents=True, exist_ok=True)
store = MemoryStore.__new__(MemoryStore)
store.memory_dir = memory_dir
store.memory_file = memory_dir / "MEMORY.md"
store.history_file = memory_dir / "HISTORY.md"
return store
async def _consolidate_memory(
self, session, archive_all: bool = False,
memory_store: MemoryStore | None = None,
) -> bool:
"""Delegate to MemoryStore.consolidate(). Returns True on success.
Args:
memory_store: If provided, consolidate into this store instead of
the default workspace-level one.
"""
store = memory_store or MemoryStore(self.workspace)
return await store.consolidate(
session, self.provider, self.model,
archive_all=archive_all, memory_window=self.memory_window,
)
@ -484,9 +527,26 @@ class AgentLoop:
channel: str = "cli",
chat_id: str = "direct",
on_progress: Callable[[str], Awaitable[None]] | None = None,
isolate_memory: bool = False,
disabled_tools: set[str] | None = None,
) -> str:
"""Process a message directly (for CLI or cron usage)."""
"""Process a message directly (for CLI or cron usage).
Args:
isolate_memory: When True, use a per-session-key memory directory
instead of the shared workspace memory. This prevents context
leakage between different session keys in multi-tenant (API) mode.
disabled_tools: Tool names to exclude from the LLM tool list and
reject at execution time. Use to block filesystem access in
multi-tenant API mode.
"""
await self._connect_mcp()
memory_store: MemoryStore | None = None
if isolate_memory:
memory_store = self._isolated_memory_store(session_key)
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
response = await self._process_message(msg, session_key=session_key, on_progress=on_progress)
response = await self._process_message(
msg, session_key=session_key, on_progress=on_progress,
memory_store=memory_store, disabled_tools=disabled_tools,
)
return response.content if response else ""

1
nanobot/api/__init__.py Normal file
View File

@ -0,0 +1 @@
"""OpenAI-compatible HTTP API for nanobot."""

222
nanobot/api/server.py Normal file
View File

@ -0,0 +1,222 @@
"""OpenAI-compatible HTTP API server for nanobot.
Provides /v1/chat/completions and /v1/models endpoints.
Session isolation is enforced via the x-session-key request header.
"""
from __future__ import annotations
import asyncio
import time
import uuid
from typing import Any
from aiohttp import web
from loguru import logger
# Tools that must NOT run in multi-tenant API mode.
# Filesystem tools allow the LLM to read/write the shared workspace (including
# global MEMORY.md), and exec allows shell commands that can bypass filesystem
# restrictions (e.g. `cat ~/.nanobot/workspace/memory/MEMORY.md`).
_API_DISABLED_TOOLS: set[str] = {
"read_file", "write_file", "edit_file", "list_dir", "exec",
}
# ---------------------------------------------------------------------------
# Per-session-key lock manager
# ---------------------------------------------------------------------------
class _SessionLocks:
"""Manages one asyncio.Lock per session key for serial execution."""
def __init__(self) -> None:
self._locks: dict[str, asyncio.Lock] = {}
self._ref: dict[str, int] = {} # reference count for cleanup
def acquire(self, key: str) -> asyncio.Lock:
if key not in self._locks:
self._locks[key] = asyncio.Lock()
self._ref[key] = 0
self._ref[key] += 1
return self._locks[key]
def release(self, key: str) -> None:
self._ref[key] -= 1
if self._ref[key] <= 0:
self._locks.pop(key, None)
self._ref.pop(key, None)
# ---------------------------------------------------------------------------
# Response helpers
# ---------------------------------------------------------------------------
def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response:
return web.json_response(
{"error": {"message": message, "type": err_type, "code": status}},
status=status,
)
def _chat_completion_response(content: str, model: str) -> dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
# ---------------------------------------------------------------------------
# Route handlers
# ---------------------------------------------------------------------------
async def handle_chat_completions(request: web.Request) -> web.Response:
"""POST /v1/chat/completions"""
# --- x-session-key validation ---
session_key = request.headers.get("x-session-key", "").strip()
if not session_key:
return _error_json(400, "Missing required header: x-session-key")
# --- Parse body ---
try:
body = await request.json()
except Exception:
return _error_json(400, "Invalid JSON body")
messages = body.get("messages")
if not messages or not isinstance(messages, list):
return _error_json(400, "messages field is required and must be a non-empty array")
# Stream not yet supported
if body.get("stream", False):
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
# Extract last user message — nanobot manages its own multi-turn history
user_content = None
for msg in reversed(messages):
if msg.get("role") == "user":
user_content = msg.get("content", "")
break
if user_content is None:
return _error_json(400, "messages must contain at least one user message")
if isinstance(user_content, list):
# Multi-modal content array — extract text parts
user_content = " ".join(
part.get("text", "") for part in user_content if part.get("type") == "text"
)
agent_loop = request.app["agent_loop"]
timeout_s: float = request.app.get("request_timeout", 120.0)
model_name: str = body.get("model") or request.app.get("model_name", "nanobot")
locks: _SessionLocks = request.app["session_locks"]
safe_key = session_key[:32] + ("" if len(session_key) > 32 else "")
logger.info("API request session_key={} content={}", safe_key, user_content[:80])
_FALLBACK = "I've completed processing but have no response to give."
lock = locks.acquire(session_key)
try:
async with lock:
try:
response_text = await asyncio.wait_for(
agent_loop.process_direct(
content=user_content,
session_key=session_key,
channel="api",
chat_id=session_key,
isolate_memory=True,
disabled_tools=_API_DISABLED_TOOLS,
),
timeout=timeout_s,
)
if not response_text or not response_text.strip():
logger.warning("Empty response for session {}, retrying", safe_key)
response_text = await asyncio.wait_for(
agent_loop.process_direct(
content=user_content,
session_key=session_key,
channel="api",
chat_id=session_key,
isolate_memory=True,
disabled_tools=_API_DISABLED_TOOLS,
),
timeout=timeout_s,
)
if not response_text or not response_text.strip():
logger.warning("Empty response after retry for session {}, using fallback", safe_key)
response_text = _FALLBACK
except asyncio.TimeoutError:
return _error_json(504, f"Request timed out after {timeout_s}s")
except Exception:
logger.exception("Error processing request for session {}", safe_key)
return _error_json(500, "Internal server error", err_type="server_error")
finally:
locks.release(session_key)
return web.json_response(_chat_completion_response(response_text, model_name))
async def handle_models(request: web.Request) -> web.Response:
"""GET /v1/models"""
model_name = request.app.get("model_name", "nanobot")
return web.json_response({
"object": "list",
"data": [
{
"id": model_name,
"object": "model",
"created": 0,
"owned_by": "nanobot",
}
],
})
async def handle_health(request: web.Request) -> web.Response:
"""GET /health"""
return web.json_response({"status": "ok"})
# ---------------------------------------------------------------------------
# App factory
# ---------------------------------------------------------------------------
def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0) -> web.Application:
"""Create the aiohttp application.
Args:
agent_loop: An initialized AgentLoop instance.
model_name: Model name reported in responses.
request_timeout: Per-request timeout in seconds.
"""
app = web.Application()
app["agent_loop"] = agent_loop
app["model_name"] = model_name
app["request_timeout"] = request_timeout
app["session_locks"] = _SessionLocks()
app.router.add_post("/v1/chat/completions", handle_chat_completions)
app.router.add_get("/v1/models", handle_models)
app.router.add_get("/health", handle_health)
return app
def run_server(agent_loop, host: str = "0.0.0.0", port: int = 8900,
model_name: str = "nanobot", request_timeout: float = 120.0) -> None:
"""Create and run the server (blocking)."""
app = create_app(agent_loop, model_name=model_name, request_timeout=request_timeout)
web.run_app(app, host=host, port=port, print=lambda msg: logger.info(msg))

View File

@ -237,6 +237,83 @@ def _make_provider(config: Config):
)
# ============================================================================
# OpenAI-Compatible API Server
# ============================================================================
@app.command()
def serve(
port: int = typer.Option(8900, "--port", "-p", help="API server port"),
host: str = typer.Option("0.0.0.0", "--host", "-H", help="Bind address"),
timeout: float = typer.Option(120.0, "--timeout", "-t", help="Per-request timeout (seconds)"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"),
):
"""Start the OpenAI-compatible API server (/v1/chat/completions)."""
try:
from aiohttp import web # noqa: F401
except ImportError:
console.print("[red]aiohttp is required. Install with: pip install aiohttp[/red]")
raise typer.Exit(1)
from nanobot.config.loader import load_config
from nanobot.api.server import create_app
from loguru import logger
if verbose:
logger.enable("nanobot")
else:
logger.disable("nanobot")
config = load_config()
sync_workspace_templates(config.workspace_path)
provider = _make_provider(config)
from nanobot.bus.queue import MessageBus
from nanobot.agent.loop import AgentLoop
from nanobot.session.manager import SessionManager
bus = MessageBus()
session_manager = SessionManager(config.workspace_path)
agent_loop = AgentLoop(
bus=bus,
provider=provider,
workspace=config.workspace_path,
model=config.agents.defaults.model,
temperature=config.agents.defaults.temperature,
max_tokens=config.agents.defaults.max_tokens,
max_iterations=config.agents.defaults.max_tool_iterations,
memory_window=config.agents.defaults.memory_window,
brave_api_key=config.tools.web.search.api_key or None,
exec_config=config.tools.exec,
restrict_to_workspace=config.tools.restrict_to_workspace,
session_manager=session_manager,
mcp_servers=config.tools.mcp_servers,
channels_config=config.channels,
)
model_name = config.agents.defaults.model
console.print(f"{__logo__} Starting OpenAI-compatible API server")
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
console.print(f" [cyan]Model[/cyan] : {model_name}")
console.print(f" [cyan]Timeout[/cyan] : {timeout}s")
console.print(f" [cyan]Header[/cyan] : x-session-key (required)")
console.print()
api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout)
async def on_startup(_app):
await agent_loop._connect_mcp()
async def on_cleanup(_app):
await agent_loop.close_mcp()
api_app.on_startup.append(on_startup)
api_app.on_cleanup.append(on_cleanup)
web.run_app(api_app, host=host, port=port, print=lambda msg: logger.info(msg))
# ============================================================================
# Gateway / Server
# ============================================================================

View File

@ -45,6 +45,9 @@ dependencies = [
]
[project.optional-dependencies]
api = [
"aiohttp>=3.9.0,<4.0.0",
]
matrix = [
"matrix-nio[e2e]>=0.25.2",
"mistune>=3.0.0,<4.0.0",
@ -53,6 +56,7 @@ matrix = [
dev = [
"pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=1.3.0,<2.0.0",
"aiohttp>=3.9.0,<4.0.0",
"ruff>=0.1.0",
]

View File

@ -509,7 +509,7 @@ class TestConsolidationDeduplicationGuard:
consolidation_calls = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
async def _fake_consolidate(_session, archive_all: bool = False, **kw) -> None:
nonlocal consolidation_calls
consolidation_calls += 1
await asyncio.sleep(0.05)
@ -555,7 +555,7 @@ class TestConsolidationDeduplicationGuard:
active = 0
max_active = 0
async def _fake_consolidate(_session, archive_all: bool = False) -> None:
async def _fake_consolidate(_session, archive_all: bool = False, **kw) -> None:
nonlocal consolidation_calls, active, max_active
consolidation_calls += 1
active += 1
@ -605,7 +605,7 @@ class TestConsolidationDeduplicationGuard:
started = asyncio.Event()
async def _slow_consolidate(_session, archive_all: bool = False) -> None:
async def _slow_consolidate(_session, archive_all: bool = False, **kw) -> None:
started.set()
await asyncio.sleep(0.1)
@ -652,7 +652,7 @@ class TestConsolidationDeduplicationGuard:
release = asyncio.Event()
archived_count = 0
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
async def _fake_consolidate(sess, archive_all: bool = False, **kw) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
@ -707,7 +707,7 @@ class TestConsolidationDeduplicationGuard:
loop.sessions.save(session)
before_count = len(session.messages)
async def _failing_consolidate(sess, archive_all: bool = False) -> bool:
async def _failing_consolidate(sess, archive_all: bool = False, **kw) -> bool:
if archive_all:
return False
return True
@ -754,7 +754,7 @@ class TestConsolidationDeduplicationGuard:
release = asyncio.Event()
archived_count = -1
async def _fake_consolidate(sess, archive_all: bool = False) -> bool:
async def _fake_consolidate(sess, archive_all: bool = False, **kw) -> bool:
nonlocal archived_count
if archive_all:
archived_count = len(sess.messages)
@ -815,7 +815,7 @@ class TestConsolidationDeduplicationGuard:
loop._consolidation_locks.setdefault(session.key, asyncio.Lock())
assert session.key in loop._consolidation_locks
async def _ok_consolidate(sess, archive_all: bool = False) -> bool:
async def _ok_consolidate(sess, archive_all: bool = False, **kw) -> bool:
return True
loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign]

883
tests/test_openai_api.py Normal file
View File

@ -0,0 +1,883 @@
"""Tests for the OpenAI-compatible API server."""
from __future__ import annotations
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.api.server import _SessionLocks, _chat_completion_response, _error_json, create_app
# ---------------------------------------------------------------------------
# aiohttp test client helper
# ---------------------------------------------------------------------------
try:
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
from aiohttp import web
HAS_AIOHTTP = True
except ImportError:
HAS_AIOHTTP = False
pytest_plugins = ("pytest_asyncio",)
# ---------------------------------------------------------------------------
# Unit tests — no aiohttp required
# ---------------------------------------------------------------------------
class TestSessionLocks:
def test_acquire_creates_lock(self):
sl = _SessionLocks()
lock = sl.acquire("k1")
assert isinstance(lock, asyncio.Lock)
def test_same_key_returns_same_lock(self):
sl = _SessionLocks()
l1 = sl.acquire("k1")
l2 = sl.acquire("k1")
assert l1 is l2
def test_different_keys_different_locks(self):
sl = _SessionLocks()
l1 = sl.acquire("k1")
l2 = sl.acquire("k2")
assert l1 is not l2
def test_release_cleans_up(self):
sl = _SessionLocks()
sl.acquire("k1")
sl.release("k1")
assert "k1" not in sl._locks
def test_release_keeps_lock_if_still_referenced(self):
sl = _SessionLocks()
sl.acquire("k1")
sl.acquire("k1")
sl.release("k1")
assert "k1" in sl._locks
sl.release("k1")
assert "k1" not in sl._locks
class TestResponseHelpers:
def test_error_json(self):
resp = _error_json(400, "bad request")
assert resp.status == 400
body = json.loads(resp.body)
assert body["error"]["message"] == "bad request"
assert body["error"]["code"] == 400
def test_chat_completion_response(self):
result = _chat_completion_response("hello world", "test-model")
assert result["object"] == "chat.completion"
assert result["model"] == "test-model"
assert result["choices"][0]["message"]["content"] == "hello world"
assert result["choices"][0]["finish_reason"] == "stop"
assert result["id"].startswith("chatcmpl-")
# ---------------------------------------------------------------------------
# Integration tests — require aiohttp
# ---------------------------------------------------------------------------
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
agent = MagicMock()
agent.process_direct = AsyncMock(return_value=response_text)
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
return agent
@pytest.fixture
def mock_agent():
return _make_mock_agent()
@pytest.fixture
def app(mock_agent):
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
@pytest.fixture
def cli(event_loop, aiohttp_client, app):
return event_loop.run_until_complete(aiohttp_client(app))
# ---- Missing header tests ----
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_missing_session_key_returns_400(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
)
assert resp.status == 400
body = await resp.json()
assert "x-session-key" in body["error"]["message"]
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_empty_session_key_returns_400(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
headers={"x-session-key": " "},
)
assert resp.status == 400
# ---- Missing messages tests ----
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_missing_messages_returns_400(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"model": "test"},
headers={"x-session-key": "test-key"},
)
assert resp.status == 400
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_no_user_message_returns_400(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "system", "content": "you are a bot"}]},
headers={"x-session-key": "test-key"},
)
assert resp.status == 400
# ---- Stream not supported ----
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_true_returns_400(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "hello"}],
"stream": True,
},
headers={"x-session-key": "test-key"},
)
assert resp.status == 400
body = await resp.json()
assert "stream" in body["error"]["message"].lower()
# ---- Successful request ----
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_successful_request(aiohttp_client, mock_agent):
app = create_app(mock_agent, model_name="test-model")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
headers={"x-session-key": "wx:dm:user1"},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "mock response"
assert body["model"] == "test-model"
mock_agent.process_direct.assert_called_once_with(
content="hello",
session_key="wx:dm:user1",
channel="api",
chat_id="wx:dm:user1",
isolate_memory=True,
disabled_tools={"read_file", "write_file", "edit_file", "list_dir", "exec"},
)
# ---- Session isolation ----
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_session_isolation_different_keys(aiohttp_client):
"""Two different session keys must route to separate session_key arguments."""
call_log: list[str] = []
async def fake_process(content, session_key="", channel="", chat_id="",
isolate_memory=False, disabled_tools=None):
call_log.append(session_key)
return f"reply to {session_key}"
agent = MagicMock()
agent.process_direct = fake_process
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
r1 = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "msg1"}]},
headers={"x-session-key": "wx:dm:alice"},
)
r2 = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "msg2"}]},
headers={"x-session-key": "wx:group:g1:user:bob"},
)
assert r1.status == 200
assert r2.status == 200
b1 = await r1.json()
b2 = await r2.json()
assert b1["choices"][0]["message"]["content"] == "reply to wx:dm:alice"
assert b2["choices"][0]["message"]["content"] == "reply to wx:group:g1:user:bob"
assert call_log == ["wx:dm:alice", "wx:group:g1:user:bob"]
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_same_session_key_serialized(aiohttp_client):
"""Concurrent requests with the same session key must run serially."""
order: list[str] = []
barrier = asyncio.Event()
async def slow_process(content, session_key="", channel="", chat_id="",
isolate_memory=False, disabled_tools=None):
order.append(f"start:{content}")
if content == "first":
barrier.set()
await asyncio.sleep(0.1) # hold lock
else:
await barrier.wait() # ensure "second" starts after "first" begins
order.append(f"end:{content}")
return content
agent = MagicMock()
agent.process_direct = slow_process
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
async def send(msg):
return await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": msg}]},
headers={"x-session-key": "same-key"},
)
r1, r2 = await asyncio.gather(send("first"), send("second"))
assert r1.status == 200
assert r2.status == 200
# "first" must fully complete before "second" starts
assert order.index("end:first") < order.index("start:second")
# ---- /v1/models ----
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_models_endpoint(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/v1/models")
assert resp.status == 200
body = await resp.json()
assert body["object"] == "list"
assert len(body["data"]) >= 1
assert body["data"][0]["id"] == "test-model"
# ---- /health ----
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_health_endpoint(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.get("/health")
assert resp.status == 200
body = await resp.json()
assert body["status"] == "ok"
# ---- Multimodal content array ----
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent):
app = create_app(mock_agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "describe this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
],
}
]
},
headers={"x-session-key": "test"},
)
assert resp.status == 200
mock_agent.process_direct.assert_called_once()
call_kwargs = mock_agent.process_direct.call_args
assert call_kwargs.kwargs["content"] == "describe this"
# ---------------------------------------------------------------------------
# Memory isolation regression tests (root cause of cross-session leakage)
# ---------------------------------------------------------------------------
class TestMemoryIsolation:
"""Verify that per-session-key memory prevents cross-session context leakage.
Root cause: ContextBuilder.build_system_prompt() reads a SHARED
workspace/memory/MEMORY.md into the system prompt of ALL users.
If user_1 writes "my name is Alice" and the agent persists it to
MEMORY.md, user_2/user_N will see it.
Fix: API mode passes a per-session MemoryStore so each session reads/
writes its own MEMORY.md.
"""
def test_context_builder_uses_override_memory(self, tmp_path):
"""build_system_prompt with memory_store= must use the override, not global."""
from nanobot.agent.context import ContextBuilder
from nanobot.agent.memory import MemoryStore
workspace = tmp_path / "workspace"
workspace.mkdir()
(workspace / "memory").mkdir()
(workspace / "memory" / "MEMORY.md").write_text("Global: I am shared context")
ctx = ContextBuilder(workspace)
# Without override → sees global memory
prompt_global = ctx.build_system_prompt()
assert "I am shared context" in prompt_global
# With override → sees only the override's memory
override_dir = tmp_path / "isolated" / "memory"
override_dir.mkdir(parents=True)
(override_dir / "MEMORY.md").write_text("User Alice's private note")
override_store = MemoryStore.__new__(MemoryStore)
override_store.memory_dir = override_dir
override_store.memory_file = override_dir / "MEMORY.md"
override_store.history_file = override_dir / "HISTORY.md"
prompt_isolated = ctx.build_system_prompt(memory_store=override_store)
assert "User Alice's private note" in prompt_isolated
assert "I am shared context" not in prompt_isolated
def test_different_session_keys_get_different_memory_dirs(self, tmp_path):
"""_isolated_memory_store must return distinct paths for distinct keys."""
from unittest.mock import MagicMock
from nanobot.agent.loop import AgentLoop
agent = MagicMock(spec=AgentLoop)
agent.workspace = tmp_path
agent._isolated_memory_store = AgentLoop._isolated_memory_store.__get__(agent)
store_a = agent._isolated_memory_store("wx:dm:alice")
store_b = agent._isolated_memory_store("wx:dm:bob")
assert store_a.memory_file != store_b.memory_file
assert store_a.memory_dir != store_b.memory_dir
assert store_a.memory_file.parent.exists()
assert store_b.memory_file.parent.exists()
def test_isolated_memory_does_not_leak_across_sessions(self, tmp_path):
"""End-to-end: writing to one session's memory must not appear in another's."""
from nanobot.agent.context import ContextBuilder
from nanobot.agent.memory import MemoryStore
workspace = tmp_path / "workspace"
workspace.mkdir()
(workspace / "memory").mkdir()
(workspace / "memory" / "MEMORY.md").write_text("")
ctx = ContextBuilder(workspace)
# Simulate two isolated memory stores (as the API server would create)
def make_store(name):
d = tmp_path / "sessions" / name / "memory"
d.mkdir(parents=True)
s = MemoryStore.__new__(MemoryStore)
s.memory_dir = d
s.memory_file = d / "MEMORY.md"
s.history_file = d / "HISTORY.md"
return s
store_alice = make_store("wx_dm_alice")
store_bob = make_store("wx_dm_bob")
# Use unique markers that won't appear in builtin skills/prompts
alice_marker = "XYZZY_ALICE_PRIVATE_MARKER_42"
store_alice.write_long_term(alice_marker)
# Alice's prompt sees it
prompt_alice = ctx.build_system_prompt(memory_store=store_alice)
assert alice_marker in prompt_alice
# Bob's prompt must NOT see it
prompt_bob = ctx.build_system_prompt(memory_store=store_bob)
assert alice_marker not in prompt_bob
# Global prompt must NOT see it either
prompt_global = ctx.build_system_prompt()
assert alice_marker not in prompt_global
def test_build_messages_passes_memory_store(self, tmp_path):
"""build_messages must forward memory_store to build_system_prompt."""
from nanobot.agent.context import ContextBuilder
from nanobot.agent.memory import MemoryStore
workspace = tmp_path / "workspace"
workspace.mkdir()
(workspace / "memory").mkdir()
(workspace / "memory" / "MEMORY.md").write_text("GLOBAL_SECRET")
ctx = ContextBuilder(workspace)
override_dir = tmp_path / "per_session" / "memory"
override_dir.mkdir(parents=True)
(override_dir / "MEMORY.md").write_text("SESSION_PRIVATE")
override_store = MemoryStore.__new__(MemoryStore)
override_store.memory_dir = override_dir
override_store.memory_file = override_dir / "MEMORY.md"
override_store.history_file = override_dir / "HISTORY.md"
messages = ctx.build_messages(
history=[], current_message="hello",
memory_store=override_store,
)
system_content = messages[0]["content"]
assert "SESSION_PRIVATE" in system_content
assert "GLOBAL_SECRET" not in system_content
def test_api_handler_passes_isolate_memory_and_disabled_tools(self):
"""The API handler must call process_direct with isolate_memory=True and disabled filesystem tools."""
import ast
from pathlib import Path
server_path = Path(__file__).parent.parent / "nanobot" / "api" / "server.py"
source = server_path.read_text()
tree = ast.parse(source)
found_isolate = False
found_disabled = False
for node in ast.walk(tree):
if isinstance(node, ast.keyword):
if node.arg == "isolate_memory" and isinstance(node.value, ast.Constant) and node.value.value is True:
found_isolate = True
if node.arg == "disabled_tools":
found_disabled = True
assert found_isolate, "server.py must call process_direct with isolate_memory=True"
assert found_disabled, "server.py must call process_direct with disabled_tools"
def test_disabled_tools_constant_blocks_filesystem_and_exec(self):
"""_API_DISABLED_TOOLS must include all filesystem tool names and exec."""
from nanobot.api.server import _API_DISABLED_TOOLS
for name in ("read_file", "write_file", "edit_file", "list_dir", "exec"):
assert name in _API_DISABLED_TOOLS, f"{name} missing from _API_DISABLED_TOOLS"
def test_system_prompt_uses_isolated_memory_path(self, tmp_path):
"""When memory_store is provided, the system prompt must reference
the store's paths, NOT the global workspace/memory/MEMORY.md."""
from nanobot.agent.context import ContextBuilder
from nanobot.agent.memory import MemoryStore
workspace = tmp_path / "workspace"
workspace.mkdir()
(workspace / "memory").mkdir()
ctx = ContextBuilder(workspace)
# Default prompt references global path
default_prompt = ctx.build_system_prompt()
assert "memory/MEMORY.md" in default_prompt
# Isolated store
iso_dir = tmp_path / "sessions" / "wx_dm_alice" / "memory"
iso_dir.mkdir(parents=True)
store = MemoryStore.__new__(MemoryStore)
store.memory_dir = iso_dir
store.memory_file = iso_dir / "MEMORY.md"
store.history_file = iso_dir / "HISTORY.md"
iso_prompt = ctx.build_system_prompt(memory_store=store)
# Must reference the isolated path
assert str(iso_dir / "MEMORY.md") in iso_prompt
assert str(iso_dir / "HISTORY.md") in iso_prompt
# Must NOT reference the global workspace memory path
global_mem = str(workspace.resolve() / "memory" / "MEMORY.md")
assert global_mem not in iso_prompt
def test_run_agent_loop_filters_disabled_tools(self):
"""_run_agent_loop must exclude disabled tools from definitions
and reject execution of disabled tools."""
from nanobot.agent.tools.registry import ToolRegistry
registry = ToolRegistry()
# Create minimal fake tool definitions
class FakeTool:
def __init__(self, n):
self._name = n
@property
def name(self):
return self._name
def to_schema(self):
return {"type": "function", "function": {"name": self._name, "parameters": {}}}
def validate_params(self, params):
return []
async def execute(self, **kw):
return "ok"
for n in ("read_file", "write_file", "web_search", "exec"):
registry.register(FakeTool(n))
all_defs = registry.get_definitions()
assert len(all_defs) == 4
disabled = {"read_file", "write_file"}
filtered = [d for d in all_defs
if d.get("function", {}).get("name") not in disabled]
assert len(filtered) == 2
names = {d["function"]["name"] for d in filtered}
assert names == {"web_search", "exec"}
# ---------------------------------------------------------------------------
# Consolidation isolation regression tests
# ---------------------------------------------------------------------------
class TestConsolidationIsolation:
"""Verify that memory consolidation in API (isolate_memory) mode writes
to the per-session directory and never touches global workspace/memory."""
@pytest.mark.asyncio
async def test_consolidate_memory_uses_provided_store(self, tmp_path):
"""_consolidate_memory(memory_store=X) must call X.consolidate,
not MemoryStore(self.workspace).consolidate."""
from unittest.mock import AsyncMock, MagicMock, patch
from nanobot.agent.loop import AgentLoop
from nanobot.agent.memory import MemoryStore
from nanobot.session.manager import Session
agent = MagicMock(spec=AgentLoop)
agent.workspace = tmp_path / "workspace"
agent.workspace.mkdir()
agent.provider = MagicMock()
agent.model = "test"
agent.memory_window = 50
# Bind the real method
agent._consolidate_memory = AgentLoop._consolidate_memory.__get__(agent)
session = Session(key="test")
session.messages = [{"role": "user", "content": "hi", "timestamp": "2025-01-01T00:00"}] * 10
# Create an isolated store and mock its consolidate
iso_store = MagicMock(spec=MemoryStore)
iso_store.consolidate = AsyncMock(return_value=True)
result = await agent._consolidate_memory(session, memory_store=iso_store)
assert result is True
iso_store.consolidate.assert_called_once()
call_args = iso_store.consolidate.call_args
assert call_args[0][0] is session # first positional arg is session
@pytest.mark.asyncio
async def test_consolidate_memory_defaults_to_global_when_no_store(self, tmp_path):
"""Without memory_store, _consolidate_memory must use MemoryStore(workspace)."""
from unittest.mock import AsyncMock, MagicMock, patch
from nanobot.agent.loop import AgentLoop
from nanobot.session.manager import Session
agent = MagicMock(spec=AgentLoop)
agent.workspace = tmp_path / "workspace"
agent.workspace.mkdir()
(agent.workspace / "memory").mkdir()
agent.provider = MagicMock()
agent.model = "test"
agent.memory_window = 50
agent._consolidate_memory = AgentLoop._consolidate_memory.__get__(agent)
session = Session(key="test")
with patch("nanobot.agent.loop.MemoryStore") as MockStore:
mock_instance = MagicMock()
mock_instance.consolidate = AsyncMock(return_value=True)
MockStore.return_value = mock_instance
await agent._consolidate_memory(session)
MockStore.assert_called_once_with(agent.workspace)
mock_instance.consolidate.assert_called_once()
def test_consolidate_writes_to_isolated_dir_not_global(self, tmp_path):
"""End-to-end: MemoryStore.consolidate with an isolated store must
write HISTORY.md in the isolated dir, not in workspace/memory."""
from nanobot.agent.memory import MemoryStore
# Set up global workspace memory
global_mem_dir = tmp_path / "workspace" / "memory"
global_mem_dir.mkdir(parents=True)
(global_mem_dir / "MEMORY.md").write_text("")
(global_mem_dir / "HISTORY.md").write_text("")
# Set up isolated per-session store
iso_dir = tmp_path / "sessions" / "wx_dm_alice" / "memory"
iso_dir.mkdir(parents=True)
iso_store = MemoryStore.__new__(MemoryStore)
iso_store.memory_dir = iso_dir
iso_store.memory_file = iso_dir / "MEMORY.md"
iso_store.history_file = iso_dir / "HISTORY.md"
# Write via the isolated store
iso_store.write_long_term("Alice's private data")
iso_store.append_history("[2025-01-01 00:00] Alice asked about X")
# Isolated store has the data
assert "Alice's private data" in iso_store.read_long_term()
assert "Alice asked about X" in iso_store.history_file.read_text()
# Global store must NOT have it
assert (global_mem_dir / "MEMORY.md").read_text() == ""
assert (global_mem_dir / "HISTORY.md").read_text() == ""
def test_process_message_passes_memory_store_to_consolidation_paths(self):
"""Verify that _process_message passes memory_store to both
consolidation triggers (source code check)."""
import ast
from pathlib import Path
loop_path = Path(__file__).parent.parent / "nanobot" / "agent" / "loop.py"
source = loop_path.read_text()
tree = ast.parse(source)
# Find all calls to self._consolidate_memory inside _process_message
# and verify they all pass memory_store=
for node in ast.walk(tree):
if not isinstance(node, ast.FunctionDef) or node.name != "_process_message":
continue
consolidate_calls = []
for child in ast.walk(node):
if (isinstance(child, ast.Call)
and isinstance(child.func, ast.Attribute)
and child.func.attr == "_consolidate_memory"):
kw_names = {kw.arg for kw in child.keywords}
consolidate_calls.append(kw_names)
assert len(consolidate_calls) == 2, (
f"Expected 2 _consolidate_memory calls in _process_message, "
f"found {len(consolidate_calls)}"
)
for i, kw_names in enumerate(consolidate_calls):
assert "memory_store" in kw_names, (
f"_consolidate_memory call #{i+1} in _process_message "
f"missing memory_store= keyword argument"
)
# ---------------------------------------------------------------------------
# Empty response retry + fallback tests
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_empty_response_retry_then_success(aiohttp_client):
"""First call returns empty → retry once → second call returns real text."""
call_count = 0
async def sometimes_empty(content, session_key="", channel="", chat_id="",
isolate_memory=False, disabled_tools=None):
nonlocal call_count
call_count += 1
if call_count == 1:
return ""
return "recovered response"
agent = MagicMock()
agent.process_direct = sometimes_empty
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
headers={"x-session-key": "retry-test"},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "recovered response"
assert call_count == 2
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_empty_response_both_empty_returns_fallback(aiohttp_client):
"""Both calls return empty → must use the fallback text."""
call_count = 0
async def always_empty(content, session_key="", channel="", chat_id="",
isolate_memory=False, disabled_tools=None):
nonlocal call_count
call_count += 1
return ""
agent = MagicMock()
agent.process_direct = always_empty
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
headers={"x-session-key": "fallback-test"},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give."
assert call_count == 2
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_whitespace_only_response_triggers_retry(aiohttp_client):
"""Whitespace-only response should be treated as empty and trigger retry."""
call_count = 0
async def whitespace_then_ok(content, session_key="", channel="", chat_id="",
isolate_memory=False, disabled_tools=None):
nonlocal call_count
call_count += 1
if call_count == 1:
return " \n "
return "real answer"
agent = MagicMock()
agent.process_direct = whitespace_then_ok
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
headers={"x-session-key": "ws-test"},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "real answer"
assert call_count == 2
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_none_response_triggers_retry(aiohttp_client):
"""None response should be treated as empty and trigger retry."""
call_count = 0
async def none_then_ok(content, session_key="", channel="", chat_id="",
isolate_memory=False, disabled_tools=None):
nonlocal call_count
call_count += 1
if call_count == 1:
return None
return "got it"
agent = MagicMock()
agent.process_direct = none_then_ok
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
headers={"x-session-key": "none-test"},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "got it"
assert call_count == 2
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_nonempty_response_no_retry(aiohttp_client):
"""A normal non-empty response must NOT trigger a retry."""
call_count = 0
async def normal_response(content, session_key="", channel="", chat_id="",
isolate_memory=False, disabled_tools=None):
nonlocal call_count
call_count += 1
return "immediate answer"
agent = MagicMock()
agent.process_direct = normal_response
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
headers={"x-session-key": "normal-test"},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "immediate answer"
assert call_count == 1