merge: resolve conflicts with upstream/main, preserve typing indicator

This commit is contained in:
cypggs 2026-04-02 14:28:23 +08:00
commit ca68a89ce6
44 changed files with 5581 additions and 666 deletions

107
README.md
View File

@ -115,6 +115,8 @@
- [Configuration](#-configuration)
- [Multiple Instances](#-multiple-instances)
- [CLI Reference](#-cli-reference)
- [Python SDK](#-python-sdk)
- [OpenAI-Compatible API](#-openai-compatible-api)
- [Docker](#-docker)
- [Linux Service](#-linux-service)
- [Project Structure](#-project-structure)
@ -1541,6 +1543,7 @@ nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobo
| `nanobot agent` | Interactive chat mode |
| `nanobot agent --no-markdown` | Show plain-text replies |
| `nanobot agent --logs` | Show runtime logs during chat |
| `nanobot serve` | Start the OpenAI-compatible API |
| `nanobot gateway` | Start the gateway |
| `nanobot status` | Show status |
| `nanobot provider login openai-codex` | OAuth login for providers |
@ -1569,6 +1572,110 @@ The agent can also manage this file itself — ask it to "add a periodic task" a
</details>
## 🐍 Python SDK
Use nanobot as a library — no CLI, no gateway, just Python:
```python
from nanobot import Nanobot
bot = Nanobot.from_config()
result = await bot.run("Summarize the README")
print(result.content)
```
Each call carries a `session_key` for conversation isolation — different keys get independent history:
```python
await bot.run("hi", session_key="user-alice")
await bot.run("hi", session_key="task-42")
```
Add lifecycle hooks to observe or customize the agent:
```python
from nanobot.agent import AgentHook, AgentHookContext
class AuditHook(AgentHook):
async def before_execute_tools(self, ctx: AgentHookContext) -> None:
for tc in ctx.tool_calls:
print(f"[tool] {tc.name}")
result = await bot.run("Hello", hooks=[AuditHook()])
```
See [docs/PYTHON_SDK.md](docs/PYTHON_SDK.md) for the full SDK reference.
## 🔌 OpenAI-Compatible API
nanobot can expose a minimal OpenAI-compatible endpoint for local integrations:
```bash
pip install "nanobot-ai[api]"
nanobot serve
```
By default, the API binds to `127.0.0.1:8900`. You can change this in `config.json`.
### Behavior
- Session isolation: pass `"session_id"` in the request body to isolate conversations; omit for a shared default session (`api:default`)
- Single-message input: each request must contain exactly one `user` message
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
- No streaming: `stream=true` is not supported
### Endpoints
- `GET /health`
- `GET /v1/models`
- `POST /v1/chat/completions`
### curl
```bash
curl http://127.0.0.1:8900/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [{"role": "user", "content": "hi"}],
"session_id": "my-session"
}'
```
### Python (`requests`)
```python
import requests
resp = requests.post(
"http://127.0.0.1:8900/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "hi"}],
"session_id": "my-session", # optional: isolate conversation
},
timeout=120,
)
resp.raise_for_status()
print(resp.json()["choices"][0]["message"]["content"])
```
### Python (`openai`)
```python
from openai import OpenAI
client = OpenAI(
base_url="http://127.0.0.1:8900/v1",
api_key="dummy",
)
resp = client.chat.completions.create(
model="MiniMax-M2.7",
messages=[{"role": "user", "content": "hi"}],
extra_body={"session_id": "my-session"}, # optional: isolate conversation
)
print(resp.choices[0].message.content)
```
## 🐳 Docker
> [!TIP]

View File

@ -1,5 +1,6 @@
#!/bin/bash
# Count core agent lines (excluding channels/, cli/, providers/ adapters)
# Count core agent lines (excluding channels/, cli/, api/, providers/ adapters,
# and the high-level Python SDK facade)
cd "$(dirname "$0")" || exit 1
echo "nanobot core agent line count"
@ -15,7 +16,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l)
printf " %-16s %5s lines\n" "(root)" "$root"
echo ""
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l)
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/api/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" ! -path "nanobot/nanobot.py" | xargs cat | wc -l)
echo " Core total: $total lines"
echo ""
echo " (excludes: channels/, cli/, command/, providers/, skills/)"
echo " (excludes: channels/, cli/, api/, command/, providers/, skills/, nanobot.py)"

136
docs/PYTHON_SDK.md Normal file
View File

@ -0,0 +1,136 @@
# Python SDK
Use nanobot programmatically — load config, run the agent, get results.
## Quick Start
```python
import asyncio
from nanobot import Nanobot
async def main():
bot = Nanobot.from_config()
result = await bot.run("What time is it in Tokyo?")
print(result.content)
asyncio.run(main())
```
## API
### `Nanobot.from_config(config_path?, *, workspace?)`
Create a `Nanobot` from a config file.
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `config_path` | `str \| Path \| None` | `None` | Path to `config.json`. Defaults to `~/.nanobot/config.json`. |
| `workspace` | `str \| Path \| None` | `None` | Override workspace directory from config. |
Raises `FileNotFoundError` if an explicit path doesn't exist.
### `await bot.run(message, *, session_key?, hooks?)`
Run the agent once. Returns a `RunResult`.
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `message` | `str` | *(required)* | The user message to process. |
| `session_key` | `str` | `"sdk:default"` | Session identifier for conversation isolation. Different keys get independent history. |
| `hooks` | `list[AgentHook] \| None` | `None` | Lifecycle hooks for this run only. |
```python
# Isolated sessions — each user gets independent conversation history
await bot.run("hi", session_key="user-alice")
await bot.run("hi", session_key="user-bob")
```
### `RunResult`
| Field | Type | Description |
|-------|------|-------------|
| `content` | `str` | The agent's final text response. |
| `tools_used` | `list[str]` | Tool names invoked during the run. |
| `messages` | `list[dict]` | Raw message history (for debugging). |
## Hooks
Hooks let you observe or modify the agent loop without touching internals.
Subclass `AgentHook` and override any method:
| Method | When |
|--------|------|
| `before_iteration(ctx)` | Before each LLM call |
| `on_stream(ctx, delta)` | On each streamed token |
| `on_stream_end(ctx)` | When streaming finishes |
| `before_execute_tools(ctx)` | Before tool execution (inspect `ctx.tool_calls`) |
| `after_iteration(ctx, response)` | After each LLM response |
| `finalize_content(ctx, content)` | Transform final output text |
### Example: Audit Hook
```python
from nanobot.agent import AgentHook, AgentHookContext
class AuditHook(AgentHook):
def __init__(self):
self.calls = []
async def before_execute_tools(self, ctx: AgentHookContext) -> None:
for tc in ctx.tool_calls:
self.calls.append(tc.name)
print(f"[audit] {tc.name}({tc.arguments})")
hook = AuditHook()
result = await bot.run("List files in /tmp", hooks=[hook])
print(f"Tools used: {hook.calls}")
```
### Composing Hooks
Pass multiple hooks — they run in order, errors in one don't block others:
```python
result = await bot.run("hi", hooks=[AuditHook(), MetricsHook()])
```
Under the hood this uses `CompositeHook` for fan-out with error isolation.
### `finalize_content` Pipeline
Unlike the async methods (fan-out), `finalize_content` is a pipeline — each hook's output feeds the next:
```python
class Censor(AgentHook):
def finalize_content(self, ctx, content):
return content.replace("secret", "***") if content else content
```
## Full Example
```python
import asyncio
from nanobot import Nanobot
from nanobot.agent import AgentHook, AgentHookContext
class TimingHook(AgentHook):
async def before_iteration(self, ctx: AgentHookContext) -> None:
import time
ctx.metadata["_t0"] = time.time()
async def after_iteration(self, ctx, response) -> None:
import time
elapsed = time.time() - ctx.metadata.get("_t0", 0)
print(f"[timing] iteration took {elapsed:.2f}s")
async def main():
bot = Nanobot.from_config(workspace="/my/project")
result = await bot.run(
"Explain the main function",
hooks=[TimingHook()],
)
print(result.content)
asyncio.run(main())
```

View File

@ -4,3 +4,7 @@ nanobot - A lightweight AI agent framework
__version__ = "0.1.4.post6"
__logo__ = "🐈"
from nanobot.nanobot import Nanobot, RunResult
__all__ = ["Nanobot", "RunResult"]

View File

@ -1,8 +1,19 @@
"""Agent core module."""
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.loop import AgentLoop
from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader
from nanobot.agent.subagent import SubagentManager
__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"]
__all__ = [
"AgentHook",
"AgentHookContext",
"AgentLoop",
"CompositeHook",
"ContextBuilder",
"MemoryStore",
"SkillsLoader",
"SubagentManager",
]

View File

@ -5,6 +5,8 @@ from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from loguru import logger
from nanobot.providers.base import LLMResponse, ToolCallRequest
@ -47,3 +49,60 @@ class AgentHook:
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return content
class CompositeHook(AgentHook):
"""Fan-out hook that delegates to an ordered list of hooks.
Error isolation: async methods catch and log per-hook exceptions
so a faulty custom hook cannot crash the agent loop.
``finalize_content`` is a pipeline (no isolation bugs should surface).
"""
__slots__ = ("_hooks",)
def __init__(self, hooks: list[AgentHook]) -> None:
self._hooks = list(hooks)
def wants_streaming(self) -> bool:
return any(h.wants_streaming() for h in self._hooks)
async def before_iteration(self, context: AgentHookContext) -> None:
for h in self._hooks:
try:
await h.before_iteration(context)
except Exception:
logger.exception("AgentHook.before_iteration error in {}", type(h).__name__)
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
for h in self._hooks:
try:
await h.on_stream(context, delta)
except Exception:
logger.exception("AgentHook.on_stream error in {}", type(h).__name__)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
for h in self._hooks:
try:
await h.on_stream_end(context, resuming=resuming)
except Exception:
logger.exception("AgentHook.on_stream_end error in {}", type(h).__name__)
async def before_execute_tools(self, context: AgentHookContext) -> None:
for h in self._hooks:
try:
await h.before_execute_tools(context)
except Exception:
logger.exception("AgentHook.before_execute_tools error in {}", type(h).__name__)
async def after_iteration(self, context: AgentHookContext) -> None:
for h in self._hooks:
try:
await h.after_iteration(context)
except Exception:
logger.exception("AgentHook.after_iteration error in {}", type(h).__name__)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
for h in self._hooks:
content = h.finalize_content(context, content)
return content

View File

@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.memory import MemoryConsolidator
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.subagent import SubagentManager
@ -37,6 +37,120 @@ if TYPE_CHECKING:
from nanobot.cron.service import CronService
class _LoopHook(AgentHook):
"""Core lifecycle hook for the main agent loop.
Handles streaming delta relay, progress reporting, tool-call logging,
and think-tag stripping for the built-in agent path.
"""
def __init__(
self,
agent_loop: AgentLoop,
on_progress: Callable[..., Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
*,
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
) -> None:
self._loop = agent_loop
self._on_progress = on_progress
self._on_stream = on_stream
self._on_stream_end = on_stream_end
self._channel = channel
self._chat_id = chat_id
self._message_id = message_id
self._stream_buf = ""
def wants_streaming(self) -> bool:
return self._on_stream is not None
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
from nanobot.utils.helpers import strip_think
prev_clean = strip_think(self._stream_buf)
self._stream_buf += delta
new_clean = strip_think(self._stream_buf)
incremental = new_clean[len(prev_clean):]
if incremental and self._on_stream:
await self._on_stream(incremental)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
if self._on_stream_end:
await self._on_stream_end(resuming=resuming)
self._stream_buf = ""
async def before_execute_tools(self, context: AgentHookContext) -> None:
if self._on_progress:
if not self._on_stream:
thought = self._loop._strip_think(
context.response.content if context.response else None
)
if thought:
await self._on_progress(thought)
tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls))
await self._on_progress(tool_hint, tool_hint=True)
for tc in context.tool_calls:
args_str = json.dumps(tc.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tc.name, args_str[:200])
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
async def after_iteration(self, context: AgentHookContext) -> None:
u = context.usage or {}
logger.debug(
"LLM usage: prompt={} completion={} cached={}",
u.get("prompt_tokens", 0),
u.get("completion_tokens", 0),
u.get("cached_tokens", 0),
)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return self._loop._strip_think(content)
class _LoopHookChain(AgentHook):
"""Run the core loop hook first, then best-effort extra hooks.
This preserves the historical failure behavior of ``_LoopHook`` while still
letting user-supplied hooks opt into ``CompositeHook`` isolation.
"""
__slots__ = ("_primary", "_extras")
def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None:
self._primary = primary
self._extras = CompositeHook(extra_hooks)
def wants_streaming(self) -> bool:
return self._primary.wants_streaming() or self._extras.wants_streaming()
async def before_iteration(self, context: AgentHookContext) -> None:
await self._primary.before_iteration(context)
await self._extras.before_iteration(context)
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
await self._primary.on_stream(context, delta)
await self._extras.on_stream(context, delta)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
await self._primary.on_stream_end(context, resuming=resuming)
await self._extras.on_stream_end(context, resuming=resuming)
async def before_execute_tools(self, context: AgentHookContext) -> None:
await self._primary.before_execute_tools(context)
await self._extras.before_execute_tools(context)
async def after_iteration(self, context: AgentHookContext) -> None:
await self._primary.after_iteration(context)
await self._extras.after_iteration(context)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
content = self._primary.finalize_content(context, content)
return self._extras.finalize_content(context, content)
class AgentLoop:
"""
The agent loop is the core processing engine.
@ -68,6 +182,7 @@ class AgentLoop:
mcp_servers: dict | None = None,
channels_config: ChannelsConfig | None = None,
timezone: str | None = None,
hooks: list[AgentHook] | None = None,
):
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
@ -85,6 +200,7 @@ class AgentLoop:
self.restrict_to_workspace = restrict_to_workspace
self._start_time = time.time()
self._last_usage: dict[str, int] = {}
self._extra_hooks: list[AgentHook] = hooks or []
self.context = ContextBuilder(workspace, timezone=timezone)
self.sessions = session_manager or SessionManager(workspace)
@ -217,52 +333,27 @@ class AgentLoop:
``resuming=True`` means tool calls follow (spinner should restart);
``resuming=False`` means this is the final response.
"""
loop_self = self
class _LoopHook(AgentHook):
def __init__(self) -> None:
self._stream_buf = ""
def wants_streaming(self) -> bool:
return on_stream is not None
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
from nanobot.utils.helpers import strip_think
prev_clean = strip_think(self._stream_buf)
self._stream_buf += delta
new_clean = strip_think(self._stream_buf)
incremental = new_clean[len(prev_clean):]
if incremental and on_stream:
await on_stream(incremental)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
if on_stream_end:
await on_stream_end(resuming=resuming)
self._stream_buf = ""
async def before_execute_tools(self, context: AgentHookContext) -> None:
if on_progress:
if not on_stream:
thought = loop_self._strip_think(context.response.content if context.response else None)
if thought:
await on_progress(thought)
tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls))
await on_progress(tool_hint, tool_hint=True)
for tc in context.tool_calls:
args_str = json.dumps(tc.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tc.name, args_str[:200])
loop_self._set_tool_context(channel, chat_id, message_id)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return loop_self._strip_think(content)
loop_hook = _LoopHook(
self,
on_progress=on_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
channel=channel,
chat_id=chat_id,
message_id=message_id,
)
hook: AgentHook = (
_LoopHookChain(loop_hook, self._extra_hooks)
if self._extra_hooks
else loop_hook
)
result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
model=self.model,
max_iterations=self.max_iterations,
hook=_LoopHook(),
hook=hook,
error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True,
))
@ -321,25 +412,25 @@ class AgentLoop:
return f"{stream_base_id}:{stream_segment}"
async def on_stream(delta: str) -> None:
meta = dict(msg.metadata or {})
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content=delta,
metadata={
"_stream_delta": True,
"_stream_id": _current_stream_id(),
},
metadata=meta,
))
async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(msg.metadata or {})
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="",
metadata={
"_stream_end": True,
"_resuming": resuming,
"_stream_id": _current_stream_id(),
},
metadata=meta,
))
stream_segment += 1

View File

@ -60,7 +60,7 @@ class AgentRunner:
messages = list(spec.initial_messages)
final_content: str | None = None
tools_used: list[str] = []
usage = {"prompt_tokens": 0, "completion_tokens": 0}
usage: dict[str, int] = {}
error: str | None = None
stop_reason = "completed"
tool_events: list[dict[str, str]] = []
@ -92,13 +92,15 @@ class AgentRunner:
response = await self.provider.chat_with_retry(**kwargs)
raw_usage = response.usage or {}
usage = {
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
}
context.response = response
context.usage = usage
context.usage = raw_usage
context.tool_calls = list(response.tool_calls)
# Accumulate standard fields into result usage.
usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0)
usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0)
cached = raw_usage.get("cached_tokens")
if cached:
usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached)
if response.has_tool_calls:
if hook.wants_streaming():

View File

@ -21,6 +21,21 @@ from nanobot.config.schema import ExecToolConfig
from nanobot.providers.base import LLMProvider
class _SubagentHook(AgentHook):
"""Logging-only hook for subagent execution."""
def __init__(self, task_id: str) -> None:
self._task_id = task_id
async def before_execute_tools(self, context: AgentHookContext) -> None:
for tool_call in context.tool_calls:
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
logger.debug(
"Subagent [{}] executing: {} with arguments: {}",
self._task_id, tool_call.name, args_str,
)
class SubagentManager:
"""Manages background subagent execution."""
@ -100,33 +115,28 @@ class SubagentManager:
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
path_append=self.exec_config.path_append,
))
if self.exec_config.enable:
tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
path_append=self.exec_config.path_append,
))
tools.register(WebSearchTool(config=self.web_search_config, proxy=self.web_proxy))
tools.register(WebFetchTool(proxy=self.web_proxy))
system_prompt = self._build_subagent_prompt()
messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": task},
]
class _SubagentHook(AgentHook):
async def before_execute_tools(self, context: AgentHookContext) -> None:
for tool_call in context.tool_calls:
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
result = await self.runner.run(AgentRunSpec(
initial_messages=messages,
tools=tools,
model=self.model,
max_iterations=15,
hook=_SubagentHook(),
hook=_SubagentHook(task_id),
max_iterations_message="Task completed but no final response was generated.",
error_message=None,
fail_on_tool_error=True,
@ -213,7 +223,7 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
lines.append("Failure:")
lines.append(f"- {result.error}")
return "\n".join(lines) or (result.error or "Error: subagent execution failed.")
def _build_subagent_prompt(self) -> str:
"""Build a focused system prompt for the subagent."""
from nanobot.agent.context import ContextBuilder

View File

@ -74,7 +74,7 @@ class CronTool(Tool):
"enum": ["add", "list", "remove"],
"description": "Action to perform",
},
"message": {"type": "string", "description": "Reminder message (for add)"},
"message": {"type": "string", "description": "Instruction for the agent to execute when the job triggers (e.g., 'Send a reminder to WeChat: xxx' or 'Check system status and report')"},
"every_seconds": {
"type": "integer",
"description": "Interval in seconds (for recurring tasks)",
@ -97,6 +97,11 @@ class CronTool(Tool):
f"(e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}."
),
},
"deliver": {
"type": "boolean",
"description": "Whether to deliver the execution result to the user channel (default true)",
"default": True
},
"job_id": {"type": "string", "description": "Job ID (for remove)"},
},
"required": ["action"],
@ -111,12 +116,13 @@ class CronTool(Tool):
tz: str | None = None,
at: str | None = None,
job_id: str | None = None,
deliver: bool = True,
**kwargs: Any,
) -> str:
if action == "add":
if self._in_cron_context.get():
return "Error: cannot schedule new jobs from within a cron job execution"
return self._add_job(message, every_seconds, cron_expr, tz, at)
return self._add_job(message, every_seconds, cron_expr, tz, at, deliver)
elif action == "list":
return self._list_jobs()
elif action == "remove":
@ -130,6 +136,7 @@ class CronTool(Tool):
cron_expr: str | None,
tz: str | None,
at: str | None,
deliver: bool = True,
) -> str:
if not message:
return "Error: message is required for add"
@ -171,7 +178,7 @@ class CronTool(Tool):
name=message[:30],
schedule=schedule,
message=message,
deliver=True,
deliver=deliver,
channel=self._channel,
to=self._chat_id,
delete_after_run=delete_after,

View File

@ -170,7 +170,11 @@ async def connect_mcp_servers(
timeout: httpx.Timeout | None = None,
auth: httpx.Auth | None = None,
) -> httpx.AsyncClient:
merged_headers = {**(cfg.headers or {}), **(headers or {})}
merged_headers = {
"Accept": "application/json, text/event-stream",
**(cfg.headers or {}),
**(headers or {}),
}
return httpx.AsyncClient(
headers=merged_headers or None,
follow_redirects=True,

View File

@ -86,7 +86,15 @@ class MessageTool(Tool):
) -> str:
channel = channel or self._default_channel
chat_id = chat_id or self._default_chat_id
message_id = message_id or self._default_message_id
# Only inherit default message_id when targeting the same channel+chat.
# Cross-chat sends must not carry the original message_id, because
# some channels (e.g. Feishu) use it to determine the target
# conversation via their Reply API, which would route the message
# to the wrong chat entirely.
if channel == self._default_channel and chat_id == self._default_chat_id:
message_id = message_id or self._default_message_id
else:
message_id = None
if not channel or not chat_id:
return "Error: No target channel/chat specified"
@ -101,7 +109,7 @@ class MessageTool(Tool):
media=media or [],
metadata={
"message_id": message_id,
},
} if message_id else {},
)
try:

View File

@ -186,7 +186,9 @@ class ExecTool(Tool):
@staticmethod
def _extract_absolute_paths(command: str) -> list[str]:
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
# Windows: match drive-root paths like `C:\` as well as `C:\path\to\file`
# NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted.
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]*", command)
posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~
return win_paths + posix_paths + home_paths

1
nanobot/api/__init__.py Normal file
View File

@ -0,0 +1 @@
"""OpenAI-compatible HTTP API for nanobot."""

193
nanobot/api/server.py Normal file
View File

@ -0,0 +1,193 @@
"""OpenAI-compatible HTTP API server for a fixed nanobot session.
Provides /v1/chat/completions and /v1/models endpoints.
All requests route to a single persistent API session.
"""
from __future__ import annotations
import asyncio
import time
import uuid
from typing import Any
from aiohttp import web
from loguru import logger
API_SESSION_KEY = "api:default"
API_CHAT_ID = "default"
# ---------------------------------------------------------------------------
# Response helpers
# ---------------------------------------------------------------------------
def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response:
return web.json_response(
{"error": {"message": message, "type": err_type, "code": status}},
status=status,
)
def _chat_completion_response(content: str, model: str) -> dict[str, Any]:
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
def _response_text(value: Any) -> str:
"""Normalize process_direct output to plain assistant text."""
if value is None:
return ""
if hasattr(value, "content"):
return str(getattr(value, "content") or "")
return str(value)
# ---------------------------------------------------------------------------
# Route handlers
# ---------------------------------------------------------------------------
async def handle_chat_completions(request: web.Request) -> web.Response:
"""POST /v1/chat/completions"""
# --- Parse body ---
try:
body = await request.json()
except Exception:
return _error_json(400, "Invalid JSON body")
messages = body.get("messages")
if not isinstance(messages, list) or len(messages) != 1:
return _error_json(400, "Only a single user message is supported")
# Stream not yet supported
if body.get("stream", False):
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
message = messages[0]
if not isinstance(message, dict) or message.get("role") != "user":
return _error_json(400, "Only a single user message is supported")
user_content = message.get("content", "")
if isinstance(user_content, list):
# Multi-modal content array — extract text parts
user_content = " ".join(
part.get("text", "") for part in user_content if part.get("type") == "text"
)
agent_loop = request.app["agent_loop"]
timeout_s: float = request.app.get("request_timeout", 120.0)
model_name: str = request.app.get("model_name", "nanobot")
if (requested_model := body.get("model")) and requested_model != model_name:
return _error_json(400, f"Only configured model '{model_name}' is available")
session_key = f"api:{body['session_id']}" if body.get("session_id") else API_SESSION_KEY
session_locks: dict[str, asyncio.Lock] = request.app["session_locks"]
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
logger.info("API request session_key={} content={}", session_key, user_content[:80])
_FALLBACK = "I've completed processing but have no response to give."
try:
async with session_lock:
try:
response = await asyncio.wait_for(
agent_loop.process_direct(
content=user_content,
session_key=session_key,
channel="api",
chat_id=API_CHAT_ID,
),
timeout=timeout_s,
)
response_text = _response_text(response)
if not response_text or not response_text.strip():
logger.warning(
"Empty response for session {}, retrying",
session_key,
)
retry_response = await asyncio.wait_for(
agent_loop.process_direct(
content=user_content,
session_key=session_key,
channel="api",
chat_id=API_CHAT_ID,
),
timeout=timeout_s,
)
response_text = _response_text(retry_response)
if not response_text or not response_text.strip():
logger.warning(
"Empty response after retry for session {}, using fallback",
session_key,
)
response_text = _FALLBACK
except asyncio.TimeoutError:
return _error_json(504, f"Request timed out after {timeout_s}s")
except Exception:
logger.exception("Error processing request for session {}", session_key)
return _error_json(500, "Internal server error", err_type="server_error")
except Exception:
logger.exception("Unexpected API lock error for session {}", session_key)
return _error_json(500, "Internal server error", err_type="server_error")
return web.json_response(_chat_completion_response(response_text, model_name))
async def handle_models(request: web.Request) -> web.Response:
"""GET /v1/models"""
model_name = request.app.get("model_name", "nanobot")
return web.json_response({
"object": "list",
"data": [
{
"id": model_name,
"object": "model",
"created": 0,
"owned_by": "nanobot",
}
],
})
async def handle_health(request: web.Request) -> web.Response:
"""GET /health"""
return web.json_response({"status": "ok"})
# ---------------------------------------------------------------------------
# App factory
# ---------------------------------------------------------------------------
def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0) -> web.Application:
"""Create the aiohttp application.
Args:
agent_loop: An initialized AgentLoop instance.
model_name: Model name reported in responses.
request_timeout: Per-request timeout in seconds.
"""
app = web.Application()
app["agent_loop"] = agent_loop
app["model_name"] = model_name
app["request_timeout"] = request_timeout
app["session_locks"] = {} # per-user locks, keyed by session_key
app.router.add_post("/v1/chat/completions", handle_chat_completions)
app.router.add_get("/v1/models", handle_models)
app.router.add_get("/health", handle_health)
return app

View File

@ -1,25 +1,37 @@
"""Discord channel implementation using Discord Gateway websocket."""
"""Discord channel implementation using discord.py."""
from __future__ import annotations
import asyncio
import json
import importlib.util
from pathlib import Path
from typing import Any, Literal
from typing import TYPE_CHECKING, Any, Literal
import httpx
from pydantic import Field
import websockets
from loguru import logger
from pydantic import Field
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.command.builtin import build_help_text
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base
from nanobot.utils.helpers import split_message
from nanobot.utils.helpers import safe_filename, split_message
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
if TYPE_CHECKING:
import discord
from discord import app_commands
from discord.abc import Messageable
if DISCORD_AVAILABLE:
import discord
from discord import app_commands
from discord.abc import Messageable
DISCORD_API_BASE = "https://discord.com/api/v10"
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
MAX_MESSAGE_LEN = 2000 # Discord message character limit
TYPING_INTERVAL_S = 8
class DiscordConfig(Base):
@ -28,13 +40,205 @@ class DiscordConfig(Base):
enabled: bool = False
token: str = ""
allow_from: list[str] = Field(default_factory=list)
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
intents: int = 37377
group_policy: Literal["mention", "open"] = "mention"
read_receipt_emoji: str = "👀"
working_emoji: str = "🔧"
working_emoji_delay: float = 2.0
if DISCORD_AVAILABLE:
class DiscordBotClient(discord.Client):
"""discord.py client that forwards events to the channel."""
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
super().__init__(intents=intents)
self._channel = channel
self.tree = app_commands.CommandTree(self)
self._register_app_commands()
async def on_ready(self) -> None:
self._channel._bot_user_id = str(self.user.id) if self.user else None
logger.info("Discord bot connected as user {}", self._channel._bot_user_id)
try:
synced = await self.tree.sync()
logger.info("Discord app commands synced: {}", len(synced))
except Exception as e:
logger.warning("Discord app command sync failed: {}", e)
async def on_message(self, message: discord.Message) -> None:
await self._channel._handle_discord_message(message)
async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool:
"""Send an ephemeral interaction response and report success."""
try:
await interaction.response.send_message(text, ephemeral=True)
return True
except Exception as e:
logger.warning("Discord interaction response failed: {}", e)
return False
async def _forward_slash_command(
self,
interaction: discord.Interaction,
command_text: str,
) -> None:
sender_id = str(interaction.user.id)
channel_id = interaction.channel_id
if channel_id is None:
logger.warning("Discord slash command missing channel_id: {}", command_text)
return
if not self._channel.is_allowed(sender_id):
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
return
await self._reply_ephemeral(interaction, f"Processing {command_text}...")
await self._channel._handle_message(
sender_id=sender_id,
chat_id=str(channel_id),
content=command_text,
metadata={
"interaction_id": str(interaction.id),
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
"is_slash_command": True,
},
)
def _register_app_commands(self) -> None:
commands = (
("new", "Start a new conversation", "/new"),
("stop", "Stop the current task", "/stop"),
("restart", "Restart the bot", "/restart"),
("status", "Show bot status", "/status"),
)
for name, description, command_text in commands:
@self.tree.command(name=name, description=description)
async def command_handler(
interaction: discord.Interaction,
_command_text: str = command_text,
) -> None:
await self._forward_slash_command(interaction, _command_text)
@self.tree.command(name="help", description="Show available commands")
async def help_command(interaction: discord.Interaction) -> None:
sender_id = str(interaction.user.id)
if not self._channel.is_allowed(sender_id):
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
return
await self._reply_ephemeral(interaction, build_help_text())
@self.tree.error
async def on_app_command_error(
interaction: discord.Interaction,
error: app_commands.AppCommandError,
) -> None:
command_name = interaction.command.qualified_name if interaction.command else "?"
logger.warning(
"Discord app command failed user={} channel={} cmd={} error={}",
interaction.user.id,
interaction.channel_id,
command_name,
error,
)
async def send_outbound(self, msg: OutboundMessage) -> None:
"""Send a nanobot outbound message using Discord transport rules."""
channel_id = int(msg.chat_id)
channel = self.get_channel(channel_id)
if channel is None:
try:
channel = await self.fetch_channel(channel_id)
except Exception as e:
logger.warning("Discord channel {} unavailable: {}", msg.chat_id, e)
return
reference, mention_settings = self._build_reply_context(channel, msg.reply_to)
sent_media = False
failed_media: list[str] = []
for index, media_path in enumerate(msg.media or []):
if await self._send_file(
channel,
media_path,
reference=reference if index == 0 else None,
mention_settings=mention_settings,
):
sent_media = True
else:
failed_media.append(Path(media_path).name)
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
kwargs: dict[str, Any] = {"content": chunk}
if index == 0 and reference is not None and not sent_media:
kwargs["reference"] = reference
kwargs["allowed_mentions"] = mention_settings
await channel.send(**kwargs)
async def _send_file(
self,
channel: Messageable,
file_path: str,
*,
reference: discord.PartialMessage | None,
mention_settings: discord.AllowedMentions,
) -> bool:
"""Send a file attachment via discord.py."""
path = Path(file_path)
if not path.is_file():
logger.warning("Discord file not found, skipping: {}", file_path)
return False
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
return False
try:
kwargs: dict[str, Any] = {"file": discord.File(path)}
if reference is not None:
kwargs["reference"] = reference
kwargs["allowed_mentions"] = mention_settings
await channel.send(**kwargs)
logger.info("Discord file sent: {}", path.name)
return True
except Exception as e:
logger.error("Error sending Discord file {}: {}", path.name, e)
return False
@staticmethod
def _build_chunks(content: str, failed_media: list[str], sent_media: bool) -> list[str]:
"""Build outbound text chunks, including attachment-failure fallback text."""
chunks = split_message(content, MAX_MESSAGE_LEN)
if chunks or not failed_media or sent_media:
return chunks
fallback = "\n".join(f"[attachment: {name} - send failed]" for name in failed_media)
return split_message(fallback, MAX_MESSAGE_LEN)
@staticmethod
def _build_reply_context(
channel: Messageable,
reply_to: str | None,
) -> tuple[discord.PartialMessage | None, discord.AllowedMentions]:
"""Build reply context for outbound messages."""
mention_settings = discord.AllowedMentions(replied_user=False)
if not reply_to:
return None, mention_settings
try:
message_id = int(reply_to)
except ValueError:
logger.warning("Invalid Discord reply target: {}", reply_to)
return None, mention_settings
return channel.get_partial_message(message_id), mention_settings
class DiscordChannel(BaseChannel):
"""Discord channel using Gateway websocket."""
"""Discord channel using discord.py."""
name = "discord"
display_name = "Discord"
@ -43,353 +247,270 @@ class DiscordChannel(BaseChannel):
def default_config(cls) -> dict[str, Any]:
return DiscordConfig().model_dump(by_alias=True)
@staticmethod
def _channel_key(channel_or_id: Any) -> str:
"""Normalize channel-like objects and ids to a stable string key."""
channel_id = getattr(channel_or_id, "id", channel_or_id)
return str(channel_id)
def __init__(self, config: Any, bus: MessageBus):
if isinstance(config, dict):
config = DiscordConfig.model_validate(config)
super().__init__(config, bus)
self.config: DiscordConfig = config
self._ws: websockets.WebSocketClientProtocol | None = None
self._seq: int | None = None
self._heartbeat_task: asyncio.Task | None = None
self._typing_tasks: dict[str, asyncio.Task] = {}
self._http: httpx.AsyncClient | None = None
self._client: DiscordBotClient | None = None
self._typing_tasks: dict[str, asyncio.Task[None]] = {}
self._bot_user_id: str | None = None
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
async def start(self) -> None:
"""Start the Discord gateway connection."""
"""Start the Discord client."""
if not DISCORD_AVAILABLE:
logger.error("discord.py not installed. Run: pip install nanobot-ai[discord]")
return
if not self.config.token:
logger.error("Discord bot token not configured")
return
self._running = True
self._http = httpx.AsyncClient(timeout=30.0)
try:
intents = discord.Intents.none()
intents.value = self.config.intents
self._client = DiscordBotClient(self, intents=intents)
except Exception as e:
logger.error("Failed to initialize Discord client: {}", e)
self._client = None
self._running = False
return
while self._running:
try:
logger.info("Connecting to Discord gateway...")
async with websockets.connect(self.config.gateway_url) as ws:
self._ws = ws
await self._gateway_loop()
except asyncio.CancelledError:
break
except Exception as e:
logger.warning("Discord gateway error: {}", e)
if self._running:
logger.info("Reconnecting to Discord gateway in 5 seconds...")
await asyncio.sleep(5)
self._running = True
logger.info("Starting Discord client via discord.py...")
try:
await self._client.start(self.config.token)
except asyncio.CancelledError:
raise
except Exception as e:
logger.error("Discord client startup failed: {}", e)
finally:
self._running = False
await self._reset_runtime_state(close_client=True)
async def stop(self) -> None:
"""Stop the Discord channel."""
self._running = False
if self._heartbeat_task:
self._heartbeat_task.cancel()
self._heartbeat_task = None
for task in self._typing_tasks.values():
task.cancel()
self._typing_tasks.clear()
if self._ws:
await self._ws.close()
self._ws = None
if self._http:
await self._http.aclose()
self._http = None
await self._reset_runtime_state(close_client=True)
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through Discord REST API, including file attachments."""
if not self._http:
logger.warning("Discord HTTP client not initialized")
"""Send a message through Discord using discord.py."""
client = self._client
if client is None or not client.is_ready():
logger.warning("Discord client not ready; dropping outbound message")
return
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
headers = {"Authorization": f"Bot {self.config.token}"}
is_progress = bool((msg.metadata or {}).get("_progress"))
try:
sent_media = False
failed_media: list[str] = []
# Send file attachments first
for media_path in msg.media or []:
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
sent_media = True
else:
failed_media.append(Path(media_path).name)
# Send text content
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
if not chunks and failed_media and not sent_media:
chunks = split_message(
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
MAX_MESSAGE_LEN,
)
if not chunks:
return
for i, chunk in enumerate(chunks):
payload: dict[str, Any] = {"content": chunk}
# Let the first successful attachment carry the reply if present.
if i == 0 and msg.reply_to and not sent_media:
payload["message_reference"] = {"message_id": msg.reply_to}
payload["allowed_mentions"] = {"replied_user": False}
if not await self._send_payload(url, headers, payload):
break # Abort remaining chunks on failure
await client.send_outbound(msg)
except Exception as e:
logger.error("Error sending Discord message: {}", e)
finally:
await self._stop_typing(msg.chat_id)
if not is_progress:
await self._stop_typing(msg.chat_id)
await self._clear_reactions(msg.chat_id)
async def _send_payload(
self, url: str, headers: dict[str, str], payload: dict[str, Any]
) -> bool:
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
for attempt in range(3):
async def _handle_discord_message(self, message: discord.Message) -> None:
"""Handle incoming Discord messages from discord.py."""
if message.author.bot:
return
sender_id = str(message.author.id)
channel_id = self._channel_key(message.channel)
content = message.content or ""
if not self._should_accept_inbound(message, sender_id, content):
return
media_paths, attachment_markers = await self._download_attachments(message.attachments)
full_content = self._compose_inbound_content(content, attachment_markers)
metadata = self._build_inbound_metadata(message)
await self._start_typing(message.channel)
# Add read receipt reaction immediately, working emoji after delay
channel_id = self._channel_key(message.channel)
try:
await message.add_reaction(self.config.read_receipt_emoji)
self._pending_reactions[channel_id] = message
except Exception as e:
logger.debug("Failed to add read receipt reaction: {}", e)
# Delayed working indicator (cosmetic — not tied to subagent lifecycle)
async def _delayed_working_emoji() -> None:
await asyncio.sleep(self.config.working_emoji_delay)
try:
response = await self._http.post(url, headers=headers, json=payload)
if response.status_code == 429:
data = response.json()
retry_after = float(data.get("retry_after", 1.0))
logger.warning("Discord rate limited, retrying in {}s", retry_after)
await asyncio.sleep(retry_after)
continue
response.raise_for_status()
return True
except Exception as e:
if attempt == 2:
logger.error("Error sending Discord message: {}", e)
else:
await asyncio.sleep(1)
return False
await message.add_reaction(self.config.working_emoji)
except Exception:
pass
async def _send_file(
self._working_emoji_tasks[channel_id] = asyncio.create_task(_delayed_working_emoji())
try:
await self._handle_message(
sender_id=sender_id,
chat_id=channel_id,
content=full_content,
media=media_paths,
metadata=metadata,
)
except Exception:
await self._clear_reactions(channel_id)
await self._stop_typing(channel_id)
raise
async def _on_message(self, message: discord.Message) -> None:
"""Backward-compatible alias for legacy tests/callers."""
await self._handle_discord_message(message)
def _should_accept_inbound(
self,
url: str,
headers: dict[str, str],
file_path: str,
reply_to: str | None = None,
message: discord.Message,
sender_id: str,
content: str,
) -> bool:
"""Send a file attachment via Discord REST API using multipart/form-data."""
path = Path(file_path)
if not path.is_file():
logger.warning("Discord file not found, skipping: {}", file_path)
return False
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
return False
payload_json: dict[str, Any] = {}
if reply_to:
payload_json["message_reference"] = {"message_id": reply_to}
payload_json["allowed_mentions"] = {"replied_user": False}
for attempt in range(3):
try:
with open(path, "rb") as f:
files = {"files[0]": (path.name, f, "application/octet-stream")}
data: dict[str, Any] = {}
if payload_json:
data["payload_json"] = json.dumps(payload_json)
response = await self._http.post(
url, headers=headers, files=files, data=data
)
if response.status_code == 429:
resp_data = response.json()
retry_after = float(resp_data.get("retry_after", 1.0))
logger.warning("Discord rate limited, retrying in {}s", retry_after)
await asyncio.sleep(retry_after)
continue
response.raise_for_status()
logger.info("Discord file sent: {}", path.name)
return True
except Exception as e:
if attempt == 2:
logger.error("Error sending Discord file {}: {}", path.name, e)
else:
await asyncio.sleep(1)
return False
async def _gateway_loop(self) -> None:
"""Main gateway loop: identify, heartbeat, dispatch events."""
if not self._ws:
return
async for raw in self._ws:
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
continue
op = data.get("op")
event_type = data.get("t")
seq = data.get("s")
payload = data.get("d")
if seq is not None:
self._seq = seq
if op == 10:
# HELLO: start heartbeat and identify
interval_ms = payload.get("heartbeat_interval", 45000)
await self._start_heartbeat(interval_ms / 1000)
await self._identify()
elif op == 0 and event_type == "READY":
logger.info("Discord gateway READY")
# Capture bot user ID for mention detection
user_data = payload.get("user") or {}
self._bot_user_id = user_data.get("id")
logger.info("Discord bot connected as user {}", self._bot_user_id)
elif op == 0 and event_type == "MESSAGE_CREATE":
await self._handle_message_create(payload)
elif op == 7:
# RECONNECT: exit loop to reconnect
logger.info("Discord gateway requested reconnect")
break
elif op == 9:
# INVALID_SESSION: reconnect
logger.warning("Discord gateway invalid session")
break
async def _identify(self) -> None:
"""Send IDENTIFY payload."""
if not self._ws:
return
identify = {
"op": 2,
"d": {
"token": self.config.token,
"intents": self.config.intents,
"properties": {
"os": "nanobot",
"browser": "nanobot",
"device": "nanobot",
},
},
}
await self._ws.send(json.dumps(identify))
async def _start_heartbeat(self, interval_s: float) -> None:
"""Start or restart the heartbeat loop."""
if self._heartbeat_task:
self._heartbeat_task.cancel()
async def heartbeat_loop() -> None:
while self._running and self._ws:
payload = {"op": 1, "d": self._seq}
try:
await self._ws.send(json.dumps(payload))
except Exception as e:
logger.warning("Discord heartbeat failed: {}", e)
break
await asyncio.sleep(interval_s)
self._heartbeat_task = asyncio.create_task(heartbeat_loop())
async def _handle_message_create(self, payload: dict[str, Any]) -> None:
"""Handle incoming Discord messages."""
author = payload.get("author") or {}
if author.get("bot"):
return
sender_id = str(author.get("id", ""))
channel_id = str(payload.get("channel_id", ""))
content = payload.get("content") or ""
guild_id = payload.get("guild_id")
if not sender_id or not channel_id:
return
"""Check if inbound Discord message should be processed."""
if not self.is_allowed(sender_id):
return
return False
if message.guild is not None and not self._should_respond_in_group(message, content):
return False
return True
# Check group channel policy (DMs always respond if is_allowed passes)
if guild_id is not None:
if not self._should_respond_in_group(payload, content):
return
content_parts = [content] if content else []
async def _download_attachments(
self,
attachments: list[discord.Attachment],
) -> tuple[list[str], list[str]]:
"""Download supported attachments and return paths + display markers."""
media_paths: list[str] = []
markers: list[str] = []
media_dir = get_media_dir("discord")
for attachment in payload.get("attachments") or []:
url = attachment.get("url")
filename = attachment.get("filename") or "attachment"
size = attachment.get("size") or 0
if not url or not self._http:
continue
if size and size > MAX_ATTACHMENT_BYTES:
content_parts.append(f"[attachment: {filename} - too large]")
for attachment in attachments:
filename = attachment.filename or "attachment"
if attachment.size and attachment.size > MAX_ATTACHMENT_BYTES:
markers.append(f"[attachment: {filename} - too large]")
continue
try:
media_dir.mkdir(parents=True, exist_ok=True)
file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}"
resp = await self._http.get(url)
resp.raise_for_status()
file_path.write_bytes(resp.content)
safe_name = safe_filename(filename)
file_path = media_dir / f"{attachment.id}_{safe_name}"
await attachment.save(file_path)
media_paths.append(str(file_path))
content_parts.append(f"[attachment: {file_path}]")
markers.append(f"[attachment: {file_path.name}]")
except Exception as e:
logger.warning("Failed to download Discord attachment: {}", e)
content_parts.append(f"[attachment: {filename} - download failed]")
markers.append(f"[attachment: {filename} - download failed]")
reply_to = (payload.get("referenced_message") or {}).get("id")
return media_paths, markers
await self._start_typing(channel_id)
@staticmethod
def _compose_inbound_content(content: str, attachment_markers: list[str]) -> str:
"""Combine message text with attachment markers."""
content_parts = [content] if content else []
content_parts.extend(attachment_markers)
return "\n".join(part for part in content_parts if part) or "[empty message]"
await self._handle_message(
sender_id=sender_id,
chat_id=channel_id,
content="\n".join(p for p in content_parts if p) or "[empty message]",
media=media_paths,
metadata={
"message_id": str(payload.get("id", "")),
"guild_id": guild_id,
"reply_to": reply_to,
},
)
@staticmethod
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
"""Build metadata for inbound Discord messages."""
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
return {
"message_id": str(message.id),
"guild_id": str(message.guild.id) if message.guild else None,
"reply_to": reply_to,
}
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
"""Check if bot should respond in a group channel based on policy."""
def _should_respond_in_group(self, message: discord.Message, content: str) -> bool:
"""Check if the bot should respond in a guild channel based on policy."""
if self.config.group_policy == "open":
return True
if self.config.group_policy == "mention":
# Check if bot was mentioned in the message
if self._bot_user_id:
# Check mentions array
mentions = payload.get("mentions") or []
for mention in mentions:
if str(mention.get("id")) == self._bot_user_id:
return True
# Also check content for mention format <@USER_ID>
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
return True
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
bot_user_id = self._bot_user_id
if bot_user_id is None:
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
return False
if any(str(user.id) == bot_user_id for user in message.mentions):
return True
if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content:
return True
logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id)
return False
return True
async def _start_typing(self, channel_id: str) -> None:
async def _start_typing(self, channel: Messageable) -> None:
"""Start periodic typing indicator for a channel."""
channel_id = self._channel_key(channel)
await self._stop_typing(channel_id)
async def typing_loop() -> None:
url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing"
headers = {"Authorization": f"Bot {self.config.token}"}
while self._running:
try:
await self._http.post(url, headers=headers)
async with channel.typing():
await asyncio.sleep(TYPING_INTERVAL_S)
except asyncio.CancelledError:
return
except Exception as e:
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
return
await asyncio.sleep(8)
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
async def _stop_typing(self, channel_id: str) -> None:
"""Stop typing indicator for a channel."""
task = self._typing_tasks.pop(channel_id, None)
if task:
task = self._typing_tasks.pop(self._channel_key(channel_id), None)
if task is None:
return
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _clear_reactions(self, chat_id: str) -> None:
"""Remove all pending reactions after bot replies."""
# Cancel delayed working emoji if it hasn't fired yet
task = self._working_emoji_tasks.pop(chat_id, None)
if task and not task.done():
task.cancel()
msg_obj = self._pending_reactions.pop(chat_id, None)
if msg_obj is None:
return
bot_user = self._client.user if self._client else None
for emoji in (self.config.read_receipt_emoji, self.config.working_emoji):
try:
await msg_obj.remove_reaction(emoji, bot_user)
except Exception:
pass
async def _cancel_all_typing(self) -> None:
"""Stop all typing tasks."""
channel_ids = list(self._typing_tasks)
for channel_id in channel_ids:
await self._stop_typing(channel_id)
async def _reset_runtime_state(self, close_client: bool) -> None:
"""Reset client and typing state."""
await self._cancel_all_typing()
if close_client and self._client is not None and not self._client.is_closed():
try:
await self._client.close()
except Exception as e:
logger.warning("Discord client close failed: {}", e)
self._client = None
self._bot_user_id = None

View File

@ -3,6 +3,8 @@
import asyncio
import logging
import mimetypes
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, TypeAlias
@ -28,8 +30,8 @@ try:
RoomSendError,
RoomTypingError,
SyncError,
UploadError,
)
UploadError, RoomSendResponse,
)
from nio.crypto.attachments import decrypt_attachment
from nio.exceptions import EncryptionError
except ImportError as e:
@ -97,6 +99,22 @@ MATRIX_HTML_CLEANER = nh3.Cleaner(
link_rel="noopener noreferrer",
)
@dataclass
class _StreamBuf:
"""
Represents a buffer for managing LLM response stream data.
:ivar text: Stores the text content of the buffer.
:type text: str
:ivar event_id: Identifier for the associated event. None indicates no
specific event association.
:type event_id: str | None
:ivar last_edit: Timestamp of the most recent edit to the buffer.
:type last_edit: float
"""
text: str = ""
event_id: str | None = None
last_edit: float = 0.0
def _render_markdown_html(text: str) -> str | None:
"""Render markdown to sanitized HTML; returns None for plain text."""
@ -114,12 +132,47 @@ def _render_markdown_html(text: str) -> str | None:
return formatted
def _build_matrix_text_content(text: str) -> dict[str, object]:
"""Build Matrix m.text payload with optional HTML formatted_body."""
def _build_matrix_text_content(
text: str,
event_id: str | None = None,
thread_relates_to: dict[str, object] | None = None,
) -> dict[str, object]:
"""
Constructs and returns a dictionary representing the matrix text content with optional
HTML formatting and reference to an existing event for replacement. This function is
primarily used to create content payloads compatible with the Matrix messaging protocol.
:param text: The plain text content to include in the message.
:type text: str
:param event_id: Optional ID of the event to replace. If provided, the function will
include information indicating that the message is a replacement of the specified
event.
:type event_id: str | None
:param thread_relates_to: Optional Matrix thread relation metadata. For edits this is
stored in ``m.new_content`` so the replacement remains in the same thread.
:type thread_relates_to: dict[str, object] | None
:return: A dictionary containing the matrix text content, potentially enriched with
HTML formatting and replacement metadata if applicable.
:rtype: dict[str, object]
"""
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
if html := _render_markdown_html(text):
content["format"] = MATRIX_HTML_FORMAT
content["formatted_body"] = html
if event_id:
content["m.new_content"] = {
"body": text,
"msgtype": "m.text",
}
content["m.relates_to"] = {
"rel_type": "m.replace",
"event_id": event_id,
}
if thread_relates_to:
content["m.new_content"]["m.relates_to"] = thread_relates_to
elif thread_relates_to:
content["m.relates_to"] = thread_relates_to
return content
@ -159,7 +212,8 @@ class MatrixConfig(Base):
allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list)
allow_room_mentions: bool = False
allow_room_mentions: bool = False,
streaming: bool = False
class MatrixChannel(BaseChannel):
@ -167,6 +221,8 @@ class MatrixChannel(BaseChannel):
name = "matrix"
display_name = "Matrix"
_STREAM_EDIT_INTERVAL = 2 # min seconds between edit_message_text calls
monotonic_time = time.monotonic
@classmethod
def default_config(cls) -> dict[str, Any]:
@ -192,6 +248,8 @@ class MatrixChannel(BaseChannel):
)
self._server_upload_limit_bytes: int | None = None
self._server_upload_limit_checked = False
self._stream_bufs: dict[str, _StreamBuf] = {}
async def start(self) -> None:
"""Start Matrix client and begin sync loop."""
@ -297,14 +355,17 @@ class MatrixChannel(BaseChannel):
room = getattr(self.client, "rooms", {}).get(room_id)
return bool(getattr(room, "encrypted", False))
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
async def _send_room_content(self, room_id: str,
content: dict[str, Any]) -> None | RoomSendResponse | RoomSendError:
"""Send m.room.message with E2EE options."""
if not self.client:
return
return None
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
if self.config.e2ee_enabled:
kwargs["ignore_unverified_devices"] = True
await self.client.room_send(**kwargs)
response = await self.client.room_send(**kwargs)
return response
async def _resolve_server_upload_limit_bytes(self) -> int | None:
"""Query homeserver upload limit once per channel lifecycle."""
@ -414,6 +475,53 @@ class MatrixChannel(BaseChannel):
if not is_progress:
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
meta = metadata or {}
relates_to = self._build_thread_relates_to(metadata)
if meta.get("_stream_end"):
buf = self._stream_bufs.pop(chat_id, None)
if not buf or not buf.event_id or not buf.text:
return
await self._stop_typing_keepalive(chat_id, clear_typing=True)
content = _build_matrix_text_content(
buf.text,
buf.event_id,
thread_relates_to=relates_to,
)
await self._send_room_content(chat_id, content)
return
buf = self._stream_bufs.get(chat_id)
if buf is None:
buf = _StreamBuf()
self._stream_bufs[chat_id] = buf
buf.text += delta
if not buf.text.strip():
return
now = self.monotonic_time()
if not buf.last_edit or (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
try:
content = _build_matrix_text_content(
buf.text,
buf.event_id,
thread_relates_to=relates_to,
)
response = await self._send_room_content(chat_id, content)
buf.last_edit = now
if not buf.event_id:
# we are editing the same message all the time, so only the first time the event id needs to be set
buf.event_id = response.event_id
except Exception:
await self._stop_typing_keepalive(chat_id, clear_typing=True)
pass
def _register_event_callbacks(self) -> None:
self.client.add_event_callback(self._on_message, RoomMessageText)
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)

View File

@ -14,6 +14,7 @@ import base64
import hashlib
import json
import os
import random
import re
import time
import uuid
@ -52,7 +53,26 @@ MESSAGE_TYPE_BOT = 2
MESSAGE_STATE_FINISH = 2
WEIXIN_MAX_MESSAGE_LEN = 4000
WEIXIN_CHANNEL_VERSION = "1.0.3"
WEIXIN_CHANNEL_VERSION = "2.1.1"
ILINK_APP_ID = "bot"
def _build_client_version(version: str) -> int:
"""Encode semantic version as 0x00MMNNPP (major/minor/patch in one uint32)."""
parts = version.split(".")
def _as_int(idx: int) -> int:
try:
return int(parts[idx])
except Exception:
return 0
major = _as_int(0)
minor = _as_int(1)
patch = _as_int(2)
return ((major & 0xFF) << 16) | ((minor & 0xFF) << 8) | (patch & 0xFF)
ILINK_APP_CLIENT_VERSION = _build_client_version(WEIXIN_CHANNEL_VERSION)
BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION}
# Session-expired error code
@ -64,18 +84,32 @@ MAX_CONSECUTIVE_FAILURES = 3
BACKOFF_DELAY_S = 30
RETRY_DELAY_S = 2
MAX_QR_REFRESH_COUNT = 3
TYPING_STATUS_TYPING = 1
TYPING_STATUS_CANCEL = 2
TYPING_TICKET_TTL_S = 24 * 60 * 60
TYPING_KEEPALIVE_INTERVAL_S = 5
CONFIG_CACHE_INITIAL_RETRY_S = 2
CONFIG_CACHE_MAX_RETRY_S = 60 * 60
# Default long-poll timeout; overridden by server via longpolling_timeout_ms.
DEFAULT_LONG_POLL_TIMEOUT_S = 35
# Media-type codes for getuploadurl (1=image, 2=video, 3=file)
# Media-type codes for getuploadurl (1=image, 2=video, 3=file, 4=voice)
UPLOAD_MEDIA_IMAGE = 1
UPLOAD_MEDIA_VIDEO = 2
UPLOAD_MEDIA_FILE = 3
UPLOAD_MEDIA_VOICE = 4
# File extensions considered as images / videos for outbound media
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".ico", ".svg"}
_VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv"}
_VOICE_EXTS = {".mp3", ".wav", ".amr", ".silk", ".ogg", ".m4a", ".aac", ".flac"}
def _has_downloadable_media_locator(media: dict[str, Any] | None) -> bool:
if not isinstance(media, dict):
return False
return bool(str(media.get("encrypt_query_param", "") or "") or str(media.get("full_url", "") or "").strip())
class WeixinConfig(Base):
@ -124,7 +158,7 @@ class WeixinChannel(BaseChannel):
self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S
self._session_pause_until: float = 0.0
self._typing_tasks: dict[str, asyncio.Task] = {}
self._typing_tickets: dict[str, str] = {}
self._typing_tickets: dict[str, dict[str, Any]] = {}
# ------------------------------------------------------------------
# State persistence
@ -162,9 +196,9 @@ class WeixinChannel(BaseChannel):
typing_tickets = data.get("typing_tickets", {})
if isinstance(typing_tickets, dict):
self._typing_tickets = {
str(user_id): str(ticket)
str(user_id): ticket
for user_id, ticket in typing_tickets.items()
if str(user_id).strip() and str(ticket).strip()
if str(user_id).strip() and isinstance(ticket, dict)
}
else:
self._typing_tickets = {}
@ -172,8 +206,7 @@ class WeixinChannel(BaseChannel):
if base_url:
self.config.base_url = base_url
return bool(self._token)
except Exception as e:
logger.warning("Failed to load WeChat state: {}", e)
except Exception:
return False
def _save_state(self) -> None:
@ -187,8 +220,8 @@ class WeixinChannel(BaseChannel):
"base_url": self.config.base_url,
}
state_file.write_text(json.dumps(data, ensure_ascii=False))
except Exception as e:
logger.warning("Failed to save WeChat state: {}", e)
except Exception:
pass
# ------------------------------------------------------------------
# HTTP helpers (matches api.ts buildHeaders / apiFetch)
@ -210,6 +243,8 @@ class WeixinChannel(BaseChannel):
"X-WECHAT-UIN": self._random_wechat_uin(),
"Content-Type": "application/json",
"AuthorizationType": "ilink_bot_token",
"iLink-App-Id": ILINK_APP_ID,
"iLink-App-ClientVersion": str(ILINK_APP_CLIENT_VERSION),
}
if auth and self._token:
headers["Authorization"] = f"Bearer {self._token}"
@ -217,6 +252,15 @@ class WeixinChannel(BaseChannel):
headers["SKRouteTag"] = str(self.config.route_tag).strip()
return headers
@staticmethod
def _is_retryable_media_download_error(err: Exception) -> bool:
if isinstance(err, httpx.TimeoutException | httpx.TransportError):
return True
if isinstance(err, httpx.HTTPStatusError):
status_code = err.response.status_code if err.response is not None else 0
return status_code >= 500
return False
async def _api_get(
self,
endpoint: str,
@ -234,6 +278,25 @@ class WeixinChannel(BaseChannel):
resp.raise_for_status()
return resp.json()
async def _api_get_with_base(
self,
*,
base_url: str,
endpoint: str,
params: dict | None = None,
auth: bool = True,
extra_headers: dict[str, str] | None = None,
) -> dict:
"""GET helper that allows overriding base_url for QR redirect polling."""
assert self._client is not None
url = f"{base_url.rstrip('/')}/{endpoint}"
hdrs = self._make_headers(auth=auth)
if extra_headers:
hdrs.update(extra_headers)
resp = await self._client.get(url, params=params, headers=hdrs)
resp.raise_for_status()
return resp.json()
async def _api_post(
self,
endpoint: str,
@ -270,23 +333,27 @@ class WeixinChannel(BaseChannel):
async def _qr_login(self) -> bool:
"""Perform QR code login flow. Returns True on success."""
try:
logger.info("Starting WeChat QR code login...")
refresh_count = 0
qrcode_id, scan_url = await self._fetch_qr_code()
self._print_qr_code(scan_url)
current_poll_base_url = self.config.base_url
logger.info("Waiting for QR code scan...")
while self._running:
try:
# Reference plugin sends iLink-App-ClientVersion header for
# QR status polling (login-qr.ts:81).
status_data = await self._api_get(
"ilink/bot/get_qrcode_status",
status_data = await self._api_get_with_base(
base_url=current_poll_base_url,
endpoint="ilink/bot/get_qrcode_status",
params={"qrcode": qrcode_id},
auth=False,
extra_headers={"iLink-App-ClientVersion": "1"},
)
except httpx.TimeoutException:
except Exception as e:
if self._is_retryable_qr_poll_error(e):
await asyncio.sleep(1)
continue
raise
if not isinstance(status_data, dict):
await asyncio.sleep(1)
continue
status = status_data.get("status", "")
@ -309,8 +376,15 @@ class WeixinChannel(BaseChannel):
else:
logger.error("Login confirmed but no bot_token in response")
return False
elif status == "scaned":
logger.info("QR code scanned, waiting for confirmation...")
elif status == "scaned_but_redirect":
redirect_host = str(status_data.get("redirect_host", "") or "").strip()
if redirect_host:
if redirect_host.startswith("http://") or redirect_host.startswith("https://"):
redirected_base = redirect_host
else:
redirected_base = f"https://{redirect_host}"
if redirected_base != current_poll_base_url:
current_poll_base_url = redirected_base
elif status == "expired":
refresh_count += 1
if refresh_count > MAX_QR_REFRESH_COUNT:
@ -320,14 +394,9 @@ class WeixinChannel(BaseChannel):
MAX_QR_REFRESH_COUNT,
)
return False
logger.warning(
"QR code expired, refreshing... ({}/{})",
refresh_count,
MAX_QR_REFRESH_COUNT,
)
qrcode_id, scan_url = await self._fetch_qr_code()
current_poll_base_url = self.config.base_url
self._print_qr_code(scan_url)
logger.info("New QR code generated, waiting for scan...")
continue
# status == "wait" — keep polling
@ -338,6 +407,16 @@ class WeixinChannel(BaseChannel):
return False
@staticmethod
def _is_retryable_qr_poll_error(err: Exception) -> bool:
if isinstance(err, httpx.TimeoutException | httpx.TransportError):
return True
if isinstance(err, httpx.HTTPStatusError):
status_code = err.response.status_code if err.response is not None else 0
if status_code >= 500:
return True
return False
@staticmethod
def _print_qr_code(url: str) -> None:
try:
@ -348,7 +427,6 @@ class WeixinChannel(BaseChannel):
qr.make(fit=True)
qr.print_ascii(invert=True)
except ImportError:
logger.info("QR code URL (install 'qrcode' for terminal display): {}", url)
print(f"\nLogin URL: {url}\n")
# ------------------------------------------------------------------
@ -410,12 +488,6 @@ class WeixinChannel(BaseChannel):
if not self._running:
break
consecutive_failures += 1
logger.error(
"WeChat poll error ({}/{}): {}",
consecutive_failures,
MAX_CONSECUTIVE_FAILURES,
e,
)
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
consecutive_failures = 0
await asyncio.sleep(BACKOFF_DELAY_S)
@ -432,8 +504,6 @@ class WeixinChannel(BaseChannel):
await self._client.aclose()
self._client = None
self._save_state()
logger.info("WeChat channel stopped")
# ------------------------------------------------------------------
# Polling (matches monitor.ts monitorWeixinProvider)
# ------------------------------------------------------------------
@ -459,10 +529,6 @@ class WeixinChannel(BaseChannel):
async def _poll_once(self) -> None:
remaining = self._session_pause_remaining_s()
if remaining > 0:
logger.warning(
"WeChat session paused, waiting {} min before next poll.",
max((remaining + 59) // 60, 1),
)
await asyncio.sleep(remaining)
return
@ -512,8 +578,8 @@ class WeixinChannel(BaseChannel):
for msg in msgs:
try:
await self._process_message(msg)
except Exception as e:
logger.error("Error processing WeChat message: {}", e)
except Exception:
pass
# ------------------------------------------------------------------
# Inbound message processing (matches inbound.ts + process-message.ts)
@ -549,6 +615,7 @@ class WeixinChannel(BaseChannel):
item_list: list[dict] = msg.get("item_list") or []
content_parts: list[str] = []
media_paths: list[str] = []
has_top_level_downloadable_media = False
for item in item_list:
item_type = item.get("type", 0)
@ -585,6 +652,8 @@ class WeixinChannel(BaseChannel):
elif item_type == ITEM_IMAGE:
image_item = item.get("image_item") or {}
if _has_downloadable_media_locator(image_item.get("media")):
has_top_level_downloadable_media = True
file_path = await self._download_media_item(image_item, "image")
if file_path:
content_parts.append(f"[image]\n[Image: source: {file_path}]")
@ -599,6 +668,8 @@ class WeixinChannel(BaseChannel):
if voice_text:
content_parts.append(f"[voice] {voice_text}")
else:
if _has_downloadable_media_locator(voice_item.get("media")):
has_top_level_downloadable_media = True
file_path = await self._download_media_item(voice_item, "voice")
if file_path:
transcription = await self.transcribe_audio(file_path)
@ -612,6 +683,8 @@ class WeixinChannel(BaseChannel):
elif item_type == ITEM_FILE:
file_item = item.get("file_item") or {}
if _has_downloadable_media_locator(file_item.get("media")):
has_top_level_downloadable_media = True
file_name = file_item.get("file_name", "unknown")
file_path = await self._download_media_item(
file_item,
@ -626,6 +699,8 @@ class WeixinChannel(BaseChannel):
elif item_type == ITEM_VIDEO:
video_item = item.get("video_item") or {}
if _has_downloadable_media_locator(video_item.get("media")):
has_top_level_downloadable_media = True
file_path = await self._download_media_item(video_item, "video")
if file_path:
content_parts.append(f"[video]\n[Video: source: {file_path}]")
@ -633,6 +708,52 @@ class WeixinChannel(BaseChannel):
else:
content_parts.append("[video]")
# Fallback: when no top-level media was downloaded, try quoted/referenced media.
# This aligns with the reference plugin behavior that checks ref_msg.message_item
# when main item_list has no downloadable media.
if not media_paths and not has_top_level_downloadable_media:
ref_media_item: dict[str, Any] | None = None
for item in item_list:
if item.get("type", 0) != ITEM_TEXT:
continue
ref = item.get("ref_msg") or {}
candidate = ref.get("message_item") or {}
if candidate.get("type", 0) in (ITEM_IMAGE, ITEM_VOICE, ITEM_FILE, ITEM_VIDEO):
ref_media_item = candidate
break
if ref_media_item:
ref_type = ref_media_item.get("type", 0)
if ref_type == ITEM_IMAGE:
image_item = ref_media_item.get("image_item") or {}
file_path = await self._download_media_item(image_item, "image")
if file_path:
content_parts.append(f"[image]\n[Image: source: {file_path}]")
media_paths.append(file_path)
elif ref_type == ITEM_VOICE:
voice_item = ref_media_item.get("voice_item") or {}
file_path = await self._download_media_item(voice_item, "voice")
if file_path:
transcription = await self.transcribe_audio(file_path)
if transcription:
content_parts.append(f"[voice] {transcription}")
else:
content_parts.append(f"[voice]\n[Audio: source: {file_path}]")
media_paths.append(file_path)
elif ref_type == ITEM_FILE:
file_item = ref_media_item.get("file_item") or {}
file_name = file_item.get("file_name", "unknown")
file_path = await self._download_media_item(file_item, "file", file_name)
if file_path:
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
media_paths.append(file_path)
elif ref_type == ITEM_VIDEO:
video_item = ref_media_item.get("video_item") or {}
file_path = await self._download_media_item(video_item, "video")
if file_path:
content_parts.append(f"[video]\n[Video: source: {file_path}]")
media_paths.append(file_path)
content = "\n".join(content_parts)
if not content:
return
@ -667,9 +788,10 @@ class WeixinChannel(BaseChannel):
"""Download + AES-decrypt a media item. Returns local path or None."""
try:
media = typed_item.get("media") or {}
encrypt_query_param = media.get("encrypt_query_param", "")
encrypt_query_param = str(media.get("encrypt_query_param", "") or "")
full_url = str(media.get("full_url", "") or "").strip()
if not encrypt_query_param:
if not encrypt_query_param and not full_url:
return None
# Resolve AES key (media-download.ts:43-45, pic-decrypt.ts:40-52)
@ -686,21 +808,50 @@ class WeixinChannel(BaseChannel):
elif media_aes_key_b64:
aes_key_b64 = media_aes_key_b64
# Build CDN download URL with proper URL-encoding (cdn-url.ts:7)
cdn_url = (
f"{self.config.cdn_base_url}/download"
f"?encrypted_query_param={quote(encrypt_query_param)}"
)
# Reference protocol behavior: VOICE/FILE/VIDEO require aes_key;
# only IMAGE may be downloaded as plain bytes when key is missing.
if media_type != "image" and not aes_key_b64:
return None
assert self._client is not None
resp = await self._client.get(cdn_url)
resp.raise_for_status()
data = resp.content
fallback_url = ""
if encrypt_query_param:
fallback_url = (
f"{self.config.cdn_base_url}/download"
f"?encrypted_query_param={quote(encrypt_query_param)}"
)
download_candidates: list[tuple[str, str]] = []
if full_url:
download_candidates.append(("full_url", full_url))
if fallback_url and (not full_url or fallback_url != full_url):
download_candidates.append(("encrypt_query_param", fallback_url))
data = b""
for idx, (download_source, cdn_url) in enumerate(download_candidates):
try:
resp = await self._client.get(cdn_url)
resp.raise_for_status()
data = resp.content
break
except Exception as e:
has_more_candidates = idx + 1 < len(download_candidates)
should_fallback = (
download_source == "full_url"
and has_more_candidates
and self._is_retryable_media_download_error(e)
)
if should_fallback:
logger.warning(
"WeChat media download failed via full_url, falling back to encrypt_query_param: type={} err={}",
media_type,
e,
)
continue
raise
if aes_key_b64 and data:
data = _decrypt_aes_ecb(data, aes_key_b64)
elif not aes_key_b64:
logger.debug("No AES key for {} item, using raw bytes", media_type)
if not data:
return None
@ -709,12 +860,12 @@ class WeixinChannel(BaseChannel):
ext = _ext_for_type(media_type)
if not filename:
ts = int(time.time())
h = abs(hash(encrypt_query_param)) % 100000
hash_seed = encrypt_query_param or full_url
h = abs(hash(hash_seed)) % 100000
filename = f"{media_type}_{ts}_{h}{ext}"
safe_name = os.path.basename(filename)
file_path = media_dir / safe_name
file_path.write_bytes(data)
logger.debug("Downloaded WeChat {} to {}", media_type, file_path)
return str(file_path)
except Exception as e:
@ -725,14 +876,76 @@ class WeixinChannel(BaseChannel):
# Outbound (matches send.ts buildTextMessageReq + sendMessageWeixin)
# ------------------------------------------------------------------
async def _get_typing_ticket(self, user_id: str, context_token: str = "") -> str:
"""Get typing ticket with per-user refresh + failure backoff cache."""
now = time.time()
entry = self._typing_tickets.get(user_id)
if entry and now < float(entry.get("next_fetch_at", 0)):
return str(entry.get("ticket", "") or "")
body: dict[str, Any] = {
"ilink_user_id": user_id,
"context_token": context_token or None,
"base_info": BASE_INFO,
}
data = await self._api_post("ilink/bot/getconfig", body)
if data.get("ret", 0) == 0:
ticket = str(data.get("typing_ticket", "") or "")
self._typing_tickets[user_id] = {
"ticket": ticket,
"ever_succeeded": True,
"next_fetch_at": now + (random.random() * TYPING_TICKET_TTL_S),
"retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S,
}
return ticket
prev_delay = float(entry.get("retry_delay_s", CONFIG_CACHE_INITIAL_RETRY_S)) if entry else CONFIG_CACHE_INITIAL_RETRY_S
next_delay = min(prev_delay * 2, CONFIG_CACHE_MAX_RETRY_S)
if entry:
entry["next_fetch_at"] = now + next_delay
entry["retry_delay_s"] = next_delay
return str(entry.get("ticket", "") or "")
self._typing_tickets[user_id] = {
"ticket": "",
"ever_succeeded": False,
"next_fetch_at": now + CONFIG_CACHE_INITIAL_RETRY_S,
"retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S,
}
return ""
async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None:
"""Best-effort sendtyping wrapper."""
if not typing_ticket:
return
body: dict[str, Any] = {
"ilink_user_id": user_id,
"typing_ticket": typing_ticket,
"status": status,
"base_info": BASE_INFO,
}
await self._api_post("ilink/bot/sendtyping", body)
async def _typing_keepalive_loop(self, user_id: str, typing_ticket: str, stop_event: asyncio.Event) -> None:
try:
while not stop_event.is_set():
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
if stop_event.is_set():
break
try:
await self._send_typing(user_id, typing_ticket, TYPING_STATUS_TYPING)
except Exception:
pass
finally:
pass
async def send(self, msg: OutboundMessage) -> None:
if not self._client or not self._token:
logger.warning("WeChat client not initialized or not authenticated")
return
try:
self._assert_session_active()
except RuntimeError as e:
logger.warning("WeChat send blocked: {}", e)
except RuntimeError:
return
is_progress = bool((msg.metadata or {}).get("_progress", False))
@ -748,94 +961,103 @@ class WeixinChannel(BaseChannel):
)
return
# --- Send media files first (following Telegram channel pattern) ---
for media_path in (msg.media or []):
try:
await self._send_media_file(msg.chat_id, media_path, ctx_token)
except Exception as e:
filename = Path(media_path).name
logger.error("Failed to send WeChat media {}: {}", media_path, e)
# Notify user about failure via text
await self._send_text(
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
)
typing_ticket = ""
try:
typing_ticket = await self._get_typing_ticket(msg.chat_id, ctx_token)
except Exception:
typing_ticket = ""
# --- Send text content ---
if not content:
return
if typing_ticket:
try:
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_TYPING)
except Exception:
pass
typing_keepalive_stop = asyncio.Event()
typing_keepalive_task: asyncio.Task | None = None
if typing_ticket:
typing_keepalive_task = asyncio.create_task(
self._typing_keepalive_loop(msg.chat_id, typing_ticket, typing_keepalive_stop)
)
try:
# --- Send media files first (following Telegram channel pattern) ---
for media_path in (msg.media or []):
try:
await self._send_media_file(msg.chat_id, media_path, ctx_token)
except Exception as e:
filename = Path(media_path).name
logger.error("Failed to send WeChat media {}: {}", media_path, e)
# Notify user about failure via text
await self._send_text(
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
)
# --- Send text content ---
if not content:
return
chunks = split_message(content, WEIXIN_MAX_MESSAGE_LEN)
for chunk in chunks:
await self._send_text(msg.chat_id, chunk, ctx_token)
except Exception as e:
logger.error("Error sending WeChat message: {}", e)
raise
finally:
if typing_keepalive_task:
typing_keepalive_stop.set()
typing_keepalive_task.cancel()
try:
await typing_keepalive_task
except asyncio.CancelledError:
pass
async def _get_typing_ticket(self, user_id: str, context_token: str) -> str:
"""Fetch and cache typing ticket for a user/context pair."""
if not self._client or not self._token or not user_id or not context_token:
return ""
cached = self._typing_tickets.get(user_id, "")
if cached:
return cached
try:
data = await self._api_post(
"ilink/bot/getconfig",
{
"ilink_user_id": user_id,
"context_token": context_token,
},
)
except Exception as e:
logger.debug("WeChat getconfig failed for {}: {}", user_id, e)
return ""
ticket = str(data.get("typing_ticket") or "").strip()
if ticket:
self._typing_tickets[user_id] = ticket
self._save_state()
return ticket
if typing_ticket and not is_progress:
try:
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
except Exception:
pass
async def _send_typing_status(self, to_user_id: str, typing_ticket: str, status: int) -> None:
if not typing_ticket:
return
await self._api_post(
"ilink/bot/sendtyping",
{
"ilink_user_id": to_user_id,
"typing_ticket": typing_ticket,
"status": status,
},
)
async def _start_typing(self, chat_id: str, context_token: str) -> None:
if not self._client or not self._token or not chat_id or not context_token:
async def _start_typing(self, chat_id: str, context_token: str = "") -> None:
"""Start typing indicator immediately when a message is received."""
if not self._client or not self._token or not chat_id:
return
await self._stop_typing(chat_id, clear_remote=False)
ticket = await self._get_typing_ticket(chat_id, context_token)
if not ticket:
return
try:
await self._send_typing_status(chat_id, ticket, 1)
ticket = await self._get_typing_ticket(chat_id, context_token)
if not ticket:
return
await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING)
except Exception as e:
logger.debug("WeChat typing indicator failed for {}: {}", chat_id, e)
logger.debug("WeChat typing indicator start failed for {}: {}", chat_id, e)
return
async def typing_loop() -> None:
try:
while self._running:
await asyncio.sleep(5)
await self._send_typing_status(chat_id, ticket, 1)
except asyncio.CancelledError:
pass
except Exception as e:
logger.debug("WeChat typing keepalive stopped for {}: {}", chat_id, e)
stop_event = asyncio.Event()
self._typing_tasks[chat_id] = asyncio.create_task(typing_loop())
async def keepalive() -> None:
try:
while not stop_event.is_set():
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
if stop_event.is_set():
break
try:
await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING)
except Exception:
pass
finally:
pass
task = asyncio.create_task(keepalive())
task._typing_stop_event = stop_event # type: ignore[attr-defined]
self._typing_tasks[chat_id] = task
async def _stop_typing(self, chat_id: str, *, clear_remote: bool) -> None:
"""Stop typing indicator for a chat."""
task = self._typing_tasks.pop(chat_id, None)
if task and not task.done():
stop_event = getattr(task, "_typing_stop_event", None)
if stop_event:
stop_event.set()
task.cancel()
try:
await task
@ -843,11 +1065,12 @@ class WeixinChannel(BaseChannel):
pass
if not clear_remote:
return
ticket = self._typing_tickets.get(chat_id, "")
entry = self._typing_tickets.get(chat_id)
ticket = str(entry.get("ticket", "") or "") if isinstance(entry, dict) else ""
if not ticket:
return
try:
await self._send_typing_status(chat_id, ticket, 2)
await self._send_typing(chat_id, ticket, TYPING_STATUS_CANCEL)
except Exception as e:
logger.debug("WeChat typing clear failed for {}: {}", chat_id, e)
@ -923,6 +1146,10 @@ class WeixinChannel(BaseChannel):
upload_type = UPLOAD_MEDIA_VIDEO
item_type = ITEM_VIDEO
item_key = "video_item"
elif ext in _VOICE_EXTS:
upload_type = UPLOAD_MEDIA_VOICE
item_type = ITEM_VOICE
item_key = "voice_item"
else:
upload_type = UPLOAD_MEDIA_FILE
item_type = ITEM_FILE
@ -936,7 +1163,7 @@ class WeixinChannel(BaseChannel):
# Matches aesEcbPaddedSize: Math.ceil((size + 1) / 16) * 16
padded_size = ((raw_size + 1 + 15) // 16) * 16
# Step 1: Get upload URL (upload_param) from server
# Step 1: Get upload URL from server (prefer upload_full_url, fallback to upload_param)
file_key = os.urandom(16).hex()
upload_body: dict[str, Any] = {
"filekey": file_key,
@ -951,22 +1178,27 @@ class WeixinChannel(BaseChannel):
assert self._client is not None
upload_resp = await self._api_post("ilink/bot/getuploadurl", upload_body)
logger.debug("WeChat getuploadurl response: {}", upload_resp)
upload_param = upload_resp.get("upload_param", "")
if not upload_param:
raise RuntimeError(f"getuploadurl returned no upload_param: {upload_resp}")
upload_full_url = str(upload_resp.get("upload_full_url", "") or "").strip()
upload_param = str(upload_resp.get("upload_param", "") or "")
if not upload_full_url and not upload_param:
raise RuntimeError(
"getuploadurl returned no upload URL "
f"(need upload_full_url or upload_param): {upload_resp}"
)
# Step 2: AES-128-ECB encrypt and POST to CDN
aes_key_b64 = base64.b64encode(aes_key_raw).decode()
encrypted_data = _encrypt_aes_ecb(raw_data, aes_key_b64)
cdn_upload_url = (
f"{self.config.cdn_base_url}/upload"
f"?encrypted_query_param={quote(upload_param)}"
f"&filekey={quote(file_key)}"
)
logger.debug("WeChat CDN POST url={} ciphertextSize={}", cdn_upload_url[:80], len(encrypted_data))
if upload_full_url:
cdn_upload_url = upload_full_url
else:
cdn_upload_url = (
f"{self.config.cdn_base_url}/upload"
f"?encrypted_query_param={quote(upload_param)}"
f"&filekey={quote(file_key)}"
)
cdn_resp = await self._client.post(
cdn_upload_url,
@ -982,7 +1214,6 @@ class WeixinChannel(BaseChannel):
"CDN upload response missing x-encrypted-param header; "
f"status={cdn_resp.status_code} headers={dict(cdn_resp.headers)}"
)
logger.debug("WeChat CDN upload success for {}, got download_param", p.name)
# Step 3: Send message with the media item
# aes_key for CDNMedia is the hex key encoded as base64
@ -1031,7 +1262,6 @@ class WeixinChannel(BaseChannel):
raise RuntimeError(
f"WeChat send media error (code {errcode}): {data.get('errmsg', '')}"
)
logger.info("WeChat media sent: {} (type={})", p.name, item_key)
# ---------------------------------------------------------------------------
@ -1103,23 +1333,42 @@ def _decrypt_aes_ecb(data: bytes, aes_key_b64: str) -> bytes:
logger.warning("Failed to parse AES key, returning raw data: {}", e)
return data
decrypted: bytes | None = None
try:
from Crypto.Cipher import AES
cipher = AES.new(key, AES.MODE_ECB)
return cipher.decrypt(data) # pycryptodome auto-strips PKCS7 with unpad
decrypted = cipher.decrypt(data)
except ImportError:
pass
try:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
if decrypted is None:
try:
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
cipher_obj = Cipher(algorithms.AES(key), modes.ECB())
decryptor = cipher_obj.decryptor()
return decryptor.update(data) + decryptor.finalize()
except ImportError:
logger.warning("Cannot decrypt media: install 'pycryptodome' or 'cryptography'")
cipher_obj = Cipher(algorithms.AES(key), modes.ECB())
decryptor = cipher_obj.decryptor()
decrypted = decryptor.update(data) + decryptor.finalize()
except ImportError:
logger.warning("Cannot decrypt media: install 'pycryptodome' or 'cryptography'")
return data
return _pkcs7_unpad_safe(decrypted)
def _pkcs7_unpad_safe(data: bytes, block_size: int = 16) -> bytes:
"""Safely remove PKCS7 padding when valid; otherwise return original bytes."""
if not data:
return data
if len(data) % block_size != 0:
return data
pad_len = data[-1]
if pad_len < 1 or pad_len > block_size:
return data
if data[-pad_len:] != bytes([pad_len]) * pad_len:
return data
return data[:-pad_len]
def _ext_for_type(media_type: str) -> str:

View File

@ -415,6 +415,9 @@ def _make_provider(config: Config):
api_base=p.api_base,
default_model=model,
)
elif backend == "github_copilot":
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
provider = GitHubCopilotProvider(default_model=model)
elif backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider
provider = AnthropicProvider(
@ -491,6 +494,91 @@ def _migrate_cron_store(config: "Config") -> None:
shutil.move(str(legacy_path), str(new_path))
# ============================================================================
# OpenAI-Compatible API Server
# ============================================================================
@app.command()
def serve(
port: int | None = typer.Option(None, "--port", "-p", help="API server port"),
host: str | None = typer.Option(None, "--host", "-H", help="Bind address"),
timeout: float | None = typer.Option(None, "--timeout", "-t", help="Per-request timeout (seconds)"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"),
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
):
"""Start the OpenAI-compatible API server (/v1/chat/completions)."""
try:
from aiohttp import web # noqa: F401
except ImportError:
console.print("[red]aiohttp is required. Install with: pip install 'nanobot-ai[api]'[/red]")
raise typer.Exit(1)
from loguru import logger
from nanobot.agent.loop import AgentLoop
from nanobot.api.server import create_app
from nanobot.bus.queue import MessageBus
from nanobot.session.manager import SessionManager
if verbose:
logger.enable("nanobot")
else:
logger.disable("nanobot")
runtime_config = _load_runtime_config(config, workspace)
api_cfg = runtime_config.api
host = host if host is not None else api_cfg.host
port = port if port is not None else api_cfg.port
timeout = timeout if timeout is not None else api_cfg.timeout
sync_workspace_templates(runtime_config.workspace_path)
bus = MessageBus()
provider = _make_provider(runtime_config)
session_manager = SessionManager(runtime_config.workspace_path)
agent_loop = AgentLoop(
bus=bus,
provider=provider,
workspace=runtime_config.workspace_path,
model=runtime_config.agents.defaults.model,
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
web_search_config=runtime_config.tools.web.search,
web_proxy=runtime_config.tools.web.proxy or None,
exec_config=runtime_config.tools.exec,
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
session_manager=session_manager,
mcp_servers=runtime_config.tools.mcp_servers,
channels_config=runtime_config.channels,
timezone=runtime_config.agents.defaults.timezone,
)
model_name = runtime_config.agents.defaults.model
console.print(f"{__logo__} Starting OpenAI-compatible API server")
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
console.print(f" [cyan]Model[/cyan] : {model_name}")
console.print(" [cyan]Session[/cyan] : api:default")
console.print(f" [cyan]Timeout[/cyan] : {timeout}s")
if host in {"0.0.0.0", "::"}:
console.print(
"[yellow]Warning:[/yellow] API is bound to all interfaces. "
"Only do this behind a trusted network boundary, firewall, or reverse proxy."
)
console.print()
api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout)
async def on_startup(_app):
await agent_loop._connect_mcp()
async def on_cleanup(_app):
await agent_loop.close_mcp()
api_app.on_startup.append(on_startup)
api_app.on_cleanup.append(on_cleanup)
web.run_app(api_app, host=host, port=port, print=lambda msg: logger.info(msg))
# ============================================================================
# Gateway / Server
# ============================================================================
@ -1204,26 +1292,16 @@ def _login_openai_codex() -> None:
@_register_login("github_copilot")
def _login_github_copilot() -> None:
import asyncio
from openai import AsyncOpenAI
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
async def _trigger():
client = AsyncOpenAI(
api_key="dummy",
base_url="https://api.githubcopilot.com",
)
await client.chat.completions.create(
model="gpt-4o",
messages=[{"role": "user", "content": "hi"}],
max_tokens=1,
)
try:
asyncio.run(_trigger())
console.print("[green]✓ Authenticated with GitHub Copilot[/green]")
from nanobot.providers.github_copilot_provider import login_github_copilot
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
token = login_github_copilot(
print_fn=lambda s: console.print(s),
prompt_fn=lambda s: typer.prompt(s),
)
account = token.account_id or "GitHub"
console.print(f"[green]✓ Authenticated with GitHub Copilot[/green] [dim]{account}[/dim]")
except Exception as e:
console.print(f"[red]Authentication error: {e}[/red]")
raise typer.Exit(1)

View File

@ -84,6 +84,16 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
async def cmd_help(ctx: CommandContext) -> OutboundMessage:
"""Return available slash commands."""
return OutboundMessage(
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content=build_help_text(),
metadata={"render_as": "text"},
)
def build_help_text() -> str:
"""Build canonical help text shared across channels."""
lines = [
"🐈 nanobot commands:",
"/new — Start a new conversation",
@ -92,12 +102,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage:
"/status — Show bot status",
"/help — Show available commands",
]
return OutboundMessage(
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
content="\n".join(lines),
metadata={"render_as": "text"},
)
return "\n".join(lines)
def register_builtin_commands(router: CommandRouter) -> None:

View File

@ -96,6 +96,14 @@ class HeartbeatConfig(Base):
keep_recent_messages: int = 8
class ApiConfig(Base):
"""OpenAI-compatible API server configuration."""
host: str = "127.0.0.1" # Safer default: local-only bind.
port: int = 8900
timeout: float = 120.0 # Per-request timeout in seconds.
class GatewayConfig(Base):
"""Gateway/server configuration."""
@ -156,6 +164,7 @@ class Config(BaseSettings):
agents: AgentsConfig = Field(default_factory=AgentsConfig)
channels: ChannelsConfig = Field(default_factory=ChannelsConfig)
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
api: ApiConfig = Field(default_factory=ApiConfig)
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
tools: ToolsConfig = Field(default_factory=ToolsConfig)

174
nanobot/nanobot.py Normal file
View File

@ -0,0 +1,174 @@
"""High-level programmatic interface to nanobot."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from nanobot.agent.hook import AgentHook
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
@dataclass(slots=True)
class RunResult:
"""Result of a single agent run."""
content: str
tools_used: list[str]
messages: list[dict[str, Any]]
class Nanobot:
"""Programmatic facade for running the nanobot agent.
Usage::
bot = Nanobot.from_config()
result = await bot.run("Summarize this repo", hooks=[MyHook()])
print(result.content)
"""
def __init__(self, loop: AgentLoop) -> None:
self._loop = loop
@classmethod
def from_config(
cls,
config_path: str | Path | None = None,
*,
workspace: str | Path | None = None,
) -> Nanobot:
"""Create a Nanobot instance from a config file.
Args:
config_path: Path to ``config.json``. Defaults to
``~/.nanobot/config.json``.
workspace: Override the workspace directory from config.
"""
from nanobot.config.loader import load_config
from nanobot.config.schema import Config
resolved: Path | None = None
if config_path is not None:
resolved = Path(config_path).expanduser().resolve()
if not resolved.exists():
raise FileNotFoundError(f"Config not found: {resolved}")
config: Config = load_config(resolved)
if workspace is not None:
config.agents.defaults.workspace = str(
Path(workspace).expanduser().resolve()
)
provider = _make_provider(config)
bus = MessageBus()
defaults = config.agents.defaults
loop = AgentLoop(
bus=bus,
provider=provider,
workspace=config.workspace_path,
model=defaults.model,
max_iterations=defaults.max_tool_iterations,
context_window_tokens=defaults.context_window_tokens,
web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,
restrict_to_workspace=config.tools.restrict_to_workspace,
mcp_servers=config.tools.mcp_servers,
timezone=defaults.timezone,
)
return cls(loop)
async def run(
self,
message: str,
*,
session_key: str = "sdk:default",
hooks: list[AgentHook] | None = None,
) -> RunResult:
"""Run the agent once and return the result.
Args:
message: The user message to process.
session_key: Session identifier for conversation isolation.
Different keys get independent history.
hooks: Optional lifecycle hooks for this run.
"""
prev = self._loop._extra_hooks
if hooks is not None:
self._loop._extra_hooks = list(hooks)
try:
response = await self._loop.process_direct(
message, session_key=session_key,
)
finally:
self._loop._extra_hooks = prev
content = (response.content if response else None) or ""
return RunResult(content=content, tools_used=[], messages=[])
def _make_provider(config: Any) -> Any:
"""Create the LLM provider from config (extracted from CLI)."""
from nanobot.providers.base import GenerationSettings
from nanobot.providers.registry import find_by_name
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)
p = config.get_provider(model)
spec = find_by_name(provider_name) if provider_name else None
backend = spec.backend if spec else "openai_compat"
if backend == "azure_openai":
if not p or not p.api_key or not p.api_base:
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
elif backend == "openai_compat" and not model.startswith("bedrock/"):
needs_key = not (p and p.api_key)
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
if needs_key and not exempt:
raise ValueError(f"No API key configured for provider '{provider_name}'.")
if backend == "openai_codex":
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
provider = OpenAICodexProvider(default_model=model)
elif backend == "github_copilot":
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
provider = GitHubCopilotProvider(default_model=model)
elif backend == "azure_openai":
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
provider = AzureOpenAIProvider(
api_key=p.api_key, api_base=p.api_base, default_model=model
)
elif backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider
provider = AnthropicProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
)
else:
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
provider = OpenAICompatProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
spec=spec,
)
defaults = config.agents.defaults
provider.generation = GenerationSettings(
temperature=defaults.temperature,
max_tokens=defaults.max_tokens,
reasoning_effort=defaults.reasoning_effort,
)
return provider

View File

@ -13,6 +13,7 @@ __all__ = [
"AnthropicProvider",
"OpenAICompatProvider",
"OpenAICodexProvider",
"GitHubCopilotProvider",
"AzureOpenAIProvider",
]
@ -20,12 +21,14 @@ _LAZY_IMPORTS = {
"AnthropicProvider": ".anthropic_provider",
"OpenAICompatProvider": ".openai_compat_provider",
"OpenAICodexProvider": ".openai_codex_provider",
"GitHubCopilotProvider": ".github_copilot_provider",
"AzureOpenAIProvider": ".azure_openai_provider",
}
if TYPE_CHECKING:
from nanobot.providers.anthropic_provider import AnthropicProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
from nanobot.providers.openai_codex_provider import OpenAICodexProvider

View File

@ -370,15 +370,22 @@ class AnthropicProvider(LLMProvider):
usage: dict[str, int] = {}
if response.usage:
input_tokens = response.usage.input_tokens
cache_creation = getattr(response.usage, "cache_creation_input_tokens", 0) or 0
cache_read = getattr(response.usage, "cache_read_input_tokens", 0) or 0
total_prompt_tokens = input_tokens + cache_creation + cache_read
usage = {
"prompt_tokens": response.usage.input_tokens,
"prompt_tokens": total_prompt_tokens,
"completion_tokens": response.usage.output_tokens,
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
"total_tokens": total_prompt_tokens + response.usage.output_tokens,
}
for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"):
val = getattr(response.usage, attr, 0)
if val:
usage[attr] = val
# Normalize to cached_tokens for downstream consistency.
if cache_read:
usage["cached_tokens"] = cache_read
return LLMResponse(
content="".join(content_parts) or None,

View File

@ -0,0 +1,257 @@
"""GitHub Copilot OAuth-backed provider."""
from __future__ import annotations
import time
import webbrowser
from collections.abc import Callable
import httpx
from oauth_cli_kit.models import OAuthToken
from oauth_cli_kit.storage import FileTokenStorage
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
DEFAULT_GITHUB_DEVICE_CODE_URL = "https://github.com/login/device/code"
DEFAULT_GITHUB_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
DEFAULT_GITHUB_USER_URL = "https://api.github.com/user"
DEFAULT_COPILOT_TOKEN_URL = "https://api.github.com/copilot_internal/v2/token"
DEFAULT_COPILOT_BASE_URL = "https://api.githubcopilot.com"
GITHUB_COPILOT_CLIENT_ID = "Iv1.b507a08c87ecfe98"
GITHUB_COPILOT_SCOPE = "read:user"
TOKEN_FILENAME = "github-copilot.json"
TOKEN_APP_NAME = "nanobot"
USER_AGENT = "nanobot/0.1"
EDITOR_VERSION = "vscode/1.99.0"
EDITOR_PLUGIN_VERSION = "copilot-chat/0.26.0"
_EXPIRY_SKEW_SECONDS = 60
_LONG_LIVED_TOKEN_SECONDS = 315360000
def _storage() -> FileTokenStorage:
return FileTokenStorage(
token_filename=TOKEN_FILENAME,
app_name=TOKEN_APP_NAME,
import_codex_cli=False,
)
def _copilot_headers(token: str) -> dict[str, str]:
return {
"Authorization": f"token {token}",
"Accept": "application/json",
"User-Agent": USER_AGENT,
"Editor-Version": EDITOR_VERSION,
"Editor-Plugin-Version": EDITOR_PLUGIN_VERSION,
}
def _load_github_token() -> OAuthToken | None:
token = _storage().load()
if not token or not token.access:
return None
return token
def get_github_copilot_login_status() -> OAuthToken | None:
"""Return the persisted GitHub OAuth token if available."""
return _load_github_token()
def login_github_copilot(
print_fn: Callable[[str], None] | None = None,
prompt_fn: Callable[[str], str] | None = None,
) -> OAuthToken:
"""Run GitHub device flow and persist the GitHub OAuth token used for Copilot."""
del prompt_fn
printer = print_fn or print
timeout = httpx.Timeout(20.0, connect=20.0)
with httpx.Client(timeout=timeout, follow_redirects=True, trust_env=True) as client:
response = client.post(
DEFAULT_GITHUB_DEVICE_CODE_URL,
headers={"Accept": "application/json", "User-Agent": USER_AGENT},
data={"client_id": GITHUB_COPILOT_CLIENT_ID, "scope": GITHUB_COPILOT_SCOPE},
)
response.raise_for_status()
payload = response.json()
device_code = str(payload["device_code"])
user_code = str(payload["user_code"])
verify_url = str(payload.get("verification_uri") or payload.get("verification_uri_complete") or "")
verify_complete = str(payload.get("verification_uri_complete") or verify_url)
interval = max(1, int(payload.get("interval") or 5))
expires_in = int(payload.get("expires_in") or 900)
printer(f"Open: {verify_url}")
printer(f"Code: {user_code}")
if verify_complete:
try:
webbrowser.open(verify_complete)
except Exception:
pass
deadline = time.time() + expires_in
current_interval = interval
access_token = None
token_expires_in = _LONG_LIVED_TOKEN_SECONDS
while time.time() < deadline:
poll = client.post(
DEFAULT_GITHUB_ACCESS_TOKEN_URL,
headers={"Accept": "application/json", "User-Agent": USER_AGENT},
data={
"client_id": GITHUB_COPILOT_CLIENT_ID,
"device_code": device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
},
)
poll.raise_for_status()
poll_payload = poll.json()
access_token = poll_payload.get("access_token")
if access_token:
token_expires_in = int(poll_payload.get("expires_in") or _LONG_LIVED_TOKEN_SECONDS)
break
error = poll_payload.get("error")
if error == "authorization_pending":
time.sleep(current_interval)
continue
if error == "slow_down":
current_interval += 5
time.sleep(current_interval)
continue
if error == "expired_token":
raise RuntimeError("GitHub device code expired. Please run login again.")
if error == "access_denied":
raise RuntimeError("GitHub device flow was denied.")
if error:
desc = poll_payload.get("error_description") or error
raise RuntimeError(str(desc))
time.sleep(current_interval)
else:
raise RuntimeError("GitHub device flow timed out.")
user = client.get(
DEFAULT_GITHUB_USER_URL,
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/vnd.github+json",
"User-Agent": USER_AGENT,
},
)
user.raise_for_status()
user_payload = user.json()
account_id = user_payload.get("login") or str(user_payload.get("id") or "") or None
expires_ms = int((time.time() + token_expires_in) * 1000)
token = OAuthToken(
access=str(access_token),
refresh="",
expires=expires_ms,
account_id=str(account_id) if account_id else None,
)
_storage().save(token)
return token
class GitHubCopilotProvider(OpenAICompatProvider):
"""Provider that exchanges a stored GitHub OAuth token for Copilot access tokens."""
def __init__(self, default_model: str = "github-copilot/gpt-4.1"):
from nanobot.providers.registry import find_by_name
self._copilot_access_token: str | None = None
self._copilot_expires_at: float = 0.0
super().__init__(
api_key="no-key",
api_base=DEFAULT_COPILOT_BASE_URL,
default_model=default_model,
extra_headers={
"Editor-Version": EDITOR_VERSION,
"Editor-Plugin-Version": EDITOR_PLUGIN_VERSION,
"User-Agent": USER_AGENT,
},
spec=find_by_name("github_copilot"),
)
async def _get_copilot_access_token(self) -> str:
now = time.time()
if self._copilot_access_token and now < self._copilot_expires_at - _EXPIRY_SKEW_SECONDS:
return self._copilot_access_token
github_token = _load_github_token()
if not github_token or not github_token.access:
raise RuntimeError("GitHub Copilot is not logged in. Run: nanobot provider login github-copilot")
timeout = httpx.Timeout(20.0, connect=20.0)
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True, trust_env=True) as client:
response = await client.get(
DEFAULT_COPILOT_TOKEN_URL,
headers=_copilot_headers(github_token.access),
)
response.raise_for_status()
payload = response.json()
token = payload.get("token")
if not token:
raise RuntimeError("GitHub Copilot token exchange returned no token.")
expires_at = payload.get("expires_at")
if isinstance(expires_at, (int, float)):
self._copilot_expires_at = float(expires_at)
else:
refresh_in = payload.get("refresh_in") or 1500
self._copilot_expires_at = time.time() + int(refresh_in)
self._copilot_access_token = str(token)
return self._copilot_access_token
async def _refresh_client_api_key(self) -> str:
token = await self._get_copilot_access_token()
self.api_key = token
self._client.api_key = token
return token
async def chat(
self,
messages: list[dict[str, object]],
tools: list[dict[str, object]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, object] | None = None,
):
await self._refresh_client_api_key()
return await super().chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
)
async def chat_stream(
self,
messages: list[dict[str, object]],
tools: list[dict[str, object]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, object] | None = None,
on_content_delta: Callable[[str], None] | None = None,
):
await self._refresh_client_api_key()
return await super().chat_stream(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
on_content_delta=on_content_delta,
)

View File

@ -235,7 +235,9 @@ class OpenAICompatProvider(LLMProvider):
spec = self._spec
if spec and spec.supports_prompt_caching:
messages, tools = self._apply_cache_control(messages, tools)
model_name = model or self.default_model
if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")):
messages, tools = self._apply_cache_control(messages, tools)
if spec and spec.strip_model_prefix:
model_name = model_name.split("/")[-1]
@ -308,6 +310,13 @@ class OpenAICompatProvider(LLMProvider):
@classmethod
def _extract_usage(cls, response: Any) -> dict[str, int]:
"""Extract token usage from an OpenAI-compatible response.
Handles both dict-based (raw JSON) and object-based (SDK Pydantic)
responses. Provider-specific ``cached_tokens`` fields are normalised
under a single key; see the priority chain inside for details.
"""
# --- resolve usage object ---
usage_obj = None
response_map = cls._maybe_mapping(response)
if response_map is not None:
@ -317,19 +326,53 @@ class OpenAICompatProvider(LLMProvider):
usage_map = cls._maybe_mapping(usage_obj)
if usage_map is not None:
return {
result = {
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
"total_tokens": int(usage_map.get("total_tokens") or 0),
}
if usage_obj:
return {
elif usage_obj:
result = {
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
}
return {}
else:
return {}
# --- cached_tokens (normalised across providers) ---
# Try nested paths first (dict), fall back to attribute (SDK object).
# Priority order ensures the most specific field wins.
for path in (
("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI
("cached_tokens",), # StepFun/Moonshot (top-level)
("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow
):
cached = cls._get_nested_int(usage_map, path)
if not cached and usage_obj:
cached = cls._get_nested_int(usage_obj, path)
if cached:
result["cached_tokens"] = cached
break
return result
@staticmethod
def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int:
"""Drill into *obj* by *path* segments and return an ``int`` value.
Supports both dict-key access and attribute access so it works
uniformly with raw JSON dicts **and** SDK Pydantic models.
"""
current = obj
for segment in path:
if current is None:
return 0
if isinstance(current, dict):
current = current.get(segment)
else:
current = getattr(current, segment, None)
return int(current or 0) if current is not None else 0
def _parse(self, response: Any) -> LLMResponse:
if isinstance(response, str):
@ -586,4 +629,4 @@ class OpenAICompatProvider(LLMProvider):
return self._handle_error(e)
def get_default_model(self) -> str:
return self.default_model
return self.default_model

View File

@ -34,7 +34,7 @@ class ProviderSpec:
display_name: str = "" # shown in `nanobot status`
# which provider implementation to use
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex"
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" | "github_copilot"
backend: str = "openai_compat"
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
@ -218,8 +218,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
keywords=("github_copilot", "copilot"),
env_key="",
display_name="Github Copilot",
backend="openai_compat",
backend="github_copilot",
default_api_base="https://api.githubcopilot.com",
strip_model_prefix=True,
is_oauth=True,
),
# DeepSeek: OpenAI-compatible at api.deepseek.com

View File

@ -124,8 +124,8 @@ def build_assistant_message(
msg: dict[str, Any] = {"role": "assistant", "content": content}
if tool_calls:
msg["tool_calls"] = tool_calls
if reasoning_content is not None:
msg["reasoning_content"] = reasoning_content
if reasoning_content is not None or thinking_blocks:
msg["reasoning_content"] = reasoning_content if reasoning_content is not None else ""
if thinking_blocks:
msg["thinking_blocks"] = thinking_blocks
return msg
@ -255,14 +255,18 @@ def build_status_content(
)
last_in = last_usage.get("prompt_tokens", 0)
last_out = last_usage.get("completion_tokens", 0)
cached = last_usage.get("cached_tokens", 0)
ctx_total = max(context_window_tokens, 0)
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a"
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
if cached and last_in:
token_line += f" ({cached * 100 // last_in}% cached)"
return "\n".join([
f"\U0001f408 nanobot v{version}",
f"\U0001f9e0 Model: {model}",
f"\U0001f4ca Tokens: {last_in} in / {last_out} out",
token_line,
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
f"\U0001f4ac Session: {session_msg_count} messages",
f"\u23f1 Uptime: {uptime}",

View File

@ -51,6 +51,9 @@ dependencies = [
]
[project.optional-dependencies]
api = [
"aiohttp>=3.9.0,<4.0.0",
]
wecom = [
"wecom-aibot-sdk-python>=0.1.5",
]
@ -64,12 +67,16 @@ matrix = [
"mistune>=3.0.0,<4.0.0",
"nh3>=0.2.17,<1.0.0",
]
discord = [
"discord.py>=2.5.2,<3.0.0",
]
langsmith = [
"langsmith>=0.1.0",
]
dev = [
"pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=1.3.0,<2.0.0",
"aiohttp>=3.9.0,<4.0.0",
"pytest-cov>=6.0.0,<7.0.0",
"ruff>=0.1.0",
]

View File

@ -0,0 +1,351 @@
"""Tests for CompositeHook fan-out, error isolation, and integration."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
def _ctx() -> AgentHookContext:
return AgentHookContext(iteration=0, messages=[])
# ---------------------------------------------------------------------------
# Fan-out: every hook is called in order
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_composite_fans_out_before_iteration():
calls: list[str] = []
class H(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
calls.append(f"A:{context.iteration}")
class H2(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
calls.append(f"B:{context.iteration}")
hook = CompositeHook([H(), H2()])
ctx = _ctx()
await hook.before_iteration(ctx)
assert calls == ["A:0", "B:0"]
@pytest.mark.asyncio
async def test_composite_fans_out_all_async_methods():
"""Verify all async methods fan out to every hook."""
events: list[str] = []
class RecordingHook(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
events.append("before_iteration")
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
events.append(f"on_stream:{delta}")
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
events.append(f"on_stream_end:{resuming}")
async def before_execute_tools(self, context: AgentHookContext) -> None:
events.append("before_execute_tools")
async def after_iteration(self, context: AgentHookContext) -> None:
events.append("after_iteration")
hook = CompositeHook([RecordingHook(), RecordingHook()])
ctx = _ctx()
await hook.before_iteration(ctx)
await hook.on_stream(ctx, "hi")
await hook.on_stream_end(ctx, resuming=True)
await hook.before_execute_tools(ctx)
await hook.after_iteration(ctx)
assert events == [
"before_iteration", "before_iteration",
"on_stream:hi", "on_stream:hi",
"on_stream_end:True", "on_stream_end:True",
"before_execute_tools", "before_execute_tools",
"after_iteration", "after_iteration",
]
# ---------------------------------------------------------------------------
# Error isolation: one hook raises, others still run
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_composite_error_isolation_before_iteration():
calls: list[str] = []
class Bad(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
raise RuntimeError("boom")
class Good(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
calls.append("good")
hook = CompositeHook([Bad(), Good()])
await hook.before_iteration(_ctx())
assert calls == ["good"]
@pytest.mark.asyncio
async def test_composite_error_isolation_on_stream():
calls: list[str] = []
class Bad(AgentHook):
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
raise RuntimeError("stream-boom")
class Good(AgentHook):
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
calls.append(delta)
hook = CompositeHook([Bad(), Good()])
await hook.on_stream(_ctx(), "delta")
assert calls == ["delta"]
@pytest.mark.asyncio
async def test_composite_error_isolation_all_async():
"""Error isolation for on_stream_end, before_execute_tools, after_iteration."""
calls: list[str] = []
class Bad(AgentHook):
async def on_stream_end(self, context, *, resuming):
raise RuntimeError("err")
async def before_execute_tools(self, context):
raise RuntimeError("err")
async def after_iteration(self, context):
raise RuntimeError("err")
class Good(AgentHook):
async def on_stream_end(self, context, *, resuming):
calls.append("on_stream_end")
async def before_execute_tools(self, context):
calls.append("before_execute_tools")
async def after_iteration(self, context):
calls.append("after_iteration")
hook = CompositeHook([Bad(), Good()])
ctx = _ctx()
await hook.on_stream_end(ctx, resuming=False)
await hook.before_execute_tools(ctx)
await hook.after_iteration(ctx)
assert calls == ["on_stream_end", "before_execute_tools", "after_iteration"]
# ---------------------------------------------------------------------------
# finalize_content: pipeline semantics (no error isolation)
# ---------------------------------------------------------------------------
def test_composite_finalize_content_pipeline():
class Upper(AgentHook):
def finalize_content(self, context, content):
return content.upper() if content else content
class Suffix(AgentHook):
def finalize_content(self, context, content):
return (content + "!") if content else content
hook = CompositeHook([Upper(), Suffix()])
result = hook.finalize_content(_ctx(), "hello")
assert result == "HELLO!"
def test_composite_finalize_content_none_passthrough():
hook = CompositeHook([AgentHook()])
assert hook.finalize_content(_ctx(), None) is None
def test_composite_finalize_content_ordering():
"""First hook transforms first, result feeds second hook."""
steps: list[str] = []
class H1(AgentHook):
def finalize_content(self, context, content):
steps.append(f"H1:{content}")
return content.upper()
class H2(AgentHook):
def finalize_content(self, context, content):
steps.append(f"H2:{content}")
return content + "!"
hook = CompositeHook([H1(), H2()])
result = hook.finalize_content(_ctx(), "hi")
assert result == "HI!"
assert steps == ["H1:hi", "H2:HI"]
# ---------------------------------------------------------------------------
# wants_streaming: any-semantics
# ---------------------------------------------------------------------------
def test_composite_wants_streaming_any_true():
class No(AgentHook):
def wants_streaming(self):
return False
class Yes(AgentHook):
def wants_streaming(self):
return True
hook = CompositeHook([No(), Yes(), No()])
assert hook.wants_streaming() is True
def test_composite_wants_streaming_all_false():
hook = CompositeHook([AgentHook(), AgentHook()])
assert hook.wants_streaming() is False
def test_composite_wants_streaming_empty():
hook = CompositeHook([])
assert hook.wants_streaming() is False
# ---------------------------------------------------------------------------
# Empty hooks list: behaves like no-op AgentHook
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_composite_empty_hooks_no_ops():
hook = CompositeHook([])
ctx = _ctx()
await hook.before_iteration(ctx)
await hook.on_stream(ctx, "delta")
await hook.on_stream_end(ctx, resuming=False)
await hook.before_execute_tools(ctx)
await hook.after_iteration(ctx)
assert hook.finalize_content(ctx, "test") == "test"
# ---------------------------------------------------------------------------
# Integration: AgentLoop with extra hooks
# ---------------------------------------------------------------------------
def _make_loop(tmp_path, hooks=None):
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.generation.max_tokens = 4096
with patch("nanobot.agent.loop.ContextBuilder"), \
patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager") as mock_sub_mgr, \
patch("nanobot.agent.loop.MemoryConsolidator"):
mock_sub_mgr.return_value.cancel_by_session = AsyncMock(return_value=0)
loop = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path, hooks=hooks,
)
return loop
@pytest.mark.asyncio
async def test_agent_loop_extra_hook_receives_calls(tmp_path):
"""Extra hook passed to AgentLoop is called alongside core LoopHook."""
from nanobot.providers.base import LLMResponse
events: list[str] = []
class TrackingHook(AgentHook):
async def before_iteration(self, context):
events.append(f"before_iter:{context.iteration}")
async def after_iteration(self, context):
events.append(f"after_iter:{context.iteration}")
loop = _make_loop(tmp_path, hooks=[TrackingHook()])
loop.provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(content="done", tool_calls=[], usage={})
)
loop.tools.get_definitions = MagicMock(return_value=[])
content, tools_used, messages = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}]
)
assert content == "done"
assert "before_iter:0" in events
assert "after_iter:0" in events
@pytest.mark.asyncio
async def test_agent_loop_extra_hook_error_isolation(tmp_path):
"""A faulty extra hook does not crash the agent loop."""
from nanobot.providers.base import LLMResponse
class BadHook(AgentHook):
async def before_iteration(self, context):
raise RuntimeError("I am broken")
loop = _make_loop(tmp_path, hooks=[BadHook()])
loop.provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(content="still works", tool_calls=[], usage={})
)
loop.tools.get_definitions = MagicMock(return_value=[])
content, _, _ = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}]
)
assert content == "still works"
@pytest.mark.asyncio
async def test_agent_loop_extra_hooks_do_not_swallow_loop_hook_errors(tmp_path):
"""Extra hooks must not change the core LoopHook failure behavior."""
from nanobot.providers.base import LLMResponse, ToolCallRequest
loop = _make_loop(tmp_path, hooks=[AgentHook()])
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})],
usage={},
))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")
async def bad_progress(*args, **kwargs):
raise RuntimeError("progress failed")
with pytest.raises(RuntimeError, match="progress failed"):
await loop._run_agent_loop([], on_progress=bad_progress)
@pytest.mark.asyncio
async def test_agent_loop_no_hooks_backward_compat(tmp_path):
"""Without hooks param, behavior is identical to before."""
from nanobot.providers.base import LLMResponse, ToolCallRequest
loop = _make_loop(tmp_path)
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})],
))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.execute = AsyncMock(return_value="ok")
loop.max_iterations = 2
content, tools_used, _ = await loop._run_agent_loop([])
assert content == (
"I reached the maximum number of tool call iterations (2) "
"without completing the task. You can try breaking the task into smaller steps."
)
assert tools_used == ["list_dir", "list_dir"]

View File

@ -333,3 +333,82 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
args = mgr._announce_result.await_args.args
assert args[3] == "Task completed but no final response was generated."
assert args[5] == "ok"
@pytest.mark.asyncio
async def test_runner_accumulates_usage_and_preserves_cached_tokens():
"""Runner should accumulate prompt/completion tokens across iterations
and preserve cached_tokens from provider responses."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80},
)
return LLMResponse(
content="done",
tool_calls=[],
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="file content")
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=3,
))
# Usage should be accumulated across iterations
assert result.usage["prompt_tokens"] == 300 # 100 + 200
assert result.usage["completion_tokens"] == 30 # 10 + 20
assert result.usage["cached_tokens"] == 230 # 80 + 150
@pytest.mark.asyncio
async def test_runner_passes_cached_tokens_to_hook_context():
"""Hook context.usage should contain cached_tokens."""
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_usage: list[dict] = []
class UsageHook(AgentHook):
async def after_iteration(self, context: AgentHookContext) -> None:
captured_usage.append(dict(context.usage))
async def chat_with_retry(**kwargs):
return LLMResponse(
content="done",
tool_calls=[],
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
hook=UsageHook(),
))
assert len(captured_usage) == 1
assert captured_usage[0]["cached_tokens"] == 150

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -116,6 +117,43 @@ class TestDispatch:
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert out.content == "hi"
@pytest.mark.asyncio
async def test_dispatch_streaming_preserves_message_metadata(self):
from nanobot.bus.events import InboundMessage
loop, bus = _make_loop()
msg = InboundMessage(
channel="matrix",
sender_id="u1",
chat_id="!room:matrix.org",
content="hello",
metadata={
"_wants_stream": True,
"thread_root_event_id": "$root1",
"thread_reply_to_event_id": "$reply1",
},
)
async def fake_process(_msg, *, on_stream=None, on_stream_end=None, **kwargs):
assert on_stream is not None
assert on_stream_end is not None
await on_stream("hi")
await on_stream_end(resuming=False)
return None
loop._process_message = fake_process
await loop._dispatch(msg)
first = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
second = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
assert first.metadata["thread_root_event_id"] == "$root1"
assert first.metadata["thread_reply_to_event_id"] == "$reply1"
assert first.metadata["_stream_delta"] is True
assert second.metadata["thread_root_event_id"] == "$root1"
assert second.metadata["thread_reply_to_event_id"] == "$reply1"
assert second.metadata["_stream_end"] is True
@pytest.mark.asyncio
async def test_processing_lock_serializes(self):
from nanobot.bus.events import InboundMessage, OutboundMessage
@ -222,6 +260,39 @@ class TestSubagentCancellation:
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
@pytest.mark.asyncio
async def test_subagent_exec_tool_not_registered_when_disabled(self, tmp_path):
from nanobot.agent.subagent import SubagentManager
from nanobot.bus.queue import MessageBus
from nanobot.config.schema import ExecToolConfig
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
exec_config=ExecToolConfig(enable=False),
)
mgr._announce_result = AsyncMock()
async def fake_run(spec):
assert spec.tools.get("exec") is None
return SimpleNamespace(
stop_reason="done",
final_content="done",
error=None,
tool_events=[],
)
mgr.runner.run = AsyncMock(side_effect=fake_run)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
mgr.runner.run.assert_awaited_once()
mgr._announce_result.assert_awaited_once()
@pytest.mark.asyncio
async def test_subagent_announces_error_when_tool_execution_fails(self, monkeypatch, tmp_path):
from nanobot.agent.subagent import SubagentManager

View File

@ -0,0 +1,676 @@
from __future__ import annotations
import asyncio
from pathlib import Path
from types import SimpleNamespace
import pytest
discord = pytest.importorskip("discord")
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig
from nanobot.command.builtin import build_help_text
# Minimal Discord client test double used to control startup/readiness behavior.
class _FakeDiscordClient:
instances: list["_FakeDiscordClient"] = []
start_error: Exception | None = None
def __init__(self, owner, *, intents) -> None:
self.owner = owner
self.intents = intents
self.closed = False
self.ready = True
self.channels: dict[int, object] = {}
self.user = SimpleNamespace(id=999)
self.__class__.instances.append(self)
async def start(self, token: str) -> None:
self.token = token
if self.__class__.start_error is not None:
raise self.__class__.start_error
async def close(self) -> None:
self.closed = True
def is_closed(self) -> bool:
return self.closed
def is_ready(self) -> bool:
return self.ready
def get_channel(self, channel_id: int):
return self.channels.get(channel_id)
async def send_outbound(self, msg: OutboundMessage) -> None:
channel = self.get_channel(int(msg.chat_id))
if channel is None:
return
await channel.send(content=msg.content)
class _FakeAttachment:
# Attachment double that can simulate successful or failing save() calls.
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
self.id = attachment_id
self.filename = filename
self.size = size
self._fail = fail
async def save(self, path: str | Path) -> None:
if self._fail:
raise RuntimeError("save failed")
Path(path).write_bytes(b"attachment")
class _FakePartialMessage:
# Lightweight stand-in for Discord partial message references used in replies.
def __init__(self, message_id: int) -> None:
self.id = message_id
class _FakeChannel:
# Channel double that records outbound payloads and typing activity.
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
self.sent_payloads: list[dict] = []
self.trigger_typing_calls = 0
self.typing_enter_hook = None
async def send(self, **kwargs) -> None:
payload = dict(kwargs)
if "file" in payload:
payload["file_name"] = payload["file"].filename
del payload["file"]
self.sent_payloads.append(payload)
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
return _FakePartialMessage(message_id)
def typing(self):
channel = self
class _TypingContext:
async def __aenter__(self):
channel.trigger_typing_calls += 1
if channel.typing_enter_hook is not None:
await channel.typing_enter_hook()
async def __aexit__(self, exc_type, exc, tb):
return False
return _TypingContext()
class _FakeInteractionResponse:
def __init__(self) -> None:
self.messages: list[dict] = []
self._done = False
async def send_message(self, content: str, *, ephemeral: bool = False) -> None:
self.messages.append({"content": content, "ephemeral": ephemeral})
self._done = True
def is_done(self) -> bool:
return self._done
def _make_interaction(
*,
user_id: int = 123,
channel_id: int | None = 456,
guild_id: int | None = None,
interaction_id: int = 999,
):
return SimpleNamespace(
user=SimpleNamespace(id=user_id),
channel_id=channel_id,
guild_id=guild_id,
id=interaction_id,
command=SimpleNamespace(qualified_name="new"),
response=_FakeInteractionResponse(),
)
def _make_message(
*,
author_id: int = 123,
author_bot: bool = False,
channel_id: int = 456,
message_id: int = 789,
content: str = "hello",
guild_id: int | None = None,
mentions: list[object] | None = None,
attachments: list[object] | None = None,
reply_to: int | None = None,
):
# Factory for incoming Discord message objects with optional guild/reply/attachments.
guild = SimpleNamespace(id=guild_id) if guild_id is not None else None
reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None
return SimpleNamespace(
author=SimpleNamespace(id=author_id, bot=author_bot),
channel=_FakeChannel(channel_id),
content=content,
guild=guild,
mentions=mentions or [],
attachments=attachments or [],
reference=reference,
id=message_id,
)
@pytest.mark.asyncio
async def test_start_returns_when_token_missing() -> None:
# If no token is configured, startup should no-op and leave channel stopped.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None:
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False)
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_handles_client_construction_failure(monkeypatch) -> None:
# Construction errors from the Discord client should be swallowed and keep state clean.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
def _boom(owner, *, intents):
raise RuntimeError("bad client")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
await channel.start()
assert channel.is_running is False
assert channel._client is None
@pytest.mark.asyncio
async def test_start_handles_client_start_failure(monkeypatch) -> None:
# If client.start fails, the partially created client should be closed and detached.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
_FakeDiscordClient.instances.clear()
_FakeDiscordClient.start_error = RuntimeError("connect failed")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert channel._client is None
assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents
assert _FakeDiscordClient.instances[0].closed is True
_FakeDiscordClient.start_error = None
@pytest.mark.asyncio
async def test_stop_is_safe_after_partial_start(monkeypatch) -> None:
# stop() should close/discard the client even when startup was only partially completed.
channel = DiscordChannel(
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
MessageBus(),
)
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
await channel.stop()
assert channel.is_running is False
assert client.closed is True
assert channel._client is None
@pytest.mark.asyncio
async def test_on_message_ignores_bot_messages() -> None:
# Incoming bot-authored messages must be ignored to prevent feedback loops.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign]
await channel._on_message(_make_message(author_bot=True))
assert handled == []
# If inbound handling raises, typing should be stopped for that channel.
async def fail_handle(**kwargs) -> None:
raise RuntimeError("boom")
channel._handle_message = fail_handle # type: ignore[method-assign]
with pytest.raises(RuntimeError, match="boom"):
await channel._on_message(_make_message(author_id=123, channel_id=456))
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_on_message_accepts_allowlisted_dm() -> None:
# Allowed direct messages should be forwarded with normalized metadata.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789))
assert len(handled) == 1
assert handled[0]["chat_id"] == "456"
assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None}
@pytest.mark.asyncio
async def test_on_message_ignores_unmentioned_guild_message() -> None:
# With mention-only group policy, guild messages without a bot mention are dropped.
channel = DiscordChannel(
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._bot_user_id = "999"
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(_make_message(guild_id=1, content="hello everyone"))
assert handled == []
@pytest.mark.asyncio
async def test_on_message_accepts_mentioned_guild_message() -> None:
# Mentioned guild messages should be accepted and preserve reply threading metadata.
channel = DiscordChannel(
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
MessageBus(),
)
channel._bot_user_id = "999"
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
await channel._on_message(
_make_message(
guild_id=1,
content="<@999> hello",
mentions=[SimpleNamespace(id=999)],
reply_to=321,
)
)
assert len(handled) == 1
assert handled[0]["metadata"]["reply_to"] == "321"
@pytest.mark.asyncio
async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None:
# Attachment downloads should be saved and referenced in forwarded content/media.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
await channel._on_message(
_make_message(
attachments=[_FakeAttachment(12, "photo.png")],
content="see file",
)
)
assert len(handled) == 1
assert handled[0]["media"] == [str(tmp_path / "12_photo.png")]
assert "[attachment:" in handled[0]["content"]
@pytest.mark.asyncio
async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None:
# Failed attachment downloads should emit a readable placeholder and no media path.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
await channel._on_message(
_make_message(
attachments=[_FakeAttachment(12, "photo.png", fail=True)],
content="",
)
)
assert len(handled) == 1
assert handled[0]["media"] == []
assert handled[0]["content"] == "[attachment: photo.png - download failed]"
@pytest.mark.asyncio
async def test_send_warns_when_client_not_ready() -> None:
# Sending without a running/ready client should be a safe no-op.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_send_skips_when_channel_not_cached() -> None:
# Outbound sends should be skipped when the destination channel is not resolvable.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
fetch_calls: list[int] = []
async def fetch_channel(channel_id: int):
fetch_calls.append(channel_id)
raise RuntimeError("not found")
client.fetch_channel = fetch_channel # type: ignore[method-assign]
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert client.get_channel(123) is None
assert fetch_calls == [123]
@pytest.mark.asyncio
async def test_send_fetches_channel_when_not_cached() -> None:
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
async def fetch_channel(channel_id: int):
return target if channel_id == 123 else None
client.fetch_channel = fetch_channel # type: ignore[method-assign]
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
assert target.sent_payloads == [{"content": "hello"}]
@pytest.mark.asyncio
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321)
new_cmd = client.tree.get_command("new")
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "Processing /new...", "ephemeral": True}
]
assert len(handled) == 1
assert handled[0]["content"] == "/new"
assert handled[0]["sender_id"] == "123"
assert handled[0]["chat_id"] == "456"
assert handled[0]["metadata"]["interaction_id"] == "321"
assert handled[0]["metadata"]["is_slash_command"] is True
@pytest.mark.asyncio
async def test_slash_new_is_blocked_for_disallowed_user() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction(user_id=123, channel_id=456)
new_cmd = client.tree.get_command("new")
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "You are not allowed to use this bot.", "ephemeral": True}
]
assert handled == []
@pytest.mark.parametrize("slash_name", ["stop", "restart", "status"])
@pytest.mark.asyncio
async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction()
interaction.command.qualified_name = slash_name
cmd = client.tree.get_command(slash_name)
assert cmd is not None
await cmd.callback(interaction)
assert interaction.response.messages == [
{"content": f"Processing /{slash_name}...", "ephemeral": True}
]
assert len(handled) == 1
assert handled[0]["content"] == f"/{slash_name}"
assert handled[0]["metadata"]["is_slash_command"] is True
@pytest.mark.asyncio
async def test_slash_help_returns_ephemeral_help_text() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
handled: list[dict] = []
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle # type: ignore[method-assign]
client = DiscordBotClient(channel, intents=discord.Intents.none())
interaction = _make_interaction()
interaction.command.qualified_name = "help"
help_cmd = client.tree.get_command("help")
assert help_cmd is not None
await help_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": build_help_text(), "ephemeral": True}
]
assert handled == []
@pytest.mark.asyncio
async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None:
# Outbound payloads should upload files, attach reply references, and chunk long text.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
file_path = tmp_path / "demo.txt"
file_path.write_text("hi")
await client.send_outbound(
OutboundMessage(
channel="discord",
chat_id="123",
content="a" * 2100,
reply_to="55",
media=[str(file_path)],
)
)
assert len(target.sent_payloads) == 3
assert target.sent_payloads[0]["file_name"] == "demo.txt"
assert target.sent_payloads[0]["reference"].id == 55
assert target.sent_payloads[1]["content"] == "a" * 2000
assert target.sent_payloads[2]["content"] == "a" * 100
@pytest.mark.asyncio
async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None:
# If all attachment sends fail and no text exists, emit a failure placeholder message.
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = DiscordBotClient(owner, intents=discord.Intents.none())
target = _FakeChannel(channel_id=123)
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
missing_file = tmp_path / "missing.txt"
await client.send_outbound(
OutboundMessage(
channel="discord",
chat_id="123",
content="",
media=[str(missing_file)],
)
)
assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}]
@pytest.mark.asyncio
async def test_send_stops_typing_after_send() -> None:
# Active typing indicators should be cancelled/cleared after a successful send.
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
start = asyncio.Event()
release = asyncio.Event()
async def slow_typing() -> None:
start.set()
await release.wait()
typing_channel = _FakeChannel(channel_id=123)
typing_channel.typing_enter_hook = slow_typing
await channel._start_typing(typing_channel)
await start.wait()
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}
# Progress messages should keep typing active until a final (non-progress) send.
start = asyncio.Event()
release = asyncio.Event()
async def slow_typing_progress() -> None:
start.set()
await release.wait()
typing_channel = _FakeChannel(channel_id=123)
typing_channel.typing_enter_hook = slow_typing_progress
await channel._start_typing(typing_channel)
await start.wait()
await channel.send(
OutboundMessage(
channel="discord",
chat_id="123",
content="progress",
metadata={"_progress": True},
)
)
assert "123" in channel._typing_tasks
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final"))
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
channel._running = True
entered = asyncio.Event()
release = asyncio.Event()
class _TypingCtx:
async def __aenter__(self):
entered.set()
async def __aexit__(self, exc_type, exc, tb):
return False
class _NoTriggerChannel:
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
def typing(self):
async def _waiter():
await release.wait()
# Hold the loop so task remains active until explicitly stopped.
class _Ctx(_TypingCtx):
async def __aenter__(self):
await super().__aenter__()
await _waiter()
return _Ctx()
typing_channel = _NoTriggerChannel(channel_id=123)
await channel._start_typing(typing_channel) # type: ignore[arg-type]
await entered.wait()
assert "123" in channel._typing_tasks
await channel._stop_typing("123")
release.set()
await asyncio.sleep(0)
assert channel._typing_tasks == {}

View File

@ -3,6 +3,9 @@ from pathlib import Path
from types import SimpleNamespace
import pytest
from nio import RoomSendResponse
from nanobot.channels.matrix import _build_matrix_text_content
# Check optional matrix dependencies before importing
try:
@ -65,6 +68,7 @@ class _FakeAsyncClient:
self.raise_on_send = False
self.raise_on_typing = False
self.raise_on_upload = False
self.room_send_response: RoomSendResponse | None = RoomSendResponse(event_id="", room_id="")
def add_event_callback(self, callback, event_type) -> None:
self.callbacks.append((callback, event_type))
@ -87,7 +91,7 @@ class _FakeAsyncClient:
message_type: str,
content: dict[str, object],
ignore_unverified_devices: object = _ROOM_SEND_UNSET,
) -> None:
) -> RoomSendResponse:
call: dict[str, object] = {
"room_id": room_id,
"message_type": message_type,
@ -98,6 +102,7 @@ class _FakeAsyncClient:
self.room_send_calls.append(call)
if self.raise_on_send:
raise RuntimeError("send failed")
return self.room_send_response
async def room_typing(
self,
@ -520,6 +525,7 @@ async def test_on_message_room_mention_requires_opt_in() -> None:
source={"content": {"m.mentions": {"room": True}}},
)
channel.config.allow_room_mentions = False
await channel._on_message(room, room_mention_event)
assert handled == []
assert client.typing_calls == []
@ -1322,3 +1328,302 @@ async def test_send_keeps_plaintext_only_for_plain_text() -> None:
"body": text,
"m.mentions": {},
}
def test_build_matrix_text_content_basic_text() -> None:
"""Test basic text content without HTML formatting."""
result = _build_matrix_text_content("Hello, World!")
expected = {
"msgtype": "m.text",
"body": "Hello, World!",
"m.mentions": {}
}
assert expected == result
def test_build_matrix_text_content_with_markdown() -> None:
"""Test text content with markdown that renders to HTML."""
text = "*Hello* **World**"
result = _build_matrix_text_content(text)
assert "msgtype" in result
assert "body" in result
assert result["body"] == text
assert "format" in result
assert result["format"] == "org.matrix.custom.html"
assert "formatted_body" in result
assert isinstance(result["formatted_body"], str)
assert len(result["formatted_body"]) > 0
def test_build_matrix_text_content_with_event_id() -> None:
"""Test text content with event_id for message replacement."""
event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
result = _build_matrix_text_content("Updated message", event_id)
assert "msgtype" in result
assert "body" in result
assert result["m.new_content"]
assert result["m.new_content"]["body"] == "Updated message"
assert result["m.relates_to"]["rel_type"] == "m.replace"
assert result["m.relates_to"]["event_id"] == event_id
def test_build_matrix_text_content_with_event_id_preserves_thread_relation() -> None:
"""Thread relations for edits should stay inside m.new_content."""
relates_to = {
"rel_type": "m.thread",
"event_id": "$root1",
"m.in_reply_to": {"event_id": "$reply1"},
"is_falling_back": True,
}
result = _build_matrix_text_content("Updated message", "event-1", relates_to)
assert result["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "event-1",
}
assert result["m.new_content"]["m.relates_to"] == relates_to
def test_build_matrix_text_content_no_event_id() -> None:
"""Test that when event_id is not provided, no extra properties are added."""
result = _build_matrix_text_content("Regular message")
# Basic required properties should be present
assert "msgtype" in result
assert "body" in result
assert result["body"] == "Regular message"
# Extra properties for replacement should NOT be present
assert "m.relates_to" not in result
assert "m.new_content" not in result
assert "format" not in result
assert "formatted_body" not in result
def test_build_matrix_text_content_plain_text_no_html() -> None:
"""Test plain text that should not include HTML formatting."""
result = _build_matrix_text_content("Simple plain text")
assert "msgtype" in result
assert "body" in result
assert "format" not in result
assert "formatted_body" not in result
@pytest.mark.asyncio
async def test_send_room_content_returns_room_send_response():
"""Test that _send_room_content returns the response from client.room_send."""
client = _FakeAsyncClient("", "", "", None)
channel = MatrixChannel(_make_config(), MessageBus())
channel.client = client
room_id = "!test_room:matrix.org"
content = {"msgtype": "m.text", "body": "Hello World"}
result = await channel._send_room_content(room_id, content)
assert result is client.room_send_response
@pytest.mark.asyncio
async def test_send_delta_creates_stream_buffer_and_sends_initial_message() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
await channel.send_delta("!room:matrix.org", "Hello")
assert "!room:matrix.org" in channel._stream_bufs
buf = channel._stream_bufs["!room:matrix.org"]
assert buf.text == "Hello"
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
assert len(client.room_send_calls) == 1
assert client.room_send_calls[0]["content"]["body"] == "Hello"
@pytest.mark.asyncio
async def test_send_delta_appends_without_sending_before_edit_interval(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
now = 100.0
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
await channel.send_delta("!room:matrix.org", "Hello")
assert len(client.room_send_calls) == 1
await channel.send_delta("!room:matrix.org", " world")
assert len(client.room_send_calls) == 1
buf = channel._stream_bufs["!room:matrix.org"]
assert buf.text == "Hello world"
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
@pytest.mark.asyncio
async def test_send_delta_edits_again_after_interval(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
times = [100.0, 102.0, 104.0, 106.0, 108.0]
times.reverse()
monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop())
await channel.send_delta("!room:matrix.org", "Hello")
await channel.send_delta("!room:matrix.org", " world")
assert len(client.room_send_calls) == 2
first_content = client.room_send_calls[0]["content"]
second_content = client.room_send_calls[1]["content"]
assert "body" in first_content
assert first_content["body"] == "Hello"
assert "m.relates_to" not in first_content
assert "body" in second_content
assert "m.relates_to" in second_content
assert second_content["body"] == "Hello world"
assert second_content["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo",
}
@pytest.mark.asyncio
async def test_send_delta_stream_end_replaces_existing_message() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
channel._stream_bufs["!room:matrix.org"] = matrix_module._StreamBuf(
text="Final text",
event_id="event-1",
last_edit=100.0,
)
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
assert "!room:matrix.org" not in channel._stream_bufs
assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
assert len(client.room_send_calls) == 1
assert client.room_send_calls[0]["content"]["body"] == "Final text"
assert client.room_send_calls[0]["content"]["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "event-1",
}
@pytest.mark.asyncio
async def test_send_delta_starts_threaded_stream_inside_thread() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "event-1"
metadata = {
"thread_root_event_id": "$root1",
"thread_reply_to_event_id": "$reply1",
}
await channel.send_delta("!room:matrix.org", "Hello", metadata)
assert client.room_send_calls[0]["content"]["m.relates_to"] == {
"rel_type": "m.thread",
"event_id": "$root1",
"m.in_reply_to": {"event_id": "$reply1"},
"is_falling_back": True,
}
@pytest.mark.asyncio
async def test_send_delta_threaded_edit_keeps_replace_and_thread_relation(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "event-1"
times = [100.0, 102.0, 104.0]
times.reverse()
monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop())
metadata = {
"thread_root_event_id": "$root1",
"thread_reply_to_event_id": "$reply1",
}
await channel.send_delta("!room:matrix.org", "Hello", metadata)
await channel.send_delta("!room:matrix.org", " world", metadata)
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True, **metadata})
edit_content = client.room_send_calls[1]["content"]
final_content = client.room_send_calls[2]["content"]
assert edit_content["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "event-1",
}
assert edit_content["m.new_content"]["m.relates_to"] == {
"rel_type": "m.thread",
"event_id": "$root1",
"m.in_reply_to": {"event_id": "$reply1"},
"is_falling_back": True,
}
assert final_content["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "event-1",
}
assert final_content["m.new_content"]["m.relates_to"] == {
"rel_type": "m.thread",
"event_id": "$root1",
"m.in_reply_to": {"event_id": "$reply1"},
"is_falling_back": True,
}
@pytest.mark.asyncio
async def test_send_delta_stream_end_noop_when_buffer_missing() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
assert client.room_send_calls == []
assert client.typing_calls == []
@pytest.mark.asyncio
async def test_send_delta_on_error_stops_typing(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
client.raise_on_send = True
channel.client = client
now = 100.0
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
await channel.send_delta("!room:matrix.org", "Hello", {"room_id": "!room:matrix.org"})
assert "!room:matrix.org" in channel._stream_bufs
assert channel._stream_bufs["!room:matrix.org"].text == "Hello"
assert len(client.room_send_calls) == 1
assert len(client.typing_calls) == 1
@pytest.mark.asyncio
async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
now = 100.0
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
await channel.send_delta("!room:matrix.org", " ")
assert "!room:matrix.org" in channel._stream_bufs
assert channel._stream_bufs["!room:matrix.org"].text == " "
assert client.room_send_calls == []

View File

@ -1,17 +1,22 @@
import asyncio
import json
import tempfile
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
import httpx
import nanobot.channels.weixin as weixin_mod
from nanobot.bus.queue import MessageBus
from nanobot.channels.weixin import (
ITEM_IMAGE,
ITEM_TEXT,
MESSAGE_TYPE_BOT,
WEIXIN_CHANNEL_VERSION,
_decrypt_aes_ecb,
_encrypt_aes_ecb,
WeixinChannel,
WeixinConfig,
)
@ -42,10 +47,12 @@ def test_make_headers_includes_route_tag_when_configured() -> None:
assert headers["Authorization"] == "Bearer token"
assert headers["SKRouteTag"] == "123"
assert headers["iLink-App-Id"] == "bot"
assert headers["iLink-App-ClientVersion"] == str((2 << 16) | (1 << 8) | 1)
def test_channel_version_matches_reference_plugin_version() -> None:
assert WEIXIN_CHANNEL_VERSION == "1.0.3"
assert WEIXIN_CHANNEL_VERSION == "2.1.1"
def test_save_and_load_state_persists_context_tokens(tmp_path) -> None:
@ -169,6 +176,120 @@ async def test_process_message_extracts_media_and_preserves_paths() -> None:
assert inbound.media == ["/tmp/test.jpg"]
@pytest.mark.asyncio
async def test_process_message_falls_back_to_referenced_media_when_no_top_level_media() -> None:
channel, bus = _make_channel()
channel._download_media_item = AsyncMock(return_value="/tmp/ref.jpg")
await channel._process_message(
{
"message_type": 1,
"message_id": "m3-ref-fallback",
"from_user_id": "wx-user",
"context_token": "ctx-3-ref-fallback",
"item_list": [
{
"type": ITEM_TEXT,
"text_item": {"text": "reply to image"},
"ref_msg": {
"message_item": {
"type": ITEM_IMAGE,
"image_item": {"media": {"encrypt_query_param": "ref-enc"}},
},
},
},
],
}
)
inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
channel._download_media_item.assert_awaited_once_with(
{"media": {"encrypt_query_param": "ref-enc"}},
"image",
)
assert inbound.media == ["/tmp/ref.jpg"]
assert "reply to image" in inbound.content
assert "[image]" in inbound.content
@pytest.mark.asyncio
async def test_process_message_does_not_use_referenced_fallback_when_top_level_media_exists() -> None:
channel, bus = _make_channel()
channel._download_media_item = AsyncMock(side_effect=["/tmp/top.jpg", "/tmp/ref.jpg"])
await channel._process_message(
{
"message_type": 1,
"message_id": "m3-ref-no-fallback",
"from_user_id": "wx-user",
"context_token": "ctx-3-ref-no-fallback",
"item_list": [
{"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "top-enc"}}},
{
"type": ITEM_TEXT,
"text_item": {"text": "has top-level media"},
"ref_msg": {
"message_item": {
"type": ITEM_IMAGE,
"image_item": {"media": {"encrypt_query_param": "ref-enc"}},
},
},
},
],
}
)
inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
channel._download_media_item.assert_awaited_once_with(
{"media": {"encrypt_query_param": "top-enc"}},
"image",
)
assert inbound.media == ["/tmp/top.jpg"]
assert "/tmp/ref.jpg" not in inbound.content
@pytest.mark.asyncio
async def test_process_message_does_not_fallback_when_top_level_media_exists_but_download_fails() -> None:
channel, bus = _make_channel()
# Top-level image download fails (None), referenced image would succeed if fallback were triggered.
channel._download_media_item = AsyncMock(side_effect=[None, "/tmp/ref.jpg"])
await channel._process_message(
{
"message_type": 1,
"message_id": "m3-ref-no-fallback-on-failure",
"from_user_id": "wx-user",
"context_token": "ctx-3-ref-no-fallback-on-failure",
"item_list": [
{"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "top-enc"}}},
{
"type": ITEM_TEXT,
"text_item": {"text": "quoted has media"},
"ref_msg": {
"message_item": {
"type": ITEM_IMAGE,
"image_item": {"media": {"encrypt_query_param": "ref-enc"}},
},
},
},
],
}
)
inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
# Should only attempt top-level media item; reference fallback must not activate.
channel._download_media_item.assert_awaited_once_with(
{"media": {"encrypt_query_param": "top-enc"}},
"image",
)
assert inbound.media == []
assert "[image]" in inbound.content
assert "/tmp/ref.jpg" not in inbound.content
@pytest.mark.asyncio
async def test_send_without_context_token_does_not_send_text() -> None:
channel, _bus = _make_channel()
@ -199,6 +320,70 @@ async def test_send_does_not_send_when_session_is_paused() -> None:
channel._send_text.assert_not_awaited()
@pytest.mark.asyncio
async def test_get_typing_ticket_fetches_and_caches_per_user() -> None:
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._api_post = AsyncMock(return_value={"ret": 0, "typing_ticket": "ticket-1"})
first = await channel._get_typing_ticket("wx-user", "ctx-1")
second = await channel._get_typing_ticket("wx-user", "ctx-2")
assert first == "ticket-1"
assert second == "ticket-1"
channel._api_post.assert_awaited_once_with(
"ilink/bot/getconfig",
{"ilink_user_id": "wx-user", "context_token": "ctx-1", "base_info": weixin_mod.BASE_INFO},
)
@pytest.mark.asyncio
async def test_send_uses_typing_start_and_cancel_when_ticket_available() -> None:
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-typing"
channel._send_text = AsyncMock()
channel._api_post = AsyncMock(
side_effect=[
{"ret": 0, "typing_ticket": "ticket-typing"},
{"ret": 0},
{"ret": 0},
]
)
await channel.send(
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
)
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-typing")
assert channel._api_post.await_count == 3
assert channel._api_post.await_args_list[0].args[0] == "ilink/bot/getconfig"
assert channel._api_post.await_args_list[1].args[0] == "ilink/bot/sendtyping"
assert channel._api_post.await_args_list[1].args[1]["status"] == 1
assert channel._api_post.await_args_list[2].args[0] == "ilink/bot/sendtyping"
assert channel._api_post.await_args_list[2].args[1]["status"] == 2
@pytest.mark.asyncio
async def test_send_still_sends_text_when_typing_ticket_missing() -> None:
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-no-ticket"
channel._send_text = AsyncMock()
channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "no config"})
await channel.send(
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
)
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-no-ticket")
channel._api_post.assert_awaited_once()
assert channel._api_post.await_args_list[0].args[0] == "ilink/bot/getconfig"
@pytest.mark.asyncio
async def test_poll_once_pauses_session_on_expired_errcode() -> None:
channel, _bus = _make_channel()
@ -220,8 +405,12 @@ async def test_qr_login_refreshes_expired_qr_and_then_succeeds() -> None:
channel._api_get = AsyncMock(
side_effect=[
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
{"status": "expired"},
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
]
)
channel._api_get_with_base = AsyncMock(
side_effect=[
{"status": "expired"},
{
"status": "confirmed",
"bot_token": "token-2",
@ -247,12 +436,16 @@ async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None:
channel._api_get = AsyncMock(
side_effect=[
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
{"status": "expired"},
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
{"status": "expired"},
{"qrcode": "qr-3", "qrcode_img_content": "url-3"},
{"status": "expired"},
{"qrcode": "qr-4", "qrcode_img_content": "url-4"},
]
)
channel._api_get_with_base = AsyncMock(
side_effect=[
{"status": "expired"},
{"status": "expired"},
{"status": "expired"},
{"status": "expired"},
]
)
@ -262,6 +455,105 @@ async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None:
assert ok is False
@pytest.mark.asyncio
async def test_qr_login_switches_polling_base_url_on_redirect_status() -> None:
channel, _bus = _make_channel()
channel._running = True
channel._save_state = lambda: None
channel._print_qr_code = lambda url: None
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
status_side_effect = [
{"status": "scaned_but_redirect", "redirect_host": "idc.redirect.test"},
{
"status": "confirmed",
"bot_token": "token-3",
"ilink_bot_id": "bot-3",
"baseurl": "https://example.test",
"ilink_user_id": "wx-user",
},
]
channel._api_get = AsyncMock(side_effect=list(status_side_effect))
channel._api_get_with_base = AsyncMock(side_effect=list(status_side_effect))
ok = await channel._qr_login()
assert ok is True
assert channel._token == "token-3"
assert channel._api_get_with_base.await_count == 2
first_call = channel._api_get_with_base.await_args_list[0]
second_call = channel._api_get_with_base.await_args_list[1]
assert first_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
assert second_call.kwargs["base_url"] == "https://idc.redirect.test"
@pytest.mark.asyncio
async def test_qr_login_redirect_without_host_keeps_current_polling_base_url() -> None:
channel, _bus = _make_channel()
channel._running = True
channel._save_state = lambda: None
channel._print_qr_code = lambda url: None
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
status_side_effect = [
{"status": "scaned_but_redirect"},
{
"status": "confirmed",
"bot_token": "token-4",
"ilink_bot_id": "bot-4",
"baseurl": "https://example.test",
"ilink_user_id": "wx-user",
},
]
channel._api_get = AsyncMock(side_effect=list(status_side_effect))
channel._api_get_with_base = AsyncMock(side_effect=list(status_side_effect))
ok = await channel._qr_login()
assert ok is True
assert channel._token == "token-4"
assert channel._api_get_with_base.await_count == 2
first_call = channel._api_get_with_base.await_args_list[0]
second_call = channel._api_get_with_base.await_args_list[1]
assert first_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
assert second_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
@pytest.mark.asyncio
async def test_qr_login_resets_redirect_base_url_after_qr_refresh() -> None:
channel, _bus = _make_channel()
channel._running = True
channel._save_state = lambda: None
channel._print_qr_code = lambda url: None
channel._fetch_qr_code = AsyncMock(side_effect=[("qr-1", "url-1"), ("qr-2", "url-2")])
channel._api_get_with_base = AsyncMock(
side_effect=[
{"status": "scaned_but_redirect", "redirect_host": "idc.redirect.test"},
{"status": "expired"},
{
"status": "confirmed",
"bot_token": "token-5",
"ilink_bot_id": "bot-5",
"baseurl": "https://example.test",
"ilink_user_id": "wx-user",
},
]
)
ok = await channel._qr_login()
assert ok is True
assert channel._token == "token-5"
assert channel._api_get_with_base.await_count == 3
first_call = channel._api_get_with_base.await_args_list[0]
second_call = channel._api_get_with_base.await_args_list[1]
third_call = channel._api_get_with_base.await_args_list[2]
assert first_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
assert second_call.kwargs["base_url"] == "https://idc.redirect.test"
assert third_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
@pytest.mark.asyncio
async def test_process_message_skips_bot_messages() -> None:
channel, bus = _make_channel()
@ -281,12 +573,13 @@ async def test_process_message_skips_bot_messages() -> None:
@pytest.mark.asyncio
async def test_process_message_fetches_typing_ticket_and_starts_typing() -> None:
async def test_process_message_starts_typing_on_inbound() -> None:
"""Typing indicator fires immediately when user message arrives."""
channel, _bus = _make_channel()
channel._running = True
channel._client = object()
channel._token = "token"
channel._api_post = AsyncMock(return_value={"typing_ticket": "ticket-1"})
channel._start_typing = AsyncMock()
await channel._process_message(
{
@ -300,42 +593,42 @@ async def test_process_message_fetches_typing_ticket_and_starts_typing() -> None
}
)
assert channel._typing_tickets["wx-user"] == "ticket-1"
assert "wx-user" in channel._typing_tasks
await channel._stop_typing("wx-user", clear_remote=False)
channel._start_typing.assert_awaited_once_with("wx-user", "ctx-typing")
@pytest.mark.asyncio
async def test_send_final_message_clears_typing_indicator() -> None:
"""Non-progress send should cancel typing status."""
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-2"
channel._typing_tickets["wx-user"] = "ticket-2"
channel._typing_tickets["wx-user"] = {"ticket": "ticket-2", "next_fetch_at": 9999999999}
channel._send_text = AsyncMock()
channel._api_post = AsyncMock(return_value={})
channel._api_post = AsyncMock(return_value={"ret": 0})
await channel.send(
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
)
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2")
channel._api_post.assert_awaited_once()
endpoint, body = channel._api_post.await_args.args
assert endpoint == "ilink/bot/sendtyping"
assert body["status"] == 2
assert body["typing_ticket"] == "ticket-2"
typing_cancel_calls = [
c for c in channel._api_post.await_args_list
if c.args[0] == "ilink/bot/sendtyping" and c.args[1]["status"] == 2
]
assert len(typing_cancel_calls) >= 1
@pytest.mark.asyncio
async def test_send_progress_message_keeps_typing_indicator() -> None:
"""Progress messages must not cancel typing status."""
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-2"
channel._typing_tickets["wx-user"] = "ticket-2"
channel._typing_tickets["wx-user"] = {"ticket": "ticket-2", "next_fetch_at": 9999999999}
channel._send_text = AsyncMock()
channel._api_post = AsyncMock(return_value={})
channel._api_post = AsyncMock(return_value={"ret": 0})
await channel.send(
type(
@ -351,4 +644,362 @@ async def test_send_progress_message_keeps_typing_indicator() -> None:
)
channel._send_text.assert_awaited_once_with("wx-user", "thinking", "ctx-2")
channel._api_post.assert_not_awaited()
typing_cancel_calls = [
c for c in channel._api_post.await_args_list
if c.args and c.args[0] == "ilink/bot/sendtyping" and c.args[1].get("status") == 2
]
assert len(typing_cancel_calls) == 0
class _DummyHttpResponse:
def __init__(self, *, headers: dict[str, str] | None = None, status_code: int = 200) -> None:
self.headers = headers or {}
self.status_code = status_code
def raise_for_status(self) -> None:
return None
@pytest.mark.asyncio
async def test_send_media_uses_upload_full_url_when_present(tmp_path) -> None:
channel, _bus = _make_channel()
media_file = tmp_path / "photo.jpg"
media_file.write_bytes(b"hello-weixin")
cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "dl-param"}))
channel._client = SimpleNamespace(post=cdn_post)
channel._api_post = AsyncMock(
side_effect=[
{
"upload_full_url": "https://upload-full.example.test/path?foo=bar",
"upload_param": "should-not-be-used",
},
{"ret": 0},
]
)
await channel._send_media_file("wx-user", str(media_file), "ctx-1")
# first POST call is CDN upload
cdn_url = cdn_post.await_args_list[0].args[0]
assert cdn_url == "https://upload-full.example.test/path?foo=bar"
@pytest.mark.asyncio
async def test_send_media_falls_back_to_upload_param_url(tmp_path) -> None:
channel, _bus = _make_channel()
media_file = tmp_path / "photo.jpg"
media_file.write_bytes(b"hello-weixin")
cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "dl-param"}))
channel._client = SimpleNamespace(post=cdn_post)
channel._api_post = AsyncMock(
side_effect=[
{"upload_param": "enc-need-fallback"},
{"ret": 0},
]
)
await channel._send_media_file("wx-user", str(media_file), "ctx-1")
cdn_url = cdn_post.await_args_list[0].args[0]
assert cdn_url.startswith(f"{channel.config.cdn_base_url}/upload?encrypted_query_param=enc-need-fallback")
assert "&filekey=" in cdn_url
@pytest.mark.asyncio
async def test_send_media_voice_file_uses_voice_item_and_voice_upload_type(tmp_path) -> None:
channel, _bus = _make_channel()
media_file = tmp_path / "voice.mp3"
media_file.write_bytes(b"voice-bytes")
cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "voice-dl-param"}))
channel._client = SimpleNamespace(post=cdn_post)
channel._api_post = AsyncMock(
side_effect=[
{"upload_full_url": "https://upload-full.example.test/voice?foo=bar"},
{"ret": 0},
]
)
await channel._send_media_file("wx-user", str(media_file), "ctx-voice")
getupload_body = channel._api_post.await_args_list[0].args[1]
assert getupload_body["media_type"] == 4
sendmessage_body = channel._api_post.await_args_list[1].args[1]
item = sendmessage_body["msg"]["item_list"][0]
assert item["type"] == 3
assert "voice_item" in item
assert "file_item" not in item
assert item["voice_item"]["media"]["encrypt_query_param"] == "voice-dl-param"
@pytest.mark.asyncio
async def test_send_typing_uses_keepalive_until_send_finishes() -> None:
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-typing-loop"
async def _api_post_side_effect(endpoint: str, _body: dict | None = None, *, auth: bool = True):
if endpoint == "ilink/bot/getconfig":
return {"ret": 0, "typing_ticket": "ticket-keepalive"}
return {"ret": 0}
channel._api_post = AsyncMock(side_effect=_api_post_side_effect)
async def _slow_send_text(*_args, **_kwargs) -> None:
await asyncio.sleep(0.03)
channel._send_text = AsyncMock(side_effect=_slow_send_text)
old_interval = weixin_mod.TYPING_KEEPALIVE_INTERVAL_S
weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = 0.01
try:
await channel.send(
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
)
finally:
weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = old_interval
status_calls = [
c.args[1]["status"]
for c in channel._api_post.await_args_list
if c.args and c.args[0] == "ilink/bot/sendtyping"
]
assert status_calls.count(1) >= 2
assert status_calls[-1] == 2
@pytest.mark.asyncio
async def test_get_typing_ticket_failure_uses_backoff_and_cached_ticket(monkeypatch) -> None:
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
now = {"value": 1000.0}
monkeypatch.setattr(weixin_mod.time, "time", lambda: now["value"])
monkeypatch.setattr(weixin_mod.random, "random", lambda: 0.5)
channel._api_post = AsyncMock(return_value={"ret": 0, "typing_ticket": "ticket-ok"})
first = await channel._get_typing_ticket("wx-user", "ctx-1")
assert first == "ticket-ok"
# force refresh window reached
now["value"] = now["value"] + (12 * 60 * 60) + 1
channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "temporary failure"})
# On refresh failure, should still return cached ticket and apply backoff.
second = await channel._get_typing_ticket("wx-user", "ctx-2")
assert second == "ticket-ok"
assert channel._api_post.await_count == 1
# Before backoff expiry, no extra fetch should happen.
now["value"] += 1
third = await channel._get_typing_ticket("wx-user", "ctx-3")
assert third == "ticket-ok"
assert channel._api_post.await_count == 1
@pytest.mark.asyncio
async def test_qr_login_treats_temporary_connect_error_as_wait_and_recovers() -> None:
channel, _bus = _make_channel()
channel._running = True
channel._save_state = lambda: None
channel._print_qr_code = lambda url: None
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status")
channel._api_get_with_base = AsyncMock(
side_effect=[
httpx.ConnectError("temporary network", request=request),
{
"status": "confirmed",
"bot_token": "token-net-ok",
"ilink_bot_id": "bot-id",
"baseurl": "https://example.test",
"ilink_user_id": "wx-user",
},
]
)
ok = await channel._qr_login()
assert ok is True
assert channel._token == "token-net-ok"
@pytest.mark.asyncio
async def test_qr_login_treats_5xx_gateway_response_error_as_wait_and_recovers() -> None:
channel, _bus = _make_channel()
channel._running = True
channel._save_state = lambda: None
channel._print_qr_code = lambda url: None
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status")
response = httpx.Response(status_code=524, request=request)
channel._api_get_with_base = AsyncMock(
side_effect=[
httpx.HTTPStatusError("gateway timeout", request=request, response=response),
{
"status": "confirmed",
"bot_token": "token-5xx-ok",
"ilink_bot_id": "bot-id",
"baseurl": "https://example.test",
"ilink_user_id": "wx-user",
},
]
)
ok = await channel._qr_login()
assert ok is True
assert channel._token == "token-5xx-ok"
def test_decrypt_aes_ecb_strips_valid_pkcs7_padding() -> None:
key_b64 = "MDEyMzQ1Njc4OWFiY2RlZg==" # base64("0123456789abcdef")
plaintext = b"hello-weixin-padding"
ciphertext = _encrypt_aes_ecb(plaintext, key_b64)
decrypted = _decrypt_aes_ecb(ciphertext, key_b64)
assert decrypted == plaintext
class _DummyDownloadResponse:
def __init__(self, content: bytes, status_code: int = 200) -> None:
self.content = content
self.status_code = status_code
def raise_for_status(self) -> None:
return None
class _DummyErrorDownloadResponse(_DummyDownloadResponse):
def __init__(self, url: str, status_code: int) -> None:
super().__init__(content=b"", status_code=status_code)
self._url = url
def raise_for_status(self) -> None:
request = httpx.Request("GET", self._url)
response = httpx.Response(self.status_code, request=request)
raise httpx.HTTPStatusError(
f"download failed with status {self.status_code}",
request=request,
response=response,
)
@pytest.mark.asyncio
async def test_download_media_item_uses_full_url_when_present(tmp_path) -> None:
channel, _bus = _make_channel()
weixin_mod.get_media_dir = lambda _name: tmp_path
full_url = "https://cdn.example.test/download/full"
channel._client = SimpleNamespace(
get=AsyncMock(return_value=_DummyDownloadResponse(content=b"raw-image-bytes"))
)
item = {
"media": {
"full_url": full_url,
"encrypt_query_param": "enc-fallback-should-not-be-used",
},
}
saved_path = await channel._download_media_item(item, "image")
assert saved_path is not None
assert Path(saved_path).read_bytes() == b"raw-image-bytes"
channel._client.get.assert_awaited_once_with(full_url)
@pytest.mark.asyncio
async def test_download_media_item_falls_back_when_full_url_returns_retryable_error(tmp_path) -> None:
channel, _bus = _make_channel()
weixin_mod.get_media_dir = lambda _name: tmp_path
full_url = "https://cdn.example.test/download/full?taskid=123"
channel._client = SimpleNamespace(
get=AsyncMock(
side_effect=[
_DummyErrorDownloadResponse(full_url, 500),
_DummyDownloadResponse(content=b"fallback-bytes"),
]
)
)
item = {
"media": {
"full_url": full_url,
"encrypt_query_param": "enc-fallback",
},
}
saved_path = await channel._download_media_item(item, "image")
assert saved_path is not None
assert Path(saved_path).read_bytes() == b"fallback-bytes"
assert channel._client.get.await_count == 2
assert channel._client.get.await_args_list[0].args[0] == full_url
fallback_url = channel._client.get.await_args_list[1].args[0]
assert fallback_url.startswith(f"{channel.config.cdn_base_url}/download?encrypted_query_param=enc-fallback")
@pytest.mark.asyncio
async def test_download_media_item_falls_back_to_encrypt_query_param(tmp_path) -> None:
channel, _bus = _make_channel()
weixin_mod.get_media_dir = lambda _name: tmp_path
channel._client = SimpleNamespace(
get=AsyncMock(return_value=_DummyDownloadResponse(content=b"fallback-bytes"))
)
item = {"media": {"encrypt_query_param": "enc-fallback"}}
saved_path = await channel._download_media_item(item, "image")
assert saved_path is not None
assert Path(saved_path).read_bytes() == b"fallback-bytes"
called_url = channel._client.get.await_args_list[0].args[0]
assert called_url.startswith(f"{channel.config.cdn_base_url}/download?encrypted_query_param=enc-fallback")
@pytest.mark.asyncio
async def test_download_media_item_does_not_retry_when_full_url_fails_without_fallback(tmp_path) -> None:
channel, _bus = _make_channel()
weixin_mod.get_media_dir = lambda _name: tmp_path
full_url = "https://cdn.example.test/download/full"
channel._client = SimpleNamespace(
get=AsyncMock(return_value=_DummyErrorDownloadResponse(full_url, 500))
)
item = {"media": {"full_url": full_url}}
saved_path = await channel._download_media_item(item, "image")
assert saved_path is None
channel._client.get.assert_awaited_once_with(full_url)
@pytest.mark.asyncio
async def test_download_media_item_non_image_requires_aes_key_even_with_full_url(tmp_path) -> None:
channel, _bus = _make_channel()
weixin_mod.get_media_dir = lambda _name: tmp_path
full_url = "https://cdn.example.test/download/voice"
channel._client = SimpleNamespace(
get=AsyncMock(return_value=_DummyDownloadResponse(content=b"ciphertext-or-unknown"))
)
item = {
"media": {
"full_url": full_url,
},
}
saved_path = await channel._download_media_item(item, "voice")
assert saved_path is None
channel._client.get.assert_not_awaited()

View File

@ -317,6 +317,75 @@ def test_openai_compat_provider_passes_model_through():
assert provider.get_default_model() == "github-copilot/gpt-5.3-codex"
def test_make_provider_uses_github_copilot_backend():
from nanobot.cli.commands import _make_provider
from nanobot.config.schema import Config
config = Config.model_validate(
{
"agents": {
"defaults": {
"provider": "github-copilot",
"model": "github-copilot/gpt-4.1",
}
}
}
)
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = _make_provider(config)
assert provider.__class__.__name__ == "GitHubCopilotProvider"
def test_github_copilot_provider_strips_prefixed_model_name():
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1")
kwargs = provider._build_kwargs(
messages=[{"role": "user", "content": "hi"}],
tools=None,
model="github-copilot/gpt-5.1",
max_tokens=16,
temperature=0.1,
reasoning_effort=None,
tool_choice=None,
)
assert kwargs["model"] == "gpt-5.1"
@pytest.mark.asyncio
async def test_github_copilot_provider_refreshes_client_api_key_before_chat():
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
mock_client = MagicMock()
mock_client.api_key = "no-key"
mock_client.chat.completions.create = AsyncMock(return_value={
"choices": [{"message": {"content": "ok"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
})
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI", return_value=mock_client):
provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1")
provider._get_copilot_access_token = AsyncMock(return_value="copilot-access-token")
response = await provider.chat(
messages=[{"role": "user", "content": "hi"}],
model="github-copilot/gpt-5.1",
max_tokens=16,
temperature=0.1,
)
assert response.content == "ok"
assert provider._client.api_key == "copilot-access-token"
provider._get_copilot_access_token.assert_awaited_once()
mock_client.chat.completions.create.assert_awaited_once()
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"
@ -642,27 +711,105 @@ def test_heartbeat_retains_recent_messages_by_default():
assert config.gateway.heartbeat.keep_recent_messages == 8
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
def _write_instance_config(tmp_path: Path) -> Path:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
return config_file
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
def _stop_gateway_provider(_config) -> object:
raise _StopGatewayError("stop")
def _patch_cli_command_runtime(
monkeypatch,
config: Config,
*,
set_config_path=None,
sync_templates=None,
make_provider=None,
message_bus=None,
session_manager=None,
cron_service=None,
get_cron_dir=None,
) -> None:
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
lambda path: seen.__setitem__("config_path", path),
set_config_path or (lambda _path: None),
)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr(
"nanobot.cli.commands.sync_workspace_templates",
lambda path: seen.__setitem__("workspace", path),
sync_templates or (lambda _path: None),
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
make_provider or (lambda _config: object()),
)
if message_bus is not None:
monkeypatch.setattr("nanobot.bus.queue.MessageBus", message_bus)
if session_manager is not None:
monkeypatch.setattr("nanobot.session.manager.SessionManager", session_manager)
if cron_service is not None:
monkeypatch.setattr("nanobot.cron.service.CronService", cron_service)
if get_cron_dir is not None:
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", get_cron_dir)
def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) -> None:
pytest.importorskip("aiohttp")
class _FakeApiApp:
def __init__(self) -> None:
self.on_startup: list[object] = []
self.on_cleanup: list[object] = []
class _FakeAgentLoop:
def __init__(self, **kwargs) -> None:
seen["workspace"] = kwargs["workspace"]
async def _connect_mcp(self) -> None:
return None
async def close_mcp(self) -> None:
return None
def _fake_create_app(agent_loop, model_name: str, request_timeout: float):
seen["agent_loop"] = agent_loop
seen["model_name"] = model_name
seen["request_timeout"] = request_timeout
return _FakeApiApp()
def _fake_run_app(api_app, host: str, port: int, print):
seen["api_app"] = api_app
seen["host"] = host
seen["port"] = port
_patch_cli_command_runtime(
monkeypatch,
config,
message_bus=lambda: object(),
session_manager=lambda _workspace: object(),
)
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.api.server.create_app", _fake_create_app)
monkeypatch.setattr("aiohttp.web.run_app", _fake_run_app)
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
config_file = _write_instance_config(tmp_path)
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
_patch_cli_command_runtime(
monkeypatch,
config,
set_config_path=lambda path: seen.__setitem__("config_path", path),
sync_templates=lambda path: seen.__setitem__("workspace", path),
make_provider=_stop_gateway_provider,
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
@ -673,24 +820,17 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa
def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
override = tmp_path / "override-workspace"
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr(
"nanobot.cli.commands.sync_workspace_templates",
lambda path: seen.__setitem__("workspace", path),
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
_patch_cli_command_runtime(
monkeypatch,
config,
sync_templates=lambda path: seen.__setitem__("workspace", path),
make_provider=_stop_gateway_provider,
)
result = runner.invoke(
@ -704,27 +844,23 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
raise _StopGatewayError("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
_patch_cli_command_runtime(
monkeypatch,
config,
message_bus=lambda: object(),
session_manager=lambda _workspace: object(),
cron_service=_StopCron,
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
@ -735,10 +871,7 @@ def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path:
def test_gateway_workspace_override_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path
) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
legacy_dir = tmp_path / "global" / "cron"
legacy_dir.mkdir(parents=True)
legacy_file = legacy_dir / "jobs.json"
@ -748,20 +881,19 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron(
config = Config()
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
raise _StopGatewayError("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
_patch_cli_command_runtime(
monkeypatch,
config,
message_bus=lambda: object(),
session_manager=lambda _workspace: object(),
cron_service=_StopCron,
get_cron_dir=lambda: legacy_dir,
)
result = runner.invoke(
app,
@ -777,10 +909,7 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron(
def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path
) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
legacy_dir = tmp_path / "global" / "cron"
legacy_dir.mkdir(parents=True)
legacy_file = legacy_dir / "jobs.json"
@ -791,20 +920,19 @@ def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
config.agents.defaults.workspace = str(custom_workspace)
seen: dict[str, Path] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
class _StopCron:
def __init__(self, store_path: Path) -> None:
seen["cron_store"] = store_path
raise _StopGatewayError("stop")
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
_patch_cli_command_runtime(
monkeypatch,
config,
message_bus=lambda: object(),
session_manager=lambda _workspace: object(),
cron_service=_StopCron,
get_cron_dir=lambda: legacy_dir,
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
@ -856,19 +984,14 @@ def test_migrate_cron_store_skips_when_workspace_file_exists(tmp_path: Path) ->
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
_patch_cli_command_runtime(
monkeypatch,
config,
make_provider=_stop_gateway_provider,
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
@ -878,19 +1001,14 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config_file = _write_instance_config(tmp_path)
config = Config()
config.gateway.port = 18791
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
_patch_cli_command_runtime(
monkeypatch,
config,
make_provider=_stop_gateway_provider,
)
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
@ -899,6 +1017,63 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
assert "port 18792" in result.stdout
def test_serve_uses_api_config_defaults_and_workspace_override(
monkeypatch, tmp_path: Path
) -> None:
config_file = _write_instance_config(tmp_path)
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
config.api.host = "127.0.0.2"
config.api.port = 18900
config.api.timeout = 45.0
override_workspace = tmp_path / "override-workspace"
seen: dict[str, object] = {}
_patch_serve_runtime(monkeypatch, config, seen)
result = runner.invoke(
app,
["serve", "--config", str(config_file), "--workspace", str(override_workspace)],
)
assert result.exit_code == 0
assert seen["workspace"] == override_workspace
assert seen["host"] == "127.0.0.2"
assert seen["port"] == 18900
assert seen["request_timeout"] == 45.0
def test_serve_cli_options_override_api_config(monkeypatch, tmp_path: Path) -> None:
config_file = _write_instance_config(tmp_path)
config = Config()
config.api.host = "127.0.0.2"
config.api.port = 18900
config.api.timeout = 45.0
seen: dict[str, object] = {}
_patch_serve_runtime(monkeypatch, config, seen)
result = runner.invoke(
app,
[
"serve",
"--config",
str(config_file),
"--host",
"127.0.0.1",
"--port",
"18901",
"--timeout",
"46",
],
)
assert result.exit_code == 0
assert seen["host"] == "127.0.0.1"
assert seen["port"] == 18901
assert seen["request_timeout"] == 46.0
def test_channels_login_requires_channel_name() -> None:
result = runner.invoke(app, ["channels", "login"])

View File

@ -152,10 +152,12 @@ class TestRestartCommand:
])
await loop._run_agent_loop([])
assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4}
assert loop._last_usage["prompt_tokens"] == 9
assert loop._last_usage["completion_tokens"] == 4
await loop._run_agent_loop([])
assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0}
assert loop._last_usage["prompt_tokens"] == 0
assert loop._last_usage["completion_tokens"] == 0
@pytest.mark.asyncio
async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self):

View File

@ -285,6 +285,28 @@ def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None:
assert job.schedule.at_ms == expected
def test_add_job_delivers_by_default(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool.set_context("telegram", "chat-1")
result = tool._add_job("Morning standup", 60, None, None, None)
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]
assert job.payload.deliver is True
def test_add_job_can_disable_delivery(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool.set_context("telegram", "chat-1")
result = tool._add_job("Background refresh", 60, None, None, None, deliver=False)
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]
assert job.payload.deliver is False
def test_list_excludes_disabled_jobs(tmp_path) -> None:
tool = _make_tool(tmp_path)
job = tool._cron.add_job(

View File

@ -0,0 +1,233 @@
"""Tests for cached token extraction from OpenAI-compatible providers."""
from __future__ import annotations
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
class FakeUsage:
"""Mimics an OpenAI SDK usage object (has attributes, not dict keys)."""
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class FakePromptDetails:
"""Mimics prompt_tokens_details sub-object."""
def __init__(self, cached_tokens=0):
self.cached_tokens = cached_tokens
class _FakeSpec:
supports_prompt_caching = False
model_id_prefix = None
strip_model_prefix = False
max_completion_tokens = False
reasoning_effort = None
def _provider():
from unittest.mock import MagicMock
p = OpenAICompatProvider.__new__(OpenAICompatProvider)
p.client = MagicMock()
p.spec = _FakeSpec()
return p
# Minimal valid choice so _parse reaches _extract_usage.
_DICT_CHOICE = {"message": {"content": "Hello"}}
class _FakeMessage:
content = "Hello"
tool_calls = None
class _FakeChoice:
message = _FakeMessage()
finish_reason = "stop"
# --- dict-based response (raw JSON / mapping) ---
def test_extract_usage_openai_cached_tokens_dict():
"""prompt_tokens_details.cached_tokens from a dict response."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 2000,
"completion_tokens": 300,
"total_tokens": 2300,
"prompt_tokens_details": {"cached_tokens": 1200},
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
assert result.usage["prompt_tokens"] == 2000
def test_extract_usage_deepseek_cached_tokens_dict():
"""prompt_cache_hit_tokens from a DeepSeek dict response."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 1500,
"completion_tokens": 200,
"total_tokens": 1700,
"prompt_cache_hit_tokens": 1200,
"prompt_cache_miss_tokens": 300,
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
def test_extract_usage_no_cached_tokens_dict():
"""Response without any cache fields -> no cached_tokens key."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 1000,
"completion_tokens": 200,
"total_tokens": 1200,
}
}
result = p._parse(response)
assert "cached_tokens" not in result.usage
def test_extract_usage_openai_cached_zero_dict():
"""cached_tokens=0 should NOT be included (same as existing fields)."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 2000,
"completion_tokens": 300,
"total_tokens": 2300,
"prompt_tokens_details": {"cached_tokens": 0},
}
}
result = p._parse(response)
assert "cached_tokens" not in result.usage
# --- object-based response (OpenAI SDK Pydantic model) ---
def test_extract_usage_openai_cached_tokens_obj():
"""prompt_tokens_details.cached_tokens from an SDK object response."""
p = _provider()
usage_obj = FakeUsage(
prompt_tokens=2000,
completion_tokens=300,
total_tokens=2300,
prompt_tokens_details=FakePromptDetails(cached_tokens=1200),
)
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
def test_extract_usage_deepseek_cached_tokens_obj():
"""prompt_cache_hit_tokens from a DeepSeek SDK object response."""
p = _provider()
usage_obj = FakeUsage(
prompt_tokens=1500,
completion_tokens=200,
total_tokens=1700,
prompt_cache_hit_tokens=1200,
)
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
result = p._parse(response)
assert result.usage["cached_tokens"] == 1200
def test_extract_usage_stepfun_top_level_cached_tokens_dict():
"""StepFun/Moonshot: usage.cached_tokens at top level (not nested)."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 591,
"completion_tokens": 120,
"total_tokens": 711,
"cached_tokens": 512,
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 512
def test_extract_usage_stepfun_top_level_cached_tokens_obj():
"""StepFun/Moonshot: usage.cached_tokens as SDK object attribute."""
p = _provider()
usage_obj = FakeUsage(
prompt_tokens=591,
completion_tokens=120,
total_tokens=711,
cached_tokens=512,
)
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
result = p._parse(response)
assert result.usage["cached_tokens"] == 512
def test_extract_usage_priority_nested_over_top_level_dict():
"""When both nested and top-level cached_tokens exist, nested wins."""
p = _provider()
response = {
"choices": [_DICT_CHOICE],
"usage": {
"prompt_tokens": 2000,
"completion_tokens": 300,
"total_tokens": 2300,
"prompt_tokens_details": {"cached_tokens": 100},
"cached_tokens": 500,
}
}
result = p._parse(response)
assert result.usage["cached_tokens"] == 100
def test_anthropic_maps_cache_fields_to_cached_tokens():
"""Anthropic's cache_read_input_tokens should map to cached_tokens."""
from nanobot.providers.anthropic_provider import AnthropicProvider
usage_obj = FakeUsage(
input_tokens=800,
output_tokens=200,
cache_creation_input_tokens=300,
cache_read_input_tokens=1200,
)
content_block = FakeUsage(type="text", text="hello")
response = FakeUsage(
id="msg_1",
type="message",
stop_reason="end_turn",
content=[content_block],
usage=usage_obj,
)
result = AnthropicProvider._parse_response(response)
assert result.usage["cached_tokens"] == 1200
assert result.usage["prompt_tokens"] == 2300
assert result.usage["total_tokens"] == 2500
assert result.usage["cache_creation_input_tokens"] == 300
def test_anthropic_no_cache_fields():
"""Anthropic response without cache fields should not have cached_tokens."""
from nanobot.providers.anthropic_provider import AnthropicProvider
usage_obj = FakeUsage(input_tokens=800, output_tokens=200)
content_block = FakeUsage(type="text", text="hello")
response = FakeUsage(
id="msg_1",
type="message",
stop_reason="end_turn",
content=[content_block],
usage=usage_obj,
)
result = AnthropicProvider._parse_response(response)
assert "cached_tokens" not in result.usage

View File

@ -11,6 +11,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.github_copilot_provider", raising=False)
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
providers = importlib.import_module("nanobot.providers")
@ -18,6 +19,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
assert "nanobot.providers.anthropic_provider" not in sys.modules
assert "nanobot.providers.openai_compat_provider" not in sys.modules
assert "nanobot.providers.openai_codex_provider" not in sys.modules
assert "nanobot.providers.github_copilot_provider" not in sys.modules
assert "nanobot.providers.azure_openai_provider" not in sys.modules
assert providers.__all__ == [
"LLMProvider",
@ -25,6 +27,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
"AnthropicProvider",
"OpenAICompatProvider",
"OpenAICodexProvider",
"GitHubCopilotProvider",
"AzureOpenAIProvider",
]

View File

@ -0,0 +1,59 @@
"""Tests for build_status_content cache hit rate display."""
from nanobot.utils.helpers import build_status_content
def test_status_shows_cache_hit_rate():
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 1200},
context_window_tokens=128000,
session_msg_count=10,
context_tokens_estimate=5000,
)
assert "60% cached" in content
assert "2000 in / 300 out" in content
def test_status_no_cache_info():
"""Without cached_tokens, display should not show cache percentage."""
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 2000, "completion_tokens": 300},
context_window_tokens=128000,
session_msg_count=10,
context_tokens_estimate=5000,
)
assert "cached" not in content.lower()
assert "2000 in / 300 out" in content
def test_status_zero_cached_tokens():
"""cached_tokens=0 should not show cache percentage."""
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 0},
context_window_tokens=128000,
session_msg_count=10,
context_tokens_estimate=5000,
)
assert "cached" not in content.lower()
def test_status_100_percent_cached():
content = build_status_content(
version="0.1.0",
model="glm-4-plus",
start_time=1000000.0,
last_usage={"prompt_tokens": 1000, "completion_tokens": 100, "cached_tokens": 1000},
context_window_tokens=128000,
session_msg_count=5,
context_tokens_estimate=3000,
)
assert "100% cached" in content

View File

@ -0,0 +1,168 @@
"""Tests for the Nanobot programmatic facade."""
from __future__ import annotations
import json
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.nanobot import Nanobot, RunResult
def _write_config(tmp_path: Path, overrides: dict | None = None) -> Path:
data = {
"providers": {"openrouter": {"apiKey": "sk-test-key"}},
"agents": {"defaults": {"model": "openai/gpt-4.1"}},
}
if overrides:
data.update(overrides)
config_path = tmp_path / "config.json"
config_path.write_text(json.dumps(data))
return config_path
def test_from_config_missing_file():
with pytest.raises(FileNotFoundError):
Nanobot.from_config("/nonexistent/config.json")
def test_from_config_creates_instance(tmp_path):
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
assert bot._loop is not None
assert bot._loop.workspace == tmp_path
def test_from_config_default_path():
from nanobot.config.schema import Config
with patch("nanobot.config.loader.load_config") as mock_load, \
patch("nanobot.nanobot._make_provider") as mock_prov:
mock_load.return_value = Config()
mock_prov.return_value = MagicMock()
mock_prov.return_value.get_default_model.return_value = "test"
mock_prov.return_value.generation.max_tokens = 4096
Nanobot.from_config()
mock_load.assert_called_once_with(None)
@pytest.mark.asyncio
async def test_run_returns_result(tmp_path):
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
from nanobot.bus.events import OutboundMessage
mock_response = OutboundMessage(
channel="cli", chat_id="direct", content="Hello back!"
)
bot._loop.process_direct = AsyncMock(return_value=mock_response)
result = await bot.run("hi")
assert isinstance(result, RunResult)
assert result.content == "Hello back!"
bot._loop.process_direct.assert_awaited_once_with("hi", session_key="sdk:default")
@pytest.mark.asyncio
async def test_run_with_hooks(tmp_path):
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.bus.events import OutboundMessage
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
class TestHook(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
pass
mock_response = OutboundMessage(
channel="cli", chat_id="direct", content="done"
)
bot._loop.process_direct = AsyncMock(return_value=mock_response)
result = await bot.run("hi", hooks=[TestHook()])
assert result.content == "done"
assert bot._loop._extra_hooks == []
@pytest.mark.asyncio
async def test_run_hooks_restored_on_error(tmp_path):
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
from nanobot.agent.hook import AgentHook
bot._loop.process_direct = AsyncMock(side_effect=RuntimeError("boom"))
original_hooks = bot._loop._extra_hooks
with pytest.raises(RuntimeError):
await bot.run("hi", hooks=[AgentHook()])
assert bot._loop._extra_hooks is original_hooks
@pytest.mark.asyncio
async def test_run_none_response(tmp_path):
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
bot._loop.process_direct = AsyncMock(return_value=None)
result = await bot.run("hi")
assert result.content == ""
def test_workspace_override(tmp_path):
config_path = _write_config(tmp_path)
custom_ws = tmp_path / "custom_workspace"
custom_ws.mkdir()
bot = Nanobot.from_config(config_path, workspace=custom_ws)
assert bot._loop.workspace == custom_ws
def test_sdk_make_provider_uses_github_copilot_backend():
from nanobot.config.schema import Config
from nanobot.nanobot import _make_provider
config = Config.model_validate(
{
"agents": {
"defaults": {
"provider": "github-copilot",
"model": "github-copilot/gpt-4.1",
}
}
}
)
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = _make_provider(config)
assert provider.__class__.__name__ == "GitHubCopilotProvider"
@pytest.mark.asyncio
async def test_run_custom_session_key(tmp_path):
from nanobot.bus.events import OutboundMessage
config_path = _write_config(tmp_path)
bot = Nanobot.from_config(config_path, workspace=tmp_path)
mock_response = OutboundMessage(
channel="cli", chat_id="direct", content="ok"
)
bot._loop.process_direct = AsyncMock(return_value=mock_response)
await bot.run("hi", session_key="user-alice")
bot._loop.process_direct.assert_awaited_once_with("hi", session_key="user-alice")
def test_import_from_top_level():
from nanobot import Nanobot as N, RunResult as R
assert N is Nanobot
assert R is RunResult

371
tests/test_openai_api.py Normal file
View File

@ -0,0 +1,371 @@
"""Focused tests for the fixed-session OpenAI-compatible API."""
from __future__ import annotations
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
import pytest_asyncio
from nanobot.api.server import (
API_CHAT_ID,
API_SESSION_KEY,
_chat_completion_response,
_error_json,
create_app,
handle_chat_completions,
)
try:
from aiohttp.test_utils import TestClient, TestServer
HAS_AIOHTTP = True
except ImportError:
HAS_AIOHTTP = False
pytest_plugins = ("pytest_asyncio",)
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
agent = MagicMock()
agent.process_direct = AsyncMock(return_value=response_text)
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
return agent
@pytest.fixture
def mock_agent():
return _make_mock_agent()
@pytest.fixture
def app(mock_agent):
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
@pytest_asyncio.fixture
async def aiohttp_client():
clients: list[TestClient] = []
async def _make_client(app):
client = TestClient(TestServer(app))
await client.start_server()
clients.append(client)
return client
try:
yield _make_client
finally:
for client in clients:
await client.close()
def test_error_json() -> None:
resp = _error_json(400, "bad request")
assert resp.status == 400
body = json.loads(resp.body)
assert body["error"]["message"] == "bad request"
assert body["error"]["code"] == 400
def test_chat_completion_response() -> None:
result = _chat_completion_response("hello world", "test-model")
assert result["object"] == "chat.completion"
assert result["model"] == "test-model"
assert result["choices"][0]["message"]["content"] == "hello world"
assert result["choices"][0]["finish_reason"] == "stop"
assert result["id"].startswith("chatcmpl-")
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_missing_messages_returns_400(aiohttp_client, app) -> None:
client = await aiohttp_client(app)
resp = await client.post("/v1/chat/completions", json={"model": "test"})
assert resp.status == 400
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_no_user_message_returns_400(aiohttp_client, app) -> None:
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "system", "content": "you are a bot"}]},
)
assert resp.status == 400
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_stream_true_returns_400(aiohttp_client, app) -> None:
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}], "stream": True},
)
assert resp.status == 400
body = await resp.json()
assert "stream" in body["error"]["message"].lower()
@pytest.mark.asyncio
async def test_model_mismatch_returns_400() -> None:
request = MagicMock()
request.json = AsyncMock(
return_value={
"model": "other-model",
"messages": [{"role": "user", "content": "hello"}],
}
)
request.app = {
"agent_loop": _make_mock_agent(),
"model_name": "test-model",
"request_timeout": 10.0,
"session_lock": asyncio.Lock(),
}
resp = await handle_chat_completions(request)
assert resp.status == 400
body = json.loads(resp.body)
assert "test-model" in body["error"]["message"]
@pytest.mark.asyncio
async def test_single_user_message_required() -> None:
request = MagicMock()
request.json = AsyncMock(
return_value={
"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "previous reply"},
],
}
)
request.app = {
"agent_loop": _make_mock_agent(),
"model_name": "test-model",
"request_timeout": 10.0,
"session_lock": asyncio.Lock(),
}
resp = await handle_chat_completions(request)
assert resp.status == 400
body = json.loads(resp.body)
assert "single user message" in body["error"]["message"].lower()
@pytest.mark.asyncio
async def test_single_user_message_must_have_user_role() -> None:
request = MagicMock()
request.json = AsyncMock(
return_value={
"messages": [{"role": "system", "content": "you are a bot"}],
}
)
request.app = {
"agent_loop": _make_mock_agent(),
"model_name": "test-model",
"request_timeout": 10.0,
"session_lock": asyncio.Lock(),
}
resp = await handle_chat_completions(request)
assert resp.status == 400
body = json.loads(resp.body)
assert "single user message" in body["error"]["message"].lower()
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_agent) -> None:
app = create_app(mock_agent, model_name="test-model")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "mock response"
assert body["model"] == "test-model"
mock_agent.process_direct.assert_called_once_with(
content="hello",
session_key=API_SESSION_KEY,
channel="api",
chat_id=API_CHAT_ID,
)
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
call_log: list[str] = []
async def fake_process(content, session_key="", channel="", chat_id=""):
call_log.append(session_key)
return f"reply to {content}"
agent = MagicMock()
agent.process_direct = fake_process
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
r1 = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "first"}]},
)
r2 = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "second"}]},
)
assert r1.status == 200
assert r2.status == 200
assert call_log == [API_SESSION_KEY, API_SESSION_KEY]
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None:
order: list[str] = []
async def slow_process(content, session_key="", channel="", chat_id=""):
order.append(f"start:{content}")
await asyncio.sleep(0.1)
order.append(f"end:{content}")
return content
agent = MagicMock()
agent.process_direct = slow_process
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
async def send(msg: str):
return await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": msg}]},
)
r1, r2 = await asyncio.gather(send("first"), send("second"))
assert r1.status == 200
assert r2.status == 200
# Verify serialization: one process must fully finish before the other starts
if order[0] == "start:first":
assert order.index("end:first") < order.index("start:second")
else:
assert order.index("end:second") < order.index("start:first")
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_models_endpoint(aiohttp_client, app) -> None:
client = await aiohttp_client(app)
resp = await client.get("/v1/models")
assert resp.status == 200
body = await resp.json()
assert body["object"] == "list"
assert body["data"][0]["id"] == "test-model"
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_health_endpoint(aiohttp_client, app) -> None:
client = await aiohttp_client(app)
resp = await client.get("/health")
assert resp.status == 200
body = await resp.json()
assert body["status"] == "ok"
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> None:
app = create_app(mock_agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "describe this"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
],
}
]
},
)
assert resp.status == 200
mock_agent.process_direct.assert_called_once_with(
content="describe this",
session_key=API_SESSION_KEY,
channel="api",
chat_id=API_CHAT_ID,
)
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_empty_response_retry_then_success(aiohttp_client) -> None:
call_count = 0
async def sometimes_empty(content, session_key="", channel="", chat_id=""):
nonlocal call_count
call_count += 1
if call_count == 1:
return ""
return "recovered response"
agent = MagicMock()
agent.process_direct = sometimes_empty
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "recovered response"
assert call_count == 2
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
@pytest.mark.asyncio
async def test_empty_response_falls_back(aiohttp_client) -> None:
call_count = 0
async def always_empty(content, session_key="", channel="", chat_id=""):
nonlocal call_count
call_count += 1
return ""
agent = MagicMock()
agent.process_direct = always_empty
agent._connect_mcp = AsyncMock()
agent.close_mcp = AsyncMock()
app = create_app(agent, model_name="m")
client = await aiohttp_client(app)
resp = await client.post(
"/v1/chat/completions",
json={"messages": [{"role": "user", "content": "hello"}]},
)
assert resp.status == 200
body = await resp.json()
assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give."
assert call_count == 2

View File

@ -95,6 +95,14 @@ def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
assert paths == [r"C:\user\workspace\txt"]
def test_exec_extract_absolute_paths_captures_windows_drive_root_path() -> None:
"""Windows drive root paths like `E:\\` must be extracted for workspace guarding."""
# Note: raw strings cannot end with a single backslash.
cmd = "dir E:\\"
paths = ExecTool._extract_absolute_paths(cmd)
assert paths == ["E:\\"]
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
cmd = ".venv/bin/python script.py"
paths = ExecTool._extract_absolute_paths(cmd)
@ -134,6 +142,45 @@ def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
assert error == "Error: Command blocked by safety guard (path outside working dir)"
def test_exec_guard_blocks_windows_drive_root_outside_workspace(monkeypatch) -> None:
import nanobot.agent.tools.shell as shell_mod
class FakeWindowsPath:
def __init__(self, raw: str) -> None:
self.raw = raw.rstrip("\\") + ("\\" if raw.endswith("\\") else "")
def resolve(self) -> "FakeWindowsPath":
return self
def expanduser(self) -> "FakeWindowsPath":
return self
def is_absolute(self) -> bool:
return len(self.raw) >= 3 and self.raw[1:3] == ":\\"
@property
def parents(self) -> list["FakeWindowsPath"]:
if not self.is_absolute():
return []
trimmed = self.raw.rstrip("\\")
if len(trimmed) <= 2:
return []
idx = trimmed.rfind("\\")
if idx <= 2:
return [FakeWindowsPath(trimmed[:2] + "\\")]
parent = FakeWindowsPath(trimmed[:idx])
return [parent, *parent.parents]
def __eq__(self, other: object) -> bool:
return isinstance(other, FakeWindowsPath) and self.raw.lower() == other.raw.lower()
monkeypatch.setattr(shell_mod, "Path", FakeWindowsPath)
tool = ExecTool(restrict_to_workspace=True)
error = tool._guard_command("dir E:\\", "E:\\workspace")
assert error == "Error: Command blocked by safety guard (path outside working dir)"
# --- cast_params tests ---