From 80219baf255d2f75b15edac616f24b7b0025ded1 Mon Sep 17 00:00:00 2001 From: Tink Date: Sun, 1 Mar 2026 10:53:45 +0800 Subject: [PATCH 01/22] feat(api): add OpenAI-compatible endpoint with x-session-key isolation --- examples/curl.txt | 96 ++++ nanobot/agent/context.py | 36 +- nanobot/agent/loop.py | 80 ++- nanobot/api/__init__.py | 1 + nanobot/api/server.py | 222 ++++++++ nanobot/cli/commands.py | 77 +++ pyproject.toml | 4 + tests/test_consolidate_offset.py | 14 +- tests/test_openai_api.py | 883 +++++++++++++++++++++++++++++++ 9 files changed, 1387 insertions(+), 26 deletions(-) create mode 100644 examples/curl.txt create mode 100644 nanobot/api/__init__.py create mode 100644 nanobot/api/server.py create mode 100644 tests/test_openai_api.py diff --git a/examples/curl.txt b/examples/curl.txt new file mode 100644 index 000000000..70dc4dfe7 --- /dev/null +++ b/examples/curl.txt @@ -0,0 +1,96 @@ +# ============================================================================= +# nanobot OpenAI-Compatible API — curl examples +# ============================================================================= +# +# Prerequisites: +# pip install nanobot-ai[api] # installs aiohttp +# nanobot serve --port 8900 # start the API server +# +# The x-session-key header is REQUIRED for every request. +# Convention: +# Private chat: wx:dm:{sender_id} +# Group @: wx:group:{group_id}:user:{sender_id} +# ============================================================================= + +# --- 1. Basic chat completion (private chat) --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-session-key: wx:dm:user_alice" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "Hello, who are you?"} + ] + }' + +# --- 2. Follow-up in the same session (context is remembered) --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-session-key: wx:dm:user_alice" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "What did I just ask you?"} + ] + }' + +# --- 3. Different user — isolated session --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-session-key: wx:dm:user_bob" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "What did I just ask you?"} + ] + }' +# ↑ Bob gets a fresh context — he never asked anything before. + +# --- 4. Group chat — per-user session within a group --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-session-key: wx:group:group_abc:user:user_alice" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "Summarize our discussion"} + ] + }' + +# --- 5. List available models --- + +curl http://localhost:8900/v1/models + +# --- 6. Health check --- + +curl http://localhost:8900/health + +# --- 7. Missing header — expect 400 --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "hello"} + ] + }' +# ↑ Returns: {"error": {"message": "Missing required header: x-session-key", ...}} + +# --- 8. Stream not yet supported — expect 400 --- + +curl -X POST http://localhost:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "x-session-key: wx:dm:user_alice" \ + -d '{ + "model": "nanobot", + "messages": [ + {"role": "user", "content": "hello"} + ], + "stream": true + }' +# ↑ Returns: {"error": {"message": "stream=true is not supported yet...", ...}} diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index be0ec5996..3665d7f3a 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -23,15 +23,25 @@ class ContextBuilder: self.memory = MemoryStore(workspace) self.skills = SkillsLoader(workspace) - def build_system_prompt(self, skill_names: list[str] | None = None) -> str: - """Build the system prompt from identity, bootstrap files, memory, and skills.""" - parts = [self._get_identity()] + def build_system_prompt( + self, + skill_names: list[str] | None = None, + memory_store: "MemoryStore | None" = None, + ) -> str: + """Build the system prompt from identity, bootstrap files, memory, and skills. + + Args: + memory_store: If provided, use this MemoryStore instead of the default + workspace-level one. Used for per-session memory isolation. + """ + parts = [self._get_identity(memory_store=memory_store)] bootstrap = self._load_bootstrap_files() if bootstrap: parts.append(bootstrap) - memory = self.memory.get_memory_context() + store = memory_store or self.memory + memory = store.get_memory_context() if memory: parts.append(f"# Memory\n\n{memory}") @@ -52,12 +62,19 @@ Skills with available="false" need dependencies installed first - you can try in return "\n\n---\n\n".join(parts) - def _get_identity(self) -> str: + def _get_identity(self, memory_store: "MemoryStore | None" = None) -> str: """Get the core identity section.""" workspace_path = str(self.workspace.expanduser().resolve()) system = platform.system() runtime = f"{'macOS' if system == 'Darwin' else system} {platform.machine()}, Python {platform.python_version()}" - + + if memory_store is not None: + mem_path = str(memory_store.memory_file) + hist_path = str(memory_store.history_file) + else: + mem_path = f"{workspace_path}/memory/MEMORY.md" + hist_path = f"{workspace_path}/memory/HISTORY.md" + return f"""# nanobot 🐈 You are nanobot, a helpful AI assistant. @@ -67,8 +84,8 @@ You are nanobot, a helpful AI assistant. ## Workspace Your workspace is at: {workspace_path} -- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here) -- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM]. +- Long-term memory: {mem_path} (write important facts here) +- History log: {hist_path} (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM]. - Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md ## nanobot Guidelines @@ -110,10 +127,11 @@ Reply directly with text for conversations. Only use the 'message' tool to send media: list[str] | None = None, channel: str | None = None, chat_id: str | None = None, + memory_store: "MemoryStore | None" = None, ) -> list[dict[str, Any]]: """Build the complete message list for an LLM call.""" return [ - {"role": "system", "content": self.build_system_prompt(skill_names)}, + {"role": "system", "content": self.build_system_prompt(skill_names, memory_store=memory_store)}, *history, {"role": "user", "content": self._build_runtime_context(channel, chat_id)}, {"role": "user", "content": self._build_user_content(current_message, media)}, diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index b605ae4a9..6a0d24f26 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -174,6 +174,7 @@ class AgentLoop: self, initial_messages: list[dict], on_progress: Callable[..., Awaitable[None]] | None = None, + disabled_tools: set[str] | None = None, ) -> tuple[str | None, list[str], list[dict]]: """Run the agent iteration loop. Returns (final_content, tools_used, messages).""" messages = initial_messages @@ -181,12 +182,19 @@ class AgentLoop: final_content = None tools_used: list[str] = [] + # Build tool definitions, filtering out disabled tools + if disabled_tools: + tool_defs = [d for d in self.tools.get_definitions() + if d.get("function", {}).get("name") not in disabled_tools] + else: + tool_defs = self.tools.get_definitions() + while iteration < self.max_iterations: iteration += 1 response = await self.provider.chat( messages=messages, - tools=self.tools.get_definitions(), + tools=tool_defs, model=self.model, temperature=self.temperature, max_tokens=self.max_tokens, @@ -219,7 +227,10 @@ class AgentLoop: tools_used.append(tool_call.name) args_str = json.dumps(tool_call.arguments, ensure_ascii=False) logger.info("Tool call: {}({})", tool_call.name, args_str[:200]) - result = await self.tools.execute(tool_call.name, tool_call.arguments) + if disabled_tools and tool_call.name in disabled_tools: + result = f"Error: Tool '{tool_call.name}' is not available in this mode." + else: + result = await self.tools.execute(tool_call.name, tool_call.arguments) messages = self.context.add_tool_result( messages, tool_call.id, tool_call.name, result ) @@ -322,6 +333,8 @@ class AgentLoop: msg: InboundMessage, session_key: str | None = None, on_progress: Callable[[str], Awaitable[None]] | None = None, + memory_store: MemoryStore | None = None, + disabled_tools: set[str] | None = None, ) -> OutboundMessage | None: """Process a single inbound message and return the response.""" # System messages: parse origin from chat_id ("channel:chat_id") @@ -336,8 +349,11 @@ class AgentLoop: messages = self.context.build_messages( history=history, current_message=msg.content, channel=channel, chat_id=chat_id, + memory_store=memory_store, + ) + final_content, _, all_msgs = await self._run_agent_loop( + messages, disabled_tools=disabled_tools, ) - final_content, _, all_msgs = await self._run_agent_loop(messages) self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) return OutboundMessage(channel=channel, chat_id=chat_id, @@ -360,7 +376,9 @@ class AgentLoop: if snapshot: temp = Session(key=session.key) temp.messages = list(snapshot) - if not await self._consolidate_memory(temp, archive_all=True): + if not await self._consolidate_memory( + temp, archive_all=True, memory_store=memory_store, + ): return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content="Memory archival failed, session not cleared. Please try again.", @@ -393,7 +411,9 @@ class AgentLoop: async def _consolidate_and_unlock(): try: async with lock: - await self._consolidate_memory(session) + await self._consolidate_memory( + session, memory_store=memory_store, + ) finally: self._consolidating.discard(session.key) if not lock.locked(): @@ -416,6 +436,7 @@ class AgentLoop: current_message=msg.content, media=msg.media if msg.media else None, channel=msg.channel, chat_id=msg.chat_id, + memory_store=memory_store, ) async def _bus_progress(content: str, *, tool_hint: bool = False) -> None: @@ -428,6 +449,7 @@ class AgentLoop: final_content, _, all_msgs = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, + disabled_tools=disabled_tools, ) if final_content is None: @@ -470,9 +492,30 @@ class AgentLoop: session.messages.append(entry) session.updated_at = datetime.now() - async def _consolidate_memory(self, session, archive_all: bool = False) -> bool: - """Delegate to MemoryStore.consolidate(). Returns True on success.""" - return await MemoryStore(self.workspace).consolidate( + def _isolated_memory_store(self, session_key: str) -> MemoryStore: + """Return a per-session-key MemoryStore for multi-tenant isolation.""" + from nanobot.utils.helpers import safe_filename + safe_key = safe_filename(session_key.replace(":", "_")) + memory_dir = self.workspace / "sessions" / safe_key / "memory" + memory_dir.mkdir(parents=True, exist_ok=True) + store = MemoryStore.__new__(MemoryStore) + store.memory_dir = memory_dir + store.memory_file = memory_dir / "MEMORY.md" + store.history_file = memory_dir / "HISTORY.md" + return store + + async def _consolidate_memory( + self, session, archive_all: bool = False, + memory_store: MemoryStore | None = None, + ) -> bool: + """Delegate to MemoryStore.consolidate(). Returns True on success. + + Args: + memory_store: If provided, consolidate into this store instead of + the default workspace-level one. + """ + store = memory_store or MemoryStore(self.workspace) + return await store.consolidate( session, self.provider, self.model, archive_all=archive_all, memory_window=self.memory_window, ) @@ -484,9 +527,26 @@ class AgentLoop: channel: str = "cli", chat_id: str = "direct", on_progress: Callable[[str], Awaitable[None]] | None = None, + isolate_memory: bool = False, + disabled_tools: set[str] | None = None, ) -> str: - """Process a message directly (for CLI or cron usage).""" + """Process a message directly (for CLI or cron usage). + + Args: + isolate_memory: When True, use a per-session-key memory directory + instead of the shared workspace memory. This prevents context + leakage between different session keys in multi-tenant (API) mode. + disabled_tools: Tool names to exclude from the LLM tool list and + reject at execution time. Use to block filesystem access in + multi-tenant API mode. + """ await self._connect_mcp() + memory_store: MemoryStore | None = None + if isolate_memory: + memory_store = self._isolated_memory_store(session_key) msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content) - response = await self._process_message(msg, session_key=session_key, on_progress=on_progress) + response = await self._process_message( + msg, session_key=session_key, on_progress=on_progress, + memory_store=memory_store, disabled_tools=disabled_tools, + ) return response.content if response else "" diff --git a/nanobot/api/__init__.py b/nanobot/api/__init__.py new file mode 100644 index 000000000..f0c504cc1 --- /dev/null +++ b/nanobot/api/__init__.py @@ -0,0 +1 @@ +"""OpenAI-compatible HTTP API for nanobot.""" diff --git a/nanobot/api/server.py b/nanobot/api/server.py new file mode 100644 index 000000000..a3077537f --- /dev/null +++ b/nanobot/api/server.py @@ -0,0 +1,222 @@ +"""OpenAI-compatible HTTP API server for nanobot. + +Provides /v1/chat/completions and /v1/models endpoints. +Session isolation is enforced via the x-session-key request header. +""" + +from __future__ import annotations + +import asyncio +import time +import uuid +from typing import Any + +from aiohttp import web +from loguru import logger + +# Tools that must NOT run in multi-tenant API mode. +# Filesystem tools allow the LLM to read/write the shared workspace (including +# global MEMORY.md), and exec allows shell commands that can bypass filesystem +# restrictions (e.g. `cat ~/.nanobot/workspace/memory/MEMORY.md`). +_API_DISABLED_TOOLS: set[str] = { + "read_file", "write_file", "edit_file", "list_dir", "exec", +} + + +# --------------------------------------------------------------------------- +# Per-session-key lock manager +# --------------------------------------------------------------------------- + +class _SessionLocks: + """Manages one asyncio.Lock per session key for serial execution.""" + + def __init__(self) -> None: + self._locks: dict[str, asyncio.Lock] = {} + self._ref: dict[str, int] = {} # reference count for cleanup + + def acquire(self, key: str) -> asyncio.Lock: + if key not in self._locks: + self._locks[key] = asyncio.Lock() + self._ref[key] = 0 + self._ref[key] += 1 + return self._locks[key] + + def release(self, key: str) -> None: + self._ref[key] -= 1 + if self._ref[key] <= 0: + self._locks.pop(key, None) + self._ref.pop(key, None) + + +# --------------------------------------------------------------------------- +# Response helpers +# --------------------------------------------------------------------------- + +def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response: + return web.json_response( + {"error": {"message": message, "type": err_type, "code": status}}, + status=status, + ) + + +def _chat_completion_response(content: str, model: str) -> dict[str, Any]: + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + +# --------------------------------------------------------------------------- +# Route handlers +# --------------------------------------------------------------------------- + +async def handle_chat_completions(request: web.Request) -> web.Response: + """POST /v1/chat/completions""" + + # --- x-session-key validation --- + session_key = request.headers.get("x-session-key", "").strip() + if not session_key: + return _error_json(400, "Missing required header: x-session-key") + + # --- Parse body --- + try: + body = await request.json() + except Exception: + return _error_json(400, "Invalid JSON body") + + messages = body.get("messages") + if not messages or not isinstance(messages, list): + return _error_json(400, "messages field is required and must be a non-empty array") + + # Stream not yet supported + if body.get("stream", False): + return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.") + + # Extract last user message — nanobot manages its own multi-turn history + user_content = None + for msg in reversed(messages): + if msg.get("role") == "user": + user_content = msg.get("content", "") + break + if user_content is None: + return _error_json(400, "messages must contain at least one user message") + if isinstance(user_content, list): + # Multi-modal content array — extract text parts + user_content = " ".join( + part.get("text", "") for part in user_content if part.get("type") == "text" + ) + + agent_loop = request.app["agent_loop"] + timeout_s: float = request.app.get("request_timeout", 120.0) + model_name: str = body.get("model") or request.app.get("model_name", "nanobot") + locks: _SessionLocks = request.app["session_locks"] + + safe_key = session_key[:32] + ("…" if len(session_key) > 32 else "") + logger.info("API request session_key={} content={}", safe_key, user_content[:80]) + + _FALLBACK = "I've completed processing but have no response to give." + + lock = locks.acquire(session_key) + try: + async with lock: + try: + response_text = await asyncio.wait_for( + agent_loop.process_direct( + content=user_content, + session_key=session_key, + channel="api", + chat_id=session_key, + isolate_memory=True, + disabled_tools=_API_DISABLED_TOOLS, + ), + timeout=timeout_s, + ) + + if not response_text or not response_text.strip(): + logger.warning("Empty response for session {}, retrying", safe_key) + response_text = await asyncio.wait_for( + agent_loop.process_direct( + content=user_content, + session_key=session_key, + channel="api", + chat_id=session_key, + isolate_memory=True, + disabled_tools=_API_DISABLED_TOOLS, + ), + timeout=timeout_s, + ) + if not response_text or not response_text.strip(): + logger.warning("Empty response after retry for session {}, using fallback", safe_key) + response_text = _FALLBACK + + except asyncio.TimeoutError: + return _error_json(504, f"Request timed out after {timeout_s}s") + except Exception: + logger.exception("Error processing request for session {}", safe_key) + return _error_json(500, "Internal server error", err_type="server_error") + finally: + locks.release(session_key) + + return web.json_response(_chat_completion_response(response_text, model_name)) + + +async def handle_models(request: web.Request) -> web.Response: + """GET /v1/models""" + model_name = request.app.get("model_name", "nanobot") + return web.json_response({ + "object": "list", + "data": [ + { + "id": model_name, + "object": "model", + "created": 0, + "owned_by": "nanobot", + } + ], + }) + + +async def handle_health(request: web.Request) -> web.Response: + """GET /health""" + return web.json_response({"status": "ok"}) + + +# --------------------------------------------------------------------------- +# App factory +# --------------------------------------------------------------------------- + +def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0) -> web.Application: + """Create the aiohttp application. + + Args: + agent_loop: An initialized AgentLoop instance. + model_name: Model name reported in responses. + request_timeout: Per-request timeout in seconds. + """ + app = web.Application() + app["agent_loop"] = agent_loop + app["model_name"] = model_name + app["request_timeout"] = request_timeout + app["session_locks"] = _SessionLocks() + + app.router.add_post("/v1/chat/completions", handle_chat_completions) + app.router.add_get("/v1/models", handle_models) + app.router.add_get("/health", handle_health) + return app + + +def run_server(agent_loop, host: str = "0.0.0.0", port: int = 8900, + model_name: str = "nanobot", request_timeout: float = 120.0) -> None: + """Create and run the server (blocking).""" + app = create_app(agent_loop, model_name=model_name, request_timeout=request_timeout) + web.run_app(app, host=host, port=port, print=lambda msg: logger.info(msg)) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index fc4c261ea..208b4e742 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -237,6 +237,83 @@ def _make_provider(config: Config): ) +# ============================================================================ +# OpenAI-Compatible API Server +# ============================================================================ + + +@app.command() +def serve( + port: int = typer.Option(8900, "--port", "-p", help="API server port"), + host: str = typer.Option("0.0.0.0", "--host", "-H", help="Bind address"), + timeout: float = typer.Option(120.0, "--timeout", "-t", help="Per-request timeout (seconds)"), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"), +): + """Start the OpenAI-compatible API server (/v1/chat/completions).""" + try: + from aiohttp import web # noqa: F401 + except ImportError: + console.print("[red]aiohttp is required. Install with: pip install aiohttp[/red]") + raise typer.Exit(1) + + from nanobot.config.loader import load_config + from nanobot.api.server import create_app + from loguru import logger + + if verbose: + logger.enable("nanobot") + else: + logger.disable("nanobot") + + config = load_config() + sync_workspace_templates(config.workspace_path) + provider = _make_provider(config) + + from nanobot.bus.queue import MessageBus + from nanobot.agent.loop import AgentLoop + from nanobot.session.manager import SessionManager + + bus = MessageBus() + session_manager = SessionManager(config.workspace_path) + agent_loop = AgentLoop( + bus=bus, + provider=provider, + workspace=config.workspace_path, + model=config.agents.defaults.model, + temperature=config.agents.defaults.temperature, + max_tokens=config.agents.defaults.max_tokens, + max_iterations=config.agents.defaults.max_tool_iterations, + memory_window=config.agents.defaults.memory_window, + brave_api_key=config.tools.web.search.api_key or None, + exec_config=config.tools.exec, + restrict_to_workspace=config.tools.restrict_to_workspace, + session_manager=session_manager, + mcp_servers=config.tools.mcp_servers, + channels_config=config.channels, + ) + + model_name = config.agents.defaults.model + console.print(f"{__logo__} Starting OpenAI-compatible API server") + console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions") + console.print(f" [cyan]Model[/cyan] : {model_name}") + console.print(f" [cyan]Timeout[/cyan] : {timeout}s") + console.print(f" [cyan]Header[/cyan] : x-session-key (required)") + console.print() + + api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout) + + async def on_startup(_app): + await agent_loop._connect_mcp() + + async def on_cleanup(_app): + await agent_loop.close_mcp() + + api_app.on_startup.append(on_startup) + api_app.on_cleanup.append(on_cleanup) + + web.run_app(api_app, host=host, port=port, print=lambda msg: logger.info(msg)) + + # ============================================================================ # Gateway / Server # ============================================================================ diff --git a/pyproject.toml b/pyproject.toml index 20dcb1e01..f71faa146 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,9 @@ dependencies = [ ] [project.optional-dependencies] +api = [ + "aiohttp>=3.9.0,<4.0.0", +] matrix = [ "matrix-nio[e2e]>=0.25.2", "mistune>=3.0.0,<4.0.0", @@ -53,6 +56,7 @@ matrix = [ dev = [ "pytest>=9.0.0,<10.0.0", "pytest-asyncio>=1.3.0,<2.0.0", + "aiohttp>=3.9.0,<4.0.0", "ruff>=0.1.0", ] diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py index 675512406..fc72e0a63 100644 --- a/tests/test_consolidate_offset.py +++ b/tests/test_consolidate_offset.py @@ -509,7 +509,7 @@ class TestConsolidationDeduplicationGuard: consolidation_calls = 0 - async def _fake_consolidate(_session, archive_all: bool = False) -> None: + async def _fake_consolidate(_session, archive_all: bool = False, **kw) -> None: nonlocal consolidation_calls consolidation_calls += 1 await asyncio.sleep(0.05) @@ -555,7 +555,7 @@ class TestConsolidationDeduplicationGuard: active = 0 max_active = 0 - async def _fake_consolidate(_session, archive_all: bool = False) -> None: + async def _fake_consolidate(_session, archive_all: bool = False, **kw) -> None: nonlocal consolidation_calls, active, max_active consolidation_calls += 1 active += 1 @@ -605,7 +605,7 @@ class TestConsolidationDeduplicationGuard: started = asyncio.Event() - async def _slow_consolidate(_session, archive_all: bool = False) -> None: + async def _slow_consolidate(_session, archive_all: bool = False, **kw) -> None: started.set() await asyncio.sleep(0.1) @@ -652,7 +652,7 @@ class TestConsolidationDeduplicationGuard: release = asyncio.Event() archived_count = 0 - async def _fake_consolidate(sess, archive_all: bool = False) -> bool: + async def _fake_consolidate(sess, archive_all: bool = False, **kw) -> bool: nonlocal archived_count if archive_all: archived_count = len(sess.messages) @@ -707,7 +707,7 @@ class TestConsolidationDeduplicationGuard: loop.sessions.save(session) before_count = len(session.messages) - async def _failing_consolidate(sess, archive_all: bool = False) -> bool: + async def _failing_consolidate(sess, archive_all: bool = False, **kw) -> bool: if archive_all: return False return True @@ -754,7 +754,7 @@ class TestConsolidationDeduplicationGuard: release = asyncio.Event() archived_count = -1 - async def _fake_consolidate(sess, archive_all: bool = False) -> bool: + async def _fake_consolidate(sess, archive_all: bool = False, **kw) -> bool: nonlocal archived_count if archive_all: archived_count = len(sess.messages) @@ -815,7 +815,7 @@ class TestConsolidationDeduplicationGuard: loop._consolidation_locks.setdefault(session.key, asyncio.Lock()) assert session.key in loop._consolidation_locks - async def _ok_consolidate(sess, archive_all: bool = False) -> bool: + async def _ok_consolidate(sess, archive_all: bool = False, **kw) -> bool: return True loop._consolidate_memory = _ok_consolidate # type: ignore[method-assign] diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py new file mode 100644 index 000000000..b4d831579 --- /dev/null +++ b/tests/test_openai_api.py @@ -0,0 +1,883 @@ +"""Tests for the OpenAI-compatible API server.""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.api.server import _SessionLocks, _chat_completion_response, _error_json, create_app + +# --------------------------------------------------------------------------- +# aiohttp test client helper +# --------------------------------------------------------------------------- + +try: + from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop + from aiohttp import web + + HAS_AIOHTTP = True +except ImportError: + HAS_AIOHTTP = False + +pytest_plugins = ("pytest_asyncio",) + +# --------------------------------------------------------------------------- +# Unit tests — no aiohttp required +# --------------------------------------------------------------------------- + + +class TestSessionLocks: + def test_acquire_creates_lock(self): + sl = _SessionLocks() + lock = sl.acquire("k1") + assert isinstance(lock, asyncio.Lock) + + def test_same_key_returns_same_lock(self): + sl = _SessionLocks() + l1 = sl.acquire("k1") + l2 = sl.acquire("k1") + assert l1 is l2 + + def test_different_keys_different_locks(self): + sl = _SessionLocks() + l1 = sl.acquire("k1") + l2 = sl.acquire("k2") + assert l1 is not l2 + + def test_release_cleans_up(self): + sl = _SessionLocks() + sl.acquire("k1") + sl.release("k1") + assert "k1" not in sl._locks + + def test_release_keeps_lock_if_still_referenced(self): + sl = _SessionLocks() + sl.acquire("k1") + sl.acquire("k1") + sl.release("k1") + assert "k1" in sl._locks + sl.release("k1") + assert "k1" not in sl._locks + + +class TestResponseHelpers: + def test_error_json(self): + resp = _error_json(400, "bad request") + assert resp.status == 400 + body = json.loads(resp.body) + assert body["error"]["message"] == "bad request" + assert body["error"]["code"] == 400 + + def test_chat_completion_response(self): + result = _chat_completion_response("hello world", "test-model") + assert result["object"] == "chat.completion" + assert result["model"] == "test-model" + assert result["choices"][0]["message"]["content"] == "hello world" + assert result["choices"][0]["finish_reason"] == "stop" + assert result["id"].startswith("chatcmpl-") + + +# --------------------------------------------------------------------------- +# Integration tests — require aiohttp +# --------------------------------------------------------------------------- + + +def _make_mock_agent(response_text: str = "mock response") -> MagicMock: + agent = MagicMock() + agent.process_direct = AsyncMock(return_value=response_text) + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + return agent + + +@pytest.fixture +def mock_agent(): + return _make_mock_agent() + + +@pytest.fixture +def app(mock_agent): + return create_app(mock_agent, model_name="test-model", request_timeout=10.0) + + +@pytest.fixture +def cli(event_loop, aiohttp_client, app): + return event_loop.run_until_complete(aiohttp_client(app)) + + +# ---- Missing header tests ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_missing_session_key_returns_400(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + ) + assert resp.status == 400 + body = await resp.json() + assert "x-session-key" in body["error"]["message"] + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_empty_session_key_returns_400(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": " "}, + ) + assert resp.status == 400 + + +# ---- Missing messages tests ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_missing_messages_returns_400(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"model": "test"}, + headers={"x-session-key": "test-key"}, + ) + assert resp.status == 400 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_no_user_message_returns_400(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "system", "content": "you are a bot"}]}, + headers={"x-session-key": "test-key"}, + ) + assert resp.status == 400 + + +# ---- Stream not supported ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_stream_true_returns_400(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + }, + headers={"x-session-key": "test-key"}, + ) + assert resp.status == 400 + body = await resp.json() + assert "stream" in body["error"]["message"].lower() + + +# ---- Successful request ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_successful_request(aiohttp_client, mock_agent): + app = create_app(mock_agent, model_name="test-model") + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "wx:dm:user1"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "mock response" + assert body["model"] == "test-model" + mock_agent.process_direct.assert_called_once_with( + content="hello", + session_key="wx:dm:user1", + channel="api", + chat_id="wx:dm:user1", + isolate_memory=True, + disabled_tools={"read_file", "write_file", "edit_file", "list_dir", "exec"}, + ) + + +# ---- Session isolation ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_session_isolation_different_keys(aiohttp_client): + """Two different session keys must route to separate session_key arguments.""" + call_log: list[str] = [] + + async def fake_process(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + call_log.append(session_key) + return f"reply to {session_key}" + + agent = MagicMock() + agent.process_direct = fake_process + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + r1 = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "msg1"}]}, + headers={"x-session-key": "wx:dm:alice"}, + ) + r2 = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "msg2"}]}, + headers={"x-session-key": "wx:group:g1:user:bob"}, + ) + + assert r1.status == 200 + assert r2.status == 200 + + b1 = await r1.json() + b2 = await r2.json() + assert b1["choices"][0]["message"]["content"] == "reply to wx:dm:alice" + assert b2["choices"][0]["message"]["content"] == "reply to wx:group:g1:user:bob" + assert call_log == ["wx:dm:alice", "wx:group:g1:user:bob"] + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_same_session_key_serialized(aiohttp_client): + """Concurrent requests with the same session key must run serially.""" + order: list[str] = [] + barrier = asyncio.Event() + + async def slow_process(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + order.append(f"start:{content}") + if content == "first": + barrier.set() + await asyncio.sleep(0.1) # hold lock + else: + await barrier.wait() # ensure "second" starts after "first" begins + order.append(f"end:{content}") + return content + + agent = MagicMock() + agent.process_direct = slow_process + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + async def send(msg): + return await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": msg}]}, + headers={"x-session-key": "same-key"}, + ) + + r1, r2 = await asyncio.gather(send("first"), send("second")) + assert r1.status == 200 + assert r2.status == 200 + # "first" must fully complete before "second" starts + assert order.index("end:first") < order.index("start:second") + + +# ---- /v1/models ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_models_endpoint(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/v1/models") + assert resp.status == 200 + body = await resp.json() + assert body["object"] == "list" + assert len(body["data"]) >= 1 + assert body["data"][0]["id"] == "test-model" + + +# ---- /health ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_health_endpoint(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.get("/health") + assert resp.status == 200 + body = await resp.json() + assert body["status"] == "ok" + + +# ---- Multimodal content array ---- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent): + app = create_app(mock_agent, model_name="m") + client = await aiohttp_client(app) + resp = await client.post( + "/v1/chat/completions", + json={ + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ], + } + ] + }, + headers={"x-session-key": "test"}, + ) + assert resp.status == 200 + mock_agent.process_direct.assert_called_once() + call_kwargs = mock_agent.process_direct.call_args + assert call_kwargs.kwargs["content"] == "describe this" + + +# --------------------------------------------------------------------------- +# Memory isolation regression tests (root cause of cross-session leakage) +# --------------------------------------------------------------------------- + + +class TestMemoryIsolation: + """Verify that per-session-key memory prevents cross-session context leakage. + + Root cause: ContextBuilder.build_system_prompt() reads a SHARED + workspace/memory/MEMORY.md into the system prompt of ALL users. + If user_1 writes "my name is Alice" and the agent persists it to + MEMORY.md, user_2/user_N will see it. + + Fix: API mode passes a per-session MemoryStore so each session reads/ + writes its own MEMORY.md. + """ + + def test_context_builder_uses_override_memory(self, tmp_path): + """build_system_prompt with memory_store= must use the override, not global.""" + from nanobot.agent.context import ContextBuilder + from nanobot.agent.memory import MemoryStore + + workspace = tmp_path / "workspace" + workspace.mkdir() + (workspace / "memory").mkdir() + (workspace / "memory" / "MEMORY.md").write_text("Global: I am shared context") + + ctx = ContextBuilder(workspace) + + # Without override → sees global memory + prompt_global = ctx.build_system_prompt() + assert "I am shared context" in prompt_global + + # With override → sees only the override's memory + override_dir = tmp_path / "isolated" / "memory" + override_dir.mkdir(parents=True) + (override_dir / "MEMORY.md").write_text("User Alice's private note") + + override_store = MemoryStore.__new__(MemoryStore) + override_store.memory_dir = override_dir + override_store.memory_file = override_dir / "MEMORY.md" + override_store.history_file = override_dir / "HISTORY.md" + + prompt_isolated = ctx.build_system_prompt(memory_store=override_store) + assert "User Alice's private note" in prompt_isolated + assert "I am shared context" not in prompt_isolated + + def test_different_session_keys_get_different_memory_dirs(self, tmp_path): + """_isolated_memory_store must return distinct paths for distinct keys.""" + from unittest.mock import MagicMock + from nanobot.agent.loop import AgentLoop + + agent = MagicMock(spec=AgentLoop) + agent.workspace = tmp_path + agent._isolated_memory_store = AgentLoop._isolated_memory_store.__get__(agent) + + store_a = agent._isolated_memory_store("wx:dm:alice") + store_b = agent._isolated_memory_store("wx:dm:bob") + + assert store_a.memory_file != store_b.memory_file + assert store_a.memory_dir != store_b.memory_dir + assert store_a.memory_file.parent.exists() + assert store_b.memory_file.parent.exists() + + def test_isolated_memory_does_not_leak_across_sessions(self, tmp_path): + """End-to-end: writing to one session's memory must not appear in another's.""" + from nanobot.agent.context import ContextBuilder + from nanobot.agent.memory import MemoryStore + + workspace = tmp_path / "workspace" + workspace.mkdir() + (workspace / "memory").mkdir() + (workspace / "memory" / "MEMORY.md").write_text("") + + ctx = ContextBuilder(workspace) + + # Simulate two isolated memory stores (as the API server would create) + def make_store(name): + d = tmp_path / "sessions" / name / "memory" + d.mkdir(parents=True) + s = MemoryStore.__new__(MemoryStore) + s.memory_dir = d + s.memory_file = d / "MEMORY.md" + s.history_file = d / "HISTORY.md" + return s + + store_alice = make_store("wx_dm_alice") + store_bob = make_store("wx_dm_bob") + + # Use unique markers that won't appear in builtin skills/prompts + alice_marker = "XYZZY_ALICE_PRIVATE_MARKER_42" + store_alice.write_long_term(alice_marker) + + # Alice's prompt sees it + prompt_alice = ctx.build_system_prompt(memory_store=store_alice) + assert alice_marker in prompt_alice + + # Bob's prompt must NOT see it + prompt_bob = ctx.build_system_prompt(memory_store=store_bob) + assert alice_marker not in prompt_bob + + # Global prompt must NOT see it either + prompt_global = ctx.build_system_prompt() + assert alice_marker not in prompt_global + + def test_build_messages_passes_memory_store(self, tmp_path): + """build_messages must forward memory_store to build_system_prompt.""" + from nanobot.agent.context import ContextBuilder + from nanobot.agent.memory import MemoryStore + + workspace = tmp_path / "workspace" + workspace.mkdir() + (workspace / "memory").mkdir() + (workspace / "memory" / "MEMORY.md").write_text("GLOBAL_SECRET") + + ctx = ContextBuilder(workspace) + + override_dir = tmp_path / "per_session" / "memory" + override_dir.mkdir(parents=True) + (override_dir / "MEMORY.md").write_text("SESSION_PRIVATE") + + override_store = MemoryStore.__new__(MemoryStore) + override_store.memory_dir = override_dir + override_store.memory_file = override_dir / "MEMORY.md" + override_store.history_file = override_dir / "HISTORY.md" + + messages = ctx.build_messages( + history=[], current_message="hello", + memory_store=override_store, + ) + system_content = messages[0]["content"] + assert "SESSION_PRIVATE" in system_content + assert "GLOBAL_SECRET" not in system_content + + def test_api_handler_passes_isolate_memory_and_disabled_tools(self): + """The API handler must call process_direct with isolate_memory=True and disabled filesystem tools.""" + import ast + from pathlib import Path + + server_path = Path(__file__).parent.parent / "nanobot" / "api" / "server.py" + source = server_path.read_text() + tree = ast.parse(source) + + found_isolate = False + found_disabled = False + for node in ast.walk(tree): + if isinstance(node, ast.keyword): + if node.arg == "isolate_memory" and isinstance(node.value, ast.Constant) and node.value.value is True: + found_isolate = True + if node.arg == "disabled_tools": + found_disabled = True + assert found_isolate, "server.py must call process_direct with isolate_memory=True" + assert found_disabled, "server.py must call process_direct with disabled_tools" + + def test_disabled_tools_constant_blocks_filesystem_and_exec(self): + """_API_DISABLED_TOOLS must include all filesystem tool names and exec.""" + from nanobot.api.server import _API_DISABLED_TOOLS + for name in ("read_file", "write_file", "edit_file", "list_dir", "exec"): + assert name in _API_DISABLED_TOOLS, f"{name} missing from _API_DISABLED_TOOLS" + + def test_system_prompt_uses_isolated_memory_path(self, tmp_path): + """When memory_store is provided, the system prompt must reference + the store's paths, NOT the global workspace/memory/MEMORY.md.""" + from nanobot.agent.context import ContextBuilder + from nanobot.agent.memory import MemoryStore + + workspace = tmp_path / "workspace" + workspace.mkdir() + (workspace / "memory").mkdir() + + ctx = ContextBuilder(workspace) + + # Default prompt references global path + default_prompt = ctx.build_system_prompt() + assert "memory/MEMORY.md" in default_prompt + + # Isolated store + iso_dir = tmp_path / "sessions" / "wx_dm_alice" / "memory" + iso_dir.mkdir(parents=True) + store = MemoryStore.__new__(MemoryStore) + store.memory_dir = iso_dir + store.memory_file = iso_dir / "MEMORY.md" + store.history_file = iso_dir / "HISTORY.md" + + iso_prompt = ctx.build_system_prompt(memory_store=store) + # Must reference the isolated path + assert str(iso_dir / "MEMORY.md") in iso_prompt + assert str(iso_dir / "HISTORY.md") in iso_prompt + # Must NOT reference the global workspace memory path + global_mem = str(workspace.resolve() / "memory" / "MEMORY.md") + assert global_mem not in iso_prompt + + def test_run_agent_loop_filters_disabled_tools(self): + """_run_agent_loop must exclude disabled tools from definitions + and reject execution of disabled tools.""" + from nanobot.agent.tools.registry import ToolRegistry + + registry = ToolRegistry() + + # Create minimal fake tool definitions + class FakeTool: + def __init__(self, n): + self._name = n + + @property + def name(self): + return self._name + + def to_schema(self): + return {"type": "function", "function": {"name": self._name, "parameters": {}}} + + def validate_params(self, params): + return [] + + async def execute(self, **kw): + return "ok" + + for n in ("read_file", "write_file", "web_search", "exec"): + registry.register(FakeTool(n)) + + all_defs = registry.get_definitions() + assert len(all_defs) == 4 + + disabled = {"read_file", "write_file"} + filtered = [d for d in all_defs + if d.get("function", {}).get("name") not in disabled] + assert len(filtered) == 2 + names = {d["function"]["name"] for d in filtered} + assert names == {"web_search", "exec"} + + +# --------------------------------------------------------------------------- +# Consolidation isolation regression tests +# --------------------------------------------------------------------------- + + +class TestConsolidationIsolation: + """Verify that memory consolidation in API (isolate_memory) mode writes + to the per-session directory and never touches global workspace/memory.""" + + @pytest.mark.asyncio + async def test_consolidate_memory_uses_provided_store(self, tmp_path): + """_consolidate_memory(memory_store=X) must call X.consolidate, + not MemoryStore(self.workspace).consolidate.""" + from unittest.mock import AsyncMock, MagicMock, patch + from nanobot.agent.loop import AgentLoop + from nanobot.agent.memory import MemoryStore + from nanobot.session.manager import Session + + agent = MagicMock(spec=AgentLoop) + agent.workspace = tmp_path / "workspace" + agent.workspace.mkdir() + agent.provider = MagicMock() + agent.model = "test" + agent.memory_window = 50 + + # Bind the real method + agent._consolidate_memory = AgentLoop._consolidate_memory.__get__(agent) + + session = Session(key="test") + session.messages = [{"role": "user", "content": "hi", "timestamp": "2025-01-01T00:00"}] * 10 + + # Create an isolated store and mock its consolidate + iso_store = MagicMock(spec=MemoryStore) + iso_store.consolidate = AsyncMock(return_value=True) + + result = await agent._consolidate_memory(session, memory_store=iso_store) + + assert result is True + iso_store.consolidate.assert_called_once() + call_args = iso_store.consolidate.call_args + assert call_args[0][0] is session # first positional arg is session + + @pytest.mark.asyncio + async def test_consolidate_memory_defaults_to_global_when_no_store(self, tmp_path): + """Without memory_store, _consolidate_memory must use MemoryStore(workspace).""" + from unittest.mock import AsyncMock, MagicMock, patch + from nanobot.agent.loop import AgentLoop + from nanobot.session.manager import Session + + agent = MagicMock(spec=AgentLoop) + agent.workspace = tmp_path / "workspace" + agent.workspace.mkdir() + (agent.workspace / "memory").mkdir() + agent.provider = MagicMock() + agent.model = "test" + agent.memory_window = 50 + agent._consolidate_memory = AgentLoop._consolidate_memory.__get__(agent) + + session = Session(key="test") + + with patch("nanobot.agent.loop.MemoryStore") as MockStore: + mock_instance = MagicMock() + mock_instance.consolidate = AsyncMock(return_value=True) + MockStore.return_value = mock_instance + + await agent._consolidate_memory(session) + + MockStore.assert_called_once_with(agent.workspace) + mock_instance.consolidate.assert_called_once() + + def test_consolidate_writes_to_isolated_dir_not_global(self, tmp_path): + """End-to-end: MemoryStore.consolidate with an isolated store must + write HISTORY.md in the isolated dir, not in workspace/memory.""" + from nanobot.agent.memory import MemoryStore + + # Set up global workspace memory + global_mem_dir = tmp_path / "workspace" / "memory" + global_mem_dir.mkdir(parents=True) + (global_mem_dir / "MEMORY.md").write_text("") + (global_mem_dir / "HISTORY.md").write_text("") + + # Set up isolated per-session store + iso_dir = tmp_path / "sessions" / "wx_dm_alice" / "memory" + iso_dir.mkdir(parents=True) + + iso_store = MemoryStore.__new__(MemoryStore) + iso_store.memory_dir = iso_dir + iso_store.memory_file = iso_dir / "MEMORY.md" + iso_store.history_file = iso_dir / "HISTORY.md" + + # Write via the isolated store + iso_store.write_long_term("Alice's private data") + iso_store.append_history("[2025-01-01 00:00] Alice asked about X") + + # Isolated store has the data + assert "Alice's private data" in iso_store.read_long_term() + assert "Alice asked about X" in iso_store.history_file.read_text() + + # Global store must NOT have it + assert (global_mem_dir / "MEMORY.md").read_text() == "" + assert (global_mem_dir / "HISTORY.md").read_text() == "" + + def test_process_message_passes_memory_store_to_consolidation_paths(self): + """Verify that _process_message passes memory_store to both + consolidation triggers (source code check).""" + import ast + from pathlib import Path + + loop_path = Path(__file__).parent.parent / "nanobot" / "agent" / "loop.py" + source = loop_path.read_text() + tree = ast.parse(source) + + # Find all calls to self._consolidate_memory inside _process_message + # and verify they all pass memory_store= + for node in ast.walk(tree): + if not isinstance(node, ast.FunctionDef) or node.name != "_process_message": + continue + consolidate_calls = [] + for child in ast.walk(node): + if (isinstance(child, ast.Call) + and isinstance(child.func, ast.Attribute) + and child.func.attr == "_consolidate_memory"): + kw_names = {kw.arg for kw in child.keywords} + consolidate_calls.append(kw_names) + + assert len(consolidate_calls) == 2, ( + f"Expected 2 _consolidate_memory calls in _process_message, " + f"found {len(consolidate_calls)}" + ) + for i, kw_names in enumerate(consolidate_calls): + assert "memory_store" in kw_names, ( + f"_consolidate_memory call #{i+1} in _process_message " + f"missing memory_store= keyword argument" + ) + + +# --------------------------------------------------------------------------- +# Empty response retry + fallback tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_empty_response_retry_then_success(aiohttp_client): + """First call returns empty → retry once → second call returns real text.""" + call_count = 0 + + async def sometimes_empty(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return "" + return "recovered response" + + agent = MagicMock() + agent.process_direct = sometimes_empty + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "retry-test"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "recovered response" + assert call_count == 2 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_empty_response_both_empty_returns_fallback(aiohttp_client): + """Both calls return empty → must use the fallback text.""" + call_count = 0 + + async def always_empty(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + nonlocal call_count + call_count += 1 + return "" + + agent = MagicMock() + agent.process_direct = always_empty + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "fallback-test"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give." + assert call_count == 2 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_whitespace_only_response_triggers_retry(aiohttp_client): + """Whitespace-only response should be treated as empty and trigger retry.""" + call_count = 0 + + async def whitespace_then_ok(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return " \n " + return "real answer" + + agent = MagicMock() + agent.process_direct = whitespace_then_ok + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "ws-test"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "real answer" + assert call_count == 2 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_none_response_triggers_retry(aiohttp_client): + """None response should be treated as empty and trigger retry.""" + call_count = 0 + + async def none_then_ok(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + return None + return "got it" + + agent = MagicMock() + agent.process_direct = none_then_ok + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "none-test"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "got it" + assert call_count == 2 + + +@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") +@pytest.mark.asyncio +async def test_nonempty_response_no_retry(aiohttp_client): + """A normal non-empty response must NOT trigger a retry.""" + call_count = 0 + + async def normal_response(content, session_key="", channel="", chat_id="", + isolate_memory=False, disabled_tools=None): + nonlocal call_count + call_count += 1 + return "immediate answer" + + agent = MagicMock() + agent.process_direct = normal_response + agent._connect_mcp = AsyncMock() + agent.close_mcp = AsyncMock() + + app = create_app(agent, model_name="m") + client = await aiohttp_client(app) + + resp = await client.post( + "/v1/chat/completions", + json={"messages": [{"role": "user", "content": "hello"}]}, + headers={"x-session-key": "normal-test"}, + ) + assert resp.status == 200 + body = await resp.json() + assert body["choices"][0]["message"]["content"] == "immediate answer" + assert call_count == 1 From e868fb32d2cf83d17eadfa885b616a576567fd98 Mon Sep 17 00:00:00 2001 From: Tink Date: Fri, 6 Mar 2026 19:09:38 +0800 Subject: [PATCH 02/22] fix: add from __future__ import annotations to fix Python <3.11 compat These two files from upstream use PEP 604 union syntax (str | None) without the future annotations import. While the project requires Python >=3.11, this makes local testing possible on 3.9/3.10. Co-Authored-By: Claude Opus 4.6 --- nanobot/agent/skills.py | 2 ++ nanobot/utils/helpers.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/nanobot/agent/skills.py b/nanobot/agent/skills.py index 9afee82f0..0e1388255 100644 --- a/nanobot/agent/skills.py +++ b/nanobot/agent/skills.py @@ -1,5 +1,7 @@ """Skills loader for agent capabilities.""" +from __future__ import annotations + import json import os import re diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index c57c3654e..7e6531a86 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -1,5 +1,7 @@ """Utility functions for nanobot.""" +from __future__ import annotations + import re from datetime import datetime from pathlib import Path From 6b3997c463df94242121c556bd539da676433dad Mon Sep 17 00:00:00 2001 From: Tink Date: Fri, 6 Mar 2026 19:13:56 +0800 Subject: [PATCH 03/22] fix: add from __future__ import annotations across codebase Ensure all modules using PEP 604 union syntax (X | Y) include the future annotations import for Python <3.10 compatibility. While the project requires >=3.11, this avoids import-time TypeErrors when running tests on older interpreters. Co-Authored-By: Claude Opus 4.6 --- nanobot/agent/context.py | 2 ++ nanobot/agent/subagent.py | 2 ++ nanobot/agent/tools/base.py | 2 ++ nanobot/agent/tools/cron.py | 2 ++ nanobot/agent/tools/filesystem.py | 2 ++ nanobot/agent/tools/mcp.py | 2 ++ nanobot/agent/tools/message.py | 2 ++ nanobot/agent/tools/registry.py | 2 ++ nanobot/agent/tools/shell.py | 2 ++ nanobot/agent/tools/spawn.py | 2 ++ nanobot/agent/tools/web.py | 2 ++ nanobot/bus/events.py | 2 ++ nanobot/channels/base.py | 2 ++ nanobot/channels/dingtalk.py | 2 ++ nanobot/channels/discord.py | 2 ++ nanobot/channels/email.py | 2 ++ nanobot/channels/feishu.py | 2 ++ nanobot/channels/matrix.py | 2 ++ nanobot/channels/qq.py | 2 ++ nanobot/channels/slack.py | 2 ++ nanobot/cli/commands.py | 2 ++ nanobot/config/loader.py | 2 ++ nanobot/config/schema.py | 2 ++ nanobot/cron/service.py | 2 ++ nanobot/cron/types.py | 2 ++ nanobot/providers/base.py | 2 ++ nanobot/providers/litellm_provider.py | 2 ++ nanobot/providers/transcription.py | 2 ++ nanobot/session/manager.py | 2 ++ 29 files changed, 58 insertions(+) diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index 6a43d3e91..905562a98 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -1,5 +1,7 @@ """Context builder for assembling agent prompts.""" +from __future__ import annotations + import base64 import mimetypes import platform diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index f2d6ee5f2..20dbaede0 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -1,5 +1,7 @@ """Subagent manager for background task execution.""" +from __future__ import annotations + import asyncio import json import uuid diff --git a/nanobot/agent/tools/base.py b/nanobot/agent/tools/base.py index 051fc9acf..ea5b66318 100644 --- a/nanobot/agent/tools/base.py +++ b/nanobot/agent/tools/base.py @@ -1,5 +1,7 @@ """Base class for agent tools.""" +from __future__ import annotations + from abc import ABC, abstractmethod from typing import Any diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index f8e737b39..350e261f8 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -1,5 +1,7 @@ """Cron tool for scheduling reminders and tasks.""" +from __future__ import annotations + from contextvars import ContextVar from typing import Any diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 7b0b86725..c13464e69 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -1,5 +1,7 @@ """File system tools: read, write, edit.""" +from __future__ import annotations + import difflib from pathlib import Path from typing import Any diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index 2cbffd09d..dd6ce8c52 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -1,5 +1,7 @@ """MCP client: connects to MCP servers and wraps their tools as native nanobot tools.""" +from __future__ import annotations + import asyncio from contextlib import AsyncExitStack from typing import Any diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index 35e519a00..9d7cfbdca 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -1,5 +1,7 @@ """Message tool for sending messages to users.""" +from __future__ import annotations + from typing import Any, Awaitable, Callable from nanobot.agent.tools.base import Tool diff --git a/nanobot/agent/tools/registry.py b/nanobot/agent/tools/registry.py index 5d36e52cd..6edb88e16 100644 --- a/nanobot/agent/tools/registry.py +++ b/nanobot/agent/tools/registry.py @@ -1,5 +1,7 @@ """Tool registry for dynamic tool management.""" +from __future__ import annotations + from typing import Any from nanobot.agent.tools.base import Tool diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index ce1992092..74d1923f5 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -1,5 +1,7 @@ """Shell execution tool.""" +from __future__ import annotations + import asyncio import os import re diff --git a/nanobot/agent/tools/spawn.py b/nanobot/agent/tools/spawn.py index fc62bf8df..935dd319f 100644 --- a/nanobot/agent/tools/spawn.py +++ b/nanobot/agent/tools/spawn.py @@ -1,5 +1,7 @@ """Spawn tool for creating background subagents.""" +from __future__ import annotations + from typing import TYPE_CHECKING, Any from nanobot.agent.tools.base import Tool diff --git a/nanobot/agent/tools/web.py b/nanobot/agent/tools/web.py index 0d8f4d167..61920d981 100644 --- a/nanobot/agent/tools/web.py +++ b/nanobot/agent/tools/web.py @@ -1,5 +1,7 @@ """Web tools: web_search and web_fetch.""" +from __future__ import annotations + import html import json import os diff --git a/nanobot/bus/events.py b/nanobot/bus/events.py index 018c25b3d..0bc8f3971 100644 --- a/nanobot/bus/events.py +++ b/nanobot/bus/events.py @@ -1,5 +1,7 @@ """Event types for the message bus.""" +from __future__ import annotations + from dataclasses import dataclass, field from datetime import datetime from typing import Any diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index b38fcaf28..296426c68 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -1,5 +1,7 @@ """Base channel interface for chat platforms.""" +from __future__ import annotations + from abc import ABC, abstractmethod from typing import Any diff --git a/nanobot/channels/dingtalk.py b/nanobot/channels/dingtalk.py index 8d02fa6cd..76f25d11a 100644 --- a/nanobot/channels/dingtalk.py +++ b/nanobot/channels/dingtalk.py @@ -1,5 +1,7 @@ """DingTalk/DingDing channel implementation using Stream Mode.""" +from __future__ import annotations + import asyncio import json import mimetypes diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index c868bbf3a..fd4926742 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -1,5 +1,7 @@ """Discord channel implementation using Discord Gateway websocket.""" +from __future__ import annotations + import asyncio import json from pathlib import Path diff --git a/nanobot/channels/email.py b/nanobot/channels/email.py index 16771fb64..d0e1b61d1 100644 --- a/nanobot/channels/email.py +++ b/nanobot/channels/email.py @@ -1,5 +1,7 @@ """Email channel implementation using IMAP polling + SMTP replies.""" +from __future__ import annotations + import asyncio import html import imaplib diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 8f69c0952..e56b7da23 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -1,5 +1,7 @@ """Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection.""" +from __future__ import annotations + import asyncio import json import os diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index 4967ac13c..488b607ec 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -1,5 +1,7 @@ """Matrix (Element) channel — inbound sync + outbound message/media delivery.""" +from __future__ import annotations + import asyncio import logging import mimetypes diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index 6c5804900..1a4c8af03 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -1,5 +1,7 @@ """QQ channel implementation using botpy SDK.""" +from __future__ import annotations + import asyncio from collections import deque from typing import TYPE_CHECKING diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index afd1d2dcd..7301ced67 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -1,5 +1,7 @@ """Slack channel implementation using Socket Mode.""" +from __future__ import annotations + import asyncio import re from typing import Any diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index b28dcedc9..8035b2639 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -1,5 +1,7 @@ """CLI commands for nanobot.""" +from __future__ import annotations + import asyncio import os import select diff --git a/nanobot/config/loader.py b/nanobot/config/loader.py index c789efdaf..d16c0d468 100644 --- a/nanobot/config/loader.py +++ b/nanobot/config/loader.py @@ -1,5 +1,7 @@ """Configuration loading utilities.""" +from __future__ import annotations + import json from pathlib import Path diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 2073eeb07..5eefa831a 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -1,5 +1,7 @@ """Configuration schema using Pydantic.""" +from __future__ import annotations + from pathlib import Path from typing import Literal diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py index 1ed71f0f4..c9cd86811 100644 --- a/nanobot/cron/service.py +++ b/nanobot/cron/service.py @@ -1,5 +1,7 @@ """Cron service for scheduling agent tasks.""" +from __future__ import annotations + import asyncio import json import time diff --git a/nanobot/cron/types.py b/nanobot/cron/types.py index 2b4206057..209fddf57 100644 --- a/nanobot/cron/types.py +++ b/nanobot/cron/types.py @@ -1,5 +1,7 @@ """Cron types.""" +from __future__ import annotations + from dataclasses import dataclass, field from typing import Literal diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 55bd80571..7a90db4d1 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -1,5 +1,7 @@ """Base LLM provider interface.""" +from __future__ import annotations + from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py index 620424e61..5a76cb0ea 100644 --- a/nanobot/providers/litellm_provider.py +++ b/nanobot/providers/litellm_provider.py @@ -1,5 +1,7 @@ """LiteLLM provider implementation for multi-provider support.""" +from __future__ import annotations + import os import secrets import string diff --git a/nanobot/providers/transcription.py b/nanobot/providers/transcription.py index 1c8cb6a3f..d7fa9b3d0 100644 --- a/nanobot/providers/transcription.py +++ b/nanobot/providers/transcription.py @@ -1,5 +1,7 @@ """Voice transcription provider using Groq.""" +from __future__ import annotations + import os from pathlib import Path diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index dce4b2ec4..2cde436ed 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -1,5 +1,7 @@ """Session management for conversation history.""" +from __future__ import annotations + import json import shutil from dataclasses import dataclass, field From 9d69ba9f56a7e99e64f689ce2aaa37a82d17ffdb Mon Sep 17 00:00:00 2001 From: Tink Date: Fri, 13 Mar 2026 19:26:50 +0800 Subject: [PATCH 04/22] fix: isolate /new consolidation in API mode --- nanobot/agent/loop.py | 14 ++++---- nanobot/agent/memory.py | 25 +++++++++---- tests/test_consolidate_offset.py | 36 +++++++++++++++++-- tests/test_loop_consolidation_tokens.py | 2 +- tests/test_openai_api.py | 47 +++++++++++++++++++++++++ 5 files changed, 108 insertions(+), 16 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index ea14bc013..474068904 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger from nanobot.agent.context import ContextBuilder -from nanobot.agent.memory import MemoryConsolidator +from nanobot.agent.memory import MemoryConsolidator, MemoryStore from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool @@ -362,7 +362,7 @@ class AgentLoop: logger.info("Processing system message from {}", msg.sender_id) key = f"{channel}:{chat_id}" session = self.sessions.get_or_create(key) - await self.memory_consolidator.maybe_consolidate_by_tokens(session) + await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) history = session.get_history(max_messages=0) messages = self.context.build_messages( @@ -375,7 +375,7 @@ class AgentLoop: ) self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - await self.memory_consolidator.maybe_consolidate_by_tokens(session) + await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store) return OutboundMessage(channel=channel, chat_id=chat_id, content=final_content or "Background task completed.") @@ -389,7 +389,9 @@ class AgentLoop: cmd = msg.content.strip().lower() if cmd == "/new": try: - if not await self.memory_consolidator.archive_unconsolidated(session): + if not await self.memory_consolidator.archive_unconsolidated( + session, store=memory_store, + ): return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, @@ -419,7 +421,7 @@ class AgentLoop: return OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines), ) - await self.memory_consolidator.maybe_consolidate_by_tokens(session) + await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) if message_tool := self.tools.get("message"): @@ -453,7 +455,7 @@ class AgentLoop: self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - await self.memory_consolidator.maybe_consolidate_by_tokens(session) + await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store) if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: return None diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index f220f2346..407cc20fe 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -247,9 +247,14 @@ class MemoryConsolidator: """Return the shared consolidation lock for one session.""" return self._locks.setdefault(session_key, asyncio.Lock()) - async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool: + async def consolidate_messages( + self, + messages: list[dict[str, object]], + store: MemoryStore | None = None, + ) -> bool: """Archive a selected message chunk into persistent memory.""" - return await self.store.consolidate(messages, self.provider, self.model) + target = store or self.store + return await target.consolidate(messages, self.provider, self.model) def pick_consolidation_boundary( self, @@ -290,16 +295,24 @@ class MemoryConsolidator: self._get_tool_definitions(), ) - async def archive_unconsolidated(self, session: Session) -> bool: + async def archive_unconsolidated( + self, + session: Session, + store: MemoryStore | None = None, + ) -> bool: """Archive the full unconsolidated tail for /new-style session rollover.""" lock = self.get_lock(session.key) async with lock: snapshot = session.messages[session.last_consolidated:] if not snapshot: return True - return await self.consolidate_messages(snapshot) + return await self.consolidate_messages(snapshot, store=store) - async def maybe_consolidate_by_tokens(self, session: Session) -> None: + async def maybe_consolidate_by_tokens( + self, + session: Session, + store: MemoryStore | None = None, + ) -> None: """Loop: archive old messages until prompt fits within half the context window.""" if not session.messages or self.context_window_tokens <= 0: return @@ -347,7 +360,7 @@ class MemoryConsolidator: source, len(chunk), ) - if not await self.consolidate_messages(chunk): + if not await self.consolidate_messages(chunk, store=store): return session.last_consolidated = end_idx self.sessions.save(session) diff --git a/tests/test_consolidate_offset.py b/tests/test_consolidate_offset.py index 7d12338aa..bea193fcb 100644 --- a/tests/test_consolidate_offset.py +++ b/tests/test_consolidate_offset.py @@ -516,7 +516,7 @@ class TestNewCommandArchival: loop.sessions.save(session) before_count = len(session.messages) - async def _failing_consolidate(_messages) -> bool: + async def _failing_consolidate(_messages, store=None) -> bool: return False loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign] @@ -542,7 +542,7 @@ class TestNewCommandArchival: archived_count = -1 - async def _fake_consolidate(messages) -> bool: + async def _fake_consolidate(messages, store=None) -> bool: nonlocal archived_count archived_count = len(messages) return True @@ -567,7 +567,7 @@ class TestNewCommandArchival: session.add_message("assistant", f"resp{i}") loop.sessions.save(session) - async def _ok_consolidate(_messages) -> bool: + async def _ok_consolidate(_messages, store=None) -> bool: return True loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign] @@ -578,3 +578,33 @@ class TestNewCommandArchival: assert response is not None assert "new session started" in response.content.lower() assert loop.sessions.get_or_create("cli:test").messages == [] + + @pytest.mark.asyncio + async def test_new_archives_to_custom_store_when_provided(self, tmp_path: Path) -> None: + """When memory_store is passed, /new must archive through that store.""" + from nanobot.bus.events import InboundMessage + from nanobot.agent.memory import MemoryStore + + loop = self._make_loop(tmp_path) + session = loop.sessions.get_or_create("cli:test") + for i in range(5): + session.add_message("user", f"msg{i}") + session.add_message("assistant", f"resp{i}") + loop.sessions.save(session) + + used_store = None + + async def _tracking_consolidate(messages, store=None) -> bool: + nonlocal used_store + used_store = store + return True + + loop.memory_consolidator.consolidate_messages = _tracking_consolidate # type: ignore[method-assign] + + iso_store = MagicMock(spec=MemoryStore) + new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") + response = await loop._process_message(new_msg, memory_store=iso_store) + + assert response is not None + assert "new session started" in response.content.lower() + assert used_store is iso_store, "archive_unconsolidated must use the provided store" diff --git a/tests/test_loop_consolidation_tokens.py b/tests/test_loop_consolidation_tokens.py index b0f3dda53..7daa38809 100644 --- a/tests/test_loop_consolidation_tokens.py +++ b/tests/test_loop_consolidation_tokens.py @@ -158,7 +158,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) - async def track_consolidate(messages): + async def track_consolidate(messages, store=None): order.append("consolidate") return True loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign] diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py index 216596de0..d2d30b8b8 100644 --- a/tests/test_openai_api.py +++ b/tests/test_openai_api.py @@ -622,6 +622,53 @@ class TestConsolidationIsolation: assert (global_mem_dir / "MEMORY.md").read_text() == "" assert (global_mem_dir / "HISTORY.md").read_text() == "" + @pytest.mark.asyncio + async def test_new_command_uses_isolated_store(self, tmp_path): + """process_direct(isolate_memory=True) + /new must archive to the isolated store.""" + from unittest.mock import AsyncMock, MagicMock + from nanobot.agent.loop import AgentLoop + from nanobot.agent.memory import MemoryStore + from nanobot.bus.queue import MessageBus + from nanobot.providers.base import LLMResponse + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.estimate_prompt_tokens.return_value = (10_000, "test") + agent = AgentLoop( + bus=bus, provider=provider, workspace=tmp_path, + model="test-model", context_window_tokens=1, + ) + agent._mcp_connected = True # skip MCP connect + agent.tools.get_definitions = MagicMock(return_value=[]) + + # Pre-populate session so /new has something to archive + session = agent.sessions.get_or_create("api:alice") + for i in range(3): + session.add_message("user", f"msg{i}") + session.add_message("assistant", f"resp{i}") + agent.sessions.save(session) + + used_store = None + + async def _tracking_consolidate(messages, store=None) -> bool: + nonlocal used_store + used_store = store + return True + + agent.memory_consolidator.consolidate_messages = _tracking_consolidate # type: ignore[method-assign] + + result = await agent.process_direct( + "/new", session_key="api:alice", isolate_memory=True, + ) + + assert "new session started" in result.lower() + assert used_store is not None, "consolidation must receive a store" + assert isinstance(used_store, MemoryStore) + assert "sessions" in str(used_store.memory_dir), ( + "store must point to per-session dir, not global workspace" + ) + # --------------------------------------------------------------------------- From c8c520cc9a4dbe619eb3f21200dc40971a36b665 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sat, 28 Mar 2026 13:28:56 +0000 Subject: [PATCH 05/22] docs: update providers information --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 60f131244..828b56477 100644 --- a/README.md +++ b/README.md @@ -854,7 +854,6 @@ Config file: `~/.nanobot/config.json` > - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config. > - **Alibaba Cloud BaiLian**: If you're using Alibaba Cloud BaiLian's OpenAI-compatible endpoint, set `"apiBase": "https://dashscope.aliyuncs.com/compatible-mode/v1"` in your dashscope provider config. > - **Step Fun (Mainland China)**: If your API key is from Step Fun's mainland China platform (stepfun.com), set `"apiBase": "https://api.stepfun.com/v1"` in your stepfun provider config. -> - **Step Fun Step Plan**: Exclusive discount links for the nanobot community: [Overseas](https://platform.stepfun.ai/step-plan) · [Mainland China](https://platform.stepfun.com/step-plan) | Provider | Purpose | Get API Key | |----------|---------|-------------| From 5635907e3318f16979c2833bb1fc2b2a0c9b6aab Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sun, 29 Mar 2026 15:32:33 +0000 Subject: [PATCH 06/22] feat(api): load serve settings from config Read serve host, port, and timeout from config by default, keep CLI flags higher priority, and bind the API to localhost by default for safer local usage. --- nanobot/api/server.py | 2 +- nanobot/cli/commands.py | 15 ++- nanobot/config/schema.py | 9 ++ tests/cli/test_commands.py | 262 ++++++++++++++++++++++++++----------- 4 files changed, 206 insertions(+), 82 deletions(-) diff --git a/nanobot/api/server.py b/nanobot/api/server.py index 1dd58d512..2a818667a 100644 --- a/nanobot/api/server.py +++ b/nanobot/api/server.py @@ -192,7 +192,7 @@ def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = return app -def run_server(agent_loop, host: str = "0.0.0.0", port: int = 8900, +def run_server(agent_loop, host: str = "127.0.0.1", port: int = 8900, model_name: str = "nanobot", request_timeout: float = 120.0) -> None: """Create and run the server (blocking).""" app = create_app(agent_loop, model_name=model_name, request_timeout=request_timeout) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index d3fc68e8f..7f7d24f39 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -498,9 +498,9 @@ def _migrate_cron_store(config: "Config") -> None: @app.command() def serve( - port: int = typer.Option(8900, "--port", "-p", help="API server port"), - host: str = typer.Option("0.0.0.0", "--host", "-H", help="Bind address"), - timeout: float = typer.Option(120.0, "--timeout", "-t", help="Per-request timeout (seconds)"), + port: int | None = typer.Option(None, "--port", "-p", help="API server port"), + host: str | None = typer.Option(None, "--host", "-H", help="Bind address"), + timeout: float | None = typer.Option(None, "--timeout", "-t", help="Per-request timeout (seconds)"), verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"), workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"), config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), @@ -524,6 +524,10 @@ def serve( logger.disable("nanobot") runtime_config = _load_runtime_config(config, workspace) + api_cfg = runtime_config.api + host = host if host is not None else api_cfg.host + port = port if port is not None else api_cfg.port + timeout = timeout if timeout is not None else api_cfg.timeout sync_workspace_templates(runtime_config.workspace_path) bus = MessageBus() provider = _make_provider(runtime_config) @@ -551,6 +555,11 @@ def serve( console.print(f" [cyan]Model[/cyan] : {model_name}") console.print(" [cyan]Session[/cyan] : api:default") console.print(f" [cyan]Timeout[/cyan] : {timeout}s") + if host in {"0.0.0.0", "::"}: + console.print( + "[yellow]Warning:[/yellow] API is bound to all interfaces. " + "Only do this behind a trusted network boundary, firewall, or reverse proxy." + ) console.print() api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout) diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index c8b69b42e..c4c927afd 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -96,6 +96,14 @@ class HeartbeatConfig(Base): keep_recent_messages: int = 8 +class ApiConfig(Base): + """OpenAI-compatible API server configuration.""" + + host: str = "127.0.0.1" # Safer default: local-only bind. + port: int = 8900 + timeout: float = 120.0 # Per-request timeout in seconds. + + class GatewayConfig(Base): """Gateway/server configuration.""" @@ -156,6 +164,7 @@ class Config(BaseSettings): agents: AgentsConfig = Field(default_factory=AgentsConfig) channels: ChannelsConfig = Field(default_factory=ChannelsConfig) providers: ProvidersConfig = Field(default_factory=ProvidersConfig) + api: ApiConfig = Field(default_factory=ApiConfig) gateway: GatewayConfig = Field(default_factory=GatewayConfig) tools: ToolsConfig = Field(default_factory=ToolsConfig) diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index a8fcc4aa0..735c02a5a 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -642,27 +642,105 @@ def test_heartbeat_retains_recent_messages_by_default(): assert config.gateway.heartbeat.keep_recent_messages == 8 -def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None: +def _write_instance_config(tmp_path: Path) -> Path: config_file = tmp_path / "instance" / "config.json" config_file.parent.mkdir(parents=True) config_file.write_text("{}") + return config_file - config = Config() - config.agents.defaults.workspace = str(tmp_path / "config-workspace") - seen: dict[str, Path] = {} +def _stop_gateway_provider(_config) -> object: + raise _StopGatewayError("stop") + + +def _patch_cli_command_runtime( + monkeypatch, + config: Config, + *, + set_config_path=None, + sync_templates=None, + make_provider=None, + message_bus=None, + session_manager=None, + cron_service=None, + get_cron_dir=None, +) -> None: monkeypatch.setattr( "nanobot.config.loader.set_config_path", - lambda path: seen.__setitem__("config_path", path), + set_config_path or (lambda _path: None), ) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) monkeypatch.setattr( "nanobot.cli.commands.sync_workspace_templates", - lambda path: seen.__setitem__("workspace", path), + sync_templates or (lambda _path: None), ) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), + make_provider or (lambda _config: object()), + ) + + if message_bus is not None: + monkeypatch.setattr("nanobot.bus.queue.MessageBus", message_bus) + if session_manager is not None: + monkeypatch.setattr("nanobot.session.manager.SessionManager", session_manager) + if cron_service is not None: + monkeypatch.setattr("nanobot.cron.service.CronService", cron_service) + if get_cron_dir is not None: + monkeypatch.setattr("nanobot.config.paths.get_cron_dir", get_cron_dir) + + +def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) -> None: + pytest.importorskip("aiohttp") + + class _FakeApiApp: + def __init__(self) -> None: + self.on_startup: list[object] = [] + self.on_cleanup: list[object] = [] + + class _FakeAgentLoop: + def __init__(self, **kwargs) -> None: + seen["workspace"] = kwargs["workspace"] + + async def _connect_mcp(self) -> None: + return None + + async def close_mcp(self) -> None: + return None + + def _fake_create_app(agent_loop, model_name: str, request_timeout: float): + seen["agent_loop"] = agent_loop + seen["model_name"] = model_name + seen["request_timeout"] = request_timeout + return _FakeApiApp() + + def _fake_run_app(api_app, host: str, port: int, print): + seen["api_app"] = api_app + seen["host"] = host + seen["port"] = port + + _patch_cli_command_runtime( + monkeypatch, + config, + message_bus=lambda: object(), + session_manager=lambda _workspace: object(), + ) + monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.api.server.create_app", _fake_create_app) + monkeypatch.setattr("aiohttp.web.run_app", _fake_run_app) + + +def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None: + config_file = _write_instance_config(tmp_path) + config = Config() + config.agents.defaults.workspace = str(tmp_path / "config-workspace") + seen: dict[str, Path] = {} + + _patch_cli_command_runtime( + monkeypatch, + config, + set_config_path=lambda path: seen.__setitem__("config_path", path), + sync_templates=lambda path: seen.__setitem__("workspace", path), + make_provider=_stop_gateway_provider, ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) @@ -673,24 +751,17 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) config = Config() config.agents.defaults.workspace = str(tmp_path / "config-workspace") override = tmp_path / "override-workspace" seen: dict[str, Path] = {} - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr( - "nanobot.cli.commands.sync_workspace_templates", - lambda path: seen.__setitem__("workspace", path), - ) - monkeypatch.setattr( - "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), + _patch_cli_command_runtime( + monkeypatch, + config, + sync_templates=lambda path: seen.__setitem__("workspace", path), + make_provider=_stop_gateway_provider, ) result = runner.invoke( @@ -704,27 +775,23 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) config = Config() config.agents.defaults.workspace = str(tmp_path / "config-workspace") seen: dict[str, Path] = {} - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) - monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) - monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) - class _StopCron: def __init__(self, store_path: Path) -> None: seen["cron_store"] = store_path raise _StopGatewayError("stop") - monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) + _patch_cli_command_runtime( + monkeypatch, + config, + message_bus=lambda: object(), + session_manager=lambda _workspace: object(), + cron_service=_StopCron, + ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) @@ -735,10 +802,7 @@ def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: def test_gateway_workspace_override_does_not_migrate_legacy_cron( monkeypatch, tmp_path: Path ) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) legacy_dir = tmp_path / "global" / "cron" legacy_dir.mkdir(parents=True) legacy_file = legacy_dir / "jobs.json" @@ -748,20 +812,19 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron( config = Config() seen: dict[str, Path] = {} - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) - monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) - monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) - monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir) - class _StopCron: def __init__(self, store_path: Path) -> None: seen["cron_store"] = store_path raise _StopGatewayError("stop") - monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) + _patch_cli_command_runtime( + monkeypatch, + config, + message_bus=lambda: object(), + session_manager=lambda _workspace: object(), + cron_service=_StopCron, + get_cron_dir=lambda: legacy_dir, + ) result = runner.invoke( app, @@ -777,10 +840,7 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron( def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron( monkeypatch, tmp_path: Path ) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) legacy_dir = tmp_path / "global" / "cron" legacy_dir.mkdir(parents=True) legacy_file = legacy_dir / "jobs.json" @@ -791,20 +851,19 @@ def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron( config.agents.defaults.workspace = str(custom_workspace) seen: dict[str, Path] = {} - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) - monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object()) - monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) - monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir) - class _StopCron: def __init__(self, store_path: Path) -> None: seen["cron_store"] = store_path raise _StopGatewayError("stop") - monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) + _patch_cli_command_runtime( + monkeypatch, + config, + message_bus=lambda: object(), + session_manager=lambda _workspace: object(), + cron_service=_StopCron, + get_cron_dir=lambda: legacy_dir, + ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) @@ -856,19 +915,14 @@ def test_migrate_cron_store_skips_when_workspace_file_exists(tmp_path: Path) -> def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) config = Config() config.gateway.port = 18791 - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr( - "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), + _patch_cli_command_runtime( + monkeypatch, + config, + make_provider=_stop_gateway_provider, ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) @@ -878,19 +932,14 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None: - config_file = tmp_path / "instance" / "config.json" - config_file.parent.mkdir(parents=True) - config_file.write_text("{}") - + config_file = _write_instance_config(tmp_path) config = Config() config.gateway.port = 18791 - monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) - monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) - monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) - monkeypatch.setattr( - "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), + _patch_cli_command_runtime( + monkeypatch, + config, + make_provider=_stop_gateway_provider, ) result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"]) @@ -899,6 +948,63 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) assert "port 18792" in result.stdout +def test_serve_uses_api_config_defaults_and_workspace_override( + monkeypatch, tmp_path: Path +) -> None: + config_file = _write_instance_config(tmp_path) + config = Config() + config.agents.defaults.workspace = str(tmp_path / "config-workspace") + config.api.host = "127.0.0.2" + config.api.port = 18900 + config.api.timeout = 45.0 + override_workspace = tmp_path / "override-workspace" + seen: dict[str, object] = {} + + _patch_serve_runtime(monkeypatch, config, seen) + + result = runner.invoke( + app, + ["serve", "--config", str(config_file), "--workspace", str(override_workspace)], + ) + + assert result.exit_code == 0 + assert seen["workspace"] == override_workspace + assert seen["host"] == "127.0.0.2" + assert seen["port"] == 18900 + assert seen["request_timeout"] == 45.0 + + +def test_serve_cli_options_override_api_config(monkeypatch, tmp_path: Path) -> None: + config_file = _write_instance_config(tmp_path) + config = Config() + config.api.host = "127.0.0.2" + config.api.port = 18900 + config.api.timeout = 45.0 + seen: dict[str, object] = {} + + _patch_serve_runtime(monkeypatch, config, seen) + + result = runner.invoke( + app, + [ + "serve", + "--config", + str(config_file), + "--host", + "127.0.0.1", + "--port", + "18901", + "--timeout", + "46", + ], + ) + + assert result.exit_code == 0 + assert seen["host"] == "127.0.0.1" + assert seen["port"] == 18901 + assert seen["request_timeout"] == 46.0 + + def test_channels_login_requires_channel_name() -> None: result = runner.invoke(app, ["channels", "login"]) From 55501057ac138b4ab75e36d5ef605ea4c96a5af6 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Mon, 30 Mar 2026 14:20:14 +0000 Subject: [PATCH 07/22] refactor(api): tighten fixed-session chat input contract Reject mismatched models and require a single user message so the OpenAI-compatible endpoint reflects the fixed-session nanobot runtime without extra compatibility noise. --- nanobot/api/server.py | 27 ++++++---------- tests/test_openai_api.py | 68 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 18 deletions(-) diff --git a/nanobot/api/server.py b/nanobot/api/server.py index 2a818667a..34b73ad57 100644 --- a/nanobot/api/server.py +++ b/nanobot/api/server.py @@ -69,21 +69,17 @@ async def handle_chat_completions(request: web.Request) -> web.Response: return _error_json(400, "Invalid JSON body") messages = body.get("messages") - if not messages or not isinstance(messages, list): - return _error_json(400, "messages field is required and must be a non-empty array") + if not isinstance(messages, list) or len(messages) != 1: + return _error_json(400, "Only a single user message is supported") # Stream not yet supported if body.get("stream", False): return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.") - # Extract last user message — nanobot manages its own multi-turn history - user_content = None - for msg in reversed(messages): - if msg.get("role") == "user": - user_content = msg.get("content", "") - break - if user_content is None: - return _error_json(400, "messages must contain at least one user message") + message = messages[0] + if not isinstance(message, dict) or message.get("role") != "user": + return _error_json(400, "Only a single user message is supported") + user_content = message.get("content", "") if isinstance(user_content, list): # Multi-modal content array — extract text parts user_content = " ".join( @@ -92,7 +88,9 @@ async def handle_chat_completions(request: web.Request) -> web.Response: agent_loop = request.app["agent_loop"] timeout_s: float = request.app.get("request_timeout", 120.0) - model_name: str = body.get("model") or request.app.get("model_name", "nanobot") + model_name: str = request.app.get("model_name", "nanobot") + if (requested_model := body.get("model")) and requested_model != model_name: + return _error_json(400, f"Only configured model '{model_name}' is available") session_lock: asyncio.Lock = request.app["session_lock"] logger.info("API request session_key={} content={}", API_SESSION_KEY, user_content[:80]) @@ -190,10 +188,3 @@ def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = app.router.add_get("/v1/models", handle_models) app.router.add_get("/health", handle_health) return app - - -def run_server(agent_loop, host: str = "127.0.0.1", port: int = 8900, - model_name: str = "nanobot", request_timeout: float = 120.0) -> None: - """Create and run the server (blocking).""" - app = create_app(agent_loop, model_name=model_name, request_timeout=request_timeout) - web.run_app(app, host=host, port=port, print=lambda msg: logger.info(msg)) diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py index dbb47f6b6..d935729a8 100644 --- a/tests/test_openai_api.py +++ b/tests/test_openai_api.py @@ -14,6 +14,7 @@ from nanobot.api.server import ( _chat_completion_response, _error_json, create_app, + handle_chat_completions, ) try: @@ -93,6 +94,73 @@ async def test_stream_true_returns_400(aiohttp_client, app) -> None: assert "stream" in body["error"]["message"].lower() +@pytest.mark.asyncio +async def test_model_mismatch_returns_400() -> None: + request = MagicMock() + request.json = AsyncMock( + return_value={ + "model": "other-model", + "messages": [{"role": "user", "content": "hello"}], + } + ) + request.app = { + "agent_loop": _make_mock_agent(), + "model_name": "test-model", + "request_timeout": 10.0, + "session_lock": asyncio.Lock(), + } + + resp = await handle_chat_completions(request) + assert resp.status == 400 + body = json.loads(resp.body) + assert "test-model" in body["error"]["message"] + + +@pytest.mark.asyncio +async def test_single_user_message_required() -> None: + request = MagicMock() + request.json = AsyncMock( + return_value={ + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous reply"}, + ], + } + ) + request.app = { + "agent_loop": _make_mock_agent(), + "model_name": "test-model", + "request_timeout": 10.0, + "session_lock": asyncio.Lock(), + } + + resp = await handle_chat_completions(request) + assert resp.status == 400 + body = json.loads(resp.body) + assert "single user message" in body["error"]["message"].lower() + + +@pytest.mark.asyncio +async def test_single_user_message_must_have_user_role() -> None: + request = MagicMock() + request.json = AsyncMock( + return_value={ + "messages": [{"role": "system", "content": "you are a bot"}], + } + ) + request.app = { + "agent_loop": _make_mock_agent(), + "model_name": "test-model", + "request_timeout": 10.0, + "session_lock": asyncio.Lock(), + } + + resp = await handle_chat_completions(request) + assert resp.status == 400 + body = json.loads(resp.body) + assert "single user message" in body["error"]["message"].lower() + + @pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") @pytest.mark.asyncio async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_agent) -> None: From d9a5080d66874affd9812fc5bcb5c07004ccd081 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Mon, 30 Mar 2026 14:43:22 +0000 Subject: [PATCH 08/22] refactor(api): tighten fixed-session API contract Require a single user message, reject mismatched models, document the OpenAI-compatible API, and exclude api/ from core agent line counts so the interface matches nanobot's minimal fixed-session runtime. --- README.md | 76 +++++++++++++++++++++++++++++++++++++++++++++ core_agent_lines.sh | 6 ++-- 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 828b56477..01bc11c25 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,7 @@ - [Configuration](#️-configuration) - [Multiple Instances](#-multiple-instances) - [CLI Reference](#-cli-reference) +- [OpenAI-Compatible API](#-openai-compatible-api) - [Docker](#-docker) - [Linux Service](#-linux-service) - [Project Structure](#-project-structure) @@ -1541,6 +1542,7 @@ nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobo | `nanobot agent` | Interactive chat mode | | `nanobot agent --no-markdown` | Show plain-text replies | | `nanobot agent --logs` | Show runtime logs during chat | +| `nanobot serve` | Start the OpenAI-compatible API | | `nanobot gateway` | Start the gateway | | `nanobot status` | Show status | | `nanobot provider login openai-codex` | OAuth login for providers | @@ -1569,6 +1571,80 @@ The agent can also manage this file itself — ask it to "add a periodic task" a +## 🔌 OpenAI-Compatible API + +nanobot can expose a minimal OpenAI-compatible endpoint for local integrations: + +```bash +pip install "nanobot-ai[api]" +nanobot serve +``` + +By default, the API binds to `127.0.0.1:8900`. + +### Behavior + +- Fixed session: all requests share the same nanobot session (`api:default`) +- Single-message input: each request must contain exactly one `user` message +- Fixed model: omit `model`, or pass the same model shown by `/v1/models` +- No streaming: `stream=true` is not supported + +### Endpoints + +- `GET /health` +- `GET /v1/models` +- `POST /v1/chat/completions` + +### curl + +```bash +curl http://127.0.0.1:8900/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "messages": [ + { + "role": "user", + "content": "hi" + } + ] + }' +``` + +### Python (`requests`) + +```python +import requests + +resp = requests.post( + "http://127.0.0.1:8900/v1/chat/completions", + json={ + "messages": [ + {"role": "user", "content": "hi"} + ] + }, + timeout=120, +) +resp.raise_for_status() +print(resp.json()["choices"][0]["message"]["content"]) +``` + +### Python (`openai`) + +```python +from openai import OpenAI + +client = OpenAI( + base_url="http://127.0.0.1:8900/v1", + api_key="dummy", +) + +resp = client.chat.completions.create( + model="MiniMax-M2.7", + messages=[{"role": "user", "content": "hi"}], +) +print(resp.choices[0].message.content) +``` + ## 🐳 Docker > [!TIP] diff --git a/core_agent_lines.sh b/core_agent_lines.sh index d35207cb4..90f39aacc 100755 --- a/core_agent_lines.sh +++ b/core_agent_lines.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Count core agent lines (excluding channels/, cli/, providers/ adapters) +# Count core agent lines (excluding channels/, cli/, api/, providers/ adapters) cd "$(dirname "$0")" || exit 1 echo "nanobot core agent line count" @@ -15,7 +15,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l) printf " %-16s %5s lines\n" "(root)" "$root" echo "" -total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l) +total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/api/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l) echo " Core total: $total lines" echo "" -echo " (excludes: channels/, cli/, command/, providers/, skills/)" +echo " (excludes: channels/, cli/, api/, command/, providers/, skills/)" From 5e99b81c6e55a8ea9b99edb0ea5804d9eb731eab Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Mon, 30 Mar 2026 15:05:06 +0000 Subject: [PATCH 09/22] refactor(api): reduce compatibility and test noise Make the fixed-session API surface explicit, document its usage, exclude api/ from core agent line counts, and remove implicit aiohttp pytest fixture dependencies from API tests. --- tests/test_openai_api.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py index d935729a8..3d29d4767 100644 --- a/tests/test_openai_api.py +++ b/tests/test_openai_api.py @@ -7,6 +7,7 @@ import json from unittest.mock import AsyncMock, MagicMock import pytest +import pytest_asyncio from nanobot.api.server import ( API_CHAT_ID, @@ -18,7 +19,7 @@ from nanobot.api.server import ( ) try: - import aiohttp # noqa: F401 + from aiohttp.test_utils import TestClient, TestServer HAS_AIOHTTP = True except ImportError: @@ -45,6 +46,23 @@ def app(mock_agent): return create_app(mock_agent, model_name="test-model", request_timeout=10.0) +@pytest_asyncio.fixture +async def aiohttp_client(): + clients: list[TestClient] = [] + + async def _make_client(app): + client = TestClient(TestServer(app)) + await client.start_server() + clients.append(client) + return client + + try: + yield _make_client + finally: + for client in clients: + await client.close() + + def test_error_json() -> None: resp = _error_json(400, "bad request") assert resp.status == 400 From f08de72f18c0889592458c95c547fdf03cb2e78a Mon Sep 17 00:00:00 2001 From: sontianye Date: Sun, 29 Mar 2026 22:56:02 +0800 Subject: [PATCH 10/22] feat(agent): add CompositeHook for composable lifecycle hooks Introduce a CompositeHook that fans out lifecycle callbacks to an ordered list of AgentHook instances with per-hook error isolation. Extract the nested _LoopHook and _SubagentHook to module scope as public LoopHook / SubagentHook so downstream users can subclass or compose them. Add `hooks` parameter to AgentLoop.__init__ for registering custom hooks at construction time. Closes #2603 --- nanobot/agent/__init__.py | 17 +- nanobot/agent/hook.py | 59 ++++++ nanobot/agent/loop.py | 124 +++++++---- nanobot/agent/subagent.py | 30 ++- tests/agent/test_hook_composite.py | 330 +++++++++++++++++++++++++++++ 5 files changed, 508 insertions(+), 52 deletions(-) create mode 100644 tests/agent/test_hook_composite.py diff --git a/nanobot/agent/__init__.py b/nanobot/agent/__init__.py index f9ba8b87a..d3805805b 100644 --- a/nanobot/agent/__init__.py +++ b/nanobot/agent/__init__.py @@ -1,8 +1,21 @@ """Agent core module.""" from nanobot.agent.context import ContextBuilder -from nanobot.agent.loop import AgentLoop +from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook +from nanobot.agent.loop import AgentLoop, LoopHook from nanobot.agent.memory import MemoryStore from nanobot.agent.skills import SkillsLoader +from nanobot.agent.subagent import SubagentHook, SubagentManager -__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"] +__all__ = [ + "AgentHook", + "AgentHookContext", + "AgentLoop", + "CompositeHook", + "ContextBuilder", + "LoopHook", + "MemoryStore", + "SkillsLoader", + "SubagentHook", + "SubagentManager", +] diff --git a/nanobot/agent/hook.py b/nanobot/agent/hook.py index 368c46aa2..97ec7a07d 100644 --- a/nanobot/agent/hook.py +++ b/nanobot/agent/hook.py @@ -5,6 +5,8 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import Any +from loguru import logger + from nanobot.providers.base import LLMResponse, ToolCallRequest @@ -47,3 +49,60 @@ class AgentHook: def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: return content + + +class CompositeHook(AgentHook): + """Fan-out hook that delegates to an ordered list of hooks. + + Error isolation: async methods catch and log per-hook exceptions + so a faulty custom hook cannot crash the agent loop. + ``finalize_content`` is a pipeline (no isolation — bugs should surface). + """ + + __slots__ = ("_hooks",) + + def __init__(self, hooks: list[AgentHook]) -> None: + self._hooks = list(hooks) + + def wants_streaming(self) -> bool: + return any(h.wants_streaming() for h in self._hooks) + + async def before_iteration(self, context: AgentHookContext) -> None: + for h in self._hooks: + try: + await h.before_iteration(context) + except Exception: + logger.exception("AgentHook.before_iteration error in {}", type(h).__name__) + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + for h in self._hooks: + try: + await h.on_stream(context, delta) + except Exception: + logger.exception("AgentHook.on_stream error in {}", type(h).__name__) + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + for h in self._hooks: + try: + await h.on_stream_end(context, resuming=resuming) + except Exception: + logger.exception("AgentHook.on_stream_end error in {}", type(h).__name__) + + async def before_execute_tools(self, context: AgentHookContext) -> None: + for h in self._hooks: + try: + await h.before_execute_tools(context) + except Exception: + logger.exception("AgentHook.before_execute_tools error in {}", type(h).__name__) + + async def after_iteration(self, context: AgentHookContext) -> None: + for h in self._hooks: + try: + await h.after_iteration(context) + except Exception: + logger.exception("AgentHook.after_iteration error in {}", type(h).__name__) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + for h in self._hooks: + content = h.finalize_content(context, content) + return content diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 63ee92ca5..0e58fa557 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger from nanobot.agent.context import ContextBuilder -from nanobot.agent.hook import AgentHook, AgentHookContext +from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.memory import MemoryConsolidator from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.subagent import SubagentManager @@ -37,6 +37,71 @@ if TYPE_CHECKING: from nanobot.cron.service import CronService +class LoopHook(AgentHook): + """Core lifecycle hook for the main agent loop. + + Handles streaming delta relay, progress reporting, tool-call logging, + and think-tag stripping. Public so downstream users can subclass or + compose it via :class:`CompositeHook`. + """ + + def __init__( + self, + agent_loop: AgentLoop, + on_progress: Callable[..., Awaitable[None]] | None = None, + on_stream: Callable[[str], Awaitable[None]] | None = None, + on_stream_end: Callable[..., Awaitable[None]] | None = None, + *, + channel: str = "cli", + chat_id: str = "direct", + message_id: str | None = None, + ) -> None: + self._loop = agent_loop + self._on_progress = on_progress + self._on_stream = on_stream + self._on_stream_end = on_stream_end + self._channel = channel + self._chat_id = chat_id + self._message_id = message_id + self._stream_buf = "" + + def wants_streaming(self) -> bool: + return self._on_stream is not None + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + from nanobot.utils.helpers import strip_think + + prev_clean = strip_think(self._stream_buf) + self._stream_buf += delta + new_clean = strip_think(self._stream_buf) + incremental = new_clean[len(prev_clean):] + if incremental and self._on_stream: + await self._on_stream(incremental) + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + if self._on_stream_end: + await self._on_stream_end(resuming=resuming) + self._stream_buf = "" + + async def before_execute_tools(self, context: AgentHookContext) -> None: + if self._on_progress: + if not self._on_stream: + thought = self._loop._strip_think( + context.response.content if context.response else None + ) + if thought: + await self._on_progress(thought) + tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls)) + await self._on_progress(tool_hint, tool_hint=True) + for tc in context.tool_calls: + args_str = json.dumps(tc.arguments, ensure_ascii=False) + logger.info("Tool call: {}({})", tc.name, args_str[:200]) + self._loop._set_tool_context(self._channel, self._chat_id, self._message_id) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + return self._loop._strip_think(content) + + class AgentLoop: """ The agent loop is the core processing engine. @@ -68,6 +133,7 @@ class AgentLoop: mcp_servers: dict | None = None, channels_config: ChannelsConfig | None = None, timezone: str | None = None, + hooks: list[AgentHook] | None = None, ): from nanobot.config.schema import ExecToolConfig, WebSearchConfig @@ -85,6 +151,7 @@ class AgentLoop: self.restrict_to_workspace = restrict_to_workspace self._start_time = time.time() self._last_usage: dict[str, int] = {} + self._extra_hooks: list[AgentHook] = hooks or [] self.context = ContextBuilder(workspace, timezone=timezone) self.sessions = session_manager or SessionManager(workspace) @@ -217,52 +284,27 @@ class AgentLoop: ``resuming=True`` means tool calls follow (spinner should restart); ``resuming=False`` means this is the final response. """ - loop_self = self - - class _LoopHook(AgentHook): - def __init__(self) -> None: - self._stream_buf = "" - - def wants_streaming(self) -> bool: - return on_stream is not None - - async def on_stream(self, context: AgentHookContext, delta: str) -> None: - from nanobot.utils.helpers import strip_think - - prev_clean = strip_think(self._stream_buf) - self._stream_buf += delta - new_clean = strip_think(self._stream_buf) - incremental = new_clean[len(prev_clean):] - if incremental and on_stream: - await on_stream(incremental) - - async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: - if on_stream_end: - await on_stream_end(resuming=resuming) - self._stream_buf = "" - - async def before_execute_tools(self, context: AgentHookContext) -> None: - if on_progress: - if not on_stream: - thought = loop_self._strip_think(context.response.content if context.response else None) - if thought: - await on_progress(thought) - tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls)) - await on_progress(tool_hint, tool_hint=True) - for tc in context.tool_calls: - args_str = json.dumps(tc.arguments, ensure_ascii=False) - logger.info("Tool call: {}({})", tc.name, args_str[:200]) - loop_self._set_tool_context(channel, chat_id, message_id) - - def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: - return loop_self._strip_think(content) + loop_hook = LoopHook( + self, + on_progress=on_progress, + on_stream=on_stream, + on_stream_end=on_stream_end, + channel=channel, + chat_id=chat_id, + message_id=message_id, + ) + hook: AgentHook = ( + CompositeHook([loop_hook, *self._extra_hooks]) + if self._extra_hooks + else loop_hook + ) result = await self.runner.run(AgentRunSpec( initial_messages=initial_messages, tools=self.tools, model=self.model, max_iterations=self.max_iterations, - hook=_LoopHook(), + hook=hook, error_message="Sorry, I encountered an error calling the AI model.", concurrent_tools=True, )) diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index 5266fc8b1..691f53820 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -21,6 +21,24 @@ from nanobot.config.schema import ExecToolConfig from nanobot.providers.base import LLMProvider +class SubagentHook(AgentHook): + """Logging-only hook for subagent execution. + + Public so downstream users can subclass or compose via :class:`CompositeHook`. + """ + + def __init__(self, task_id: str) -> None: + self._task_id = task_id + + async def before_execute_tools(self, context: AgentHookContext) -> None: + for tool_call in context.tool_calls: + args_str = json.dumps(tool_call.arguments, ensure_ascii=False) + logger.debug( + "Subagent [{}] executing: {} with arguments: {}", + self._task_id, tool_call.name, args_str, + ) + + class SubagentManager: """Manages background subagent execution.""" @@ -108,25 +126,19 @@ class SubagentManager: )) tools.register(WebSearchTool(config=self.web_search_config, proxy=self.web_proxy)) tools.register(WebFetchTool(proxy=self.web_proxy)) - + system_prompt = self._build_subagent_prompt() messages: list[dict[str, Any]] = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": task}, ] - class _SubagentHook(AgentHook): - async def before_execute_tools(self, context: AgentHookContext) -> None: - for tool_call in context.tool_calls: - args_str = json.dumps(tool_call.arguments, ensure_ascii=False) - logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str) - result = await self.runner.run(AgentRunSpec( initial_messages=messages, tools=tools, model=self.model, max_iterations=15, - hook=_SubagentHook(), + hook=SubagentHook(task_id), max_iterations_message="Task completed but no final response was generated.", error_message=None, fail_on_tool_error=True, @@ -213,7 +225,7 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men lines.append("Failure:") lines.append(f"- {result.error}") return "\n".join(lines) or (result.error or "Error: subagent execution failed.") - + def _build_subagent_prompt(self) -> str: """Build a focused system prompt for the subagent.""" from nanobot.agent.context import ContextBuilder diff --git a/tests/agent/test_hook_composite.py b/tests/agent/test_hook_composite.py new file mode 100644 index 000000000..8a43a4249 --- /dev/null +++ b/tests/agent/test_hook_composite.py @@ -0,0 +1,330 @@ +"""Tests for CompositeHook fan-out, error isolation, and integration.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook + + +def _ctx() -> AgentHookContext: + return AgentHookContext(iteration=0, messages=[]) + + +# --------------------------------------------------------------------------- +# Fan-out: every hook is called in order +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_composite_fans_out_before_iteration(): + calls: list[str] = [] + + class H(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + calls.append(f"A:{context.iteration}") + + class H2(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + calls.append(f"B:{context.iteration}") + + hook = CompositeHook([H(), H2()]) + ctx = _ctx() + await hook.before_iteration(ctx) + assert calls == ["A:0", "B:0"] + + +@pytest.mark.asyncio +async def test_composite_fans_out_all_async_methods(): + """Verify all async methods fan out to every hook.""" + events: list[str] = [] + + class RecordingHook(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + events.append("before_iteration") + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + events.append(f"on_stream:{delta}") + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + events.append(f"on_stream_end:{resuming}") + + async def before_execute_tools(self, context: AgentHookContext) -> None: + events.append("before_execute_tools") + + async def after_iteration(self, context: AgentHookContext) -> None: + events.append("after_iteration") + + hook = CompositeHook([RecordingHook(), RecordingHook()]) + ctx = _ctx() + + await hook.before_iteration(ctx) + await hook.on_stream(ctx, "hi") + await hook.on_stream_end(ctx, resuming=True) + await hook.before_execute_tools(ctx) + await hook.after_iteration(ctx) + + assert events == [ + "before_iteration", "before_iteration", + "on_stream:hi", "on_stream:hi", + "on_stream_end:True", "on_stream_end:True", + "before_execute_tools", "before_execute_tools", + "after_iteration", "after_iteration", + ] + + +# --------------------------------------------------------------------------- +# Error isolation: one hook raises, others still run +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_composite_error_isolation_before_iteration(): + calls: list[str] = [] + + class Bad(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + raise RuntimeError("boom") + + class Good(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + calls.append("good") + + hook = CompositeHook([Bad(), Good()]) + await hook.before_iteration(_ctx()) + assert calls == ["good"] + + +@pytest.mark.asyncio +async def test_composite_error_isolation_on_stream(): + calls: list[str] = [] + + class Bad(AgentHook): + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + raise RuntimeError("stream-boom") + + class Good(AgentHook): + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + calls.append(delta) + + hook = CompositeHook([Bad(), Good()]) + await hook.on_stream(_ctx(), "delta") + assert calls == ["delta"] + + +@pytest.mark.asyncio +async def test_composite_error_isolation_all_async(): + """Error isolation for on_stream_end, before_execute_tools, after_iteration.""" + calls: list[str] = [] + + class Bad(AgentHook): + async def on_stream_end(self, context, *, resuming): + raise RuntimeError("err") + async def before_execute_tools(self, context): + raise RuntimeError("err") + async def after_iteration(self, context): + raise RuntimeError("err") + + class Good(AgentHook): + async def on_stream_end(self, context, *, resuming): + calls.append("on_stream_end") + async def before_execute_tools(self, context): + calls.append("before_execute_tools") + async def after_iteration(self, context): + calls.append("after_iteration") + + hook = CompositeHook([Bad(), Good()]) + ctx = _ctx() + await hook.on_stream_end(ctx, resuming=False) + await hook.before_execute_tools(ctx) + await hook.after_iteration(ctx) + assert calls == ["on_stream_end", "before_execute_tools", "after_iteration"] + + +# --------------------------------------------------------------------------- +# finalize_content: pipeline semantics (no error isolation) +# --------------------------------------------------------------------------- + + +def test_composite_finalize_content_pipeline(): + class Upper(AgentHook): + def finalize_content(self, context, content): + return content.upper() if content else content + + class Suffix(AgentHook): + def finalize_content(self, context, content): + return (content + "!") if content else content + + hook = CompositeHook([Upper(), Suffix()]) + result = hook.finalize_content(_ctx(), "hello") + assert result == "HELLO!" + + +def test_composite_finalize_content_none_passthrough(): + hook = CompositeHook([AgentHook()]) + assert hook.finalize_content(_ctx(), None) is None + + +def test_composite_finalize_content_ordering(): + """First hook transforms first, result feeds second hook.""" + steps: list[str] = [] + + class H1(AgentHook): + def finalize_content(self, context, content): + steps.append(f"H1:{content}") + return content.upper() + + class H2(AgentHook): + def finalize_content(self, context, content): + steps.append(f"H2:{content}") + return content + "!" + + hook = CompositeHook([H1(), H2()]) + result = hook.finalize_content(_ctx(), "hi") + assert result == "HI!" + assert steps == ["H1:hi", "H2:HI"] + + +# --------------------------------------------------------------------------- +# wants_streaming: any-semantics +# --------------------------------------------------------------------------- + + +def test_composite_wants_streaming_any_true(): + class No(AgentHook): + def wants_streaming(self): + return False + + class Yes(AgentHook): + def wants_streaming(self): + return True + + hook = CompositeHook([No(), Yes(), No()]) + assert hook.wants_streaming() is True + + +def test_composite_wants_streaming_all_false(): + hook = CompositeHook([AgentHook(), AgentHook()]) + assert hook.wants_streaming() is False + + +def test_composite_wants_streaming_empty(): + hook = CompositeHook([]) + assert hook.wants_streaming() is False + + +# --------------------------------------------------------------------------- +# Empty hooks list: behaves like no-op AgentHook +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_composite_empty_hooks_no_ops(): + hook = CompositeHook([]) + ctx = _ctx() + await hook.before_iteration(ctx) + await hook.on_stream(ctx, "delta") + await hook.on_stream_end(ctx, resuming=False) + await hook.before_execute_tools(ctx) + await hook.after_iteration(ctx) + assert hook.finalize_content(ctx, "test") == "test" + + +# --------------------------------------------------------------------------- +# Integration: AgentLoop with extra hooks +# --------------------------------------------------------------------------- + + +def _make_loop(tmp_path, hooks=None): + from nanobot.agent.loop import AgentLoop + from nanobot.bus.queue import MessageBus + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.generation.max_tokens = 4096 + + with patch("nanobot.agent.loop.ContextBuilder"), \ + patch("nanobot.agent.loop.SessionManager"), \ + patch("nanobot.agent.loop.SubagentManager") as mock_sub_mgr, \ + patch("nanobot.agent.loop.MemoryConsolidator"): + mock_sub_mgr.return_value.cancel_by_session = AsyncMock(return_value=0) + loop = AgentLoop( + bus=bus, provider=provider, workspace=tmp_path, hooks=hooks, + ) + return loop + + +@pytest.mark.asyncio +async def test_agent_loop_extra_hook_receives_calls(tmp_path): + """Extra hook passed to AgentLoop is called alongside core LoopHook.""" + from nanobot.providers.base import LLMResponse + + events: list[str] = [] + + class TrackingHook(AgentHook): + async def before_iteration(self, context): + events.append(f"before_iter:{context.iteration}") + + async def after_iteration(self, context): + events.append(f"after_iter:{context.iteration}") + + loop = _make_loop(tmp_path, hooks=[TrackingHook()]) + loop.provider.chat_with_retry = AsyncMock( + return_value=LLMResponse(content="done", tool_calls=[], usage={}) + ) + loop.tools.get_definitions = MagicMock(return_value=[]) + + content, tools_used, messages = await loop._run_agent_loop( + [{"role": "user", "content": "hi"}] + ) + + assert content == "done" + assert "before_iter:0" in events + assert "after_iter:0" in events + + +@pytest.mark.asyncio +async def test_agent_loop_extra_hook_error_isolation(tmp_path): + """A faulty extra hook does not crash the agent loop.""" + from nanobot.providers.base import LLMResponse + + class BadHook(AgentHook): + async def before_iteration(self, context): + raise RuntimeError("I am broken") + + loop = _make_loop(tmp_path, hooks=[BadHook()]) + loop.provider.chat_with_retry = AsyncMock( + return_value=LLMResponse(content="still works", tool_calls=[], usage={}) + ) + loop.tools.get_definitions = MagicMock(return_value=[]) + + content, _, _ = await loop._run_agent_loop( + [{"role": "user", "content": "hi"}] + ) + + assert content == "still works" + + +@pytest.mark.asyncio +async def test_agent_loop_no_hooks_backward_compat(tmp_path): + """Without hooks param, behavior is identical to before.""" + from nanobot.providers.base import LLMResponse, ToolCallRequest + + loop = _make_loop(tmp_path) + loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})], + )) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.tools.execute = AsyncMock(return_value="ok") + loop.max_iterations = 2 + + content, tools_used, _ = await loop._run_agent_loop([]) + assert content == ( + "I reached the maximum number of tool call iterations (2) " + "without completing the task. You can try breaking the task into smaller steps." + ) + assert tools_used == ["list_dir", "list_dir"] From 758c4e74c9d3f6e494d497a050f12b5d5bdad2f8 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Mon, 30 Mar 2026 17:57:49 +0000 Subject: [PATCH 11/22] fix(agent): preserve LoopHook error semantics when extra hooks are present --- nanobot/agent/loop.py | 43 +++++++++++++++++++++++++++++- tests/agent/test_hook_composite.py | 21 +++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 0e58fa557..c45257657 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -102,6 +102,47 @@ class LoopHook(AgentHook): return self._loop._strip_think(content) +class _LoopHookChain(AgentHook): + """Run the core loop hook first, then best-effort extra hooks. + + This preserves the historical failure behavior of ``LoopHook`` while still + letting user-supplied hooks opt into ``CompositeHook`` isolation. + """ + + __slots__ = ("_primary", "_extras") + + def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None: + self._primary = primary + self._extras = CompositeHook(extra_hooks) + + def wants_streaming(self) -> bool: + return self._primary.wants_streaming() or self._extras.wants_streaming() + + async def before_iteration(self, context: AgentHookContext) -> None: + await self._primary.before_iteration(context) + await self._extras.before_iteration(context) + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + await self._primary.on_stream(context, delta) + await self._extras.on_stream(context, delta) + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + await self._primary.on_stream_end(context, resuming=resuming) + await self._extras.on_stream_end(context, resuming=resuming) + + async def before_execute_tools(self, context: AgentHookContext) -> None: + await self._primary.before_execute_tools(context) + await self._extras.before_execute_tools(context) + + async def after_iteration(self, context: AgentHookContext) -> None: + await self._primary.after_iteration(context) + await self._extras.after_iteration(context) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + content = self._primary.finalize_content(context, content) + return self._extras.finalize_content(context, content) + + class AgentLoop: """ The agent loop is the core processing engine. @@ -294,7 +335,7 @@ class AgentLoop: message_id=message_id, ) hook: AgentHook = ( - CompositeHook([loop_hook, *self._extra_hooks]) + _LoopHookChain(loop_hook, self._extra_hooks) if self._extra_hooks else loop_hook ) diff --git a/tests/agent/test_hook_composite.py b/tests/agent/test_hook_composite.py index 8a43a4249..203c892fb 100644 --- a/tests/agent/test_hook_composite.py +++ b/tests/agent/test_hook_composite.py @@ -308,6 +308,27 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path): assert content == "still works" +@pytest.mark.asyncio +async def test_agent_loop_extra_hooks_do_not_swallow_loop_hook_errors(tmp_path): + """Extra hooks must not change the core LoopHook failure behavior.""" + from nanobot.providers.base import LLMResponse, ToolCallRequest + + loop = _make_loop(tmp_path, hooks=[AgentHook()]) + loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})], + usage={}, + )) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.tools.execute = AsyncMock(return_value="ok") + + async def bad_progress(*args, **kwargs): + raise RuntimeError("progress failed") + + with pytest.raises(RuntimeError, match="progress failed"): + await loop._run_agent_loop([], on_progress=bad_progress) + + @pytest.mark.asyncio async def test_agent_loop_no_hooks_backward_compat(tmp_path): """Without hooks param, behavior is identical to before.""" From 842b8b255dc472e55e206b3c2c04af5d29ffe8c3 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Mon, 30 Mar 2026 18:14:11 +0000 Subject: [PATCH 12/22] fix(agent): preserve core hook failure semantics --- nanobot/agent/__init__.py | 6 ++---- nanobot/agent/loop.py | 9 ++++----- nanobot/agent/subagent.py | 9 +++------ 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/nanobot/agent/__init__.py b/nanobot/agent/__init__.py index d3805805b..7d3ab2af4 100644 --- a/nanobot/agent/__init__.py +++ b/nanobot/agent/__init__.py @@ -2,10 +2,10 @@ from nanobot.agent.context import ContextBuilder from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook -from nanobot.agent.loop import AgentLoop, LoopHook +from nanobot.agent.loop import AgentLoop from nanobot.agent.memory import MemoryStore from nanobot.agent.skills import SkillsLoader -from nanobot.agent.subagent import SubagentHook, SubagentManager +from nanobot.agent.subagent import SubagentManager __all__ = [ "AgentHook", @@ -13,9 +13,7 @@ __all__ = [ "AgentLoop", "CompositeHook", "ContextBuilder", - "LoopHook", "MemoryStore", "SkillsLoader", - "SubagentHook", "SubagentManager", ] diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index c45257657..97d352cb8 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -37,12 +37,11 @@ if TYPE_CHECKING: from nanobot.cron.service import CronService -class LoopHook(AgentHook): +class _LoopHook(AgentHook): """Core lifecycle hook for the main agent loop. Handles streaming delta relay, progress reporting, tool-call logging, - and think-tag stripping. Public so downstream users can subclass or - compose it via :class:`CompositeHook`. + and think-tag stripping for the built-in agent path. """ def __init__( @@ -105,7 +104,7 @@ class LoopHook(AgentHook): class _LoopHookChain(AgentHook): """Run the core loop hook first, then best-effort extra hooks. - This preserves the historical failure behavior of ``LoopHook`` while still + This preserves the historical failure behavior of ``_LoopHook`` while still letting user-supplied hooks opt into ``CompositeHook`` isolation. """ @@ -325,7 +324,7 @@ class AgentLoop: ``resuming=True`` means tool calls follow (spinner should restart); ``resuming=False`` means this is the final response. """ - loop_hook = LoopHook( + loop_hook = _LoopHook( self, on_progress=on_progress, on_stream=on_stream, diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index 691f53820..c1aaa2d0d 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -21,11 +21,8 @@ from nanobot.config.schema import ExecToolConfig from nanobot.providers.base import LLMProvider -class SubagentHook(AgentHook): - """Logging-only hook for subagent execution. - - Public so downstream users can subclass or compose via :class:`CompositeHook`. - """ +class _SubagentHook(AgentHook): + """Logging-only hook for subagent execution.""" def __init__(self, task_id: str) -> None: self._task_id = task_id @@ -138,7 +135,7 @@ class SubagentManager: tools=tools, model=self.model, max_iterations=15, - hook=SubagentHook(task_id), + hook=_SubagentHook(task_id), max_iterations_message="Task completed but no final response was generated.", error_message=None, fail_on_tool_error=True, From 7fad14802e77983176a6c60649fcf3ff63ecc1ab Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Mon, 30 Mar 2026 18:46:11 +0000 Subject: [PATCH 13/22] feat: add Python SDK facade and per-session isolation --- README.md | 53 ++++++++--- core_agent_lines.sh | 7 +- docs/PYTHON_SDK.md | 136 ++++++++++++++++++++++++++++ nanobot/__init__.py | 4 + nanobot/api/server.py | 21 +++-- nanobot/nanobot.py | 170 +++++++++++++++++++++++++++++++++++ tests/test_nanobot_facade.py | 147 ++++++++++++++++++++++++++++++ 7 files changed, 515 insertions(+), 23 deletions(-) create mode 100644 docs/PYTHON_SDK.md create mode 100644 nanobot/nanobot.py create mode 100644 tests/test_nanobot_facade.py diff --git a/README.md b/README.md index 01bc11c25..8a8c864d0 100644 --- a/README.md +++ b/README.md @@ -115,6 +115,7 @@ - [Configuration](#️-configuration) - [Multiple Instances](#-multiple-instances) - [CLI Reference](#-cli-reference) +- [Python SDK](#-python-sdk) - [OpenAI-Compatible API](#-openai-compatible-api) - [Docker](#-docker) - [Linux Service](#-linux-service) @@ -1571,6 +1572,40 @@ The agent can also manage this file itself — ask it to "add a periodic task" a +## 🐍 Python SDK + +Use nanobot as a library — no CLI, no gateway, just Python: + +```python +from nanobot import Nanobot + +bot = Nanobot.from_config() +result = await bot.run("Summarize the README") +print(result.content) +``` + +Each call carries a `session_key` for conversation isolation — different keys get independent history: + +```python +await bot.run("hi", session_key="user-alice") +await bot.run("hi", session_key="task-42") +``` + +Add lifecycle hooks to observe or customize the agent: + +```python +from nanobot.agent import AgentHook, AgentHookContext + +class AuditHook(AgentHook): + async def before_execute_tools(self, ctx: AgentHookContext) -> None: + for tc in ctx.tool_calls: + print(f"[tool] {tc.name}") + +result = await bot.run("Hello", hooks=[AuditHook()]) +``` + +See [docs/PYTHON_SDK.md](docs/PYTHON_SDK.md) for the full SDK reference. + ## 🔌 OpenAI-Compatible API nanobot can expose a minimal OpenAI-compatible endpoint for local integrations: @@ -1580,11 +1615,11 @@ pip install "nanobot-ai[api]" nanobot serve ``` -By default, the API binds to `127.0.0.1:8900`. +By default, the API binds to `127.0.0.1:8900`. You can change this in `config.json`. ### Behavior -- Fixed session: all requests share the same nanobot session (`api:default`) +- Session isolation: pass `"session_id"` in the request body to isolate conversations; omit for a shared default session (`api:default`) - Single-message input: each request must contain exactly one `user` message - Fixed model: omit `model`, or pass the same model shown by `/v1/models` - No streaming: `stream=true` is not supported @@ -1601,12 +1636,8 @@ By default, the API binds to `127.0.0.1:8900`. curl http://127.0.0.1:8900/v1/chat/completions \ -H "Content-Type: application/json" \ -d '{ - "messages": [ - { - "role": "user", - "content": "hi" - } - ] + "messages": [{"role": "user", "content": "hi"}], + "session_id": "my-session" }' ``` @@ -1618,9 +1649,8 @@ import requests resp = requests.post( "http://127.0.0.1:8900/v1/chat/completions", json={ - "messages": [ - {"role": "user", "content": "hi"} - ] + "messages": [{"role": "user", "content": "hi"}], + "session_id": "my-session", # optional: isolate conversation }, timeout=120, ) @@ -1641,6 +1671,7 @@ client = OpenAI( resp = client.chat.completions.create( model="MiniMax-M2.7", messages=[{"role": "user", "content": "hi"}], + extra_body={"session_id": "my-session"}, # optional: isolate conversation ) print(resp.choices[0].message.content) ``` diff --git a/core_agent_lines.sh b/core_agent_lines.sh index 90f39aacc..0891347d5 100755 --- a/core_agent_lines.sh +++ b/core_agent_lines.sh @@ -1,5 +1,6 @@ #!/bin/bash -# Count core agent lines (excluding channels/, cli/, api/, providers/ adapters) +# Count core agent lines (excluding channels/, cli/, api/, providers/ adapters, +# and the high-level Python SDK facade) cd "$(dirname "$0")" || exit 1 echo "nanobot core agent line count" @@ -15,7 +16,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l) printf " %-16s %5s lines\n" "(root)" "$root" echo "" -total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/api/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l) +total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/api/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" ! -path "nanobot/nanobot.py" | xargs cat | wc -l) echo " Core total: $total lines" echo "" -echo " (excludes: channels/, cli/, api/, command/, providers/, skills/)" +echo " (excludes: channels/, cli/, api/, command/, providers/, skills/, nanobot.py)" diff --git a/docs/PYTHON_SDK.md b/docs/PYTHON_SDK.md new file mode 100644 index 000000000..357722e5e --- /dev/null +++ b/docs/PYTHON_SDK.md @@ -0,0 +1,136 @@ +# Python SDK + +Use nanobot programmatically — load config, run the agent, get results. + +## Quick Start + +```python +import asyncio +from nanobot import Nanobot + +async def main(): + bot = Nanobot.from_config() + result = await bot.run("What time is it in Tokyo?") + print(result.content) + +asyncio.run(main()) +``` + +## API + +### `Nanobot.from_config(config_path?, *, workspace?)` + +Create a `Nanobot` from a config file. + +| Param | Type | Default | Description | +|-------|------|---------|-------------| +| `config_path` | `str \| Path \| None` | `None` | Path to `config.json`. Defaults to `~/.nanobot/config.json`. | +| `workspace` | `str \| Path \| None` | `None` | Override workspace directory from config. | + +Raises `FileNotFoundError` if an explicit path doesn't exist. + +### `await bot.run(message, *, session_key?, hooks?)` + +Run the agent once. Returns a `RunResult`. + +| Param | Type | Default | Description | +|-------|------|---------|-------------| +| `message` | `str` | *(required)* | The user message to process. | +| `session_key` | `str` | `"sdk:default"` | Session identifier for conversation isolation. Different keys get independent history. | +| `hooks` | `list[AgentHook] \| None` | `None` | Lifecycle hooks for this run only. | + +```python +# Isolated sessions — each user gets independent conversation history +await bot.run("hi", session_key="user-alice") +await bot.run("hi", session_key="user-bob") +``` + +### `RunResult` + +| Field | Type | Description | +|-------|------|-------------| +| `content` | `str` | The agent's final text response. | +| `tools_used` | `list[str]` | Tool names invoked during the run. | +| `messages` | `list[dict]` | Raw message history (for debugging). | + +## Hooks + +Hooks let you observe or modify the agent loop without touching internals. + +Subclass `AgentHook` and override any method: + +| Method | When | +|--------|------| +| `before_iteration(ctx)` | Before each LLM call | +| `on_stream(ctx, delta)` | On each streamed token | +| `on_stream_end(ctx)` | When streaming finishes | +| `before_execute_tools(ctx)` | Before tool execution (inspect `ctx.tool_calls`) | +| `after_iteration(ctx, response)` | After each LLM response | +| `finalize_content(ctx, content)` | Transform final output text | + +### Example: Audit Hook + +```python +from nanobot.agent import AgentHook, AgentHookContext + +class AuditHook(AgentHook): + def __init__(self): + self.calls = [] + + async def before_execute_tools(self, ctx: AgentHookContext) -> None: + for tc in ctx.tool_calls: + self.calls.append(tc.name) + print(f"[audit] {tc.name}({tc.arguments})") + +hook = AuditHook() +result = await bot.run("List files in /tmp", hooks=[hook]) +print(f"Tools used: {hook.calls}") +``` + +### Composing Hooks + +Pass multiple hooks — they run in order, errors in one don't block others: + +```python +result = await bot.run("hi", hooks=[AuditHook(), MetricsHook()]) +``` + +Under the hood this uses `CompositeHook` for fan-out with error isolation. + +### `finalize_content` Pipeline + +Unlike the async methods (fan-out), `finalize_content` is a pipeline — each hook's output feeds the next: + +```python +class Censor(AgentHook): + def finalize_content(self, ctx, content): + return content.replace("secret", "***") if content else content +``` + +## Full Example + +```python +import asyncio +from nanobot import Nanobot +from nanobot.agent import AgentHook, AgentHookContext + +class TimingHook(AgentHook): + async def before_iteration(self, ctx: AgentHookContext) -> None: + import time + ctx.metadata["_t0"] = time.time() + + async def after_iteration(self, ctx, response) -> None: + import time + elapsed = time.time() - ctx.metadata.get("_t0", 0) + print(f"[timing] iteration took {elapsed:.2f}s") + +async def main(): + bot = Nanobot.from_config(workspace="/my/project") + result = await bot.run( + "Explain the main function", + hooks=[TimingHook()], + ) + print(result.content) + +asyncio.run(main()) +``` diff --git a/nanobot/__init__.py b/nanobot/__init__.py index 07efd09cf..11833c696 100644 --- a/nanobot/__init__.py +++ b/nanobot/__init__.py @@ -4,3 +4,7 @@ nanobot - A lightweight AI agent framework __version__ = "0.1.4.post6" __logo__ = "🐈" + +from nanobot.nanobot import Nanobot, RunResult + +__all__ = ["Nanobot", "RunResult"] diff --git a/nanobot/api/server.py b/nanobot/api/server.py index 34b73ad57..9494b6e31 100644 --- a/nanobot/api/server.py +++ b/nanobot/api/server.py @@ -91,9 +91,12 @@ async def handle_chat_completions(request: web.Request) -> web.Response: model_name: str = request.app.get("model_name", "nanobot") if (requested_model := body.get("model")) and requested_model != model_name: return _error_json(400, f"Only configured model '{model_name}' is available") - session_lock: asyncio.Lock = request.app["session_lock"] - logger.info("API request session_key={} content={}", API_SESSION_KEY, user_content[:80]) + session_key = f"api:{body['session_id']}" if body.get("session_id") else API_SESSION_KEY + session_locks: dict[str, asyncio.Lock] = request.app["session_locks"] + session_lock = session_locks.setdefault(session_key, asyncio.Lock()) + + logger.info("API request session_key={} content={}", session_key, user_content[:80]) _FALLBACK = "I've completed processing but have no response to give." @@ -103,7 +106,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response: response = await asyncio.wait_for( agent_loop.process_direct( content=user_content, - session_key=API_SESSION_KEY, + session_key=session_key, channel="api", chat_id=API_CHAT_ID, ), @@ -114,12 +117,12 @@ async def handle_chat_completions(request: web.Request) -> web.Response: if not response_text or not response_text.strip(): logger.warning( "Empty response for session {}, retrying", - API_SESSION_KEY, + session_key, ) retry_response = await asyncio.wait_for( agent_loop.process_direct( content=user_content, - session_key=API_SESSION_KEY, + session_key=session_key, channel="api", chat_id=API_CHAT_ID, ), @@ -129,17 +132,17 @@ async def handle_chat_completions(request: web.Request) -> web.Response: if not response_text or not response_text.strip(): logger.warning( "Empty response after retry for session {}, using fallback", - API_SESSION_KEY, + session_key, ) response_text = _FALLBACK except asyncio.TimeoutError: return _error_json(504, f"Request timed out after {timeout_s}s") except Exception: - logger.exception("Error processing request for session {}", API_SESSION_KEY) + logger.exception("Error processing request for session {}", session_key) return _error_json(500, "Internal server error", err_type="server_error") except Exception: - logger.exception("Unexpected API lock error for session {}", API_SESSION_KEY) + logger.exception("Unexpected API lock error for session {}", session_key) return _error_json(500, "Internal server error", err_type="server_error") return web.json_response(_chat_completion_response(response_text, model_name)) @@ -182,7 +185,7 @@ def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = app["agent_loop"] = agent_loop app["model_name"] = model_name app["request_timeout"] = request_timeout - app["session_lock"] = asyncio.Lock() + app["session_locks"] = {} # per-user locks, keyed by session_key app.router.add_post("/v1/chat/completions", handle_chat_completions) app.router.add_get("/v1/models", handle_models) diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py new file mode 100644 index 000000000..137688455 --- /dev/null +++ b/nanobot/nanobot.py @@ -0,0 +1,170 @@ +"""High-level programmatic interface to nanobot.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from nanobot.agent.hook import AgentHook +from nanobot.agent.loop import AgentLoop +from nanobot.bus.queue import MessageBus + + +@dataclass(slots=True) +class RunResult: + """Result of a single agent run.""" + + content: str + tools_used: list[str] + messages: list[dict[str, Any]] + + +class Nanobot: + """Programmatic facade for running the nanobot agent. + + Usage:: + + bot = Nanobot.from_config() + result = await bot.run("Summarize this repo", hooks=[MyHook()]) + print(result.content) + """ + + def __init__(self, loop: AgentLoop) -> None: + self._loop = loop + + @classmethod + def from_config( + cls, + config_path: str | Path | None = None, + *, + workspace: str | Path | None = None, + ) -> Nanobot: + """Create a Nanobot instance from a config file. + + Args: + config_path: Path to ``config.json``. Defaults to + ``~/.nanobot/config.json``. + workspace: Override the workspace directory from config. + """ + from nanobot.config.loader import load_config + from nanobot.config.schema import Config + + resolved: Path | None = None + if config_path is not None: + resolved = Path(config_path).expanduser().resolve() + if not resolved.exists(): + raise FileNotFoundError(f"Config not found: {resolved}") + + config: Config = load_config(resolved) + if workspace is not None: + config.agents.defaults.workspace = str( + Path(workspace).expanduser().resolve() + ) + + provider = _make_provider(config) + bus = MessageBus() + defaults = config.agents.defaults + + loop = AgentLoop( + bus=bus, + provider=provider, + workspace=config.workspace_path, + model=defaults.model, + max_iterations=defaults.max_tool_iterations, + context_window_tokens=defaults.context_window_tokens, + web_search_config=config.tools.web.search, + web_proxy=config.tools.web.proxy or None, + exec_config=config.tools.exec, + restrict_to_workspace=config.tools.restrict_to_workspace, + mcp_servers=config.tools.mcp_servers, + timezone=defaults.timezone, + ) + return cls(loop) + + async def run( + self, + message: str, + *, + session_key: str = "sdk:default", + hooks: list[AgentHook] | None = None, + ) -> RunResult: + """Run the agent once and return the result. + + Args: + message: The user message to process. + session_key: Session identifier for conversation isolation. + Different keys get independent history. + hooks: Optional lifecycle hooks for this run. + """ + prev = self._loop._extra_hooks + if hooks is not None: + self._loop._extra_hooks = list(hooks) + try: + response = await self._loop.process_direct( + message, session_key=session_key, + ) + finally: + self._loop._extra_hooks = prev + + content = (response.content if response else None) or "" + return RunResult(content=content, tools_used=[], messages=[]) + + +def _make_provider(config: Any) -> Any: + """Create the LLM provider from config (extracted from CLI).""" + from nanobot.providers.base import GenerationSettings + from nanobot.providers.registry import find_by_name + + model = config.agents.defaults.model + provider_name = config.get_provider_name(model) + p = config.get_provider(model) + spec = find_by_name(provider_name) if provider_name else None + backend = spec.backend if spec else "openai_compat" + + if backend == "azure_openai": + if not p or not p.api_key or not p.api_base: + raise ValueError("Azure OpenAI requires api_key and api_base in config.") + elif backend == "openai_compat" and not model.startswith("bedrock/"): + needs_key = not (p and p.api_key) + exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct) + if needs_key and not exempt: + raise ValueError(f"No API key configured for provider '{provider_name}'.") + + if backend == "openai_codex": + from nanobot.providers.openai_codex_provider import OpenAICodexProvider + + provider = OpenAICodexProvider(default_model=model) + elif backend == "azure_openai": + from nanobot.providers.azure_openai_provider import AzureOpenAIProvider + + provider = AzureOpenAIProvider( + api_key=p.api_key, api_base=p.api_base, default_model=model + ) + elif backend == "anthropic": + from nanobot.providers.anthropic_provider import AnthropicProvider + + provider = AnthropicProvider( + api_key=p.api_key if p else None, + api_base=config.get_api_base(model), + default_model=model, + extra_headers=p.extra_headers if p else None, + ) + else: + from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + provider = OpenAICompatProvider( + api_key=p.api_key if p else None, + api_base=config.get_api_base(model), + default_model=model, + extra_headers=p.extra_headers if p else None, + spec=spec, + ) + + defaults = config.agents.defaults + provider.generation = GenerationSettings( + temperature=defaults.temperature, + max_tokens=defaults.max_tokens, + reasoning_effort=defaults.reasoning_effort, + ) + return provider diff --git a/tests/test_nanobot_facade.py b/tests/test_nanobot_facade.py new file mode 100644 index 000000000..9d0d8a175 --- /dev/null +++ b/tests/test_nanobot_facade.py @@ -0,0 +1,147 @@ +"""Tests for the Nanobot programmatic facade.""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from nanobot.nanobot import Nanobot, RunResult + + +def _write_config(tmp_path: Path, overrides: dict | None = None) -> Path: + data = { + "providers": {"openrouter": {"apiKey": "sk-test-key"}}, + "agents": {"defaults": {"model": "openai/gpt-4.1"}}, + } + if overrides: + data.update(overrides) + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(data)) + return config_path + + +def test_from_config_missing_file(): + with pytest.raises(FileNotFoundError): + Nanobot.from_config("/nonexistent/config.json") + + +def test_from_config_creates_instance(tmp_path): + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + assert bot._loop is not None + assert bot._loop.workspace == tmp_path + + +def test_from_config_default_path(): + from nanobot.config.schema import Config + + with patch("nanobot.config.loader.load_config") as mock_load, \ + patch("nanobot.nanobot._make_provider") as mock_prov: + mock_load.return_value = Config() + mock_prov.return_value = MagicMock() + mock_prov.return_value.get_default_model.return_value = "test" + mock_prov.return_value.generation.max_tokens = 4096 + Nanobot.from_config() + mock_load.assert_called_once_with(None) + + +@pytest.mark.asyncio +async def test_run_returns_result(tmp_path): + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + from nanobot.bus.events import OutboundMessage + + mock_response = OutboundMessage( + channel="cli", chat_id="direct", content="Hello back!" + ) + bot._loop.process_direct = AsyncMock(return_value=mock_response) + + result = await bot.run("hi") + + assert isinstance(result, RunResult) + assert result.content == "Hello back!" + bot._loop.process_direct.assert_awaited_once_with("hi", session_key="sdk:default") + + +@pytest.mark.asyncio +async def test_run_with_hooks(tmp_path): + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.bus.events import OutboundMessage + + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + class TestHook(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + pass + + mock_response = OutboundMessage( + channel="cli", chat_id="direct", content="done" + ) + bot._loop.process_direct = AsyncMock(return_value=mock_response) + + result = await bot.run("hi", hooks=[TestHook()]) + + assert result.content == "done" + assert bot._loop._extra_hooks == [] + + +@pytest.mark.asyncio +async def test_run_hooks_restored_on_error(tmp_path): + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + from nanobot.agent.hook import AgentHook + + bot._loop.process_direct = AsyncMock(side_effect=RuntimeError("boom")) + original_hooks = bot._loop._extra_hooks + + with pytest.raises(RuntimeError): + await bot.run("hi", hooks=[AgentHook()]) + + assert bot._loop._extra_hooks is original_hooks + + +@pytest.mark.asyncio +async def test_run_none_response(tmp_path): + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + bot._loop.process_direct = AsyncMock(return_value=None) + + result = await bot.run("hi") + assert result.content == "" + + +def test_workspace_override(tmp_path): + config_path = _write_config(tmp_path) + custom_ws = tmp_path / "custom_workspace" + custom_ws.mkdir() + + bot = Nanobot.from_config(config_path, workspace=custom_ws) + assert bot._loop.workspace == custom_ws + + +@pytest.mark.asyncio +async def test_run_custom_session_key(tmp_path): + from nanobot.bus.events import OutboundMessage + + config_path = _write_config(tmp_path) + bot = Nanobot.from_config(config_path, workspace=tmp_path) + + mock_response = OutboundMessage( + channel="cli", chat_id="direct", content="ok" + ) + bot._loop.process_direct = AsyncMock(return_value=mock_response) + + await bot.run("hi", session_key="user-alice") + bot._loop.process_direct.assert_awaited_once_with("hi", session_key="user-alice") + + +def test_import_from_top_level(): + from nanobot import Nanobot as N, RunResult as R + assert N is Nanobot + assert R is RunResult From 8682b017e25af0eaf658d8b862222efb13a9b1e0 Mon Sep 17 00:00:00 2001 From: 04cb <0x04cb@gmail.com> Date: Tue, 31 Mar 2026 08:53:35 +0800 Subject: [PATCH 14/22] fix(tools): add Accept header for MCP SSE connections (#2651) --- nanobot/agent/tools/mcp.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/tools/mcp.py b/nanobot/agent/tools/mcp.py index c1c3e79a2..51533333e 100644 --- a/nanobot/agent/tools/mcp.py +++ b/nanobot/agent/tools/mcp.py @@ -170,7 +170,11 @@ async def connect_mcp_servers( timeout: httpx.Timeout | None = None, auth: httpx.Auth | None = None, ) -> httpx.AsyncClient: - merged_headers = {**(cfg.headers or {}), **(headers or {})} + merged_headers = { + "Accept": "application/json, text/event-stream", + **(cfg.headers or {}), + **(headers or {}), + } return httpx.AsyncClient( headers=merged_headers or None, follow_redirects=True, From 3f21e83af8056dcdb682cc7eee0a10b667460da1 Mon Sep 17 00:00:00 2001 From: 04cb <0x04cb@gmail.com> Date: Tue, 31 Mar 2026 08:53:39 +0800 Subject: [PATCH 15/22] fix(tools): clarify cron message param as agent instruction (#2566) --- nanobot/agent/tools/cron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index 9989af55f..00f726c08 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -74,7 +74,7 @@ class CronTool(Tool): "enum": ["add", "list", "remove"], "description": "Action to perform", }, - "message": {"type": "string", "description": "Reminder message (for add)"}, + "message": {"type": "string", "description": "Instruction for the agent to execute when the job triggers (e.g., 'Send a reminder to WeChat: xxx' or 'Check system status and report')"}, "every_seconds": { "type": "integer", "description": "Interval in seconds (for recurring tasks)", From 929ee094995f716bfa9cff6d69cdd5b1bd6dd7d9 Mon Sep 17 00:00:00 2001 From: 04cb <0x04cb@gmail.com> Date: Tue, 31 Mar 2026 08:53:44 +0800 Subject: [PATCH 16/22] fix(utils): ensure reasoning_content present with thinking_blocks (#2579) --- nanobot/utils/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index a10a4f18b..a7c2c2574 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -124,8 +124,8 @@ def build_assistant_message( msg: dict[str, Any] = {"role": "assistant", "content": content} if tool_calls: msg["tool_calls"] = tool_calls - if reasoning_content is not None: - msg["reasoning_content"] = reasoning_content + if reasoning_content is not None or thinking_blocks: + msg["reasoning_content"] = reasoning_content if reasoning_content is not None else "" if thinking_blocks: msg["thinking_blocks"] = thinking_blocks return msg From c3c1424db35e1158377c8d2beb7168d3dd104573 Mon Sep 17 00:00:00 2001 From: "zhangxiaoyu.york" Date: Tue, 31 Mar 2026 00:09:01 +0800 Subject: [PATCH 17/22] fix:register exec when enable exec_config --- nanobot/agent/subagent.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index c1aaa2d0d..9d936f034 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -115,12 +115,13 @@ class SubagentManager: tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir)) - tools.register(ExecTool( - working_dir=str(self.workspace), - timeout=self.exec_config.timeout, - restrict_to_workspace=self.restrict_to_workspace, - path_append=self.exec_config.path_append, - )) + if self.exec_config.enable: + tools.register(ExecTool( + working_dir=str(self.workspace), + timeout=self.exec_config.timeout, + restrict_to_workspace=self.restrict_to_workspace, + path_append=self.exec_config.path_append, + )) tools.register(WebSearchTool(config=self.web_search_config, proxy=self.web_proxy)) tools.register(WebFetchTool(proxy=self.web_proxy)) From 351e3720b6c65ab12b4eba4fd2eb859c0096042a Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 31 Mar 2026 04:11:54 +0000 Subject: [PATCH 18/22] test(agent): cover disabled subagent exec tool Add a regression test for the maintainer fix so subagents cannot register ExecTool when exec support is disabled. Made-with: Cursor --- tests/agent/test_task_cancel.py | 34 +++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/agent/test_task_cancel.py b/tests/agent/test_task_cancel.py index 8894cd973..4902a4c80 100644 --- a/tests/agent/test_task_cancel.py +++ b/tests/agent/test_task_cancel.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -222,6 +223,39 @@ class TestSubagentCancellation: assert assistant_messages[0]["reasoning_content"] == "hidden reasoning" assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}] + @pytest.mark.asyncio + async def test_subagent_exec_tool_not_registered_when_disabled(self, tmp_path): + from nanobot.agent.subagent import SubagentManager + from nanobot.bus.queue import MessageBus + from nanobot.config.schema import ExecToolConfig + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + mgr = SubagentManager( + provider=provider, + workspace=tmp_path, + bus=bus, + exec_config=ExecToolConfig(enable=False), + ) + mgr._announce_result = AsyncMock() + + async def fake_run(spec): + assert spec.tools.get("exec") is None + return SimpleNamespace( + stop_reason="done", + final_content="done", + error=None, + tool_events=[], + ) + + mgr.runner.run = AsyncMock(side_effect=fake_run) + + await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) + + mgr.runner.run.assert_awaited_once() + mgr._announce_result.assert_awaited_once() + @pytest.mark.asyncio async def test_subagent_announces_error_when_tool_execution_fails(self, monkeypatch, tmp_path): from nanobot.agent.subagent import SubagentManager From b94d4c0509e1d273703a5fb2c05f3b6e630e5668 Mon Sep 17 00:00:00 2001 From: npodbielski Date: Fri, 27 Mar 2026 08:12:14 +0100 Subject: [PATCH 19/22] feat(matrix): streaming support (#2447) * Added streaming message support with incremental updates for Matrix channel * Improve Matrix message handling and add tests * Adjust Matrix streaming edit interval to 2 seconds --------- Co-authored-by: natan --- nanobot/channels/matrix.py | 107 +++++++++++- tests/channels/test_matrix_channel.py | 225 +++++++++++++++++++++++++- 2 files changed, 323 insertions(+), 9 deletions(-) diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index 98926735e..dcece1043 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -3,6 +3,8 @@ import asyncio import logging import mimetypes +import time +from dataclasses import dataclass from pathlib import Path from typing import Any, Literal, TypeAlias @@ -28,8 +30,8 @@ try: RoomSendError, RoomTypingError, SyncError, - UploadError, - ) + UploadError, RoomSendResponse, +) from nio.crypto.attachments import decrypt_attachment from nio.exceptions import EncryptionError except ImportError as e: @@ -97,6 +99,22 @@ MATRIX_HTML_CLEANER = nh3.Cleaner( link_rel="noopener noreferrer", ) +@dataclass +class _StreamBuf: + """ + Represents a buffer for managing LLM response stream data. + + :ivar text: Stores the text content of the buffer. + :type text: str + :ivar event_id: Identifier for the associated event. None indicates no + specific event association. + :type event_id: str | None + :ivar last_edit: Timestamp of the most recent edit to the buffer. + :type last_edit: float + """ + text: str = "" + event_id: str | None = None + last_edit: float = 0.0 def _render_markdown_html(text: str) -> str | None: """Render markdown to sanitized HTML; returns None for plain text.""" @@ -114,12 +132,36 @@ def _render_markdown_html(text: str) -> str | None: return formatted -def _build_matrix_text_content(text: str) -> dict[str, object]: - """Build Matrix m.text payload with optional HTML formatted_body.""" +def _build_matrix_text_content(text: str, event_id: str | None = None) -> dict[str, object]: + """ + Constructs and returns a dictionary representing the matrix text content with optional + HTML formatting and reference to an existing event for replacement. This function is + primarily used to create content payloads compatible with the Matrix messaging protocol. + + :param text: The plain text content to include in the message. + :type text: str + :param event_id: Optional ID of the event to replace. If provided, the function will + include information indicating that the message is a replacement of the specified + event. + :type event_id: str | None + :return: A dictionary containing the matrix text content, potentially enriched with + HTML formatting and replacement metadata if applicable. + :rtype: dict[str, object] + """ content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}} if html := _render_markdown_html(text): content["format"] = MATRIX_HTML_FORMAT content["formatted_body"] = html + if event_id: + content["m.new_content"] = { + "body": text, + "msgtype": "m.text" + } + content["m.relates_to"] = { + "rel_type": "m.replace", + "event_id": event_id + } + return content @@ -159,7 +201,8 @@ class MatrixConfig(Base): allow_from: list[str] = Field(default_factory=list) group_policy: Literal["open", "mention", "allowlist"] = "open" group_allow_from: list[str] = Field(default_factory=list) - allow_room_mentions: bool = False + allow_room_mentions: bool = False, + streaming: bool = False class MatrixChannel(BaseChannel): @@ -167,6 +210,8 @@ class MatrixChannel(BaseChannel): name = "matrix" display_name = "Matrix" + _STREAM_EDIT_INTERVAL = 2 # min seconds between edit_message_text calls + monotonic_time = time.monotonic @classmethod def default_config(cls) -> dict[str, Any]: @@ -192,6 +237,8 @@ class MatrixChannel(BaseChannel): ) self._server_upload_limit_bytes: int | None = None self._server_upload_limit_checked = False + self._stream_bufs: dict[str, _StreamBuf] = {} + async def start(self) -> None: """Start Matrix client and begin sync loop.""" @@ -297,14 +344,17 @@ class MatrixChannel(BaseChannel): room = getattr(self.client, "rooms", {}).get(room_id) return bool(getattr(room, "encrypted", False)) - async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None: + async def _send_room_content(self, room_id: str, + content: dict[str, Any]) -> None | RoomSendResponse | RoomSendError: """Send m.room.message with E2EE options.""" if not self.client: - return + return None kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content} + if self.config.e2ee_enabled: kwargs["ignore_unverified_devices"] = True - await self.client.room_send(**kwargs) + response = await self.client.room_send(**kwargs) + return response async def _resolve_server_upload_limit_bytes(self) -> int | None: """Query homeserver upload limit once per channel lifecycle.""" @@ -414,6 +464,47 @@ class MatrixChannel(BaseChannel): if not is_progress: await self._stop_typing_keepalive(msg.chat_id, clear_typing=True) + async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: + meta = metadata or {} + relates_to = self._build_thread_relates_to(metadata) + + if meta.get("_stream_end"): + buf = self._stream_bufs.pop(chat_id, None) + if not buf or not buf.event_id or not buf.text: + return + + await self._stop_typing_keepalive(chat_id, clear_typing=True) + + content = _build_matrix_text_content(buf.text, buf.event_id) + if relates_to: + content["m.relates_to"] = relates_to + await self._send_room_content(chat_id, content) + return + + buf = self._stream_bufs.get(chat_id) + if buf is None: + buf = _StreamBuf() + self._stream_bufs[chat_id] = buf + buf.text += delta + + if not buf.text.strip(): + return + + now = self.monotonic_time() + + if not buf.last_edit or (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL: + try: + content = _build_matrix_text_content(buf.text, buf.event_id) + response = await self._send_room_content(chat_id, content) + buf.last_edit = now + if not buf.event_id: + # we are editing the same message all the time, so only the first time the event id needs to be set + buf.event_id = response.event_id + except Exception: + await self._stop_typing_keepalive(metadata["room_id"], clear_typing=True) + pass + + def _register_event_callbacks(self) -> None: self.client.add_event_callback(self._on_message, RoomMessageText) self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER) diff --git a/tests/channels/test_matrix_channel.py b/tests/channels/test_matrix_channel.py index dd5e97d90..3ad65e76b 100644 --- a/tests/channels/test_matrix_channel.py +++ b/tests/channels/test_matrix_channel.py @@ -3,6 +3,9 @@ from pathlib import Path from types import SimpleNamespace import pytest +from nio import RoomSendResponse + +from nanobot.channels.matrix import _build_matrix_text_content # Check optional matrix dependencies before importing try: @@ -65,6 +68,7 @@ class _FakeAsyncClient: self.raise_on_send = False self.raise_on_typing = False self.raise_on_upload = False + self.room_send_response: RoomSendResponse | None = RoomSendResponse(event_id="", room_id="") def add_event_callback(self, callback, event_type) -> None: self.callbacks.append((callback, event_type)) @@ -87,7 +91,7 @@ class _FakeAsyncClient: message_type: str, content: dict[str, object], ignore_unverified_devices: object = _ROOM_SEND_UNSET, - ) -> None: + ) -> RoomSendResponse: call: dict[str, object] = { "room_id": room_id, "message_type": message_type, @@ -98,6 +102,7 @@ class _FakeAsyncClient: self.room_send_calls.append(call) if self.raise_on_send: raise RuntimeError("send failed") + return self.room_send_response async def room_typing( self, @@ -520,6 +525,7 @@ async def test_on_message_room_mention_requires_opt_in() -> None: source={"content": {"m.mentions": {"room": True}}}, ) + channel.config.allow_room_mentions = False await channel._on_message(room, room_mention_event) assert handled == [] assert client.typing_calls == [] @@ -1322,3 +1328,220 @@ async def test_send_keeps_plaintext_only_for_plain_text() -> None: "body": text, "m.mentions": {}, } + + +def test_build_matrix_text_content_basic_text() -> None: + """Test basic text content without HTML formatting.""" + result = _build_matrix_text_content("Hello, World!") + expected = { + "msgtype": "m.text", + "body": "Hello, World!", + "m.mentions": {} + } + assert expected == result + + +def test_build_matrix_text_content_with_markdown() -> None: + """Test text content with markdown that renders to HTML.""" + text = "*Hello* **World**" + result = _build_matrix_text_content(text) + assert "msgtype" in result + assert "body" in result + assert result["body"] == text + assert "format" in result + assert result["format"] == "org.matrix.custom.html" + assert "formatted_body" in result + assert isinstance(result["formatted_body"], str) + assert len(result["formatted_body"]) > 0 + + +def test_build_matrix_text_content_with_event_id() -> None: + """Test text content with event_id for message replacement.""" + event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + result = _build_matrix_text_content("Updated message", event_id) + assert "msgtype" in result + assert "body" in result + assert result["m.new_content"] + assert result["m.new_content"]["body"] == "Updated message" + assert result["m.relates_to"]["rel_type"] == "m.replace" + assert result["m.relates_to"]["event_id"] == event_id + + +def test_build_matrix_text_content_no_event_id() -> None: + """Test that when event_id is not provided, no extra properties are added.""" + result = _build_matrix_text_content("Regular message") + + # Basic required properties should be present + assert "msgtype" in result + assert "body" in result + assert result["body"] == "Regular message" + + # Extra properties for replacement should NOT be present + assert "m.relates_to" not in result + assert "m.new_content" not in result + assert "format" not in result + assert "formatted_body" not in result + + +def test_build_matrix_text_content_plain_text_no_html() -> None: + """Test plain text that should not include HTML formatting.""" + result = _build_matrix_text_content("Simple plain text") + assert "msgtype" in result + assert "body" in result + assert "format" not in result + assert "formatted_body" not in result + + +@pytest.mark.asyncio +async def test_send_room_content_returns_room_send_response(): + """Test that _send_room_content returns the response from client.room_send.""" + client = _FakeAsyncClient("", "", "", None) + channel = MatrixChannel(_make_config(), MessageBus()) + channel.client = client + + room_id = "!test_room:matrix.org" + content = {"msgtype": "m.text", "body": "Hello World"} + + result = await channel._send_room_content(room_id, content) + + assert result is client.room_send_response + + +@pytest.mark.asyncio +async def test_send_delta_creates_stream_buffer_and_sends_initial_message() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + await channel.send_delta("!room:matrix.org", "Hello") + + assert "!room:matrix.org" in channel._stream_bufs + buf = channel._stream_bufs["!room:matrix.org"] + assert buf.text == "Hello" + assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + assert len(client.room_send_calls) == 1 + assert client.room_send_calls[0]["content"]["body"] == "Hello" + + +@pytest.mark.asyncio +async def test_send_delta_appends_without_sending_before_edit_interval(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + now = 100.0 + monkeypatch.setattr(channel, "monotonic_time", lambda: now) + + await channel.send_delta("!room:matrix.org", "Hello") + assert len(client.room_send_calls) == 1 + + await channel.send_delta("!room:matrix.org", " world") + assert len(client.room_send_calls) == 1 + + buf = channel._stream_bufs["!room:matrix.org"] + assert buf.text == "Hello world" + assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + +@pytest.mark.asyncio +async def test_send_delta_edits_again_after_interval(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + times = [100.0, 102.0, 104.0, 106.0, 108.0] + times.reverse() + monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop()) + + await channel.send_delta("!room:matrix.org", "Hello") + await channel.send_delta("!room:matrix.org", " world") + + assert len(client.room_send_calls) == 2 + first_content = client.room_send_calls[0]["content"] + second_content = client.room_send_calls[1]["content"] + + assert "body" in first_content + assert first_content["body"] == "Hello" + assert "m.relates_to" not in first_content + + assert "body" in second_content + assert "m.relates_to" in second_content + assert second_content["body"] == "Hello world" + assert second_content["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo", + } + + +@pytest.mark.asyncio +async def test_send_delta_stream_end_replaces_existing_message() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + channel._stream_bufs["!room:matrix.org"] = matrix_module._StreamBuf( + text="Final text", + event_id="event-1", + last_edit=100.0, + ) + + await channel.send_delta("!room:matrix.org", "", {"_stream_end": True}) + + assert "!room:matrix.org" not in channel._stream_bufs + assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS) + assert len(client.room_send_calls) == 1 + assert client.room_send_calls[0]["content"]["body"] == "Final text" + assert client.room_send_calls[0]["content"]["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "event-1", + } + + +@pytest.mark.asyncio +async def test_send_delta_stream_end_noop_when_buffer_missing() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + await channel.send_delta("!room:matrix.org", "", {"_stream_end": True}) + + assert client.room_send_calls == [] + assert client.typing_calls == [] + + +@pytest.mark.asyncio +async def test_send_delta_on_error_stops_typing(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.raise_on_send = True + channel.client = client + + now = 100.0 + monkeypatch.setattr(channel, "monotonic_time", lambda: now) + + await channel.send_delta("!room:matrix.org", "Hello", {"room_id": "!room:matrix.org"}) + + assert "!room:matrix.org" in channel._stream_bufs + assert channel._stream_bufs["!room:matrix.org"].text == "Hello" + assert len(client.room_send_calls) == 1 + + assert len(client.typing_calls) == 1 + + +@pytest.mark.asyncio +async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + now = 100.0 + monkeypatch.setattr(channel, "monotonic_time", lambda: now) + + await channel.send_delta("!room:matrix.org", " ") + + assert "!room:matrix.org" in channel._stream_bufs + assert channel._stream_bufs["!room:matrix.org"].text == " " + assert client.room_send_calls == [] \ No newline at end of file From 0506e6c1c1fe908bbfca46408f5c8ff3b3ba8ab9 Mon Sep 17 00:00:00 2001 From: Paresh Mathur Date: Fri, 27 Mar 2026 02:51:45 +0100 Subject: [PATCH 20/22] feat(discord): Use `discord.py` for stable discord channel (#2486) Co-authored-by: Pares Mathur Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- nanobot/channels/discord.py | 665 +++++++++++++----------- nanobot/command/builtin.py | 17 +- pyproject.toml | 3 + tests/channels/test_discord_channel.py | 676 +++++++++++++++++++++++++ 4 files changed, 1061 insertions(+), 300 deletions(-) create mode 100644 tests/channels/test_discord_channel.py diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index 82eafcc00..ef7d41d77 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -1,25 +1,37 @@ -"""Discord channel implementation using Discord Gateway websocket.""" +"""Discord channel implementation using discord.py.""" + +from __future__ import annotations import asyncio -import json +import importlib.util from pathlib import Path -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal -import httpx -from pydantic import Field -import websockets from loguru import logger +from pydantic import Field from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel +from nanobot.command.builtin import build_help_text from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base -from nanobot.utils.helpers import split_message +from nanobot.utils.helpers import safe_filename, split_message + +DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None +if TYPE_CHECKING: + import discord + from discord import app_commands + from discord.abc import Messageable + +if DISCORD_AVAILABLE: + import discord + from discord import app_commands + from discord.abc import Messageable -DISCORD_API_BASE = "https://discord.com/api/v10" MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB MAX_MESSAGE_LEN = 2000 # Discord message character limit +TYPING_INTERVAL_S = 8 class DiscordConfig(Base): @@ -28,13 +40,202 @@ class DiscordConfig(Base): enabled: bool = False token: str = "" allow_from: list[str] = Field(default_factory=list) - gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json" intents: int = 37377 group_policy: Literal["mention", "open"] = "mention" +if DISCORD_AVAILABLE: + + class DiscordBotClient(discord.Client): + """discord.py client that forwards events to the channel.""" + + def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None: + super().__init__(intents=intents) + self._channel = channel + self.tree = app_commands.CommandTree(self) + self._register_app_commands() + + async def on_ready(self) -> None: + self._channel._bot_user_id = str(self.user.id) if self.user else None + logger.info("Discord bot connected as user {}", self._channel._bot_user_id) + try: + synced = await self.tree.sync() + logger.info("Discord app commands synced: {}", len(synced)) + except Exception as e: + logger.warning("Discord app command sync failed: {}", e) + + async def on_message(self, message: discord.Message) -> None: + await self._channel._handle_discord_message(message) + + async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool: + """Send an ephemeral interaction response and report success.""" + try: + await interaction.response.send_message(text, ephemeral=True) + return True + except Exception as e: + logger.warning("Discord interaction response failed: {}", e) + return False + + async def _forward_slash_command( + self, + interaction: discord.Interaction, + command_text: str, + ) -> None: + sender_id = str(interaction.user.id) + channel_id = interaction.channel_id + + if channel_id is None: + logger.warning("Discord slash command missing channel_id: {}", command_text) + return + + if not self._channel.is_allowed(sender_id): + await self._reply_ephemeral(interaction, "You are not allowed to use this bot.") + return + + await self._reply_ephemeral(interaction, f"Processing {command_text}...") + + await self._channel._handle_message( + sender_id=sender_id, + chat_id=str(channel_id), + content=command_text, + metadata={ + "interaction_id": str(interaction.id), + "guild_id": str(interaction.guild_id) if interaction.guild_id else None, + "is_slash_command": True, + }, + ) + + def _register_app_commands(self) -> None: + commands = ( + ("new", "Start a new conversation", "/new"), + ("stop", "Stop the current task", "/stop"), + ("restart", "Restart the bot", "/restart"), + ("status", "Show bot status", "/status"), + ) + + for name, description, command_text in commands: + @self.tree.command(name=name, description=description) + async def command_handler( + interaction: discord.Interaction, + _command_text: str = command_text, + ) -> None: + await self._forward_slash_command(interaction, _command_text) + + @self.tree.command(name="help", description="Show available commands") + async def help_command(interaction: discord.Interaction) -> None: + sender_id = str(interaction.user.id) + if not self._channel.is_allowed(sender_id): + await self._reply_ephemeral(interaction, "You are not allowed to use this bot.") + return + await self._reply_ephemeral(interaction, build_help_text()) + + @self.tree.error + async def on_app_command_error( + interaction: discord.Interaction, + error: app_commands.AppCommandError, + ) -> None: + command_name = interaction.command.qualified_name if interaction.command else "?" + logger.warning( + "Discord app command failed user={} channel={} cmd={} error={}", + interaction.user.id, + interaction.channel_id, + command_name, + error, + ) + + async def send_outbound(self, msg: OutboundMessage) -> None: + """Send a nanobot outbound message using Discord transport rules.""" + channel_id = int(msg.chat_id) + + channel = self.get_channel(channel_id) + if channel is None: + try: + channel = await self.fetch_channel(channel_id) + except Exception as e: + logger.warning("Discord channel {} unavailable: {}", msg.chat_id, e) + return + + reference, mention_settings = self._build_reply_context(channel, msg.reply_to) + sent_media = False + failed_media: list[str] = [] + + for index, media_path in enumerate(msg.media or []): + if await self._send_file( + channel, + media_path, + reference=reference if index == 0 else None, + mention_settings=mention_settings, + ): + sent_media = True + else: + failed_media.append(Path(media_path).name) + + for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)): + kwargs: dict[str, Any] = {"content": chunk} + if index == 0 and reference is not None and not sent_media: + kwargs["reference"] = reference + kwargs["allowed_mentions"] = mention_settings + await channel.send(**kwargs) + + async def _send_file( + self, + channel: Messageable, + file_path: str, + *, + reference: discord.PartialMessage | None, + mention_settings: discord.AllowedMentions, + ) -> bool: + """Send a file attachment via discord.py.""" + path = Path(file_path) + if not path.is_file(): + logger.warning("Discord file not found, skipping: {}", file_path) + return False + + if path.stat().st_size > MAX_ATTACHMENT_BYTES: + logger.warning("Discord file too large (>20MB), skipping: {}", path.name) + return False + + try: + kwargs: dict[str, Any] = {"file": discord.File(path)} + if reference is not None: + kwargs["reference"] = reference + kwargs["allowed_mentions"] = mention_settings + await channel.send(**kwargs) + logger.info("Discord file sent: {}", path.name) + return True + except Exception as e: + logger.error("Error sending Discord file {}: {}", path.name, e) + return False + + @staticmethod + def _build_chunks(content: str, failed_media: list[str], sent_media: bool) -> list[str]: + """Build outbound text chunks, including attachment-failure fallback text.""" + chunks = split_message(content, MAX_MESSAGE_LEN) + if chunks or not failed_media or sent_media: + return chunks + fallback = "\n".join(f"[attachment: {name} - send failed]" for name in failed_media) + return split_message(fallback, MAX_MESSAGE_LEN) + + @staticmethod + def _build_reply_context( + channel: Messageable, + reply_to: str | None, + ) -> tuple[discord.PartialMessage | None, discord.AllowedMentions]: + """Build reply context for outbound messages.""" + mention_settings = discord.AllowedMentions(replied_user=False) + if not reply_to: + return None, mention_settings + try: + message_id = int(reply_to) + except ValueError: + logger.warning("Invalid Discord reply target: {}", reply_to) + return None, mention_settings + + return channel.get_partial_message(message_id), mention_settings + + class DiscordChannel(BaseChannel): - """Discord channel using Gateway websocket.""" + """Discord channel using discord.py.""" name = "discord" display_name = "Discord" @@ -43,353 +244,229 @@ class DiscordChannel(BaseChannel): def default_config(cls) -> dict[str, Any]: return DiscordConfig().model_dump(by_alias=True) + @staticmethod + def _channel_key(channel_or_id: Any) -> str: + """Normalize channel-like objects and ids to a stable string key.""" + channel_id = getattr(channel_or_id, "id", channel_or_id) + return str(channel_id) + def __init__(self, config: Any, bus: MessageBus): if isinstance(config, dict): config = DiscordConfig.model_validate(config) super().__init__(config, bus) self.config: DiscordConfig = config - self._ws: websockets.WebSocketClientProtocol | None = None - self._seq: int | None = None - self._heartbeat_task: asyncio.Task | None = None - self._typing_tasks: dict[str, asyncio.Task] = {} - self._http: httpx.AsyncClient | None = None + self._client: DiscordBotClient | None = None + self._typing_tasks: dict[str, asyncio.Task[None]] = {} self._bot_user_id: str | None = None async def start(self) -> None: - """Start the Discord gateway connection.""" + """Start the Discord client.""" + if not DISCORD_AVAILABLE: + logger.error("discord.py not installed. Run: pip install nanobot-ai[discord]") + return + if not self.config.token: logger.error("Discord bot token not configured") return - self._running = True - self._http = httpx.AsyncClient(timeout=30.0) + try: + intents = discord.Intents.none() + intents.value = self.config.intents + self._client = DiscordBotClient(self, intents=intents) + except Exception as e: + logger.error("Failed to initialize Discord client: {}", e) + self._client = None + self._running = False + return - while self._running: - try: - logger.info("Connecting to Discord gateway...") - async with websockets.connect(self.config.gateway_url) as ws: - self._ws = ws - await self._gateway_loop() - except asyncio.CancelledError: - break - except Exception as e: - logger.warning("Discord gateway error: {}", e) - if self._running: - logger.info("Reconnecting to Discord gateway in 5 seconds...") - await asyncio.sleep(5) + self._running = True + logger.info("Starting Discord client via discord.py...") + + try: + await self._client.start(self.config.token) + except asyncio.CancelledError: + raise + except Exception as e: + logger.error("Discord client startup failed: {}", e) + finally: + self._running = False + await self._reset_runtime_state(close_client=True) async def stop(self) -> None: """Stop the Discord channel.""" self._running = False - if self._heartbeat_task: - self._heartbeat_task.cancel() - self._heartbeat_task = None - for task in self._typing_tasks.values(): - task.cancel() - self._typing_tasks.clear() - if self._ws: - await self._ws.close() - self._ws = None - if self._http: - await self._http.aclose() - self._http = None + await self._reset_runtime_state(close_client=True) async def send(self, msg: OutboundMessage) -> None: - """Send a message through Discord REST API, including file attachments.""" - if not self._http: - logger.warning("Discord HTTP client not initialized") + """Send a message through Discord using discord.py.""" + client = self._client + if client is None or not client.is_ready(): + logger.warning("Discord client not ready; dropping outbound message") return - url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages" - headers = {"Authorization": f"Bot {self.config.token}"} + is_progress = bool((msg.metadata or {}).get("_progress")) + try: + await client.send_outbound(msg) + except Exception as e: + logger.error("Error sending Discord message: {}", e) + finally: + if not is_progress: + await self._stop_typing(msg.chat_id) + + async def _handle_discord_message(self, message: discord.Message) -> None: + """Handle incoming Discord messages from discord.py.""" + if message.author.bot: + return + + sender_id = str(message.author.id) + channel_id = self._channel_key(message.channel) + content = message.content or "" + + if not self._should_accept_inbound(message, sender_id, content): + return + + media_paths, attachment_markers = await self._download_attachments(message.attachments) + full_content = self._compose_inbound_content(content, attachment_markers) + metadata = self._build_inbound_metadata(message) + + await self._start_typing(message.channel) try: - sent_media = False - failed_media: list[str] = [] + await self._handle_message( + sender_id=sender_id, + chat_id=channel_id, + content=full_content, + media=media_paths, + metadata=metadata, + ) + except Exception: + await self._stop_typing(channel_id) + raise - # Send file attachments first - for media_path in msg.media or []: - if await self._send_file(url, headers, media_path, reply_to=msg.reply_to): - sent_media = True - else: - failed_media.append(Path(media_path).name) + async def _on_message(self, message: discord.Message) -> None: + """Backward-compatible alias for legacy tests/callers.""" + await self._handle_discord_message(message) - # Send text content - chunks = split_message(msg.content or "", MAX_MESSAGE_LEN) - if not chunks and failed_media and not sent_media: - chunks = split_message( - "\n".join(f"[attachment: {name} - send failed]" for name in failed_media), - MAX_MESSAGE_LEN, - ) - if not chunks: - return - - for i, chunk in enumerate(chunks): - payload: dict[str, Any] = {"content": chunk} - - # Let the first successful attachment carry the reply if present. - if i == 0 and msg.reply_to and not sent_media: - payload["message_reference"] = {"message_id": msg.reply_to} - payload["allowed_mentions"] = {"replied_user": False} - - if not await self._send_payload(url, headers, payload): - break # Abort remaining chunks on failure - finally: - await self._stop_typing(msg.chat_id) - - async def _send_payload( - self, url: str, headers: dict[str, str], payload: dict[str, Any] - ) -> bool: - """Send a single Discord API payload with retry on rate-limit. Returns True on success.""" - for attempt in range(3): - try: - response = await self._http.post(url, headers=headers, json=payload) - if response.status_code == 429: - data = response.json() - retry_after = float(data.get("retry_after", 1.0)) - logger.warning("Discord rate limited, retrying in {}s", retry_after) - await asyncio.sleep(retry_after) - continue - response.raise_for_status() - return True - except Exception as e: - if attempt == 2: - logger.error("Error sending Discord message: {}", e) - else: - await asyncio.sleep(1) - return False - - async def _send_file( + def _should_accept_inbound( self, - url: str, - headers: dict[str, str], - file_path: str, - reply_to: str | None = None, + message: discord.Message, + sender_id: str, + content: str, ) -> bool: - """Send a file attachment via Discord REST API using multipart/form-data.""" - path = Path(file_path) - if not path.is_file(): - logger.warning("Discord file not found, skipping: {}", file_path) - return False - - if path.stat().st_size > MAX_ATTACHMENT_BYTES: - logger.warning("Discord file too large (>20MB), skipping: {}", path.name) - return False - - payload_json: dict[str, Any] = {} - if reply_to: - payload_json["message_reference"] = {"message_id": reply_to} - payload_json["allowed_mentions"] = {"replied_user": False} - - for attempt in range(3): - try: - with open(path, "rb") as f: - files = {"files[0]": (path.name, f, "application/octet-stream")} - data: dict[str, Any] = {} - if payload_json: - data["payload_json"] = json.dumps(payload_json) - response = await self._http.post( - url, headers=headers, files=files, data=data - ) - if response.status_code == 429: - resp_data = response.json() - retry_after = float(resp_data.get("retry_after", 1.0)) - logger.warning("Discord rate limited, retrying in {}s", retry_after) - await asyncio.sleep(retry_after) - continue - response.raise_for_status() - logger.info("Discord file sent: {}", path.name) - return True - except Exception as e: - if attempt == 2: - logger.error("Error sending Discord file {}: {}", path.name, e) - else: - await asyncio.sleep(1) - return False - - async def _gateway_loop(self) -> None: - """Main gateway loop: identify, heartbeat, dispatch events.""" - if not self._ws: - return - - async for raw in self._ws: - try: - data = json.loads(raw) - except json.JSONDecodeError: - logger.warning("Invalid JSON from Discord gateway: {}", raw[:100]) - continue - - op = data.get("op") - event_type = data.get("t") - seq = data.get("s") - payload = data.get("d") - - if seq is not None: - self._seq = seq - - if op == 10: - # HELLO: start heartbeat and identify - interval_ms = payload.get("heartbeat_interval", 45000) - await self._start_heartbeat(interval_ms / 1000) - await self._identify() - elif op == 0 and event_type == "READY": - logger.info("Discord gateway READY") - # Capture bot user ID for mention detection - user_data = payload.get("user") or {} - self._bot_user_id = user_data.get("id") - logger.info("Discord bot connected as user {}", self._bot_user_id) - elif op == 0 and event_type == "MESSAGE_CREATE": - await self._handle_message_create(payload) - elif op == 7: - # RECONNECT: exit loop to reconnect - logger.info("Discord gateway requested reconnect") - break - elif op == 9: - # INVALID_SESSION: reconnect - logger.warning("Discord gateway invalid session") - break - - async def _identify(self) -> None: - """Send IDENTIFY payload.""" - if not self._ws: - return - - identify = { - "op": 2, - "d": { - "token": self.config.token, - "intents": self.config.intents, - "properties": { - "os": "nanobot", - "browser": "nanobot", - "device": "nanobot", - }, - }, - } - await self._ws.send(json.dumps(identify)) - - async def _start_heartbeat(self, interval_s: float) -> None: - """Start or restart the heartbeat loop.""" - if self._heartbeat_task: - self._heartbeat_task.cancel() - - async def heartbeat_loop() -> None: - while self._running and self._ws: - payload = {"op": 1, "d": self._seq} - try: - await self._ws.send(json.dumps(payload)) - except Exception as e: - logger.warning("Discord heartbeat failed: {}", e) - break - await asyncio.sleep(interval_s) - - self._heartbeat_task = asyncio.create_task(heartbeat_loop()) - - async def _handle_message_create(self, payload: dict[str, Any]) -> None: - """Handle incoming Discord messages.""" - author = payload.get("author") or {} - if author.get("bot"): - return - - sender_id = str(author.get("id", "")) - channel_id = str(payload.get("channel_id", "")) - content = payload.get("content") or "" - guild_id = payload.get("guild_id") - - if not sender_id or not channel_id: - return - + """Check if inbound Discord message should be processed.""" if not self.is_allowed(sender_id): - return + return False + if message.guild is not None and not self._should_respond_in_group(message, content): + return False + return True - # Check group channel policy (DMs always respond if is_allowed passes) - if guild_id is not None: - if not self._should_respond_in_group(payload, content): - return - - content_parts = [content] if content else [] + async def _download_attachments( + self, + attachments: list[discord.Attachment], + ) -> tuple[list[str], list[str]]: + """Download supported attachments and return paths + display markers.""" media_paths: list[str] = [] + markers: list[str] = [] media_dir = get_media_dir("discord") - for attachment in payload.get("attachments") or []: - url = attachment.get("url") - filename = attachment.get("filename") or "attachment" - size = attachment.get("size") or 0 - if not url or not self._http: - continue - if size and size > MAX_ATTACHMENT_BYTES: - content_parts.append(f"[attachment: {filename} - too large]") + for attachment in attachments: + filename = attachment.filename or "attachment" + if attachment.size and attachment.size > MAX_ATTACHMENT_BYTES: + markers.append(f"[attachment: {filename} - too large]") continue try: media_dir.mkdir(parents=True, exist_ok=True) - file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}" - resp = await self._http.get(url) - resp.raise_for_status() - file_path.write_bytes(resp.content) + safe_name = safe_filename(filename) + file_path = media_dir / f"{attachment.id}_{safe_name}" + await attachment.save(file_path) media_paths.append(str(file_path)) - content_parts.append(f"[attachment: {file_path}]") + markers.append(f"[attachment: {file_path.name}]") except Exception as e: logger.warning("Failed to download Discord attachment: {}", e) - content_parts.append(f"[attachment: {filename} - download failed]") + markers.append(f"[attachment: {filename} - download failed]") - reply_to = (payload.get("referenced_message") or {}).get("id") + return media_paths, markers - await self._start_typing(channel_id) + @staticmethod + def _compose_inbound_content(content: str, attachment_markers: list[str]) -> str: + """Combine message text with attachment markers.""" + content_parts = [content] if content else [] + content_parts.extend(attachment_markers) + return "\n".join(part for part in content_parts if part) or "[empty message]" - await self._handle_message( - sender_id=sender_id, - chat_id=channel_id, - content="\n".join(p for p in content_parts if p) or "[empty message]", - media=media_paths, - metadata={ - "message_id": str(payload.get("id", "")), - "guild_id": guild_id, - "reply_to": reply_to, - }, - ) + @staticmethod + def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]: + """Build metadata for inbound Discord messages.""" + reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None + return { + "message_id": str(message.id), + "guild_id": str(message.guild.id) if message.guild else None, + "reply_to": reply_to, + } - def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool: - """Check if bot should respond in a group channel based on policy.""" + def _should_respond_in_group(self, message: discord.Message, content: str) -> bool: + """Check if the bot should respond in a guild channel based on policy.""" if self.config.group_policy == "open": return True if self.config.group_policy == "mention": - # Check if bot was mentioned in the message - if self._bot_user_id: - # Check mentions array - mentions = payload.get("mentions") or [] - for mention in mentions: - if str(mention.get("id")) == self._bot_user_id: - return True - # Also check content for mention format <@USER_ID> - if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content: - return True - logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id")) + bot_user_id = self._bot_user_id + if bot_user_id is None: + logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id) + return False + + if any(str(user.id) == bot_user_id for user in message.mentions): + return True + if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content: + return True + + logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id) return False return True - async def _start_typing(self, channel_id: str) -> None: + async def _start_typing(self, channel: Messageable) -> None: """Start periodic typing indicator for a channel.""" + channel_id = self._channel_key(channel) await self._stop_typing(channel_id) async def typing_loop() -> None: - url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing" - headers = {"Authorization": f"Bot {self.config.token}"} while self._running: try: - await self._http.post(url, headers=headers) + async with channel.typing(): + await asyncio.sleep(TYPING_INTERVAL_S) except asyncio.CancelledError: return except Exception as e: logger.debug("Discord typing indicator failed for {}: {}", channel_id, e) return - await asyncio.sleep(8) self._typing_tasks[channel_id] = asyncio.create_task(typing_loop()) async def _stop_typing(self, channel_id: str) -> None: """Stop typing indicator for a channel.""" - task = self._typing_tasks.pop(channel_id, None) - if task: - task.cancel() + task = self._typing_tasks.pop(self._channel_key(channel_id), None) + if task is None: + return + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def _cancel_all_typing(self) -> None: + """Stop all typing tasks.""" + channel_ids = list(self._typing_tasks) + for channel_id in channel_ids: + await self._stop_typing(channel_id) + + async def _reset_runtime_state(self, close_client: bool) -> None: + """Reset client and typing state.""" + await self._cancel_all_typing() + if close_client and self._client is not None and not self._client.is_closed(): + try: + await self._client.close() + except Exception as e: + logger.warning("Discord client close failed: {}", e) + self._client = None + self._bot_user_id = None diff --git a/nanobot/command/builtin.py b/nanobot/command/builtin.py index 0a9af3cb9..643397057 100644 --- a/nanobot/command/builtin.py +++ b/nanobot/command/builtin.py @@ -84,6 +84,16 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage: async def cmd_help(ctx: CommandContext) -> OutboundMessage: """Return available slash commands.""" + return OutboundMessage( + channel=ctx.msg.channel, + chat_id=ctx.msg.chat_id, + content=build_help_text(), + metadata={"render_as": "text"}, + ) + + +def build_help_text() -> str: + """Build canonical help text shared across channels.""" lines = [ "🐈 nanobot commands:", "/new — Start a new conversation", @@ -92,12 +102,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage: "/status — Show bot status", "/help — Show available commands", ] - return OutboundMessage( - channel=ctx.msg.channel, - chat_id=ctx.msg.chat_id, - content="\n".join(lines), - metadata={"render_as": "text"}, - ) + return "\n".join(lines) def register_builtin_commands(router: CommandRouter) -> None: diff --git a/pyproject.toml b/pyproject.toml index 8298d112a..51d494668 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,9 @@ matrix = [ "mistune>=3.0.0,<4.0.0", "nh3>=0.2.17,<1.0.0", ] +discord = [ + "discord.py>=2.5.2,<3.0.0", +] langsmith = [ "langsmith>=0.1.0", ] diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py new file mode 100644 index 000000000..3f1f996fc --- /dev/null +++ b/tests/channels/test_discord_channel.py @@ -0,0 +1,676 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from types import SimpleNamespace + +import discord +import pytest + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig +from nanobot.command.builtin import build_help_text + + +# Minimal Discord client test double used to control startup/readiness behavior. +class _FakeDiscordClient: + instances: list["_FakeDiscordClient"] = [] + start_error: Exception | None = None + + def __init__(self, owner, *, intents) -> None: + self.owner = owner + self.intents = intents + self.closed = False + self.ready = True + self.channels: dict[int, object] = {} + self.user = SimpleNamespace(id=999) + self.__class__.instances.append(self) + + async def start(self, token: str) -> None: + self.token = token + if self.__class__.start_error is not None: + raise self.__class__.start_error + + async def close(self) -> None: + self.closed = True + + def is_closed(self) -> bool: + return self.closed + + def is_ready(self) -> bool: + return self.ready + + def get_channel(self, channel_id: int): + return self.channels.get(channel_id) + + async def send_outbound(self, msg: OutboundMessage) -> None: + channel = self.get_channel(int(msg.chat_id)) + if channel is None: + return + await channel.send(content=msg.content) + + +class _FakeAttachment: + # Attachment double that can simulate successful or failing save() calls. + def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None: + self.id = attachment_id + self.filename = filename + self.size = size + self._fail = fail + + async def save(self, path: str | Path) -> None: + if self._fail: + raise RuntimeError("save failed") + Path(path).write_bytes(b"attachment") + + +class _FakePartialMessage: + # Lightweight stand-in for Discord partial message references used in replies. + def __init__(self, message_id: int) -> None: + self.id = message_id + + +class _FakeChannel: + # Channel double that records outbound payloads and typing activity. + def __init__(self, channel_id: int = 123) -> None: + self.id = channel_id + self.sent_payloads: list[dict] = [] + self.trigger_typing_calls = 0 + self.typing_enter_hook = None + + async def send(self, **kwargs) -> None: + payload = dict(kwargs) + if "file" in payload: + payload["file_name"] = payload["file"].filename + del payload["file"] + self.sent_payloads.append(payload) + + def get_partial_message(self, message_id: int) -> _FakePartialMessage: + return _FakePartialMessage(message_id) + + def typing(self): + channel = self + + class _TypingContext: + async def __aenter__(self): + channel.trigger_typing_calls += 1 + if channel.typing_enter_hook is not None: + await channel.typing_enter_hook() + + async def __aexit__(self, exc_type, exc, tb): + return False + + return _TypingContext() + + +class _FakeInteractionResponse: + def __init__(self) -> None: + self.messages: list[dict] = [] + self._done = False + + async def send_message(self, content: str, *, ephemeral: bool = False) -> None: + self.messages.append({"content": content, "ephemeral": ephemeral}) + self._done = True + + def is_done(self) -> bool: + return self._done + + +def _make_interaction( + *, + user_id: int = 123, + channel_id: int | None = 456, + guild_id: int | None = None, + interaction_id: int = 999, +): + return SimpleNamespace( + user=SimpleNamespace(id=user_id), + channel_id=channel_id, + guild_id=guild_id, + id=interaction_id, + command=SimpleNamespace(qualified_name="new"), + response=_FakeInteractionResponse(), + ) + + +def _make_message( + *, + author_id: int = 123, + author_bot: bool = False, + channel_id: int = 456, + message_id: int = 789, + content: str = "hello", + guild_id: int | None = None, + mentions: list[object] | None = None, + attachments: list[object] | None = None, + reply_to: int | None = None, +): + # Factory for incoming Discord message objects with optional guild/reply/attachments. + guild = SimpleNamespace(id=guild_id) if guild_id is not None else None + reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None + return SimpleNamespace( + author=SimpleNamespace(id=author_id, bot=author_bot), + channel=_FakeChannel(channel_id), + content=content, + guild=guild, + mentions=mentions or [], + attachments=attachments or [], + reference=reference, + id=message_id, + ) + + +@pytest.mark.asyncio +async def test_start_returns_when_token_missing() -> None: + # If no token is configured, startup should no-op and leave channel stopped. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + + +@pytest.mark.asyncio +async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None: + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + + +@pytest.mark.asyncio +async def test_start_handles_client_construction_failure(monkeypatch) -> None: + # Construction errors from the Discord client should be swallowed and keep state clean. + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + + def _boom(owner, *, intents): + raise RuntimeError("bad client") + + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + + +@pytest.mark.asyncio +async def test_start_handles_client_start_failure(monkeypatch) -> None: + # If client.start fails, the partially created client should be closed and detached. + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + + _FakeDiscordClient.instances.clear() + _FakeDiscordClient.start_error = RuntimeError("connect failed") + monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient) + + await channel.start() + + assert channel.is_running is False + assert channel._client is None + assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents + assert _FakeDiscordClient.instances[0].closed is True + + _FakeDiscordClient.start_error = None + + +@pytest.mark.asyncio +async def test_stop_is_safe_after_partial_start(monkeypatch) -> None: + # stop() should close/discard the client even when startup was only partially completed. + channel = DiscordChannel( + DiscordConfig(enabled=True, token="token", allow_from=["*"]), + MessageBus(), + ) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + await channel.stop() + + assert channel.is_running is False + assert client.closed is True + assert channel._client is None + + +@pytest.mark.asyncio +async def test_on_message_ignores_bot_messages() -> None: + # Incoming bot-authored messages must be ignored to prevent feedback loops. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign] + + await channel._on_message(_make_message(author_bot=True)) + + assert handled == [] + + # If inbound handling raises, typing should be stopped for that channel. + async def fail_handle(**kwargs) -> None: + raise RuntimeError("boom") + + channel._handle_message = fail_handle # type: ignore[method-assign] + + with pytest.raises(RuntimeError, match="boom"): + await channel._on_message(_make_message(author_id=123, channel_id=456)) + + assert channel._typing_tasks == {} + + +@pytest.mark.asyncio +async def test_on_message_accepts_allowlisted_dm() -> None: + # Allowed direct messages should be forwarded with normalized metadata. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + + await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789)) + + assert len(handled) == 1 + assert handled[0]["chat_id"] == "456" + assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None} + + +@pytest.mark.asyncio +async def test_on_message_ignores_unmentioned_guild_message() -> None: + # With mention-only group policy, guild messages without a bot mention are dropped. + channel = DiscordChannel( + DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"), + MessageBus(), + ) + channel._bot_user_id = "999" + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + + await channel._on_message(_make_message(guild_id=1, content="hello everyone")) + + assert handled == [] + + +@pytest.mark.asyncio +async def test_on_message_accepts_mentioned_guild_message() -> None: + # Mentioned guild messages should be accepted and preserve reply threading metadata. + channel = DiscordChannel( + DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"), + MessageBus(), + ) + channel._bot_user_id = "999" + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + + await channel._on_message( + _make_message( + guild_id=1, + content="<@999> hello", + mentions=[SimpleNamespace(id=999)], + reply_to=321, + ) + ) + + assert len(handled) == 1 + assert handled[0]["metadata"]["reply_to"] == "321" + + +@pytest.mark.asyncio +async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None: + # Attachment downloads should be saved and referenced in forwarded content/media. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path) + + await channel._on_message( + _make_message( + attachments=[_FakeAttachment(12, "photo.png")], + content="see file", + ) + ) + + assert len(handled) == 1 + assert handled[0]["media"] == [str(tmp_path / "12_photo.png")] + assert "[attachment:" in handled[0]["content"] + + +@pytest.mark.asyncio +async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None: + # Failed attachment downloads should emit a readable placeholder and no media path. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path) + + await channel._on_message( + _make_message( + attachments=[_FakeAttachment(12, "photo.png", fail=True)], + content="", + ) + ) + + assert len(handled) == 1 + assert handled[0]["media"] == [] + assert handled[0]["content"] == "[attachment: photo.png - download failed]" + + +@pytest.mark.asyncio +async def test_send_warns_when_client_not_ready() -> None: + # Sending without a running/ready client should be a safe no-op. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + assert channel._typing_tasks == {} + + +@pytest.mark.asyncio +async def test_send_skips_when_channel_not_cached() -> None: + # Outbound sends should be skipped when the destination channel is not resolvable. + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + fetch_calls: list[int] = [] + + async def fetch_channel(channel_id: int): + fetch_calls.append(channel_id) + raise RuntimeError("not found") + + client.fetch_channel = fetch_channel # type: ignore[method-assign] + + await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + assert client.get_channel(123) is None + assert fetch_calls == [123] + + +@pytest.mark.asyncio +async def test_send_fetches_channel_when_not_cached() -> None: + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + target = _FakeChannel(channel_id=123) + + async def fetch_channel(channel_id: int): + return target if channel_id == 123 else None + + client.fetch_channel = fetch_channel # type: ignore[method-assign] + + await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello")) + + assert target.sent_payloads == [{"content": "hello"}] + + +@pytest.mark.asyncio +async def test_slash_new_forwards_when_user_is_allowlisted() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321) + + new_cmd = client.tree.get_command("new") + assert new_cmd is not None + await new_cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": "Processing /new...", "ephemeral": True} + ] + assert len(handled) == 1 + assert handled[0]["content"] == "/new" + assert handled[0]["sender_id"] == "123" + assert handled[0]["chat_id"] == "456" + assert handled[0]["metadata"]["interaction_id"] == "321" + assert handled[0]["metadata"]["is_slash_command"] is True + + +@pytest.mark.asyncio +async def test_slash_new_is_blocked_for_disallowed_user() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction(user_id=123, channel_id=456) + + new_cmd = client.tree.get_command("new") + assert new_cmd is not None + await new_cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": "You are not allowed to use this bot.", "ephemeral": True} + ] + assert handled == [] + + +@pytest.mark.parametrize("slash_name", ["stop", "restart", "status"]) +@pytest.mark.asyncio +async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction() + interaction.command.qualified_name = slash_name + + cmd = client.tree.get_command(slash_name) + assert cmd is not None + await cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": f"Processing /{slash_name}...", "ephemeral": True} + ] + assert len(handled) == 1 + assert handled[0]["content"] == f"/{slash_name}" + assert handled[0]["metadata"]["is_slash_command"] is True + + +@pytest.mark.asyncio +async def test_slash_help_returns_ephemeral_help_text() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + handled: list[dict] = [] + + async def capture_handle(**kwargs) -> None: + handled.append(kwargs) + + channel._handle_message = capture_handle # type: ignore[method-assign] + client = DiscordBotClient(channel, intents=discord.Intents.none()) + interaction = _make_interaction() + interaction.command.qualified_name = "help" + + help_cmd = client.tree.get_command("help") + assert help_cmd is not None + await help_cmd.callback(interaction) + + assert interaction.response.messages == [ + {"content": build_help_text(), "ephemeral": True} + ] + assert handled == [] + + +@pytest.mark.asyncio +async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None: + # Outbound payloads should upload files, attach reply references, and chunk long text. + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + target = _FakeChannel(channel_id=123) + client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign] + + file_path = tmp_path / "demo.txt" + file_path.write_text("hi") + + await client.send_outbound( + OutboundMessage( + channel="discord", + chat_id="123", + content="a" * 2100, + reply_to="55", + media=[str(file_path)], + ) + ) + + assert len(target.sent_payloads) == 3 + assert target.sent_payloads[0]["file_name"] == "demo.txt" + assert target.sent_payloads[0]["reference"].id == 55 + assert target.sent_payloads[1]["content"] == "a" * 2000 + assert target.sent_payloads[2]["content"] == "a" * 100 + + +@pytest.mark.asyncio +async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None: + # If all attachment sends fail and no text exists, emit a failure placeholder message. + owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = DiscordBotClient(owner, intents=discord.Intents.none()) + target = _FakeChannel(channel_id=123) + client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign] + + missing_file = tmp_path / "missing.txt" + + await client.send_outbound( + OutboundMessage( + channel="discord", + chat_id="123", + content="", + media=[str(missing_file)], + ) + ) + + assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}] + + +@pytest.mark.asyncio +async def test_send_stops_typing_after_send() -> None: + # Active typing indicators should be cancelled/cleared after a successful send. + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + client = _FakeDiscordClient(channel, intents=None) + channel._client = client + channel._running = True + + start = asyncio.Event() + release = asyncio.Event() + + async def slow_typing() -> None: + start.set() + await release.wait() + + typing_channel = _FakeChannel(channel_id=123) + typing_channel.typing_enter_hook = slow_typing + + await channel._start_typing(typing_channel) + await start.wait() + + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) + release.set() + await asyncio.sleep(0) + + assert channel._typing_tasks == {} + + # Progress messages should keep typing active until a final (non-progress) send. + start = asyncio.Event() + release = asyncio.Event() + + async def slow_typing_progress() -> None: + start.set() + await release.wait() + + typing_channel = _FakeChannel(channel_id=123) + typing_channel.typing_enter_hook = slow_typing_progress + + await channel._start_typing(typing_channel) + await start.wait() + + await channel.send( + OutboundMessage( + channel="discord", + chat_id="123", + content="progress", + metadata={"_progress": True}, + ) + ) + + assert "123" in channel._typing_tasks + + await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final")) + release.set() + await asyncio.sleep(0) + + assert channel._typing_tasks == {} + + +@pytest.mark.asyncio +async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None: + channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus()) + channel._running = True + + entered = asyncio.Event() + release = asyncio.Event() + + class _TypingCtx: + async def __aenter__(self): + entered.set() + + async def __aexit__(self, exc_type, exc, tb): + return False + + class _NoTriggerChannel: + def __init__(self, channel_id: int = 123) -> None: + self.id = channel_id + + def typing(self): + async def _waiter(): + await release.wait() + # Hold the loop so task remains active until explicitly stopped. + class _Ctx(_TypingCtx): + async def __aenter__(self): + await super().__aenter__() + await _waiter() + return _Ctx() + + typing_channel = _NoTriggerChannel(channel_id=123) + await channel._start_typing(typing_channel) # type: ignore[arg-type] + await entered.wait() + + assert "123" in channel._typing_tasks + + await channel._stop_typing("123") + release.set() + await asyncio.sleep(0) + + assert channel._typing_tasks == {} From 8956df3668de0e0b009275aa38d88049535b3cd6 Mon Sep 17 00:00:00 2001 From: Jesse <74103710+95256155o@users.noreply.github.com> Date: Mon, 30 Mar 2026 02:02:43 -0400 Subject: [PATCH 21/22] feat(discord): configurable read receipt + subagent working indicator (#2330) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(discord): channel-side read receipt and subagent indicator - Add 👀 reaction on message receipt, removed after bot reply - Add 🔧 reaction on first progress message, removed on final reply - Both managed purely in discord.py channel layer, no subagent.py changes - Config: read_receipt_emoji, subagent_emoji with sensible defaults Addresses maintainer feedback on HKUDS/nanobot#2330 Co-Authored-By: Claude Sonnet 4.6 * fix(discord): add both reactions on inbound, not on progress _progress flag is for streaming chunks, not subagent lifecycle. Add 👀 + 🔧 immediately on message receipt, clear both on final reply. * fix: remove stale _subagent_active reference in _clear_reactions * fix(discord): clean up reactions on message handling failure Previously, if _handle_message raised an exception, pending reactions (read receipt + subagent indicator) would remain on the user's message indefinitely since send() — which handles normal cleanup — would never be called. Co-Authored-By: Claude Opus 4.6 (1M context) * refactor(discord): replace subagent_emoji with delayed working indicator - Rename subagent_emoji → working_emoji (honest naming: not tied to subagent lifecycle) - Add working_emoji_delay (default 2s) — cosmetic delay so 🔧 appears after 👀, cancelled if bot replies before delay fires - Clean up: cancel pending task + remove both reactions on reply/error Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Sonnet 4.6 --- nanobot/channels/discord.py | 44 +++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/nanobot/channels/discord.py b/nanobot/channels/discord.py index ef7d41d77..9bf4d919c 100644 --- a/nanobot/channels/discord.py +++ b/nanobot/channels/discord.py @@ -42,6 +42,9 @@ class DiscordConfig(Base): allow_from: list[str] = Field(default_factory=list) intents: int = 37377 group_policy: Literal["mention", "open"] = "mention" + read_receipt_emoji: str = "👀" + working_emoji: str = "🔧" + working_emoji_delay: float = 2.0 if DISCORD_AVAILABLE: @@ -258,6 +261,8 @@ class DiscordChannel(BaseChannel): self._client: DiscordBotClient | None = None self._typing_tasks: dict[str, asyncio.Task[None]] = {} self._bot_user_id: str | None = None + self._pending_reactions: dict[str, Any] = {} # chat_id -> message object + self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {} async def start(self) -> None: """Start the Discord client.""" @@ -305,6 +310,7 @@ class DiscordChannel(BaseChannel): return is_progress = bool((msg.metadata or {}).get("_progress")) + try: await client.send_outbound(msg) except Exception as e: @@ -312,6 +318,7 @@ class DiscordChannel(BaseChannel): finally: if not is_progress: await self._stop_typing(msg.chat_id) + await self._clear_reactions(msg.chat_id) async def _handle_discord_message(self, message: discord.Message) -> None: """Handle incoming Discord messages from discord.py.""" @@ -331,6 +338,24 @@ class DiscordChannel(BaseChannel): await self._start_typing(message.channel) + # Add read receipt reaction immediately, working emoji after delay + channel_id = self._channel_key(message.channel) + try: + await message.add_reaction(self.config.read_receipt_emoji) + self._pending_reactions[channel_id] = message + except Exception as e: + logger.debug("Failed to add read receipt reaction: {}", e) + + # Delayed working indicator (cosmetic — not tied to subagent lifecycle) + async def _delayed_working_emoji() -> None: + await asyncio.sleep(self.config.working_emoji_delay) + try: + await message.add_reaction(self.config.working_emoji) + except Exception: + pass + + self._working_emoji_tasks[channel_id] = asyncio.create_task(_delayed_working_emoji()) + try: await self._handle_message( sender_id=sender_id, @@ -340,6 +365,7 @@ class DiscordChannel(BaseChannel): metadata=metadata, ) except Exception: + await self._clear_reactions(channel_id) await self._stop_typing(channel_id) raise @@ -454,6 +480,24 @@ class DiscordChannel(BaseChannel): except asyncio.CancelledError: pass + + async def _clear_reactions(self, chat_id: str) -> None: + """Remove all pending reactions after bot replies.""" + # Cancel delayed working emoji if it hasn't fired yet + task = self._working_emoji_tasks.pop(chat_id, None) + if task and not task.done(): + task.cancel() + + msg_obj = self._pending_reactions.pop(chat_id, None) + if msg_obj is None: + return + bot_user = self._client.user if self._client else None + for emoji in (self.config.read_receipt_emoji, self.config.working_emoji): + try: + await msg_obj.remove_reaction(emoji, bot_user) + except Exception: + pass + async def _cancel_all_typing(self) -> None: """Stop all typing tasks.""" channel_ids = list(self._typing_tasks) From f450c6ef6c0ca9afc2c03c91fd727e94f28464a6 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 31 Mar 2026 11:18:18 +0000 Subject: [PATCH 22/22] fix(channel): preserve threaded streaming context --- nanobot/agent/loop.py | 18 +++--- nanobot/channels/matrix.py | 35 ++++++++--- tests/agent/test_task_cancel.py | 37 ++++++++++++ tests/channels/test_discord_channel.py | 2 +- tests/channels/test_matrix_channel.py | 82 ++++++++++++++++++++++++++ 5 files changed, 155 insertions(+), 19 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 97d352cb8..a9dc589e8 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -403,25 +403,25 @@ class AgentLoop: return f"{stream_base_id}:{stream_segment}" async def on_stream(delta: str) -> None: + meta = dict(msg.metadata or {}) + meta["_stream_delta"] = True + meta["_stream_id"] = _current_stream_id() await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content=delta, - metadata={ - "_stream_delta": True, - "_stream_id": _current_stream_id(), - }, + metadata=meta, )) async def on_stream_end(*, resuming: bool = False) -> None: nonlocal stream_segment + meta = dict(msg.metadata or {}) + meta["_stream_end"] = True + meta["_resuming"] = resuming + meta["_stream_id"] = _current_stream_id() await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content="", - metadata={ - "_stream_end": True, - "_resuming": resuming, - "_stream_id": _current_stream_id(), - }, + metadata=meta, )) stream_segment += 1 diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index dcece1043..bc6d9398a 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -132,7 +132,11 @@ def _render_markdown_html(text: str) -> str | None: return formatted -def _build_matrix_text_content(text: str, event_id: str | None = None) -> dict[str, object]: +def _build_matrix_text_content( + text: str, + event_id: str | None = None, + thread_relates_to: dict[str, object] | None = None, +) -> dict[str, object]: """ Constructs and returns a dictionary representing the matrix text content with optional HTML formatting and reference to an existing event for replacement. This function is @@ -144,6 +148,9 @@ def _build_matrix_text_content(text: str, event_id: str | None = None) -> dict[s include information indicating that the message is a replacement of the specified event. :type event_id: str | None + :param thread_relates_to: Optional Matrix thread relation metadata. For edits this is + stored in ``m.new_content`` so the replacement remains in the same thread. + :type thread_relates_to: dict[str, object] | None :return: A dictionary containing the matrix text content, potentially enriched with HTML formatting and replacement metadata if applicable. :rtype: dict[str, object] @@ -153,14 +160,18 @@ def _build_matrix_text_content(text: str, event_id: str | None = None) -> dict[s content["format"] = MATRIX_HTML_FORMAT content["formatted_body"] = html if event_id: - content["m.new_content"] = { + content["m.new_content"] = { "body": text, - "msgtype": "m.text" + "msgtype": "m.text", } content["m.relates_to"] = { "rel_type": "m.replace", - "event_id": event_id + "event_id": event_id, } + if thread_relates_to: + content["m.new_content"]["m.relates_to"] = thread_relates_to + elif thread_relates_to: + content["m.relates_to"] = thread_relates_to return content @@ -475,9 +486,11 @@ class MatrixChannel(BaseChannel): await self._stop_typing_keepalive(chat_id, clear_typing=True) - content = _build_matrix_text_content(buf.text, buf.event_id) - if relates_to: - content["m.relates_to"] = relates_to + content = _build_matrix_text_content( + buf.text, + buf.event_id, + thread_relates_to=relates_to, + ) await self._send_room_content(chat_id, content) return @@ -494,14 +507,18 @@ class MatrixChannel(BaseChannel): if not buf.last_edit or (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL: try: - content = _build_matrix_text_content(buf.text, buf.event_id) + content = _build_matrix_text_content( + buf.text, + buf.event_id, + thread_relates_to=relates_to, + ) response = await self._send_room_content(chat_id, content) buf.last_edit = now if not buf.event_id: # we are editing the same message all the time, so only the first time the event id needs to be set buf.event_id = response.event_id except Exception: - await self._stop_typing_keepalive(metadata["room_id"], clear_typing=True) + await self._stop_typing_keepalive(chat_id, clear_typing=True) pass diff --git a/tests/agent/test_task_cancel.py b/tests/agent/test_task_cancel.py index 4902a4c80..70f7621d1 100644 --- a/tests/agent/test_task_cancel.py +++ b/tests/agent/test_task_cancel.py @@ -117,6 +117,43 @@ class TestDispatch: out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) assert out.content == "hi" + @pytest.mark.asyncio + async def test_dispatch_streaming_preserves_message_metadata(self): + from nanobot.bus.events import InboundMessage + + loop, bus = _make_loop() + msg = InboundMessage( + channel="matrix", + sender_id="u1", + chat_id="!room:matrix.org", + content="hello", + metadata={ + "_wants_stream": True, + "thread_root_event_id": "$root1", + "thread_reply_to_event_id": "$reply1", + }, + ) + + async def fake_process(_msg, *, on_stream=None, on_stream_end=None, **kwargs): + assert on_stream is not None + assert on_stream_end is not None + await on_stream("hi") + await on_stream_end(resuming=False) + return None + + loop._process_message = fake_process + + await loop._dispatch(msg) + first = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + second = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0) + + assert first.metadata["thread_root_event_id"] == "$root1" + assert first.metadata["thread_reply_to_event_id"] == "$reply1" + assert first.metadata["_stream_delta"] is True + assert second.metadata["thread_root_event_id"] == "$root1" + assert second.metadata["thread_reply_to_event_id"] == "$reply1" + assert second.metadata["_stream_end"] is True + @pytest.mark.asyncio async def test_processing_lock_serializes(self): from nanobot.bus.events import InboundMessage, OutboundMessage diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py index 3f1f996fc..d352c788c 100644 --- a/tests/channels/test_discord_channel.py +++ b/tests/channels/test_discord_channel.py @@ -4,8 +4,8 @@ import asyncio from pathlib import Path from types import SimpleNamespace -import discord import pytest +discord = pytest.importorskip("discord") from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus diff --git a/tests/channels/test_matrix_channel.py b/tests/channels/test_matrix_channel.py index 3ad65e76b..18a8e1097 100644 --- a/tests/channels/test_matrix_channel.py +++ b/tests/channels/test_matrix_channel.py @@ -1367,6 +1367,23 @@ def test_build_matrix_text_content_with_event_id() -> None: assert result["m.relates_to"]["event_id"] == event_id +def test_build_matrix_text_content_with_event_id_preserves_thread_relation() -> None: + """Thread relations for edits should stay inside m.new_content.""" + relates_to = { + "rel_type": "m.thread", + "event_id": "$root1", + "m.in_reply_to": {"event_id": "$reply1"}, + "is_falling_back": True, + } + result = _build_matrix_text_content("Updated message", "event-1", relates_to) + + assert result["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "event-1", + } + assert result["m.new_content"]["m.relates_to"] == relates_to + + def test_build_matrix_text_content_no_event_id() -> None: """Test that when event_id is not provided, no extra properties are added.""" result = _build_matrix_text_content("Regular message") @@ -1500,6 +1517,71 @@ async def test_send_delta_stream_end_replaces_existing_message() -> None: } +@pytest.mark.asyncio +async def test_send_delta_starts_threaded_stream_inside_thread() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "event-1" + + metadata = { + "thread_root_event_id": "$root1", + "thread_reply_to_event_id": "$reply1", + } + await channel.send_delta("!room:matrix.org", "Hello", metadata) + + assert client.room_send_calls[0]["content"]["m.relates_to"] == { + "rel_type": "m.thread", + "event_id": "$root1", + "m.in_reply_to": {"event_id": "$reply1"}, + "is_falling_back": True, + } + + +@pytest.mark.asyncio +async def test_send_delta_threaded_edit_keeps_replace_and_thread_relation(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "event-1" + + times = [100.0, 102.0, 104.0] + times.reverse() + monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop()) + + metadata = { + "thread_root_event_id": "$root1", + "thread_reply_to_event_id": "$reply1", + } + await channel.send_delta("!room:matrix.org", "Hello", metadata) + await channel.send_delta("!room:matrix.org", " world", metadata) + await channel.send_delta("!room:matrix.org", "", {"_stream_end": True, **metadata}) + + edit_content = client.room_send_calls[1]["content"] + final_content = client.room_send_calls[2]["content"] + + assert edit_content["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "event-1", + } + assert edit_content["m.new_content"]["m.relates_to"] == { + "rel_type": "m.thread", + "event_id": "$root1", + "m.in_reply_to": {"event_id": "$reply1"}, + "is_falling_back": True, + } + assert final_content["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "event-1", + } + assert final_content["m.new_content"]["m.relates_to"] == { + "rel_type": "m.thread", + "event_id": "$root1", + "m.in_reply_to": {"event_id": "$reply1"}, + "is_falling_back": True, + } + + @pytest.mark.asyncio async def test_send_delta_stream_end_noop_when_buffer_missing() -> None: channel = MatrixChannel(_make_config(), MessageBus())