Merge remote-tracking branch 'origin/main' into pr-2614

This commit is contained in:
Xubin Ren 2026-03-31 11:29:36 +00:00
commit 6aad945719
27 changed files with 3598 additions and 458 deletions

108
README.md
View File

@ -115,6 +115,8 @@
- [Configuration](#-configuration)
- [Multiple Instances](#-multiple-instances)
- [CLI Reference](#-cli-reference)
- [Python SDK](#-python-sdk)
- [OpenAI-Compatible API](#-openai-compatible-api)
- [Docker](#-docker)
- [Linux Service](#-linux-service)
- [Project Structure](#-project-structure)
@ -854,7 +856,6 @@ Config file: `~/.nanobot/config.json`
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
> - **Alibaba Cloud BaiLian**: If you're using Alibaba Cloud BaiLian's OpenAI-compatible endpoint, set `"apiBase": "https://dashscope.aliyuncs.com/compatible-mode/v1"` in your dashscope provider config.
> - **Step Fun (Mainland China)**: If your API key is from Step Fun's mainland China platform (stepfun.com), set `"apiBase": "https://api.stepfun.com/v1"` in your stepfun provider config.
> - **Step Fun Step Plan**: Exclusive discount links for the nanobot community: [Overseas](https://platform.stepfun.ai/step-plan) · [Mainland China](https://platform.stepfun.com/step-plan)
| Provider | Purpose | Get API Key |
|----------|---------|-------------|
@ -1542,6 +1543,7 @@ nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobo
| `nanobot agent` | Interactive chat mode |
| `nanobot agent --no-markdown` | Show plain-text replies |
| `nanobot agent --logs` | Show runtime logs during chat |
| `nanobot serve` | Start the OpenAI-compatible API |
| `nanobot gateway` | Start the gateway |
| `nanobot status` | Show status |
| `nanobot provider login openai-codex` | OAuth login for providers |
@ -1570,6 +1572,110 @@ The agent can also manage this file itself — ask it to "add a periodic task" a
</details>
## 🐍 Python SDK
Use nanobot as a library — no CLI, no gateway, just Python:
```python
from nanobot import Nanobot
bot = Nanobot.from_config()
result = await bot.run("Summarize the README")
print(result.content)
```
Each call carries a `session_key` for conversation isolation — different keys get independent history:
```python
await bot.run("hi", session_key="user-alice")
await bot.run("hi", session_key="task-42")
```
Add lifecycle hooks to observe or customize the agent:
```python
from nanobot.agent import AgentHook, AgentHookContext
class AuditHook(AgentHook):
async def before_execute_tools(self, ctx: AgentHookContext) -> None:
for tc in ctx.tool_calls:
print(f"[tool] {tc.name}")
result = await bot.run("Hello", hooks=[AuditHook()])
```
See [docs/PYTHON_SDK.md](docs/PYTHON_SDK.md) for the full SDK reference.
## 🔌 OpenAI-Compatible API
nanobot can expose a minimal OpenAI-compatible endpoint for local integrations:
```bash
pip install "nanobot-ai[api]"
nanobot serve
```
By default, the API binds to `127.0.0.1:8900`. You can change this in `config.json`.
### Behavior
- Session isolation: pass `"session_id"` in the request body to isolate conversations; omit for a shared default session (`api:default`)
- Single-message input: each request must contain exactly one `user` message
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
- No streaming: `stream=true` is not supported
### Endpoints
- `GET /health`
- `GET /v1/models`
- `POST /v1/chat/completions`
### curl
```bash
curl http://127.0.0.1:8900/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "hi"}],
"session_id": "my-session"
}'
```
### Python (`requests`)
```python
import requests
resp = requests.post(
"http://127.0.0.1:8900/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "hi"}],
"session_id": "my-session", # optional: isolate conversation
},
timeout=120,
)
resp.raise_for_status()
print(resp.json()["choices"][0]["message"]["content"])
```
### Python (`openai`)
```python
from openai import OpenAI
client = OpenAI(
base_url="http://127.0.0.1:8900/v1",
api_key="dummy",
)
resp = client.chat.completions.create(
model="MiniMax-M2.7",
messages=[{"role": "user", "content": "hi"}],
extra_body={"session_id": "my-session"}, # optional: isolate conversation
)
print(resp.choices[0].message.content)
```
## 🐳 Docker
> [!TIP]

View File

@ -1,5 +1,6 @@
#!/bin/bash
# Count core agent lines (excluding channels/, cli/, providers/ adapters)
# Count core agent lines (excluding channels/, cli/, api/, providers/ adapters,
# and the high-level Python SDK facade)
cd "$(dirname "$0")" || exit 1
echo "nanobot core agent line count"
@ -15,7 +16,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l)
printf " %-16s %5s lines\n" "(root)" "$root"
echo ""
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l)
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/api/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" ! -path "nanobot/nanobot.py" | xargs cat | wc -l)
echo " Core total: $total lines"
echo ""
echo " (excludes: channels/, cli/, command/, providers/, skills/)"
echo " (excludes: channels/, cli/, api/, command/, providers/, skills/, nanobot.py)"

136
docs/PYTHON_SDK.md Normal file
View File

@ -0,0 +1,136 @@
# Python SDK
Use nanobot programmatically — load config, run the agent, get results.
## Quick Start
```python
import asyncio
from nanobot import Nanobot
async def main():
bot = Nanobot.from_config()
result = await bot.run("What time is it in Tokyo?")
print(result.content)
asyncio.run(main())
```
## API
### `Nanobot.from_config(config_path?, *, workspace?)`
Create a `Nanobot` from a config file.
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `config_path` | `str \| Path \| None` | `None` | Path to `config.json`. Defaults to `~/.nanobot/config.json`. |
| `workspace` | `str \| Path \| None` | `None` | Override workspace directory from config. |
Raises `FileNotFoundError` if an explicit path doesn't exist.
### `await bot.run(message, *, session_key?, hooks?)`
Run the agent once. Returns a `RunResult`.
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `message` | `str` | *(required)* | The user message to process. |
| `session_key` | `str` | `"sdk:default"` | Session identifier for conversation isolation. Different keys get independent history. |
| `hooks` | `list[AgentHook] \| None` | `None` | Lifecycle hooks for this run only. |
```python
# Isolated sessions — each user gets independent conversation history
await bot.run("hi", session_key="user-alice")
await bot.run("hi", session_key="user-bob")
```
### `RunResult`
| Field | Type | Description |
|-------|------|-------------|
| `content` | `str` | The agent's final text response. |
| `tools_used` | `list[str]` | Tool names invoked during the run. |
| `messages` | `list[dict]` | Raw message history (for debugging). |
## Hooks
Hooks let you observe or modify the agent loop without touching internals.
Subclass `AgentHook` and override any method:
| Method | When |
|--------|------|
| `before_iteration(ctx)` | Before each LLM call |
| `on_stream(ctx, delta)` | On each streamed token |
| `on_stream_end(ctx)` | When streaming finishes |
| `before_execute_tools(ctx)` | Before tool execution (inspect `ctx.tool_calls`) |
| `after_iteration(ctx, response)` | After each LLM response |
| `finalize_content(ctx, content)` | Transform final output text |
### Example: Audit Hook
```python
from nanobot.agent import AgentHook, AgentHookContext
class AuditHook(AgentHook):
def __init__(self):
self.calls = []
async def before_execute_tools(self, ctx: AgentHookContext) -> None:
for tc in ctx.tool_calls:
self.calls.append(tc.name)
print(f"[audit] {tc.name}({tc.arguments})")
hook = AuditHook()
result = await bot.run("List files in /tmp", hooks=[hook])
print(f"Tools used: {hook.calls}")
```
### Composing Hooks
Pass multiple hooks — they run in order, errors in one don't block others:
```python
result = await bot.run("hi", hooks=[AuditHook(), MetricsHook()])
```
Under the hood this uses `CompositeHook` for fan-out with error isolation.
### `finalize_content` Pipeline
Unlike the async methods (fan-out), `finalize_content` is a pipeline — each hook's output feeds the next:
```python
class Censor(AgentHook):
def finalize_content(self, ctx, content):
return content.replace("secret", "***") if content else content
```
## Full Example
```python
import asyncio
from nanobot import Nanobot
from nanobot.agent import AgentHook, AgentHookContext
class TimingHook(AgentHook):
async def before_iteration(self, ctx: AgentHookContext) -> None:
import time
ctx.metadata["_t0"] = time.time()
async def after_iteration(self, ctx, response) -> None:
import time
elapsed = time.time() - ctx.metadata.get("_t0", 0)
print(f"[timing] iteration took {elapsed:.2f}s")
async def main():
bot = Nanobot.from_config(workspace="/my/project")
result = await bot.run(
"Explain the main function",
hooks=[TimingHook()],
)
print(result.content)
asyncio.run(main())
```

View File

@ -4,3 +4,7 @@ nanobot - A lightweight AI agent framework
__version__ = "0.1.4.post6"
__logo__ = "🐈"
from nanobot.nanobot import Nanobot, RunResult
__all__ = ["Nanobot", "RunResult"]

View File

@ -1,8 +1,19 @@
"""Agent core module."""
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.loop import AgentLoop
from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader
from nanobot.agent.subagent import SubagentManager
__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"]
__all__ = [
"AgentHook",
"AgentHookContext",
"AgentLoop",
"CompositeHook",
"ContextBuilder",
"MemoryStore",
"SkillsLoader",
"SubagentManager",
]

View File

@ -5,6 +5,8 @@ from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from loguru import logger
from nanobot.providers.base import LLMResponse, ToolCallRequest
@ -47,3 +49,60 @@ class AgentHook:
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return content
class CompositeHook(AgentHook):
"""Fan-out hook that delegates to an ordered list of hooks.
Error isolation: async methods catch and log per-hook exceptions
so a faulty custom hook cannot crash the agent loop.
``finalize_content`` is a pipeline (no isolation bugs should surface).
"""
__slots__ = ("_hooks",)
def __init__(self, hooks: list[AgentHook]) -> None:
self._hooks = list(hooks)
def wants_streaming(self) -> bool:
return any(h.wants_streaming() for h in self._hooks)
async def before_iteration(self, context: AgentHookContext) -> None:
for h in self._hooks:
try:
await h.before_iteration(context)
except Exception:
logger.exception("AgentHook.before_iteration error in {}", type(h).__name__)
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
for h in self._hooks:
try:
await h.on_stream(context, delta)
except Exception:
logger.exception("AgentHook.on_stream error in {}", type(h).__name__)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
for h in self._hooks:
try:
await h.on_stream_end(context, resuming=resuming)
except Exception:
logger.exception("AgentHook.on_stream_end error in {}", type(h).__name__)
async def before_execute_tools(self, context: AgentHookContext) -> None:
for h in self._hooks:
try:
await h.before_execute_tools(context)
except Exception:
logger.exception("AgentHook.before_execute_tools error in {}", type(h).__name__)
async def after_iteration(self, context: AgentHookContext) -> None:
for h in self._hooks:
try:
await h.after_iteration(context)
except Exception:
logger.exception("AgentHook.after_iteration error in {}", type(h).__name__)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
for h in self._hooks:
content = h.finalize_content(context, content)
return content

View File

@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.memory import MemoryConsolidator
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.subagent import SubagentManager
@ -37,6 +37,111 @@ if TYPE_CHECKING:
from nanobot.cron.service import CronService
class _LoopHook(AgentHook):
"""Core lifecycle hook for the main agent loop.
Handles streaming delta relay, progress reporting, tool-call logging,
and think-tag stripping for the built-in agent path.
"""
def __init__(
self,
agent_loop: AgentLoop,
on_progress: Callable[..., Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
*,
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
) -> None:
self._loop = agent_loop
self._on_progress = on_progress
self._on_stream = on_stream
self._on_stream_end = on_stream_end
self._channel = channel
self._chat_id = chat_id
self._message_id = message_id
self._stream_buf = ""
def wants_streaming(self) -> bool:
return self._on_stream is not None
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
from nanobot.utils.helpers import strip_think
prev_clean = strip_think(self._stream_buf)
self._stream_buf += delta
new_clean = strip_think(self._stream_buf)
incremental = new_clean[len(prev_clean):]
if incremental and self._on_stream:
await self._on_stream(incremental)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
if self._on_stream_end:
await self._on_stream_end(resuming=resuming)
self._stream_buf = ""
async def before_execute_tools(self, context: AgentHookContext) -> None:
if self._on_progress:
if not self._on_stream:
thought = self._loop._strip_think(
context.response.content if context.response else None
)
if thought:
await self._on_progress(thought)
tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls))
await self._on_progress(tool_hint, tool_hint=True)
for tc in context.tool_calls:
args_str = json.dumps(tc.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tc.name, args_str[:200])
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return self._loop._strip_think(content)
class _LoopHookChain(AgentHook):
"""Run the core loop hook first, then best-effort extra hooks.
This preserves the historical failure behavior of ``_LoopHook`` while still
letting user-supplied hooks opt into ``CompositeHook`` isolation.
"""
__slots__ = ("_primary", "_extras")
def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None:
self._primary = primary
self._extras = CompositeHook(extra_hooks)
def wants_streaming(self) -> bool:
return self._primary.wants_streaming() or self._extras.wants_streaming()
async def before_iteration(self, context: AgentHookContext) -> None:
await self._primary.before_iteration(context)
await self._extras.before_iteration(context)
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
await self._primary.on_stream(context, delta)
await self._extras.on_stream(context, delta)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
await self._primary.on_stream_end(context, resuming=resuming)
await self._extras.on_stream_end(context, resuming=resuming)
async def before_execute_tools(self, context: AgentHookContext) -> None:
await self._primary.before_execute_tools(context)
await self._extras.before_execute_tools(context)
async def after_iteration(self, context: AgentHookContext) -> None:
await self._primary.after_iteration(context)
await self._extras.after_iteration(context)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
content = self._primary.finalize_content(context, content)
return self._extras.finalize_content(context, content)
class AgentLoop:
"""
The agent loop is the core processing engine.
@ -68,6 +173,7 @@ class AgentLoop:
mcp_servers: dict | None = None,
channels_config: ChannelsConfig | None = None,
timezone: str | None = None,
hooks: list[AgentHook] | None = None,
):
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
@ -85,6 +191,7 @@ class AgentLoop:
self.restrict_to_workspace = restrict_to_workspace
self._start_time = time.time()
self._last_usage: dict[str, int] = {}
self._extra_hooks: list[AgentHook] = hooks or []
self.context = ContextBuilder(workspace, timezone=timezone)
self.sessions = session_manager or SessionManager(workspace)
@ -217,52 +324,27 @@ class AgentLoop:
``resuming=True`` means tool calls follow (spinner should restart);
``resuming=False`` means this is the final response.
"""
loop_self = self
class _LoopHook(AgentHook):
def __init__(self) -> None:
self._stream_buf = ""
def wants_streaming(self) -> bool:
return on_stream is not None
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
from nanobot.utils.helpers import strip_think
prev_clean = strip_think(self._stream_buf)
self._stream_buf += delta
new_clean = strip_think(self._stream_buf)
incremental = new_clean[len(prev_clean):]
if incremental and on_stream:
await on_stream(incremental)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
if on_stream_end:
await on_stream_end(resuming=resuming)
self._stream_buf = ""
async def before_execute_tools(self, context: AgentHookContext) -> None:
if on_progress:
if not on_stream:
thought = loop_self._strip_think(context.response.content if context.response else None)
if thought:
await on_progress(thought)
tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls))
await on_progress(tool_hint, tool_hint=True)
for tc in context.tool_calls:
args_str = json.dumps(tc.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tc.name, args_str[:200])
loop_self._set_tool_context(channel, chat_id, message_id)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return loop_self._strip_think(content)
loop_hook = _LoopHook(
self,
on_progress=on_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
channel=channel,
chat_id=chat_id,
message_id=message_id,
)
hook: AgentHook = (
_LoopHookChain(loop_hook, self._extra_hooks)
if self._extra_hooks
else loop_hook
)
result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
model=self.model,
max_iterations=self.max_iterations,
hook=_LoopHook(),
hook=hook,
error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True,
))
@ -321,25 +403,25 @@ class AgentLoop:
return f"{stream_base_id}:{stream_segment}"
async def on_stream(delta: str) -> None:
meta = dict(msg.metadata or {})
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content=delta,
metadata={
"_stream_delta": True,
"_stream_id": _current_stream_id(),
},
metadata=meta,
))
async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(msg.metadata or {})
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="",
metadata={
"_stream_end": True,
"_resuming": resuming,
"_stream_id": _current_stream_id(),
},
metadata=meta,
))
stream_segment += 1

View File

@ -21,6 +21,21 @@ from nanobot.config.schema import ExecToolConfig
from nanobot.providers.base import LLMProvider
class _SubagentHook(AgentHook):
"""Logging-only hook for subagent execution."""
def __init__(self, task_id: str) -> None:
self._task_id = task_id
async def before_execute_tools(self, context: AgentHookContext) -> None:
for tool_call in context.tool_calls:
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
logger.debug(
"Subagent [{}] executing: {} with arguments: {}",
self._task_id, tool_call.name, args_str,
)
class SubagentManager:
"""Manages background subagent execution."""
@ -100,33 +115,28 @@ class SubagentManager:
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
path_append=self.exec_config.path_append,
))
if self.exec_config.enable:
tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
path_append=self.exec_config.path_append,
))
tools.register(WebSearchTool(config=self.web_search_config, proxy=self.web_proxy))
tools.register(WebFetchTool(proxy=self.web_proxy))
system_prompt = self._build_subagent_prompt()
messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": task},
]
class _SubagentHook(AgentHook):
async def before_execute_tools(self, context: AgentHookContext) -> None:
for tool_call in context.tool_calls:
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
result = await self.runner.run(AgentRunSpec(
initial_messages=messages,
tools=tools,
model=self.model,
max_iterations=15,
hook=_SubagentHook(),
hook=_SubagentHook(task_id),
max_iterations_message="Task completed but no final response was generated.",
error_message=None,
fail_on_tool_error=True,
@ -213,7 +223,7 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
lines.append("Failure:")
lines.append(f"- {result.error}")
return "\n".join(lines) or (result.error or "Error: subagent execution failed.")
def _build_subagent_prompt(self) -> str:
"""Build a focused system prompt for the subagent."""
from nanobot.agent.context import ContextBuilder

View File

@ -74,7 +74,7 @@ class CronTool(Tool):
"enum": ["add", "list", "remove"],
"description": "Action to perform",
},
"message": {"type": "string", "description": "Reminder message (for add)"},
"message": {"type": "string", "description": "Instruction for the agent to execute when the job triggers (e.g., 'Send a reminder to WeChat: xxx' or 'Check system status and report')"},
"every_seconds": {
"type": "integer",
"description": "Interval in seconds (for recurring tasks)",

View File

@ -170,7 +170,11 @@ async def connect_mcp_servers(
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
merged_headers = {**(cfg.headers or {}), **(headers or {})}
merged_headers = {
"Accept": "application/json, text/event-stream",
**(cfg.headers or {}),
**(headers or {}),
}
return httpx.AsyncClient(
headers=merged_headers or None,
follow_redirects=True,

1
nanobot/api/__init__.py Normal file
View File

@ -0,0 +1 @@
"""OpenAI-compatible HTTP API for nanobot."""

193
nanobot/api/server.py Normal file
View File

@ -0,0 +1,193 @@
"""OpenAI-compatible HTTP API server for a fixed nanobot session.
Provides /v1/chat/completions and /v1/models endpoints.
All requests route to a single persistent API session.
"""
from __future__ import annotations
import asyncio
import time
import uuid
from typing import Any
from aiohttp import web
from loguru import logger
API_SESSION_KEY = "api:default"
API_CHAT_ID = "default"
# ---------------------------------------------------------------------------
# Response helpers
# ---------------------------------------------------------------------------
def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response:
return web.json_response(
{"error": {"message": message, "type": err_type, "code": status}},
status=status,
)
def _chat_completion_response(content: str, model: str) -> dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
def _response_text(value: Any) -> str:
"""Normalize process_direct output to plain assistant text."""
if value is None:
return ""
if hasattr(value, "content"):
return str(getattr(value, "content") or "")
return str(value)
# ---------------------------------------------------------------------------
# Route handlers
# ---------------------------------------------------------------------------
async def handle_chat_completions(request: web.Request) -> web.Response:
"""POST /v1/chat/completions"""
# --- Parse body ---
try:
body = await request.json()
except Exception:
return _error_json(400, "Invalid JSON body")
messages = body.get("messages")
if not isinstance(messages, list) or len(messages) != 1:
return _error_json(400, "Only a single user message is supported")
# Stream not yet supported
if body.get("stream", False):
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
message = messages[0]
if not isinstance(message, dict) or message.get("role") != "user":
return _error_json(400, "Only a single user message is supported")
user_content = message.get("content", "")
if isinstance(user_content, list):
# Multi-modal content array — extract text parts
user_content = " ".join(
part.get("text", "") for part in user_content if part.get("type") == "text"
)
agent_loop = request.app["agent_loop"]
timeout_s: float = request.app.get("request_timeout", 120.0)
model_name: str = request.app.get("model_name", "nanobot")
if (requested_model := body.get("model")) and requested_model != model_name:
return _error_json(400, f"Only configured model '{model_name}' is available")
session_key = f"api:{body['session_id']}" if body.get("session_id") else API_SESSION_KEY
session_locks: dict[str, asyncio.Lock] = request.app["session_locks"]
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
logger.info("API request session_key={} content={}", session_key, user_content[:80])
_FALLBACK = "I've completed processing but have no response to give."
try:
async with session_lock:
try:
response = await asyncio.wait_for(
agent_loop.process_direct(
content=user_content,
session_key=session_key,
channel="api",
chat_id=API_CHAT_ID,
),
timeout=timeout_s,
)
response_text = _response_text(response)
if not response_text or not response_text.strip():
logger.warning(
"Empty response for session {}, retrying",
session_key,
)
retry_response = await asyncio.wait_for(
agent_loop.process_direct(
content=user_content,
session_key=session_key,
channel="api",
chat_id=API_CHAT_ID,
),
timeout=timeout_s,
)
response_text = _response_text(retry_response)
if not response_text or not response_text.strip():
logger.warning(
"Empty response after retry for session {}, using fallback",
session_key,
)
response_text = _FALLBACK
except asyncio.TimeoutError:
return _error_json(504, f"Request timed out after {timeout_s}s")
except Exception:
logger.exception("Error processing request for session {}", session_key)
return _error_json(500, "Internal server error", err_type="server_error")
except Exception:
logger.exception("Unexpected API lock error for session {}", session_key)
return _error_json(500, "Internal server error", err_type="server_error")
return web.json_response(_chat_completion_response(response_text, model_name))
async def handle_models(request: web.Request) -> web.Response:
"""GET /v1/models"""
model_name = request.app.get("model_name", "nanobot")
return web.json_response({
"object": "list",
"data": [
{
"id": model_name,
"object": "model",
"created": 0,
"owned_by": "nanobot",
}
],
})
async def handle_health(request: web.Request) -> web.Response:
"""GET /health"""
return web.json_response({"status": "ok"})
# ---------------------------------------------------------------------------
# App factory
# ---------------------------------------------------------------------------
def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0) -> web.Application:
"""Create the aiohttp application.
Args:
agent_loop: An initialized AgentLoop instance.
model_name: Model name reported in responses.
request_timeout: Per-request timeout in seconds.
"""
app = web.Application()
app["agent_loop"] = agent_loop
app["model_name"] = model_name
app["request_timeout"] = request_timeout
app["session_locks"] = {} # per-user locks, keyed by session_key
app.router.add_post("/v1/chat/completions", handle_chat_completions)
app.router.add_get("/v1/models", handle_models)
app.router.add_get("/health", handle_health)
return app

View File

@ -1,25 +1,37 @@
"""Discord channel implementation using Discord Gateway websocket."""
"""Discord channel implementation using discord.py."""
from __future__ import annotations
import asyncio
import json
import importlib.util
from pathlib import Path
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal
import httpx
from pydantic import Field
import websockets
from loguru import logger
from pydantic import Field
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.command.builtin import build_help_text
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base
from nanobot.utils.helpers import split_message
from nanobot.utils.helpers import safe_filename, split_message
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
if TYPE_CHECKING:
import discord
from discord import app_commands
from discord.abc import Messageable
if DISCORD_AVAILABLE:
import discord
from discord import app_commands
from discord.abc import Messageable
DISCORD_API_BASE = "https://discord.com/api/v10"
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
MAX_MESSAGE_LEN = 2000 # Discord message character limit
TYPING_INTERVAL_S = 8
class DiscordConfig(Base):
@ -28,13 +40,205 @@ class DiscordConfig(Base):
enabled: bool = False
token: str = ""
allow_from: list[str] = Field(default_factory=list)
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
intents: int = 37377
group_policy: Literal["mention", "open"] = "mention"
read_receipt_emoji: str = "👀"
working_emoji: str = "🔧"
working_emoji_delay: float = 2.0
if DISCORD_AVAILABLE:
class DiscordBotClient(discord.Client):
"""discord.py client that forwards events to the channel."""
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
super().__init__(intents=intents)
self._channel = channel
self.tree = app_commands.CommandTree(self)
self._register_app_commands()
async def on_ready(self) -> None:
self._channel._bot_user_id = str(self.user.id) if self.user else None
logger.info("Discord bot connected as user {}", self._channel._bot_user_id)
try:
synced = await self.tree.sync()
logger.info("Discord app commands synced: {}", len(synced))
except Exception as e:
logger.warning("Discord app command sync failed: {}", e)
async def on_message(self, message: discord.Message) -> None:
await self._channel._handle_discord_message(message)
async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool:
"""Send an ephemeral interaction response and report success."""
try:
await interaction.response.send_message(text, ephemeral=True)
return True
except Exception as e:
logger.warning("Discord interaction response failed: {}", e)
return False
async def _forward_slash_command(
self,
interaction: discord.Interaction,
command_text: str,
) -> None:
sender_id = str(interaction.user.id)
channel_id = interaction.channel_id
if channel_id is None:
logger.warning("Discord slash command missing channel_id: {}", command_text)
return
if not self._channel.is_allowed(sender_id):
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
return
await self._reply_ephemeral(interaction, f"Processing {command_text}...")
await self._channel._handle_message(
sender_id=sender_id,
chat_id=str(channel_id),
content=command_text,
metadata={
"interaction_id": str(interaction.id),
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
"is_slash_command": True,
},
)
def _register_app_commands(self) -> None:
commands = (
("new", "Start a new conversation", "/new"),
("stop", "Stop the current task", "/stop"),
("restart", "Restart the bot", "/restart"),
("status", "Show bot status", "/status"),
)
for name, description, command_text in commands:
@self.tree.command(name=name, description=description)
async def command_handler(
interaction: discord.Interaction,
_command_text: str = command_text,
) -> None:
await self._forward_slash_command(interaction, _command_text)
@self.tree.command(name="help", description="Show available commands")
async def help_command(interaction: discord.Interaction) -> None:
sender_id = str(interaction.user.id)
if not self._channel.is_allowed(sender_id):
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
return
await self._reply_ephemeral(interaction, build_help_text())
@self.tree.error
async def on_app_command_error(
interaction: discord.Interaction,
error: app_commands.AppCommandError,
) -> None:
command_name = interaction.command.qualified_name if interaction.command else "?"
logger.warning(
"Discord app command failed user={} channel={} cmd={} error={}",
interaction.user.id,
interaction.channel_id,
command_name,
error,
)
async def send_outbound(self, msg: OutboundMessage) -> None:
"""Send a nanobot outbound message using Discord transport rules."""
channel_id = int(msg.chat_id)
channel = self.get_channel(channel_id)
if channel is None:
try:
channel = await self.fetch_channel(channel_id)
except Exception as e:
logger.warning("Discord channel {} unavailable: {}", msg.chat_id, e)
return
reference, mention_settings = self._build_reply_context(channel, msg.reply_to)
sent_media = False
failed_media: list[str] = []
for index, media_path in enumerate(msg.media or []):
if await self._send_file(
channel,
media_path,
reference=reference if index == 0 else None,
mention_settings=mention_settings,
):
sent_media = True
else:
failed_media.append(Path(media_path).name)
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
kwargs: dict[str, Any] = {"content": chunk}
if index == 0 and reference is not None and not sent_media:
kwargs["reference"] = reference
kwargs["allowed_mentions"] = mention_settings
await channel.send(**kwargs)
async def _send_file(
self,
channel: Messageable,
file_path: str,
*,
reference: discord.PartialMessage | None,
mention_settings: discord.AllowedMentions,
) -> bool:
"""Send a file attachment via discord.py."""
path = Path(file_path)
if not path.is_file():
logger.warning("Discord file not found, skipping: {}", file_path)
return False
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
return False
try:
kwargs: dict[str, Any] = {"file": discord.File(path)}
if reference is not None:
kwargs["reference"] = reference
kwargs["allowed_mentions"] = mention_settings
await channel.send(**kwargs)
logger.info("Discord file sent: {}", path.name)
return True
except Exception as e:
logger.error("Error sending Discord file {}: {}", path.name, e)
return False
@staticmethod
def _build_chunks(content: str, failed_media: list[str], sent_media: bool) -> list[str]:
"""Build outbound text chunks, including attachment-failure fallback text."""
chunks = split_message(content, MAX_MESSAGE_LEN)
if chunks or not failed_media or sent_media:
return chunks
fallback = "\n".join(f"[attachment: {name} - send failed]" for name in failed_media)
return split_message(fallback, MAX_MESSAGE_LEN)
@staticmethod
def _build_reply_context(
channel: Messageable,
reply_to: str | None,
) -> tuple[discord.PartialMessage | None, discord.AllowedMentions]:
"""Build reply context for outbound messages."""
mention_settings = discord.AllowedMentions(replied_user=False)
if not reply_to:
return None, mention_settings
try:
message_id = int(reply_to)
except ValueError:
logger.warning("Invalid Discord reply target: {}", reply_to)
return None, mention_settings
return channel.get_partial_message(message_id), mention_settings
class DiscordChannel(BaseChannel):
"""Discord channel using Gateway websocket."""
"""Discord channel using discord.py."""
name = "discord"
display_name = "Discord"
@ -43,353 +247,270 @@ class DiscordChannel(BaseChannel):
def default_config(cls) -> dict[str, Any]:
return DiscordConfig().model_dump(by_alias=True)
@staticmethod
def _channel_key(channel_or_id: Any) -> str:
"""Normalize channel-like objects and ids to a stable string key."""
channel_id = getattr(channel_or_id, "id", channel_or_id)
return str(channel_id)
def __init__(self, config: Any, bus: MessageBus):
if isinstance(config, dict):
config = DiscordConfig.model_validate(config)
super().__init__(config, bus)
self.config: DiscordConfig = config
self._ws: websockets.WebSocketClientProtocol | None = None
self._seq: int | None = None
self._heartbeat_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {}
self._http: httpx.AsyncClient | None = None
self._client: DiscordBotClient | None = None
self._typing_tasks: dict[str, asyncio.Task[None]] = {}
self._bot_user_id: str | None = None
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
async def start(self) -> None:
"""Start the Discord gateway connection."""
"""Start the Discord client."""
if not DISCORD_AVAILABLE:
logger.error("discord.py not installed. Run: pip install nanobot-ai[discord]")
return
if not self.config.token:
logger.error("Discord bot token not configured")
return
self._running = True
self._http = httpx.AsyncClient(timeout=30.0)
try:
intents = discord.Intents.none()
intents.value = self.config.intents
self._client = DiscordBotClient(self, intents=intents)
except Exception as e:
logger.error("Failed to initialize Discord client: {}", e)
self._client = None
self._running = False
return
while self._running:
try:
logger.info("Connecting to Discord gateway...")
async with websockets.connect(self.config.gateway_url) as ws:
self._ws = ws
await self._gateway_loop()
except asyncio.CancelledError:
break
except Exception as e:
logger.warning("Discord gateway error: {}", e)
if self._running:
logger.info("Reconnecting to Discord gateway in 5 seconds...")
await asyncio.sleep(5)
self._running = True
logger.info("Starting Discord client via discord.py...")
try:
await self._client.start(self.config.token)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error("Discord client startup failed: {}", e)
finally:
self._running = False
await self._reset_runtime_state(close_client=True)
async def stop(self) -> None:
"""Stop the Discord channel."""
self._running = False
if self._heartbeat_task:
self._heartbeat_task.cancel()
self._heartbeat_task = None
for task in self._typing_tasks.values():
task.cancel()
self._typing_tasks.clear()
if self._ws:
await self._ws.close()
self._ws = None
if self._http:
await self._http.aclose()
self._http = None
await self._reset_runtime_state(close_client=True)
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through Discord REST API, including file attachments."""
if not self._http:
logger.warning("Discord HTTP client not initialized")
"""Send a message through Discord using discord.py."""
client = self._client
if client is None or not client.is_ready():
logger.warning("Discord client not ready; dropping outbound message")
return
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
headers = {"Authorization": f"Bot {self.config.token}"}
is_progress = bool((msg.metadata or {}).get("_progress"))
try:
sent_media = False
failed_media: list[str] = []
# Send file attachments first
for media_path in msg.media or []:
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
sent_media = True
else:
failed_media.append(Path(media_path).name)
# Send text content
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
if not chunks and failed_media and not sent_media:
chunks = split_message(
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
MAX_MESSAGE_LEN,
)
if not chunks:
return
for i, chunk in enumerate(chunks):
payload: dict[str, Any] = {"content": chunk}
# Let the first successful attachment carry the reply if present.
if i == 0 and msg.reply_to and not sent_media:
payload["message_reference"] = {"message_id": msg.reply_to}
payload["allowed_mentions"] = {"replied_user": False}
if not await self._send_payload(url, headers, payload):
break # Abort remaining chunks on failure
await client.send_outbound(msg)
except Exception as e:
logger.error("Error sending Discord message: {}", e)
finally:
await self._stop_typing(msg.chat_id)
if not is_progress:
await self._stop_typing(msg.chat_id)
await self._clear_reactions(msg.chat_id)
async def _send_payload(
self, url: str, headers: dict[str, str], payload: dict[str, Any]
) -> bool:
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
for attempt in range(3):
async def _handle_discord_message(self, message: discord.Message) -> None:
"""Handle incoming Discord messages from discord.py."""
if message.author.bot:
return
sender_id = str(message.author.id)
channel_id = self._channel_key(message.channel)
content = message.content or ""
if not self._should_accept_inbound(message, sender_id, content):
return
media_paths, attachment_markers = await self._download_attachments(message.attachments)
full_content = self._compose_inbound_content(content, attachment_markers)
metadata = self._build_inbound_metadata(message)
await self._start_typing(message.channel)
# Add read receipt reaction immediately, working emoji after delay
channel_id = self._channel_key(message.channel)
try:
await message.add_reaction(self.config.read_receipt_emoji)
self._pending_reactions[channel_id] = message
except Exception as e:
logger.debug("Failed to add read receipt reaction: {}", e)
# Delayed working indicator (cosmetic — not tied to subagent lifecycle)
async def _delayed_working_emoji() -> None:
await asyncio.sleep(self.config.working_emoji_delay)
try:
response = await self._http.post(url, headers=headers, json=payload)
if response.status_code == 429:
data = response.json()
retry_after = float(data.get("retry_after", 1.0))
logger.warning("Discord rate limited, retrying in {}s", retry_after)
await asyncio.sleep(retry_after)
continue
response.raise_for_status()
return True
except Exception as e:
if attempt == 2:
logger.error("Error sending Discord message: {}", e)
else:
await asyncio.sleep(1)
return False
await message.add_reaction(self.config.working_emoji)
except Exception:
pass
async def _send_file(
self._working_emoji_tasks[channel_id] = asyncio.create_task(_delayed_working_emoji())
try:
await self._handle_message(
sender_id=sender_id,
chat_id=channel_id,
content=full_content,
media=media_paths,
metadata=metadata,
)
except Exception:
await self._clear_reactions(channel_id)
await self._stop_typing(channel_id)
raise
async def _on_message(self, message: discord.Message) -> None:
"""Backward-compatible alias for legacy tests/callers."""
await self._handle_discord_message(message)
def _should_accept_inbound(
self,
url: str,
headers: dict[str, str],
file_path: str,
reply_to: str | None = None,
message: discord.Message,
sender_id: str,
content: str,
) -> bool:
"""Send a file attachment via Discord REST API using multipart/form-data."""
path = Path(file_path)
if not path.is_file():
logger.warning("Discord file not found, skipping: {}", file_path)
return False
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
return False
payload_json: dict[str, Any] = {}
if reply_to:
payload_json["message_reference"] = {"message_id": reply_to}
payload_json["allowed_mentions"] = {"replied_user": False}
for attempt in range(3):
try:
with open(path, "rb") as f:
files = {"files[0]": (path.name, f, "application/octet-stream")}
data: dict[str, Any] = {}
if payload_json:
data["payload_json"] = json.dumps(payload_json)
response = await self._http.post(
url, headers=headers, files=files, data=data
)
if response.status_code == 429:
resp_data = response.json()
retry_after = float(resp_data.get("retry_after", 1.0))
logger.warning("Discord rate limited, retrying in {}s", retry_after)
await asyncio.sleep(retry_after)
continue
response.raise_for_status()
logger.info("Discord file sent: {}", path.name)
return True
except Exception as e:
if attempt == 2:
logger.error("Error sending Discord file {}: {}", path.name, e)
else:
await asyncio.sleep(1)
return False
async def _gateway_loop(self) -> None:
"""Main gateway loop: identify, heartbeat, dispatch events."""
if not self._ws:
return
async for raw in self._ws:
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
continue
op = data.get("op")
event_type = data.get("t")
seq = data.get("s")
payload = data.get("d")
if seq is not None:
self._seq = seq
if op == 10:
# HELLO: start heartbeat and identify
interval_ms = payload.get("heartbeat_interval", 45000)
await self._start_heartbeat(interval_ms / 1000)
await self._identify()
elif op == 0 and event_type == "READY":
logger.info("Discord gateway READY")
# Capture bot user ID for mention detection
user_data = payload.get("user") or {}
self._bot_user_id = user_data.get("id")
logger.info("Discord bot connected as user {}", self._bot_user_id)
elif op == 0 and event_type == "MESSAGE_CREATE":
await self._handle_message_create(payload)
elif op == 7:
# RECONNECT: exit loop to reconnect
logger.info("Discord gateway requested reconnect")
break
elif op == 9:
# INVALID_SESSION: reconnect
logger.warning("Discord gateway invalid session")
break
async def _identify(self) -> None:
"""Send IDENTIFY payload."""
if not self._ws:
return
identify = {
"op": 2,
"d": {
"token": self.config.token,
"intents": self.config.intents,
"properties": {
"os": "nanobot",
"browser": "nanobot",
"device": "nanobot",
},
},
}
await self._ws.send(json.dumps(identify))
async def _start_heartbeat(self, interval_s: float) -> None:
"""Start or restart the heartbeat loop."""
if self._heartbeat_task:
self._heartbeat_task.cancel()
async def heartbeat_loop() -> None:
while self._running and self._ws:
payload = {"op": 1, "d": self._seq}
try:
await self._ws.send(json.dumps(payload))
except Exception as e:
logger.warning("Discord heartbeat failed: {}", e)
break
await asyncio.sleep(interval_s)
self._heartbeat_task = asyncio.create_task(heartbeat_loop())
async def _handle_message_create(self, payload: dict[str, Any]) -> None:
"""Handle incoming Discord messages."""
author = payload.get("author") or {}
if author.get("bot"):
return
sender_id = str(author.get("id", ""))
channel_id = str(payload.get("channel_id", ""))
content = payload.get("content") or ""
guild_id = payload.get("guild_id")
if not sender_id or not channel_id:
return
"""Check if inbound Discord message should be processed."""
if not self.is_allowed(sender_id):
return
return False
if message.guild is not None and not self._should_respond_in_group(message, content):
return False
return True
# Check group channel policy (DMs always respond if is_allowed passes)
if guild_id is not None:
if not self._should_respond_in_group(payload, content):
return
content_parts = [content] if content else []
async def _download_attachments(
self,
attachments: list[discord.Attachment],
) -> tuple[list[str], list[str]]:
"""Download supported attachments and return paths + display markers."""
media_paths: list[str] = []
markers: list[str] = []
media_dir = get_media_dir("discord")
for attachment in payload.get("attachments") or []:
url = attachment.get("url")
filename = attachment.get("filename") or "attachment"
size = attachment.get("size") or 0
if not url or not self._http:
continue
if size and size > MAX_ATTACHMENT_BYTES:
content_parts.append(f"[attachment: {filename} - too large]")
for attachment in attachments:
filename = attachment.filename or "attachment"
if attachment.size and attachment.size > MAX_ATTACHMENT_BYTES:
markers.append(f"[attachment: {filename} - too large]")
continue
try:
media_dir.mkdir(parents=True, exist_ok=True)
file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}"
resp = await self._http.get(url)
resp.raise_for_status()
file_path.write_bytes(resp.content)
safe_name = safe_filename(filename)
file_path = media_dir / f"{attachment.id}_{safe_name}"
await attachment.save(file_path)
media_paths.append(str(file_path))
content_parts.append(f"[attachment: {file_path}]")
markers.append(f"[attachment: {file_path.name}]")
except Exception as e:
logger.warning("Failed to download Discord attachment: {}", e)
content_parts.append(f"[attachment: {filename} - download failed]")
markers.append(f"[attachment: {filename} - download failed]")
reply_to = (payload.get("referenced_message") or {}).get("id")
return media_paths, markers
await self._start_typing(channel_id)
@staticmethod
def _compose_inbound_content(content: str, attachment_markers: list[str]) -> str:
"""Combine message text with attachment markers."""
content_parts = [content] if content else []
content_parts.extend(attachment_markers)
return "\n".join(part for part in content_parts if part) or "[empty message]"
await self._handle_message(
sender_id=sender_id,
chat_id=channel_id,
content="\n".join(p for p in content_parts if p) or "[empty message]",
media=media_paths,
metadata={
"message_id": str(payload.get("id", "")),
"guild_id": guild_id,
"reply_to": reply_to,
},
)
@staticmethod
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
"""Build metadata for inbound Discord messages."""
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
return {
"message_id": str(message.id),
"guild_id": str(message.guild.id) if message.guild else None,
"reply_to": reply_to,
}
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
"""Check if bot should respond in a group channel based on policy."""
def _should_respond_in_group(self, message: discord.Message, content: str) -> bool:
"""Check if the bot should respond in a guild channel based on policy."""
if self.config.group_policy == "open":
return True
if self.config.group_policy == "mention":
# Check if bot was mentioned in the message
if self._bot_user_id:
# Check mentions array
mentions = payload.get("mentions") or []
for mention in mentions:
if str(mention.get("id")) == self._bot_user_id:
return True
# Also check content for mention format <@USER_ID>
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
return True
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
bot_user_id = self._bot_user_id
if bot_user_id is None:
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
return False
if any(str(user.id) == bot_user_id for user in message.mentions):
return True
if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content:
return True
logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id)
return False
return True
async def _start_typing(self, channel_id: str) -> None:
async def _start_typing(self, channel: Messageable) -> None:
"""Start periodic typing indicator for a channel."""
channel_id = self._channel_key(channel)
await self._stop_typing(channel_id)
async def typing_loop() -> None:
url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing"
headers = {"Authorization": f"Bot {self.config.token}"}
while self._running:
try:
await self._http.post(url, headers=headers)
async with channel.typing():
await asyncio.sleep(TYPING_INTERVAL_S)
except asyncio.CancelledError:
return
except Exception as e:
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
return
await asyncio.sleep(8)
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
async def _stop_typing(self, channel_id: str) -> None:
"""Stop typing indicator for a channel."""
task = self._typing_tasks.pop(channel_id, None)
if task:
task = self._typing_tasks.pop(self._channel_key(channel_id), None)
if task is None:
return
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _clear_reactions(self, chat_id: str) -> None:
"""Remove all pending reactions after bot replies."""
# Cancel delayed working emoji if it hasn't fired yet
task = self._working_emoji_tasks.pop(chat_id, None)
if task and not task.done():
task.cancel()
msg_obj = self._pending_reactions.pop(chat_id, None)
if msg_obj is None:
return
bot_user = self._client.user if self._client else None
for emoji in (self.config.read_receipt_emoji, self.config.working_emoji):
try:
await msg_obj.remove_reaction(emoji, bot_user)
except Exception:
pass
async def _cancel_all_typing(self) -> None:
"""Stop all typing tasks."""
channel_ids = list(self._typing_tasks)
for channel_id in channel_ids:
await self._stop_typing(channel_id)
async def _reset_runtime_state(self, close_client: bool) -> None:
"""Reset client and typing state."""
await self._cancel_all_typing()
if close_client and self._client is not None and not self._client.is_closed():
try:
await self._client.close()
except Exception as e:
logger.warning("Discord client close failed: {}", e)
self._client = None
self._bot_user_id = None

View File

@ -3,6 +3,8 @@
import asyncio
import logging
import mimetypes
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, TypeAlias
@ -28,8 +30,8 @@ try:
RoomSendError,
RoomTypingError,
SyncError,
UploadError,
)
UploadError, RoomSendResponse,
)
from nio.crypto.attachments import decrypt_attachment
from nio.exceptions import EncryptionError
except ImportError as e:
@ -97,6 +99,22 @@ MATRIX_HTML_CLEANER = nh3.Cleaner(
link_rel="noopener noreferrer",
)
@dataclass
class _StreamBuf:
"""
Represents a buffer for managing LLM response stream data.
:ivar text: Stores the text content of the buffer.
:type text: str
:ivar event_id: Identifier for the associated event. None indicates no
specific event association.
:type event_id: str | None
:ivar last_edit: Timestamp of the most recent edit to the buffer.
:type last_edit: float
"""
text: str = ""
event_id: str | None = None
last_edit: float = 0.0
def _render_markdown_html(text: str) -> str | None:
"""Render markdown to sanitized HTML; returns None for plain text."""
@ -114,12 +132,47 @@ def _render_markdown_html(text: str) -> str | None:
return formatted
def _build_matrix_text_content(text: str) -> dict[str, object]:
"""Build Matrix m.text payload with optional HTML formatted_body."""
def _build_matrix_text_content(
text: str,
event_id: str | None = None,
thread_relates_to: dict[str, object] | None = None,
) -> dict[str, object]:
"""
Constructs and returns a dictionary representing the matrix text content with optional
HTML formatting and reference to an existing event for replacement. This function is
primarily used to create content payloads compatible with the Matrix messaging protocol.
:param text: The plain text content to include in the message.
:type text: str
:param event_id: Optional ID of the event to replace. If provided, the function will
include information indicating that the message is a replacement of the specified
event.
:type event_id: str | None
:param thread_relates_to: Optional Matrix thread relation metadata. For edits this is
stored in ``m.new_content`` so the replacement remains in the same thread.
:type thread_relates_to: dict[str, object] | None
:return: A dictionary containing the matrix text content, potentially enriched with
HTML formatting and replacement metadata if applicable.
:rtype: dict[str, object]
"""
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
if html := _render_markdown_html(text):
content["format"] = MATRIX_HTML_FORMAT
content["formatted_body"] = html
if event_id:
content["m.new_content"] = {
"body": text,
"msgtype": "m.text",
}
content["m.relates_to"] = {
"rel_type": "m.replace",
"event_id": event_id,
}
if thread_relates_to:
content["m.new_content"]["m.relates_to"] = thread_relates_to
elif thread_relates_to:
content["m.relates_to"] = thread_relates_to
return content
@ -159,7 +212,8 @@ class MatrixConfig(Base):
allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list)
allow_room_mentions: bool = False
allow_room_mentions: bool = False,
streaming: bool = False
class MatrixChannel(BaseChannel):
@ -167,6 +221,8 @@ class MatrixChannel(BaseChannel):
name = "matrix"
display_name = "Matrix"
_STREAM_EDIT_INTERVAL = 2 # min seconds between edit_message_text calls
monotonic_time = time.monotonic
@classmethod
def default_config(cls) -> dict[str, Any]:
@ -192,6 +248,8 @@ class MatrixChannel(BaseChannel):
)
self._server_upload_limit_bytes: int | None = None
self._server_upload_limit_checked = False
self._stream_bufs: dict[str, _StreamBuf] = {}
async def start(self) -> None:
"""Start Matrix client and begin sync loop."""
@ -297,14 +355,17 @@ class MatrixChannel(BaseChannel):
room = getattr(self.client, "rooms", {}).get(room_id)
return bool(getattr(room, "encrypted", False))
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
async def _send_room_content(self, room_id: str,
content: dict[str, Any]) -> None | RoomSendResponse | RoomSendError:
"""Send m.room.message with E2EE options."""
if not self.client:
return
return None
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
if self.config.e2ee_enabled:
kwargs["ignore_unverified_devices"] = True
await self.client.room_send(**kwargs)
response = await self.client.room_send(**kwargs)
return response
async def _resolve_server_upload_limit_bytes(self) -> int | None:
"""Query homeserver upload limit once per channel lifecycle."""
@ -414,6 +475,53 @@ class MatrixChannel(BaseChannel):
if not is_progress:
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
meta = metadata or {}
relates_to = self._build_thread_relates_to(metadata)
if meta.get("_stream_end"):
buf = self._stream_bufs.pop(chat_id, None)
if not buf or not buf.event_id or not buf.text:
return
await self._stop_typing_keepalive(chat_id, clear_typing=True)
content = _build_matrix_text_content(
buf.text,
buf.event_id,
thread_relates_to=relates_to,
)
await self._send_room_content(chat_id, content)
return
buf = self._stream_bufs.get(chat_id)
if buf is None:
buf = _StreamBuf()
self._stream_bufs[chat_id] = buf
buf.text += delta
if not buf.text.strip():
return
now = self.monotonic_time()
if not buf.last_edit or (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
try:
content = _build_matrix_text_content(
buf.text,
buf.event_id,
thread_relates_to=relates_to,
)
response = await self._send_room_content(chat_id, content)
buf.last_edit = now
if not buf.event_id:
# we are editing the same message all the time, so only the first time the event id needs to be set
buf.event_id = response.event_id
except Exception:
await self._stop_typing_keepalive(chat_id, clear_typing=True)
pass
def _register_event_callbacks(self) -> None:
self.client.add_event_callback(self._on_message, RoomMessageText)
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)

View File

@ -491,6 +491,91 @@ def _migrate_cron_store(config: "Config") -> None:
shutil.move(str(legacy_path), str(new_path))
# ============================================================================
# OpenAI-Compatible API Server
# ============================================================================
@app.command()
def serve(
port: int | None = typer.Option(None, "--port", "-p", help="API server port"),
host: str | None = typer.Option(None, "--host", "-H", help="Bind address"),
timeout: float | None = typer.Option(None, "--timeout", "-t", help="Per-request timeout (seconds)"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"),
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
):
"""Start the OpenAI-compatible API server (/v1/chat/completions)."""
try:
from aiohttp import web # noqa: F401
except ImportError:
console.print("[red]aiohttp is required. Install with: pip install 'nanobot-ai[api]'[/red]")
raise typer.Exit(1)
from loguru import logger
from nanobot.agent.loop import AgentLoop
from nanobot.api.server import create_app
from nanobot.bus.queue import MessageBus
from nanobot.session.manager import SessionManager
if verbose:
logger.enable("nanobot")
else:
logger.disable("nanobot")
runtime_config = _load_runtime_config(config, workspace)
api_cfg = runtime_config.api
host = host if host is not None else api_cfg.host
port = port if port is not None else api_cfg.port
timeout = timeout if timeout is not None else api_cfg.timeout
sync_workspace_templates(runtime_config.workspace_path)
bus = MessageBus()
provider = _make_provider(runtime_config)
session_manager = SessionManager(runtime_config.workspace_path)
agent_loop = AgentLoop(
bus=bus,
provider=provider,
workspace=runtime_config.workspace_path,
model=runtime_config.agents.defaults.model,
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
web_search_config=runtime_config.tools.web.search,
web_proxy=runtime_config.tools.web.proxy or None,
exec_config=runtime_config.tools.exec,
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
session_manager=session_manager,
mcp_servers=runtime_config.tools.mcp_servers,
channels_config=runtime_config.channels,
timezone=runtime_config.agents.defaults.timezone,
)
model_name = runtime_config.agents.defaults.model
console.print(f"{__logo__} Starting OpenAI-compatible API server")
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
console.print(f" [cyan]Model[/cyan] : {model_name}")
console.print(" [cyan]Session[/cyan] : api:default")
console.print(f" [cyan]Timeout[/cyan] : {timeout}s")
if host in {"0.0.0.0", "::"}:
console.print(
"[yellow]Warning:[/yellow] API is bound to all interfaces. "
"Only do this behind a trusted network boundary, firewall, or reverse proxy."
)
console.print()
api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout)
async def on_startup(_app):
await agent_loop._connect_mcp()
async def on_cleanup(_app):
await agent_loop.close_mcp()
api_app.on_startup.append(on_startup)
api_app.on_cleanup.append(on_cleanup)
web.run_app(api_app, host=host, port=port, print=lambda msg: logger.info(msg))
# ============================================================================
# Gateway / Server
# ============================================================================

View File

@ -84,6 +84,16 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
async def cmd_help(ctx: CommandContext) -> OutboundMessage:
"""Return available slash commands."""
return OutboundMessage(
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content=build_help_text(),
metadata={"render_as": "text"},
)
def build_help_text() -> str:
"""Build canonical help text shared across channels."""
lines = [
"🐈 nanobot commands:",
"/new — Start a new conversation",
@ -92,12 +102,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage:
"/status — Show bot status",
"/help — Show available commands",
]
return OutboundMessage(
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content="\n".join(lines),
metadata={"render_as": "text"},
)
return "\n".join(lines)
def register_builtin_commands(router: CommandRouter) -> None:

View File

@ -96,6 +96,14 @@ class HeartbeatConfig(Base):
keep_recent_messages: int = 8
class ApiConfig(Base):
"""OpenAI-compatible API server configuration."""
host: str = "127.0.0.1" # Safer default: local-only bind.
port: int = 8900
timeout: float = 120.0 # Per-request timeout in seconds.
class GatewayConfig(Base):
"""Gateway/server configuration."""
@ -156,6 +164,7 @@ class Config(BaseSettings):
agents: AgentsConfig = Field(default_factory=AgentsConfig)
channels: ChannelsConfig = Field(default_factory=ChannelsConfig)
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
api: ApiConfig = Field(default_factory=ApiConfig)
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
tools: ToolsConfig = Field(default_factory=ToolsConfig)

170
nanobot/nanobot.py Normal file
View File

@ -0,0 +1,170 @@
"""High-level programmatic interface to nanobot."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from nanobot.agent.hook import AgentHook
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
@dataclass(slots=True)
class RunResult:
"""Result of a single agent run."""
content: str
tools_used: list[str]
messages: list[dict[str, Any]]
class Nanobot:
"""Programmatic facade for running the nanobot agent.
Usage::
bot = Nanobot.from_config()
result = await bot.run("Summarize this repo", hooks=[MyHook()])
print(result.content)
"""
def __init__(self, loop: AgentLoop) -> None:
self._loop = loop
@classmethod
def from_config(
cls,
config_path: str | Path | None = None,
*,
workspace: str | Path | None = None,
) -> Nanobot:
"""Create a Nanobot instance from a config file.
Args:
config_path: Path to ``config.json``. Defaults to
``~/.nanobot/config.json``.
workspace: Override the workspace directory from config.
"""
from nanobot.config.loader import load_config
from nanobot.config.schema import Config
resolved: Path | None = None
if config_path is not None:
resolved = Path(config_path).expanduser().resolve()
if not resolved.exists():
raise FileNotFoundError(f"Config not found: {resolved}")
config: Config = load_config(resolved)
if workspace is not None:
config.agents.defaults.workspace = str(
Path(workspace).expanduser().resolve()
)
provider = _make_provider(config)
bus = MessageBus()
defaults = config.agents.defaults
loop = AgentLoop(
bus=bus,
provider=provider,
workspace=config.workspace_path,
model=defaults.model,
max_iterations=defaults.max_tool_iterations,
context_window_tokens=defaults.context_window_tokens,
web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,
restrict_to_workspace=config.tools.restrict_to_workspace,
mcp_servers=config.tools.mcp_servers,
timezone=defaults.timezone,
)
return cls(loop)
async def run(
self,
message: str,
*,
session_key: str = "sdk:default",
hooks: list[AgentHook] | None = None,
) -> RunResult:
"""Run the agent once and return the result.
Args:
message: The user message to process.
session_key: Session identifier for conversation isolation.
Different keys get independent history.
hooks: Optional lifecycle hooks for this run.
"""
prev = self._loop._extra_hooks
if hooks is not None:
self._loop._extra_hooks = list(hooks)
try:
response = await self._loop.process_direct(
message, session_key=session_key,
)
finally:
self._loop._extra_hooks = prev
content = (response.content if response else None) or ""
return RunResult(content=content, tools_used=[], messages=[])
def _make_provider(config: Any) -> Any:
"""Create the LLM provider from config (extracted from CLI)."""
from nanobot.providers.base import GenerationSettings
from nanobot.providers.registry import find_by_name
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)
p = config.get_provider(model)
spec = find_by_name(provider_name) if provider_name else None
backend = spec.backend if spec else "openai_compat"
if backend == "azure_openai":
if not p or not p.api_key or not p.api_base:
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
elif backend == "openai_compat" and not model.startswith("bedrock/"):
needs_key = not (p and p.api_key)
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
if needs_key and not exempt:
raise ValueError(f"No API key configured for provider '{provider_name}'.")
if backend == "openai_codex":
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
provider = OpenAICodexProvider(default_model=model)
elif backend == "azure_openai":
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
provider = AzureOpenAIProvider(
api_key=p.api_key, api_base=p.api_base, default_model=model
)
elif backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider
provider = AnthropicProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
)
else:
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
provider = OpenAICompatProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
spec=spec,
)
defaults = config.agents.defaults
provider.generation = GenerationSettings(
temperature=defaults.temperature,
max_tokens=defaults.max_tokens,
reasoning_effort=defaults.reasoning_effort,
)
return provider

View File

@ -124,8 +124,8 @@ def build_assistant_message(
msg: dict[str, Any] = {"role": "assistant", "content": content}
if tool_calls:
msg["tool_calls"] = tool_calls
if reasoning_content is not None:
msg["reasoning_content"] = reasoning_content
if reasoning_content is not None or thinking_blocks:
msg["reasoning_content"] = reasoning_content if reasoning_content is not None else ""
if thinking_blocks:
msg["thinking_blocks"] = thinking_blocks
return msg

View File

@ -51,6 +51,9 @@ dependencies = [
]
[project.optional-dependencies]
api = [
"aiohttp>=3.9.0,<4.0.0",
]
wecom = [
"wecom-aibot-sdk-python>=0.1.5",
]
@ -64,12 +67,16 @@ matrix = [
"mistune>=3.0.0,<4.0.0",
"nh3>=0.2.17,<1.0.0",
]
discord = [
"discord.py>=2.5.2,<3.0.0",
]
langsmith = [
"langsmith>=0.1.0",
]
dev = [
"pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=1.3.0,<2.0.0",
"aiohttp>=3.9.0,<4.0.0",
"pytest-cov>=6.0.0,<7.0.0",
"ruff>=0.1.0",
]

View File

@ -0,0 +1,351 @@
"""Tests for CompositeHook fan-out, error isolation, and integration."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
def _ctx() -> AgentHookContext:
return AgentHookContext(iteration=0, messages=[])
# ---------------------------------------------------------------------------
# Fan-out: every hook is called in order
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_composite_fans_out_before_iteration():
calls: list[str] = []
class H(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
calls.append(f"A:{context.iteration}")
class H2(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
calls.append(f"B:{context.iteration}")
hook = CompositeHook([H(), H2()])
ctx = _ctx()
await hook.before_iteration(ctx)
assert calls == ["A:0", "B:0"]
@pytest.mark.asyncio
async def test_composite_fans_out_all_async_methods():
"""Verify all async methods fan out to every hook."""
events: list[str] = []
class RecordingHook(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
events.append("before_iteration")
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
events.append(f"on_stream:{delta}")
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
events.append(f"on_stream_end:{resuming}")
async def before_execute_tools(self, context: AgentHookContext) -> None:
events.append("before_execute_tools")
async def after_iteration(self, context: AgentHookContext) -> None:
events.append("after_iteration")
hook = CompositeHook([RecordingHook(), RecordingHook()])
ctx = _ctx()
await hook.before_iteration(ctx)
await hook.on_stream(ctx, "hi")
await hook.on_stream_end(ctx, resuming=True)
await hook.before_execute_tools(ctx)
await hook.after_iteration(ctx)
assert events == [
"before_iteration", "before_iteration",
"on_stream:hi", "on_stream:hi",
"on_stream_end:True", "on_stream_end:True",
"before_execute_tools", "before_execute_tools",
"after_iteration", "after_iteration",
]
# ---------------------------------------------------------------------------
# Error isolation: one hook raises, others still run
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_composite_error_isolation_before_iteration():
calls: list[str] = []
class Bad(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
raise RuntimeError("boom")
class Good(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
calls.append("good")
hook = CompositeHook([Bad(), Good()])
await hook.before_iteration(_ctx())
assert calls == ["good"]
@pytest.mark.asyncio
async def test_composite_error_isolation_on_stream():
calls: list[str] = []
class Bad(AgentHook):
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
raise RuntimeError("stream-boom")
class Good(AgentHook):
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
calls.append(delta)
hook = CompositeHook([Bad(), Good()])
await hook.on_stream(_ctx(), "delta")
assert calls == ["delta"]
@pytest.mark.asyncio
async def test_composite_error_isolation_all_async():
"""Error isolation for on_stream_end, before_execute_tools, after_iteration."""
calls: list[str] = []
class Bad(AgentHook):
async def on_stream_end(self, context, *, resuming):
raise RuntimeError("err")
async def before_execute_tools(self, context):
raise RuntimeError("err")
async def after_iteration(self, context):
raise RuntimeError("err")
class Good(AgentHook):
async def on_stream_end(self, context, *, resuming):
calls.append("on_stream_end")
async def before_execute_tools(self, context):
calls.append("before_execute_tools")
async def after_iteration(self, context):
calls.append("after_iteration")
hook = CompositeHook([Bad(), Good()])
ctx = _ctx()
await hook.on_stream_end(ctx, resuming=False)
await hook.before_execute_tools(ctx)
await hook.after_iteration(ctx)
assert calls == ["on_stream_end", "before_execute_tools", "after_iteration"]
# ---------------------------------------------------------------------------
# finalize_content: pipeline semantics (no error isolation)
# ---------------------------------------------------------------------------
def test_composite_finalize_content_pipeline():
class Upper(AgentHook):
def finalize_content(self, context, content):
return content.upper() if content else content
class Suffix(AgentHook):
def finalize_content(self, context, content):
return (content + "!") if content else content
hook = CompositeHook([Upper(), Suffix()])
result = hook.finalize_content(_ctx(), "hello")
assert result == "HELLO!"
def test_composite_finalize_content_none_passthrough():
hook = CompositeHook([AgentHook()])
assert hook.finalize_content(_ctx(), None) is None
def test_composite_finalize_content_ordering():
"""First hook transforms first, result feeds second hook."""
steps: list[str] = []
class H1(AgentHook):
def finalize_content(self, context, content):
steps.append(f"H1:{content}")
return content.upper()
class H2(AgentHook):
def finalize_content(self, context, content):
steps.append(f"H2:{content}")
return content + "!"
hook = CompositeHook([H1(), H2()])
result = hook.finalize_content(_ctx(), "hi")
assert result == "HI!"
assert steps == ["H1:hi", "H2:HI"]
# ---------------------------------------------------------------------------
# wants_streaming: any-semantics
# ---------------------------------------------------------------------------
def test_composite_wants_streaming_any_true():
class No(AgentHook):
def wants_streaming(self):
return False
class Yes(AgentHook):
def wants_streaming(self):
return True
hook = CompositeHook([No(), Yes(), No()])
assert hook.wants_streaming() is True
def test_composite_wants_streaming_all_false():
hook = CompositeHook([AgentHook(), AgentHook()])
assert hook.wants_streaming() is False
def test_composite_wants_streaming_empty():
hook = CompositeHook([])
assert hook.wants_streaming() is False
# ---------------------------------------------------------------------------
# Empty hooks list: behaves like no-op AgentHook
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_composite_empty_hooks_no_ops():
hook = CompositeHook([])
ctx = _ctx()
await hook.before_iteration(ctx)
await hook.on_stream(ctx, "delta")
await hook.on_stream_end(ctx, resuming=False)
await hook.before_execute_tools(ctx)
await hook.after_iteration(ctx)
assert hook.finalize_content(ctx, "test") == "test"
# ---------------------------------------------------------------------------
# Integration: AgentLoop with extra hooks
# ---------------------------------------------------------------------------
def _make_loop(tmp_path, hooks=None):
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.generation.max_tokens = 4096
with patch("nanobot.agent.loop.ContextBuilder"), \
patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager") as mock_sub_mgr, \
patch("nanobot.agent.loop.MemoryConsolidator"):
mock_sub_mgr.return_value.cancel_by_session = AsyncMock(return_value=0)
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, hooks=hooks,
)
return loop
@pytest.mark.asyncio
async def test_agent_loop_extra_hook_receives_calls(tmp_path):
"""Extra hook passed to AgentLoop is called alongside core LoopHook."""
from nanobot.providers.base import LLMResponse
events: list[str] = []
class TrackingHook(AgentHook):
async def before_iteration(self, context):
events.append(f"before_iter:{context.iteration}")
async def after_iteration(self, context):
events.append(f"after_iter:{context.iteration}")
loop = _make_loop(tmp_path, hooks=[TrackingHook()])
loop.provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(content="done", tool_calls=[], usage={})
)
loop.tools.get_definitions = MagicMock(return_value=[])
content, tools_used, messages = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}]
)
assert content == "done"
assert "before_iter:0" in events
assert "after_iter:0" in events
@pytest.mark.asyncio
async def test_agent_loop_extra_hook_error_isolation(tmp_path):
"""A faulty extra hook does not crash the agent loop."""
from nanobot.providers.base import LLMResponse
class BadHook(AgentHook):
async def before_iteration(self, context):
raise RuntimeError("I am broken")
loop = _make_loop(tmp_path, hooks=[BadHook()])
loop.provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(content="still works", tool_calls=[], usage={})
)
loop.tools.get_definitions = MagicMock(return_value=[])
content, _, _ = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}]
)
assert content == "still works"
@pytest.mark.asyncio
async def test_agent_loop_extra_hooks_do_not_swallow_loop_hook_errors(tmp_path):
"""Extra hooks must not change the core LoopHook failure behavior."""
from nanobot.providers.base import LLMResponse, ToolCallRequest
loop = _make_loop(tmp_path, hooks=[AgentHook()])
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})],
usage={},
))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")
async def bad_progress(*args, **kwargs):
raise RuntimeError("progress failed")
with pytest.raises(RuntimeError, match="progress failed"):
await loop._run_agent_loop([], on_progress=bad_progress)
@pytest.mark.asyncio
async def test_agent_loop_no_hooks_backward_compat(tmp_path):
"""Without hooks param, behavior is identical to before."""
from nanobot.providers.base import LLMResponse, ToolCallRequest
loop = _make_loop(tmp_path)
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})],
))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")
loop.max_iterations = 2
content, tools_used, _ = await loop._run_agent_loop([])
assert content == (
"I reached the maximum number of tool call iterations (2) "
"without completing the task. You can try breaking the task into smaller steps."
)
assert tools_used == ["list_dir", "list_dir"]

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -116,6 +117,43 @@ class TestDispatch:
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert out.content == "hi"
@pytest.mark.asyncio
async def test_dispatch_streaming_preserves_message_metadata(self):
from nanobot.bus.events import InboundMessage
loop, bus = _make_loop()
msg = InboundMessage(
channel="matrix",
sender_id="u1",
chat_id="!room:matrix.org",
content="hello",
metadata={
"_wants_stream": True,
"thread_root_event_id": "$root1",
"thread_reply_to_event_id": "$reply1",
},
)
async def fake_process(_msg, *, on_stream=None, on_stream_end=None, **kwargs):
assert on_stream is not None
assert on_stream_end is not None
await on_stream("hi")
await on_stream_end(resuming=False)
return None
loop._process_message = fake_process
await loop._dispatch(msg)
first = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
second = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert first.metadata["thread_root_event_id"] == "$root1"
assert first.metadata["thread_reply_to_event_id"] == "$reply1"
assert first.metadata["_stream_delta"] is True
assert second.metadata["thread_root_event_id"] == "$root1"
assert second.metadata["thread_reply_to_event_id"] == "$reply1"
assert second.metadata["_stream_end"] is True
@pytest.mark.asyncio
async def test_processing_lock_serializes(self):
from nanobot.bus.events import InboundMessage, OutboundMessage
@ -222,6 +260,39 @@ class TestSubagentCancellation:
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
@pytest.mark.asyncio
async def test_subagent_exec_tool_not_registered_when_disabled(self, tmp_path):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
from nanobot.config.schema import ExecToolConfig
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
exec_config=ExecToolConfig(enable=False),
)
mgr._announce_result = AsyncMock()
async def fake_run(spec):
assert spec.tools.get("exec") is None
return SimpleNamespace(
stop_reason="done",
final_content="done",
error=None,
tool_events=[],
)
mgr.runner.run = AsyncMock(side_effect=fake_run)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
mgr.runner.run.assert_awaited_once()
mgr._announce_result.assert_awaited_once()
@pytest.mark.asyncio
async def test_subagent_announces_error_when_tool_execution_fails(self, monkeypatch, tmp_path):
from nanobot.agent.subagent import SubagentManager

View File

@ -0,0 +1,676 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from types import SimpleNamespace
import pytest
discord = pytest.importorskip("discord")
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig
from nanobot.command.builtin import build_help_text
# Minimal Discord client test double used to control startup/readiness behavior.
class _FakeDiscordClient:
instances: list["_FakeDiscordClient"] = []
start_error: Exception | None = None
def __init__(self, owner, *, intents) -> None:
self.owner = owner
self.intents = intents
self.closed = False
self.ready = True
self.channels: dict[int, object] = {}
self.user = SimpleNamespace(id=999)
self.__class__.instances.append(self)
async def start(self, token: str) -> None:
self.token = token
if self.__class__.start_error is not None:
raise self.__class__.start_error
async def close(self) -> None:
self.closed = True
def is_closed(self) -> bool:
return self.closed
def is_ready(self) -> bool:
return self.ready
def get_channel(self, channel_id: int):
return self.channels.get(channel_id)
async def send_outbound(self, msg: OutboundMessage) -> None:
channel = self.get_channel(int(msg.chat_id))
if channel is None:
return
await channel.send(content=msg.content)
class _FakeAttachment:
# Attachment double that can simulate successful or failing save() calls.
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
self.id = attachment_id
self.filename = filename
self.size = size
self._fail = fail
async def save(self, path: str | Path) -> None:
if self._fail:
raise RuntimeError("save failed")
Path(path).write_bytes(b"attachment")
class _FakePartialMessage:
# Lightweight stand-in for Discord partial message references used in replies.
def __init__(self, message_id: int) -> None:
self.id = message_id
class _FakeChannel:
# Channel double that records outbound payloads and typing activity.
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
self.sent_payloads: list[dict] = []
self.trigger_typing_calls = 0
self.typing_enter_hook = None
async def send(self, **kwargs) -> None:
payload = dict(kwargs)
if "file" in payload:
payload["file_name"] = payload["file"].filename
del payload["file"]
self.sent_payloads.append(payload)
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
return _FakePartialMessage(message_id)
def typing(self):
channel = self
class _TypingContext:
async def __aenter__(self):
channel.trigger_typing_calls += 1
if channel.typing_enter_hook is not None:
await channel.typing_enter_hook()
async def __aexit__(self, exc_type, exc, tb):
return False
return _TypingContext()
class _FakeInteractionResponse:
def __init__(self) -> None:
self.messages: list[dict] = []
self._done = False
async def send_message(self, content: str, *, ephemeral: bool = False) -> None:
self.messages.append({"content": content, "ephemeral": ephemeral})
self._done = True
def is_done(self) -> bool:
return self._done
def _make_interaction(
*,
user_id: int = 123,
channel_id: int | None = 456,
guild_id: int | None = None,
interaction_id: int = 999,
):
return SimpleNamespace(
user=SimpleNamespace(id=user_id),
channel_id=channel_id,
guild_id=guild_id,
id=interaction_id,
command=SimpleNamespace(qualified_name="new"),
response=_FakeInteractionResponse(),
)
def _make_message(
*,
author_id: int = 123,
author_bot: bool = False,
channel_id: int = 456,
message_id: int = 789,
content: str = "hello",
guild_id: int | None = None,
mentions: list[object] | None = None,
attachments: list[object] | None = None,
reply_to: int | None = None,
):
# Factory for incoming Discord message objects with optional guild/reply/attachments.
guild = SimpleNamespace(id=guild_id) if guild_id is not None else None
reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None
return SimpleNamespace(
author=SimpleNamespace(id=author_id, bot=author_bot),
channel=_FakeChannel(channel_id),
content=content,
guild=guild,
mentions=mentions or [],
attachments=attachments or [],
reference=reference,
id=message_id,
)
@pytest.mark.asyncio
async def test_start_returns_when_token_missing() -> None:
# If no token is configured, startup should no-op and leave channel stopped.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None:
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False)
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_handles_client_construction_failure(monkeypatch) -> None:
# Construction errors from the Discord client should be swallowed and keep state clean.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
def _boom(owner, *, intents):
raise RuntimeError("bad client")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_handles_client_start_failure(monkeypatch) -> None:
# If client.start fails, the partially created client should be closed and detached.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
_FakeDiscordClient.instances.clear()
_FakeDiscordClient.start_error = RuntimeError("connect failed")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert channel._client is None
assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents
assert _FakeDiscordClient.instances[0].closed is True
_FakeDiscordClient.start_error = None
@pytest.mark.asyncio
async def test_stop_is_safe_after_partial_start(monkeypatch) -> None:
# stop() should close/discard the client even when startup was only partially completed.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
await channel.stop()
assert channel.is_running is False
assert client.closed is True
assert channel._client is None
@pytest.mark.asyncio
async def test_on_message_ignores_bot_messages() -> None:
# Incoming bot-authored messages must be ignored to prevent feedback loops.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign]
await channel._on_message(_make_message(author_bot=True))
assert handled == []
# If inbound handling raises, typing should be stopped for that channel.
async def fail_handle(**kwargs) -> None:
raise RuntimeError("boom")
channel._handle_message = fail_handle # type: ignore[method-assign]
with pytest.raises(RuntimeError, match="boom"):
await channel._on_message(_make_message(author_id=123, channel_id=456))
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_on_message_accepts_allowlisted_dm() -> None:
# Allowed direct messages should be forwarded with normalized metadata.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789))
assert len(handled) == 1
assert handled[0]["chat_id"] == "456"
assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None}
@pytest.mark.asyncio
async def test_on_message_ignores_unmentioned_guild_message() -> None:
# With mention-only group policy, guild messages without a bot mention are dropped.
channel = DiscordChannel(
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._bot_user_id = "999"
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(_make_message(guild_id=1, content="hello everyone"))
assert handled == []
@pytest.mark.asyncio
async def test_on_message_accepts_mentioned_guild_message() -> None:
# Mentioned guild messages should be accepted and preserve reply threading metadata.
channel = DiscordChannel(
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._bot_user_id = "999"
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(
_make_message(
guild_id=1,
content="<@999> hello",
mentions=[SimpleNamespace(id=999)],
reply_to=321,
)
)
assert len(handled) == 1
assert handled[0]["metadata"]["reply_to"] == "321"
@pytest.mark.asyncio
async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None:
# Attachment downloads should be saved and referenced in forwarded content/media.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
await channel._on_message(
_make_message(
attachments=[_FakeAttachment(12, "photo.png")],
content="see file",
)
)
assert len(handled) == 1
assert handled[0]["media"] == [str(tmp_path / "12_photo.png")]
assert "[attachment:" in handled[0]["content"]
@pytest.mark.asyncio
async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None:
# Failed attachment downloads should emit a readable placeholder and no media path.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
await channel._on_message(
_make_message(
attachments=[_FakeAttachment(12, "photo.png", fail=True)],
content="",
)
)
assert len(handled) == 1
assert handled[0]["media"] == []
assert handled[0]["content"] == "[attachment: photo.png - download failed]"
@pytest.mark.asyncio
async def test_send_warns_when_client_not_ready() -> None:
# Sending without a running/ready client should be a safe no-op.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_send_skips_when_channel_not_cached() -> None:
# Outbound sends should be skipped when the destination channel is not resolvable.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
fetch_calls: list[int] = []
async def fetch_channel(channel_id: int):
fetch_calls.append(channel_id)
raise RuntimeError("not found")
client.fetch_channel = fetch_channel # type: ignore[method-assign]
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert client.get_channel(123) is None
assert fetch_calls == [123]
@pytest.mark.asyncio
async def test_send_fetches_channel_when_not_cached() -> None:
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
async def fetch_channel(channel_id: int):
return target if channel_id == 123 else None
client.fetch_channel = fetch_channel # type: ignore[method-assign]
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert target.sent_payloads == [{"content": "hello"}]
@pytest.mark.asyncio
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321)
new_cmd = client.tree.get_command("new")
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "Processing /new...", "ephemeral": True}
]
assert len(handled) == 1
assert handled[0]["content"] == "/new"
assert handled[0]["sender_id"] == "123"
assert handled[0]["chat_id"] == "456"
assert handled[0]["metadata"]["interaction_id"] == "321"
assert handled[0]["metadata"]["is_slash_command"] is True
@pytest.mark.asyncio
async def test_slash_new_is_blocked_for_disallowed_user() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction(user_id=123, channel_id=456)
new_cmd = client.tree.get_command("new")
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "You are not allowed to use this bot.", "ephemeral": True}
]
assert handled == []
@pytest.mark.parametrize("slash_name", ["stop", "restart", "status"])
@pytest.mark.asyncio
async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction()
interaction.command.qualified_name = slash_name
cmd = client.tree.get_command(slash_name)
assert cmd is not None
await cmd.callback(interaction)
assert interaction.response.messages == [
{"content": f"Processing /{slash_name}...", "ephemeral": True}
]
assert len(handled) == 1
assert handled[0]["content"] == f"/{slash_name}"
assert handled[0]["metadata"]["is_slash_command"] is True
@pytest.mark.asyncio
async def test_slash_help_returns_ephemeral_help_text() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction()
interaction.command.qualified_name = "help"
help_cmd = client.tree.get_command("help")
assert help_cmd is not None
await help_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": build_help_text(), "ephemeral": True}
]
assert handled == []
@pytest.mark.asyncio
async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None:
# Outbound payloads should upload files, attach reply references, and chunk long text.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
file_path = tmp_path / "demo.txt"
file_path.write_text("hi")
await client.send_outbound(
OutboundMessage(
channel="discord",
chat_id="123",
content="a" * 2100,
reply_to="55",
media=[str(file_path)],
)
)
assert len(target.sent_payloads) == 3
assert target.sent_payloads[0]["file_name"] == "demo.txt"
assert target.sent_payloads[0]["reference"].id == 55
assert target.sent_payloads[1]["content"] == "a" * 2000
assert target.sent_payloads[2]["content"] == "a" * 100
@pytest.mark.asyncio
async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None:
# If all attachment sends fail and no text exists, emit a failure placeholder message.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
missing_file = tmp_path / "missing.txt"
await client.send_outbound(
OutboundMessage(
channel="discord",
chat_id="123",
content="",
media=[str(missing_file)],
)
)
assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}]
@pytest.mark.asyncio
async def test_send_stops_typing_after_send() -> None:
# Active typing indicators should be cancelled/cleared after a successful send.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
start = asyncio.Event()
release = asyncio.Event()
async def slow_typing() -> None:
start.set()
await release.wait()
typing_channel = _FakeChannel(channel_id=123)
typing_channel.typing_enter_hook = slow_typing
await channel._start_typing(typing_channel)
await start.wait()
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}
# Progress messages should keep typing active until a final (non-progress) send.
start = asyncio.Event()
release = asyncio.Event()
async def slow_typing_progress() -> None:
start.set()
await release.wait()
typing_channel = _FakeChannel(channel_id=123)
typing_channel.typing_enter_hook = slow_typing_progress
await channel._start_typing(typing_channel)
await start.wait()
await channel.send(
OutboundMessage(
channel="discord",
chat_id="123",
content="progress",
metadata={"_progress": True},
)
)
assert "123" in channel._typing_tasks
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final"))
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
channel._running = True
entered = asyncio.Event()
release = asyncio.Event()
class _TypingCtx:
async def __aenter__(self):
entered.set()
async def __aexit__(self, exc_type, exc, tb):
return False
class _NoTriggerChannel:
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
def typing(self):
async def _waiter():
await release.wait()
# Hold the loop so task remains active until explicitly stopped.
class _Ctx(_TypingCtx):
async def __aenter__(self):
await super().__aenter__()
await _waiter()
return _Ctx()
typing_channel = _NoTriggerChannel(channel_id=123)
await channel._start_typing(typing_channel) # type: ignore[arg-type]
await entered.wait()
assert "123" in channel._typing_tasks
await channel._stop_typing("123")
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}

View File

@ -3,6 +3,9 @@ from pathlib import Path
from types import SimpleNamespace
import pytest
from nio import RoomSendResponse
from nanobot.channels.matrix import _build_matrix_text_content
# Check optional matrix dependencies before importing
try:
@ -65,6 +68,7 @@ class _FakeAsyncClient:
self.raise_on_send = False
self.raise_on_typing = False
self.raise_on_upload = False
self.room_send_response: RoomSendResponse | None = RoomSendResponse(event_id="", room_id="")
def add_event_callback(self, callback, event_type) -> None:
self.callbacks.append((callback, event_type))
@ -87,7 +91,7 @@ class _FakeAsyncClient:
message_type: str,
content: dict[str, object],
ignore_unverified_devices: object = _ROOM_SEND_UNSET,
) -> None:
) -> RoomSendResponse:
call: dict[str, object] = {
"room_id": room_id,
"message_type": message_type,
@ -98,6 +102,7 @@ class _FakeAsyncClient:
self.room_send_calls.append(call)
if self.raise_on_send:
raise RuntimeError("send failed")
return self.room_send_response
async def room_typing(
self,
@ -520,6 +525,7 @@ async def test_on_message_room_mention_requires_opt_in() -> None:
source={"content": {"m.mentions": {"room": True}}},
)
channel.config.allow_room_mentions = False
await channel._on_message(room, room_mention_event)
assert handled == []
assert client.typing_calls == []
@ -1322,3 +1328,302 @@ async def test_send_keeps_plaintext_only_for_plain_text() -> None:
"body": text,
"m.mentions": {},
}
def test_build_matrix_text_content_basic_text() -> None:
"""Test basic text content without HTML formatting."""
result = _build_matrix_text_content("Hello, World!")
expected = {
"msgtype": "m.text",
"body": "Hello, World!",
"m.mentions": {}
}
assert expected == result
def test_build_matrix_text_content_with_markdown() -> None:
"""Test text content with markdown that renders to HTML."""
text = "*Hello* **World**"
result = _build_matrix_text_content(text)
assert "msgtype" in result
assert "body" in result
assert result["body"] == text
assert "format" in result
assert result["format"] == "org.matrix.custom.html"
assert "formatted_body" in result
assert isinstance(result["formatted_body"], str)
assert len(result["formatted_body"]) > 0
def test_build_matrix_text_content_with_event_id() -> None:
"""Test text content with event_id for message replacement."""
event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
result = _build_matrix_text_content("Updated message", event_id)
assert "msgtype" in result
assert "body" in result
assert result["m.new_content"]
assert result["m.new_content"]["body"] == "Updated message"
assert result["m.relates_to"]["rel_type"] == "m.replace"
assert result["m.relates_to"]["event_id"] == event_id
def test_build_matrix_text_content_with_event_id_preserves_thread_relation() -> None:
"""Thread relations for edits should stay inside m.new_content."""
relates_to = {
"rel_type": "m.thread",
"event_id": "$root1",
"m.in_reply_to": {"event_id": "$reply1"},
"is_falling_back": True,
}
result = _build_matrix_text_content("Updated message", "event-1", relates_to)
assert result["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "event-1",
}
assert result["m.new_content"]["m.relates_to"] == relates_to
def test_build_matrix_text_content_no_event_id() -> None:
"""Test that when event_id is not provided, no extra properties are added."""
result = _build_matrix_text_content("Regular message")
# Basic required properties should be present
assert "msgtype" in result
assert "body" in result
assert result["body"] == "Regular message"
# Extra properties for replacement should NOT be present
assert "m.relates_to" not in result
assert "m.new_content" not in result
assert "format" not in result
assert "formatted_body" not in result
def test_build_matrix_text_content_plain_text_no_html() -> None:
"""Test plain text that should not include HTML formatting."""
result = _build_matrix_text_content("Simple plain text")
assert "msgtype" in result
assert "body" in result
assert "format" not in result
assert "formatted_body" not in result
@pytest.mark.asyncio
async def test_send_room_content_returns_room_send_response():
"""Test that _send_room_content returns the response from client.room_send."""
client = _FakeAsyncClient("", "", "", None)
channel = MatrixChannel(_make_config(), MessageBus())
channel.client = client
room_id = "!test_room:matrix.org"
content = {"msgtype": "m.text", "body": "Hello World"}
result = await channel._send_room_content(room_id, content)
assert result is client.room_send_response
@pytest.mark.asyncio
async def test_send_delta_creates_stream_buffer_and_sends_initial_message() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
await channel.send_delta("!room:matrix.org", "Hello")
assert "!room:matrix.org" in channel._stream_bufs
buf = channel._stream_bufs["!room:matrix.org"]
assert buf.text == "Hello"
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
assert len(client.room_send_calls) == 1
assert client.room_send_calls[0]["content"]["body"] == "Hello"
@pytest.mark.asyncio
async def test_send_delta_appends_without_sending_before_edit_interval(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
now = 100.0
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
await channel.send_delta("!room:matrix.org", "Hello")
assert len(client.room_send_calls) == 1
await channel.send_delta("!room:matrix.org", " world")
assert len(client.room_send_calls) == 1
buf = channel._stream_bufs["!room:matrix.org"]
assert buf.text == "Hello world"
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
@pytest.mark.asyncio
async def test_send_delta_edits_again_after_interval(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
times = [100.0, 102.0, 104.0, 106.0, 108.0]
times.reverse()
monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop())
await channel.send_delta("!room:matrix.org", "Hello")
await channel.send_delta("!room:matrix.org", " world")
assert len(client.room_send_calls) == 2
first_content = client.room_send_calls[0]["content"]
second_content = client.room_send_calls[1]["content"]
assert "body" in first_content
assert first_content["body"] == "Hello"
assert "m.relates_to" not in first_content
assert "body" in second_content
assert "m.relates_to" in second_content
assert second_content["body"] == "Hello world"
assert second_content["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo",
}
@pytest.mark.asyncio
async def test_send_delta_stream_end_replaces_existing_message() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
channel._stream_bufs["!room:matrix.org"] = matrix_module._StreamBuf(
text="Final text",
event_id="event-1",
last_edit=100.0,
)
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
assert "!room:matrix.org" not in channel._stream_bufs
assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
assert len(client.room_send_calls) == 1
assert client.room_send_calls[0]["content"]["body"] == "Final text"
assert client.room_send_calls[0]["content"]["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "event-1",
}
@pytest.mark.asyncio
async def test_send_delta_starts_threaded_stream_inside_thread() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "event-1"
metadata = {
"thread_root_event_id": "$root1",
"thread_reply_to_event_id": "$reply1",
}
await channel.send_delta("!room:matrix.org", "Hello", metadata)
assert client.room_send_calls[0]["content"]["m.relates_to"] == {
"rel_type": "m.thread",
"event_id": "$root1",
"m.in_reply_to": {"event_id": "$reply1"},
"is_falling_back": True,
}
@pytest.mark.asyncio
async def test_send_delta_threaded_edit_keeps_replace_and_thread_relation(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "event-1"
times = [100.0, 102.0, 104.0]
times.reverse()
monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop())
metadata = {
"thread_root_event_id": "$root1",
"thread_reply_to_event_id": "$reply1",
}
await channel.send_delta("!room:matrix.org", "Hello", metadata)
await channel.send_delta("!room:matrix.org", " world", metadata)
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True, **metadata})
edit_content = client.room_send_calls[1]["content"]
final_content = client.room_send_calls[2]["content"]
assert edit_content["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "event-1",
}
assert edit_content["m.new_content"]["m.relates_to"] == {
"rel_type": "m.thread",
"event_id": "$root1",
"m.in_reply_to": {"event_id": "$reply1"},
"is_falling_back": True,
}
assert final_content["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "event-1",
}
assert final_content["m.new_content"]["m.relates_to"] == {
"rel_type": "m.thread",
"event_id": "$root1",
"m.in_reply_to": {"event_id": "$reply1"},
"is_falling_back": True,
}
@pytest.mark.asyncio
async def test_send_delta_stream_end_noop_when_buffer_missing() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
assert client.room_send_calls == []
assert client.typing_calls == []
@pytest.mark.asyncio
async def test_send_delta_on_error_stops_typing(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
client.raise_on_send = True
channel.client = client
now = 100.0
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
await channel.send_delta("!room:matrix.org", "Hello", {"room_id": "!room:matrix.org"})
assert "!room:matrix.org" in channel._stream_bufs
assert channel._stream_bufs["!room:matrix.org"].text == "Hello"
assert len(client.room_send_calls) == 1
assert len(client.typing_calls) == 1
@pytest.mark.asyncio
async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
now = 100.0
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
await channel.send_delta("!room:matrix.org", " ")
assert "!room:matrix.org" in channel._stream_bufs
assert channel._stream_bufs["!room:matrix.org"].text == " "
assert client.room_send_calls == []

View File

@ -642,27 +642,105 @@ def test_heartbeat_retains_recent_messages_by_default():
assert config.gateway.heartbeat.keep_recent_messages == 8
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
def _write_instance_config(tmp_path: Path) -> Path:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
return config_file
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
def _stop_gateway_provider(_config) -> object:
raise _StopGatewayError("stop")
def _patch_cli_command_runtime(
monkeypatch,
config: Config,
*,
set_config_path=None,
sync_templates=None,
make_provider=None,
message_bus=None,
session_manager=None,
cron_service=None,
get_cron_dir=None,
) -> None:
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
set_config_path or (lambda _path: None),
)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr(
"nanobot.cli.commands.sync_workspace_templates",
lambda path: seen.__setitem__("workspace", path),
sync_templates or (lambda _path: None),
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
make_provider or (lambda _config: object()),
)
if message_bus is not None:
monkeypatch.setattr("nanobot.bus.queue.MessageBus", message_bus)
if session_manager is not None:
monkeypatch.setattr("nanobot.session.manager.SessionManager", session_manager)
if cron_service is not None:
monkeypatch.setattr("nanobot.cron.service.CronService", cron_service)
if get_cron_dir is not None:
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", get_cron_dir)
def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) -> None:
pytest.importorskip("aiohttp")
class _FakeApiApp:
def __init__(self) -> None:
self.on_startup: list[object] = []
self.on_cleanup: list[object] = []
class _FakeAgentLoop:
def __init__(self, **kwargs) -> None:
seen["workspace"] = kwargs["workspace"]
async def _connect_mcp(self) -> None:
return None
async def close_mcp(self) -> None:
return None
def _fake_create_app(agent_loop, model_name: str, request_timeout: float):
seen["agent_loop"] = agent_loop
seen["model_name"] = model_name
seen["request_timeout"] = request_timeout
return _FakeApiApp()
def _fake_run_app(api_app, host: str, port: int, print):
seen["api_app"] = api_app
seen["host"] = host
seen["port"] = port
_patch_cli_command_runtime(
monkeypatch,
config,
message_bus=lambda: object(),
session_manager=lambda _workspace: object(),
)
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.api.server.create_app", _fake_create_app)
monkeypatch.setattr("aiohttp.web.run_app", _fake_run_app)
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
config_file = _write_instance_config(tmp_path)
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
_patch_cli_command_runtime(
monkeypatch,
config,
set_config_path=lambda path: seen.__setitem__("config_path", path),
sync_templates=lambda path: seen.__setitem__("workspace", path),
make_provider=_stop_gateway_provider,
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
@ -673,24 +751,17 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa
def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
override = tmp_path / "override-workspace"
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr(
"nanobot.cli.commands.sync_workspace_templates",
lambda path: seen.__setitem__("workspace", path),
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
_patch_cli_command_runtime(
monkeypatch,
config,
sync_templates=lambda path: seen.__setitem__("workspace", path),
make_provider=_stop_gateway_provider,
)
result = runner.invoke(
@ -704,27 +775,23 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
raise _StopGatewayError("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
_patch_cli_command_runtime(
monkeypatch,
config,
message_bus=lambda: object(),
session_manager=lambda _workspace: object(),
cron_service=_StopCron,
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
@ -735,10 +802,7 @@ def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path:
def test_gateway_workspace_override_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path
) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
legacy_dir = tmp_path / "global" / "cron"
legacy_dir.mkdir(parents=True)
legacy_file = legacy_dir / "jobs.json"
@ -748,20 +812,19 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron(
config = Config()
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
raise _StopGatewayError("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
_patch_cli_command_runtime(
monkeypatch,
config,
message_bus=lambda: object(),
session_manager=lambda _workspace: object(),
cron_service=_StopCron,
get_cron_dir=lambda: legacy_dir,
)
result = runner.invoke(
app,
@ -777,10 +840,7 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron(
def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path
) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
legacy_dir = tmp_path / "global" / "cron"
legacy_dir.mkdir(parents=True)
legacy_file = legacy_dir / "jobs.json"
@ -791,20 +851,19 @@ def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
config.agents.defaults.workspace = str(custom_workspace)
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
raise _StopGatewayError("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
_patch_cli_command_runtime(
monkeypatch,
config,
message_bus=lambda: object(),
session_manager=lambda _workspace: object(),
cron_service=_StopCron,
get_cron_dir=lambda: legacy_dir,
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
@ -856,19 +915,14 @@ def test_migrate_cron_store_skips_when_workspace_file_exists(tmp_path: Path) ->
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
_patch_cli_command_runtime(
monkeypatch,
config,
make_provider=_stop_gateway_provider,
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
@ -878,19 +932,14 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
_patch_cli_command_runtime(
monkeypatch,
config,
make_provider=_stop_gateway_provider,
)
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
@ -899,6 +948,63 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
assert "port 18792" in result.stdout
def test_serve_uses_api_config_defaults_and_workspace_override(
monkeypatch, tmp_path: Path
) -> None:
config_file = _write_instance_config(tmp_path)
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
config.api.host = "127.0.0.2"
config.api.port = 18900
config.api.timeout = 45.0
override_workspace = tmp_path / "override-workspace"
seen: dict[str, object] = {}
_patch_serve_runtime(monkeypatch, config, seen)
result = runner.invoke(
app,
["serve", "--config", str(config_file), "--workspace", str(override_workspace)],
)
assert result.exit_code == 0
assert seen["workspace"] == override_workspace
assert seen["host"] == "127.0.0.2"
assert seen["port"] == 18900
assert seen["request_timeout"] == 45.0
def test_serve_cli_options_override_api_config(monkeypatch, tmp_path: Path) -> None:
config_file = _write_instance_config(tmp_path)
config = Config()
config.api.host = "127.0.0.2"
config.api.port = 18900
config.api.timeout = 45.0
seen: dict[str, object] = {}
_patch_serve_runtime(monkeypatch, config, seen)
result = runner.invoke(
app,
[
"serve",
"--config",
str(config_file),
"--host",
"127.0.0.1",
"--port",
"18901",
"--timeout",
"46",
],
)
assert result.exit_code == 0
assert seen["host"] == "127.0.0.1"
assert seen["port"] == 18901
assert seen["request_timeout"] == 46.0
def test_channels_login_requires_channel_name() -> None:
result = runner.invoke(app, ["channels", "login"])

View File

@ -0,0 +1,147 @@
"""Tests for the Nanobot programmatic facade."""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.nanobot import Nanobot, RunResult
def _write_config(tmp_path: Path, overrides: dict | None = None) -> Path:
data = {
"providers": {"openrouter": {"apiKey": "sk-test-key"}},
"agents": {"defaults": {"model": "openai/gpt-4.1"}},
}
if overrides:
data.update(overrides)
config_path = tmp_path / "config.json"
config_path.write_text(json.dumps(data))
return config_path
def test_from_config_missing_file():
with pytest.raises(FileNotFoundError):
Nanobot.from_config("/nonexistent/config.json")
def test_from_config_creates_instance(tmp_path):
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
assert bot._loop is not None
assert bot._loop.workspace == tmp_path
def test_from_config_default_path():
from nanobot.config.schema import Config
with patch("nanobot.config.loader.load_config") as mock_load, \
patch("nanobot.nanobot._make_provider") as mock_prov:
mock_load.return_value = Config()
mock_prov.return_value = MagicMock()
mock_prov.return_value.get_default_model.return_value = "test"
mock_prov.return_value.generation.max_tokens = 4096
Nanobot.from_config()
mock_load.assert_called_once_with(None)
@pytest.mark.asyncio
async def test_run_returns_result(tmp_path):
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
from nanobot.bus.events import OutboundMessage
mock_response = OutboundMessage(
channel="cli", chat_id="direct", content="Hello back!"
)
bot._loop.process_direct = AsyncMock(return_value=mock_response)
result = await bot.run("hi")
assert isinstance(result, RunResult)
assert result.content == "Hello back!"
bot._loop.process_direct.assert_awaited_once_with("hi", session_key="sdk:default")
@pytest.mark.asyncio
async def test_run_with_hooks(tmp_path):
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.bus.events import OutboundMessage
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
class TestHook(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
pass
mock_response = OutboundMessage(
channel="cli", chat_id="direct", content="done"
)
bot._loop.process_direct = AsyncMock(return_value=mock_response)
result = await bot.run("hi", hooks=[TestHook()])
assert result.content == "done"
assert bot._loop._extra_hooks == []
@pytest.mark.asyncio
async def test_run_hooks_restored_on_error(tmp_path):
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
from nanobot.agent.hook import AgentHook
bot._loop.process_direct = AsyncMock(side_effect=RuntimeError("boom"))
original_hooks = bot._loop._extra_hooks
with pytest.raises(RuntimeError):
await bot.run("hi", hooks=[AgentHook()])
assert bot._loop._extra_hooks is original_hooks
@pytest.mark.asyncio
async def test_run_none_response(tmp_path):
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
bot._loop.process_direct = AsyncMock(return_value=None)
result = await bot.run("hi")
assert result.content == ""
def test_workspace_override(tmp_path):
config_path = _write_config(tmp_path)
custom_ws = tmp_path / "custom_workspace"
custom_ws.mkdir()
bot = Nanobot.from_config(config_path, workspace=custom_ws)
assert bot._loop.workspace == custom_ws
@pytest.mark.asyncio
async def test_run_custom_session_key(tmp_path):
from nanobot.bus.events import OutboundMessage
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
mock_response = OutboundMessage(
channel="cli", chat_id="direct", content="ok"
)
bot._loop.process_direct = AsyncMock(return_value=mock_response)
await bot.run("hi", session_key="user-alice")
bot._loop.process_direct.assert_awaited_once_with("hi", session_key="user-alice")
def test_import_from_top_level():
from nanobot import Nanobot as N, RunResult as R
assert N is Nanobot
assert R is RunResult

372
tests/test_openai_api.py Normal file
View File

@ -0,0 +1,372 @@
"""Focused tests for the fixed-session OpenAI-compatible API."""
from __future__ import annotations
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
from nanobot.api.server import (
API_CHAT_ID,
API_SESSION_KEY,
_chat_completion_response,
_error_json,
create_app,
handle_chat_completions,
)
try:
from aiohttp.test_utils import TestClient, TestServer
HAS_AIOHTTP = True
except ImportError:
HAS_AIOHTTP = False
pytest_plugins = ("pytest_asyncio",)
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
agent = MagicMock()
agent.process_direct = AsyncMock(return_value=response_text)
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
return agent
@pytest.fixture
def mock_agent():
return _make_mock_agent()
@pytest.fixture
def app(mock_agent):
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
@pytest_asyncio.fixture
async def aiohttp_client():
clients: list[TestClient] = []
async def _make_client(app):
client = TestClient(TestServer(app))
await client.start_server()
clients.append(client)
return client
try:
yield _make_client
finally:
for client in clients:
await client.close()
def test_error_json() -> None:
resp = _error_json(400, "bad request")
assert resp.status == 400
body = json.loads(resp.body)
assert body["error"]["message"] == "bad request"
assert body["error"]["code"] == 400
def test_chat_completion_response() -> None:
result = _chat_completion_response("hello world", "test-model")
assert result["object"] == "chat.completion"
assert result["model"] == "test-model"
assert result["choices"][0]["message"]["content"] == "hello world"
assert result["choices"][0]["finish_reason"] == "stop"
assert result["id"].startswith("chatcmpl-")
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_missing_messages_returns_400(aiohttp_client, app) -> None:
client = await aiohttp_client(app)
resp = await client.post("/v1/chat/completions", json={"model": "test"})
assert resp.status == 400
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_no_user_message_returns_400(aiohttp_client, app) -> None:
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "system", "content": "you are a bot"}]},
)
assert resp.status == 400
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_true_returns_400(aiohttp_client, app) -> None:
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}], "stream": True},
)
assert resp.status == 400
body = await resp.json()
assert "stream" in body["error"]["message"].lower()
@pytest.mark.asyncio
async def test_model_mismatch_returns_400() -> None:
request = MagicMock()
request.json = AsyncMock(
return_value={
"model": "other-model",
"messages": [{"role": "user", "content": "hello"}],
}
)
request.app = {
"agent_loop": _make_mock_agent(),
"model_name": "test-model",
"request_timeout": 10.0,
"session_lock": asyncio.Lock(),
}
resp = await handle_chat_completions(request)
assert resp.status == 400
body = json.loads(resp.body)
assert "test-model" in body["error"]["message"]
@pytest.mark.asyncio
async def test_single_user_message_required() -> None:
request = MagicMock()
request.json = AsyncMock(
return_value={
"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "previous reply"},
],
}
)
request.app = {
"agent_loop": _make_mock_agent(),
"model_name": "test-model",
"request_timeout": 10.0,
"session_lock": asyncio.Lock(),
}
resp = await handle_chat_completions(request)
assert resp.status == 400
body = json.loads(resp.body)
assert "single user message" in body["error"]["message"].lower()
@pytest.mark.asyncio
async def test_single_user_message_must_have_user_role() -> None:
request = MagicMock()
request.json = AsyncMock(
return_value={
"messages": [{"role": "system", "content": "you are a bot"}],
}
)
request.app = {
"agent_loop": _make_mock_agent(),
"model_name": "test-model",
"request_timeout": 10.0,
"session_lock": asyncio.Lock(),
}
resp = await handle_chat_completions(request)
assert resp.status == 400
body = json.loads(resp.body)
assert "single user message" in body["error"]["message"].lower()
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_agent) -> None:
app = create_app(mock_agent, model_name="test-model")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "mock response"
assert body["model"] == "test-model"
mock_agent.process_direct.assert_called_once_with(
content="hello",
session_key=API_SESSION_KEY,
channel="api",
chat_id=API_CHAT_ID,
)
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
call_log: list[str] = []
async def fake_process(content, session_key="", channel="", chat_id=""):
call_log.append(session_key)
return f"reply to {content}"
agent = MagicMock()
agent.process_direct = fake_process
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
r1 = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "first"}]},
)
r2 = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "second"}]},
)
assert r1.status == 200
assert r2.status == 200
assert call_log == [API_SESSION_KEY, API_SESSION_KEY]
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None:
order: list[str] = []
barrier = asyncio.Event()
async def slow_process(content, session_key="", channel="", chat_id=""):
order.append(f"start:{content}")
if content == "first":
barrier.set()
await asyncio.sleep(0.1)
else:
await barrier.wait()
order.append(f"end:{content}")
return content
agent = MagicMock()
agent.process_direct = slow_process
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
async def send(msg: str):
return await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": msg}]},
)
r1, r2 = await asyncio.gather(send("first"), send("second"))
assert r1.status == 200
assert r2.status == 200
assert order.index("end:first") < order.index("start:second")
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_models_endpoint(aiohttp_client, app) -> None:
client = await aiohttp_client(app)
resp = await client.get("/v1/models")
assert resp.status == 200
body = await resp.json()
assert body["object"] == "list"
assert body["data"][0]["id"] == "test-model"
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_health_endpoint(aiohttp_client, app) -> None:
client = await aiohttp_client(app)
resp = await client.get("/health")
assert resp.status == 200
body = await resp.json()
assert body["status"] == "ok"
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> None:
app = create_app(mock_agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "describe this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
],
}
]
},
)
assert resp.status == 200
mock_agent.process_direct.assert_called_once_with(
content="describe this",
session_key=API_SESSION_KEY,
channel="api",
chat_id=API_CHAT_ID,
)
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_empty_response_retry_then_success(aiohttp_client) -> None:
call_count = 0
async def sometimes_empty(content, session_key="", channel="", chat_id=""):
nonlocal call_count
call_count += 1
if call_count == 1:
return ""
return "recovered response"
agent = MagicMock()
agent.process_direct = sometimes_empty
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "recovered response"
assert call_count == 2
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_empty_response_falls_back(aiohttp_client) -> None:
call_count = 0
async def always_empty(content, session_key="", channel="", chat_id=""):
nonlocal call_count
call_count += 1
return ""
agent = MagicMock()
agent.process_direct = always_empty
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give."
assert call_count == 2