From 80219baf255d2f75b15edac616f24b7b0025ded1 Mon Sep 17 00:00:00 2001 From: Tink Date: Sun, 1 Mar 2026 10:53:45 +0800 Subject: [PATCH] feat(api): add OpenAI-compatible endpoint with x-session-key isolation --- examples/curl.txt | 96 ++++ nanobot/agent/context.py | 36 +- nanobot/agent/loop.py | 80 ++- nanobot/api/__init__.py | 1 + nanobot/api/server.py | 222 ++++++++ nanobot/cli/commands.py | 77 +++ pyproject.toml | 4 + tests/test_consolidate_offset.py | 14 +- tests/test_openai_api.py | 883 +++++++++++++++++++++++++++++++ 9 files changed, 1387 insertions(+), 26 deletions(-) create mode 100644 examples/curl.txt create mode 100644 nanobot/api/__init__.py create mode 100644 nanobot/api/server.py create mode 100644 tests/test_openai_api.py diff --git a/examples/curl.txt b/examples/curl.txt new file mode 100644 index 000000000..70dc4dfe7 --- /dev/null +++ b/examples/curl.txt @@ -0,0 +1,96 @@ +# ============================================================================= +# nanobot OpenAI-Compatible API — curl examples +# ============================================================================= +# +# Prerequisites: +# pip install nanobot-ai[api] # installs aiohttp +# nanobot serve --port 8900 # start the API server +# +# The x-session-key header is REQUIRED for every request. +# Convention: +# Private chat: wx:dm:{sender_id} +# Group @: wx:group:{group_id}:user:{sender_id} +# ============================================================================= + +# --- 1. Basic chat completion (private chat) --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-session-key: wx:dm:user_alice" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "Hello, who are you?"} + ] + }' + +# --- 2. Follow-up in the same session (context is remembered) --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-session-key: wx:dm:user_alice" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "What did I just ask you?"} + ] + }' + +# --- 3. Different user — isolated session --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-session-key: wx:dm:user_bob" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "What did I just ask you?"} + ] + }' +# ↑ Bob gets a fresh context — he never asked anything before. + +# --- 4. Group chat — per-user session within a group --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-session-key: wx:group:group_abc:user:user_alice" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "Summarize our discussion"} + ] + }' + +# --- 5. List available models --- + +curl http://localhost:8900/v1/models + +# --- 6. Health check --- + +curl http://localhost:8900/health + +# --- 7. Missing header — expect 400 --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "hello"} + ] + }' +# ↑ Returns: {"error": {"message": "Missing required header: x-session-key", ...}} + +# --- 8. Stream not yet supported — expect 400 --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-session-key: wx:dm:user_alice" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "hello"} + ], + "stream": true + }' +# ↑ Returns: {"error": {"message": "stream=true is not supported yet...", ...}} diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index be0ec5996..3665d7f3a 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -23,15 +23,25 @@ class ContextBuilder: self.memory = MemoryStore(workspace) self.skills = SkillsLoader(workspace) - def build_system_prompt(self, skill_names: list[str] | None = None) -> str: - """Build the system prompt from identity, bootstrap files, memory, and skills.""" - parts = [self._get_identity()] + def build_system_prompt( + self, + skill_names: list[str] | None = None, + memory_store: "MemoryStore | None" = None, + ) -> str: + """Build the system prompt from identity, bootstrap files, memory, and skills. + + Args: + memory_store: If provided, use this MemoryStore instead of the default + workspace-level one. Used for per-session memory isolation. + """ + parts = [self._get_identity(memory_store=memory_store)] bootstrap = self._load_bootstrap_files() if bootstrap: parts.append(bootstrap) - memory = self.memory.get_memory_context() + store = memory_store or self.memory + memory = store.get_memory_context() if memory: parts.append(f"# Memory\n\n{memory}") @@ -52,12 +62,19 @@ Skills with available="false" need dependencies installed first - you can try in return "\n\n---\n\n".join(parts) - def _get_identity(self) -> str: + def _get_identity(self, memory_store: "MemoryStore | None" = None) -> str: """Get the core identity section.""" workspace_path = str(self.workspace.expanduser().resolve()) system = platform.system() runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}" - + + if memory_store is not None: + mem_path = str(memory_store.memory_file) + hist_path = str(memory_store.history_file) + else: + mem_path = f"{workspace_path}/memory/MEMORY.md" + hist_path = f"{workspace_path}/memory/HISTORY.md" + return f"""# nanobot 🐈 You are nanobot, a helpful AI assistant. @@ -67,8 +84,8 @@ You are nanobot, a helpful AI assistant. ## Workspace Your workspace is at: {workspace_path} -- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here) -- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM]. +- Long-term memory: {mem_path} (write important facts here) +- History log: {hist_path} (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM]. - Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md ## nanobot Guidelines @@ -110,10 +127,11 @@ Reply directly with text for conversations. Only use the 'message' tool to send media: list[str] | None = None, channel: str | None = None, chat_id: str | None = None, + memory_store: "MemoryStore | None" = None, ) -> list[dict[str, Any]]: """Build the complete message list for an LLM call.""" return [ - {"role": "system", "content": self.build_system_prompt(skill_names)}, + {"role": "system", "content": self.build_system_prompt(skill_names, memory_store=memory_store)}, *history, {"role": "user", "content": self._build_runtime_context(channel, chat_id)}, {"role": "user", "content": self._build_user_content(current_message, media)}, diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index b605ae4a9..6a0d24f26 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -174,6 +174,7 @@ class AgentLoop: self, initial_messages: list[dict], on_progress: Callable[..., Awaitable[None]] | None = None, + disabled_tools: set[str] | None = None, ) -> tuple[str | None, list[str], list[dict]]: """Run the agent iteration loop. Returns (final_content, tools_used, messages).""" messages = initial_messages @@ -181,12 +182,19 @@ class AgentLoop: final_content = None tools_used: list[str] = [] + # Build tool definitions, filtering out disabled tools + if disabled_tools: + tool_defs = [d for d in self.tools.get_definitions() + if d.get("function", {}).get("name") not in disabled_tools] + else: + tool_defs = self.tools.get_definitions() + while iteration < self.max_iterations: iteration += 1 response = await self.provider.chat( messages=messages, - tools=self.tools.get_definitions(), + tools=tool_defs, model=self.model, temperature=self.temperature, max_tokens=self.max_tokens, @@ -219,7 +227,10 @@ class AgentLoop: tools_used.append(tool_call.name) args_str = json.dumps(tool_call.arguments, ensure_ascii=False) logger.info("Tool call: {}({})", tool_call.name, args_str[:200]) - result = await self.tools.execute(tool_call.name, tool_call.arguments) + if disabled_tools and tool_call.name in disabled_tools: + result = f"Error: Tool '{tool_call.name}' is not available in this mode." + else: + result = await self.tools.execute(tool_call.name, tool_call.arguments) messages = self.context.add_tool_result( messages, tool_call.id, tool_call.name, result ) @@ -322,6 +333,8 @@ class AgentLoop: msg: InboundMessage, session_key: str | None = None, on_progress: Callable[[str], Awaitable[None]] | None = None, + memory_store: MemoryStore | None = None, + disabled_tools: set[str] | None = None, ) -> OutboundMessage | None: """Process a single inbound message and return the response.""" # System messages: parse origin from chat_id ("channel:chat_id") @@ -336,8 +349,11 @@ class AgentLoop: messages = self.context.build_messages( history=history, current_message=msg.content, channel=channel, chat_id=chat_id, + memory_store=memory_store, + ) + final_content, _, all_msgs = await self._run_agent_loop( + messages, disabled_tools=disabled_tools, ) - final_content, _, all_msgs = await self._run_agent_loop(messages) self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) return OutboundMessage(channel=channel, chat_id=chat_id, @@ -360,7 +376,9 @@ class AgentLoop: if snapshot: temp = Session(key=session.key) temp.messages = list(snapshot) - if not await self._consolidate_memory(temp, archive_all=True): + if not await self._consolidate_memory( + temp, archive_all=True, memory_store=memory_store, + ): return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content="Memory archival failed, session not cleared. Please try again.", @@ -393,7 +411,9 @@ class AgentLoop: async def _consolidate_and_unlock(): try: async with lock: - await self._consolidate_memory(session) + await self._consolidate_memory( + session, memory_store=memory_store, + ) finally: self._consolidating.discard(session.key) if not lock.locked(): @@ -416,6 +436,7 @@ class AgentLoop: current_message=msg.content, media=msg.media if msg.media else None, channel=msg.channel, chat_id=msg.chat_id, + memory_store=memory_store, ) async def _bus_progress(content: str, *, tool_hint: bool = False) -> None: @@ -428,6 +449,7 @@ class AgentLoop: final_content, _, all_msgs = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, + disabled_tools=disabled_tools, ) if final_content is None: @@ -470,9 +492,30 @@ class AgentLoop: session.messages.append(entry) session.updated_at = datetime.now() - async def _consolidate_memory(self, session, archive_all: bool = False) -> bool: - """Delegate to MemoryStore.consolidate(). Returns True on success.""" - return await MemoryStore(self.workspace).consolidate( + def _isolated_memory_store(self, session_key: str) -> MemoryStore: + """Return a per-session-key MemoryStore for multi-tenant isolation.""" + from nanobot.utils.helpers import safe_filename + safe_key = safe_filename(session_key.replace(":", "_")) + memory_dir = self.workspace / "sessions" / safe_key / "memory" + memory_dir.mkdir(parents=True, exist_ok=True) + store = MemoryStore.__new__(MemoryStore) + store.memory_dir = memory_dir + store.memory_file = memory_dir / "MEMORY.md" + store.history_file = memory_dir / "HISTORY.md" + return store + + async def _consolidate_memory( + self, session, archive_all: bool = False, + memory_store: MemoryStore | None = None, + ) -> bool: + """Delegate to MemoryStore.consolidate(). Returns True on success. + + Args: + memory_store: If provided, consolidate into this store instead of + the default workspace-level one. + """ + store = memory_store or MemoryStore(self.workspace) + return await store.consolidate( session, self.provider, self.model, archive_all=archive_all, memory_window=self.memory_window, ) @@ -484,9 +527,26 @@ class AgentLoop: channel: str = "cli", chat_id: str = "direct", on_progress: Callable[[str], Awaitable[None]] | None = None, + isolate_memory: bool = False, + disabled_tools: set[str] | None = None, ) -> str: - """Process a message directly (for CLI or cron usage).""" + """Process a message directly (for CLI or cron usage). + + Args: + isolate_memory: When True, use a per-session-key memory directory + instead of the shared workspace memory. This prevents context + leakage between different session keys in multi-tenant (API) mode. + disabled_tools: Tool names to exclude from the LLM tool list and + reject at execution time. Use to block filesystem access in + multi-tenant API mode. + """ await self._connect_mcp() + memory_store: MemoryStore | None = None + if isolate_memory: + memory_store = self._isolated_memory_store(session_key) msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content) - response = await self._process_message(msg, session_key=session_key, on_progress=on_progress) + response = await self._process_message( + msg, session_key=session_key, on_progress=on_progress, + memory_store=memory_store, disabled_tools=disabled_tools, + ) return response.content if response else "" diff --git a/nanobot/api/__init__.py b/nanobot/api/__init__.py new file mode 100644 index 000000000..f0c504cc1 --- /dev/null +++ b/nanobot/api/__init__.py @@ -0,0 +1 @@ +"""OpenAI-compatible HTTP API for nanobot.""" diff --git a/nanobot/api/server.py b/nanobot/api/server.py new file mode 100644 index 000000000..a3077537f --- /dev/null +++ b/nanobot/api/server.py @@ -0,0 +1,222 @@ +"""OpenAI-compatible HTTP API server for nanobot. + +Provides /v1/chat/completions and /v1/models endpoints. +Session isolation is enforced via the x-session-key request header. +""" + +from __future__ import annotations + +import asyncio +import time +import uuid +from typing import Any + +from aiohttp import web +from loguru import logger + +# Tools that must NOT run in multi-tenant API mode. +# Filesystem tools allow the LLM to read/write the shared workspace (including +# global MEMORY.md), and exec allows shell commands that can bypass filesystem +# restrictions (e.g. `cat ~/.nanobot/workspace/memory/MEMORY.md`). +_API_DISABLED_TOOLS: set[str] = { + "read_file", "write_file", "edit_file", "list_dir", "exec", +} + + +# --------------------------------------------------------------------------- +# Per-session-key lock manager +# --------------------------------------------------------------------------- + +class _SessionLocks: + """Manages one asyncio.Lock per session key for serial execution.""" + + def __init__(self) -> None: + self._locks: dict[str, asyncio.Lock] = {} + self._ref: dict[str, int] = {} # reference count for cleanup + + def acquire(self, key: str) -> asyncio.Lock: + if key not in self._locks: + self._locks[key] = asyncio.Lock() + self._ref[key] = 0 + self._ref[key] += 1 + return self._locks[key] + + def release(self, key: str) -> None: + self._ref[key] -= 1 + if self._ref[key] <= 0: + self._locks.pop(key, None) + self._ref.pop(key, None) + + +# --------------------------------------------------------------------------- +# Response helpers +# --------------------------------------------------------------------------- + +def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response: + return web.json_response( + {"error": {"message": message, "type": err_type, "code": status}}, + status=status, + ) + + +def _chat_completion_response(content: str, model: str) -> dict[str, Any]: + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + +# --------------------------------------------------------------------------- +# Route handlers +# --------------------------------------------------------------------------- + +async def handle_chat_completions(request: web.Request) -> web.Response: + """POST /v1/chat/completions""" + + # --- x-session-key validation --- + session_key = request.headers.get("x-session-key", "").strip() + if not session_key: + return _error_json(400, "Missing required header: x-session-key") + + # --- Parse body --- + try: + body = await request.json() + except Exception: + return _error_json(400, "Invalid JSON body") + + messages = body.get("messages") + if not messages or not isinstance(messages, list): + return _error_json(400, "messages field is required and must be a non-empty array") + + # Stream not yet supported + if body.get("stream", False): + return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.") + + # Extract last user message — nanobot manages its own multi-turn history + user_content = None + for msg in reversed(messages): + if msg.get("role") == "user": + user_content = msg.get("content", "") + break + if user_content is None: + return _error_json(400, "messages must contain at least one user message") + if isinstance(user_content, list): + # Multi-modal content array — extract text parts + user_content = " ".join( + part.get("text", "") for part in user_content if part.get("type") == "text" + ) + + agent_loop = request.app["agent_loop"] + timeout_s: float = request.app.get("request_timeout", 120.0) + model_name: str = body.get("model") or request.app.get("model_name", "nanobot") + locks: _SessionLocks = request.app["session_locks"] + + safe_key = session_key[:32] + ("…" if len(session_key) > 32 else "") + logger.info("API request session_key={} content={}", safe_key, user_content[:80]) + + _FALLBACK = "I've completed processing but have no response to give." + + lock = locks.acquire(session_key) + try: + async with lock: + try: + response_text = await asyncio.wait_for( + agent_loop.process_direct( + content=user_content, + session_key=session_key, + channel="api", + chat_id=session_key, + isolate_memory=True, + disabled_tools=_API_DISABLED_TOOLS, + ), + timeout=timeout_s, + ) + + if not response_text or not response_text.strip(): + logger.warning("Empty response for session {}, retrying", safe_key) + response_text = await asyncio.wait_for( + agent_loop.process_direct( + content=user_content, + session_key=session_key, + channel="api", + chat_id=session_key, + isolate_memory=True, + disabled_tools=_API_DISABLED_TOOLS, + ), + timeout=timeout_s, + ) + if not response_text or not response_text.strip(): + logger.warning("Empty response after retry for session {}, using fallback", safe_key) + response_text = _FALLBACK + + except asyncio.TimeoutError: + return _error_json(504, f"Request timed out after {timeout_s}s") + except Exception: + logger.exception("Error processing request for session {}", safe_key) + return _error_json(500, "Internal server error", err_type="server_error") + finally: + locks.release(session_key) + + return web.json_response(_chat_completion_response(response_text, model_name)) + + +async def handle_models(request: web.Request) -> web.Response: + """GET /v1/models""" + model_name = request.app.get("model_name", "nanobot") + return web.json_response({ + "object": "list", + "data": [ + { + "id": model_name, + "object": "model", + "created": 0, + "owned_by": "nanobot", + } + ], + }) + + +async def handle_health(request: web.Request) -> web.Response: + """GET /health""" + return web.json_response({"status": "ok"}) + + +# --------------------------------------------------------------------------- +# App factory +# --------------------------------------------------------------------------- + +def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0) -> web.Application: + """Create the aiohttp application. + + Args: + agent_loop: An initialized AgentLoop instance. + model_name: Model name reported in responses. + request_timeout: Per-request timeout in seconds. + """ + app = web.Application() + app["agent_loop"] = agent_loop + app["model_name"] = model_name + app["request_timeout"] = request_timeout + app["session_locks"] = _SessionLocks() + + app.router.add_post("/v1/chat/completions", handle_chat_completions) + app.router.add_get("/v1/models", handle_models) + app.router.add_get("/health", handle_health) + return app + + +def run_server(agent_loop, host: str = "0.0.0.0", port: int = 8900, + model_name: str = "nanobot", request_timeout: float = 120.0) -> None: + """Create and run the server (blocking).""" + app = create_app(agent_loop, model_name=model_name, request_timeout=request_timeout) + web.run_app(app, host=host, port=port, print=lambda msg: logger.info(msg)) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index fc4c261ea..208b4e742 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -237,6 +237,83 @@ def _make_provider(config: Config): ) +# ============================================================================ +# OpenAI-Compatible API Server +# ============================================================================ + + +@app.command() +def serve( + port: int = typer.Option(8900, "--port", "-p", help="API server port"), + host: str = typer.Option("0.0.0.0", "--host", "-H", help="Bind address"), + timeout: float = typer.Option(120.0, "--timeout", "-t", help="Per-request timeout (seconds)"), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"), +): + """Start the OpenAI-compatible API server (/v1/chat/completions).""" + try: + from aiohttp import web # noqa: F401 + except ImportError: + console.print("[red]aiohttp is required. Install with: pip install aiohttp[/red]") + raise typer.Exit(1) + + from nanobot.config.loader import load_config + from nanobot.api.server import create_app + from loguru import logger + + if verbose: + logger.enable("nanobot") + else: + logger.disable("nanobot") + + config = load_config() + sync_workspace_templates(config.workspace_path) + provider = _make_provider(config) + + from nanobot.bus.queue import MessageBus + from nanobot.agent.loop import AgentLoop + from nanobot.session.manager import SessionManager + + bus = MessageBus() + session_manager = SessionManager(config.workspace_path) + agent_loop = AgentLoop( + bus=bus, + provider=provider, + workspace=config.workspace_path, + model=config.agents.defaults.model, + temperature=config.agents.defaults.temperature, + max_tokens=config.agents.defaults.max_tokens, + max_iterations=config.agents.defaults.max_tool_iterations, + memory_window=config.agents.defaults.memory_window, + brave_api_key=config.tools.web.search.api_key or None, + exec_config=config.tools.exec, + restrict_to_workspace=config.tools.restrict_to_workspace, + session_manager=session_manager, + mcp_servers=config.tools.mcp_servers, + channels_config=config.channels, + ) + + model_name = config.agents.defaults.model + console.print(f"{__logo__} Starting OpenAI-compatible API server") + console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions") + console.print(f" [cyan]Model[/cyan] : {model_name}") + console.print(f" [cyan]Timeout[/cyan] : {timeout}s") + console.print(f" [cyan]Header[/cyan] : x-session-key (required)") + console.print() + + api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout) + + async def on_startup(_app): + await agent_loop._connect_mcp() + + async def on_cleanup(_app): + await agent_loop.close_mcp() + + api_app.on_startup.append(on_startup) + api_app.on_cleanup.append(on_cleanup) + + web.run_app(api_app, host=host, port=port, print=lambda msg: logger.info(msg)) + + # ============================================================================ # Gateway / Server # ============================================================================ diff --git a/pyproject.toml b/pyproject.toml index 20dcb1e01..f71faa146 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,9 @@ dependencies = [ ] [project.optional-dependencies] +api = [ + "aiohttp>=3.9.0,<4.0.0", +] matrix = [ "matrix-nio[e2e]>=0.25.2", "mistune>=3.0.0,<4.0.0", @@ -53,6 +56,7 @@ matrix = [ dev = [ "pytest>=9.0.0,<10.0.0", "pytest-asyncio>=1.3.0,<2.0.0", + "aiohttp>=3.9.0,<4.0.0", "ruff>=0.1.0", ] diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py index 675512406..fc72e0a63 100644 --- a/tests/test_consolidate_offset.py +++ b/tests/test_consolidate_offset.py @@ -509,7 +509,7 @@ class TestConsolidationDeduplicationGuard: consolidation_calls = 0 - async def _fake_consolidate(_session, archive_all: bool = False) -> None: + async def _fake_consolidate(_session, archive_all: bool = False, **kw) -> None: nonlocal consolidation_calls consolidation_calls += 1 await asyncio.sleep(0.05) @@ -555,7 +555,7 @@ class TestConsolidationDeduplicationGuard: active = 0 max_active = 0 - async def _fake_consolidate(_session, archive_all: bool = False) -> None: + async def _fake_consolidate(_session, archive_all: bool = False, **kw) -> None: nonlocal consolidation_calls, active, max_active consolidation_calls += 1 active += 1 @@ -605,7 +605,7 @@ class TestConsolidationDeduplicationGuard: started = asyncio.Event() - async def _slow_consolidate(_session, archive_all: bool = False) -> None: + async def _slow_consolidate(_session, archive_all: bool = False, **kw) -> None: started.set() await asyncio.sleep(0.1) @@ -652,7 +652,7 @@ class TestConsolidationDeduplicationGuard: release = asyncio.Event() archived_count = 0 - async def _fake_consolidate(sess, archive_all: bool = False) -> bool: + async def _fake_consolidate(sess, archive_all: bool = False, **kw) -> bool: nonlocal archived_count if archive_all: archived_count = len(sess.messages) @@ -707,7 +707,7 @@ class TestConsolidationDeduplicationGuard: loop.sessions.save(session) before_count = len(session.messages) - async def _failing_consolidate(sess, archive_all: bool = False) -> bool: + async def _failing_consolidate(sess, archive_all: bool = False, **kw) -> bool: if archive_all: return False return True @@ -754,7 +754,7 @@ class TestConsolidationDeduplicationGuard: release = asyncio.Event() archived_count = -1 - async def _fake_consolidate(sess, archive_all: bool = False) -> bool: + async def _fake_consolidate(sess, archive_all: bool = False, **kw) -> bool: nonlocal archived_count if archive_all: archived_count = len(sess.messages) @@ -815,7 +815,7 @@ class TestConsolidationDeduplicationGuard: loop._consolidation_locks.setdefault(session.key, asyncio.Lock()) assert session.key in loop._consolidation_locks - async def _ok_consolidate(sess, archive_all: bool = False) -> bool: + async def _ok_consolidate(sess, archive_all: bool = False, **kw) -> bool: return True loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign] diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py new file mode 100644 index 000000000..b4d831579 --- /dev/null +++ b/tests/test_openai_api.py @@ -0,0 +1,883 @@ +"""Tests for the OpenAI-compatible API server.""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.api.server import _SessionLocks, _chat_completion_response, _error_json, create_app + +# --------------------------------------------------------------------------- +# aiohttp test client helper +# --------------------------------------------------------------------------- + +try: + from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop + from aiohttp import web + + HAS_AIOHTTP = True +except ImportError: + HAS_AIOHTTP = False + +pytest_plugins = ("pytest_asyncio",) + +# --------------------------------------------------------------------------- +# Unit tests — no aiohttp required +# --------------------------------------------------------------------------- + + +class TestSessionLocks: + def test_acquire_creates_lock(self): + sl = _SessionLocks() + lock = sl.acquire("k1") + assert isinstance(lock, asyncio.Lock) + + def test_same_key_returns_same_lock(self): + sl = _SessionLocks() + l1 = sl.acquire("k1") + l2 = sl.acquire("k1") + assert l1 is l2 + + def test_different_keys_different_locks(self): + sl = _SessionLocks() + l1 = sl.acquire("k1") + l2 = sl.acquire("k2") + assert l1 is not l2 + + def test_release_cleans_up(self): + sl = _SessionLocks() + sl.acquire("k1") + sl.release("k1") + assert "k1" not in sl._locks + + def test_release_keeps_lock_if_still_referenced(self): + sl = _SessionLocks() + sl.acquire("k1") + sl.acquire("k1") + sl.release("k1") + assert "k1" in sl._locks + sl.release("k1") + assert "k1" not in sl._locks + + +class TestResponseHelpers: + def test_error_json(self): + resp = _error_json(400, "bad request") + assert resp.status == 400 + body = json.loads(resp.body) + assert body["error"]["message"] == "bad request" + assert body["error"]["code"] == 400 + + def test_chat_completion_response(self): + result = _chat_completion_response("hello world", "test-model") + assert result["object"] == "chat.completion" + assert result["model"] == "test-model" + assert result["choices"][0]["message"]["content"] == "hello world" + assert result["choices"][0]["finish_reason"] == "stop" + assert result["id"].startswith("chatcmpl-") + + +# --------------------------------------------------------------------------- +# Integration tests — require aiohttp +# --------------------------------------------------------------------------- + + +def _make_mock_agent(response_text: str = "mock response") -> MagicMock: + agent = MagicMock() + agent.process_direct = AsyncMock(return_value=response_text) + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + return agent + + +@pytest.fixture +def mock_agent(): + return _make_mock_agent() + + +@pytest.fixture +def app(mock_agent): + return create_app(mock_agent, model_name="test-model", request_timeout=10.0) + + +@pytest.fixture +def cli(event_loop, aiohttp_client, app): + return event_loop.run_until_complete(aiohttp_client(app)) + + +# ---- Missing header tests ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_missing_session_key_returns_400(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + assert resp.status == 400 + body = await resp.json() + assert "x-session-key" in body["error"]["message"] + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_empty_session_key_returns_400(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": " "}, + ) + assert resp.status == 400 + + +# ---- Missing messages tests ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_missing_messages_returns_400(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"model": "test"}, + headers={"x-session-key": "test-key"}, + ) + assert resp.status == 400 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_no_user_message_returns_400(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "system", "content": "you are a bot"}]}, + headers={"x-session-key": "test-key"}, + ) + assert resp.status == 400 + + +# ---- Stream not supported ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_stream_true_returns_400(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + }, + headers={"x-session-key": "test-key"}, + ) + assert resp.status == 400 + body = await resp.json() + assert "stream" in body["error"]["message"].lower() + + +# ---- Successful request ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_successful_request(aiohttp_client, mock_agent): + app = create_app(mock_agent, model_name="test-model") + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "wx:dm:user1"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "mock response" + assert body["model"] == "test-model" + mock_agent.process_direct.assert_called_once_with( + content="hello", + session_key="wx:dm:user1", + channel="api", + chat_id="wx:dm:user1", + isolate_memory=True, + disabled_tools={"read_file", "write_file", "edit_file", "list_dir", "exec"}, + ) + + +# ---- Session isolation ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_session_isolation_different_keys(aiohttp_client): + """Two different session keys must route to separate session_key arguments.""" + call_log: list[str] = [] + + async def fake_process(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + call_log.append(session_key) + return f"reply to {session_key}" + + agent = MagicMock() + agent.process_direct = fake_process + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + r1 = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "msg1"}]}, + headers={"x-session-key": "wx:dm:alice"}, + ) + r2 = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "msg2"}]}, + headers={"x-session-key": "wx:group:g1:user:bob"}, + ) + + assert r1.status == 200 + assert r2.status == 200 + + b1 = await r1.json() + b2 = await r2.json() + assert b1["choices"][0]["message"]["content"] == "reply to wx:dm:alice" + assert b2["choices"][0]["message"]["content"] == "reply to wx:group:g1:user:bob" + assert call_log == ["wx:dm:alice", "wx:group:g1:user:bob"] + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_same_session_key_serialized(aiohttp_client): + """Concurrent requests with the same session key must run serially.""" + order: list[str] = [] + barrier = asyncio.Event() + + async def slow_process(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + order.append(f"start:{content}") + if content == "first": + barrier.set() + await asyncio.sleep(0.1) # hold lock + else: + await barrier.wait() # ensure "second" starts after "first" begins + order.append(f"end:{content}") + return content + + agent = MagicMock() + agent.process_direct = slow_process + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + async def send(msg): + return await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": msg}]}, + headers={"x-session-key": "same-key"}, + ) + + r1, r2 = await asyncio.gather(send("first"), send("second")) + assert r1.status == 200 + assert r2.status == 200 + # "first" must fully complete before "second" starts + assert order.index("end:first") < order.index("start:second") + + +# ---- /v1/models ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_models_endpoint(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/v1/models") + assert resp.status == 200 + body = await resp.json() + assert body["object"] == "list" + assert len(body["data"]) >= 1 + assert body["data"][0]["id"] == "test-model" + + +# ---- /health ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_health_endpoint(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/health") + assert resp.status == 200 + body = await resp.json() + assert body["status"] == "ok" + + +# ---- Multimodal content array ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent): + app = create_app(mock_agent, model_name="m") + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ], + } + ] + }, + headers={"x-session-key": "test"}, + ) + assert resp.status == 200 + mock_agent.process_direct.assert_called_once() + call_kwargs = mock_agent.process_direct.call_args + assert call_kwargs.kwargs["content"] == "describe this" + + +# --------------------------------------------------------------------------- +# Memory isolation regression tests (root cause of cross-session leakage) +# --------------------------------------------------------------------------- + + +class TestMemoryIsolation: + """Verify that per-session-key memory prevents cross-session context leakage. + + Root cause: ContextBuilder.build_system_prompt() reads a SHARED + workspace/memory/MEMORY.md into the system prompt of ALL users. + If user_1 writes "my name is Alice" and the agent persists it to + MEMORY.md, user_2/user_N will see it. + + Fix: API mode passes a per-session MemoryStore so each session reads/ + writes its own MEMORY.md. + """ + + def test_context_builder_uses_override_memory(self, tmp_path): + """build_system_prompt with memory_store= must use the override, not global.""" + from nanobot.agent.context import ContextBuilder + from nanobot.agent.memory import MemoryStore + + workspace = tmp_path / "workspace" + workspace.mkdir() + (workspace / "memory").mkdir() + (workspace / "memory" / "MEMORY.md").write_text("Global: I am shared context") + + ctx = ContextBuilder(workspace) + + # Without override → sees global memory + prompt_global = ctx.build_system_prompt() + assert "I am shared context" in prompt_global + + # With override → sees only the override's memory + override_dir = tmp_path / "isolated" / "memory" + override_dir.mkdir(parents=True) + (override_dir / "MEMORY.md").write_text("User Alice's private note") + + override_store = MemoryStore.__new__(MemoryStore) + override_store.memory_dir = override_dir + override_store.memory_file = override_dir / "MEMORY.md" + override_store.history_file = override_dir / "HISTORY.md" + + prompt_isolated = ctx.build_system_prompt(memory_store=override_store) + assert "User Alice's private note" in prompt_isolated + assert "I am shared context" not in prompt_isolated + + def test_different_session_keys_get_different_memory_dirs(self, tmp_path): + """_isolated_memory_store must return distinct paths for distinct keys.""" + from unittest.mock import MagicMock + from nanobot.agent.loop import AgentLoop + + agent = MagicMock(spec=AgentLoop) + agent.workspace = tmp_path + agent._isolated_memory_store = AgentLoop._isolated_memory_store.__get__(agent) + + store_a = agent._isolated_memory_store("wx:dm:alice") + store_b = agent._isolated_memory_store("wx:dm:bob") + + assert store_a.memory_file != store_b.memory_file + assert store_a.memory_dir != store_b.memory_dir + assert store_a.memory_file.parent.exists() + assert store_b.memory_file.parent.exists() + + def test_isolated_memory_does_not_leak_across_sessions(self, tmp_path): + """End-to-end: writing to one session's memory must not appear in another's.""" + from nanobot.agent.context import ContextBuilder + from nanobot.agent.memory import MemoryStore + + workspace = tmp_path / "workspace" + workspace.mkdir() + (workspace / "memory").mkdir() + (workspace / "memory" / "MEMORY.md").write_text("") + + ctx = ContextBuilder(workspace) + + # Simulate two isolated memory stores (as the API server would create) + def make_store(name): + d = tmp_path / "sessions" / name / "memory" + d.mkdir(parents=True) + s = MemoryStore.__new__(MemoryStore) + s.memory_dir = d + s.memory_file = d / "MEMORY.md" + s.history_file = d / "HISTORY.md" + return s + + store_alice = make_store("wx_dm_alice") + store_bob = make_store("wx_dm_bob") + + # Use unique markers that won't appear in builtin skills/prompts + alice_marker = "XYZZY_ALICE_PRIVATE_MARKER_42" + store_alice.write_long_term(alice_marker) + + # Alice's prompt sees it + prompt_alice = ctx.build_system_prompt(memory_store=store_alice) + assert alice_marker in prompt_alice + + # Bob's prompt must NOT see it + prompt_bob = ctx.build_system_prompt(memory_store=store_bob) + assert alice_marker not in prompt_bob + + # Global prompt must NOT see it either + prompt_global = ctx.build_system_prompt() + assert alice_marker not in prompt_global + + def test_build_messages_passes_memory_store(self, tmp_path): + """build_messages must forward memory_store to build_system_prompt.""" + from nanobot.agent.context import ContextBuilder + from nanobot.agent.memory import MemoryStore + + workspace = tmp_path / "workspace" + workspace.mkdir() + (workspace / "memory").mkdir() + (workspace / "memory" / "MEMORY.md").write_text("GLOBAL_SECRET") + + ctx = ContextBuilder(workspace) + + override_dir = tmp_path / "per_session" / "memory" + override_dir.mkdir(parents=True) + (override_dir / "MEMORY.md").write_text("SESSION_PRIVATE") + + override_store = MemoryStore.__new__(MemoryStore) + override_store.memory_dir = override_dir + override_store.memory_file = override_dir / "MEMORY.md" + override_store.history_file = override_dir / "HISTORY.md" + + messages = ctx.build_messages( + history=[], current_message="hello", + memory_store=override_store, + ) + system_content = messages[0]["content"] + assert "SESSION_PRIVATE" in system_content + assert "GLOBAL_SECRET" not in system_content + + def test_api_handler_passes_isolate_memory_and_disabled_tools(self): + """The API handler must call process_direct with isolate_memory=True and disabled filesystem tools.""" + import ast + from pathlib import Path + + server_path = Path(__file__).parent.parent / "nanobot" / "api" / "server.py" + source = server_path.read_text() + tree = ast.parse(source) + + found_isolate = False + found_disabled = False + for node in ast.walk(tree): + if isinstance(node, ast.keyword): + if node.arg == "isolate_memory" and isinstance(node.value, ast.Constant) and node.value.value is True: + found_isolate = True + if node.arg == "disabled_tools": + found_disabled = True + assert found_isolate, "server.py must call process_direct with isolate_memory=True" + assert found_disabled, "server.py must call process_direct with disabled_tools" + + def test_disabled_tools_constant_blocks_filesystem_and_exec(self): + """_API_DISABLED_TOOLS must include all filesystem tool names and exec.""" + from nanobot.api.server import _API_DISABLED_TOOLS + for name in ("read_file", "write_file", "edit_file", "list_dir", "exec"): + assert name in _API_DISABLED_TOOLS, f"{name} missing from _API_DISABLED_TOOLS" + + def test_system_prompt_uses_isolated_memory_path(self, tmp_path): + """When memory_store is provided, the system prompt must reference + the store's paths, NOT the global workspace/memory/MEMORY.md.""" + from nanobot.agent.context import ContextBuilder + from nanobot.agent.memory import MemoryStore + + workspace = tmp_path / "workspace" + workspace.mkdir() + (workspace / "memory").mkdir() + + ctx = ContextBuilder(workspace) + + # Default prompt references global path + default_prompt = ctx.build_system_prompt() + assert "memory/MEMORY.md" in default_prompt + + # Isolated store + iso_dir = tmp_path / "sessions" / "wx_dm_alice" / "memory" + iso_dir.mkdir(parents=True) + store = MemoryStore.__new__(MemoryStore) + store.memory_dir = iso_dir + store.memory_file = iso_dir / "MEMORY.md" + store.history_file = iso_dir / "HISTORY.md" + + iso_prompt = ctx.build_system_prompt(memory_store=store) + # Must reference the isolated path + assert str(iso_dir / "MEMORY.md") in iso_prompt + assert str(iso_dir / "HISTORY.md") in iso_prompt + # Must NOT reference the global workspace memory path + global_mem = str(workspace.resolve() / "memory" / "MEMORY.md") + assert global_mem not in iso_prompt + + def test_run_agent_loop_filters_disabled_tools(self): + """_run_agent_loop must exclude disabled tools from definitions + and reject execution of disabled tools.""" + from nanobot.agent.tools.registry import ToolRegistry + + registry = ToolRegistry() + + # Create minimal fake tool definitions + class FakeTool: + def __init__(self, n): + self._name = n + + @property + def name(self): + return self._name + + def to_schema(self): + return {"type": "function", "function": {"name": self._name, "parameters": {}}} + + def validate_params(self, params): + return [] + + async def execute(self, **kw): + return "ok" + + for n in ("read_file", "write_file", "web_search", "exec"): + registry.register(FakeTool(n)) + + all_defs = registry.get_definitions() + assert len(all_defs) == 4 + + disabled = {"read_file", "write_file"} + filtered = [d for d in all_defs + if d.get("function", {}).get("name") not in disabled] + assert len(filtered) == 2 + names = {d["function"]["name"] for d in filtered} + assert names == {"web_search", "exec"} + + +# --------------------------------------------------------------------------- +# Consolidation isolation regression tests +# --------------------------------------------------------------------------- + + +class TestConsolidationIsolation: + """Verify that memory consolidation in API (isolate_memory) mode writes + to the per-session directory and never touches global workspace/memory.""" + + @pytest.mark.asyncio + async def test_consolidate_memory_uses_provided_store(self, tmp_path): + """_consolidate_memory(memory_store=X) must call X.consolidate, + not MemoryStore(self.workspace).consolidate.""" + from unittest.mock import AsyncMock, MagicMock, patch + from nanobot.agent.loop import AgentLoop + from nanobot.agent.memory import MemoryStore + from nanobot.session.manager import Session + + agent = MagicMock(spec=AgentLoop) + agent.workspace = tmp_path / "workspace" + agent.workspace.mkdir() + agent.provider = MagicMock() + agent.model = "test" + agent.memory_window = 50 + + # Bind the real method + agent._consolidate_memory = AgentLoop._consolidate_memory.__get__(agent) + + session = Session(key="test") + session.messages = [{"role": "user", "content": "hi", "timestamp": "2025-01-01T00:00"}] * 10 + + # Create an isolated store and mock its consolidate + iso_store = MagicMock(spec=MemoryStore) + iso_store.consolidate = AsyncMock(return_value=True) + + result = await agent._consolidate_memory(session, memory_store=iso_store) + + assert result is True + iso_store.consolidate.assert_called_once() + call_args = iso_store.consolidate.call_args + assert call_args[0][0] is session # first positional arg is session + + @pytest.mark.asyncio + async def test_consolidate_memory_defaults_to_global_when_no_store(self, tmp_path): + """Without memory_store, _consolidate_memory must use MemoryStore(workspace).""" + from unittest.mock import AsyncMock, MagicMock, patch + from nanobot.agent.loop import AgentLoop + from nanobot.session.manager import Session + + agent = MagicMock(spec=AgentLoop) + agent.workspace = tmp_path / "workspace" + agent.workspace.mkdir() + (agent.workspace / "memory").mkdir() + agent.provider = MagicMock() + agent.model = "test" + agent.memory_window = 50 + agent._consolidate_memory = AgentLoop._consolidate_memory.__get__(agent) + + session = Session(key="test") + + with patch("nanobot.agent.loop.MemoryStore") as MockStore: + mock_instance = MagicMock() + mock_instance.consolidate = AsyncMock(return_value=True) + MockStore.return_value = mock_instance + + await agent._consolidate_memory(session) + + MockStore.assert_called_once_with(agent.workspace) + mock_instance.consolidate.assert_called_once() + + def test_consolidate_writes_to_isolated_dir_not_global(self, tmp_path): + """End-to-end: MemoryStore.consolidate with an isolated store must + write HISTORY.md in the isolated dir, not in workspace/memory.""" + from nanobot.agent.memory import MemoryStore + + # Set up global workspace memory + global_mem_dir = tmp_path / "workspace" / "memory" + global_mem_dir.mkdir(parents=True) + (global_mem_dir / "MEMORY.md").write_text("") + (global_mem_dir / "HISTORY.md").write_text("") + + # Set up isolated per-session store + iso_dir = tmp_path / "sessions" / "wx_dm_alice" / "memory" + iso_dir.mkdir(parents=True) + + iso_store = MemoryStore.__new__(MemoryStore) + iso_store.memory_dir = iso_dir + iso_store.memory_file = iso_dir / "MEMORY.md" + iso_store.history_file = iso_dir / "HISTORY.md" + + # Write via the isolated store + iso_store.write_long_term("Alice's private data") + iso_store.append_history("[2025-01-01 00:00] Alice asked about X") + + # Isolated store has the data + assert "Alice's private data" in iso_store.read_long_term() + assert "Alice asked about X" in iso_store.history_file.read_text() + + # Global store must NOT have it + assert (global_mem_dir / "MEMORY.md").read_text() == "" + assert (global_mem_dir / "HISTORY.md").read_text() == "" + + def test_process_message_passes_memory_store_to_consolidation_paths(self): + """Verify that _process_message passes memory_store to both + consolidation triggers (source code check).""" + import ast + from pathlib import Path + + loop_path = Path(__file__).parent.parent / "nanobot" / "agent" / "loop.py" + source = loop_path.read_text() + tree = ast.parse(source) + + # Find all calls to self._consolidate_memory inside _process_message + # and verify they all pass memory_store= + for node in ast.walk(tree): + if not isinstance(node, ast.FunctionDef) or node.name != "_process_message": + continue + consolidate_calls = [] + for child in ast.walk(node): + if (isinstance(child, ast.Call) + and isinstance(child.func, ast.Attribute) + and child.func.attr == "_consolidate_memory"): + kw_names = {kw.arg for kw in child.keywords} + consolidate_calls.append(kw_names) + + assert len(consolidate_calls) == 2, ( + f"Expected 2 _consolidate_memory calls in _process_message, " + f"found {len(consolidate_calls)}" + ) + for i, kw_names in enumerate(consolidate_calls): + assert "memory_store" in kw_names, ( + f"_consolidate_memory call #{i+1} in _process_message " + f"missing memory_store= keyword argument" + ) + + +# --------------------------------------------------------------------------- +# Empty response retry + fallback tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_empty_response_retry_then_success(aiohttp_client): + """First call returns empty → retry once → second call returns real text.""" + call_count = 0 + + async def sometimes_empty(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return "" + return "recovered response" + + agent = MagicMock() + agent.process_direct = sometimes_empty + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "retry-test"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "recovered response" + assert call_count == 2 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_empty_response_both_empty_returns_fallback(aiohttp_client): + """Both calls return empty → must use the fallback text.""" + call_count = 0 + + async def always_empty(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + nonlocal call_count + call_count += 1 + return "" + + agent = MagicMock() + agent.process_direct = always_empty + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "fallback-test"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give." + assert call_count == 2 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_whitespace_only_response_triggers_retry(aiohttp_client): + """Whitespace-only response should be treated as empty and trigger retry.""" + call_count = 0 + + async def whitespace_then_ok(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return " \n " + return "real answer" + + agent = MagicMock() + agent.process_direct = whitespace_then_ok + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "ws-test"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "real answer" + assert call_count == 2 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_none_response_triggers_retry(aiohttp_client): + """None response should be treated as empty and trigger retry.""" + call_count = 0 + + async def none_then_ok(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return None + return "got it" + + agent = MagicMock() + agent.process_direct = none_then_ok + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "none-test"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "got it" + assert call_count == 2 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_nonempty_response_no_retry(aiohttp_client): + """A normal non-empty response must NOT trigger a retry.""" + call_count = 0 + + async def normal_response(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + nonlocal call_count + call_count += 1 + return "immediate answer" + + agent = MagicMock() + agent.process_direct = normal_response + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "normal-test"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "immediate answer" + assert call_count == 1