mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-04 18:32:44 +00:00
Merge remote-tracking branch 'origin/main' into pr-2614
This commit is contained in:
commit
6aad945719
108
README.md
108
README.md
@ -115,6 +115,8 @@
|
||||
- [Configuration](#️-configuration)
|
||||
- [Multiple Instances](#-multiple-instances)
|
||||
- [CLI Reference](#-cli-reference)
|
||||
- [Python SDK](#-python-sdk)
|
||||
- [OpenAI-Compatible API](#-openai-compatible-api)
|
||||
- [Docker](#-docker)
|
||||
- [Linux Service](#-linux-service)
|
||||
- [Project Structure](#-project-structure)
|
||||
@ -854,7 +856,6 @@ Config file: `~/.nanobot/config.json`
|
||||
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
||||
> - **Alibaba Cloud BaiLian**: If you're using Alibaba Cloud BaiLian's OpenAI-compatible endpoint, set `"apiBase": "https://dashscope.aliyuncs.com/compatible-mode/v1"` in your dashscope provider config.
|
||||
> - **Step Fun (Mainland China)**: If your API key is from Step Fun's mainland China platform (stepfun.com), set `"apiBase": "https://api.stepfun.com/v1"` in your stepfun provider config.
|
||||
> - **Step Fun Step Plan**: Exclusive discount links for the nanobot community: [Overseas](https://platform.stepfun.ai/step-plan) · [Mainland China](https://platform.stepfun.com/step-plan)
|
||||
|
||||
| Provider | Purpose | Get API Key |
|
||||
|----------|---------|-------------|
|
||||
@ -1542,6 +1543,7 @@ nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobo
|
||||
| `nanobot agent` | Interactive chat mode |
|
||||
| `nanobot agent --no-markdown` | Show plain-text replies |
|
||||
| `nanobot agent --logs` | Show runtime logs during chat |
|
||||
| `nanobot serve` | Start the OpenAI-compatible API |
|
||||
| `nanobot gateway` | Start the gateway |
|
||||
| `nanobot status` | Show status |
|
||||
| `nanobot provider login openai-codex` | OAuth login for providers |
|
||||
@ -1570,6 +1572,110 @@ The agent can also manage this file itself — ask it to "add a periodic task" a
|
||||
|
||||
</details>
|
||||
|
||||
## 🐍 Python SDK
|
||||
|
||||
Use nanobot as a library — no CLI, no gateway, just Python:
|
||||
|
||||
```python
|
||||
from nanobot import Nanobot
|
||||
|
||||
bot = Nanobot.from_config()
|
||||
result = await bot.run("Summarize the README")
|
||||
print(result.content)
|
||||
```
|
||||
|
||||
Each call carries a `session_key` for conversation isolation — different keys get independent history:
|
||||
|
||||
```python
|
||||
await bot.run("hi", session_key="user-alice")
|
||||
await bot.run("hi", session_key="task-42")
|
||||
```
|
||||
|
||||
Add lifecycle hooks to observe or customize the agent:
|
||||
|
||||
```python
|
||||
from nanobot.agent import AgentHook, AgentHookContext
|
||||
|
||||
class AuditHook(AgentHook):
|
||||
async def before_execute_tools(self, ctx: AgentHookContext) -> None:
|
||||
for tc in ctx.tool_calls:
|
||||
print(f"[tool] {tc.name}")
|
||||
|
||||
result = await bot.run("Hello", hooks=[AuditHook()])
|
||||
```
|
||||
|
||||
See [docs/PYTHON_SDK.md](docs/PYTHON_SDK.md) for the full SDK reference.
|
||||
|
||||
## 🔌 OpenAI-Compatible API
|
||||
|
||||
nanobot can expose a minimal OpenAI-compatible endpoint for local integrations:
|
||||
|
||||
```bash
|
||||
pip install "nanobot-ai[api]"
|
||||
nanobot serve
|
||||
```
|
||||
|
||||
By default, the API binds to `127.0.0.1:8900`. You can change this in `config.json`.
|
||||
|
||||
### Behavior
|
||||
|
||||
- Session isolation: pass `"session_id"` in the request body to isolate conversations; omit for a shared default session (`api:default`)
|
||||
- Single-message input: each request must contain exactly one `user` message
|
||||
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
|
||||
- No streaming: `stream=true` is not supported
|
||||
|
||||
### Endpoints
|
||||
|
||||
- `GET /health`
|
||||
- `GET /v1/models`
|
||||
- `POST /v1/chat/completions`
|
||||
|
||||
### curl
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"session_id": "my-session"
|
||||
}'
|
||||
```
|
||||
|
||||
### Python (`requests`)
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
resp = requests.post(
|
||||
"http://127.0.0.1:8900/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"session_id": "my-session", # optional: isolate conversation
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
print(resp.json()["choices"][0]["message"]["content"])
|
||||
```
|
||||
|
||||
### Python (`openai`)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://127.0.0.1:8900/v1",
|
||||
api_key="dummy",
|
||||
)
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model="MiniMax-M2.7",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
extra_body={"session_id": "my-session"}, # optional: isolate conversation
|
||||
)
|
||||
print(resp.choices[0].message.content)
|
||||
```
|
||||
|
||||
## 🐳 Docker
|
||||
|
||||
> [!TIP]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#!/bin/bash
|
||||
# Count core agent lines (excluding channels/, cli/, providers/ adapters)
|
||||
# Count core agent lines (excluding channels/, cli/, api/, providers/ adapters,
|
||||
# and the high-level Python SDK facade)
|
||||
cd "$(dirname "$0")" || exit 1
|
||||
|
||||
echo "nanobot core agent line count"
|
||||
@ -15,7 +16,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l)
|
||||
printf " %-16s %5s lines\n" "(root)" "$root"
|
||||
|
||||
echo ""
|
||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l)
|
||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/api/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" ! -path "nanobot/nanobot.py" | xargs cat | wc -l)
|
||||
echo " Core total: $total lines"
|
||||
echo ""
|
||||
echo " (excludes: channels/, cli/, command/, providers/, skills/)"
|
||||
echo " (excludes: channels/, cli/, api/, command/, providers/, skills/, nanobot.py)"
|
||||
|
||||
136
docs/PYTHON_SDK.md
Normal file
136
docs/PYTHON_SDK.md
Normal file
@ -0,0 +1,136 @@
|
||||
# Python SDK
|
||||
|
||||
Use nanobot programmatically — load config, run the agent, get results.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from nanobot import Nanobot
|
||||
|
||||
async def main():
|
||||
bot = Nanobot.from_config()
|
||||
result = await bot.run("What time is it in Tokyo?")
|
||||
print(result.content)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
### `Nanobot.from_config(config_path?, *, workspace?)`
|
||||
|
||||
Create a `Nanobot` from a config file.
|
||||
|
||||
| Param | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `config_path` | `str \| Path \| None` | `None` | Path to `config.json`. Defaults to `~/.nanobot/config.json`. |
|
||||
| `workspace` | `str \| Path \| None` | `None` | Override workspace directory from config. |
|
||||
|
||||
Raises `FileNotFoundError` if an explicit path doesn't exist.
|
||||
|
||||
### `await bot.run(message, *, session_key?, hooks?)`
|
||||
|
||||
Run the agent once. Returns a `RunResult`.
|
||||
|
||||
| Param | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `message` | `str` | *(required)* | The user message to process. |
|
||||
| `session_key` | `str` | `"sdk:default"` | Session identifier for conversation isolation. Different keys get independent history. |
|
||||
| `hooks` | `list[AgentHook] \| None` | `None` | Lifecycle hooks for this run only. |
|
||||
|
||||
```python
|
||||
# Isolated sessions — each user gets independent conversation history
|
||||
await bot.run("hi", session_key="user-alice")
|
||||
await bot.run("hi", session_key="user-bob")
|
||||
```
|
||||
|
||||
### `RunResult`
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `content` | `str` | The agent's final text response. |
|
||||
| `tools_used` | `list[str]` | Tool names invoked during the run. |
|
||||
| `messages` | `list[dict]` | Raw message history (for debugging). |
|
||||
|
||||
## Hooks
|
||||
|
||||
Hooks let you observe or modify the agent loop without touching internals.
|
||||
|
||||
Subclass `AgentHook` and override any method:
|
||||
|
||||
| Method | When |
|
||||
|--------|------|
|
||||
| `before_iteration(ctx)` | Before each LLM call |
|
||||
| `on_stream(ctx, delta)` | On each streamed token |
|
||||
| `on_stream_end(ctx)` | When streaming finishes |
|
||||
| `before_execute_tools(ctx)` | Before tool execution (inspect `ctx.tool_calls`) |
|
||||
| `after_iteration(ctx, response)` | After each LLM response |
|
||||
| `finalize_content(ctx, content)` | Transform final output text |
|
||||
|
||||
### Example: Audit Hook
|
||||
|
||||
```python
|
||||
from nanobot.agent import AgentHook, AgentHookContext
|
||||
|
||||
class AuditHook(AgentHook):
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def before_execute_tools(self, ctx: AgentHookContext) -> None:
|
||||
for tc in ctx.tool_calls:
|
||||
self.calls.append(tc.name)
|
||||
print(f"[audit] {tc.name}({tc.arguments})")
|
||||
|
||||
hook = AuditHook()
|
||||
result = await bot.run("List files in /tmp", hooks=[hook])
|
||||
print(f"Tools used: {hook.calls}")
|
||||
```
|
||||
|
||||
### Composing Hooks
|
||||
|
||||
Pass multiple hooks — they run in order, errors in one don't block others:
|
||||
|
||||
```python
|
||||
result = await bot.run("hi", hooks=[AuditHook(), MetricsHook()])
|
||||
```
|
||||
|
||||
Under the hood this uses `CompositeHook` for fan-out with error isolation.
|
||||
|
||||
### `finalize_content` Pipeline
|
||||
|
||||
Unlike the async methods (fan-out), `finalize_content` is a pipeline — each hook's output feeds the next:
|
||||
|
||||
```python
|
||||
class Censor(AgentHook):
|
||||
def finalize_content(self, ctx, content):
|
||||
return content.replace("secret", "***") if content else content
|
||||
```
|
||||
|
||||
## Full Example
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from nanobot import Nanobot
|
||||
from nanobot.agent import AgentHook, AgentHookContext
|
||||
|
||||
class TimingHook(AgentHook):
|
||||
async def before_iteration(self, ctx: AgentHookContext) -> None:
|
||||
import time
|
||||
ctx.metadata["_t0"] = time.time()
|
||||
|
||||
async def after_iteration(self, ctx, response) -> None:
|
||||
import time
|
||||
elapsed = time.time() - ctx.metadata.get("_t0", 0)
|
||||
print(f"[timing] iteration took {elapsed:.2f}s")
|
||||
|
||||
async def main():
|
||||
bot = Nanobot.from_config(workspace="/my/project")
|
||||
result = await bot.run(
|
||||
"Explain the main function",
|
||||
hooks=[TimingHook()],
|
||||
)
|
||||
print(result.content)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
@ -4,3 +4,7 @@ nanobot - A lightweight AI agent framework
|
||||
|
||||
__version__ = "0.1.4.post6"
|
||||
__logo__ = "🐈"
|
||||
|
||||
from nanobot.nanobot import Nanobot, RunResult
|
||||
|
||||
__all__ = ["Nanobot", "RunResult"]
|
||||
|
||||
@ -1,8 +1,19 @@
|
||||
"""Agent core module."""
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
|
||||
__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"]
|
||||
__all__ = [
|
||||
"AgentHook",
|
||||
"AgentHookContext",
|
||||
"AgentLoop",
|
||||
"CompositeHook",
|
||||
"ContextBuilder",
|
||||
"MemoryStore",
|
||||
"SkillsLoader",
|
||||
"SubagentManager",
|
||||
]
|
||||
|
||||
@ -5,6 +5,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
@ -47,3 +49,60 @@ class AgentHook:
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return content
|
||||
|
||||
|
||||
class CompositeHook(AgentHook):
|
||||
"""Fan-out hook that delegates to an ordered list of hooks.
|
||||
|
||||
Error isolation: async methods catch and log per-hook exceptions
|
||||
so a faulty custom hook cannot crash the agent loop.
|
||||
``finalize_content`` is a pipeline (no isolation — bugs should surface).
|
||||
"""
|
||||
|
||||
__slots__ = ("_hooks",)
|
||||
|
||||
def __init__(self, hooks: list[AgentHook]) -> None:
|
||||
self._hooks = list(hooks)
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return any(h.wants_streaming() for h in self._hooks)
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.before_iteration(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.before_iteration error in {}", type(h).__name__)
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.on_stream(context, delta)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.on_stream error in {}", type(h).__name__)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.on_stream_end(context, resuming=resuming)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.on_stream_end error in {}", type(h).__name__)
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.before_execute_tools(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.before_execute_tools error in {}", type(h).__name__)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.after_iteration(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.after_iteration error in {}", type(h).__name__)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
for h in self._hooks:
|
||||
content = h.finalize_content(context, content)
|
||||
return content
|
||||
|
||||
@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
@ -37,6 +37,111 @@ if TYPE_CHECKING:
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
"""Core lifecycle hook for the main agent loop.
|
||||
|
||||
Handles streaming delta relay, progress reporting, tool-call logging,
|
||||
and think-tag stripping for the built-in agent path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_loop: AgentLoop,
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
*,
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
self._loop = agent_loop
|
||||
self._on_progress = on_progress
|
||||
self._on_stream = on_stream
|
||||
self._on_stream_end = on_stream_end
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
self._message_id = message_id
|
||||
self._stream_buf = ""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return self._on_stream is not None
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
if incremental and self._on_stream:
|
||||
await self._on_stream(incremental)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
if self._on_stream_end:
|
||||
await self._on_stream_end(resuming=resuming)
|
||||
self._stream_buf = ""
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
if self._on_progress:
|
||||
if not self._on_stream:
|
||||
thought = self._loop._strip_think(
|
||||
context.response.content if context.response else None
|
||||
)
|
||||
if thought:
|
||||
await self._on_progress(thought)
|
||||
tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls))
|
||||
await self._on_progress(tool_hint, tool_hint=True)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return self._loop._strip_think(content)
|
||||
|
||||
|
||||
class _LoopHookChain(AgentHook):
|
||||
"""Run the core loop hook first, then best-effort extra hooks.
|
||||
|
||||
This preserves the historical failure behavior of ``_LoopHook`` while still
|
||||
letting user-supplied hooks opt into ``CompositeHook`` isolation.
|
||||
"""
|
||||
|
||||
__slots__ = ("_primary", "_extras")
|
||||
|
||||
def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None:
|
||||
self._primary = primary
|
||||
self._extras = CompositeHook(extra_hooks)
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return self._primary.wants_streaming() or self._extras.wants_streaming()
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
await self._primary.before_iteration(context)
|
||||
await self._extras.before_iteration(context)
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
await self._primary.on_stream(context, delta)
|
||||
await self._extras.on_stream(context, delta)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
await self._primary.on_stream_end(context, resuming=resuming)
|
||||
await self._extras.on_stream_end(context, resuming=resuming)
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
await self._primary.before_execute_tools(context)
|
||||
await self._extras.before_execute_tools(context)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
await self._primary.after_iteration(context)
|
||||
await self._extras.after_iteration(context)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
content = self._primary.finalize_content(context, content)
|
||||
return self._extras.finalize_content(context, content)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
@ -68,6 +173,7 @@ class AgentLoop:
|
||||
mcp_servers: dict | None = None,
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
timezone: str | None = None,
|
||||
hooks: list[AgentHook] | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
||||
|
||||
@ -85,6 +191,7 @@ class AgentLoop:
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self._start_time = time.time()
|
||||
self._last_usage: dict[str, int] = {}
|
||||
self._extra_hooks: list[AgentHook] = hooks or []
|
||||
|
||||
self.context = ContextBuilder(workspace, timezone=timezone)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
@ -217,52 +324,27 @@ class AgentLoop:
|
||||
``resuming=True`` means tool calls follow (spinner should restart);
|
||||
``resuming=False`` means this is the final response.
|
||||
"""
|
||||
loop_self = self
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
def __init__(self) -> None:
|
||||
self._stream_buf = ""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return on_stream is not None
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
if incremental and on_stream:
|
||||
await on_stream(incremental)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
if on_stream_end:
|
||||
await on_stream_end(resuming=resuming)
|
||||
self._stream_buf = ""
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
if on_progress:
|
||||
if not on_stream:
|
||||
thought = loop_self._strip_think(context.response.content if context.response else None)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls))
|
||||
await on_progress(tool_hint, tool_hint=True)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
loop_self._set_tool_context(channel, chat_id, message_id)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return loop_self._strip_think(content)
|
||||
loop_hook = _LoopHook(
|
||||
self,
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
hook: AgentHook = (
|
||||
_LoopHookChain(loop_hook, self._extra_hooks)
|
||||
if self._extra_hooks
|
||||
else loop_hook
|
||||
)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
hook=_LoopHook(),
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
))
|
||||
@ -321,25 +403,25 @@ class AgentLoop:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata={
|
||||
"_stream_delta": True,
|
||||
"_stream_id": _current_stream_id(),
|
||||
},
|
||||
metadata=meta,
|
||||
))
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata={
|
||||
"_stream_end": True,
|
||||
"_resuming": resuming,
|
||||
"_stream_id": _current_stream_id(),
|
||||
},
|
||||
metadata=meta,
|
||||
))
|
||||
stream_segment += 1
|
||||
|
||||
|
||||
@ -21,6 +21,21 @@ from nanobot.config.schema import ExecToolConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
class _SubagentHook(AgentHook):
|
||||
"""Logging-only hook for subagent execution."""
|
||||
|
||||
def __init__(self, task_id: str) -> None:
|
||||
self._task_id = task_id
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for tool_call in context.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug(
|
||||
"Subagent [{}] executing: {} with arguments: {}",
|
||||
self._task_id, tool_call.name, args_str,
|
||||
)
|
||||
|
||||
|
||||
class SubagentManager:
|
||||
"""Manages background subagent execution."""
|
||||
|
||||
@ -100,33 +115,28 @@ class SubagentManager:
|
||||
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
if self.exec_config.enable:
|
||||
tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
tools.register(WebSearchTool(config=self.web_search_config, proxy=self.web_proxy))
|
||||
tools.register(WebFetchTool(proxy=self.web_proxy))
|
||||
|
||||
|
||||
system_prompt = self._build_subagent_prompt()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": task},
|
||||
]
|
||||
|
||||
class _SubagentHook(AgentHook):
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for tool_call in context.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model=self.model,
|
||||
max_iterations=15,
|
||||
hook=_SubagentHook(),
|
||||
hook=_SubagentHook(task_id),
|
||||
max_iterations_message="Task completed but no final response was generated.",
|
||||
error_message=None,
|
||||
fail_on_tool_error=True,
|
||||
@ -213,7 +223,7 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
||||
lines.append("Failure:")
|
||||
lines.append(f"- {result.error}")
|
||||
return "\n".join(lines) or (result.error or "Error: subagent execution failed.")
|
||||
|
||||
|
||||
def _build_subagent_prompt(self) -> str:
|
||||
"""Build a focused system prompt for the subagent."""
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
@ -74,7 +74,7 @@ class CronTool(Tool):
|
||||
"enum": ["add", "list", "remove"],
|
||||
"description": "Action to perform",
|
||||
},
|
||||
"message": {"type": "string", "description": "Reminder message (for add)"},
|
||||
"message": {"type": "string", "description": "Instruction for the agent to execute when the job triggers (e.g., 'Send a reminder to WeChat: xxx' or 'Check system status and report')"},
|
||||
"every_seconds": {
|
||||
"type": "integer",
|
||||
"description": "Interval in seconds (for recurring tasks)",
|
||||
|
||||
@ -170,7 +170,11 @@ async def connect_mcp_servers(
|
||||
timeout: httpx.Timeout | None = None,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> httpx.AsyncClient:
|
||||
merged_headers = {**(cfg.headers or {}), **(headers or {})}
|
||||
merged_headers = {
|
||||
"Accept": "application/json, text/event-stream",
|
||||
**(cfg.headers or {}),
|
||||
**(headers or {}),
|
||||
}
|
||||
return httpx.AsyncClient(
|
||||
headers=merged_headers or None,
|
||||
follow_redirects=True,
|
||||
|
||||
1
nanobot/api/__init__.py
Normal file
1
nanobot/api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""OpenAI-compatible HTTP API for nanobot."""
|
||||
193
nanobot/api/server.py
Normal file
193
nanobot/api/server.py
Normal file
@ -0,0 +1,193 @@
|
||||
"""OpenAI-compatible HTTP API server for a fixed nanobot session.
|
||||
|
||||
Provides /v1/chat/completions and /v1/models endpoints.
|
||||
All requests route to a single persistent API session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
from loguru import logger
|
||||
|
||||
API_SESSION_KEY = "api:default"
|
||||
API_CHAT_ID = "default"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response:
|
||||
return web.json_response(
|
||||
{"error": {"message": message, "type": err_type, "code": status}},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _chat_completion_response(content: str, model: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
def _response_text(value: Any) -> str:
|
||||
"""Normalize process_direct output to plain assistant text."""
|
||||
if value is None:
|
||||
return ""
|
||||
if hasattr(value, "content"):
|
||||
return str(getattr(value, "content") or "")
|
||||
return str(value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
"""POST /v1/chat/completions"""
|
||||
|
||||
# --- Parse body ---
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return _error_json(400, "Invalid JSON body")
|
||||
|
||||
messages = body.get("messages")
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
return _error_json(400, "Only a single user message is supported")
|
||||
|
||||
# Stream not yet supported
|
||||
if body.get("stream", False):
|
||||
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
|
||||
|
||||
message = messages[0]
|
||||
if not isinstance(message, dict) or message.get("role") != "user":
|
||||
return _error_json(400, "Only a single user message is supported")
|
||||
user_content = message.get("content", "")
|
||||
if isinstance(user_content, list):
|
||||
# Multi-modal content array — extract text parts
|
||||
user_content = " ".join(
|
||||
part.get("text", "") for part in user_content if part.get("type") == "text"
|
||||
)
|
||||
|
||||
agent_loop = request.app["agent_loop"]
|
||||
timeout_s: float = request.app.get("request_timeout", 120.0)
|
||||
model_name: str = request.app.get("model_name", "nanobot")
|
||||
if (requested_model := body.get("model")) and requested_model != model_name:
|
||||
return _error_json(400, f"Only configured model '{model_name}' is available")
|
||||
|
||||
session_key = f"api:{body['session_id']}" if body.get("session_id") else API_SESSION_KEY
|
||||
session_locks: dict[str, asyncio.Lock] = request.app["session_locks"]
|
||||
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
|
||||
|
||||
logger.info("API request session_key={} content={}", session_key, user_content[:80])
|
||||
|
||||
_FALLBACK = "I've completed processing but have no response to give."
|
||||
|
||||
try:
|
||||
async with session_lock:
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
session_key=session_key,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
response_text = _response_text(response)
|
||||
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning(
|
||||
"Empty response for session {}, retrying",
|
||||
session_key,
|
||||
)
|
||||
retry_response = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
session_key=session_key,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
response_text = _response_text(retry_response)
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning(
|
||||
"Empty response after retry for session {}, using fallback",
|
||||
session_key,
|
||||
)
|
||||
response_text = _FALLBACK
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return _error_json(504, f"Request timed out after {timeout_s}s")
|
||||
except Exception:
|
||||
logger.exception("Error processing request for session {}", session_key)
|
||||
return _error_json(500, "Internal server error", err_type="server_error")
|
||||
except Exception:
|
||||
logger.exception("Unexpected API lock error for session {}", session_key)
|
||||
return _error_json(500, "Internal server error", err_type="server_error")
|
||||
|
||||
return web.json_response(_chat_completion_response(response_text, model_name))
|
||||
|
||||
|
||||
async def handle_models(request: web.Request) -> web.Response:
|
||||
"""GET /v1/models"""
|
||||
model_name = request.app.get("model_name", "nanobot")
|
||||
return web.json_response({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "nanobot",
|
||||
}
|
||||
],
|
||||
})
|
||||
|
||||
|
||||
async def handle_health(request: web.Request) -> web.Response:
|
||||
"""GET /health"""
|
||||
return web.json_response({"status": "ok"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# App factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0) -> web.Application:
|
||||
"""Create the aiohttp application.
|
||||
|
||||
Args:
|
||||
agent_loop: An initialized AgentLoop instance.
|
||||
model_name: Model name reported in responses.
|
||||
request_timeout: Per-request timeout in seconds.
|
||||
"""
|
||||
app = web.Application()
|
||||
app["agent_loop"] = agent_loop
|
||||
app["model_name"] = model_name
|
||||
app["request_timeout"] = request_timeout
|
||||
app["session_locks"] = {} # per-user locks, keyed by session_key
|
||||
|
||||
app.router.add_post("/v1/chat/completions", handle_chat_completions)
|
||||
app.router.add_get("/v1/models", handle_models)
|
||||
app.router.add_get("/health", handle_health)
|
||||
return app
|
||||
@ -1,25 +1,37 @@
|
||||
"""Discord channel implementation using Discord Gateway websocket."""
|
||||
"""Discord channel implementation using discord.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import Field
|
||||
import websockets
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.command.builtin import build_help_text
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
from nanobot.utils.helpers import split_message
|
||||
from nanobot.utils.helpers import safe_filename, split_message
|
||||
|
||||
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
|
||||
if TYPE_CHECKING:
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.abc import Messageable
|
||||
|
||||
if DISCORD_AVAILABLE:
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.abc import Messageable
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
TYPING_INTERVAL_S = 8
|
||||
|
||||
|
||||
class DiscordConfig(Base):
|
||||
@ -28,13 +40,205 @@ class DiscordConfig(Base):
|
||||
enabled: bool = False
|
||||
token: str = ""
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
|
||||
intents: int = 37377
|
||||
group_policy: Literal["mention", "open"] = "mention"
|
||||
read_receipt_emoji: str = "👀"
|
||||
working_emoji: str = "🔧"
|
||||
working_emoji_delay: float = 2.0
|
||||
|
||||
|
||||
if DISCORD_AVAILABLE:
|
||||
|
||||
class DiscordBotClient(discord.Client):
|
||||
"""discord.py client that forwards events to the channel."""
|
||||
|
||||
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
|
||||
super().__init__(intents=intents)
|
||||
self._channel = channel
|
||||
self.tree = app_commands.CommandTree(self)
|
||||
self._register_app_commands()
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
self._channel._bot_user_id = str(self.user.id) if self.user else None
|
||||
logger.info("Discord bot connected as user {}", self._channel._bot_user_id)
|
||||
try:
|
||||
synced = await self.tree.sync()
|
||||
logger.info("Discord app commands synced: {}", len(synced))
|
||||
except Exception as e:
|
||||
logger.warning("Discord app command sync failed: {}", e)
|
||||
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
await self._channel._handle_discord_message(message)
|
||||
|
||||
async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool:
|
||||
"""Send an ephemeral interaction response and report success."""
|
||||
try:
|
||||
await interaction.response.send_message(text, ephemeral=True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Discord interaction response failed: {}", e)
|
||||
return False
|
||||
|
||||
async def _forward_slash_command(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
command_text: str,
|
||||
) -> None:
|
||||
sender_id = str(interaction.user.id)
|
||||
channel_id = interaction.channel_id
|
||||
|
||||
if channel_id is None:
|
||||
logger.warning("Discord slash command missing channel_id: {}", command_text)
|
||||
return
|
||||
|
||||
if not self._channel.is_allowed(sender_id):
|
||||
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
||||
return
|
||||
|
||||
await self._reply_ephemeral(interaction, f"Processing {command_text}...")
|
||||
|
||||
await self._channel._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=str(channel_id),
|
||||
content=command_text,
|
||||
metadata={
|
||||
"interaction_id": str(interaction.id),
|
||||
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
|
||||
"is_slash_command": True,
|
||||
},
|
||||
)
|
||||
|
||||
def _register_app_commands(self) -> None:
|
||||
commands = (
|
||||
("new", "Start a new conversation", "/new"),
|
||||
("stop", "Stop the current task", "/stop"),
|
||||
("restart", "Restart the bot", "/restart"),
|
||||
("status", "Show bot status", "/status"),
|
||||
)
|
||||
|
||||
for name, description, command_text in commands:
|
||||
@self.tree.command(name=name, description=description)
|
||||
async def command_handler(
|
||||
interaction: discord.Interaction,
|
||||
_command_text: str = command_text,
|
||||
) -> None:
|
||||
await self._forward_slash_command(interaction, _command_text)
|
||||
|
||||
@self.tree.command(name="help", description="Show available commands")
|
||||
async def help_command(interaction: discord.Interaction) -> None:
|
||||
sender_id = str(interaction.user.id)
|
||||
if not self._channel.is_allowed(sender_id):
|
||||
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
||||
return
|
||||
await self._reply_ephemeral(interaction, build_help_text())
|
||||
|
||||
@self.tree.error
|
||||
async def on_app_command_error(
|
||||
interaction: discord.Interaction,
|
||||
error: app_commands.AppCommandError,
|
||||
) -> None:
|
||||
command_name = interaction.command.qualified_name if interaction.command else "?"
|
||||
logger.warning(
|
||||
"Discord app command failed user={} channel={} cmd={} error={}",
|
||||
interaction.user.id,
|
||||
interaction.channel_id,
|
||||
command_name,
|
||||
error,
|
||||
)
|
||||
|
||||
async def send_outbound(self, msg: OutboundMessage) -> None:
|
||||
"""Send a nanobot outbound message using Discord transport rules."""
|
||||
channel_id = int(msg.chat_id)
|
||||
|
||||
channel = self.get_channel(channel_id)
|
||||
if channel is None:
|
||||
try:
|
||||
channel = await self.fetch_channel(channel_id)
|
||||
except Exception as e:
|
||||
logger.warning("Discord channel {} unavailable: {}", msg.chat_id, e)
|
||||
return
|
||||
|
||||
reference, mention_settings = self._build_reply_context(channel, msg.reply_to)
|
||||
sent_media = False
|
||||
failed_media: list[str] = []
|
||||
|
||||
for index, media_path in enumerate(msg.media or []):
|
||||
if await self._send_file(
|
||||
channel,
|
||||
media_path,
|
||||
reference=reference if index == 0 else None,
|
||||
mention_settings=mention_settings,
|
||||
):
|
||||
sent_media = True
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
|
||||
kwargs: dict[str, Any] = {"content": chunk}
|
||||
if index == 0 and reference is not None and not sent_media:
|
||||
kwargs["reference"] = reference
|
||||
kwargs["allowed_mentions"] = mention_settings
|
||||
await channel.send(**kwargs)
|
||||
|
||||
async def _send_file(
|
||||
self,
|
||||
channel: Messageable,
|
||||
file_path: str,
|
||||
*,
|
||||
reference: discord.PartialMessage | None,
|
||||
mention_settings: discord.AllowedMentions,
|
||||
) -> bool:
|
||||
"""Send a file attachment via discord.py."""
|
||||
path = Path(file_path)
|
||||
if not path.is_file():
|
||||
logger.warning("Discord file not found, skipping: {}", file_path)
|
||||
return False
|
||||
|
||||
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
||||
return False
|
||||
|
||||
try:
|
||||
kwargs: dict[str, Any] = {"file": discord.File(path)}
|
||||
if reference is not None:
|
||||
kwargs["reference"] = reference
|
||||
kwargs["allowed_mentions"] = mention_settings
|
||||
await channel.send(**kwargs)
|
||||
logger.info("Discord file sent: {}", path.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error sending Discord file {}: {}", path.name, e)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _build_chunks(content: str, failed_media: list[str], sent_media: bool) -> list[str]:
|
||||
"""Build outbound text chunks, including attachment-failure fallback text."""
|
||||
chunks = split_message(content, MAX_MESSAGE_LEN)
|
||||
if chunks or not failed_media or sent_media:
|
||||
return chunks
|
||||
fallback = "\n".join(f"[attachment: {name} - send failed]" for name in failed_media)
|
||||
return split_message(fallback, MAX_MESSAGE_LEN)
|
||||
|
||||
@staticmethod
|
||||
def _build_reply_context(
|
||||
channel: Messageable,
|
||||
reply_to: str | None,
|
||||
) -> tuple[discord.PartialMessage | None, discord.AllowedMentions]:
|
||||
"""Build reply context for outbound messages."""
|
||||
mention_settings = discord.AllowedMentions(replied_user=False)
|
||||
if not reply_to:
|
||||
return None, mention_settings
|
||||
try:
|
||||
message_id = int(reply_to)
|
||||
except ValueError:
|
||||
logger.warning("Invalid Discord reply target: {}", reply_to)
|
||||
return None, mention_settings
|
||||
|
||||
return channel.get_partial_message(message_id), mention_settings
|
||||
|
||||
|
||||
class DiscordChannel(BaseChannel):
|
||||
"""Discord channel using Gateway websocket."""
|
||||
"""Discord channel using discord.py."""
|
||||
|
||||
name = "discord"
|
||||
display_name = "Discord"
|
||||
@ -43,353 +247,270 @@ class DiscordChannel(BaseChannel):
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
return DiscordConfig().model_dump(by_alias=True)
|
||||
|
||||
@staticmethod
|
||||
def _channel_key(channel_or_id: Any) -> str:
|
||||
"""Normalize channel-like objects and ids to a stable string key."""
|
||||
channel_id = getattr(channel_or_id, "id", channel_or_id)
|
||||
return str(channel_id)
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
if isinstance(config, dict):
|
||||
config = DiscordConfig.model_validate(config)
|
||||
super().__init__(config, bus)
|
||||
self.config: DiscordConfig = config
|
||||
self._ws: websockets.WebSocketClientProtocol | None = None
|
||||
self._seq: int | None = None
|
||||
self._heartbeat_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._client: DiscordBotClient | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._bot_user_id: str | None = None
|
||||
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
|
||||
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord gateway connection."""
|
||||
"""Start the Discord client."""
|
||||
if not DISCORD_AVAILABLE:
|
||||
logger.error("discord.py not installed. Run: pip install nanobot-ai[discord]")
|
||||
return
|
||||
|
||||
if not self.config.token:
|
||||
logger.error("Discord bot token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
try:
|
||||
intents = discord.Intents.none()
|
||||
intents.value = self.config.intents
|
||||
self._client = DiscordBotClient(self, intents=intents)
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Discord client: {}", e)
|
||||
self._client = None
|
||||
self._running = False
|
||||
return
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
logger.info("Connecting to Discord gateway...")
|
||||
async with websockets.connect(self.config.gateway_url) as ws:
|
||||
self._ws = ws
|
||||
await self._gateway_loop()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Discord gateway error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting to Discord gateway in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
self._running = True
|
||||
logger.info("Starting Discord client via discord.py...")
|
||||
|
||||
try:
|
||||
await self._client.start(self.config.token)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Discord client startup failed: {}", e)
|
||||
finally:
|
||||
self._running = False
|
||||
await self._reset_runtime_state(close_client=True)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Discord channel."""
|
||||
self._running = False
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
self._heartbeat_task = None
|
||||
for task in self._typing_tasks.values():
|
||||
task.cancel()
|
||||
self._typing_tasks.clear()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
await self._reset_runtime_state(close_client=True)
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Discord REST API, including file attachments."""
|
||||
if not self._http:
|
||||
logger.warning("Discord HTTP client not initialized")
|
||||
"""Send a message through Discord using discord.py."""
|
||||
client = self._client
|
||||
if client is None or not client.is_ready():
|
||||
logger.warning("Discord client not ready; dropping outbound message")
|
||||
return
|
||||
|
||||
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
is_progress = bool((msg.metadata or {}).get("_progress"))
|
||||
|
||||
try:
|
||||
sent_media = False
|
||||
failed_media: list[str] = []
|
||||
|
||||
# Send file attachments first
|
||||
for media_path in msg.media or []:
|
||||
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
|
||||
sent_media = True
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
# Send text content
|
||||
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
|
||||
if not chunks and failed_media and not sent_media:
|
||||
chunks = split_message(
|
||||
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
|
||||
MAX_MESSAGE_LEN,
|
||||
)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
payload: dict[str, Any] = {"content": chunk}
|
||||
|
||||
# Let the first successful attachment carry the reply if present.
|
||||
if i == 0 and msg.reply_to and not sent_media:
|
||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||
payload["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
if not await self._send_payload(url, headers, payload):
|
||||
break # Abort remaining chunks on failure
|
||||
await client.send_outbound(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
finally:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
if not is_progress:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
await self._clear_reactions(msg.chat_id)
|
||||
|
||||
async def _send_payload(
|
||||
self, url: str, headers: dict[str, str], payload: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
|
||||
for attempt in range(3):
|
||||
async def _handle_discord_message(self, message: discord.Message) -> None:
|
||||
"""Handle incoming Discord messages from discord.py."""
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
sender_id = str(message.author.id)
|
||||
channel_id = self._channel_key(message.channel)
|
||||
content = message.content or ""
|
||||
|
||||
if not self._should_accept_inbound(message, sender_id, content):
|
||||
return
|
||||
|
||||
media_paths, attachment_markers = await self._download_attachments(message.attachments)
|
||||
full_content = self._compose_inbound_content(content, attachment_markers)
|
||||
metadata = self._build_inbound_metadata(message)
|
||||
|
||||
await self._start_typing(message.channel)
|
||||
|
||||
# Add read receipt reaction immediately, working emoji after delay
|
||||
channel_id = self._channel_key(message.channel)
|
||||
try:
|
||||
await message.add_reaction(self.config.read_receipt_emoji)
|
||||
self._pending_reactions[channel_id] = message
|
||||
except Exception as e:
|
||||
logger.debug("Failed to add read receipt reaction: {}", e)
|
||||
|
||||
# Delayed working indicator (cosmetic — not tied to subagent lifecycle)
|
||||
async def _delayed_working_emoji() -> None:
|
||||
await asyncio.sleep(self.config.working_emoji_delay)
|
||||
try:
|
||||
response = await self._http.post(url, headers=headers, json=payload)
|
||||
if response.status_code == 429:
|
||||
data = response.json()
|
||||
retry_after = float(data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
await message.add_reaction(self.config.working_emoji)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _send_file(
|
||||
self._working_emoji_tasks[channel_id] = asyncio.create_task(_delayed_working_emoji())
|
||||
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=channel_id,
|
||||
content=full_content,
|
||||
media=media_paths,
|
||||
metadata=metadata,
|
||||
)
|
||||
except Exception:
|
||||
await self._clear_reactions(channel_id)
|
||||
await self._stop_typing(channel_id)
|
||||
raise
|
||||
|
||||
async def _on_message(self, message: discord.Message) -> None:
|
||||
"""Backward-compatible alias for legacy tests/callers."""
|
||||
await self._handle_discord_message(message)
|
||||
|
||||
def _should_accept_inbound(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
file_path: str,
|
||||
reply_to: str | None = None,
|
||||
message: discord.Message,
|
||||
sender_id: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Send a file attachment via Discord REST API using multipart/form-data."""
|
||||
path = Path(file_path)
|
||||
if not path.is_file():
|
||||
logger.warning("Discord file not found, skipping: {}", file_path)
|
||||
return False
|
||||
|
||||
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
||||
return False
|
||||
|
||||
payload_json: dict[str, Any] = {}
|
||||
if reply_to:
|
||||
payload_json["message_reference"] = {"message_id": reply_to}
|
||||
payload_json["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
files = {"files[0]": (path.name, f, "application/octet-stream")}
|
||||
data: dict[str, Any] = {}
|
||||
if payload_json:
|
||||
data["payload_json"] = json.dumps(payload_json)
|
||||
response = await self._http.post(
|
||||
url, headers=headers, files=files, data=data
|
||||
)
|
||||
if response.status_code == 429:
|
||||
resp_data = response.json()
|
||||
retry_after = float(resp_data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
logger.info("Discord file sent: {}", path.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord file {}: {}", path.name, e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _gateway_loop(self) -> None:
|
||||
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
async for raw in self._ws:
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
|
||||
continue
|
||||
|
||||
op = data.get("op")
|
||||
event_type = data.get("t")
|
||||
seq = data.get("s")
|
||||
payload = data.get("d")
|
||||
|
||||
if seq is not None:
|
||||
self._seq = seq
|
||||
|
||||
if op == 10:
|
||||
# HELLO: start heartbeat and identify
|
||||
interval_ms = payload.get("heartbeat_interval", 45000)
|
||||
await self._start_heartbeat(interval_ms / 1000)
|
||||
await self._identify()
|
||||
elif op == 0 and event_type == "READY":
|
||||
logger.info("Discord gateway READY")
|
||||
# Capture bot user ID for mention detection
|
||||
user_data = payload.get("user") or {}
|
||||
self._bot_user_id = user_data.get("id")
|
||||
logger.info("Discord bot connected as user {}", self._bot_user_id)
|
||||
elif op == 0 and event_type == "MESSAGE_CREATE":
|
||||
await self._handle_message_create(payload)
|
||||
elif op == 7:
|
||||
# RECONNECT: exit loop to reconnect
|
||||
logger.info("Discord gateway requested reconnect")
|
||||
break
|
||||
elif op == 9:
|
||||
# INVALID_SESSION: reconnect
|
||||
logger.warning("Discord gateway invalid session")
|
||||
break
|
||||
|
||||
async def _identify(self) -> None:
|
||||
"""Send IDENTIFY payload."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
identify = {
|
||||
"op": 2,
|
||||
"d": {
|
||||
"token": self.config.token,
|
||||
"intents": self.config.intents,
|
||||
"properties": {
|
||||
"os": "nanobot",
|
||||
"browser": "nanobot",
|
||||
"device": "nanobot",
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send(json.dumps(identify))
|
||||
|
||||
async def _start_heartbeat(self, interval_s: float) -> None:
|
||||
"""Start or restart the heartbeat loop."""
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
|
||||
async def heartbeat_loop() -> None:
|
||||
while self._running and self._ws:
|
||||
payload = {"op": 1, "d": self._seq}
|
||||
try:
|
||||
await self._ws.send(json.dumps(payload))
|
||||
except Exception as e:
|
||||
logger.warning("Discord heartbeat failed: {}", e)
|
||||
break
|
||||
await asyncio.sleep(interval_s)
|
||||
|
||||
self._heartbeat_task = asyncio.create_task(heartbeat_loop())
|
||||
|
||||
async def _handle_message_create(self, payload: dict[str, Any]) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
author = payload.get("author") or {}
|
||||
if author.get("bot"):
|
||||
return
|
||||
|
||||
sender_id = str(author.get("id", ""))
|
||||
channel_id = str(payload.get("channel_id", ""))
|
||||
content = payload.get("content") or ""
|
||||
guild_id = payload.get("guild_id")
|
||||
|
||||
if not sender_id or not channel_id:
|
||||
return
|
||||
|
||||
"""Check if inbound Discord message should be processed."""
|
||||
if not self.is_allowed(sender_id):
|
||||
return
|
||||
return False
|
||||
if message.guild is not None and not self._should_respond_in_group(message, content):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Check group channel policy (DMs always respond if is_allowed passes)
|
||||
if guild_id is not None:
|
||||
if not self._should_respond_in_group(payload, content):
|
||||
return
|
||||
|
||||
content_parts = [content] if content else []
|
||||
async def _download_attachments(
|
||||
self,
|
||||
attachments: list[discord.Attachment],
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Download supported attachments and return paths + display markers."""
|
||||
media_paths: list[str] = []
|
||||
markers: list[str] = []
|
||||
media_dir = get_media_dir("discord")
|
||||
|
||||
for attachment in payload.get("attachments") or []:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename") or "attachment"
|
||||
size = attachment.get("size") or 0
|
||||
if not url or not self._http:
|
||||
continue
|
||||
if size and size > MAX_ATTACHMENT_BYTES:
|
||||
content_parts.append(f"[attachment: {filename} - too large]")
|
||||
for attachment in attachments:
|
||||
filename = attachment.filename or "attachment"
|
||||
if attachment.size and attachment.size > MAX_ATTACHMENT_BYTES:
|
||||
markers.append(f"[attachment: {filename} - too large]")
|
||||
continue
|
||||
try:
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}"
|
||||
resp = await self._http.get(url)
|
||||
resp.raise_for_status()
|
||||
file_path.write_bytes(resp.content)
|
||||
safe_name = safe_filename(filename)
|
||||
file_path = media_dir / f"{attachment.id}_{safe_name}"
|
||||
await attachment.save(file_path)
|
||||
media_paths.append(str(file_path))
|
||||
content_parts.append(f"[attachment: {file_path}]")
|
||||
markers.append(f"[attachment: {file_path.name}]")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download Discord attachment: {}", e)
|
||||
content_parts.append(f"[attachment: {filename} - download failed]")
|
||||
markers.append(f"[attachment: {filename} - download failed]")
|
||||
|
||||
reply_to = (payload.get("referenced_message") or {}).get("id")
|
||||
return media_paths, markers
|
||||
|
||||
await self._start_typing(channel_id)
|
||||
@staticmethod
|
||||
def _compose_inbound_content(content: str, attachment_markers: list[str]) -> str:
|
||||
"""Combine message text with attachment markers."""
|
||||
content_parts = [content] if content else []
|
||||
content_parts.extend(attachment_markers)
|
||||
return "\n".join(part for part in content_parts if part) or "[empty message]"
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=channel_id,
|
||||
content="\n".join(p for p in content_parts if p) or "[empty message]",
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": str(payload.get("id", "")),
|
||||
"guild_id": guild_id,
|
||||
"reply_to": reply_to,
|
||||
},
|
||||
)
|
||||
@staticmethod
|
||||
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
||||
"""Build metadata for inbound Discord messages."""
|
||||
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
|
||||
return {
|
||||
"message_id": str(message.id),
|
||||
"guild_id": str(message.guild.id) if message.guild else None,
|
||||
"reply_to": reply_to,
|
||||
}
|
||||
|
||||
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
|
||||
"""Check if bot should respond in a group channel based on policy."""
|
||||
def _should_respond_in_group(self, message: discord.Message, content: str) -> bool:
|
||||
"""Check if the bot should respond in a guild channel based on policy."""
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
|
||||
if self.config.group_policy == "mention":
|
||||
# Check if bot was mentioned in the message
|
||||
if self._bot_user_id:
|
||||
# Check mentions array
|
||||
mentions = payload.get("mentions") or []
|
||||
for mention in mentions:
|
||||
if str(mention.get("id")) == self._bot_user_id:
|
||||
return True
|
||||
# Also check content for mention format <@USER_ID>
|
||||
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
|
||||
return True
|
||||
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
|
||||
bot_user_id = self._bot_user_id
|
||||
if bot_user_id is None:
|
||||
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
|
||||
return False
|
||||
|
||||
if any(str(user.id) == bot_user_id for user in message.mentions):
|
||||
return True
|
||||
if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content:
|
||||
return True
|
||||
|
||||
logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _start_typing(self, channel_id: str) -> None:
|
||||
async def _start_typing(self, channel: Messageable) -> None:
|
||||
"""Start periodic typing indicator for a channel."""
|
||||
channel_id = self._channel_key(channel)
|
||||
await self._stop_typing(channel_id)
|
||||
|
||||
async def typing_loop() -> None:
|
||||
url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
while self._running:
|
||||
try:
|
||||
await self._http.post(url, headers=headers)
|
||||
async with channel.typing():
|
||||
await asyncio.sleep(TYPING_INTERVAL_S)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
|
||||
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
||||
|
||||
async def _stop_typing(self, channel_id: str) -> None:
|
||||
"""Stop typing indicator for a channel."""
|
||||
task = self._typing_tasks.pop(channel_id, None)
|
||||
if task:
|
||||
task = self._typing_tasks.pop(self._channel_key(channel_id), None)
|
||||
if task is None:
|
||||
return
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
async def _clear_reactions(self, chat_id: str) -> None:
|
||||
"""Remove all pending reactions after bot replies."""
|
||||
# Cancel delayed working emoji if it hasn't fired yet
|
||||
task = self._working_emoji_tasks.pop(chat_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
msg_obj = self._pending_reactions.pop(chat_id, None)
|
||||
if msg_obj is None:
|
||||
return
|
||||
bot_user = self._client.user if self._client else None
|
||||
for emoji in (self.config.read_receipt_emoji, self.config.working_emoji):
|
||||
try:
|
||||
await msg_obj.remove_reaction(emoji, bot_user)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _cancel_all_typing(self) -> None:
|
||||
"""Stop all typing tasks."""
|
||||
channel_ids = list(self._typing_tasks)
|
||||
for channel_id in channel_ids:
|
||||
await self._stop_typing(channel_id)
|
||||
|
||||
async def _reset_runtime_state(self, close_client: bool) -> None:
|
||||
"""Reset client and typing state."""
|
||||
await self._cancel_all_typing()
|
||||
if close_client and self._client is not None and not self._client.is_closed():
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception as e:
|
||||
logger.warning("Discord client close failed: {}", e)
|
||||
self._client = None
|
||||
self._bot_user_id = None
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
@ -28,8 +30,8 @@ try:
|
||||
RoomSendError,
|
||||
RoomTypingError,
|
||||
SyncError,
|
||||
UploadError,
|
||||
)
|
||||
UploadError, RoomSendResponse,
|
||||
)
|
||||
from nio.crypto.attachments import decrypt_attachment
|
||||
from nio.exceptions import EncryptionError
|
||||
except ImportError as e:
|
||||
@ -97,6 +99,22 @@ MATRIX_HTML_CLEANER = nh3.Cleaner(
|
||||
link_rel="noopener noreferrer",
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class _StreamBuf:
|
||||
"""
|
||||
Represents a buffer for managing LLM response stream data.
|
||||
|
||||
:ivar text: Stores the text content of the buffer.
|
||||
:type text: str
|
||||
:ivar event_id: Identifier for the associated event. None indicates no
|
||||
specific event association.
|
||||
:type event_id: str | None
|
||||
:ivar last_edit: Timestamp of the most recent edit to the buffer.
|
||||
:type last_edit: float
|
||||
"""
|
||||
text: str = ""
|
||||
event_id: str | None = None
|
||||
last_edit: float = 0.0
|
||||
|
||||
def _render_markdown_html(text: str) -> str | None:
|
||||
"""Render markdown to sanitized HTML; returns None for plain text."""
|
||||
@ -114,12 +132,47 @@ def _render_markdown_html(text: str) -> str | None:
|
||||
return formatted
|
||||
|
||||
|
||||
def _build_matrix_text_content(text: str) -> dict[str, object]:
|
||||
"""Build Matrix m.text payload with optional HTML formatted_body."""
|
||||
def _build_matrix_text_content(
|
||||
text: str,
|
||||
event_id: str | None = None,
|
||||
thread_relates_to: dict[str, object] | None = None,
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Constructs and returns a dictionary representing the matrix text content with optional
|
||||
HTML formatting and reference to an existing event for replacement. This function is
|
||||
primarily used to create content payloads compatible with the Matrix messaging protocol.
|
||||
|
||||
:param text: The plain text content to include in the message.
|
||||
:type text: str
|
||||
:param event_id: Optional ID of the event to replace. If provided, the function will
|
||||
include information indicating that the message is a replacement of the specified
|
||||
event.
|
||||
:type event_id: str | None
|
||||
:param thread_relates_to: Optional Matrix thread relation metadata. For edits this is
|
||||
stored in ``m.new_content`` so the replacement remains in the same thread.
|
||||
:type thread_relates_to: dict[str, object] | None
|
||||
:return: A dictionary containing the matrix text content, potentially enriched with
|
||||
HTML formatting and replacement metadata if applicable.
|
||||
:rtype: dict[str, object]
|
||||
"""
|
||||
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
|
||||
if html := _render_markdown_html(text):
|
||||
content["format"] = MATRIX_HTML_FORMAT
|
||||
content["formatted_body"] = html
|
||||
if event_id:
|
||||
content["m.new_content"] = {
|
||||
"body": text,
|
||||
"msgtype": "m.text",
|
||||
}
|
||||
content["m.relates_to"] = {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": event_id,
|
||||
}
|
||||
if thread_relates_to:
|
||||
content["m.new_content"]["m.relates_to"] = thread_relates_to
|
||||
elif thread_relates_to:
|
||||
content["m.relates_to"] = thread_relates_to
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@ -159,7 +212,8 @@ class MatrixConfig(Base):
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
||||
group_allow_from: list[str] = Field(default_factory=list)
|
||||
allow_room_mentions: bool = False
|
||||
allow_room_mentions: bool = False,
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
class MatrixChannel(BaseChannel):
|
||||
@ -167,6 +221,8 @@ class MatrixChannel(BaseChannel):
|
||||
|
||||
name = "matrix"
|
||||
display_name = "Matrix"
|
||||
_STREAM_EDIT_INTERVAL = 2 # min seconds between edit_message_text calls
|
||||
monotonic_time = time.monotonic
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
@ -192,6 +248,8 @@ class MatrixChannel(BaseChannel):
|
||||
)
|
||||
self._server_upload_limit_bytes: int | None = None
|
||||
self._server_upload_limit_checked = False
|
||||
self._stream_bufs: dict[str, _StreamBuf] = {}
|
||||
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Matrix client and begin sync loop."""
|
||||
@ -297,14 +355,17 @@ class MatrixChannel(BaseChannel):
|
||||
room = getattr(self.client, "rooms", {}).get(room_id)
|
||||
return bool(getattr(room, "encrypted", False))
|
||||
|
||||
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
|
||||
async def _send_room_content(self, room_id: str,
|
||||
content: dict[str, Any]) -> None | RoomSendResponse | RoomSendError:
|
||||
"""Send m.room.message with E2EE options."""
|
||||
if not self.client:
|
||||
return
|
||||
return None
|
||||
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
|
||||
|
||||
if self.config.e2ee_enabled:
|
||||
kwargs["ignore_unverified_devices"] = True
|
||||
await self.client.room_send(**kwargs)
|
||||
response = await self.client.room_send(**kwargs)
|
||||
return response
|
||||
|
||||
async def _resolve_server_upload_limit_bytes(self) -> int | None:
|
||||
"""Query homeserver upload limit once per channel lifecycle."""
|
||||
@ -414,6 +475,53 @@ class MatrixChannel(BaseChannel):
|
||||
if not is_progress:
|
||||
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||
meta = metadata or {}
|
||||
relates_to = self._build_thread_relates_to(metadata)
|
||||
|
||||
if meta.get("_stream_end"):
|
||||
buf = self._stream_bufs.pop(chat_id, None)
|
||||
if not buf or not buf.event_id or not buf.text:
|
||||
return
|
||||
|
||||
await self._stop_typing_keepalive(chat_id, clear_typing=True)
|
||||
|
||||
content = _build_matrix_text_content(
|
||||
buf.text,
|
||||
buf.event_id,
|
||||
thread_relates_to=relates_to,
|
||||
)
|
||||
await self._send_room_content(chat_id, content)
|
||||
return
|
||||
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
if buf is None:
|
||||
buf = _StreamBuf()
|
||||
self._stream_bufs[chat_id] = buf
|
||||
buf.text += delta
|
||||
|
||||
if not buf.text.strip():
|
||||
return
|
||||
|
||||
now = self.monotonic_time()
|
||||
|
||||
if not buf.last_edit or (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
|
||||
try:
|
||||
content = _build_matrix_text_content(
|
||||
buf.text,
|
||||
buf.event_id,
|
||||
thread_relates_to=relates_to,
|
||||
)
|
||||
response = await self._send_room_content(chat_id, content)
|
||||
buf.last_edit = now
|
||||
if not buf.event_id:
|
||||
# we are editing the same message all the time, so only the first time the event id needs to be set
|
||||
buf.event_id = response.event_id
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(chat_id, clear_typing=True)
|
||||
pass
|
||||
|
||||
|
||||
def _register_event_callbacks(self) -> None:
|
||||
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||
|
||||
@ -491,6 +491,91 @@ def _migrate_cron_store(config: "Config") -> None:
|
||||
shutil.move(str(legacy_path), str(new_path))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OpenAI-Compatible API Server
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
port: int | None = typer.Option(None, "--port", "-p", help="API server port"),
|
||||
host: str | None = typer.Option(None, "--host", "-H", help="Bind address"),
|
||||
timeout: float | None = typer.Option(None, "--timeout", "-t", help="Per-request timeout (seconds)"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"),
|
||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
):
|
||||
"""Start the OpenAI-compatible API server (/v1/chat/completions)."""
|
||||
try:
|
||||
from aiohttp import web # noqa: F401
|
||||
except ImportError:
|
||||
console.print("[red]aiohttp is required. Install with: pip install 'nanobot-ai[api]'[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
from loguru import logger
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.api.server import create_app
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
if verbose:
|
||||
logger.enable("nanobot")
|
||||
else:
|
||||
logger.disable("nanobot")
|
||||
|
||||
runtime_config = _load_runtime_config(config, workspace)
|
||||
api_cfg = runtime_config.api
|
||||
host = host if host is not None else api_cfg.host
|
||||
port = port if port is not None else api_cfg.port
|
||||
timeout = timeout if timeout is not None else api_cfg.timeout
|
||||
sync_workspace_templates(runtime_config.workspace_path)
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(runtime_config)
|
||||
session_manager = SessionManager(runtime_config.workspace_path)
|
||||
agent_loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=runtime_config.workspace_path,
|
||||
model=runtime_config.agents.defaults.model,
|
||||
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
|
||||
web_search_config=runtime_config.tools.web.search,
|
||||
web_proxy=runtime_config.tools.web.proxy or None,
|
||||
exec_config=runtime_config.tools.exec,
|
||||
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
|
||||
session_manager=session_manager,
|
||||
mcp_servers=runtime_config.tools.mcp_servers,
|
||||
channels_config=runtime_config.channels,
|
||||
timezone=runtime_config.agents.defaults.timezone,
|
||||
)
|
||||
|
||||
model_name = runtime_config.agents.defaults.model
|
||||
console.print(f"{__logo__} Starting OpenAI-compatible API server")
|
||||
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
|
||||
console.print(f" [cyan]Model[/cyan] : {model_name}")
|
||||
console.print(" [cyan]Session[/cyan] : api:default")
|
||||
console.print(f" [cyan]Timeout[/cyan] : {timeout}s")
|
||||
if host in {"0.0.0.0", "::"}:
|
||||
console.print(
|
||||
"[yellow]Warning:[/yellow] API is bound to all interfaces. "
|
||||
"Only do this behind a trusted network boundary, firewall, or reverse proxy."
|
||||
)
|
||||
console.print()
|
||||
|
||||
api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout)
|
||||
|
||||
async def on_startup(_app):
|
||||
await agent_loop._connect_mcp()
|
||||
|
||||
async def on_cleanup(_app):
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
api_app.on_startup.append(on_startup)
|
||||
api_app.on_cleanup.append(on_cleanup)
|
||||
|
||||
web.run_app(api_app, host=host, port=port, print=lambda msg: logger.info(msg))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Gateway / Server
|
||||
# ============================================================================
|
||||
|
||||
@ -84,6 +84,16 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
|
||||
|
||||
async def cmd_help(ctx: CommandContext) -> OutboundMessage:
|
||||
"""Return available slash commands."""
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content=build_help_text(),
|
||||
metadata={"render_as": "text"},
|
||||
)
|
||||
|
||||
|
||||
def build_help_text() -> str:
|
||||
"""Build canonical help text shared across channels."""
|
||||
lines = [
|
||||
"🐈 nanobot commands:",
|
||||
"/new — Start a new conversation",
|
||||
@ -92,12 +102,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage:
|
||||
"/status — Show bot status",
|
||||
"/help — Show available commands",
|
||||
]
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content="\n".join(lines),
|
||||
metadata={"render_as": "text"},
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def register_builtin_commands(router: CommandRouter) -> None:
|
||||
|
||||
@ -96,6 +96,14 @@ class HeartbeatConfig(Base):
|
||||
keep_recent_messages: int = 8
|
||||
|
||||
|
||||
class ApiConfig(Base):
|
||||
"""OpenAI-compatible API server configuration."""
|
||||
|
||||
host: str = "127.0.0.1" # Safer default: local-only bind.
|
||||
port: int = 8900
|
||||
timeout: float = 120.0 # Per-request timeout in seconds.
|
||||
|
||||
|
||||
class GatewayConfig(Base):
|
||||
"""Gateway/server configuration."""
|
||||
|
||||
@ -156,6 +164,7 @@ class Config(BaseSettings):
|
||||
agents: AgentsConfig = Field(default_factory=AgentsConfig)
|
||||
channels: ChannelsConfig = Field(default_factory=ChannelsConfig)
|
||||
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
|
||||
api: ApiConfig = Field(default_factory=ApiConfig)
|
||||
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
||||
tools: ToolsConfig = Field(default_factory=ToolsConfig)
|
||||
|
||||
|
||||
170
nanobot/nanobot.py
Normal file
170
nanobot/nanobot.py
Normal file
@ -0,0 +1,170 @@
|
||||
"""High-level programmatic interface to nanobot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.hook import AgentHook
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RunResult:
|
||||
"""Result of a single agent run."""
|
||||
|
||||
content: str
|
||||
tools_used: list[str]
|
||||
messages: list[dict[str, Any]]
|
||||
|
||||
|
||||
class Nanobot:
|
||||
"""Programmatic facade for running the nanobot agent.
|
||||
|
||||
Usage::
|
||||
|
||||
bot = Nanobot.from_config()
|
||||
result = await bot.run("Summarize this repo", hooks=[MyHook()])
|
||||
print(result.content)
|
||||
"""
|
||||
|
||||
def __init__(self, loop: AgentLoop) -> None:
|
||||
self._loop = loop
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config_path: str | Path | None = None,
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
) -> Nanobot:
|
||||
"""Create a Nanobot instance from a config file.
|
||||
|
||||
Args:
|
||||
config_path: Path to ``config.json``. Defaults to
|
||||
``~/.nanobot/config.json``.
|
||||
workspace: Override the workspace directory from config.
|
||||
"""
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
resolved: Path | None = None
|
||||
if config_path is not None:
|
||||
resolved = Path(config_path).expanduser().resolve()
|
||||
if not resolved.exists():
|
||||
raise FileNotFoundError(f"Config not found: {resolved}")
|
||||
|
||||
config: Config = load_config(resolved)
|
||||
if workspace is not None:
|
||||
config.agents.defaults.workspace = str(
|
||||
Path(workspace).expanduser().resolve()
|
||||
)
|
||||
|
||||
provider = _make_provider(config)
|
||||
bus = MessageBus()
|
||||
defaults = config.agents.defaults
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=defaults.model,
|
||||
max_iterations=defaults.max_tool_iterations,
|
||||
context_window_tokens=defaults.context_window_tokens,
|
||||
web_search_config=config.tools.web.search,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
timezone=defaults.timezone,
|
||||
)
|
||||
return cls(loop)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
session_key: str = "sdk:default",
|
||||
hooks: list[AgentHook] | None = None,
|
||||
) -> RunResult:
|
||||
"""Run the agent once and return the result.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
session_key: Session identifier for conversation isolation.
|
||||
Different keys get independent history.
|
||||
hooks: Optional lifecycle hooks for this run.
|
||||
"""
|
||||
prev = self._loop._extra_hooks
|
||||
if hooks is not None:
|
||||
self._loop._extra_hooks = list(hooks)
|
||||
try:
|
||||
response = await self._loop.process_direct(
|
||||
message, session_key=session_key,
|
||||
)
|
||||
finally:
|
||||
self._loop._extra_hooks = prev
|
||||
|
||||
content = (response.content if response else None) or ""
|
||||
return RunResult(content=content, tools_used=[], messages=[])
|
||||
|
||||
|
||||
def _make_provider(config: Any) -> Any:
|
||||
"""Create the LLM provider from config (extracted from CLI)."""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
|
||||
if backend == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||
needs_key = not (p and p.api_key)
|
||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||
if needs_key and not exempt:
|
||||
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
||||
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key, api_base=p.api_base, default_model=model
|
||||
)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
defaults = config.agents.defaults
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=defaults.temperature,
|
||||
max_tokens=defaults.max_tokens,
|
||||
reasoning_effort=defaults.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
@ -124,8 +124,8 @@ def build_assistant_message(
|
||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
if reasoning_content is not None:
|
||||
msg["reasoning_content"] = reasoning_content
|
||||
if reasoning_content is not None or thinking_blocks:
|
||||
msg["reasoning_content"] = reasoning_content if reasoning_content is not None else ""
|
||||
if thinking_blocks:
|
||||
msg["thinking_blocks"] = thinking_blocks
|
||||
return msg
|
||||
|
||||
@ -51,6 +51,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
api = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
]
|
||||
wecom = [
|
||||
"wecom-aibot-sdk-python>=0.1.5",
|
||||
]
|
||||
@ -64,12 +67,16 @@ matrix = [
|
||||
"mistune>=3.0.0,<4.0.0",
|
||||
"nh3>=0.2.17,<1.0.0",
|
||||
]
|
||||
discord = [
|
||||
"discord.py>=2.5.2,<3.0.0",
|
||||
]
|
||||
langsmith = [
|
||||
"langsmith>=0.1.0",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=9.0.0,<10.0.0",
|
||||
"pytest-asyncio>=1.3.0,<2.0.0",
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
"pytest-cov>=6.0.0,<7.0.0",
|
||||
"ruff>=0.1.0",
|
||||
]
|
||||
|
||||
351
tests/agent/test_hook_composite.py
Normal file
351
tests/agent/test_hook_composite.py
Normal file
@ -0,0 +1,351 @@
|
||||
"""Tests for CompositeHook fan-out, error isolation, and integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
|
||||
|
||||
def _ctx() -> AgentHookContext:
|
||||
return AgentHookContext(iteration=0, messages=[])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fan-out: every hook is called in order
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_fans_out_before_iteration():
|
||||
calls: list[str] = []
|
||||
|
||||
class H(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append(f"A:{context.iteration}")
|
||||
|
||||
class H2(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append(f"B:{context.iteration}")
|
||||
|
||||
hook = CompositeHook([H(), H2()])
|
||||
ctx = _ctx()
|
||||
await hook.before_iteration(ctx)
|
||||
assert calls == ["A:0", "B:0"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_fans_out_all_async_methods():
|
||||
"""Verify all async methods fan out to every hook."""
|
||||
events: list[str] = []
|
||||
|
||||
class RecordingHook(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append("before_iteration")
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
events.append(f"on_stream:{delta}")
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
events.append(f"on_stream_end:{resuming}")
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
events.append("before_execute_tools")
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append("after_iteration")
|
||||
|
||||
hook = CompositeHook([RecordingHook(), RecordingHook()])
|
||||
ctx = _ctx()
|
||||
|
||||
await hook.before_iteration(ctx)
|
||||
await hook.on_stream(ctx, "hi")
|
||||
await hook.on_stream_end(ctx, resuming=True)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
|
||||
assert events == [
|
||||
"before_iteration", "before_iteration",
|
||||
"on_stream:hi", "on_stream:hi",
|
||||
"on_stream_end:True", "on_stream_end:True",
|
||||
"before_execute_tools", "before_execute_tools",
|
||||
"after_iteration", "after_iteration",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error isolation: one hook raises, others still run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_before_iteration():
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append("good")
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
await hook.before_iteration(_ctx())
|
||||
assert calls == ["good"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_on_stream():
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
raise RuntimeError("stream-boom")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
calls.append(delta)
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
await hook.on_stream(_ctx(), "delta")
|
||||
assert calls == ["delta"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_all_async():
|
||||
"""Error isolation for on_stream_end, before_execute_tools, after_iteration."""
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
raise RuntimeError("err")
|
||||
async def before_execute_tools(self, context):
|
||||
raise RuntimeError("err")
|
||||
async def after_iteration(self, context):
|
||||
raise RuntimeError("err")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
calls.append("on_stream_end")
|
||||
async def before_execute_tools(self, context):
|
||||
calls.append("before_execute_tools")
|
||||
async def after_iteration(self, context):
|
||||
calls.append("after_iteration")
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
ctx = _ctx()
|
||||
await hook.on_stream_end(ctx, resuming=False)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
assert calls == ["on_stream_end", "before_execute_tools", "after_iteration"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# finalize_content: pipeline semantics (no error isolation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_composite_finalize_content_pipeline():
|
||||
class Upper(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
return content.upper() if content else content
|
||||
|
||||
class Suffix(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
return (content + "!") if content else content
|
||||
|
||||
hook = CompositeHook([Upper(), Suffix()])
|
||||
result = hook.finalize_content(_ctx(), "hello")
|
||||
assert result == "HELLO!"
|
||||
|
||||
|
||||
def test_composite_finalize_content_none_passthrough():
|
||||
hook = CompositeHook([AgentHook()])
|
||||
assert hook.finalize_content(_ctx(), None) is None
|
||||
|
||||
|
||||
def test_composite_finalize_content_ordering():
|
||||
"""First hook transforms first, result feeds second hook."""
|
||||
steps: list[str] = []
|
||||
|
||||
class H1(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
steps.append(f"H1:{content}")
|
||||
return content.upper()
|
||||
|
||||
class H2(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
steps.append(f"H2:{content}")
|
||||
return content + "!"
|
||||
|
||||
hook = CompositeHook([H1(), H2()])
|
||||
result = hook.finalize_content(_ctx(), "hi")
|
||||
assert result == "HI!"
|
||||
assert steps == ["H1:hi", "H2:HI"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wants_streaming: any-semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_composite_wants_streaming_any_true():
|
||||
class No(AgentHook):
|
||||
def wants_streaming(self):
|
||||
return False
|
||||
|
||||
class Yes(AgentHook):
|
||||
def wants_streaming(self):
|
||||
return True
|
||||
|
||||
hook = CompositeHook([No(), Yes(), No()])
|
||||
assert hook.wants_streaming() is True
|
||||
|
||||
|
||||
def test_composite_wants_streaming_all_false():
|
||||
hook = CompositeHook([AgentHook(), AgentHook()])
|
||||
assert hook.wants_streaming() is False
|
||||
|
||||
|
||||
def test_composite_wants_streaming_empty():
|
||||
hook = CompositeHook([])
|
||||
assert hook.wants_streaming() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty hooks list: behaves like no-op AgentHook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_empty_hooks_no_ops():
|
||||
hook = CompositeHook([])
|
||||
ctx = _ctx()
|
||||
await hook.before_iteration(ctx)
|
||||
await hook.on_stream(ctx, "delta")
|
||||
await hook.on_stream_end(ctx, resuming=False)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
assert hook.finalize_content(ctx, "test") == "test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: AgentLoop with extra hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_loop(tmp_path, hooks=None):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation.max_tokens = 4096
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as mock_sub_mgr, \
|
||||
patch("nanobot.agent.loop.MemoryConsolidator"):
|
||||
mock_sub_mgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, hooks=hooks,
|
||||
)
|
||||
return loop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
||||
"""Extra hook passed to AgentLoop is called alongside core LoopHook."""
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
class TrackingHook(AgentHook):
|
||||
async def before_iteration(self, context):
|
||||
events.append(f"before_iter:{context.iteration}")
|
||||
|
||||
async def after_iteration(self, context):
|
||||
events.append(f"after_iter:{context.iteration}")
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[TrackingHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(content="done", tool_calls=[], usage={})
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, tools_used, messages = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
assert content == "done"
|
||||
assert "before_iter:0" in events
|
||||
assert "after_iter:0" in events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_extra_hook_error_isolation(tmp_path):
|
||||
"""A faulty extra hook does not crash the agent loop."""
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
class BadHook(AgentHook):
|
||||
async def before_iteration(self, context):
|
||||
raise RuntimeError("I am broken")
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[BadHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(content="still works", tool_calls=[], usage={})
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
assert content == "still works"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_extra_hooks_do_not_swallow_loop_hook_errors(tmp_path):
|
||||
"""Extra hooks must not change the core LoopHook failure behavior."""
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[AgentHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})],
|
||||
usage={},
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
async def bad_progress(*args, **kwargs):
|
||||
raise RuntimeError("progress failed")
|
||||
|
||||
with pytest.raises(RuntimeError, match="progress failed"):
|
||||
await loop._run_agent_loop([], on_progress=bad_progress)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_no_hooks_backward_compat(tmp_path):
|
||||
"""Without hooks param, behavior is identical to before."""
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
content, tools_used, _ = await loop._run_agent_loop([])
|
||||
assert content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
assert tools_used == ["list_dir", "list_dir"]
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -116,6 +117,43 @@ class TestDispatch:
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert out.content == "hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_streaming_preserves_message_metadata(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(
|
||||
channel="matrix",
|
||||
sender_id="u1",
|
||||
chat_id="!room:matrix.org",
|
||||
content="hello",
|
||||
metadata={
|
||||
"_wants_stream": True,
|
||||
"thread_root_event_id": "$root1",
|
||||
"thread_reply_to_event_id": "$reply1",
|
||||
},
|
||||
)
|
||||
|
||||
async def fake_process(_msg, *, on_stream=None, on_stream_end=None, **kwargs):
|
||||
assert on_stream is not None
|
||||
assert on_stream_end is not None
|
||||
await on_stream("hi")
|
||||
await on_stream_end(resuming=False)
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process
|
||||
|
||||
await loop._dispatch(msg)
|
||||
first = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
second = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
|
||||
assert first.metadata["thread_root_event_id"] == "$root1"
|
||||
assert first.metadata["thread_reply_to_event_id"] == "$reply1"
|
||||
assert first.metadata["_stream_delta"] is True
|
||||
assert second.metadata["thread_root_event_id"] == "$root1"
|
||||
assert second.metadata["thread_reply_to_event_id"] == "$reply1"
|
||||
assert second.metadata["_stream_end"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processing_lock_serializes(self):
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
@ -222,6 +260,39 @@ class TestSubagentCancellation:
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_exec_tool_not_registered_when_disabled(self, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
exec_config=ExecToolConfig(enable=False),
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
async def fake_run(spec):
|
||||
assert spec.tools.get("exec") is None
|
||||
return SimpleNamespace(
|
||||
stop_reason="done",
|
||||
final_content="done",
|
||||
error=None,
|
||||
tool_events=[],
|
||||
)
|
||||
|
||||
mgr.runner.run = AsyncMock(side_effect=fake_run)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
mgr.runner.run.assert_awaited_once()
|
||||
mgr._announce_result.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_announces_error_when_tool_execution_fails(self, monkeypatch, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
|
||||
676
tests/channels/test_discord_channel.py
Normal file
676
tests/channels/test_discord_channel.py
Normal file
@ -0,0 +1,676 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
discord = pytest.importorskip("discord")
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig
|
||||
from nanobot.command.builtin import build_help_text
|
||||
|
||||
|
||||
# Minimal Discord client test double used to control startup/readiness behavior.
|
||||
class _FakeDiscordClient:
|
||||
instances: list["_FakeDiscordClient"] = []
|
||||
start_error: Exception | None = None
|
||||
|
||||
def __init__(self, owner, *, intents) -> None:
|
||||
self.owner = owner
|
||||
self.intents = intents
|
||||
self.closed = False
|
||||
self.ready = True
|
||||
self.channels: dict[int, object] = {}
|
||||
self.user = SimpleNamespace(id=999)
|
||||
self.__class__.instances.append(self)
|
||||
|
||||
async def start(self, token: str) -> None:
|
||||
self.token = token
|
||||
if self.__class__.start_error is not None:
|
||||
raise self.__class__.start_error
|
||||
|
||||
async def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self.closed
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return self.ready
|
||||
|
||||
def get_channel(self, channel_id: int):
|
||||
return self.channels.get(channel_id)
|
||||
|
||||
async def send_outbound(self, msg: OutboundMessage) -> None:
|
||||
channel = self.get_channel(int(msg.chat_id))
|
||||
if channel is None:
|
||||
return
|
||||
await channel.send(content=msg.content)
|
||||
|
||||
|
||||
class _FakeAttachment:
|
||||
# Attachment double that can simulate successful or failing save() calls.
|
||||
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
|
||||
self.id = attachment_id
|
||||
self.filename = filename
|
||||
self.size = size
|
||||
self._fail = fail
|
||||
|
||||
async def save(self, path: str | Path) -> None:
|
||||
if self._fail:
|
||||
raise RuntimeError("save failed")
|
||||
Path(path).write_bytes(b"attachment")
|
||||
|
||||
|
||||
class _FakePartialMessage:
|
||||
# Lightweight stand-in for Discord partial message references used in replies.
|
||||
def __init__(self, message_id: int) -> None:
|
||||
self.id = message_id
|
||||
|
||||
|
||||
class _FakeChannel:
|
||||
# Channel double that records outbound payloads and typing activity.
|
||||
def __init__(self, channel_id: int = 123) -> None:
|
||||
self.id = channel_id
|
||||
self.sent_payloads: list[dict] = []
|
||||
self.trigger_typing_calls = 0
|
||||
self.typing_enter_hook = None
|
||||
|
||||
async def send(self, **kwargs) -> None:
|
||||
payload = dict(kwargs)
|
||||
if "file" in payload:
|
||||
payload["file_name"] = payload["file"].filename
|
||||
del payload["file"]
|
||||
self.sent_payloads.append(payload)
|
||||
|
||||
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
|
||||
return _FakePartialMessage(message_id)
|
||||
|
||||
def typing(self):
|
||||
channel = self
|
||||
|
||||
class _TypingContext:
|
||||
async def __aenter__(self):
|
||||
channel.trigger_typing_calls += 1
|
||||
if channel.typing_enter_hook is not None:
|
||||
await channel.typing_enter_hook()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
return _TypingContext()
|
||||
|
||||
|
||||
class _FakeInteractionResponse:
|
||||
def __init__(self) -> None:
|
||||
self.messages: list[dict] = []
|
||||
self._done = False
|
||||
|
||||
async def send_message(self, content: str, *, ephemeral: bool = False) -> None:
|
||||
self.messages.append({"content": content, "ephemeral": ephemeral})
|
||||
self._done = True
|
||||
|
||||
def is_done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
|
||||
def _make_interaction(
|
||||
*,
|
||||
user_id: int = 123,
|
||||
channel_id: int | None = 456,
|
||||
guild_id: int | None = None,
|
||||
interaction_id: int = 999,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
user=SimpleNamespace(id=user_id),
|
||||
channel_id=channel_id,
|
||||
guild_id=guild_id,
|
||||
id=interaction_id,
|
||||
command=SimpleNamespace(qualified_name="new"),
|
||||
response=_FakeInteractionResponse(),
|
||||
)
|
||||
|
||||
|
||||
def _make_message(
|
||||
*,
|
||||
author_id: int = 123,
|
||||
author_bot: bool = False,
|
||||
channel_id: int = 456,
|
||||
message_id: int = 789,
|
||||
content: str = "hello",
|
||||
guild_id: int | None = None,
|
||||
mentions: list[object] | None = None,
|
||||
attachments: list[object] | None = None,
|
||||
reply_to: int | None = None,
|
||||
):
|
||||
# Factory for incoming Discord message objects with optional guild/reply/attachments.
|
||||
guild = SimpleNamespace(id=guild_id) if guild_id is not None else None
|
||||
reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None
|
||||
return SimpleNamespace(
|
||||
author=SimpleNamespace(id=author_id, bot=author_bot),
|
||||
channel=_FakeChannel(channel_id),
|
||||
content=content,
|
||||
guild=guild,
|
||||
mentions=mentions or [],
|
||||
attachments=attachments or [],
|
||||
reference=reference,
|
||||
id=message_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_returns_when_token_missing() -> None:
|
||||
# If no token is configured, startup should no-op and leave channel stopped.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None:
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_handles_client_construction_failure(monkeypatch) -> None:
|
||||
# Construction errors from the Discord client should be swallowed and keep state clean.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
def _boom(owner, *, intents):
|
||||
raise RuntimeError("bad client")
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_handles_client_start_failure(monkeypatch) -> None:
|
||||
# If client.start fails, the partially created client should be closed and detached.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
_FakeDiscordClient.instances.clear()
|
||||
_FakeDiscordClient.start_error = RuntimeError("connect failed")
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents
|
||||
assert _FakeDiscordClient.instances[0].closed is True
|
||||
|
||||
_FakeDiscordClient.start_error = None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_is_safe_after_partial_start(monkeypatch) -> None:
|
||||
# stop() should close/discard the client even when startup was only partially completed.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
await channel.stop()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert client.closed is True
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_ignores_bot_messages() -> None:
|
||||
# Incoming bot-authored messages must be ignored to prevent feedback loops.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(author_bot=True))
|
||||
|
||||
assert handled == []
|
||||
|
||||
# If inbound handling raises, typing should be stopped for that channel.
|
||||
async def fail_handle(**kwargs) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
channel._handle_message = fail_handle # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await channel._on_message(_make_message(author_id=123, channel_id=456))
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_accepts_allowlisted_dm() -> None:
|
||||
# Allowed direct messages should be forwarded with normalized metadata.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789))
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["chat_id"] == "456"
|
||||
assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_ignores_unmentioned_guild_message() -> None:
|
||||
# With mention-only group policy, guild messages without a bot mention are dropped.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._bot_user_id = "999"
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(guild_id=1, content="hello everyone"))
|
||||
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_accepts_mentioned_guild_message() -> None:
|
||||
# Mentioned guild messages should be accepted and preserve reply threading metadata.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._bot_user_id = "999"
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
guild_id=1,
|
||||
content="<@999> hello",
|
||||
mentions=[SimpleNamespace(id=999)],
|
||||
reply_to=321,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["metadata"]["reply_to"] == "321"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None:
|
||||
# Attachment downloads should be saved and referenced in forwarded content/media.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
attachments=[_FakeAttachment(12, "photo.png")],
|
||||
content="see file",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["media"] == [str(tmp_path / "12_photo.png")]
|
||||
assert "[attachment:" in handled[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None:
|
||||
# Failed attachment downloads should emit a readable placeholder and no media path.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
attachments=[_FakeAttachment(12, "photo.png", fail=True)],
|
||||
content="",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["media"] == []
|
||||
assert handled[0]["content"] == "[attachment: photo.png - download failed]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_warns_when_client_not_ready() -> None:
|
||||
# Sending without a running/ready client should be a safe no-op.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_when_channel_not_cached() -> None:
|
||||
# Outbound sends should be skipped when the destination channel is not resolvable.
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
fetch_calls: list[int] = []
|
||||
|
||||
async def fetch_channel(channel_id: int):
|
||||
fetch_calls.append(channel_id)
|
||||
raise RuntimeError("not found")
|
||||
|
||||
client.fetch_channel = fetch_channel # type: ignore[method-assign]
|
||||
|
||||
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
assert client.get_channel(123) is None
|
||||
assert fetch_calls == [123]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_fetches_channel_when_not_cached() -> None:
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
target = _FakeChannel(channel_id=123)
|
||||
|
||||
async def fetch_channel(channel_id: int):
|
||||
return target if channel_id == 123 else None
|
||||
|
||||
client.fetch_channel = fetch_channel # type: ignore[method-assign]
|
||||
|
||||
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
assert target.sent_payloads == [{"content": "hello"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321)
|
||||
|
||||
new_cmd = client.tree.get_command("new")
|
||||
assert new_cmd is not None
|
||||
await new_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": "Processing /new...", "ephemeral": True}
|
||||
]
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "/new"
|
||||
assert handled[0]["sender_id"] == "123"
|
||||
assert handled[0]["chat_id"] == "456"
|
||||
assert handled[0]["metadata"]["interaction_id"] == "321"
|
||||
assert handled[0]["metadata"]["is_slash_command"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_new_is_blocked_for_disallowed_user() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction(user_id=123, channel_id=456)
|
||||
|
||||
new_cmd = client.tree.get_command("new")
|
||||
assert new_cmd is not None
|
||||
await new_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": "You are not allowed to use this bot.", "ephemeral": True}
|
||||
]
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("slash_name", ["stop", "restart", "status"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction()
|
||||
interaction.command.qualified_name = slash_name
|
||||
|
||||
cmd = client.tree.get_command(slash_name)
|
||||
assert cmd is not None
|
||||
await cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": f"Processing /{slash_name}...", "ephemeral": True}
|
||||
]
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == f"/{slash_name}"
|
||||
assert handled[0]["metadata"]["is_slash_command"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_help_returns_ephemeral_help_text() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction()
|
||||
interaction.command.qualified_name = "help"
|
||||
|
||||
help_cmd = client.tree.get_command("help")
|
||||
assert help_cmd is not None
|
||||
await help_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": build_help_text(), "ephemeral": True}
|
||||
]
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None:
|
||||
# Outbound payloads should upload files, attach reply references, and chunk long text.
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
target = _FakeChannel(channel_id=123)
|
||||
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
|
||||
|
||||
file_path = tmp_path / "demo.txt"
|
||||
file_path.write_text("hi")
|
||||
|
||||
await client.send_outbound(
|
||||
OutboundMessage(
|
||||
channel="discord",
|
||||
chat_id="123",
|
||||
content="a" * 2100,
|
||||
reply_to="55",
|
||||
media=[str(file_path)],
|
||||
)
|
||||
)
|
||||
|
||||
assert len(target.sent_payloads) == 3
|
||||
assert target.sent_payloads[0]["file_name"] == "demo.txt"
|
||||
assert target.sent_payloads[0]["reference"].id == 55
|
||||
assert target.sent_payloads[1]["content"] == "a" * 2000
|
||||
assert target.sent_payloads[2]["content"] == "a" * 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None:
|
||||
# If all attachment sends fail and no text exists, emit a failure placeholder message.
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
target = _FakeChannel(channel_id=123)
|
||||
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
|
||||
|
||||
missing_file = tmp_path / "missing.txt"
|
||||
|
||||
await client.send_outbound(
|
||||
OutboundMessage(
|
||||
channel="discord",
|
||||
chat_id="123",
|
||||
content="",
|
||||
media=[str(missing_file)],
|
||||
)
|
||||
)
|
||||
|
||||
assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_stops_typing_after_send() -> None:
|
||||
# Active typing indicators should be cancelled/cleared after a successful send.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
start = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def slow_typing() -> None:
|
||||
start.set()
|
||||
await release.wait()
|
||||
|
||||
typing_channel = _FakeChannel(channel_id=123)
|
||||
typing_channel.typing_enter_hook = slow_typing
|
||||
|
||||
await channel._start_typing(typing_channel)
|
||||
await start.wait()
|
||||
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
# Progress messages should keep typing active until a final (non-progress) send.
|
||||
start = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def slow_typing_progress() -> None:
|
||||
start.set()
|
||||
await release.wait()
|
||||
|
||||
typing_channel = _FakeChannel(channel_id=123)
|
||||
typing_channel.typing_enter_hook = slow_typing_progress
|
||||
|
||||
await channel._start_typing(typing_channel)
|
||||
await start.wait()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="discord",
|
||||
chat_id="123",
|
||||
content="progress",
|
||||
metadata={"_progress": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert "123" in channel._typing_tasks
|
||||
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final"))
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
channel._running = True
|
||||
|
||||
entered = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
class _TypingCtx:
|
||||
async def __aenter__(self):
|
||||
entered.set()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class _NoTriggerChannel:
|
||||
def __init__(self, channel_id: int = 123) -> None:
|
||||
self.id = channel_id
|
||||
|
||||
def typing(self):
|
||||
async def _waiter():
|
||||
await release.wait()
|
||||
# Hold the loop so task remains active until explicitly stopped.
|
||||
class _Ctx(_TypingCtx):
|
||||
async def __aenter__(self):
|
||||
await super().__aenter__()
|
||||
await _waiter()
|
||||
return _Ctx()
|
||||
|
||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||
await channel._start_typing(typing_channel) # type: ignore[arg-type]
|
||||
await entered.wait()
|
||||
|
||||
assert "123" in channel._typing_tasks
|
||||
|
||||
await channel._stop_typing("123")
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
@ -3,6 +3,9 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from nio import RoomSendResponse
|
||||
|
||||
from nanobot.channels.matrix import _build_matrix_text_content
|
||||
|
||||
# Check optional matrix dependencies before importing
|
||||
try:
|
||||
@ -65,6 +68,7 @@ class _FakeAsyncClient:
|
||||
self.raise_on_send = False
|
||||
self.raise_on_typing = False
|
||||
self.raise_on_upload = False
|
||||
self.room_send_response: RoomSendResponse | None = RoomSendResponse(event_id="", room_id="")
|
||||
|
||||
def add_event_callback(self, callback, event_type) -> None:
|
||||
self.callbacks.append((callback, event_type))
|
||||
@ -87,7 +91,7 @@ class _FakeAsyncClient:
|
||||
message_type: str,
|
||||
content: dict[str, object],
|
||||
ignore_unverified_devices: object = _ROOM_SEND_UNSET,
|
||||
) -> None:
|
||||
) -> RoomSendResponse:
|
||||
call: dict[str, object] = {
|
||||
"room_id": room_id,
|
||||
"message_type": message_type,
|
||||
@ -98,6 +102,7 @@ class _FakeAsyncClient:
|
||||
self.room_send_calls.append(call)
|
||||
if self.raise_on_send:
|
||||
raise RuntimeError("send failed")
|
||||
return self.room_send_response
|
||||
|
||||
async def room_typing(
|
||||
self,
|
||||
@ -520,6 +525,7 @@ async def test_on_message_room_mention_requires_opt_in() -> None:
|
||||
source={"content": {"m.mentions": {"room": True}}},
|
||||
)
|
||||
|
||||
channel.config.allow_room_mentions = False
|
||||
await channel._on_message(room, room_mention_event)
|
||||
assert handled == []
|
||||
assert client.typing_calls == []
|
||||
@ -1322,3 +1328,302 @@ async def test_send_keeps_plaintext_only_for_plain_text() -> None:
|
||||
"body": text,
|
||||
"m.mentions": {},
|
||||
}
|
||||
|
||||
|
||||
def test_build_matrix_text_content_basic_text() -> None:
|
||||
"""Test basic text content without HTML formatting."""
|
||||
result = _build_matrix_text_content("Hello, World!")
|
||||
expected = {
|
||||
"msgtype": "m.text",
|
||||
"body": "Hello, World!",
|
||||
"m.mentions": {}
|
||||
}
|
||||
assert expected == result
|
||||
|
||||
|
||||
def test_build_matrix_text_content_with_markdown() -> None:
|
||||
"""Test text content with markdown that renders to HTML."""
|
||||
text = "*Hello* **World**"
|
||||
result = _build_matrix_text_content(text)
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert result["body"] == text
|
||||
assert "format" in result
|
||||
assert result["format"] == "org.matrix.custom.html"
|
||||
assert "formatted_body" in result
|
||||
assert isinstance(result["formatted_body"], str)
|
||||
assert len(result["formatted_body"]) > 0
|
||||
|
||||
|
||||
def test_build_matrix_text_content_with_event_id() -> None:
|
||||
"""Test text content with event_id for message replacement."""
|
||||
event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
result = _build_matrix_text_content("Updated message", event_id)
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert result["m.new_content"]
|
||||
assert result["m.new_content"]["body"] == "Updated message"
|
||||
assert result["m.relates_to"]["rel_type"] == "m.replace"
|
||||
assert result["m.relates_to"]["event_id"] == event_id
|
||||
|
||||
|
||||
def test_build_matrix_text_content_with_event_id_preserves_thread_relation() -> None:
|
||||
"""Thread relations for edits should stay inside m.new_content."""
|
||||
relates_to = {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": "$root1",
|
||||
"m.in_reply_to": {"event_id": "$reply1"},
|
||||
"is_falling_back": True,
|
||||
}
|
||||
result = _build_matrix_text_content("Updated message", "event-1", relates_to)
|
||||
|
||||
assert result["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "event-1",
|
||||
}
|
||||
assert result["m.new_content"]["m.relates_to"] == relates_to
|
||||
|
||||
|
||||
def test_build_matrix_text_content_no_event_id() -> None:
|
||||
"""Test that when event_id is not provided, no extra properties are added."""
|
||||
result = _build_matrix_text_content("Regular message")
|
||||
|
||||
# Basic required properties should be present
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert result["body"] == "Regular message"
|
||||
|
||||
# Extra properties for replacement should NOT be present
|
||||
assert "m.relates_to" not in result
|
||||
assert "m.new_content" not in result
|
||||
assert "format" not in result
|
||||
assert "formatted_body" not in result
|
||||
|
||||
|
||||
def test_build_matrix_text_content_plain_text_no_html() -> None:
|
||||
"""Test plain text that should not include HTML formatting."""
|
||||
result = _build_matrix_text_content("Simple plain text")
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert "format" not in result
|
||||
assert "formatted_body" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_room_content_returns_room_send_response():
|
||||
"""Test that _send_room_content returns the response from client.room_send."""
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
channel.client = client
|
||||
|
||||
room_id = "!test_room:matrix.org"
|
||||
content = {"msgtype": "m.text", "body": "Hello World"}
|
||||
|
||||
result = await channel._send_room_content(room_id, content)
|
||||
|
||||
assert result is client.room_send_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_creates_stream_buffer_and_sends_initial_message() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello")
|
||||
|
||||
assert "!room:matrix.org" in channel._stream_bufs
|
||||
buf = channel._stream_bufs["!room:matrix.org"]
|
||||
assert buf.text == "Hello"
|
||||
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
assert len(client.room_send_calls) == 1
|
||||
assert client.room_send_calls[0]["content"]["body"] == "Hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_appends_without_sending_before_edit_interval(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
now = 100.0
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello")
|
||||
assert len(client.room_send_calls) == 1
|
||||
|
||||
await channel.send_delta("!room:matrix.org", " world")
|
||||
assert len(client.room_send_calls) == 1
|
||||
|
||||
buf = channel._stream_bufs["!room:matrix.org"]
|
||||
assert buf.text == "Hello world"
|
||||
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_edits_again_after_interval(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
times = [100.0, 102.0, 104.0, 106.0, 108.0]
|
||||
times.reverse()
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop())
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello")
|
||||
await channel.send_delta("!room:matrix.org", " world")
|
||||
|
||||
assert len(client.room_send_calls) == 2
|
||||
first_content = client.room_send_calls[0]["content"]
|
||||
second_content = client.room_send_calls[1]["content"]
|
||||
|
||||
assert "body" in first_content
|
||||
assert first_content["body"] == "Hello"
|
||||
assert "m.relates_to" not in first_content
|
||||
|
||||
assert "body" in second_content
|
||||
assert "m.relates_to" in second_content
|
||||
assert second_content["body"] == "Hello world"
|
||||
assert second_content["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_replaces_existing_message() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
channel._stream_bufs["!room:matrix.org"] = matrix_module._StreamBuf(
|
||||
text="Final text",
|
||||
event_id="event-1",
|
||||
last_edit=100.0,
|
||||
)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
|
||||
|
||||
assert "!room:matrix.org" not in channel._stream_bufs
|
||||
assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
|
||||
assert len(client.room_send_calls) == 1
|
||||
assert client.room_send_calls[0]["content"]["body"] == "Final text"
|
||||
assert client.room_send_calls[0]["content"]["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "event-1",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_starts_threaded_stream_inside_thread() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "event-1"
|
||||
|
||||
metadata = {
|
||||
"thread_root_event_id": "$root1",
|
||||
"thread_reply_to_event_id": "$reply1",
|
||||
}
|
||||
await channel.send_delta("!room:matrix.org", "Hello", metadata)
|
||||
|
||||
assert client.room_send_calls[0]["content"]["m.relates_to"] == {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": "$root1",
|
||||
"m.in_reply_to": {"event_id": "$reply1"},
|
||||
"is_falling_back": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_threaded_edit_keeps_replace_and_thread_relation(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "event-1"
|
||||
|
||||
times = [100.0, 102.0, 104.0]
|
||||
times.reverse()
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop())
|
||||
|
||||
metadata = {
|
||||
"thread_root_event_id": "$root1",
|
||||
"thread_reply_to_event_id": "$reply1",
|
||||
}
|
||||
await channel.send_delta("!room:matrix.org", "Hello", metadata)
|
||||
await channel.send_delta("!room:matrix.org", " world", metadata)
|
||||
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True, **metadata})
|
||||
|
||||
edit_content = client.room_send_calls[1]["content"]
|
||||
final_content = client.room_send_calls[2]["content"]
|
||||
|
||||
assert edit_content["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "event-1",
|
||||
}
|
||||
assert edit_content["m.new_content"]["m.relates_to"] == {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": "$root1",
|
||||
"m.in_reply_to": {"event_id": "$reply1"},
|
||||
"is_falling_back": True,
|
||||
}
|
||||
assert final_content["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "event-1",
|
||||
}
|
||||
assert final_content["m.new_content"]["m.relates_to"] == {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": "$root1",
|
||||
"m.in_reply_to": {"event_id": "$reply1"},
|
||||
"is_falling_back": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_noop_when_buffer_missing() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
|
||||
|
||||
assert client.room_send_calls == []
|
||||
assert client.typing_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_on_error_stops_typing(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
client.raise_on_send = True
|
||||
channel.client = client
|
||||
|
||||
now = 100.0
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello", {"room_id": "!room:matrix.org"})
|
||||
|
||||
assert "!room:matrix.org" in channel._stream_bufs
|
||||
assert channel._stream_bufs["!room:matrix.org"].text == "Hello"
|
||||
assert len(client.room_send_calls) == 1
|
||||
|
||||
assert len(client.typing_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
now = 100.0
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", " ")
|
||||
|
||||
assert "!room:matrix.org" in channel._stream_bufs
|
||||
assert channel._stream_bufs["!room:matrix.org"].text == " "
|
||||
assert client.room_send_calls == []
|
||||
@ -642,27 +642,105 @@ def test_heartbeat_retains_recent_messages_by_default():
|
||||
assert config.gateway.heartbeat.keep_recent_messages == 8
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||
def _write_instance_config(tmp_path: Path) -> Path:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
return config_file
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
def _stop_gateway_provider(_config) -> object:
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
|
||||
def _patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config: Config,
|
||||
*,
|
||||
set_config_path=None,
|
||||
sync_templates=None,
|
||||
make_provider=None,
|
||||
message_bus=None,
|
||||
session_manager=None,
|
||||
cron_service=None,
|
||||
get_cron_dir=None,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
set_config_path or (lambda _path: None),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
sync_templates or (lambda _path: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
make_provider or (lambda _config: object()),
|
||||
)
|
||||
|
||||
if message_bus is not None:
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", message_bus)
|
||||
if session_manager is not None:
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", session_manager)
|
||||
if cron_service is not None:
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", cron_service)
|
||||
if get_cron_dir is not None:
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", get_cron_dir)
|
||||
|
||||
|
||||
def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) -> None:
|
||||
pytest.importorskip("aiohttp")
|
||||
|
||||
class _FakeApiApp:
|
||||
def __init__(self) -> None:
|
||||
self.on_startup: list[object] = []
|
||||
self.on_cleanup: list[object] = []
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
seen["workspace"] = kwargs["workspace"]
|
||||
|
||||
async def _connect_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
def _fake_create_app(agent_loop, model_name: str, request_timeout: float):
|
||||
seen["agent_loop"] = agent_loop
|
||||
seen["model_name"] = model_name
|
||||
seen["request_timeout"] = request_timeout
|
||||
return _FakeApiApp()
|
||||
|
||||
def _fake_run_app(api_app, host: str, port: int, print):
|
||||
seen["api_app"] = api_app
|
||||
seen["host"] = host
|
||||
seen["port"] = port
|
||||
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.api.server.create_app", _fake_create_app)
|
||||
monkeypatch.setattr("aiohttp.web.run_app", _fake_run_app)
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
set_config_path=lambda path: seen.__setitem__("config_path", path),
|
||||
sync_templates=lambda path: seen.__setitem__("workspace", path),
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
@ -673,24 +751,17 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa
|
||||
|
||||
|
||||
def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
override = tmp_path / "override-workspace"
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
sync_templates=lambda path: seen.__setitem__("workspace", path),
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(
|
||||
@ -704,27 +775,23 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
cron_service=_StopCron,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
@ -735,10 +802,7 @@ def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path:
|
||||
def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
legacy_file = legacy_dir / "jobs.json"
|
||||
@ -748,20 +812,19 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
config = Config()
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
cron_service=_StopCron,
|
||||
get_cron_dir=lambda: legacy_dir,
|
||||
)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
@ -777,10 +840,7 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
legacy_file = legacy_dir / "jobs.json"
|
||||
@ -791,20 +851,19 @@ def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
config.agents.defaults.workspace = str(custom_workspace)
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
cron_service=_StopCron,
|
||||
get_cron_dir=lambda: legacy_dir,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
@ -856,19 +915,14 @@ def test_migrate_cron_store_skips_when_workspace_file_exists(tmp_path: Path) ->
|
||||
|
||||
|
||||
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
@ -878,19 +932,14 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_
|
||||
|
||||
|
||||
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
|
||||
@ -899,6 +948,63 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
|
||||
assert "port 18792" in result.stdout
|
||||
|
||||
|
||||
def test_serve_uses_api_config_defaults_and_workspace_override(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
config.api.host = "127.0.0.2"
|
||||
config.api.port = 18900
|
||||
config.api.timeout = 45.0
|
||||
override_workspace = tmp_path / "override-workspace"
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
_patch_serve_runtime(monkeypatch, config, seen)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["serve", "--config", str(config_file), "--workspace", str(override_workspace)],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["workspace"] == override_workspace
|
||||
assert seen["host"] == "127.0.0.2"
|
||||
assert seen["port"] == 18900
|
||||
assert seen["request_timeout"] == 45.0
|
||||
|
||||
|
||||
def test_serve_cli_options_override_api_config(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.api.host = "127.0.0.2"
|
||||
config.api.port = 18900
|
||||
config.api.timeout = 45.0
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
_patch_serve_runtime(monkeypatch, config, seen)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"serve",
|
||||
"--config",
|
||||
str(config_file),
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"18901",
|
||||
"--timeout",
|
||||
"46",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["host"] == "127.0.0.1"
|
||||
assert seen["port"] == 18901
|
||||
assert seen["request_timeout"] == 46.0
|
||||
|
||||
|
||||
def test_channels_login_requires_channel_name() -> None:
|
||||
result = runner.invoke(app, ["channels", "login"])
|
||||
|
||||
|
||||
147
tests/test_nanobot_facade.py
Normal file
147
tests/test_nanobot_facade.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""Tests for the Nanobot programmatic facade."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.nanobot import Nanobot, RunResult
|
||||
|
||||
|
||||
def _write_config(tmp_path: Path, overrides: dict | None = None) -> Path:
|
||||
data = {
|
||||
"providers": {"openrouter": {"apiKey": "sk-test-key"}},
|
||||
"agents": {"defaults": {"model": "openai/gpt-4.1"}},
|
||||
}
|
||||
if overrides:
|
||||
data.update(overrides)
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(json.dumps(data))
|
||||
return config_path
|
||||
|
||||
|
||||
def test_from_config_missing_file():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
Nanobot.from_config("/nonexistent/config.json")
|
||||
|
||||
|
||||
def test_from_config_creates_instance(tmp_path):
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
assert bot._loop is not None
|
||||
assert bot._loop.workspace == tmp_path
|
||||
|
||||
|
||||
def test_from_config_default_path():
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
with patch("nanobot.config.loader.load_config") as mock_load, \
|
||||
patch("nanobot.nanobot._make_provider") as mock_prov:
|
||||
mock_load.return_value = Config()
|
||||
mock_prov.return_value = MagicMock()
|
||||
mock_prov.return_value.get_default_model.return_value = "test"
|
||||
mock_prov.return_value.generation.max_tokens = 4096
|
||||
Nanobot.from_config()
|
||||
mock_load.assert_called_once_with(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_result(tmp_path):
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
mock_response = OutboundMessage(
|
||||
channel="cli", chat_id="direct", content="Hello back!"
|
||||
)
|
||||
bot._loop.process_direct = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await bot.run("hi")
|
||||
|
||||
assert isinstance(result, RunResult)
|
||||
assert result.content == "Hello back!"
|
||||
bot._loop.process_direct.assert_awaited_once_with("hi", session_key="sdk:default")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_hooks(tmp_path):
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
class TestHook(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
pass
|
||||
|
||||
mock_response = OutboundMessage(
|
||||
channel="cli", chat_id="direct", content="done"
|
||||
)
|
||||
bot._loop.process_direct = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await bot.run("hi", hooks=[TestHook()])
|
||||
|
||||
assert result.content == "done"
|
||||
assert bot._loop._extra_hooks == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_hooks_restored_on_error(tmp_path):
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
from nanobot.agent.hook import AgentHook
|
||||
|
||||
bot._loop.process_direct = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
original_hooks = bot._loop._extra_hooks
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await bot.run("hi", hooks=[AgentHook()])
|
||||
|
||||
assert bot._loop._extra_hooks is original_hooks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_none_response(tmp_path):
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
bot._loop.process_direct = AsyncMock(return_value=None)
|
||||
|
||||
result = await bot.run("hi")
|
||||
assert result.content == ""
|
||||
|
||||
|
||||
def test_workspace_override(tmp_path):
|
||||
config_path = _write_config(tmp_path)
|
||||
custom_ws = tmp_path / "custom_workspace"
|
||||
custom_ws.mkdir()
|
||||
|
||||
bot = Nanobot.from_config(config_path, workspace=custom_ws)
|
||||
assert bot._loop.workspace == custom_ws
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_custom_session_key(tmp_path):
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
mock_response = OutboundMessage(
|
||||
channel="cli", chat_id="direct", content="ok"
|
||||
)
|
||||
bot._loop.process_direct = AsyncMock(return_value=mock_response)
|
||||
|
||||
await bot.run("hi", session_key="user-alice")
|
||||
bot._loop.process_direct.assert_awaited_once_with("hi", session_key="user-alice")
|
||||
|
||||
|
||||
def test_import_from_top_level():
|
||||
from nanobot import Nanobot as N, RunResult as R
|
||||
assert N is Nanobot
|
||||
assert R is RunResult
|
||||
372
tests/test_openai_api.py
Normal file
372
tests/test_openai_api.py
Normal file
@ -0,0 +1,372 @@
|
||||
"""Focused tests for the fixed-session OpenAI-compatible API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from nanobot.api.server import (
|
||||
API_CHAT_ID,
|
||||
API_SESSION_KEY,
|
||||
_chat_completion_response,
|
||||
_error_json,
|
||||
create_app,
|
||||
handle_chat_completions,
|
||||
)
|
||||
|
||||
try:
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
HAS_AIOHTTP = True
|
||||
except ImportError:
|
||||
HAS_AIOHTTP = False
|
||||
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
|
||||
agent = MagicMock()
|
||||
agent.process_direct = AsyncMock(return_value=response_text)
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
return _make_mock_agent()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_agent):
|
||||
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def aiohttp_client():
|
||||
clients: list[TestClient] = []
|
||||
|
||||
async def _make_client(app):
|
||||
client = TestClient(TestServer(app))
|
||||
await client.start_server()
|
||||
clients.append(client)
|
||||
return client
|
||||
|
||||
try:
|
||||
yield _make_client
|
||||
finally:
|
||||
for client in clients:
|
||||
await client.close()
|
||||
|
||||
|
||||
def test_error_json() -> None:
|
||||
resp = _error_json(400, "bad request")
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert body["error"]["message"] == "bad request"
|
||||
assert body["error"]["code"] == 400
|
||||
|
||||
|
||||
def test_chat_completion_response() -> None:
|
||||
result = _chat_completion_response("hello world", "test-model")
|
||||
assert result["object"] == "chat.completion"
|
||||
assert result["model"] == "test-model"
|
||||
assert result["choices"][0]["message"]["content"] == "hello world"
|
||||
assert result["choices"][0]["finish_reason"] == "stop"
|
||||
assert result["id"].startswith("chatcmpl-")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_messages_returns_400(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/v1/chat/completions", json={"model": "test"})
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_user_message_returns_400(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "system", "content": "you are a bot"}]},
|
||||
)
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_true_returns_400(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}], "stream": True},
|
||||
)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert "stream" in body["error"]["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_mismatch_returns_400() -> None:
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"model": "other-model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
}
|
||||
)
|
||||
request.app = {
|
||||
"agent_loop": _make_mock_agent(),
|
||||
"model_name": "test-model",
|
||||
"request_timeout": 10.0,
|
||||
"session_lock": asyncio.Lock(),
|
||||
}
|
||||
|
||||
resp = await handle_chat_completions(request)
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert "test-model" in body["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_user_message_required() -> None:
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "previous reply"},
|
||||
],
|
||||
}
|
||||
)
|
||||
request.app = {
|
||||
"agent_loop": _make_mock_agent(),
|
||||
"model_name": "test-model",
|
||||
"request_timeout": 10.0,
|
||||
"session_lock": asyncio.Lock(),
|
||||
}
|
||||
|
||||
resp = await handle_chat_completions(request)
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert "single user message" in body["error"]["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_user_message_must_have_user_role() -> None:
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"messages": [{"role": "system", "content": "you are a bot"}],
|
||||
}
|
||||
)
|
||||
request.app = {
|
||||
"agent_loop": _make_mock_agent(),
|
||||
"model_name": "test-model",
|
||||
"request_timeout": 10.0,
|
||||
"session_lock": asyncio.Lock(),
|
||||
}
|
||||
|
||||
resp = await handle_chat_completions(request)
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert "single user message" in body["error"]["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_agent) -> None:
|
||||
app = create_app(mock_agent, model_name="test-model")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "mock response"
|
||||
assert body["model"] == "test-model"
|
||||
mock_agent.process_direct.assert_called_once_with(
|
||||
content="hello",
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
|
||||
call_log: list[str] = []
|
||||
|
||||
async def fake_process(content, session_key="", channel="", chat_id=""):
|
||||
call_log.append(session_key)
|
||||
return f"reply to {content}"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = fake_process
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
r1 = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "first"}]},
|
||||
)
|
||||
r2 = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "second"}]},
|
||||
)
|
||||
|
||||
assert r1.status == 200
|
||||
assert r2.status == 200
|
||||
assert call_log == [API_SESSION_KEY, API_SESSION_KEY]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None:
|
||||
order: list[str] = []
|
||||
barrier = asyncio.Event()
|
||||
|
||||
async def slow_process(content, session_key="", channel="", chat_id=""):
|
||||
order.append(f"start:{content}")
|
||||
if content == "first":
|
||||
barrier.set()
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
await barrier.wait()
|
||||
order.append(f"end:{content}")
|
||||
return content
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = slow_process
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
async def send(msg: str):
|
||||
return await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": msg}]},
|
||||
)
|
||||
|
||||
r1, r2 = await asyncio.gather(send("first"), send("second"))
|
||||
assert r1.status == 200
|
||||
assert r2.status == 200
|
||||
assert order.index("end:first") < order.index("start:second")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_endpoint(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v1/models")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["object"] == "list"
|
||||
assert body["data"][0]["id"] == "test-model"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/health")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["status"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> None:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
mock_agent.process_direct.assert_called_once_with(
|
||||
content="describe this",
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_retry_then_success(aiohttp_client) -> None:
|
||||
call_count = 0
|
||||
|
||||
async def sometimes_empty(content, session_key="", channel="", chat_id=""):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return ""
|
||||
return "recovered response"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = sometimes_empty
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "recovered response"
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_falls_back(aiohttp_client) -> None:
|
||||
call_count = 0
|
||||
|
||||
async def always_empty(content, session_key="", channel="", chat_id=""):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return ""
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = always_empty
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give."
|
||||
assert call_count == 2
|
||||
Loading…
x
Reference in New Issue
Block a user