mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-04 10:22:33 +00:00
merge: resolve conflicts with upstream/main, preserve typing indicator
This commit is contained in:
commit
ca68a89ce6
107
README.md
107
README.md
@ -115,6 +115,8 @@
|
||||
- [Configuration](#️-configuration)
|
||||
- [Multiple Instances](#-multiple-instances)
|
||||
- [CLI Reference](#-cli-reference)
|
||||
- [Python SDK](#-python-sdk)
|
||||
- [OpenAI-Compatible API](#-openai-compatible-api)
|
||||
- [Docker](#-docker)
|
||||
- [Linux Service](#-linux-service)
|
||||
- [Project Structure](#-project-structure)
|
||||
@ -1541,6 +1543,7 @@ nanobot gateway --config ~/.nanobot-telegram/config.json --workspace /tmp/nanobo
|
||||
| `nanobot agent` | Interactive chat mode |
|
||||
| `nanobot agent --no-markdown` | Show plain-text replies |
|
||||
| `nanobot agent --logs` | Show runtime logs during chat |
|
||||
| `nanobot serve` | Start the OpenAI-compatible API |
|
||||
| `nanobot gateway` | Start the gateway |
|
||||
| `nanobot status` | Show status |
|
||||
| `nanobot provider login openai-codex` | OAuth login for providers |
|
||||
@ -1569,6 +1572,110 @@ The agent can also manage this file itself — ask it to "add a periodic task" a
|
||||
|
||||
</details>
|
||||
|
||||
## 🐍 Python SDK
|
||||
|
||||
Use nanobot as a library — no CLI, no gateway, just Python:
|
||||
|
||||
```python
|
||||
from nanobot import Nanobot
|
||||
|
||||
bot = Nanobot.from_config()
|
||||
result = await bot.run("Summarize the README")
|
||||
print(result.content)
|
||||
```
|
||||
|
||||
Each call carries a `session_key` for conversation isolation — different keys get independent history:
|
||||
|
||||
```python
|
||||
await bot.run("hi", session_key="user-alice")
|
||||
await bot.run("hi", session_key="task-42")
|
||||
```
|
||||
|
||||
Add lifecycle hooks to observe or customize the agent:
|
||||
|
||||
```python
|
||||
from nanobot.agent import AgentHook, AgentHookContext
|
||||
|
||||
class AuditHook(AgentHook):
|
||||
async def before_execute_tools(self, ctx: AgentHookContext) -> None:
|
||||
for tc in ctx.tool_calls:
|
||||
print(f"[tool] {tc.name}")
|
||||
|
||||
result = await bot.run("Hello", hooks=[AuditHook()])
|
||||
```
|
||||
|
||||
See [docs/PYTHON_SDK.md](docs/PYTHON_SDK.md) for the full SDK reference.
|
||||
|
||||
## 🔌 OpenAI-Compatible API
|
||||
|
||||
nanobot can expose a minimal OpenAI-compatible endpoint for local integrations:
|
||||
|
||||
```bash
|
||||
pip install "nanobot-ai[api]"
|
||||
nanobot serve
|
||||
```
|
||||
|
||||
By default, the API binds to `127.0.0.1:8900`. You can change this in `config.json`.
|
||||
|
||||
### Behavior
|
||||
|
||||
- Session isolation: pass `"session_id"` in the request body to isolate conversations; omit for a shared default session (`api:default`)
|
||||
- Single-message input: each request must contain exactly one `user` message
|
||||
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
|
||||
- No streaming: `stream=true` is not supported
|
||||
|
||||
### Endpoints
|
||||
|
||||
- `GET /health`
|
||||
- `GET /v1/models`
|
||||
- `POST /v1/chat/completions`
|
||||
|
||||
### curl
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"session_id": "my-session"
|
||||
}'
|
||||
```
|
||||
|
||||
### Python (`requests`)
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
resp = requests.post(
|
||||
"http://127.0.0.1:8900/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"session_id": "my-session", # optional: isolate conversation
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
print(resp.json()["choices"][0]["message"]["content"])
|
||||
```
|
||||
|
||||
### Python (`openai`)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://127.0.0.1:8900/v1",
|
||||
api_key="dummy",
|
||||
)
|
||||
|
||||
resp = client.chat.completions.create(
|
||||
model="MiniMax-M2.7",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
extra_body={"session_id": "my-session"}, # optional: isolate conversation
|
||||
)
|
||||
print(resp.choices[0].message.content)
|
||||
```
|
||||
|
||||
## 🐳 Docker
|
||||
|
||||
> [!TIP]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#!/bin/bash
|
||||
# Count core agent lines (excluding channels/, cli/, providers/ adapters)
|
||||
# Count core agent lines (excluding channels/, cli/, api/, providers/ adapters,
|
||||
# and the high-level Python SDK facade)
|
||||
cd "$(dirname "$0")" || exit 1
|
||||
|
||||
echo "nanobot core agent line count"
|
||||
@ -15,7 +16,7 @@ root=$(cat nanobot/__init__.py nanobot/__main__.py | wc -l)
|
||||
printf " %-16s %5s lines\n" "(root)" "$root"
|
||||
|
||||
echo ""
|
||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" | xargs cat | wc -l)
|
||||
total=$(find nanobot -name "*.py" ! -path "*/channels/*" ! -path "*/cli/*" ! -path "*/api/*" ! -path "*/command/*" ! -path "*/providers/*" ! -path "*/skills/*" ! -path "nanobot/nanobot.py" | xargs cat | wc -l)
|
||||
echo " Core total: $total lines"
|
||||
echo ""
|
||||
echo " (excludes: channels/, cli/, command/, providers/, skills/)"
|
||||
echo " (excludes: channels/, cli/, api/, command/, providers/, skills/, nanobot.py)"
|
||||
|
||||
136
docs/PYTHON_SDK.md
Normal file
136
docs/PYTHON_SDK.md
Normal file
@ -0,0 +1,136 @@
|
||||
# Python SDK
|
||||
|
||||
Use nanobot programmatically — load config, run the agent, get results.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from nanobot import Nanobot
|
||||
|
||||
async def main():
|
||||
bot = Nanobot.from_config()
|
||||
result = await bot.run("What time is it in Tokyo?")
|
||||
print(result.content)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## API
|
||||
|
||||
### `Nanobot.from_config(config_path?, *, workspace?)`
|
||||
|
||||
Create a `Nanobot` from a config file.
|
||||
|
||||
| Param | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `config_path` | `str \| Path \| None` | `None` | Path to `config.json`. Defaults to `~/.nanobot/config.json`. |
|
||||
| `workspace` | `str \| Path \| None` | `None` | Override workspace directory from config. |
|
||||
|
||||
Raises `FileNotFoundError` if an explicit path doesn't exist.
|
||||
|
||||
### `await bot.run(message, *, session_key?, hooks?)`
|
||||
|
||||
Run the agent once. Returns a `RunResult`.
|
||||
|
||||
| Param | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `message` | `str` | *(required)* | The user message to process. |
|
||||
| `session_key` | `str` | `"sdk:default"` | Session identifier for conversation isolation. Different keys get independent history. |
|
||||
| `hooks` | `list[AgentHook] \| None` | `None` | Lifecycle hooks for this run only. |
|
||||
|
||||
```python
|
||||
# Isolated sessions — each user gets independent conversation history
|
||||
await bot.run("hi", session_key="user-alice")
|
||||
await bot.run("hi", session_key="user-bob")
|
||||
```
|
||||
|
||||
### `RunResult`
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `content` | `str` | The agent's final text response. |
|
||||
| `tools_used` | `list[str]` | Tool names invoked during the run. |
|
||||
| `messages` | `list[dict]` | Raw message history (for debugging). |
|
||||
|
||||
## Hooks
|
||||
|
||||
Hooks let you observe or modify the agent loop without touching internals.
|
||||
|
||||
Subclass `AgentHook` and override any method:
|
||||
|
||||
| Method | When |
|
||||
|--------|------|
|
||||
| `before_iteration(ctx)` | Before each LLM call |
|
||||
| `on_stream(ctx, delta)` | On each streamed token |
|
||||
| `on_stream_end(ctx)` | When streaming finishes |
|
||||
| `before_execute_tools(ctx)` | Before tool execution (inspect `ctx.tool_calls`) |
|
||||
| `after_iteration(ctx, response)` | After each LLM response |
|
||||
| `finalize_content(ctx, content)` | Transform final output text |
|
||||
|
||||
### Example: Audit Hook
|
||||
|
||||
```python
|
||||
from nanobot.agent import AgentHook, AgentHookContext
|
||||
|
||||
class AuditHook(AgentHook):
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def before_execute_tools(self, ctx: AgentHookContext) -> None:
|
||||
for tc in ctx.tool_calls:
|
||||
self.calls.append(tc.name)
|
||||
print(f"[audit] {tc.name}({tc.arguments})")
|
||||
|
||||
hook = AuditHook()
|
||||
result = await bot.run("List files in /tmp", hooks=[hook])
|
||||
print(f"Tools used: {hook.calls}")
|
||||
```
|
||||
|
||||
### Composing Hooks
|
||||
|
||||
Pass multiple hooks — they run in order, errors in one don't block others:
|
||||
|
||||
```python
|
||||
result = await bot.run("hi", hooks=[AuditHook(), MetricsHook()])
|
||||
```
|
||||
|
||||
Under the hood this uses `CompositeHook` for fan-out with error isolation.
|
||||
|
||||
### `finalize_content` Pipeline
|
||||
|
||||
Unlike the async methods (fan-out), `finalize_content` is a pipeline — each hook's output feeds the next:
|
||||
|
||||
```python
|
||||
class Censor(AgentHook):
|
||||
def finalize_content(self, ctx, content):
|
||||
return content.replace("secret", "***") if content else content
|
||||
```
|
||||
|
||||
## Full Example
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from nanobot import Nanobot
|
||||
from nanobot.agent import AgentHook, AgentHookContext
|
||||
|
||||
class TimingHook(AgentHook):
|
||||
async def before_iteration(self, ctx: AgentHookContext) -> None:
|
||||
import time
|
||||
ctx.metadata["_t0"] = time.time()
|
||||
|
||||
async def after_iteration(self, ctx, response) -> None:
|
||||
import time
|
||||
elapsed = time.time() - ctx.metadata.get("_t0", 0)
|
||||
print(f"[timing] iteration took {elapsed:.2f}s")
|
||||
|
||||
async def main():
|
||||
bot = Nanobot.from_config(workspace="/my/project")
|
||||
result = await bot.run(
|
||||
"Explain the main function",
|
||||
hooks=[TimingHook()],
|
||||
)
|
||||
print(result.content)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
@ -4,3 +4,7 @@ nanobot - A lightweight AI agent framework
|
||||
|
||||
__version__ = "0.1.4.post6"
|
||||
__logo__ = "🐈"
|
||||
|
||||
from nanobot.nanobot import Nanobot, RunResult
|
||||
|
||||
__all__ = ["Nanobot", "RunResult"]
|
||||
|
||||
@ -1,8 +1,19 @@
|
||||
"""Agent core module."""
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
|
||||
__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"]
|
||||
__all__ = [
|
||||
"AgentHook",
|
||||
"AgentHookContext",
|
||||
"AgentLoop",
|
||||
"CompositeHook",
|
||||
"ContextBuilder",
|
||||
"MemoryStore",
|
||||
"SkillsLoader",
|
||||
"SubagentManager",
|
||||
]
|
||||
|
||||
@ -5,6 +5,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
@ -47,3 +49,60 @@ class AgentHook:
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return content
|
||||
|
||||
|
||||
class CompositeHook(AgentHook):
|
||||
"""Fan-out hook that delegates to an ordered list of hooks.
|
||||
|
||||
Error isolation: async methods catch and log per-hook exceptions
|
||||
so a faulty custom hook cannot crash the agent loop.
|
||||
``finalize_content`` is a pipeline (no isolation — bugs should surface).
|
||||
"""
|
||||
|
||||
__slots__ = ("_hooks",)
|
||||
|
||||
def __init__(self, hooks: list[AgentHook]) -> None:
|
||||
self._hooks = list(hooks)
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return any(h.wants_streaming() for h in self._hooks)
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.before_iteration(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.before_iteration error in {}", type(h).__name__)
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.on_stream(context, delta)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.on_stream error in {}", type(h).__name__)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.on_stream_end(context, resuming=resuming)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.on_stream_end error in {}", type(h).__name__)
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.before_execute_tools(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.before_execute_tools error in {}", type(h).__name__)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.after_iteration(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.after_iteration error in {}", type(h).__name__)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
for h in self._hooks:
|
||||
content = h.finalize_content(context, content)
|
||||
return content
|
||||
|
||||
@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
@ -37,6 +37,120 @@ if TYPE_CHECKING:
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
"""Core lifecycle hook for the main agent loop.
|
||||
|
||||
Handles streaming delta relay, progress reporting, tool-call logging,
|
||||
and think-tag stripping for the built-in agent path.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_loop: AgentLoop,
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
*,
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
self._loop = agent_loop
|
||||
self._on_progress = on_progress
|
||||
self._on_stream = on_stream
|
||||
self._on_stream_end = on_stream_end
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
self._message_id = message_id
|
||||
self._stream_buf = ""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return self._on_stream is not None
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
if incremental and self._on_stream:
|
||||
await self._on_stream(incremental)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
if self._on_stream_end:
|
||||
await self._on_stream_end(resuming=resuming)
|
||||
self._stream_buf = ""
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
if self._on_progress:
|
||||
if not self._on_stream:
|
||||
thought = self._loop._strip_think(
|
||||
context.response.content if context.response else None
|
||||
)
|
||||
if thought:
|
||||
await self._on_progress(thought)
|
||||
tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls))
|
||||
await self._on_progress(tool_hint, tool_hint=True)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
u = context.usage or {}
|
||||
logger.debug(
|
||||
"LLM usage: prompt={} completion={} cached={}",
|
||||
u.get("prompt_tokens", 0),
|
||||
u.get("completion_tokens", 0),
|
||||
u.get("cached_tokens", 0),
|
||||
)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return self._loop._strip_think(content)
|
||||
|
||||
|
||||
class _LoopHookChain(AgentHook):
|
||||
"""Run the core loop hook first, then best-effort extra hooks.
|
||||
|
||||
This preserves the historical failure behavior of ``_LoopHook`` while still
|
||||
letting user-supplied hooks opt into ``CompositeHook`` isolation.
|
||||
"""
|
||||
|
||||
__slots__ = ("_primary", "_extras")
|
||||
|
||||
def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None:
|
||||
self._primary = primary
|
||||
self._extras = CompositeHook(extra_hooks)
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return self._primary.wants_streaming() or self._extras.wants_streaming()
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
await self._primary.before_iteration(context)
|
||||
await self._extras.before_iteration(context)
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
await self._primary.on_stream(context, delta)
|
||||
await self._extras.on_stream(context, delta)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
await self._primary.on_stream_end(context, resuming=resuming)
|
||||
await self._extras.on_stream_end(context, resuming=resuming)
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
await self._primary.before_execute_tools(context)
|
||||
await self._extras.before_execute_tools(context)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
await self._primary.after_iteration(context)
|
||||
await self._extras.after_iteration(context)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
content = self._primary.finalize_content(context, content)
|
||||
return self._extras.finalize_content(context, content)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
@ -68,6 +182,7 @@ class AgentLoop:
|
||||
mcp_servers: dict | None = None,
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
timezone: str | None = None,
|
||||
hooks: list[AgentHook] | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
||||
|
||||
@ -85,6 +200,7 @@ class AgentLoop:
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self._start_time = time.time()
|
||||
self._last_usage: dict[str, int] = {}
|
||||
self._extra_hooks: list[AgentHook] = hooks or []
|
||||
|
||||
self.context = ContextBuilder(workspace, timezone=timezone)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
@ -217,52 +333,27 @@ class AgentLoop:
|
||||
``resuming=True`` means tool calls follow (spinner should restart);
|
||||
``resuming=False`` means this is the final response.
|
||||
"""
|
||||
loop_self = self
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
def __init__(self) -> None:
|
||||
self._stream_buf = ""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return on_stream is not None
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
if incremental and on_stream:
|
||||
await on_stream(incremental)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
if on_stream_end:
|
||||
await on_stream_end(resuming=resuming)
|
||||
self._stream_buf = ""
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
if on_progress:
|
||||
if not on_stream:
|
||||
thought = loop_self._strip_think(context.response.content if context.response else None)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls))
|
||||
await on_progress(tool_hint, tool_hint=True)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
loop_self._set_tool_context(channel, chat_id, message_id)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return loop_self._strip_think(content)
|
||||
loop_hook = _LoopHook(
|
||||
self,
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
hook: AgentHook = (
|
||||
_LoopHookChain(loop_hook, self._extra_hooks)
|
||||
if self._extra_hooks
|
||||
else loop_hook
|
||||
)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
hook=_LoopHook(),
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
))
|
||||
@ -321,25 +412,25 @@ class AgentLoop:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata={
|
||||
"_stream_delta": True,
|
||||
"_stream_id": _current_stream_id(),
|
||||
},
|
||||
metadata=meta,
|
||||
))
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata={
|
||||
"_stream_end": True,
|
||||
"_resuming": resuming,
|
||||
"_stream_id": _current_stream_id(),
|
||||
},
|
||||
metadata=meta,
|
||||
))
|
||||
stream_segment += 1
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ class AgentRunner:
|
||||
messages = list(spec.initial_messages)
|
||||
final_content: str | None = None
|
||||
tools_used: list[str] = []
|
||||
usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
usage: dict[str, int] = {}
|
||||
error: str | None = None
|
||||
stop_reason = "completed"
|
||||
tool_events: list[dict[str, str]] = []
|
||||
@ -92,13 +92,15 @@ class AgentRunner:
|
||||
response = await self.provider.chat_with_retry(**kwargs)
|
||||
|
||||
raw_usage = response.usage or {}
|
||||
usage = {
|
||||
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
|
||||
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
|
||||
}
|
||||
context.response = response
|
||||
context.usage = usage
|
||||
context.usage = raw_usage
|
||||
context.tool_calls = list(response.tool_calls)
|
||||
# Accumulate standard fields into result usage.
|
||||
usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0)
|
||||
usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0)
|
||||
cached = raw_usage.get("cached_tokens")
|
||||
if cached:
|
||||
usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached)
|
||||
|
||||
if response.has_tool_calls:
|
||||
if hook.wants_streaming():
|
||||
|
||||
@ -21,6 +21,21 @@ from nanobot.config.schema import ExecToolConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
class _SubagentHook(AgentHook):
|
||||
"""Logging-only hook for subagent execution."""
|
||||
|
||||
def __init__(self, task_id: str) -> None:
|
||||
self._task_id = task_id
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for tool_call in context.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug(
|
||||
"Subagent [{}] executing: {} with arguments: {}",
|
||||
self._task_id, tool_call.name, args_str,
|
||||
)
|
||||
|
||||
|
||||
class SubagentManager:
|
||||
"""Manages background subagent execution."""
|
||||
|
||||
@ -100,33 +115,28 @@ class SubagentManager:
|
||||
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
if self.exec_config.enable:
|
||||
tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
tools.register(WebSearchTool(config=self.web_search_config, proxy=self.web_proxy))
|
||||
tools.register(WebFetchTool(proxy=self.web_proxy))
|
||||
|
||||
|
||||
system_prompt = self._build_subagent_prompt()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": task},
|
||||
]
|
||||
|
||||
class _SubagentHook(AgentHook):
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for tool_call in context.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model=self.model,
|
||||
max_iterations=15,
|
||||
hook=_SubagentHook(),
|
||||
hook=_SubagentHook(task_id),
|
||||
max_iterations_message="Task completed but no final response was generated.",
|
||||
error_message=None,
|
||||
fail_on_tool_error=True,
|
||||
@ -213,7 +223,7 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
||||
lines.append("Failure:")
|
||||
lines.append(f"- {result.error}")
|
||||
return "\n".join(lines) or (result.error or "Error: subagent execution failed.")
|
||||
|
||||
|
||||
def _build_subagent_prompt(self) -> str:
|
||||
"""Build a focused system prompt for the subagent."""
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
@ -74,7 +74,7 @@ class CronTool(Tool):
|
||||
"enum": ["add", "list", "remove"],
|
||||
"description": "Action to perform",
|
||||
},
|
||||
"message": {"type": "string", "description": "Reminder message (for add)"},
|
||||
"message": {"type": "string", "description": "Instruction for the agent to execute when the job triggers (e.g., 'Send a reminder to WeChat: xxx' or 'Check system status and report')"},
|
||||
"every_seconds": {
|
||||
"type": "integer",
|
||||
"description": "Interval in seconds (for recurring tasks)",
|
||||
@ -97,6 +97,11 @@ class CronTool(Tool):
|
||||
f"(e.g. '2026-02-12T10:30:00'). Naive values default to {self._default_timezone}."
|
||||
),
|
||||
},
|
||||
"deliver": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to deliver the execution result to the user channel (default true)",
|
||||
"default": True
|
||||
},
|
||||
"job_id": {"type": "string", "description": "Job ID (for remove)"},
|
||||
},
|
||||
"required": ["action"],
|
||||
@ -111,12 +116,13 @@ class CronTool(Tool):
|
||||
tz: str | None = None,
|
||||
at: str | None = None,
|
||||
job_id: str | None = None,
|
||||
deliver: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if action == "add":
|
||||
if self._in_cron_context.get():
|
||||
return "Error: cannot schedule new jobs from within a cron job execution"
|
||||
return self._add_job(message, every_seconds, cron_expr, tz, at)
|
||||
return self._add_job(message, every_seconds, cron_expr, tz, at, deliver)
|
||||
elif action == "list":
|
||||
return self._list_jobs()
|
||||
elif action == "remove":
|
||||
@ -130,6 +136,7 @@ class CronTool(Tool):
|
||||
cron_expr: str | None,
|
||||
tz: str | None,
|
||||
at: str | None,
|
||||
deliver: bool = True,
|
||||
) -> str:
|
||||
if not message:
|
||||
return "Error: message is required for add"
|
||||
@ -171,7 +178,7 @@ class CronTool(Tool):
|
||||
name=message[:30],
|
||||
schedule=schedule,
|
||||
message=message,
|
||||
deliver=True,
|
||||
deliver=deliver,
|
||||
channel=self._channel,
|
||||
to=self._chat_id,
|
||||
delete_after_run=delete_after,
|
||||
|
||||
@ -170,7 +170,11 @@ async def connect_mcp_servers(
|
||||
timeout: httpx.Timeout | None = None,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> httpx.AsyncClient:
|
||||
merged_headers = {**(cfg.headers or {}), **(headers or {})}
|
||||
merged_headers = {
|
||||
"Accept": "application/json, text/event-stream",
|
||||
**(cfg.headers or {}),
|
||||
**(headers or {}),
|
||||
}
|
||||
return httpx.AsyncClient(
|
||||
headers=merged_headers or None,
|
||||
follow_redirects=True,
|
||||
|
||||
@ -86,7 +86,15 @@ class MessageTool(Tool):
|
||||
) -> str:
|
||||
channel = channel or self._default_channel
|
||||
chat_id = chat_id or self._default_chat_id
|
||||
message_id = message_id or self._default_message_id
|
||||
# Only inherit default message_id when targeting the same channel+chat.
|
||||
# Cross-chat sends must not carry the original message_id, because
|
||||
# some channels (e.g. Feishu) use it to determine the target
|
||||
# conversation via their Reply API, which would route the message
|
||||
# to the wrong chat entirely.
|
||||
if channel == self._default_channel and chat_id == self._default_chat_id:
|
||||
message_id = message_id or self._default_message_id
|
||||
else:
|
||||
message_id = None
|
||||
|
||||
if not channel or not chat_id:
|
||||
return "Error: No target channel/chat specified"
|
||||
@ -101,7 +109,7 @@ class MessageTool(Tool):
|
||||
media=media or [],
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
},
|
||||
} if message_id else {},
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -186,7 +186,9 @@ class ExecTool(Tool):
|
||||
|
||||
@staticmethod
|
||||
def _extract_absolute_paths(command: str) -> list[str]:
|
||||
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]+", command) # Windows: C:\...
|
||||
# Windows: match drive-root paths like `C:\` as well as `C:\path\to\file`
|
||||
# NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted.
|
||||
win_paths = re.findall(r"[A-Za-z]:\\[^\s\"'|><;]*", command)
|
||||
posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
|
||||
home_paths = re.findall(r"(?:^|[\s|>'\"])(~[^\s\"'>;|<]*)", command) # POSIX/Windows home shortcut: ~
|
||||
return win_paths + posix_paths + home_paths
|
||||
|
||||
1
nanobot/api/__init__.py
Normal file
1
nanobot/api/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""OpenAI-compatible HTTP API for nanobot."""
|
||||
193
nanobot/api/server.py
Normal file
193
nanobot/api/server.py
Normal file
@ -0,0 +1,193 @@
|
||||
"""OpenAI-compatible HTTP API server for a fixed nanobot session.
|
||||
|
||||
Provides /v1/chat/completions and /v1/models endpoints.
|
||||
All requests route to a single persistent API session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
from loguru import logger
|
||||
|
||||
API_SESSION_KEY = "api:default"
|
||||
API_CHAT_ID = "default"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response:
|
||||
return web.json_response(
|
||||
{"error": {"message": message, "type": err_type, "code": status}},
|
||||
status=status,
|
||||
)
|
||||
|
||||
|
||||
def _chat_completion_response(content: str, model: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
def _response_text(value: Any) -> str:
|
||||
"""Normalize process_direct output to plain assistant text."""
|
||||
if value is None:
|
||||
return ""
|
||||
if hasattr(value, "content"):
|
||||
return str(getattr(value, "content") or "")
|
||||
return str(value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
"""POST /v1/chat/completions"""
|
||||
|
||||
# --- Parse body ---
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return _error_json(400, "Invalid JSON body")
|
||||
|
||||
messages = body.get("messages")
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
return _error_json(400, "Only a single user message is supported")
|
||||
|
||||
# Stream not yet supported
|
||||
if body.get("stream", False):
|
||||
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
|
||||
|
||||
message = messages[0]
|
||||
if not isinstance(message, dict) or message.get("role") != "user":
|
||||
return _error_json(400, "Only a single user message is supported")
|
||||
user_content = message.get("content", "")
|
||||
if isinstance(user_content, list):
|
||||
# Multi-modal content array — extract text parts
|
||||
user_content = " ".join(
|
||||
part.get("text", "") for part in user_content if part.get("type") == "text"
|
||||
)
|
||||
|
||||
agent_loop = request.app["agent_loop"]
|
||||
timeout_s: float = request.app.get("request_timeout", 120.0)
|
||||
model_name: str = request.app.get("model_name", "nanobot")
|
||||
if (requested_model := body.get("model")) and requested_model != model_name:
|
||||
return _error_json(400, f"Only configured model '{model_name}' is available")
|
||||
|
||||
session_key = f"api:{body['session_id']}" if body.get("session_id") else API_SESSION_KEY
|
||||
session_locks: dict[str, asyncio.Lock] = request.app["session_locks"]
|
||||
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
|
||||
|
||||
logger.info("API request session_key={} content={}", session_key, user_content[:80])
|
||||
|
||||
_FALLBACK = "I've completed processing but have no response to give."
|
||||
|
||||
try:
|
||||
async with session_lock:
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
session_key=session_key,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
response_text = _response_text(response)
|
||||
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning(
|
||||
"Empty response for session {}, retrying",
|
||||
session_key,
|
||||
)
|
||||
retry_response = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
session_key=session_key,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
response_text = _response_text(retry_response)
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning(
|
||||
"Empty response after retry for session {}, using fallback",
|
||||
session_key,
|
||||
)
|
||||
response_text = _FALLBACK
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return _error_json(504, f"Request timed out after {timeout_s}s")
|
||||
except Exception:
|
||||
logger.exception("Error processing request for session {}", session_key)
|
||||
return _error_json(500, "Internal server error", err_type="server_error")
|
||||
except Exception:
|
||||
logger.exception("Unexpected API lock error for session {}", session_key)
|
||||
return _error_json(500, "Internal server error", err_type="server_error")
|
||||
|
||||
return web.json_response(_chat_completion_response(response_text, model_name))
|
||||
|
||||
|
||||
async def handle_models(request: web.Request) -> web.Response:
|
||||
"""GET /v1/models"""
|
||||
model_name = request.app.get("model_name", "nanobot")
|
||||
return web.json_response({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "nanobot",
|
||||
}
|
||||
],
|
||||
})
|
||||
|
||||
|
||||
async def handle_health(request: web.Request) -> web.Response:
|
||||
"""GET /health"""
|
||||
return web.json_response({"status": "ok"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# App factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0) -> web.Application:
|
||||
"""Create the aiohttp application.
|
||||
|
||||
Args:
|
||||
agent_loop: An initialized AgentLoop instance.
|
||||
model_name: Model name reported in responses.
|
||||
request_timeout: Per-request timeout in seconds.
|
||||
"""
|
||||
app = web.Application()
|
||||
app["agent_loop"] = agent_loop
|
||||
app["model_name"] = model_name
|
||||
app["request_timeout"] = request_timeout
|
||||
app["session_locks"] = {} # per-user locks, keyed by session_key
|
||||
|
||||
app.router.add_post("/v1/chat/completions", handle_chat_completions)
|
||||
app.router.add_get("/v1/models", handle_models)
|
||||
app.router.add_get("/health", handle_health)
|
||||
return app
|
||||
@ -1,25 +1,37 @@
|
||||
"""Discord channel implementation using Discord Gateway websocket."""
|
||||
"""Discord channel implementation using discord.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import httpx
|
||||
from pydantic import Field
|
||||
import websockets
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.command.builtin import build_help_text
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
from nanobot.utils.helpers import split_message
|
||||
from nanobot.utils.helpers import safe_filename, split_message
|
||||
|
||||
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
|
||||
if TYPE_CHECKING:
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.abc import Messageable
|
||||
|
||||
if DISCORD_AVAILABLE:
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.abc import Messageable
|
||||
|
||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
TYPING_INTERVAL_S = 8
|
||||
|
||||
|
||||
class DiscordConfig(Base):
|
||||
@ -28,13 +40,205 @@ class DiscordConfig(Base):
|
||||
enabled: bool = False
|
||||
token: str = ""
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
|
||||
intents: int = 37377
|
||||
group_policy: Literal["mention", "open"] = "mention"
|
||||
read_receipt_emoji: str = "👀"
|
||||
working_emoji: str = "🔧"
|
||||
working_emoji_delay: float = 2.0
|
||||
|
||||
|
||||
if DISCORD_AVAILABLE:
|
||||
|
||||
class DiscordBotClient(discord.Client):
|
||||
"""discord.py client that forwards events to the channel."""
|
||||
|
||||
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
|
||||
super().__init__(intents=intents)
|
||||
self._channel = channel
|
||||
self.tree = app_commands.CommandTree(self)
|
||||
self._register_app_commands()
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
self._channel._bot_user_id = str(self.user.id) if self.user else None
|
||||
logger.info("Discord bot connected as user {}", self._channel._bot_user_id)
|
||||
try:
|
||||
synced = await self.tree.sync()
|
||||
logger.info("Discord app commands synced: {}", len(synced))
|
||||
except Exception as e:
|
||||
logger.warning("Discord app command sync failed: {}", e)
|
||||
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
await self._channel._handle_discord_message(message)
|
||||
|
||||
async def _reply_ephemeral(self, interaction: discord.Interaction, text: str) -> bool:
|
||||
"""Send an ephemeral interaction response and report success."""
|
||||
try:
|
||||
await interaction.response.send_message(text, ephemeral=True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning("Discord interaction response failed: {}", e)
|
||||
return False
|
||||
|
||||
async def _forward_slash_command(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
command_text: str,
|
||||
) -> None:
|
||||
sender_id = str(interaction.user.id)
|
||||
channel_id = interaction.channel_id
|
||||
|
||||
if channel_id is None:
|
||||
logger.warning("Discord slash command missing channel_id: {}", command_text)
|
||||
return
|
||||
|
||||
if not self._channel.is_allowed(sender_id):
|
||||
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
||||
return
|
||||
|
||||
await self._reply_ephemeral(interaction, f"Processing {command_text}...")
|
||||
|
||||
await self._channel._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=str(channel_id),
|
||||
content=command_text,
|
||||
metadata={
|
||||
"interaction_id": str(interaction.id),
|
||||
"guild_id": str(interaction.guild_id) if interaction.guild_id else None,
|
||||
"is_slash_command": True,
|
||||
},
|
||||
)
|
||||
|
||||
def _register_app_commands(self) -> None:
|
||||
commands = (
|
||||
("new", "Start a new conversation", "/new"),
|
||||
("stop", "Stop the current task", "/stop"),
|
||||
("restart", "Restart the bot", "/restart"),
|
||||
("status", "Show bot status", "/status"),
|
||||
)
|
||||
|
||||
for name, description, command_text in commands:
|
||||
@self.tree.command(name=name, description=description)
|
||||
async def command_handler(
|
||||
interaction: discord.Interaction,
|
||||
_command_text: str = command_text,
|
||||
) -> None:
|
||||
await self._forward_slash_command(interaction, _command_text)
|
||||
|
||||
@self.tree.command(name="help", description="Show available commands")
|
||||
async def help_command(interaction: discord.Interaction) -> None:
|
||||
sender_id = str(interaction.user.id)
|
||||
if not self._channel.is_allowed(sender_id):
|
||||
await self._reply_ephemeral(interaction, "You are not allowed to use this bot.")
|
||||
return
|
||||
await self._reply_ephemeral(interaction, build_help_text())
|
||||
|
||||
@self.tree.error
|
||||
async def on_app_command_error(
|
||||
interaction: discord.Interaction,
|
||||
error: app_commands.AppCommandError,
|
||||
) -> None:
|
||||
command_name = interaction.command.qualified_name if interaction.command else "?"
|
||||
logger.warning(
|
||||
"Discord app command failed user={} channel={} cmd={} error={}",
|
||||
interaction.user.id,
|
||||
interaction.channel_id,
|
||||
command_name,
|
||||
error,
|
||||
)
|
||||
|
||||
async def send_outbound(self, msg: OutboundMessage) -> None:
|
||||
"""Send a nanobot outbound message using Discord transport rules."""
|
||||
channel_id = int(msg.chat_id)
|
||||
|
||||
channel = self.get_channel(channel_id)
|
||||
if channel is None:
|
||||
try:
|
||||
channel = await self.fetch_channel(channel_id)
|
||||
except Exception as e:
|
||||
logger.warning("Discord channel {} unavailable: {}", msg.chat_id, e)
|
||||
return
|
||||
|
||||
reference, mention_settings = self._build_reply_context(channel, msg.reply_to)
|
||||
sent_media = False
|
||||
failed_media: list[str] = []
|
||||
|
||||
for index, media_path in enumerate(msg.media or []):
|
||||
if await self._send_file(
|
||||
channel,
|
||||
media_path,
|
||||
reference=reference if index == 0 else None,
|
||||
mention_settings=mention_settings,
|
||||
):
|
||||
sent_media = True
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
|
||||
kwargs: dict[str, Any] = {"content": chunk}
|
||||
if index == 0 and reference is not None and not sent_media:
|
||||
kwargs["reference"] = reference
|
||||
kwargs["allowed_mentions"] = mention_settings
|
||||
await channel.send(**kwargs)
|
||||
|
||||
async def _send_file(
|
||||
self,
|
||||
channel: Messageable,
|
||||
file_path: str,
|
||||
*,
|
||||
reference: discord.PartialMessage | None,
|
||||
mention_settings: discord.AllowedMentions,
|
||||
) -> bool:
|
||||
"""Send a file attachment via discord.py."""
|
||||
path = Path(file_path)
|
||||
if not path.is_file():
|
||||
logger.warning("Discord file not found, skipping: {}", file_path)
|
||||
return False
|
||||
|
||||
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
||||
return False
|
||||
|
||||
try:
|
||||
kwargs: dict[str, Any] = {"file": discord.File(path)}
|
||||
if reference is not None:
|
||||
kwargs["reference"] = reference
|
||||
kwargs["allowed_mentions"] = mention_settings
|
||||
await channel.send(**kwargs)
|
||||
logger.info("Discord file sent: {}", path.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Error sending Discord file {}: {}", path.name, e)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _build_chunks(content: str, failed_media: list[str], sent_media: bool) -> list[str]:
|
||||
"""Build outbound text chunks, including attachment-failure fallback text."""
|
||||
chunks = split_message(content, MAX_MESSAGE_LEN)
|
||||
if chunks or not failed_media or sent_media:
|
||||
return chunks
|
||||
fallback = "\n".join(f"[attachment: {name} - send failed]" for name in failed_media)
|
||||
return split_message(fallback, MAX_MESSAGE_LEN)
|
||||
|
||||
@staticmethod
|
||||
def _build_reply_context(
|
||||
channel: Messageable,
|
||||
reply_to: str | None,
|
||||
) -> tuple[discord.PartialMessage | None, discord.AllowedMentions]:
|
||||
"""Build reply context for outbound messages."""
|
||||
mention_settings = discord.AllowedMentions(replied_user=False)
|
||||
if not reply_to:
|
||||
return None, mention_settings
|
||||
try:
|
||||
message_id = int(reply_to)
|
||||
except ValueError:
|
||||
logger.warning("Invalid Discord reply target: {}", reply_to)
|
||||
return None, mention_settings
|
||||
|
||||
return channel.get_partial_message(message_id), mention_settings
|
||||
|
||||
|
||||
class DiscordChannel(BaseChannel):
|
||||
"""Discord channel using Gateway websocket."""
|
||||
"""Discord channel using discord.py."""
|
||||
|
||||
name = "discord"
|
||||
display_name = "Discord"
|
||||
@ -43,353 +247,270 @@ class DiscordChannel(BaseChannel):
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
return DiscordConfig().model_dump(by_alias=True)
|
||||
|
||||
@staticmethod
|
||||
def _channel_key(channel_or_id: Any) -> str:
|
||||
"""Normalize channel-like objects and ids to a stable string key."""
|
||||
channel_id = getattr(channel_or_id, "id", channel_or_id)
|
||||
return str(channel_id)
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
if isinstance(config, dict):
|
||||
config = DiscordConfig.model_validate(config)
|
||||
super().__init__(config, bus)
|
||||
self.config: DiscordConfig = config
|
||||
self._ws: websockets.WebSocketClientProtocol | None = None
|
||||
self._seq: int | None = None
|
||||
self._heartbeat_task: asyncio.Task | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._http: httpx.AsyncClient | None = None
|
||||
self._client: DiscordBotClient | None = None
|
||||
self._typing_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._bot_user_id: str | None = None
|
||||
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
|
||||
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord gateway connection."""
|
||||
"""Start the Discord client."""
|
||||
if not DISCORD_AVAILABLE:
|
||||
logger.error("discord.py not installed. Run: pip install nanobot-ai[discord]")
|
||||
return
|
||||
|
||||
if not self.config.token:
|
||||
logger.error("Discord bot token not configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._http = httpx.AsyncClient(timeout=30.0)
|
||||
try:
|
||||
intents = discord.Intents.none()
|
||||
intents.value = self.config.intents
|
||||
self._client = DiscordBotClient(self, intents=intents)
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Discord client: {}", e)
|
||||
self._client = None
|
||||
self._running = False
|
||||
return
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
logger.info("Connecting to Discord gateway...")
|
||||
async with websockets.connect(self.config.gateway_url) as ws:
|
||||
self._ws = ws
|
||||
await self._gateway_loop()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning("Discord gateway error: {}", e)
|
||||
if self._running:
|
||||
logger.info("Reconnecting to Discord gateway in 5 seconds...")
|
||||
await asyncio.sleep(5)
|
||||
self._running = True
|
||||
logger.info("Starting Discord client via discord.py...")
|
||||
|
||||
try:
|
||||
await self._client.start(self.config.token)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Discord client startup failed: {}", e)
|
||||
finally:
|
||||
self._running = False
|
||||
await self._reset_runtime_state(close_client=True)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Discord channel."""
|
||||
self._running = False
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
self._heartbeat_task = None
|
||||
for task in self._typing_tasks.values():
|
||||
task.cancel()
|
||||
self._typing_tasks.clear()
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
if self._http:
|
||||
await self._http.aclose()
|
||||
self._http = None
|
||||
await self._reset_runtime_state(close_client=True)
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through Discord REST API, including file attachments."""
|
||||
if not self._http:
|
||||
logger.warning("Discord HTTP client not initialized")
|
||||
"""Send a message through Discord using discord.py."""
|
||||
client = self._client
|
||||
if client is None or not client.is_ready():
|
||||
logger.warning("Discord client not ready; dropping outbound message")
|
||||
return
|
||||
|
||||
url = f"{DISCORD_API_BASE}/channels/{msg.chat_id}/messages"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
is_progress = bool((msg.metadata or {}).get("_progress"))
|
||||
|
||||
try:
|
||||
sent_media = False
|
||||
failed_media: list[str] = []
|
||||
|
||||
# Send file attachments first
|
||||
for media_path in msg.media or []:
|
||||
if await self._send_file(url, headers, media_path, reply_to=msg.reply_to):
|
||||
sent_media = True
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
# Send text content
|
||||
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
|
||||
if not chunks and failed_media and not sent_media:
|
||||
chunks = split_message(
|
||||
"\n".join(f"[attachment: {name} - send failed]" for name in failed_media),
|
||||
MAX_MESSAGE_LEN,
|
||||
)
|
||||
if not chunks:
|
||||
return
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
payload: dict[str, Any] = {"content": chunk}
|
||||
|
||||
# Let the first successful attachment carry the reply if present.
|
||||
if i == 0 and msg.reply_to and not sent_media:
|
||||
payload["message_reference"] = {"message_id": msg.reply_to}
|
||||
payload["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
if not await self._send_payload(url, headers, payload):
|
||||
break # Abort remaining chunks on failure
|
||||
await client.send_outbound(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
finally:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
if not is_progress:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
await self._clear_reactions(msg.chat_id)
|
||||
|
||||
async def _send_payload(
|
||||
self, url: str, headers: dict[str, str], payload: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Send a single Discord API payload with retry on rate-limit. Returns True on success."""
|
||||
for attempt in range(3):
|
||||
async def _handle_discord_message(self, message: discord.Message) -> None:
|
||||
"""Handle incoming Discord messages from discord.py."""
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
sender_id = str(message.author.id)
|
||||
channel_id = self._channel_key(message.channel)
|
||||
content = message.content or ""
|
||||
|
||||
if not self._should_accept_inbound(message, sender_id, content):
|
||||
return
|
||||
|
||||
media_paths, attachment_markers = await self._download_attachments(message.attachments)
|
||||
full_content = self._compose_inbound_content(content, attachment_markers)
|
||||
metadata = self._build_inbound_metadata(message)
|
||||
|
||||
await self._start_typing(message.channel)
|
||||
|
||||
# Add read receipt reaction immediately, working emoji after delay
|
||||
channel_id = self._channel_key(message.channel)
|
||||
try:
|
||||
await message.add_reaction(self.config.read_receipt_emoji)
|
||||
self._pending_reactions[channel_id] = message
|
||||
except Exception as e:
|
||||
logger.debug("Failed to add read receipt reaction: {}", e)
|
||||
|
||||
# Delayed working indicator (cosmetic — not tied to subagent lifecycle)
|
||||
async def _delayed_working_emoji() -> None:
|
||||
await asyncio.sleep(self.config.working_emoji_delay)
|
||||
try:
|
||||
response = await self._http.post(url, headers=headers, json=payload)
|
||||
if response.status_code == 429:
|
||||
data = response.json()
|
||||
retry_after = float(data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
await message.add_reaction(self.config.working_emoji)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _send_file(
|
||||
self._working_emoji_tasks[channel_id] = asyncio.create_task(_delayed_working_emoji())
|
||||
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=channel_id,
|
||||
content=full_content,
|
||||
media=media_paths,
|
||||
metadata=metadata,
|
||||
)
|
||||
except Exception:
|
||||
await self._clear_reactions(channel_id)
|
||||
await self._stop_typing(channel_id)
|
||||
raise
|
||||
|
||||
async def _on_message(self, message: discord.Message) -> None:
|
||||
"""Backward-compatible alias for legacy tests/callers."""
|
||||
await self._handle_discord_message(message)
|
||||
|
||||
def _should_accept_inbound(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str],
|
||||
file_path: str,
|
||||
reply_to: str | None = None,
|
||||
message: discord.Message,
|
||||
sender_id: str,
|
||||
content: str,
|
||||
) -> bool:
|
||||
"""Send a file attachment via Discord REST API using multipart/form-data."""
|
||||
path = Path(file_path)
|
||||
if not path.is_file():
|
||||
logger.warning("Discord file not found, skipping: {}", file_path)
|
||||
return False
|
||||
|
||||
if path.stat().st_size > MAX_ATTACHMENT_BYTES:
|
||||
logger.warning("Discord file too large (>20MB), skipping: {}", path.name)
|
||||
return False
|
||||
|
||||
payload_json: dict[str, Any] = {}
|
||||
if reply_to:
|
||||
payload_json["message_reference"] = {"message_id": reply_to}
|
||||
payload_json["allowed_mentions"] = {"replied_user": False}
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
files = {"files[0]": (path.name, f, "application/octet-stream")}
|
||||
data: dict[str, Any] = {}
|
||||
if payload_json:
|
||||
data["payload_json"] = json.dumps(payload_json)
|
||||
response = await self._http.post(
|
||||
url, headers=headers, files=files, data=data
|
||||
)
|
||||
if response.status_code == 429:
|
||||
resp_data = response.json()
|
||||
retry_after = float(resp_data.get("retry_after", 1.0))
|
||||
logger.warning("Discord rate limited, retrying in {}s", retry_after)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
response.raise_for_status()
|
||||
logger.info("Discord file sent: {}", path.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
if attempt == 2:
|
||||
logger.error("Error sending Discord file {}: {}", path.name, e)
|
||||
else:
|
||||
await asyncio.sleep(1)
|
||||
return False
|
||||
|
||||
async def _gateway_loop(self) -> None:
|
||||
"""Main gateway loop: identify, heartbeat, dispatch events."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
async for raw in self._ws:
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid JSON from Discord gateway: {}", raw[:100])
|
||||
continue
|
||||
|
||||
op = data.get("op")
|
||||
event_type = data.get("t")
|
||||
seq = data.get("s")
|
||||
payload = data.get("d")
|
||||
|
||||
if seq is not None:
|
||||
self._seq = seq
|
||||
|
||||
if op == 10:
|
||||
# HELLO: start heartbeat and identify
|
||||
interval_ms = payload.get("heartbeat_interval", 45000)
|
||||
await self._start_heartbeat(interval_ms / 1000)
|
||||
await self._identify()
|
||||
elif op == 0 and event_type == "READY":
|
||||
logger.info("Discord gateway READY")
|
||||
# Capture bot user ID for mention detection
|
||||
user_data = payload.get("user") or {}
|
||||
self._bot_user_id = user_data.get("id")
|
||||
logger.info("Discord bot connected as user {}", self._bot_user_id)
|
||||
elif op == 0 and event_type == "MESSAGE_CREATE":
|
||||
await self._handle_message_create(payload)
|
||||
elif op == 7:
|
||||
# RECONNECT: exit loop to reconnect
|
||||
logger.info("Discord gateway requested reconnect")
|
||||
break
|
||||
elif op == 9:
|
||||
# INVALID_SESSION: reconnect
|
||||
logger.warning("Discord gateway invalid session")
|
||||
break
|
||||
|
||||
async def _identify(self) -> None:
|
||||
"""Send IDENTIFY payload."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
identify = {
|
||||
"op": 2,
|
||||
"d": {
|
||||
"token": self.config.token,
|
||||
"intents": self.config.intents,
|
||||
"properties": {
|
||||
"os": "nanobot",
|
||||
"browser": "nanobot",
|
||||
"device": "nanobot",
|
||||
},
|
||||
},
|
||||
}
|
||||
await self._ws.send(json.dumps(identify))
|
||||
|
||||
async def _start_heartbeat(self, interval_s: float) -> None:
|
||||
"""Start or restart the heartbeat loop."""
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
|
||||
async def heartbeat_loop() -> None:
|
||||
while self._running and self._ws:
|
||||
payload = {"op": 1, "d": self._seq}
|
||||
try:
|
||||
await self._ws.send(json.dumps(payload))
|
||||
except Exception as e:
|
||||
logger.warning("Discord heartbeat failed: {}", e)
|
||||
break
|
||||
await asyncio.sleep(interval_s)
|
||||
|
||||
self._heartbeat_task = asyncio.create_task(heartbeat_loop())
|
||||
|
||||
async def _handle_message_create(self, payload: dict[str, Any]) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
author = payload.get("author") or {}
|
||||
if author.get("bot"):
|
||||
return
|
||||
|
||||
sender_id = str(author.get("id", ""))
|
||||
channel_id = str(payload.get("channel_id", ""))
|
||||
content = payload.get("content") or ""
|
||||
guild_id = payload.get("guild_id")
|
||||
|
||||
if not sender_id or not channel_id:
|
||||
return
|
||||
|
||||
"""Check if inbound Discord message should be processed."""
|
||||
if not self.is_allowed(sender_id):
|
||||
return
|
||||
return False
|
||||
if message.guild is not None and not self._should_respond_in_group(message, content):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Check group channel policy (DMs always respond if is_allowed passes)
|
||||
if guild_id is not None:
|
||||
if not self._should_respond_in_group(payload, content):
|
||||
return
|
||||
|
||||
content_parts = [content] if content else []
|
||||
async def _download_attachments(
|
||||
self,
|
||||
attachments: list[discord.Attachment],
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Download supported attachments and return paths + display markers."""
|
||||
media_paths: list[str] = []
|
||||
markers: list[str] = []
|
||||
media_dir = get_media_dir("discord")
|
||||
|
||||
for attachment in payload.get("attachments") or []:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename") or "attachment"
|
||||
size = attachment.get("size") or 0
|
||||
if not url or not self._http:
|
||||
continue
|
||||
if size and size > MAX_ATTACHMENT_BYTES:
|
||||
content_parts.append(f"[attachment: {filename} - too large]")
|
||||
for attachment in attachments:
|
||||
filename = attachment.filename or "attachment"
|
||||
if attachment.size and attachment.size > MAX_ATTACHMENT_BYTES:
|
||||
markers.append(f"[attachment: {filename} - too large]")
|
||||
continue
|
||||
try:
|
||||
media_dir.mkdir(parents=True, exist_ok=True)
|
||||
file_path = media_dir / f"{attachment.get('id', 'file')}_{filename.replace('/', '_')}"
|
||||
resp = await self._http.get(url)
|
||||
resp.raise_for_status()
|
||||
file_path.write_bytes(resp.content)
|
||||
safe_name = safe_filename(filename)
|
||||
file_path = media_dir / f"{attachment.id}_{safe_name}"
|
||||
await attachment.save(file_path)
|
||||
media_paths.append(str(file_path))
|
||||
content_parts.append(f"[attachment: {file_path}]")
|
||||
markers.append(f"[attachment: {file_path.name}]")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download Discord attachment: {}", e)
|
||||
content_parts.append(f"[attachment: {filename} - download failed]")
|
||||
markers.append(f"[attachment: {filename} - download failed]")
|
||||
|
||||
reply_to = (payload.get("referenced_message") or {}).get("id")
|
||||
return media_paths, markers
|
||||
|
||||
await self._start_typing(channel_id)
|
||||
@staticmethod
|
||||
def _compose_inbound_content(content: str, attachment_markers: list[str]) -> str:
|
||||
"""Combine message text with attachment markers."""
|
||||
content_parts = [content] if content else []
|
||||
content_parts.extend(attachment_markers)
|
||||
return "\n".join(part for part in content_parts if part) or "[empty message]"
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=channel_id,
|
||||
content="\n".join(p for p in content_parts if p) or "[empty message]",
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": str(payload.get("id", "")),
|
||||
"guild_id": guild_id,
|
||||
"reply_to": reply_to,
|
||||
},
|
||||
)
|
||||
@staticmethod
|
||||
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
||||
"""Build metadata for inbound Discord messages."""
|
||||
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
|
||||
return {
|
||||
"message_id": str(message.id),
|
||||
"guild_id": str(message.guild.id) if message.guild else None,
|
||||
"reply_to": reply_to,
|
||||
}
|
||||
|
||||
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
|
||||
"""Check if bot should respond in a group channel based on policy."""
|
||||
def _should_respond_in_group(self, message: discord.Message, content: str) -> bool:
|
||||
"""Check if the bot should respond in a guild channel based on policy."""
|
||||
if self.config.group_policy == "open":
|
||||
return True
|
||||
|
||||
if self.config.group_policy == "mention":
|
||||
# Check if bot was mentioned in the message
|
||||
if self._bot_user_id:
|
||||
# Check mentions array
|
||||
mentions = payload.get("mentions") or []
|
||||
for mention in mentions:
|
||||
if str(mention.get("id")) == self._bot_user_id:
|
||||
return True
|
||||
# Also check content for mention format <@USER_ID>
|
||||
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
|
||||
return True
|
||||
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
|
||||
bot_user_id = self._bot_user_id
|
||||
if bot_user_id is None:
|
||||
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
|
||||
return False
|
||||
|
||||
if any(str(user.id) == bot_user_id for user in message.mentions):
|
||||
return True
|
||||
if f"<@{bot_user_id}>" in content or f"<@!{bot_user_id}>" in content:
|
||||
return True
|
||||
|
||||
logger.debug("Discord message in {} ignored (bot not mentioned)", message.channel.id)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _start_typing(self, channel_id: str) -> None:
|
||||
async def _start_typing(self, channel: Messageable) -> None:
|
||||
"""Start periodic typing indicator for a channel."""
|
||||
channel_id = self._channel_key(channel)
|
||||
await self._stop_typing(channel_id)
|
||||
|
||||
async def typing_loop() -> None:
|
||||
url = f"{DISCORD_API_BASE}/channels/{channel_id}/typing"
|
||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||
while self._running:
|
||||
try:
|
||||
await self._http.post(url, headers=headers)
|
||||
async with channel.typing():
|
||||
await asyncio.sleep(TYPING_INTERVAL_S)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for {}: {}", channel_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
|
||||
self._typing_tasks[channel_id] = asyncio.create_task(typing_loop())
|
||||
|
||||
async def _stop_typing(self, channel_id: str) -> None:
|
||||
"""Stop typing indicator for a channel."""
|
||||
task = self._typing_tasks.pop(channel_id, None)
|
||||
if task:
|
||||
task = self._typing_tasks.pop(self._channel_key(channel_id), None)
|
||||
if task is None:
|
||||
return
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
async def _clear_reactions(self, chat_id: str) -> None:
|
||||
"""Remove all pending reactions after bot replies."""
|
||||
# Cancel delayed working emoji if it hasn't fired yet
|
||||
task = self._working_emoji_tasks.pop(chat_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
msg_obj = self._pending_reactions.pop(chat_id, None)
|
||||
if msg_obj is None:
|
||||
return
|
||||
bot_user = self._client.user if self._client else None
|
||||
for emoji in (self.config.read_receipt_emoji, self.config.working_emoji):
|
||||
try:
|
||||
await msg_obj.remove_reaction(emoji, bot_user)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _cancel_all_typing(self) -> None:
|
||||
"""Stop all typing tasks."""
|
||||
channel_ids = list(self._typing_tasks)
|
||||
for channel_id in channel_ids:
|
||||
await self._stop_typing(channel_id)
|
||||
|
||||
async def _reset_runtime_state(self, close_client: bool) -> None:
|
||||
"""Reset client and typing state."""
|
||||
await self._cancel_all_typing()
|
||||
if close_client and self._client is not None and not self._client.is_closed():
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception as e:
|
||||
logger.warning("Discord client close failed: {}", e)
|
||||
self._client = None
|
||||
self._bot_user_id = None
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
@ -28,8 +30,8 @@ try:
|
||||
RoomSendError,
|
||||
RoomTypingError,
|
||||
SyncError,
|
||||
UploadError,
|
||||
)
|
||||
UploadError, RoomSendResponse,
|
||||
)
|
||||
from nio.crypto.attachments import decrypt_attachment
|
||||
from nio.exceptions import EncryptionError
|
||||
except ImportError as e:
|
||||
@ -97,6 +99,22 @@ MATRIX_HTML_CLEANER = nh3.Cleaner(
|
||||
link_rel="noopener noreferrer",
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class _StreamBuf:
|
||||
"""
|
||||
Represents a buffer for managing LLM response stream data.
|
||||
|
||||
:ivar text: Stores the text content of the buffer.
|
||||
:type text: str
|
||||
:ivar event_id: Identifier for the associated event. None indicates no
|
||||
specific event association.
|
||||
:type event_id: str | None
|
||||
:ivar last_edit: Timestamp of the most recent edit to the buffer.
|
||||
:type last_edit: float
|
||||
"""
|
||||
text: str = ""
|
||||
event_id: str | None = None
|
||||
last_edit: float = 0.0
|
||||
|
||||
def _render_markdown_html(text: str) -> str | None:
|
||||
"""Render markdown to sanitized HTML; returns None for plain text."""
|
||||
@ -114,12 +132,47 @@ def _render_markdown_html(text: str) -> str | None:
|
||||
return formatted
|
||||
|
||||
|
||||
def _build_matrix_text_content(text: str) -> dict[str, object]:
|
||||
"""Build Matrix m.text payload with optional HTML formatted_body."""
|
||||
def _build_matrix_text_content(
|
||||
text: str,
|
||||
event_id: str | None = None,
|
||||
thread_relates_to: dict[str, object] | None = None,
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Constructs and returns a dictionary representing the matrix text content with optional
|
||||
HTML formatting and reference to an existing event for replacement. This function is
|
||||
primarily used to create content payloads compatible with the Matrix messaging protocol.
|
||||
|
||||
:param text: The plain text content to include in the message.
|
||||
:type text: str
|
||||
:param event_id: Optional ID of the event to replace. If provided, the function will
|
||||
include information indicating that the message is a replacement of the specified
|
||||
event.
|
||||
:type event_id: str | None
|
||||
:param thread_relates_to: Optional Matrix thread relation metadata. For edits this is
|
||||
stored in ``m.new_content`` so the replacement remains in the same thread.
|
||||
:type thread_relates_to: dict[str, object] | None
|
||||
:return: A dictionary containing the matrix text content, potentially enriched with
|
||||
HTML formatting and replacement metadata if applicable.
|
||||
:rtype: dict[str, object]
|
||||
"""
|
||||
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
|
||||
if html := _render_markdown_html(text):
|
||||
content["format"] = MATRIX_HTML_FORMAT
|
||||
content["formatted_body"] = html
|
||||
if event_id:
|
||||
content["m.new_content"] = {
|
||||
"body": text,
|
||||
"msgtype": "m.text",
|
||||
}
|
||||
content["m.relates_to"] = {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": event_id,
|
||||
}
|
||||
if thread_relates_to:
|
||||
content["m.new_content"]["m.relates_to"] = thread_relates_to
|
||||
elif thread_relates_to:
|
||||
content["m.relates_to"] = thread_relates_to
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@ -159,7 +212,8 @@ class MatrixConfig(Base):
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
||||
group_allow_from: list[str] = Field(default_factory=list)
|
||||
allow_room_mentions: bool = False
|
||||
allow_room_mentions: bool = False,
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
class MatrixChannel(BaseChannel):
|
||||
@ -167,6 +221,8 @@ class MatrixChannel(BaseChannel):
|
||||
|
||||
name = "matrix"
|
||||
display_name = "Matrix"
|
||||
_STREAM_EDIT_INTERVAL = 2 # min seconds between edit_message_text calls
|
||||
monotonic_time = time.monotonic
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
@ -192,6 +248,8 @@ class MatrixChannel(BaseChannel):
|
||||
)
|
||||
self._server_upload_limit_bytes: int | None = None
|
||||
self._server_upload_limit_checked = False
|
||||
self._stream_bufs: dict[str, _StreamBuf] = {}
|
||||
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Matrix client and begin sync loop."""
|
||||
@ -297,14 +355,17 @@ class MatrixChannel(BaseChannel):
|
||||
room = getattr(self.client, "rooms", {}).get(room_id)
|
||||
return bool(getattr(room, "encrypted", False))
|
||||
|
||||
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
|
||||
async def _send_room_content(self, room_id: str,
|
||||
content: dict[str, Any]) -> None | RoomSendResponse | RoomSendError:
|
||||
"""Send m.room.message with E2EE options."""
|
||||
if not self.client:
|
||||
return
|
||||
return None
|
||||
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
|
||||
|
||||
if self.config.e2ee_enabled:
|
||||
kwargs["ignore_unverified_devices"] = True
|
||||
await self.client.room_send(**kwargs)
|
||||
response = await self.client.room_send(**kwargs)
|
||||
return response
|
||||
|
||||
async def _resolve_server_upload_limit_bytes(self) -> int | None:
|
||||
"""Query homeserver upload limit once per channel lifecycle."""
|
||||
@ -414,6 +475,53 @@ class MatrixChannel(BaseChannel):
|
||||
if not is_progress:
|
||||
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||
meta = metadata or {}
|
||||
relates_to = self._build_thread_relates_to(metadata)
|
||||
|
||||
if meta.get("_stream_end"):
|
||||
buf = self._stream_bufs.pop(chat_id, None)
|
||||
if not buf or not buf.event_id or not buf.text:
|
||||
return
|
||||
|
||||
await self._stop_typing_keepalive(chat_id, clear_typing=True)
|
||||
|
||||
content = _build_matrix_text_content(
|
||||
buf.text,
|
||||
buf.event_id,
|
||||
thread_relates_to=relates_to,
|
||||
)
|
||||
await self._send_room_content(chat_id, content)
|
||||
return
|
||||
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
if buf is None:
|
||||
buf = _StreamBuf()
|
||||
self._stream_bufs[chat_id] = buf
|
||||
buf.text += delta
|
||||
|
||||
if not buf.text.strip():
|
||||
return
|
||||
|
||||
now = self.monotonic_time()
|
||||
|
||||
if not buf.last_edit or (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
|
||||
try:
|
||||
content = _build_matrix_text_content(
|
||||
buf.text,
|
||||
buf.event_id,
|
||||
thread_relates_to=relates_to,
|
||||
)
|
||||
response = await self._send_room_content(chat_id, content)
|
||||
buf.last_edit = now
|
||||
if not buf.event_id:
|
||||
# we are editing the same message all the time, so only the first time the event id needs to be set
|
||||
buf.event_id = response.event_id
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(chat_id, clear_typing=True)
|
||||
pass
|
||||
|
||||
|
||||
def _register_event_callbacks(self) -> None:
|
||||
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||
|
||||
@ -14,6 +14,7 @@ import base64
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
@ -52,7 +53,26 @@ MESSAGE_TYPE_BOT = 2
|
||||
MESSAGE_STATE_FINISH = 2
|
||||
|
||||
WEIXIN_MAX_MESSAGE_LEN = 4000
|
||||
WEIXIN_CHANNEL_VERSION = "1.0.3"
|
||||
WEIXIN_CHANNEL_VERSION = "2.1.1"
|
||||
ILINK_APP_ID = "bot"
|
||||
|
||||
|
||||
def _build_client_version(version: str) -> int:
|
||||
"""Encode semantic version as 0x00MMNNPP (major/minor/patch in one uint32)."""
|
||||
parts = version.split(".")
|
||||
|
||||
def _as_int(idx: int) -> int:
|
||||
try:
|
||||
return int(parts[idx])
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
major = _as_int(0)
|
||||
minor = _as_int(1)
|
||||
patch = _as_int(2)
|
||||
return ((major & 0xFF) << 16) | ((minor & 0xFF) << 8) | (patch & 0xFF)
|
||||
|
||||
ILINK_APP_CLIENT_VERSION = _build_client_version(WEIXIN_CHANNEL_VERSION)
|
||||
BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION}
|
||||
|
||||
# Session-expired error code
|
||||
@ -64,18 +84,32 @@ MAX_CONSECUTIVE_FAILURES = 3
|
||||
BACKOFF_DELAY_S = 30
|
||||
RETRY_DELAY_S = 2
|
||||
MAX_QR_REFRESH_COUNT = 3
|
||||
TYPING_STATUS_TYPING = 1
|
||||
TYPING_STATUS_CANCEL = 2
|
||||
TYPING_TICKET_TTL_S = 24 * 60 * 60
|
||||
TYPING_KEEPALIVE_INTERVAL_S = 5
|
||||
CONFIG_CACHE_INITIAL_RETRY_S = 2
|
||||
CONFIG_CACHE_MAX_RETRY_S = 60 * 60
|
||||
|
||||
# Default long-poll timeout; overridden by server via longpolling_timeout_ms.
|
||||
DEFAULT_LONG_POLL_TIMEOUT_S = 35
|
||||
|
||||
# Media-type codes for getuploadurl (1=image, 2=video, 3=file)
|
||||
# Media-type codes for getuploadurl (1=image, 2=video, 3=file, 4=voice)
|
||||
UPLOAD_MEDIA_IMAGE = 1
|
||||
UPLOAD_MEDIA_VIDEO = 2
|
||||
UPLOAD_MEDIA_FILE = 3
|
||||
UPLOAD_MEDIA_VOICE = 4
|
||||
|
||||
# File extensions considered as images / videos for outbound media
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".ico", ".svg"}
|
||||
_VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv"}
|
||||
_VOICE_EXTS = {".mp3", ".wav", ".amr", ".silk", ".ogg", ".m4a", ".aac", ".flac"}
|
||||
|
||||
|
||||
def _has_downloadable_media_locator(media: dict[str, Any] | None) -> bool:
|
||||
if not isinstance(media, dict):
|
||||
return False
|
||||
return bool(str(media.get("encrypt_query_param", "") or "") or str(media.get("full_url", "") or "").strip())
|
||||
|
||||
|
||||
class WeixinConfig(Base):
|
||||
@ -124,7 +158,7 @@ class WeixinChannel(BaseChannel):
|
||||
self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S
|
||||
self._session_pause_until: float = 0.0
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._typing_tickets: dict[str, str] = {}
|
||||
self._typing_tickets: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State persistence
|
||||
@ -162,9 +196,9 @@ class WeixinChannel(BaseChannel):
|
||||
typing_tickets = data.get("typing_tickets", {})
|
||||
if isinstance(typing_tickets, dict):
|
||||
self._typing_tickets = {
|
||||
str(user_id): str(ticket)
|
||||
str(user_id): ticket
|
||||
for user_id, ticket in typing_tickets.items()
|
||||
if str(user_id).strip() and str(ticket).strip()
|
||||
if str(user_id).strip() and isinstance(ticket, dict)
|
||||
}
|
||||
else:
|
||||
self._typing_tickets = {}
|
||||
@ -172,8 +206,7 @@ class WeixinChannel(BaseChannel):
|
||||
if base_url:
|
||||
self.config.base_url = base_url
|
||||
return bool(self._token)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load WeChat state: {}", e)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _save_state(self) -> None:
|
||||
@ -187,8 +220,8 @@ class WeixinChannel(BaseChannel):
|
||||
"base_url": self.config.base_url,
|
||||
}
|
||||
state_file.write_text(json.dumps(data, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save WeChat state: {}", e)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# HTTP helpers (matches api.ts buildHeaders / apiFetch)
|
||||
@ -210,6 +243,8 @@ class WeixinChannel(BaseChannel):
|
||||
"X-WECHAT-UIN": self._random_wechat_uin(),
|
||||
"Content-Type": "application/json",
|
||||
"AuthorizationType": "ilink_bot_token",
|
||||
"iLink-App-Id": ILINK_APP_ID,
|
||||
"iLink-App-ClientVersion": str(ILINK_APP_CLIENT_VERSION),
|
||||
}
|
||||
if auth and self._token:
|
||||
headers["Authorization"] = f"Bearer {self._token}"
|
||||
@ -217,6 +252,15 @@ class WeixinChannel(BaseChannel):
|
||||
headers["SKRouteTag"] = str(self.config.route_tag).strip()
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_media_download_error(err: Exception) -> bool:
|
||||
if isinstance(err, httpx.TimeoutException | httpx.TransportError):
|
||||
return True
|
||||
if isinstance(err, httpx.HTTPStatusError):
|
||||
status_code = err.response.status_code if err.response is not None else 0
|
||||
return status_code >= 500
|
||||
return False
|
||||
|
||||
async def _api_get(
|
||||
self,
|
||||
endpoint: str,
|
||||
@ -234,6 +278,25 @@ class WeixinChannel(BaseChannel):
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def _api_get_with_base(
|
||||
self,
|
||||
*,
|
||||
base_url: str,
|
||||
endpoint: str,
|
||||
params: dict | None = None,
|
||||
auth: bool = True,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> dict:
|
||||
"""GET helper that allows overriding base_url for QR redirect polling."""
|
||||
assert self._client is not None
|
||||
url = f"{base_url.rstrip('/')}/{endpoint}"
|
||||
hdrs = self._make_headers(auth=auth)
|
||||
if extra_headers:
|
||||
hdrs.update(extra_headers)
|
||||
resp = await self._client.get(url, params=params, headers=hdrs)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
async def _api_post(
|
||||
self,
|
||||
endpoint: str,
|
||||
@ -270,23 +333,27 @@ class WeixinChannel(BaseChannel):
|
||||
async def _qr_login(self) -> bool:
|
||||
"""Perform QR code login flow. Returns True on success."""
|
||||
try:
|
||||
logger.info("Starting WeChat QR code login...")
|
||||
refresh_count = 0
|
||||
qrcode_id, scan_url = await self._fetch_qr_code()
|
||||
self._print_qr_code(scan_url)
|
||||
current_poll_base_url = self.config.base_url
|
||||
|
||||
logger.info("Waiting for QR code scan...")
|
||||
while self._running:
|
||||
try:
|
||||
# Reference plugin sends iLink-App-ClientVersion header for
|
||||
# QR status polling (login-qr.ts:81).
|
||||
status_data = await self._api_get(
|
||||
"ilink/bot/get_qrcode_status",
|
||||
status_data = await self._api_get_with_base(
|
||||
base_url=current_poll_base_url,
|
||||
endpoint="ilink/bot/get_qrcode_status",
|
||||
params={"qrcode": qrcode_id},
|
||||
auth=False,
|
||||
extra_headers={"iLink-App-ClientVersion": "1"},
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
except Exception as e:
|
||||
if self._is_retryable_qr_poll_error(e):
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
raise
|
||||
|
||||
if not isinstance(status_data, dict):
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
status = status_data.get("status", "")
|
||||
@ -309,8 +376,15 @@ class WeixinChannel(BaseChannel):
|
||||
else:
|
||||
logger.error("Login confirmed but no bot_token in response")
|
||||
return False
|
||||
elif status == "scaned":
|
||||
logger.info("QR code scanned, waiting for confirmation...")
|
||||
elif status == "scaned_but_redirect":
|
||||
redirect_host = str(status_data.get("redirect_host", "") or "").strip()
|
||||
if redirect_host:
|
||||
if redirect_host.startswith("http://") or redirect_host.startswith("https://"):
|
||||
redirected_base = redirect_host
|
||||
else:
|
||||
redirected_base = f"https://{redirect_host}"
|
||||
if redirected_base != current_poll_base_url:
|
||||
current_poll_base_url = redirected_base
|
||||
elif status == "expired":
|
||||
refresh_count += 1
|
||||
if refresh_count > MAX_QR_REFRESH_COUNT:
|
||||
@ -320,14 +394,9 @@ class WeixinChannel(BaseChannel):
|
||||
MAX_QR_REFRESH_COUNT,
|
||||
)
|
||||
return False
|
||||
logger.warning(
|
||||
"QR code expired, refreshing... ({}/{})",
|
||||
refresh_count,
|
||||
MAX_QR_REFRESH_COUNT,
|
||||
)
|
||||
qrcode_id, scan_url = await self._fetch_qr_code()
|
||||
current_poll_base_url = self.config.base_url
|
||||
self._print_qr_code(scan_url)
|
||||
logger.info("New QR code generated, waiting for scan...")
|
||||
continue
|
||||
# status == "wait" — keep polling
|
||||
|
||||
@ -338,6 +407,16 @@ class WeixinChannel(BaseChannel):
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_qr_poll_error(err: Exception) -> bool:
|
||||
if isinstance(err, httpx.TimeoutException | httpx.TransportError):
|
||||
return True
|
||||
if isinstance(err, httpx.HTTPStatusError):
|
||||
status_code = err.response.status_code if err.response is not None else 0
|
||||
if status_code >= 500:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _print_qr_code(url: str) -> None:
|
||||
try:
|
||||
@ -348,7 +427,6 @@ class WeixinChannel(BaseChannel):
|
||||
qr.make(fit=True)
|
||||
qr.print_ascii(invert=True)
|
||||
except ImportError:
|
||||
logger.info("QR code URL (install 'qrcode' for terminal display): {}", url)
|
||||
print(f"\nLogin URL: {url}\n")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@ -410,12 +488,6 @@ class WeixinChannel(BaseChannel):
|
||||
if not self._running:
|
||||
break
|
||||
consecutive_failures += 1
|
||||
logger.error(
|
||||
"WeChat poll error ({}/{}): {}",
|
||||
consecutive_failures,
|
||||
MAX_CONSECUTIVE_FAILURES,
|
||||
e,
|
||||
)
|
||||
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
|
||||
consecutive_failures = 0
|
||||
await asyncio.sleep(BACKOFF_DELAY_S)
|
||||
@ -432,8 +504,6 @@ class WeixinChannel(BaseChannel):
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
self._save_state()
|
||||
logger.info("WeChat channel stopped")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Polling (matches monitor.ts monitorWeixinProvider)
|
||||
# ------------------------------------------------------------------
|
||||
@ -459,10 +529,6 @@ class WeixinChannel(BaseChannel):
|
||||
async def _poll_once(self) -> None:
|
||||
remaining = self._session_pause_remaining_s()
|
||||
if remaining > 0:
|
||||
logger.warning(
|
||||
"WeChat session paused, waiting {} min before next poll.",
|
||||
max((remaining + 59) // 60, 1),
|
||||
)
|
||||
await asyncio.sleep(remaining)
|
||||
return
|
||||
|
||||
@ -512,8 +578,8 @@ class WeixinChannel(BaseChannel):
|
||||
for msg in msgs:
|
||||
try:
|
||||
await self._process_message(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error processing WeChat message: {}", e)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound message processing (matches inbound.ts + process-message.ts)
|
||||
@ -549,6 +615,7 @@ class WeixinChannel(BaseChannel):
|
||||
item_list: list[dict] = msg.get("item_list") or []
|
||||
content_parts: list[str] = []
|
||||
media_paths: list[str] = []
|
||||
has_top_level_downloadable_media = False
|
||||
|
||||
for item in item_list:
|
||||
item_type = item.get("type", 0)
|
||||
@ -585,6 +652,8 @@ class WeixinChannel(BaseChannel):
|
||||
|
||||
elif item_type == ITEM_IMAGE:
|
||||
image_item = item.get("image_item") or {}
|
||||
if _has_downloadable_media_locator(image_item.get("media")):
|
||||
has_top_level_downloadable_media = True
|
||||
file_path = await self._download_media_item(image_item, "image")
|
||||
if file_path:
|
||||
content_parts.append(f"[image]\n[Image: source: {file_path}]")
|
||||
@ -599,6 +668,8 @@ class WeixinChannel(BaseChannel):
|
||||
if voice_text:
|
||||
content_parts.append(f"[voice] {voice_text}")
|
||||
else:
|
||||
if _has_downloadable_media_locator(voice_item.get("media")):
|
||||
has_top_level_downloadable_media = True
|
||||
file_path = await self._download_media_item(voice_item, "voice")
|
||||
if file_path:
|
||||
transcription = await self.transcribe_audio(file_path)
|
||||
@ -612,6 +683,8 @@ class WeixinChannel(BaseChannel):
|
||||
|
||||
elif item_type == ITEM_FILE:
|
||||
file_item = item.get("file_item") or {}
|
||||
if _has_downloadable_media_locator(file_item.get("media")):
|
||||
has_top_level_downloadable_media = True
|
||||
file_name = file_item.get("file_name", "unknown")
|
||||
file_path = await self._download_media_item(
|
||||
file_item,
|
||||
@ -626,6 +699,8 @@ class WeixinChannel(BaseChannel):
|
||||
|
||||
elif item_type == ITEM_VIDEO:
|
||||
video_item = item.get("video_item") or {}
|
||||
if _has_downloadable_media_locator(video_item.get("media")):
|
||||
has_top_level_downloadable_media = True
|
||||
file_path = await self._download_media_item(video_item, "video")
|
||||
if file_path:
|
||||
content_parts.append(f"[video]\n[Video: source: {file_path}]")
|
||||
@ -633,6 +708,52 @@ class WeixinChannel(BaseChannel):
|
||||
else:
|
||||
content_parts.append("[video]")
|
||||
|
||||
# Fallback: when no top-level media was downloaded, try quoted/referenced media.
|
||||
# This aligns with the reference plugin behavior that checks ref_msg.message_item
|
||||
# when main item_list has no downloadable media.
|
||||
if not media_paths and not has_top_level_downloadable_media:
|
||||
ref_media_item: dict[str, Any] | None = None
|
||||
for item in item_list:
|
||||
if item.get("type", 0) != ITEM_TEXT:
|
||||
continue
|
||||
ref = item.get("ref_msg") or {}
|
||||
candidate = ref.get("message_item") or {}
|
||||
if candidate.get("type", 0) in (ITEM_IMAGE, ITEM_VOICE, ITEM_FILE, ITEM_VIDEO):
|
||||
ref_media_item = candidate
|
||||
break
|
||||
|
||||
if ref_media_item:
|
||||
ref_type = ref_media_item.get("type", 0)
|
||||
if ref_type == ITEM_IMAGE:
|
||||
image_item = ref_media_item.get("image_item") or {}
|
||||
file_path = await self._download_media_item(image_item, "image")
|
||||
if file_path:
|
||||
content_parts.append(f"[image]\n[Image: source: {file_path}]")
|
||||
media_paths.append(file_path)
|
||||
elif ref_type == ITEM_VOICE:
|
||||
voice_item = ref_media_item.get("voice_item") or {}
|
||||
file_path = await self._download_media_item(voice_item, "voice")
|
||||
if file_path:
|
||||
transcription = await self.transcribe_audio(file_path)
|
||||
if transcription:
|
||||
content_parts.append(f"[voice] {transcription}")
|
||||
else:
|
||||
content_parts.append(f"[voice]\n[Audio: source: {file_path}]")
|
||||
media_paths.append(file_path)
|
||||
elif ref_type == ITEM_FILE:
|
||||
file_item = ref_media_item.get("file_item") or {}
|
||||
file_name = file_item.get("file_name", "unknown")
|
||||
file_path = await self._download_media_item(file_item, "file", file_name)
|
||||
if file_path:
|
||||
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
|
||||
media_paths.append(file_path)
|
||||
elif ref_type == ITEM_VIDEO:
|
||||
video_item = ref_media_item.get("video_item") or {}
|
||||
file_path = await self._download_media_item(video_item, "video")
|
||||
if file_path:
|
||||
content_parts.append(f"[video]\n[Video: source: {file_path}]")
|
||||
media_paths.append(file_path)
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
if not content:
|
||||
return
|
||||
@ -667,9 +788,10 @@ class WeixinChannel(BaseChannel):
|
||||
"""Download + AES-decrypt a media item. Returns local path or None."""
|
||||
try:
|
||||
media = typed_item.get("media") or {}
|
||||
encrypt_query_param = media.get("encrypt_query_param", "")
|
||||
encrypt_query_param = str(media.get("encrypt_query_param", "") or "")
|
||||
full_url = str(media.get("full_url", "") or "").strip()
|
||||
|
||||
if not encrypt_query_param:
|
||||
if not encrypt_query_param and not full_url:
|
||||
return None
|
||||
|
||||
# Resolve AES key (media-download.ts:43-45, pic-decrypt.ts:40-52)
|
||||
@ -686,21 +808,50 @@ class WeixinChannel(BaseChannel):
|
||||
elif media_aes_key_b64:
|
||||
aes_key_b64 = media_aes_key_b64
|
||||
|
||||
# Build CDN download URL with proper URL-encoding (cdn-url.ts:7)
|
||||
cdn_url = (
|
||||
f"{self.config.cdn_base_url}/download"
|
||||
f"?encrypted_query_param={quote(encrypt_query_param)}"
|
||||
)
|
||||
# Reference protocol behavior: VOICE/FILE/VIDEO require aes_key;
|
||||
# only IMAGE may be downloaded as plain bytes when key is missing.
|
||||
if media_type != "image" and not aes_key_b64:
|
||||
return None
|
||||
|
||||
assert self._client is not None
|
||||
resp = await self._client.get(cdn_url)
|
||||
resp.raise_for_status()
|
||||
data = resp.content
|
||||
fallback_url = ""
|
||||
if encrypt_query_param:
|
||||
fallback_url = (
|
||||
f"{self.config.cdn_base_url}/download"
|
||||
f"?encrypted_query_param={quote(encrypt_query_param)}"
|
||||
)
|
||||
|
||||
download_candidates: list[tuple[str, str]] = []
|
||||
if full_url:
|
||||
download_candidates.append(("full_url", full_url))
|
||||
if fallback_url and (not full_url or fallback_url != full_url):
|
||||
download_candidates.append(("encrypt_query_param", fallback_url))
|
||||
|
||||
data = b""
|
||||
for idx, (download_source, cdn_url) in enumerate(download_candidates):
|
||||
try:
|
||||
resp = await self._client.get(cdn_url)
|
||||
resp.raise_for_status()
|
||||
data = resp.content
|
||||
break
|
||||
except Exception as e:
|
||||
has_more_candidates = idx + 1 < len(download_candidates)
|
||||
should_fallback = (
|
||||
download_source == "full_url"
|
||||
and has_more_candidates
|
||||
and self._is_retryable_media_download_error(e)
|
||||
)
|
||||
if should_fallback:
|
||||
logger.warning(
|
||||
"WeChat media download failed via full_url, falling back to encrypt_query_param: type={} err={}",
|
||||
media_type,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
||||
if aes_key_b64 and data:
|
||||
data = _decrypt_aes_ecb(data, aes_key_b64)
|
||||
elif not aes_key_b64:
|
||||
logger.debug("No AES key for {} item, using raw bytes", media_type)
|
||||
|
||||
if not data:
|
||||
return None
|
||||
@ -709,12 +860,12 @@ class WeixinChannel(BaseChannel):
|
||||
ext = _ext_for_type(media_type)
|
||||
if not filename:
|
||||
ts = int(time.time())
|
||||
h = abs(hash(encrypt_query_param)) % 100000
|
||||
hash_seed = encrypt_query_param or full_url
|
||||
h = abs(hash(hash_seed)) % 100000
|
||||
filename = f"{media_type}_{ts}_{h}{ext}"
|
||||
safe_name = os.path.basename(filename)
|
||||
file_path = media_dir / safe_name
|
||||
file_path.write_bytes(data)
|
||||
logger.debug("Downloaded WeChat {} to {}", media_type, file_path)
|
||||
return str(file_path)
|
||||
|
||||
except Exception as e:
|
||||
@ -725,14 +876,76 @@ class WeixinChannel(BaseChannel):
|
||||
# Outbound (matches send.ts buildTextMessageReq + sendMessageWeixin)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_typing_ticket(self, user_id: str, context_token: str = "") -> str:
|
||||
"""Get typing ticket with per-user refresh + failure backoff cache."""
|
||||
now = time.time()
|
||||
entry = self._typing_tickets.get(user_id)
|
||||
if entry and now < float(entry.get("next_fetch_at", 0)):
|
||||
return str(entry.get("ticket", "") or "")
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"ilink_user_id": user_id,
|
||||
"context_token": context_token or None,
|
||||
"base_info": BASE_INFO,
|
||||
}
|
||||
data = await self._api_post("ilink/bot/getconfig", body)
|
||||
if data.get("ret", 0) == 0:
|
||||
ticket = str(data.get("typing_ticket", "") or "")
|
||||
self._typing_tickets[user_id] = {
|
||||
"ticket": ticket,
|
||||
"ever_succeeded": True,
|
||||
"next_fetch_at": now + (random.random() * TYPING_TICKET_TTL_S),
|
||||
"retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S,
|
||||
}
|
||||
return ticket
|
||||
|
||||
prev_delay = float(entry.get("retry_delay_s", CONFIG_CACHE_INITIAL_RETRY_S)) if entry else CONFIG_CACHE_INITIAL_RETRY_S
|
||||
next_delay = min(prev_delay * 2, CONFIG_CACHE_MAX_RETRY_S)
|
||||
if entry:
|
||||
entry["next_fetch_at"] = now + next_delay
|
||||
entry["retry_delay_s"] = next_delay
|
||||
return str(entry.get("ticket", "") or "")
|
||||
|
||||
self._typing_tickets[user_id] = {
|
||||
"ticket": "",
|
||||
"ever_succeeded": False,
|
||||
"next_fetch_at": now + CONFIG_CACHE_INITIAL_RETRY_S,
|
||||
"retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S,
|
||||
}
|
||||
return ""
|
||||
|
||||
async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None:
|
||||
"""Best-effort sendtyping wrapper."""
|
||||
if not typing_ticket:
|
||||
return
|
||||
body: dict[str, Any] = {
|
||||
"ilink_user_id": user_id,
|
||||
"typing_ticket": typing_ticket,
|
||||
"status": status,
|
||||
"base_info": BASE_INFO,
|
||||
}
|
||||
await self._api_post("ilink/bot/sendtyping", body)
|
||||
|
||||
async def _typing_keepalive_loop(self, user_id: str, typing_ticket: str, stop_event: asyncio.Event) -> None:
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
|
||||
if stop_event.is_set():
|
||||
break
|
||||
try:
|
||||
await self._send_typing(user_id, typing_ticket, TYPING_STATUS_TYPING)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
if not self._client or not self._token:
|
||||
logger.warning("WeChat client not initialized or not authenticated")
|
||||
return
|
||||
try:
|
||||
self._assert_session_active()
|
||||
except RuntimeError as e:
|
||||
logger.warning("WeChat send blocked: {}", e)
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
is_progress = bool((msg.metadata or {}).get("_progress", False))
|
||||
@ -748,94 +961,103 @@ class WeixinChannel(BaseChannel):
|
||||
)
|
||||
return
|
||||
|
||||
# --- Send media files first (following Telegram channel pattern) ---
|
||||
for media_path in (msg.media or []):
|
||||
try:
|
||||
await self._send_media_file(msg.chat_id, media_path, ctx_token)
|
||||
except Exception as e:
|
||||
filename = Path(media_path).name
|
||||
logger.error("Failed to send WeChat media {}: {}", media_path, e)
|
||||
# Notify user about failure via text
|
||||
await self._send_text(
|
||||
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
|
||||
)
|
||||
typing_ticket = ""
|
||||
try:
|
||||
typing_ticket = await self._get_typing_ticket(msg.chat_id, ctx_token)
|
||||
except Exception:
|
||||
typing_ticket = ""
|
||||
|
||||
# --- Send text content ---
|
||||
if not content:
|
||||
return
|
||||
if typing_ticket:
|
||||
try:
|
||||
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_TYPING)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
typing_keepalive_stop = asyncio.Event()
|
||||
typing_keepalive_task: asyncio.Task | None = None
|
||||
if typing_ticket:
|
||||
typing_keepalive_task = asyncio.create_task(
|
||||
self._typing_keepalive_loop(msg.chat_id, typing_ticket, typing_keepalive_stop)
|
||||
)
|
||||
|
||||
try:
|
||||
# --- Send media files first (following Telegram channel pattern) ---
|
||||
for media_path in (msg.media or []):
|
||||
try:
|
||||
await self._send_media_file(msg.chat_id, media_path, ctx_token)
|
||||
except Exception as e:
|
||||
filename = Path(media_path).name
|
||||
logger.error("Failed to send WeChat media {}: {}", media_path, e)
|
||||
# Notify user about failure via text
|
||||
await self._send_text(
|
||||
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
|
||||
)
|
||||
|
||||
# --- Send text content ---
|
||||
if not content:
|
||||
return
|
||||
|
||||
chunks = split_message(content, WEIXIN_MAX_MESSAGE_LEN)
|
||||
for chunk in chunks:
|
||||
await self._send_text(msg.chat_id, chunk, ctx_token)
|
||||
except Exception as e:
|
||||
logger.error("Error sending WeChat message: {}", e)
|
||||
raise
|
||||
finally:
|
||||
if typing_keepalive_task:
|
||||
typing_keepalive_stop.set()
|
||||
typing_keepalive_task.cancel()
|
||||
try:
|
||||
await typing_keepalive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _get_typing_ticket(self, user_id: str, context_token: str) -> str:
|
||||
"""Fetch and cache typing ticket for a user/context pair."""
|
||||
if not self._client or not self._token or not user_id or not context_token:
|
||||
return ""
|
||||
cached = self._typing_tickets.get(user_id, "")
|
||||
if cached:
|
||||
return cached
|
||||
try:
|
||||
data = await self._api_post(
|
||||
"ilink/bot/getconfig",
|
||||
{
|
||||
"ilink_user_id": user_id,
|
||||
"context_token": context_token,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("WeChat getconfig failed for {}: {}", user_id, e)
|
||||
return ""
|
||||
ticket = str(data.get("typing_ticket") or "").strip()
|
||||
if ticket:
|
||||
self._typing_tickets[user_id] = ticket
|
||||
self._save_state()
|
||||
return ticket
|
||||
if typing_ticket and not is_progress:
|
||||
try:
|
||||
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _send_typing_status(self, to_user_id: str, typing_ticket: str, status: int) -> None:
|
||||
if not typing_ticket:
|
||||
return
|
||||
await self._api_post(
|
||||
"ilink/bot/sendtyping",
|
||||
{
|
||||
"ilink_user_id": to_user_id,
|
||||
"typing_ticket": typing_ticket,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
async def _start_typing(self, chat_id: str, context_token: str) -> None:
|
||||
if not self._client or not self._token or not chat_id or not context_token:
|
||||
async def _start_typing(self, chat_id: str, context_token: str = "") -> None:
|
||||
"""Start typing indicator immediately when a message is received."""
|
||||
if not self._client or not self._token or not chat_id:
|
||||
return
|
||||
await self._stop_typing(chat_id, clear_remote=False)
|
||||
ticket = await self._get_typing_ticket(chat_id, context_token)
|
||||
if not ticket:
|
||||
return
|
||||
try:
|
||||
await self._send_typing_status(chat_id, ticket, 1)
|
||||
ticket = await self._get_typing_ticket(chat_id, context_token)
|
||||
if not ticket:
|
||||
return
|
||||
await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING)
|
||||
except Exception as e:
|
||||
logger.debug("WeChat typing indicator failed for {}: {}", chat_id, e)
|
||||
logger.debug("WeChat typing indicator start failed for {}: {}", chat_id, e)
|
||||
return
|
||||
|
||||
async def typing_loop() -> None:
|
||||
try:
|
||||
while self._running:
|
||||
await asyncio.sleep(5)
|
||||
await self._send_typing_status(chat_id, ticket, 1)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug("WeChat typing keepalive stopped for {}: {}", chat_id, e)
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(typing_loop())
|
||||
async def keepalive() -> None:
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
|
||||
if stop_event.is_set():
|
||||
break
|
||||
try:
|
||||
await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
pass
|
||||
|
||||
task = asyncio.create_task(keepalive())
|
||||
task._typing_stop_event = stop_event # type: ignore[attr-defined]
|
||||
self._typing_tasks[chat_id] = task
|
||||
|
||||
async def _stop_typing(self, chat_id: str, *, clear_remote: bool) -> None:
|
||||
"""Stop typing indicator for a chat."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task and not task.done():
|
||||
stop_event = getattr(task, "_typing_stop_event", None)
|
||||
if stop_event:
|
||||
stop_event.set()
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
@ -843,11 +1065,12 @@ class WeixinChannel(BaseChannel):
|
||||
pass
|
||||
if not clear_remote:
|
||||
return
|
||||
ticket = self._typing_tickets.get(chat_id, "")
|
||||
entry = self._typing_tickets.get(chat_id)
|
||||
ticket = str(entry.get("ticket", "") or "") if isinstance(entry, dict) else ""
|
||||
if not ticket:
|
||||
return
|
||||
try:
|
||||
await self._send_typing_status(chat_id, ticket, 2)
|
||||
await self._send_typing(chat_id, ticket, TYPING_STATUS_CANCEL)
|
||||
except Exception as e:
|
||||
logger.debug("WeChat typing clear failed for {}: {}", chat_id, e)
|
||||
|
||||
@ -923,6 +1146,10 @@ class WeixinChannel(BaseChannel):
|
||||
upload_type = UPLOAD_MEDIA_VIDEO
|
||||
item_type = ITEM_VIDEO
|
||||
item_key = "video_item"
|
||||
elif ext in _VOICE_EXTS:
|
||||
upload_type = UPLOAD_MEDIA_VOICE
|
||||
item_type = ITEM_VOICE
|
||||
item_key = "voice_item"
|
||||
else:
|
||||
upload_type = UPLOAD_MEDIA_FILE
|
||||
item_type = ITEM_FILE
|
||||
@ -936,7 +1163,7 @@ class WeixinChannel(BaseChannel):
|
||||
# Matches aesEcbPaddedSize: Math.ceil((size + 1) / 16) * 16
|
||||
padded_size = ((raw_size + 1 + 15) // 16) * 16
|
||||
|
||||
# Step 1: Get upload URL (upload_param) from server
|
||||
# Step 1: Get upload URL from server (prefer upload_full_url, fallback to upload_param)
|
||||
file_key = os.urandom(16).hex()
|
||||
upload_body: dict[str, Any] = {
|
||||
"filekey": file_key,
|
||||
@ -951,22 +1178,27 @@ class WeixinChannel(BaseChannel):
|
||||
|
||||
assert self._client is not None
|
||||
upload_resp = await self._api_post("ilink/bot/getuploadurl", upload_body)
|
||||
logger.debug("WeChat getuploadurl response: {}", upload_resp)
|
||||
|
||||
upload_param = upload_resp.get("upload_param", "")
|
||||
if not upload_param:
|
||||
raise RuntimeError(f"getuploadurl returned no upload_param: {upload_resp}")
|
||||
upload_full_url = str(upload_resp.get("upload_full_url", "") or "").strip()
|
||||
upload_param = str(upload_resp.get("upload_param", "") or "")
|
||||
if not upload_full_url and not upload_param:
|
||||
raise RuntimeError(
|
||||
"getuploadurl returned no upload URL "
|
||||
f"(need upload_full_url or upload_param): {upload_resp}"
|
||||
)
|
||||
|
||||
# Step 2: AES-128-ECB encrypt and POST to CDN
|
||||
aes_key_b64 = base64.b64encode(aes_key_raw).decode()
|
||||
encrypted_data = _encrypt_aes_ecb(raw_data, aes_key_b64)
|
||||
|
||||
cdn_upload_url = (
|
||||
f"{self.config.cdn_base_url}/upload"
|
||||
f"?encrypted_query_param={quote(upload_param)}"
|
||||
f"&filekey={quote(file_key)}"
|
||||
)
|
||||
logger.debug("WeChat CDN POST url={} ciphertextSize={}", cdn_upload_url[:80], len(encrypted_data))
|
||||
if upload_full_url:
|
||||
cdn_upload_url = upload_full_url
|
||||
else:
|
||||
cdn_upload_url = (
|
||||
f"{self.config.cdn_base_url}/upload"
|
||||
f"?encrypted_query_param={quote(upload_param)}"
|
||||
f"&filekey={quote(file_key)}"
|
||||
)
|
||||
|
||||
cdn_resp = await self._client.post(
|
||||
cdn_upload_url,
|
||||
@ -982,7 +1214,6 @@ class WeixinChannel(BaseChannel):
|
||||
"CDN upload response missing x-encrypted-param header; "
|
||||
f"status={cdn_resp.status_code} headers={dict(cdn_resp.headers)}"
|
||||
)
|
||||
logger.debug("WeChat CDN upload success for {}, got download_param", p.name)
|
||||
|
||||
# Step 3: Send message with the media item
|
||||
# aes_key for CDNMedia is the hex key encoded as base64
|
||||
@ -1031,7 +1262,6 @@ class WeixinChannel(BaseChannel):
|
||||
raise RuntimeError(
|
||||
f"WeChat send media error (code {errcode}): {data.get('errmsg', '')}"
|
||||
)
|
||||
logger.info("WeChat media sent: {} (type={})", p.name, item_key)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -1103,23 +1333,42 @@ def _decrypt_aes_ecb(data: bytes, aes_key_b64: str) -> bytes:
|
||||
logger.warning("Failed to parse AES key, returning raw data: {}", e)
|
||||
return data
|
||||
|
||||
decrypted: bytes | None = None
|
||||
|
||||
try:
|
||||
from Crypto.Cipher import AES
|
||||
|
||||
cipher = AES.new(key, AES.MODE_ECB)
|
||||
return cipher.decrypt(data) # pycryptodome auto-strips PKCS7 with unpad
|
||||
decrypted = cipher.decrypt(data)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
if decrypted is None:
|
||||
try:
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
cipher_obj = Cipher(algorithms.AES(key), modes.ECB())
|
||||
decryptor = cipher_obj.decryptor()
|
||||
return decryptor.update(data) + decryptor.finalize()
|
||||
except ImportError:
|
||||
logger.warning("Cannot decrypt media: install 'pycryptodome' or 'cryptography'")
|
||||
cipher_obj = Cipher(algorithms.AES(key), modes.ECB())
|
||||
decryptor = cipher_obj.decryptor()
|
||||
decrypted = decryptor.update(data) + decryptor.finalize()
|
||||
except ImportError:
|
||||
logger.warning("Cannot decrypt media: install 'pycryptodome' or 'cryptography'")
|
||||
return data
|
||||
|
||||
return _pkcs7_unpad_safe(decrypted)
|
||||
|
||||
|
||||
def _pkcs7_unpad_safe(data: bytes, block_size: int = 16) -> bytes:
|
||||
"""Safely remove PKCS7 padding when valid; otherwise return original bytes."""
|
||||
if not data:
|
||||
return data
|
||||
if len(data) % block_size != 0:
|
||||
return data
|
||||
pad_len = data[-1]
|
||||
if pad_len < 1 or pad_len > block_size:
|
||||
return data
|
||||
if data[-pad_len:] != bytes([pad_len]) * pad_len:
|
||||
return data
|
||||
return data[:-pad_len]
|
||||
|
||||
|
||||
def _ext_for_type(media_type: str) -> str:
|
||||
|
||||
@ -415,6 +415,9 @@ def _make_provider(config: Config):
|
||||
api_base=p.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
elif backend == "github_copilot":
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
provider = AnthropicProvider(
|
||||
@ -491,6 +494,91 @@ def _migrate_cron_store(config: "Config") -> None:
|
||||
shutil.move(str(legacy_path), str(new_path))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OpenAI-Compatible API Server
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.command()
|
||||
def serve(
|
||||
port: int | None = typer.Option(None, "--port", "-p", help="API server port"),
|
||||
host: str | None = typer.Option(None, "--host", "-H", help="Bind address"),
|
||||
timeout: float | None = typer.Option(None, "--timeout", "-t", help="Per-request timeout (seconds)"),
|
||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"),
|
||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
):
|
||||
"""Start the OpenAI-compatible API server (/v1/chat/completions)."""
|
||||
try:
|
||||
from aiohttp import web # noqa: F401
|
||||
except ImportError:
|
||||
console.print("[red]aiohttp is required. Install with: pip install 'nanobot-ai[api]'[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
from loguru import logger
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.api.server import create_app
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
if verbose:
|
||||
logger.enable("nanobot")
|
||||
else:
|
||||
logger.disable("nanobot")
|
||||
|
||||
runtime_config = _load_runtime_config(config, workspace)
|
||||
api_cfg = runtime_config.api
|
||||
host = host if host is not None else api_cfg.host
|
||||
port = port if port is not None else api_cfg.port
|
||||
timeout = timeout if timeout is not None else api_cfg.timeout
|
||||
sync_workspace_templates(runtime_config.workspace_path)
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(runtime_config)
|
||||
session_manager = SessionManager(runtime_config.workspace_path)
|
||||
agent_loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=runtime_config.workspace_path,
|
||||
model=runtime_config.agents.defaults.model,
|
||||
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
|
||||
web_search_config=runtime_config.tools.web.search,
|
||||
web_proxy=runtime_config.tools.web.proxy or None,
|
||||
exec_config=runtime_config.tools.exec,
|
||||
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
|
||||
session_manager=session_manager,
|
||||
mcp_servers=runtime_config.tools.mcp_servers,
|
||||
channels_config=runtime_config.channels,
|
||||
timezone=runtime_config.agents.defaults.timezone,
|
||||
)
|
||||
|
||||
model_name = runtime_config.agents.defaults.model
|
||||
console.print(f"{__logo__} Starting OpenAI-compatible API server")
|
||||
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
|
||||
console.print(f" [cyan]Model[/cyan] : {model_name}")
|
||||
console.print(" [cyan]Session[/cyan] : api:default")
|
||||
console.print(f" [cyan]Timeout[/cyan] : {timeout}s")
|
||||
if host in {"0.0.0.0", "::"}:
|
||||
console.print(
|
||||
"[yellow]Warning:[/yellow] API is bound to all interfaces. "
|
||||
"Only do this behind a trusted network boundary, firewall, or reverse proxy."
|
||||
)
|
||||
console.print()
|
||||
|
||||
api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout)
|
||||
|
||||
async def on_startup(_app):
|
||||
await agent_loop._connect_mcp()
|
||||
|
||||
async def on_cleanup(_app):
|
||||
await agent_loop.close_mcp()
|
||||
|
||||
api_app.on_startup.append(on_startup)
|
||||
api_app.on_cleanup.append(on_cleanup)
|
||||
|
||||
web.run_app(api_app, host=host, port=port, print=lambda msg: logger.info(msg))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Gateway / Server
|
||||
# ============================================================================
|
||||
@ -1204,26 +1292,16 @@ def _login_openai_codex() -> None:
|
||||
|
||||
@_register_login("github_copilot")
|
||||
def _login_github_copilot() -> None:
|
||||
import asyncio
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
|
||||
|
||||
async def _trigger():
|
||||
client = AsyncOpenAI(
|
||||
api_key="dummy",
|
||||
base_url="https://api.githubcopilot.com",
|
||||
)
|
||||
await client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(_trigger())
|
||||
console.print("[green]✓ Authenticated with GitHub Copilot[/green]")
|
||||
from nanobot.providers.github_copilot_provider import login_github_copilot
|
||||
|
||||
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
|
||||
token = login_github_copilot(
|
||||
print_fn=lambda s: console.print(s),
|
||||
prompt_fn=lambda s: typer.prompt(s),
|
||||
)
|
||||
account = token.account_id or "GitHub"
|
||||
console.print(f"[green]✓ Authenticated with GitHub Copilot[/green] [dim]{account}[/dim]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Authentication error: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
@ -84,6 +84,16 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
|
||||
|
||||
async def cmd_help(ctx: CommandContext) -> OutboundMessage:
|
||||
"""Return available slash commands."""
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content=build_help_text(),
|
||||
metadata={"render_as": "text"},
|
||||
)
|
||||
|
||||
|
||||
def build_help_text() -> str:
|
||||
"""Build canonical help text shared across channels."""
|
||||
lines = [
|
||||
"🐈 nanobot commands:",
|
||||
"/new — Start a new conversation",
|
||||
@ -92,12 +102,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage:
|
||||
"/status — Show bot status",
|
||||
"/help — Show available commands",
|
||||
]
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
content="\n".join(lines),
|
||||
metadata={"render_as": "text"},
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def register_builtin_commands(router: CommandRouter) -> None:
|
||||
|
||||
@ -96,6 +96,14 @@ class HeartbeatConfig(Base):
|
||||
keep_recent_messages: int = 8
|
||||
|
||||
|
||||
class ApiConfig(Base):
|
||||
"""OpenAI-compatible API server configuration."""
|
||||
|
||||
host: str = "127.0.0.1" # Safer default: local-only bind.
|
||||
port: int = 8900
|
||||
timeout: float = 120.0 # Per-request timeout in seconds.
|
||||
|
||||
|
||||
class GatewayConfig(Base):
|
||||
"""Gateway/server configuration."""
|
||||
|
||||
@ -156,6 +164,7 @@ class Config(BaseSettings):
|
||||
agents: AgentsConfig = Field(default_factory=AgentsConfig)
|
||||
channels: ChannelsConfig = Field(default_factory=ChannelsConfig)
|
||||
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
|
||||
api: ApiConfig = Field(default_factory=ApiConfig)
|
||||
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
||||
tools: ToolsConfig = Field(default_factory=ToolsConfig)
|
||||
|
||||
|
||||
174
nanobot/nanobot.py
Normal file
174
nanobot/nanobot.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""High-level programmatic interface to nanobot."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.hook import AgentHook
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class RunResult:
|
||||
"""Result of a single agent run."""
|
||||
|
||||
content: str
|
||||
tools_used: list[str]
|
||||
messages: list[dict[str, Any]]
|
||||
|
||||
|
||||
class Nanobot:
|
||||
"""Programmatic facade for running the nanobot agent.
|
||||
|
||||
Usage::
|
||||
|
||||
bot = Nanobot.from_config()
|
||||
result = await bot.run("Summarize this repo", hooks=[MyHook()])
|
||||
print(result.content)
|
||||
"""
|
||||
|
||||
def __init__(self, loop: AgentLoop) -> None:
|
||||
self._loop = loop
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config_path: str | Path | None = None,
|
||||
*,
|
||||
workspace: str | Path | None = None,
|
||||
) -> Nanobot:
|
||||
"""Create a Nanobot instance from a config file.
|
||||
|
||||
Args:
|
||||
config_path: Path to ``config.json``. Defaults to
|
||||
``~/.nanobot/config.json``.
|
||||
workspace: Override the workspace directory from config.
|
||||
"""
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
resolved: Path | None = None
|
||||
if config_path is not None:
|
||||
resolved = Path(config_path).expanduser().resolve()
|
||||
if not resolved.exists():
|
||||
raise FileNotFoundError(f"Config not found: {resolved}")
|
||||
|
||||
config: Config = load_config(resolved)
|
||||
if workspace is not None:
|
||||
config.agents.defaults.workspace = str(
|
||||
Path(workspace).expanduser().resolve()
|
||||
)
|
||||
|
||||
provider = _make_provider(config)
|
||||
bus = MessageBus()
|
||||
defaults = config.agents.defaults
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=defaults.model,
|
||||
max_iterations=defaults.max_tool_iterations,
|
||||
context_window_tokens=defaults.context_window_tokens,
|
||||
web_search_config=config.tools.web.search,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
timezone=defaults.timezone,
|
||||
)
|
||||
return cls(loop)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
session_key: str = "sdk:default",
|
||||
hooks: list[AgentHook] | None = None,
|
||||
) -> RunResult:
|
||||
"""Run the agent once and return the result.
|
||||
|
||||
Args:
|
||||
message: The user message to process.
|
||||
session_key: Session identifier for conversation isolation.
|
||||
Different keys get independent history.
|
||||
hooks: Optional lifecycle hooks for this run.
|
||||
"""
|
||||
prev = self._loop._extra_hooks
|
||||
if hooks is not None:
|
||||
self._loop._extra_hooks = list(hooks)
|
||||
try:
|
||||
response = await self._loop.process_direct(
|
||||
message, session_key=session_key,
|
||||
)
|
||||
finally:
|
||||
self._loop._extra_hooks = prev
|
||||
|
||||
content = (response.content if response else None) or ""
|
||||
return RunResult(content=content, tools_used=[], messages=[])
|
||||
|
||||
|
||||
def _make_provider(config: Any) -> Any:
|
||||
"""Create the LLM provider from config (extracted from CLI)."""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
|
||||
if backend == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||
needs_key = not (p and p.api_key)
|
||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||
if needs_key and not exempt:
|
||||
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
||||
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "github_copilot":
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key, api_base=p.api_base, default_model=model
|
||||
)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
defaults = config.agents.defaults
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=defaults.temperature,
|
||||
max_tokens=defaults.max_tokens,
|
||||
reasoning_effort=defaults.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
@ -13,6 +13,7 @@ __all__ = [
|
||||
"AnthropicProvider",
|
||||
"OpenAICompatProvider",
|
||||
"OpenAICodexProvider",
|
||||
"GitHubCopilotProvider",
|
||||
"AzureOpenAIProvider",
|
||||
]
|
||||
|
||||
@ -20,12 +21,14 @@ _LAZY_IMPORTS = {
|
||||
"AnthropicProvider": ".anthropic_provider",
|
||||
"OpenAICompatProvider": ".openai_compat_provider",
|
||||
"OpenAICodexProvider": ".openai_codex_provider",
|
||||
"GitHubCopilotProvider": ".github_copilot_provider",
|
||||
"AzureOpenAIProvider": ".azure_openai_provider",
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
|
||||
@ -370,15 +370,22 @@ class AnthropicProvider(LLMProvider):
|
||||
|
||||
usage: dict[str, int] = {}
|
||||
if response.usage:
|
||||
input_tokens = response.usage.input_tokens
|
||||
cache_creation = getattr(response.usage, "cache_creation_input_tokens", 0) or 0
|
||||
cache_read = getattr(response.usage, "cache_read_input_tokens", 0) or 0
|
||||
total_prompt_tokens = input_tokens + cache_creation + cache_read
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.input_tokens,
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": response.usage.output_tokens,
|
||||
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
|
||||
"total_tokens": total_prompt_tokens + response.usage.output_tokens,
|
||||
}
|
||||
for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"):
|
||||
val = getattr(response.usage, attr, 0)
|
||||
if val:
|
||||
usage[attr] = val
|
||||
# Normalize to cached_tokens for downstream consistency.
|
||||
if cache_read:
|
||||
usage["cached_tokens"] = cache_read
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
|
||||
257
nanobot/providers/github_copilot_provider.py
Normal file
257
nanobot/providers/github_copilot_provider.py
Normal file
@ -0,0 +1,257 @@
|
||||
"""GitHub Copilot OAuth-backed provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import webbrowser
|
||||
from collections.abc import Callable
|
||||
|
||||
import httpx
|
||||
from oauth_cli_kit.models import OAuthToken
|
||||
from oauth_cli_kit.storage import FileTokenStorage
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
DEFAULT_GITHUB_DEVICE_CODE_URL = "https://github.com/login/device/code"
|
||||
DEFAULT_GITHUB_ACCESS_TOKEN_URL = "https://github.com/login/oauth/access_token"
|
||||
DEFAULT_GITHUB_USER_URL = "https://api.github.com/user"
|
||||
DEFAULT_COPILOT_TOKEN_URL = "https://api.github.com/copilot_internal/v2/token"
|
||||
DEFAULT_COPILOT_BASE_URL = "https://api.githubcopilot.com"
|
||||
GITHUB_COPILOT_CLIENT_ID = "Iv1.b507a08c87ecfe98"
|
||||
GITHUB_COPILOT_SCOPE = "read:user"
|
||||
TOKEN_FILENAME = "github-copilot.json"
|
||||
TOKEN_APP_NAME = "nanobot"
|
||||
USER_AGENT = "nanobot/0.1"
|
||||
EDITOR_VERSION = "vscode/1.99.0"
|
||||
EDITOR_PLUGIN_VERSION = "copilot-chat/0.26.0"
|
||||
_EXPIRY_SKEW_SECONDS = 60
|
||||
_LONG_LIVED_TOKEN_SECONDS = 315360000
|
||||
|
||||
|
||||
def _storage() -> FileTokenStorage:
|
||||
return FileTokenStorage(
|
||||
token_filename=TOKEN_FILENAME,
|
||||
app_name=TOKEN_APP_NAME,
|
||||
import_codex_cli=False,
|
||||
)
|
||||
|
||||
|
||||
def _copilot_headers(token: str) -> dict[str, str]:
|
||||
return {
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/json",
|
||||
"User-Agent": USER_AGENT,
|
||||
"Editor-Version": EDITOR_VERSION,
|
||||
"Editor-Plugin-Version": EDITOR_PLUGIN_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def _load_github_token() -> OAuthToken | None:
|
||||
token = _storage().load()
|
||||
if not token or not token.access:
|
||||
return None
|
||||
return token
|
||||
|
||||
|
||||
def get_github_copilot_login_status() -> OAuthToken | None:
|
||||
"""Return the persisted GitHub OAuth token if available."""
|
||||
return _load_github_token()
|
||||
|
||||
|
||||
def login_github_copilot(
|
||||
print_fn: Callable[[str], None] | None = None,
|
||||
prompt_fn: Callable[[str], str] | None = None,
|
||||
) -> OAuthToken:
|
||||
"""Run GitHub device flow and persist the GitHub OAuth token used for Copilot."""
|
||||
del prompt_fn
|
||||
printer = print_fn or print
|
||||
timeout = httpx.Timeout(20.0, connect=20.0)
|
||||
|
||||
with httpx.Client(timeout=timeout, follow_redirects=True, trust_env=True) as client:
|
||||
response = client.post(
|
||||
DEFAULT_GITHUB_DEVICE_CODE_URL,
|
||||
headers={"Accept": "application/json", "User-Agent": USER_AGENT},
|
||||
data={"client_id": GITHUB_COPILOT_CLIENT_ID, "scope": GITHUB_COPILOT_SCOPE},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
device_code = str(payload["device_code"])
|
||||
user_code = str(payload["user_code"])
|
||||
verify_url = str(payload.get("verification_uri") or payload.get("verification_uri_complete") or "")
|
||||
verify_complete = str(payload.get("verification_uri_complete") or verify_url)
|
||||
interval = max(1, int(payload.get("interval") or 5))
|
||||
expires_in = int(payload.get("expires_in") or 900)
|
||||
|
||||
printer(f"Open: {verify_url}")
|
||||
printer(f"Code: {user_code}")
|
||||
if verify_complete:
|
||||
try:
|
||||
webbrowser.open(verify_complete)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
deadline = time.time() + expires_in
|
||||
current_interval = interval
|
||||
access_token = None
|
||||
token_expires_in = _LONG_LIVED_TOKEN_SECONDS
|
||||
while time.time() < deadline:
|
||||
poll = client.post(
|
||||
DEFAULT_GITHUB_ACCESS_TOKEN_URL,
|
||||
headers={"Accept": "application/json", "User-Agent": USER_AGENT},
|
||||
data={
|
||||
"client_id": GITHUB_COPILOT_CLIENT_ID,
|
||||
"device_code": device_code,
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
},
|
||||
)
|
||||
poll.raise_for_status()
|
||||
poll_payload = poll.json()
|
||||
|
||||
access_token = poll_payload.get("access_token")
|
||||
if access_token:
|
||||
token_expires_in = int(poll_payload.get("expires_in") or _LONG_LIVED_TOKEN_SECONDS)
|
||||
break
|
||||
|
||||
error = poll_payload.get("error")
|
||||
if error == "authorization_pending":
|
||||
time.sleep(current_interval)
|
||||
continue
|
||||
if error == "slow_down":
|
||||
current_interval += 5
|
||||
time.sleep(current_interval)
|
||||
continue
|
||||
if error == "expired_token":
|
||||
raise RuntimeError("GitHub device code expired. Please run login again.")
|
||||
if error == "access_denied":
|
||||
raise RuntimeError("GitHub device flow was denied.")
|
||||
if error:
|
||||
desc = poll_payload.get("error_description") or error
|
||||
raise RuntimeError(str(desc))
|
||||
time.sleep(current_interval)
|
||||
else:
|
||||
raise RuntimeError("GitHub device flow timed out.")
|
||||
|
||||
user = client.get(
|
||||
DEFAULT_GITHUB_USER_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
"User-Agent": USER_AGENT,
|
||||
},
|
||||
)
|
||||
user.raise_for_status()
|
||||
user_payload = user.json()
|
||||
account_id = user_payload.get("login") or str(user_payload.get("id") or "") or None
|
||||
|
||||
expires_ms = int((time.time() + token_expires_in) * 1000)
|
||||
token = OAuthToken(
|
||||
access=str(access_token),
|
||||
refresh="",
|
||||
expires=expires_ms,
|
||||
account_id=str(account_id) if account_id else None,
|
||||
)
|
||||
_storage().save(token)
|
||||
return token
|
||||
|
||||
|
||||
class GitHubCopilotProvider(OpenAICompatProvider):
|
||||
"""Provider that exchanges a stored GitHub OAuth token for Copilot access tokens."""
|
||||
|
||||
def __init__(self, default_model: str = "github-copilot/gpt-4.1"):
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
self._copilot_access_token: str | None = None
|
||||
self._copilot_expires_at: float = 0.0
|
||||
super().__init__(
|
||||
api_key="no-key",
|
||||
api_base=DEFAULT_COPILOT_BASE_URL,
|
||||
default_model=default_model,
|
||||
extra_headers={
|
||||
"Editor-Version": EDITOR_VERSION,
|
||||
"Editor-Plugin-Version": EDITOR_PLUGIN_VERSION,
|
||||
"User-Agent": USER_AGENT,
|
||||
},
|
||||
spec=find_by_name("github_copilot"),
|
||||
)
|
||||
|
||||
async def _get_copilot_access_token(self) -> str:
|
||||
now = time.time()
|
||||
if self._copilot_access_token and now < self._copilot_expires_at - _EXPIRY_SKEW_SECONDS:
|
||||
return self._copilot_access_token
|
||||
|
||||
github_token = _load_github_token()
|
||||
if not github_token or not github_token.access:
|
||||
raise RuntimeError("GitHub Copilot is not logged in. Run: nanobot provider login github-copilot")
|
||||
|
||||
timeout = httpx.Timeout(20.0, connect=20.0)
|
||||
async with httpx.AsyncClient(timeout=timeout, follow_redirects=True, trust_env=True) as client:
|
||||
response = await client.get(
|
||||
DEFAULT_COPILOT_TOKEN_URL,
|
||||
headers=_copilot_headers(github_token.access),
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
|
||||
token = payload.get("token")
|
||||
if not token:
|
||||
raise RuntimeError("GitHub Copilot token exchange returned no token.")
|
||||
|
||||
expires_at = payload.get("expires_at")
|
||||
if isinstance(expires_at, (int, float)):
|
||||
self._copilot_expires_at = float(expires_at)
|
||||
else:
|
||||
refresh_in = payload.get("refresh_in") or 1500
|
||||
self._copilot_expires_at = time.time() + int(refresh_in)
|
||||
self._copilot_access_token = str(token)
|
||||
return self._copilot_access_token
|
||||
|
||||
async def _refresh_client_api_key(self) -> str:
|
||||
token = await self._get_copilot_access_token()
|
||||
self.api_key = token
|
||||
self._client.api_key = token
|
||||
return token
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, object]],
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, object] | None = None,
|
||||
):
|
||||
await self._refresh_client_api_key()
|
||||
return await super().chat(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, object]],
|
||||
tools: list[dict[str, object]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, object] | None = None,
|
||||
on_content_delta: Callable[[str], None] | None = None,
|
||||
):
|
||||
await self._refresh_client_api_key()
|
||||
return await super().chat_stream(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
reasoning_effort=reasoning_effort,
|
||||
tool_choice=tool_choice,
|
||||
on_content_delta=on_content_delta,
|
||||
)
|
||||
@ -235,7 +235,9 @@ class OpenAICompatProvider(LLMProvider):
|
||||
spec = self._spec
|
||||
|
||||
if spec and spec.supports_prompt_caching:
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
model_name = model or self.default_model
|
||||
if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")):
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
|
||||
if spec and spec.strip_model_prefix:
|
||||
model_name = model_name.split("/")[-1]
|
||||
@ -308,6 +310,13 @@ class OpenAICompatProvider(LLMProvider):
|
||||
|
||||
@classmethod
|
||||
def _extract_usage(cls, response: Any) -> dict[str, int]:
|
||||
"""Extract token usage from an OpenAI-compatible response.
|
||||
|
||||
Handles both dict-based (raw JSON) and object-based (SDK Pydantic)
|
||||
responses. Provider-specific ``cached_tokens`` fields are normalised
|
||||
under a single key; see the priority chain inside for details.
|
||||
"""
|
||||
# --- resolve usage object ---
|
||||
usage_obj = None
|
||||
response_map = cls._maybe_mapping(response)
|
||||
if response_map is not None:
|
||||
@ -317,19 +326,53 @@ class OpenAICompatProvider(LLMProvider):
|
||||
|
||||
usage_map = cls._maybe_mapping(usage_obj)
|
||||
if usage_map is not None:
|
||||
return {
|
||||
result = {
|
||||
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
|
||||
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
|
||||
"total_tokens": int(usage_map.get("total_tokens") or 0),
|
||||
}
|
||||
|
||||
if usage_obj:
|
||||
return {
|
||||
elif usage_obj:
|
||||
result = {
|
||||
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
|
||||
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
|
||||
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
|
||||
}
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
|
||||
# --- cached_tokens (normalised across providers) ---
|
||||
# Try nested paths first (dict), fall back to attribute (SDK object).
|
||||
# Priority order ensures the most specific field wins.
|
||||
for path in (
|
||||
("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI
|
||||
("cached_tokens",), # StepFun/Moonshot (top-level)
|
||||
("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow
|
||||
):
|
||||
cached = cls._get_nested_int(usage_map, path)
|
||||
if not cached and usage_obj:
|
||||
cached = cls._get_nested_int(usage_obj, path)
|
||||
if cached:
|
||||
result["cached_tokens"] = cached
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int:
|
||||
"""Drill into *obj* by *path* segments and return an ``int`` value.
|
||||
|
||||
Supports both dict-key access and attribute access so it works
|
||||
uniformly with raw JSON dicts **and** SDK Pydantic models.
|
||||
"""
|
||||
current = obj
|
||||
for segment in path:
|
||||
if current is None:
|
||||
return 0
|
||||
if isinstance(current, dict):
|
||||
current = current.get(segment)
|
||||
else:
|
||||
current = getattr(current, segment, None)
|
||||
return int(current or 0) if current is not None else 0
|
||||
|
||||
def _parse(self, response: Any) -> LLMResponse:
|
||||
if isinstance(response, str):
|
||||
@ -586,4 +629,4 @@ class OpenAICompatProvider(LLMProvider):
|
||||
return self._handle_error(e)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
return self.default_model
|
||||
@ -34,7 +34,7 @@ class ProviderSpec:
|
||||
display_name: str = "" # shown in `nanobot status`
|
||||
|
||||
# which provider implementation to use
|
||||
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex"
|
||||
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" | "github_copilot"
|
||||
backend: str = "openai_compat"
|
||||
|
||||
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
||||
@ -218,8 +218,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("github_copilot", "copilot"),
|
||||
env_key="",
|
||||
display_name="Github Copilot",
|
||||
backend="openai_compat",
|
||||
backend="github_copilot",
|
||||
default_api_base="https://api.githubcopilot.com",
|
||||
strip_model_prefix=True,
|
||||
is_oauth=True,
|
||||
),
|
||||
# DeepSeek: OpenAI-compatible at api.deepseek.com
|
||||
|
||||
@ -124,8 +124,8 @@ def build_assistant_message(
|
||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
if reasoning_content is not None:
|
||||
msg["reasoning_content"] = reasoning_content
|
||||
if reasoning_content is not None or thinking_blocks:
|
||||
msg["reasoning_content"] = reasoning_content if reasoning_content is not None else ""
|
||||
if thinking_blocks:
|
||||
msg["thinking_blocks"] = thinking_blocks
|
||||
return msg
|
||||
@ -255,14 +255,18 @@ def build_status_content(
|
||||
)
|
||||
last_in = last_usage.get("prompt_tokens", 0)
|
||||
last_out = last_usage.get("completion_tokens", 0)
|
||||
cached = last_usage.get("cached_tokens", 0)
|
||||
ctx_total = max(context_window_tokens, 0)
|
||||
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
|
||||
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
|
||||
ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a"
|
||||
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
|
||||
if cached and last_in:
|
||||
token_line += f" ({cached * 100 // last_in}% cached)"
|
||||
return "\n".join([
|
||||
f"\U0001f408 nanobot v{version}",
|
||||
f"\U0001f9e0 Model: {model}",
|
||||
f"\U0001f4ca Tokens: {last_in} in / {last_out} out",
|
||||
token_line,
|
||||
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
|
||||
f"\U0001f4ac Session: {session_msg_count} messages",
|
||||
f"\u23f1 Uptime: {uptime}",
|
||||
|
||||
@ -51,6 +51,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
api = [
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
]
|
||||
wecom = [
|
||||
"wecom-aibot-sdk-python>=0.1.5",
|
||||
]
|
||||
@ -64,12 +67,16 @@ matrix = [
|
||||
"mistune>=3.0.0,<4.0.0",
|
||||
"nh3>=0.2.17,<1.0.0",
|
||||
]
|
||||
discord = [
|
||||
"discord.py>=2.5.2,<3.0.0",
|
||||
]
|
||||
langsmith = [
|
||||
"langsmith>=0.1.0",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=9.0.0,<10.0.0",
|
||||
"pytest-asyncio>=1.3.0,<2.0.0",
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
"pytest-cov>=6.0.0,<7.0.0",
|
||||
"ruff>=0.1.0",
|
||||
]
|
||||
|
||||
351
tests/agent/test_hook_composite.py
Normal file
351
tests/agent/test_hook_composite.py
Normal file
@ -0,0 +1,351 @@
|
||||
"""Tests for CompositeHook fan-out, error isolation, and integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
|
||||
|
||||
def _ctx() -> AgentHookContext:
|
||||
return AgentHookContext(iteration=0, messages=[])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fan-out: every hook is called in order
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_fans_out_before_iteration():
|
||||
calls: list[str] = []
|
||||
|
||||
class H(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append(f"A:{context.iteration}")
|
||||
|
||||
class H2(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append(f"B:{context.iteration}")
|
||||
|
||||
hook = CompositeHook([H(), H2()])
|
||||
ctx = _ctx()
|
||||
await hook.before_iteration(ctx)
|
||||
assert calls == ["A:0", "B:0"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_fans_out_all_async_methods():
|
||||
"""Verify all async methods fan out to every hook."""
|
||||
events: list[str] = []
|
||||
|
||||
class RecordingHook(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append("before_iteration")
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
events.append(f"on_stream:{delta}")
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
events.append(f"on_stream_end:{resuming}")
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
events.append("before_execute_tools")
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append("after_iteration")
|
||||
|
||||
hook = CompositeHook([RecordingHook(), RecordingHook()])
|
||||
ctx = _ctx()
|
||||
|
||||
await hook.before_iteration(ctx)
|
||||
await hook.on_stream(ctx, "hi")
|
||||
await hook.on_stream_end(ctx, resuming=True)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
|
||||
assert events == [
|
||||
"before_iteration", "before_iteration",
|
||||
"on_stream:hi", "on_stream:hi",
|
||||
"on_stream_end:True", "on_stream_end:True",
|
||||
"before_execute_tools", "before_execute_tools",
|
||||
"after_iteration", "after_iteration",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error isolation: one hook raises, others still run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_before_iteration():
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append("good")
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
await hook.before_iteration(_ctx())
|
||||
assert calls == ["good"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_on_stream():
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
raise RuntimeError("stream-boom")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
calls.append(delta)
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
await hook.on_stream(_ctx(), "delta")
|
||||
assert calls == ["delta"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_all_async():
|
||||
"""Error isolation for on_stream_end, before_execute_tools, after_iteration."""
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
raise RuntimeError("err")
|
||||
async def before_execute_tools(self, context):
|
||||
raise RuntimeError("err")
|
||||
async def after_iteration(self, context):
|
||||
raise RuntimeError("err")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
calls.append("on_stream_end")
|
||||
async def before_execute_tools(self, context):
|
||||
calls.append("before_execute_tools")
|
||||
async def after_iteration(self, context):
|
||||
calls.append("after_iteration")
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
ctx = _ctx()
|
||||
await hook.on_stream_end(ctx, resuming=False)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
assert calls == ["on_stream_end", "before_execute_tools", "after_iteration"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# finalize_content: pipeline semantics (no error isolation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_composite_finalize_content_pipeline():
|
||||
class Upper(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
return content.upper() if content else content
|
||||
|
||||
class Suffix(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
return (content + "!") if content else content
|
||||
|
||||
hook = CompositeHook([Upper(), Suffix()])
|
||||
result = hook.finalize_content(_ctx(), "hello")
|
||||
assert result == "HELLO!"
|
||||
|
||||
|
||||
def test_composite_finalize_content_none_passthrough():
|
||||
hook = CompositeHook([AgentHook()])
|
||||
assert hook.finalize_content(_ctx(), None) is None
|
||||
|
||||
|
||||
def test_composite_finalize_content_ordering():
|
||||
"""First hook transforms first, result feeds second hook."""
|
||||
steps: list[str] = []
|
||||
|
||||
class H1(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
steps.append(f"H1:{content}")
|
||||
return content.upper()
|
||||
|
||||
class H2(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
steps.append(f"H2:{content}")
|
||||
return content + "!"
|
||||
|
||||
hook = CompositeHook([H1(), H2()])
|
||||
result = hook.finalize_content(_ctx(), "hi")
|
||||
assert result == "HI!"
|
||||
assert steps == ["H1:hi", "H2:HI"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wants_streaming: any-semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_composite_wants_streaming_any_true():
|
||||
class No(AgentHook):
|
||||
def wants_streaming(self):
|
||||
return False
|
||||
|
||||
class Yes(AgentHook):
|
||||
def wants_streaming(self):
|
||||
return True
|
||||
|
||||
hook = CompositeHook([No(), Yes(), No()])
|
||||
assert hook.wants_streaming() is True
|
||||
|
||||
|
||||
def test_composite_wants_streaming_all_false():
|
||||
hook = CompositeHook([AgentHook(), AgentHook()])
|
||||
assert hook.wants_streaming() is False
|
||||
|
||||
|
||||
def test_composite_wants_streaming_empty():
|
||||
hook = CompositeHook([])
|
||||
assert hook.wants_streaming() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty hooks list: behaves like no-op AgentHook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_empty_hooks_no_ops():
|
||||
hook = CompositeHook([])
|
||||
ctx = _ctx()
|
||||
await hook.before_iteration(ctx)
|
||||
await hook.on_stream(ctx, "delta")
|
||||
await hook.on_stream_end(ctx, resuming=False)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
assert hook.finalize_content(ctx, "test") == "test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: AgentLoop with extra hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_loop(tmp_path, hooks=None):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation.max_tokens = 4096
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as mock_sub_mgr, \
|
||||
patch("nanobot.agent.loop.MemoryConsolidator"):
|
||||
mock_sub_mgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, hooks=hooks,
|
||||
)
|
||||
return loop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
||||
"""Extra hook passed to AgentLoop is called alongside core LoopHook."""
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
class TrackingHook(AgentHook):
|
||||
async def before_iteration(self, context):
|
||||
events.append(f"before_iter:{context.iteration}")
|
||||
|
||||
async def after_iteration(self, context):
|
||||
events.append(f"after_iter:{context.iteration}")
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[TrackingHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(content="done", tool_calls=[], usage={})
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, tools_used, messages = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
assert content == "done"
|
||||
assert "before_iter:0" in events
|
||||
assert "after_iter:0" in events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_extra_hook_error_isolation(tmp_path):
|
||||
"""A faulty extra hook does not crash the agent loop."""
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
class BadHook(AgentHook):
|
||||
async def before_iteration(self, context):
|
||||
raise RuntimeError("I am broken")
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[BadHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(content="still works", tool_calls=[], usage={})
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
assert content == "still works"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_extra_hooks_do_not_swallow_loop_hook_errors(tmp_path):
|
||||
"""Extra hooks must not change the core LoopHook failure behavior."""
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[AgentHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})],
|
||||
usage={},
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
async def bad_progress(*args, **kwargs):
|
||||
raise RuntimeError("progress failed")
|
||||
|
||||
with pytest.raises(RuntimeError, match="progress failed"):
|
||||
await loop._run_agent_loop([], on_progress=bad_progress)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_no_hooks_backward_compat(tmp_path):
|
||||
"""Without hooks param, behavior is identical to before."""
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
content, tools_used, _ = await loop._run_agent_loop([])
|
||||
assert content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
assert tools_used == ["list_dir", "list_dir"]
|
||||
@ -333,3 +333,82 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
|
||||
args = mgr._announce_result.await_args.args
|
||||
assert args[3] == "Task completed but no final response was generated."
|
||||
assert args[5] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_accumulates_usage_and_preserves_cached_tokens():
|
||||
"""Runner should accumulate prompt/completion tokens across iterations
|
||||
and preserve cached_tokens from provider responses."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})],
|
||||
usage={"prompt_tokens": 100, "completion_tokens": 10, "cached_tokens": 80},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="done",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
))
|
||||
|
||||
# Usage should be accumulated across iterations
|
||||
assert result.usage["prompt_tokens"] == 300 # 100 + 200
|
||||
assert result.usage["completion_tokens"] == 30 # 10 + 20
|
||||
assert result.usage["cached_tokens"] == 230 # 80 + 150
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_passes_cached_tokens_to_hook_context():
|
||||
"""Hook context.usage should contain cached_tokens."""
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_usage: list[dict] = []
|
||||
|
||||
class UsageHook(AgentHook):
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
captured_usage.append(dict(context.usage))
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="done",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
hook=UsageHook(),
|
||||
))
|
||||
|
||||
assert len(captured_usage) == 1
|
||||
assert captured_usage[0]["cached_tokens"] == 150
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -116,6 +117,43 @@ class TestDispatch:
|
||||
out = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
assert out.content == "hi"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_streaming_preserves_message_metadata(self):
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop, bus = _make_loop()
|
||||
msg = InboundMessage(
|
||||
channel="matrix",
|
||||
sender_id="u1",
|
||||
chat_id="!room:matrix.org",
|
||||
content="hello",
|
||||
metadata={
|
||||
"_wants_stream": True,
|
||||
"thread_root_event_id": "$root1",
|
||||
"thread_reply_to_event_id": "$reply1",
|
||||
},
|
||||
)
|
||||
|
||||
async def fake_process(_msg, *, on_stream=None, on_stream_end=None, **kwargs):
|
||||
assert on_stream is not None
|
||||
assert on_stream_end is not None
|
||||
await on_stream("hi")
|
||||
await on_stream_end(resuming=False)
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process
|
||||
|
||||
await loop._dispatch(msg)
|
||||
first = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
second = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
|
||||
|
||||
assert first.metadata["thread_root_event_id"] == "$root1"
|
||||
assert first.metadata["thread_reply_to_event_id"] == "$reply1"
|
||||
assert first.metadata["_stream_delta"] is True
|
||||
assert second.metadata["thread_root_event_id"] == "$root1"
|
||||
assert second.metadata["thread_reply_to_event_id"] == "$reply1"
|
||||
assert second.metadata["_stream_end"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processing_lock_serializes(self):
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
@ -222,6 +260,39 @@ class TestSubagentCancellation:
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_exec_tool_not_registered_when_disabled(self, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
exec_config=ExecToolConfig(enable=False),
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
async def fake_run(spec):
|
||||
assert spec.tools.get("exec") is None
|
||||
return SimpleNamespace(
|
||||
stop_reason="done",
|
||||
final_content="done",
|
||||
error=None,
|
||||
tool_events=[],
|
||||
)
|
||||
|
||||
mgr.runner.run = AsyncMock(side_effect=fake_run)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
mgr.runner.run.assert_awaited_once()
|
||||
mgr._announce_result.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_announces_error_when_tool_execution_fails(self, monkeypatch, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
|
||||
676
tests/channels/test_discord_channel.py
Normal file
676
tests/channels/test_discord_channel.py
Normal file
@ -0,0 +1,676 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
discord = pytest.importorskip("discord")
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig
|
||||
from nanobot.command.builtin import build_help_text
|
||||
|
||||
|
||||
# Minimal Discord client test double used to control startup/readiness behavior.
|
||||
class _FakeDiscordClient:
|
||||
instances: list["_FakeDiscordClient"] = []
|
||||
start_error: Exception | None = None
|
||||
|
||||
def __init__(self, owner, *, intents) -> None:
|
||||
self.owner = owner
|
||||
self.intents = intents
|
||||
self.closed = False
|
||||
self.ready = True
|
||||
self.channels: dict[int, object] = {}
|
||||
self.user = SimpleNamespace(id=999)
|
||||
self.__class__.instances.append(self)
|
||||
|
||||
async def start(self, token: str) -> None:
|
||||
self.token = token
|
||||
if self.__class__.start_error is not None:
|
||||
raise self.__class__.start_error
|
||||
|
||||
async def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self.closed
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
return self.ready
|
||||
|
||||
def get_channel(self, channel_id: int):
|
||||
return self.channels.get(channel_id)
|
||||
|
||||
async def send_outbound(self, msg: OutboundMessage) -> None:
|
||||
channel = self.get_channel(int(msg.chat_id))
|
||||
if channel is None:
|
||||
return
|
||||
await channel.send(content=msg.content)
|
||||
|
||||
|
||||
class _FakeAttachment:
|
||||
# Attachment double that can simulate successful or failing save() calls.
|
||||
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
|
||||
self.id = attachment_id
|
||||
self.filename = filename
|
||||
self.size = size
|
||||
self._fail = fail
|
||||
|
||||
async def save(self, path: str | Path) -> None:
|
||||
if self._fail:
|
||||
raise RuntimeError("save failed")
|
||||
Path(path).write_bytes(b"attachment")
|
||||
|
||||
|
||||
class _FakePartialMessage:
|
||||
# Lightweight stand-in for Discord partial message references used in replies.
|
||||
def __init__(self, message_id: int) -> None:
|
||||
self.id = message_id
|
||||
|
||||
|
||||
class _FakeChannel:
|
||||
# Channel double that records outbound payloads and typing activity.
|
||||
def __init__(self, channel_id: int = 123) -> None:
|
||||
self.id = channel_id
|
||||
self.sent_payloads: list[dict] = []
|
||||
self.trigger_typing_calls = 0
|
||||
self.typing_enter_hook = None
|
||||
|
||||
async def send(self, **kwargs) -> None:
|
||||
payload = dict(kwargs)
|
||||
if "file" in payload:
|
||||
payload["file_name"] = payload["file"].filename
|
||||
del payload["file"]
|
||||
self.sent_payloads.append(payload)
|
||||
|
||||
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
|
||||
return _FakePartialMessage(message_id)
|
||||
|
||||
def typing(self):
|
||||
channel = self
|
||||
|
||||
class _TypingContext:
|
||||
async def __aenter__(self):
|
||||
channel.trigger_typing_calls += 1
|
||||
if channel.typing_enter_hook is not None:
|
||||
await channel.typing_enter_hook()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
return _TypingContext()
|
||||
|
||||
|
||||
class _FakeInteractionResponse:
|
||||
def __init__(self) -> None:
|
||||
self.messages: list[dict] = []
|
||||
self._done = False
|
||||
|
||||
async def send_message(self, content: str, *, ephemeral: bool = False) -> None:
|
||||
self.messages.append({"content": content, "ephemeral": ephemeral})
|
||||
self._done = True
|
||||
|
||||
def is_done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
|
||||
def _make_interaction(
|
||||
*,
|
||||
user_id: int = 123,
|
||||
channel_id: int | None = 456,
|
||||
guild_id: int | None = None,
|
||||
interaction_id: int = 999,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
user=SimpleNamespace(id=user_id),
|
||||
channel_id=channel_id,
|
||||
guild_id=guild_id,
|
||||
id=interaction_id,
|
||||
command=SimpleNamespace(qualified_name="new"),
|
||||
response=_FakeInteractionResponse(),
|
||||
)
|
||||
|
||||
|
||||
def _make_message(
|
||||
*,
|
||||
author_id: int = 123,
|
||||
author_bot: bool = False,
|
||||
channel_id: int = 456,
|
||||
message_id: int = 789,
|
||||
content: str = "hello",
|
||||
guild_id: int | None = None,
|
||||
mentions: list[object] | None = None,
|
||||
attachments: list[object] | None = None,
|
||||
reply_to: int | None = None,
|
||||
):
|
||||
# Factory for incoming Discord message objects with optional guild/reply/attachments.
|
||||
guild = SimpleNamespace(id=guild_id) if guild_id is not None else None
|
||||
reference = SimpleNamespace(message_id=reply_to) if reply_to is not None else None
|
||||
return SimpleNamespace(
|
||||
author=SimpleNamespace(id=author_id, bot=author_bot),
|
||||
channel=_FakeChannel(channel_id),
|
||||
content=content,
|
||||
guild=guild,
|
||||
mentions=mentions or [],
|
||||
attachments=attachments or [],
|
||||
reference=reference,
|
||||
id=message_id,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_returns_when_token_missing() -> None:
|
||||
# If no token is configured, startup should no-op and leave channel stopped.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_returns_when_discord_dependency_missing(monkeypatch) -> None:
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DISCORD_AVAILABLE", False)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_handles_client_construction_failure(monkeypatch) -> None:
|
||||
# Construction errors from the Discord client should be swallowed and keep state clean.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
def _boom(owner, *, intents):
|
||||
raise RuntimeError("bad client")
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_handles_client_start_failure(monkeypatch) -> None:
|
||||
# If client.start fails, the partially created client should be closed and detached.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
_FakeDiscordClient.instances.clear()
|
||||
_FakeDiscordClient.start_error = RuntimeError("connect failed")
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert channel._client is None
|
||||
assert _FakeDiscordClient.instances[0].intents.value == channel.config.intents
|
||||
assert _FakeDiscordClient.instances[0].closed is True
|
||||
|
||||
_FakeDiscordClient.start_error = None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_is_safe_after_partial_start(monkeypatch) -> None:
|
||||
# stop() should close/discard the client even when startup was only partially completed.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, token="token", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
await channel.stop()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert client.closed is True
|
||||
assert channel._client is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_ignores_bot_messages() -> None:
|
||||
# Incoming bot-authored messages must be ignored to prevent feedback loops.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
channel._handle_message = lambda **kwargs: handled.append(kwargs) # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(author_bot=True))
|
||||
|
||||
assert handled == []
|
||||
|
||||
# If inbound handling raises, typing should be stopped for that channel.
|
||||
async def fail_handle(**kwargs) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
channel._handle_message = fail_handle # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await channel._on_message(_make_message(author_id=123, channel_id=456))
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_accepts_allowlisted_dm() -> None:
|
||||
# Allowed direct messages should be forwarded with normalized metadata.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(author_id=123, channel_id=456, message_id=789))
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["chat_id"] == "456"
|
||||
assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_ignores_unmentioned_guild_message() -> None:
|
||||
# With mention-only group policy, guild messages without a bot mention are dropped.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._bot_user_id = "999"
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(guild_id=1, content="hello everyone"))
|
||||
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_accepts_mentioned_guild_message() -> None:
|
||||
# Mentioned guild messages should be accepted and preserve reply threading metadata.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], group_policy="mention"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._bot_user_id = "999"
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
guild_id=1,
|
||||
content="<@999> hello",
|
||||
mentions=[SimpleNamespace(id=999)],
|
||||
reply_to=321,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["metadata"]["reply_to"] == "321"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_downloads_attachments(tmp_path, monkeypatch) -> None:
|
||||
# Attachment downloads should be saved and referenced in forwarded content/media.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
attachments=[_FakeAttachment(12, "photo.png")],
|
||||
content="see file",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["media"] == [str(tmp_path / "12_photo.png")]
|
||||
assert "[attachment:" in handled[0]["content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_marks_failed_attachment_download(tmp_path, monkeypatch) -> None:
|
||||
# Failed attachment downloads should emit a readable placeholder and no media path.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
monkeypatch.setattr("nanobot.channels.discord.get_media_dir", lambda _name: tmp_path)
|
||||
|
||||
await channel._on_message(
|
||||
_make_message(
|
||||
attachments=[_FakeAttachment(12, "photo.png", fail=True)],
|
||||
content="",
|
||||
)
|
||||
)
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["media"] == []
|
||||
assert handled[0]["content"] == "[attachment: photo.png - download failed]"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_warns_when_client_not_ready() -> None:
|
||||
# Sending without a running/ready client should be a safe no-op.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_skips_when_channel_not_cached() -> None:
|
||||
# Outbound sends should be skipped when the destination channel is not resolvable.
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
fetch_calls: list[int] = []
|
||||
|
||||
async def fetch_channel(channel_id: int):
|
||||
fetch_calls.append(channel_id)
|
||||
raise RuntimeError("not found")
|
||||
|
||||
client.fetch_channel = fetch_channel # type: ignore[method-assign]
|
||||
|
||||
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
assert client.get_channel(123) is None
|
||||
assert fetch_calls == [123]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_fetches_channel_when_not_cached() -> None:
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
target = _FakeChannel(channel_id=123)
|
||||
|
||||
async def fetch_channel(channel_id: int):
|
||||
return target if channel_id == 123 else None
|
||||
|
||||
client.fetch_channel = fetch_channel # type: ignore[method-assign]
|
||||
|
||||
await client.send_outbound(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
assert target.sent_payloads == [{"content": "hello"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction(user_id=123, channel_id=456, interaction_id=321)
|
||||
|
||||
new_cmd = client.tree.get_command("new")
|
||||
assert new_cmd is not None
|
||||
await new_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": "Processing /new...", "ephemeral": True}
|
||||
]
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "/new"
|
||||
assert handled[0]["sender_id"] == "123"
|
||||
assert handled[0]["chat_id"] == "456"
|
||||
assert handled[0]["metadata"]["interaction_id"] == "321"
|
||||
assert handled[0]["metadata"]["is_slash_command"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_new_is_blocked_for_disallowed_user() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["999"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction(user_id=123, channel_id=456)
|
||||
|
||||
new_cmd = client.tree.get_command("new")
|
||||
assert new_cmd is not None
|
||||
await new_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": "You are not allowed to use this bot.", "ephemeral": True}
|
||||
]
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("slash_name", ["stop", "restart", "status"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_commands_forward_via_handle_message(slash_name: str) -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction()
|
||||
interaction.command.qualified_name = slash_name
|
||||
|
||||
cmd = client.tree.get_command(slash_name)
|
||||
assert cmd is not None
|
||||
await cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": f"Processing /{slash_name}...", "ephemeral": True}
|
||||
]
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == f"/{slash_name}"
|
||||
assert handled[0]["metadata"]["is_slash_command"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_help_returns_ephemeral_help_text() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
client = DiscordBotClient(channel, intents=discord.Intents.none())
|
||||
interaction = _make_interaction()
|
||||
interaction.command.qualified_name = "help"
|
||||
|
||||
help_cmd = client.tree.get_command("help")
|
||||
assert help_cmd is not None
|
||||
await help_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": build_help_text(), "ephemeral": True}
|
||||
]
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_send_outbound_chunks_text_replies_and_uploads_files(tmp_path) -> None:
|
||||
# Outbound payloads should upload files, attach reply references, and chunk long text.
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
target = _FakeChannel(channel_id=123)
|
||||
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
|
||||
|
||||
file_path = tmp_path / "demo.txt"
|
||||
file_path.write_text("hi")
|
||||
|
||||
await client.send_outbound(
|
||||
OutboundMessage(
|
||||
channel="discord",
|
||||
chat_id="123",
|
||||
content="a" * 2100,
|
||||
reply_to="55",
|
||||
media=[str(file_path)],
|
||||
)
|
||||
)
|
||||
|
||||
assert len(target.sent_payloads) == 3
|
||||
assert target.sent_payloads[0]["file_name"] == "demo.txt"
|
||||
assert target.sent_payloads[0]["reference"].id == 55
|
||||
assert target.sent_payloads[1]["content"] == "a" * 2000
|
||||
assert target.sent_payloads[2]["content"] == "a" * 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_send_outbound_reports_failed_attachments_when_no_text(tmp_path) -> None:
|
||||
# If all attachment sends fail and no text exists, emit a failure placeholder message.
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = DiscordBotClient(owner, intents=discord.Intents.none())
|
||||
target = _FakeChannel(channel_id=123)
|
||||
client.get_channel = lambda channel_id: target if channel_id == 123 else None # type: ignore[method-assign]
|
||||
|
||||
missing_file = tmp_path / "missing.txt"
|
||||
|
||||
await client.send_outbound(
|
||||
OutboundMessage(
|
||||
channel="discord",
|
||||
chat_id="123",
|
||||
content="",
|
||||
media=[str(missing_file)],
|
||||
)
|
||||
)
|
||||
|
||||
assert target.sent_payloads == [{"content": "[attachment: missing.txt - send failed]"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_stops_typing_after_send() -> None:
|
||||
# Active typing indicators should be cancelled/cleared after a successful send.
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
start = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def slow_typing() -> None:
|
||||
start.set()
|
||||
await release.wait()
|
||||
|
||||
typing_channel = _FakeChannel(channel_id=123)
|
||||
typing_channel.typing_enter_hook = slow_typing
|
||||
|
||||
await channel._start_typing(typing_channel)
|
||||
await start.wait()
|
||||
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
# Progress messages should keep typing active until a final (non-progress) send.
|
||||
start = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def slow_typing_progress() -> None:
|
||||
start.set()
|
||||
await release.wait()
|
||||
|
||||
typing_channel = _FakeChannel(channel_id=123)
|
||||
typing_channel.typing_enter_hook = slow_typing_progress
|
||||
|
||||
await channel._start_typing(typing_channel)
|
||||
await start.wait()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="discord",
|
||||
chat_id="123",
|
||||
content="progress",
|
||||
metadata={"_progress": True},
|
||||
)
|
||||
)
|
||||
|
||||
assert "123" in channel._typing_tasks
|
||||
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="final"))
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
channel._running = True
|
||||
|
||||
entered = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
class _TypingCtx:
|
||||
async def __aenter__(self):
|
||||
entered.set()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class _NoTriggerChannel:
|
||||
def __init__(self, channel_id: int = 123) -> None:
|
||||
self.id = channel_id
|
||||
|
||||
def typing(self):
|
||||
async def _waiter():
|
||||
await release.wait()
|
||||
# Hold the loop so task remains active until explicitly stopped.
|
||||
class _Ctx(_TypingCtx):
|
||||
async def __aenter__(self):
|
||||
await super().__aenter__()
|
||||
await _waiter()
|
||||
return _Ctx()
|
||||
|
||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||
await channel._start_typing(typing_channel) # type: ignore[arg-type]
|
||||
await entered.wait()
|
||||
|
||||
assert "123" in channel._typing_tasks
|
||||
|
||||
await channel._stop_typing("123")
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
@ -3,6 +3,9 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from nio import RoomSendResponse
|
||||
|
||||
from nanobot.channels.matrix import _build_matrix_text_content
|
||||
|
||||
# Check optional matrix dependencies before importing
|
||||
try:
|
||||
@ -65,6 +68,7 @@ class _FakeAsyncClient:
|
||||
self.raise_on_send = False
|
||||
self.raise_on_typing = False
|
||||
self.raise_on_upload = False
|
||||
self.room_send_response: RoomSendResponse | None = RoomSendResponse(event_id="", room_id="")
|
||||
|
||||
def add_event_callback(self, callback, event_type) -> None:
|
||||
self.callbacks.append((callback, event_type))
|
||||
@ -87,7 +91,7 @@ class _FakeAsyncClient:
|
||||
message_type: str,
|
||||
content: dict[str, object],
|
||||
ignore_unverified_devices: object = _ROOM_SEND_UNSET,
|
||||
) -> None:
|
||||
) -> RoomSendResponse:
|
||||
call: dict[str, object] = {
|
||||
"room_id": room_id,
|
||||
"message_type": message_type,
|
||||
@ -98,6 +102,7 @@ class _FakeAsyncClient:
|
||||
self.room_send_calls.append(call)
|
||||
if self.raise_on_send:
|
||||
raise RuntimeError("send failed")
|
||||
return self.room_send_response
|
||||
|
||||
async def room_typing(
|
||||
self,
|
||||
@ -520,6 +525,7 @@ async def test_on_message_room_mention_requires_opt_in() -> None:
|
||||
source={"content": {"m.mentions": {"room": True}}},
|
||||
)
|
||||
|
||||
channel.config.allow_room_mentions = False
|
||||
await channel._on_message(room, room_mention_event)
|
||||
assert handled == []
|
||||
assert client.typing_calls == []
|
||||
@ -1322,3 +1328,302 @@ async def test_send_keeps_plaintext_only_for_plain_text() -> None:
|
||||
"body": text,
|
||||
"m.mentions": {},
|
||||
}
|
||||
|
||||
|
||||
def test_build_matrix_text_content_basic_text() -> None:
|
||||
"""Test basic text content without HTML formatting."""
|
||||
result = _build_matrix_text_content("Hello, World!")
|
||||
expected = {
|
||||
"msgtype": "m.text",
|
||||
"body": "Hello, World!",
|
||||
"m.mentions": {}
|
||||
}
|
||||
assert expected == result
|
||||
|
||||
|
||||
def test_build_matrix_text_content_with_markdown() -> None:
|
||||
"""Test text content with markdown that renders to HTML."""
|
||||
text = "*Hello* **World**"
|
||||
result = _build_matrix_text_content(text)
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert result["body"] == text
|
||||
assert "format" in result
|
||||
assert result["format"] == "org.matrix.custom.html"
|
||||
assert "formatted_body" in result
|
||||
assert isinstance(result["formatted_body"], str)
|
||||
assert len(result["formatted_body"]) > 0
|
||||
|
||||
|
||||
def test_build_matrix_text_content_with_event_id() -> None:
|
||||
"""Test text content with event_id for message replacement."""
|
||||
event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
result = _build_matrix_text_content("Updated message", event_id)
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert result["m.new_content"]
|
||||
assert result["m.new_content"]["body"] == "Updated message"
|
||||
assert result["m.relates_to"]["rel_type"] == "m.replace"
|
||||
assert result["m.relates_to"]["event_id"] == event_id
|
||||
|
||||
|
||||
def test_build_matrix_text_content_with_event_id_preserves_thread_relation() -> None:
|
||||
"""Thread relations for edits should stay inside m.new_content."""
|
||||
relates_to = {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": "$root1",
|
||||
"m.in_reply_to": {"event_id": "$reply1"},
|
||||
"is_falling_back": True,
|
||||
}
|
||||
result = _build_matrix_text_content("Updated message", "event-1", relates_to)
|
||||
|
||||
assert result["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "event-1",
|
||||
}
|
||||
assert result["m.new_content"]["m.relates_to"] == relates_to
|
||||
|
||||
|
||||
def test_build_matrix_text_content_no_event_id() -> None:
|
||||
"""Test that when event_id is not provided, no extra properties are added."""
|
||||
result = _build_matrix_text_content("Regular message")
|
||||
|
||||
# Basic required properties should be present
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert result["body"] == "Regular message"
|
||||
|
||||
# Extra properties for replacement should NOT be present
|
||||
assert "m.relates_to" not in result
|
||||
assert "m.new_content" not in result
|
||||
assert "format" not in result
|
||||
assert "formatted_body" not in result
|
||||
|
||||
|
||||
def test_build_matrix_text_content_plain_text_no_html() -> None:
|
||||
"""Test plain text that should not include HTML formatting."""
|
||||
result = _build_matrix_text_content("Simple plain text")
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert "format" not in result
|
||||
assert "formatted_body" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_room_content_returns_room_send_response():
|
||||
"""Test that _send_room_content returns the response from client.room_send."""
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
channel.client = client
|
||||
|
||||
room_id = "!test_room:matrix.org"
|
||||
content = {"msgtype": "m.text", "body": "Hello World"}
|
||||
|
||||
result = await channel._send_room_content(room_id, content)
|
||||
|
||||
assert result is client.room_send_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_creates_stream_buffer_and_sends_initial_message() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello")
|
||||
|
||||
assert "!room:matrix.org" in channel._stream_bufs
|
||||
buf = channel._stream_bufs["!room:matrix.org"]
|
||||
assert buf.text == "Hello"
|
||||
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
assert len(client.room_send_calls) == 1
|
||||
assert client.room_send_calls[0]["content"]["body"] == "Hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_appends_without_sending_before_edit_interval(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
now = 100.0
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello")
|
||||
assert len(client.room_send_calls) == 1
|
||||
|
||||
await channel.send_delta("!room:matrix.org", " world")
|
||||
assert len(client.room_send_calls) == 1
|
||||
|
||||
buf = channel._stream_bufs["!room:matrix.org"]
|
||||
assert buf.text == "Hello world"
|
||||
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_edits_again_after_interval(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
times = [100.0, 102.0, 104.0, 106.0, 108.0]
|
||||
times.reverse()
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop())
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello")
|
||||
await channel.send_delta("!room:matrix.org", " world")
|
||||
|
||||
assert len(client.room_send_calls) == 2
|
||||
first_content = client.room_send_calls[0]["content"]
|
||||
second_content = client.room_send_calls[1]["content"]
|
||||
|
||||
assert "body" in first_content
|
||||
assert first_content["body"] == "Hello"
|
||||
assert "m.relates_to" not in first_content
|
||||
|
||||
assert "body" in second_content
|
||||
assert "m.relates_to" in second_content
|
||||
assert second_content["body"] == "Hello world"
|
||||
assert second_content["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_replaces_existing_message() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
channel._stream_bufs["!room:matrix.org"] = matrix_module._StreamBuf(
|
||||
text="Final text",
|
||||
event_id="event-1",
|
||||
last_edit=100.0,
|
||||
)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
|
||||
|
||||
assert "!room:matrix.org" not in channel._stream_bufs
|
||||
assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
|
||||
assert len(client.room_send_calls) == 1
|
||||
assert client.room_send_calls[0]["content"]["body"] == "Final text"
|
||||
assert client.room_send_calls[0]["content"]["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "event-1",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_starts_threaded_stream_inside_thread() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "event-1"
|
||||
|
||||
metadata = {
|
||||
"thread_root_event_id": "$root1",
|
||||
"thread_reply_to_event_id": "$reply1",
|
||||
}
|
||||
await channel.send_delta("!room:matrix.org", "Hello", metadata)
|
||||
|
||||
assert client.room_send_calls[0]["content"]["m.relates_to"] == {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": "$root1",
|
||||
"m.in_reply_to": {"event_id": "$reply1"},
|
||||
"is_falling_back": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_threaded_edit_keeps_replace_and_thread_relation(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "event-1"
|
||||
|
||||
times = [100.0, 102.0, 104.0]
|
||||
times.reverse()
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop())
|
||||
|
||||
metadata = {
|
||||
"thread_root_event_id": "$root1",
|
||||
"thread_reply_to_event_id": "$reply1",
|
||||
}
|
||||
await channel.send_delta("!room:matrix.org", "Hello", metadata)
|
||||
await channel.send_delta("!room:matrix.org", " world", metadata)
|
||||
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True, **metadata})
|
||||
|
||||
edit_content = client.room_send_calls[1]["content"]
|
||||
final_content = client.room_send_calls[2]["content"]
|
||||
|
||||
assert edit_content["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "event-1",
|
||||
}
|
||||
assert edit_content["m.new_content"]["m.relates_to"] == {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": "$root1",
|
||||
"m.in_reply_to": {"event_id": "$reply1"},
|
||||
"is_falling_back": True,
|
||||
}
|
||||
assert final_content["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "event-1",
|
||||
}
|
||||
assert final_content["m.new_content"]["m.relates_to"] == {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": "$root1",
|
||||
"m.in_reply_to": {"event_id": "$reply1"},
|
||||
"is_falling_back": True,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_noop_when_buffer_missing() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
|
||||
|
||||
assert client.room_send_calls == []
|
||||
assert client.typing_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_on_error_stops_typing(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
client.raise_on_send = True
|
||||
channel.client = client
|
||||
|
||||
now = 100.0
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello", {"room_id": "!room:matrix.org"})
|
||||
|
||||
assert "!room:matrix.org" in channel._stream_bufs
|
||||
assert channel._stream_bufs["!room:matrix.org"].text == "Hello"
|
||||
assert len(client.room_send_calls) == 1
|
||||
|
||||
assert len(client.typing_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
now = 100.0
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", " ")
|
||||
|
||||
assert "!room:matrix.org" in channel._stream_bufs
|
||||
assert channel._stream_bufs["!room:matrix.org"].text == " "
|
||||
assert client.room_send_calls == []
|
||||
@ -1,17 +1,22 @@
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
import nanobot.channels.weixin as weixin_mod
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.weixin import (
|
||||
ITEM_IMAGE,
|
||||
ITEM_TEXT,
|
||||
MESSAGE_TYPE_BOT,
|
||||
WEIXIN_CHANNEL_VERSION,
|
||||
_decrypt_aes_ecb,
|
||||
_encrypt_aes_ecb,
|
||||
WeixinChannel,
|
||||
WeixinConfig,
|
||||
)
|
||||
@ -42,10 +47,12 @@ def test_make_headers_includes_route_tag_when_configured() -> None:
|
||||
|
||||
assert headers["Authorization"] == "Bearer token"
|
||||
assert headers["SKRouteTag"] == "123"
|
||||
assert headers["iLink-App-Id"] == "bot"
|
||||
assert headers["iLink-App-ClientVersion"] == str((2 << 16) | (1 << 8) | 1)
|
||||
|
||||
|
||||
def test_channel_version_matches_reference_plugin_version() -> None:
|
||||
assert WEIXIN_CHANNEL_VERSION == "1.0.3"
|
||||
assert WEIXIN_CHANNEL_VERSION == "2.1.1"
|
||||
|
||||
|
||||
def test_save_and_load_state_persists_context_tokens(tmp_path) -> None:
|
||||
@ -169,6 +176,120 @@ async def test_process_message_extracts_media_and_preserves_paths() -> None:
|
||||
assert inbound.media == ["/tmp/test.jpg"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_falls_back_to_referenced_media_when_no_top_level_media() -> None:
|
||||
channel, bus = _make_channel()
|
||||
channel._download_media_item = AsyncMock(return_value="/tmp/ref.jpg")
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
"message_type": 1,
|
||||
"message_id": "m3-ref-fallback",
|
||||
"from_user_id": "wx-user",
|
||||
"context_token": "ctx-3-ref-fallback",
|
||||
"item_list": [
|
||||
{
|
||||
"type": ITEM_TEXT,
|
||||
"text_item": {"text": "reply to image"},
|
||||
"ref_msg": {
|
||||
"message_item": {
|
||||
"type": ITEM_IMAGE,
|
||||
"image_item": {"media": {"encrypt_query_param": "ref-enc"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
|
||||
|
||||
channel._download_media_item.assert_awaited_once_with(
|
||||
{"media": {"encrypt_query_param": "ref-enc"}},
|
||||
"image",
|
||||
)
|
||||
assert inbound.media == ["/tmp/ref.jpg"]
|
||||
assert "reply to image" in inbound.content
|
||||
assert "[image]" in inbound.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_does_not_use_referenced_fallback_when_top_level_media_exists() -> None:
|
||||
channel, bus = _make_channel()
|
||||
channel._download_media_item = AsyncMock(side_effect=["/tmp/top.jpg", "/tmp/ref.jpg"])
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
"message_type": 1,
|
||||
"message_id": "m3-ref-no-fallback",
|
||||
"from_user_id": "wx-user",
|
||||
"context_token": "ctx-3-ref-no-fallback",
|
||||
"item_list": [
|
||||
{"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "top-enc"}}},
|
||||
{
|
||||
"type": ITEM_TEXT,
|
||||
"text_item": {"text": "has top-level media"},
|
||||
"ref_msg": {
|
||||
"message_item": {
|
||||
"type": ITEM_IMAGE,
|
||||
"image_item": {"media": {"encrypt_query_param": "ref-enc"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
|
||||
|
||||
channel._download_media_item.assert_awaited_once_with(
|
||||
{"media": {"encrypt_query_param": "top-enc"}},
|
||||
"image",
|
||||
)
|
||||
assert inbound.media == ["/tmp/top.jpg"]
|
||||
assert "/tmp/ref.jpg" not in inbound.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_does_not_fallback_when_top_level_media_exists_but_download_fails() -> None:
|
||||
channel, bus = _make_channel()
|
||||
# Top-level image download fails (None), referenced image would succeed if fallback were triggered.
|
||||
channel._download_media_item = AsyncMock(side_effect=[None, "/tmp/ref.jpg"])
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
"message_type": 1,
|
||||
"message_id": "m3-ref-no-fallback-on-failure",
|
||||
"from_user_id": "wx-user",
|
||||
"context_token": "ctx-3-ref-no-fallback-on-failure",
|
||||
"item_list": [
|
||||
{"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "top-enc"}}},
|
||||
{
|
||||
"type": ITEM_TEXT,
|
||||
"text_item": {"text": "quoted has media"},
|
||||
"ref_msg": {
|
||||
"message_item": {
|
||||
"type": ITEM_IMAGE,
|
||||
"image_item": {"media": {"encrypt_query_param": "ref-enc"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
inbound = await asyncio.wait_for(bus.consume_inbound(), timeout=1.0)
|
||||
|
||||
# Should only attempt top-level media item; reference fallback must not activate.
|
||||
channel._download_media_item.assert_awaited_once_with(
|
||||
{"media": {"encrypt_query_param": "top-enc"}},
|
||||
"image",
|
||||
)
|
||||
assert inbound.media == []
|
||||
assert "[image]" in inbound.content
|
||||
assert "/tmp/ref.jpg" not in inbound.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_without_context_token_does_not_send_text() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
@ -199,6 +320,70 @@ async def test_send_does_not_send_when_session_is_paused() -> None:
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_typing_ticket_fetches_and_caches_per_user() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._api_post = AsyncMock(return_value={"ret": 0, "typing_ticket": "ticket-1"})
|
||||
|
||||
first = await channel._get_typing_ticket("wx-user", "ctx-1")
|
||||
second = await channel._get_typing_ticket("wx-user", "ctx-2")
|
||||
|
||||
assert first == "ticket-1"
|
||||
assert second == "ticket-1"
|
||||
channel._api_post.assert_awaited_once_with(
|
||||
"ilink/bot/getconfig",
|
||||
{"ilink_user_id": "wx-user", "context_token": "ctx-1", "base_info": weixin_mod.BASE_INFO},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_typing_start_and_cancel_when_ticket_available() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-typing"
|
||||
channel._send_text = AsyncMock()
|
||||
channel._api_post = AsyncMock(
|
||||
side_effect=[
|
||||
{"ret": 0, "typing_ticket": "ticket-typing"},
|
||||
{"ret": 0},
|
||||
{"ret": 0},
|
||||
]
|
||||
)
|
||||
|
||||
await channel.send(
|
||||
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||
)
|
||||
|
||||
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-typing")
|
||||
assert channel._api_post.await_count == 3
|
||||
assert channel._api_post.await_args_list[0].args[0] == "ilink/bot/getconfig"
|
||||
assert channel._api_post.await_args_list[1].args[0] == "ilink/bot/sendtyping"
|
||||
assert channel._api_post.await_args_list[1].args[1]["status"] == 1
|
||||
assert channel._api_post.await_args_list[2].args[0] == "ilink/bot/sendtyping"
|
||||
assert channel._api_post.await_args_list[2].args[1]["status"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_still_sends_text_when_typing_ticket_missing() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-no-ticket"
|
||||
channel._send_text = AsyncMock()
|
||||
channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "no config"})
|
||||
|
||||
await channel.send(
|
||||
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||
)
|
||||
|
||||
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-no-ticket")
|
||||
channel._api_post.assert_awaited_once()
|
||||
assert channel._api_post.await_args_list[0].args[0] == "ilink/bot/getconfig"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_pauses_session_on_expired_errcode() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
@ -220,8 +405,12 @@ async def test_qr_login_refreshes_expired_qr_and_then_succeeds() -> None:
|
||||
channel._api_get = AsyncMock(
|
||||
side_effect=[
|
||||
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
||||
]
|
||||
)
|
||||
channel._api_get_with_base = AsyncMock(
|
||||
side_effect=[
|
||||
{"status": "expired"},
|
||||
{
|
||||
"status": "confirmed",
|
||||
"bot_token": "token-2",
|
||||
@ -247,12 +436,16 @@ async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None:
|
||||
channel._api_get = AsyncMock(
|
||||
side_effect=[
|
||||
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-3", "qrcode_img_content": "url-3"},
|
||||
{"status": "expired"},
|
||||
{"qrcode": "qr-4", "qrcode_img_content": "url-4"},
|
||||
]
|
||||
)
|
||||
channel._api_get_with_base = AsyncMock(
|
||||
side_effect=[
|
||||
{"status": "expired"},
|
||||
{"status": "expired"},
|
||||
{"status": "expired"},
|
||||
{"status": "expired"},
|
||||
]
|
||||
)
|
||||
@ -262,6 +455,105 @@ async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None:
|
||||
assert ok is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qr_login_switches_polling_base_url_on_redirect_status() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._save_state = lambda: None
|
||||
channel._print_qr_code = lambda url: None
|
||||
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
|
||||
|
||||
status_side_effect = [
|
||||
{"status": "scaned_but_redirect", "redirect_host": "idc.redirect.test"},
|
||||
{
|
||||
"status": "confirmed",
|
||||
"bot_token": "token-3",
|
||||
"ilink_bot_id": "bot-3",
|
||||
"baseurl": "https://example.test",
|
||||
"ilink_user_id": "wx-user",
|
||||
},
|
||||
]
|
||||
channel._api_get = AsyncMock(side_effect=list(status_side_effect))
|
||||
channel._api_get_with_base = AsyncMock(side_effect=list(status_side_effect))
|
||||
|
||||
ok = await channel._qr_login()
|
||||
|
||||
assert ok is True
|
||||
assert channel._token == "token-3"
|
||||
assert channel._api_get_with_base.await_count == 2
|
||||
first_call = channel._api_get_with_base.await_args_list[0]
|
||||
second_call = channel._api_get_with_base.await_args_list[1]
|
||||
assert first_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
|
||||
assert second_call.kwargs["base_url"] == "https://idc.redirect.test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qr_login_redirect_without_host_keeps_current_polling_base_url() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._save_state = lambda: None
|
||||
channel._print_qr_code = lambda url: None
|
||||
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
|
||||
|
||||
status_side_effect = [
|
||||
{"status": "scaned_but_redirect"},
|
||||
{
|
||||
"status": "confirmed",
|
||||
"bot_token": "token-4",
|
||||
"ilink_bot_id": "bot-4",
|
||||
"baseurl": "https://example.test",
|
||||
"ilink_user_id": "wx-user",
|
||||
},
|
||||
]
|
||||
channel._api_get = AsyncMock(side_effect=list(status_side_effect))
|
||||
channel._api_get_with_base = AsyncMock(side_effect=list(status_side_effect))
|
||||
|
||||
ok = await channel._qr_login()
|
||||
|
||||
assert ok is True
|
||||
assert channel._token == "token-4"
|
||||
assert channel._api_get_with_base.await_count == 2
|
||||
first_call = channel._api_get_with_base.await_args_list[0]
|
||||
second_call = channel._api_get_with_base.await_args_list[1]
|
||||
assert first_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
|
||||
assert second_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qr_login_resets_redirect_base_url_after_qr_refresh() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._save_state = lambda: None
|
||||
channel._print_qr_code = lambda url: None
|
||||
channel._fetch_qr_code = AsyncMock(side_effect=[("qr-1", "url-1"), ("qr-2", "url-2")])
|
||||
|
||||
channel._api_get_with_base = AsyncMock(
|
||||
side_effect=[
|
||||
{"status": "scaned_but_redirect", "redirect_host": "idc.redirect.test"},
|
||||
{"status": "expired"},
|
||||
{
|
||||
"status": "confirmed",
|
||||
"bot_token": "token-5",
|
||||
"ilink_bot_id": "bot-5",
|
||||
"baseurl": "https://example.test",
|
||||
"ilink_user_id": "wx-user",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
ok = await channel._qr_login()
|
||||
|
||||
assert ok is True
|
||||
assert channel._token == "token-5"
|
||||
assert channel._api_get_with_base.await_count == 3
|
||||
first_call = channel._api_get_with_base.await_args_list[0]
|
||||
second_call = channel._api_get_with_base.await_args_list[1]
|
||||
third_call = channel._api_get_with_base.await_args_list[2]
|
||||
assert first_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
|
||||
assert second_call.kwargs["base_url"] == "https://idc.redirect.test"
|
||||
assert third_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_skips_bot_messages() -> None:
|
||||
channel, bus = _make_channel()
|
||||
@ -281,12 +573,13 @@ async def test_process_message_skips_bot_messages() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_fetches_typing_ticket_and_starts_typing() -> None:
|
||||
async def test_process_message_starts_typing_on_inbound() -> None:
|
||||
"""Typing indicator fires immediately when user message arrives."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._api_post = AsyncMock(return_value={"typing_ticket": "ticket-1"})
|
||||
channel._start_typing = AsyncMock()
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
@ -300,42 +593,42 @@ async def test_process_message_fetches_typing_ticket_and_starts_typing() -> None
|
||||
}
|
||||
)
|
||||
|
||||
assert channel._typing_tickets["wx-user"] == "ticket-1"
|
||||
assert "wx-user" in channel._typing_tasks
|
||||
await channel._stop_typing("wx-user", clear_remote=False)
|
||||
channel._start_typing.assert_awaited_once_with("wx-user", "ctx-typing")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_final_message_clears_typing_indicator() -> None:
|
||||
"""Non-progress send should cancel typing status."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-2"
|
||||
channel._typing_tickets["wx-user"] = "ticket-2"
|
||||
channel._typing_tickets["wx-user"] = {"ticket": "ticket-2", "next_fetch_at": 9999999999}
|
||||
channel._send_text = AsyncMock()
|
||||
channel._api_post = AsyncMock(return_value={})
|
||||
channel._api_post = AsyncMock(return_value={"ret": 0})
|
||||
|
||||
await channel.send(
|
||||
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||
)
|
||||
|
||||
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2")
|
||||
channel._api_post.assert_awaited_once()
|
||||
endpoint, body = channel._api_post.await_args.args
|
||||
assert endpoint == "ilink/bot/sendtyping"
|
||||
assert body["status"] == 2
|
||||
assert body["typing_ticket"] == "ticket-2"
|
||||
typing_cancel_calls = [
|
||||
c for c in channel._api_post.await_args_list
|
||||
if c.args[0] == "ilink/bot/sendtyping" and c.args[1]["status"] == 2
|
||||
]
|
||||
assert len(typing_cancel_calls) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_message_keeps_typing_indicator() -> None:
|
||||
"""Progress messages must not cancel typing status."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-2"
|
||||
channel._typing_tickets["wx-user"] = "ticket-2"
|
||||
channel._typing_tickets["wx-user"] = {"ticket": "ticket-2", "next_fetch_at": 9999999999}
|
||||
channel._send_text = AsyncMock()
|
||||
channel._api_post = AsyncMock(return_value={})
|
||||
channel._api_post = AsyncMock(return_value={"ret": 0})
|
||||
|
||||
await channel.send(
|
||||
type(
|
||||
@ -351,4 +644,362 @@ async def test_send_progress_message_keeps_typing_indicator() -> None:
|
||||
)
|
||||
|
||||
channel._send_text.assert_awaited_once_with("wx-user", "thinking", "ctx-2")
|
||||
channel._api_post.assert_not_awaited()
|
||||
typing_cancel_calls = [
|
||||
c for c in channel._api_post.await_args_list
|
||||
if c.args and c.args[0] == "ilink/bot/sendtyping" and c.args[1].get("status") == 2
|
||||
]
|
||||
assert len(typing_cancel_calls) == 0
|
||||
|
||||
|
||||
class _DummyHttpResponse:
|
||||
def __init__(self, *, headers: dict[str, str] | None = None, status_code: int = 200) -> None:
|
||||
self.headers = headers or {}
|
||||
self.status_code = status_code
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_uses_upload_full_url_when_present(tmp_path) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
|
||||
media_file = tmp_path / "photo.jpg"
|
||||
media_file.write_bytes(b"hello-weixin")
|
||||
|
||||
cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "dl-param"}))
|
||||
channel._client = SimpleNamespace(post=cdn_post)
|
||||
channel._api_post = AsyncMock(
|
||||
side_effect=[
|
||||
{
|
||||
"upload_full_url": "https://upload-full.example.test/path?foo=bar",
|
||||
"upload_param": "should-not-be-used",
|
||||
},
|
||||
{"ret": 0},
|
||||
]
|
||||
)
|
||||
|
||||
await channel._send_media_file("wx-user", str(media_file), "ctx-1")
|
||||
|
||||
# first POST call is CDN upload
|
||||
cdn_url = cdn_post.await_args_list[0].args[0]
|
||||
assert cdn_url == "https://upload-full.example.test/path?foo=bar"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_falls_back_to_upload_param_url(tmp_path) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
|
||||
media_file = tmp_path / "photo.jpg"
|
||||
media_file.write_bytes(b"hello-weixin")
|
||||
|
||||
cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "dl-param"}))
|
||||
channel._client = SimpleNamespace(post=cdn_post)
|
||||
channel._api_post = AsyncMock(
|
||||
side_effect=[
|
||||
{"upload_param": "enc-need-fallback"},
|
||||
{"ret": 0},
|
||||
]
|
||||
)
|
||||
|
||||
await channel._send_media_file("wx-user", str(media_file), "ctx-1")
|
||||
|
||||
cdn_url = cdn_post.await_args_list[0].args[0]
|
||||
assert cdn_url.startswith(f"{channel.config.cdn_base_url}/upload?encrypted_query_param=enc-need-fallback")
|
||||
assert "&filekey=" in cdn_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_voice_file_uses_voice_item_and_voice_upload_type(tmp_path) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
|
||||
media_file = tmp_path / "voice.mp3"
|
||||
media_file.write_bytes(b"voice-bytes")
|
||||
|
||||
cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "voice-dl-param"}))
|
||||
channel._client = SimpleNamespace(post=cdn_post)
|
||||
channel._api_post = AsyncMock(
|
||||
side_effect=[
|
||||
{"upload_full_url": "https://upload-full.example.test/voice?foo=bar"},
|
||||
{"ret": 0},
|
||||
]
|
||||
)
|
||||
|
||||
await channel._send_media_file("wx-user", str(media_file), "ctx-voice")
|
||||
|
||||
getupload_body = channel._api_post.await_args_list[0].args[1]
|
||||
assert getupload_body["media_type"] == 4
|
||||
|
||||
sendmessage_body = channel._api_post.await_args_list[1].args[1]
|
||||
item = sendmessage_body["msg"]["item_list"][0]
|
||||
assert item["type"] == 3
|
||||
assert "voice_item" in item
|
||||
assert "file_item" not in item
|
||||
assert item["voice_item"]["media"]["encrypt_query_param"] == "voice-dl-param"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_typing_uses_keepalive_until_send_finishes() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-typing-loop"
|
||||
async def _api_post_side_effect(endpoint: str, _body: dict | None = None, *, auth: bool = True):
|
||||
if endpoint == "ilink/bot/getconfig":
|
||||
return {"ret": 0, "typing_ticket": "ticket-keepalive"}
|
||||
return {"ret": 0}
|
||||
|
||||
channel._api_post = AsyncMock(side_effect=_api_post_side_effect)
|
||||
|
||||
async def _slow_send_text(*_args, **_kwargs) -> None:
|
||||
await asyncio.sleep(0.03)
|
||||
|
||||
channel._send_text = AsyncMock(side_effect=_slow_send_text)
|
||||
|
||||
old_interval = weixin_mod.TYPING_KEEPALIVE_INTERVAL_S
|
||||
weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = 0.01
|
||||
try:
|
||||
await channel.send(
|
||||
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||
)
|
||||
finally:
|
||||
weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = old_interval
|
||||
|
||||
status_calls = [
|
||||
c.args[1]["status"]
|
||||
for c in channel._api_post.await_args_list
|
||||
if c.args and c.args[0] == "ilink/bot/sendtyping"
|
||||
]
|
||||
assert status_calls.count(1) >= 2
|
||||
assert status_calls[-1] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_typing_ticket_failure_uses_backoff_and_cached_ticket(monkeypatch) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
|
||||
now = {"value": 1000.0}
|
||||
monkeypatch.setattr(weixin_mod.time, "time", lambda: now["value"])
|
||||
monkeypatch.setattr(weixin_mod.random, "random", lambda: 0.5)
|
||||
|
||||
channel._api_post = AsyncMock(return_value={"ret": 0, "typing_ticket": "ticket-ok"})
|
||||
first = await channel._get_typing_ticket("wx-user", "ctx-1")
|
||||
assert first == "ticket-ok"
|
||||
|
||||
# force refresh window reached
|
||||
now["value"] = now["value"] + (12 * 60 * 60) + 1
|
||||
channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "temporary failure"})
|
||||
|
||||
# On refresh failure, should still return cached ticket and apply backoff.
|
||||
second = await channel._get_typing_ticket("wx-user", "ctx-2")
|
||||
assert second == "ticket-ok"
|
||||
assert channel._api_post.await_count == 1
|
||||
|
||||
# Before backoff expiry, no extra fetch should happen.
|
||||
now["value"] += 1
|
||||
third = await channel._get_typing_ticket("wx-user", "ctx-3")
|
||||
assert third == "ticket-ok"
|
||||
assert channel._api_post.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qr_login_treats_temporary_connect_error_as_wait_and_recovers() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._save_state = lambda: None
|
||||
channel._print_qr_code = lambda url: None
|
||||
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
|
||||
|
||||
request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status")
|
||||
channel._api_get_with_base = AsyncMock(
|
||||
side_effect=[
|
||||
httpx.ConnectError("temporary network", request=request),
|
||||
{
|
||||
"status": "confirmed",
|
||||
"bot_token": "token-net-ok",
|
||||
"ilink_bot_id": "bot-id",
|
||||
"baseurl": "https://example.test",
|
||||
"ilink_user_id": "wx-user",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
ok = await channel._qr_login()
|
||||
|
||||
assert ok is True
|
||||
assert channel._token == "token-net-ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qr_login_treats_5xx_gateway_response_error_as_wait_and_recovers() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._save_state = lambda: None
|
||||
channel._print_qr_code = lambda url: None
|
||||
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
|
||||
|
||||
request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status")
|
||||
response = httpx.Response(status_code=524, request=request)
|
||||
channel._api_get_with_base = AsyncMock(
|
||||
side_effect=[
|
||||
httpx.HTTPStatusError("gateway timeout", request=request, response=response),
|
||||
{
|
||||
"status": "confirmed",
|
||||
"bot_token": "token-5xx-ok",
|
||||
"ilink_bot_id": "bot-id",
|
||||
"baseurl": "https://example.test",
|
||||
"ilink_user_id": "wx-user",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
ok = await channel._qr_login()
|
||||
|
||||
assert ok is True
|
||||
assert channel._token == "token-5xx-ok"
|
||||
|
||||
|
||||
def test_decrypt_aes_ecb_strips_valid_pkcs7_padding() -> None:
|
||||
key_b64 = "MDEyMzQ1Njc4OWFiY2RlZg==" # base64("0123456789abcdef")
|
||||
plaintext = b"hello-weixin-padding"
|
||||
|
||||
ciphertext = _encrypt_aes_ecb(plaintext, key_b64)
|
||||
decrypted = _decrypt_aes_ecb(ciphertext, key_b64)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
|
||||
class _DummyDownloadResponse:
|
||||
def __init__(self, content: bytes, status_code: int = 200) -> None:
|
||||
self.content = content
|
||||
self.status_code = status_code
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class _DummyErrorDownloadResponse(_DummyDownloadResponse):
|
||||
def __init__(self, url: str, status_code: int) -> None:
|
||||
super().__init__(content=b"", status_code=status_code)
|
||||
self._url = url
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
request = httpx.Request("GET", self._url)
|
||||
response = httpx.Response(self.status_code, request=request)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"download failed with status {self.status_code}",
|
||||
request=request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_media_item_uses_full_url_when_present(tmp_path) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
weixin_mod.get_media_dir = lambda _name: tmp_path
|
||||
|
||||
full_url = "https://cdn.example.test/download/full"
|
||||
channel._client = SimpleNamespace(
|
||||
get=AsyncMock(return_value=_DummyDownloadResponse(content=b"raw-image-bytes"))
|
||||
)
|
||||
|
||||
item = {
|
||||
"media": {
|
||||
"full_url": full_url,
|
||||
"encrypt_query_param": "enc-fallback-should-not-be-used",
|
||||
},
|
||||
}
|
||||
saved_path = await channel._download_media_item(item, "image")
|
||||
|
||||
assert saved_path is not None
|
||||
assert Path(saved_path).read_bytes() == b"raw-image-bytes"
|
||||
channel._client.get.assert_awaited_once_with(full_url)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_media_item_falls_back_when_full_url_returns_retryable_error(tmp_path) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
weixin_mod.get_media_dir = lambda _name: tmp_path
|
||||
|
||||
full_url = "https://cdn.example.test/download/full?taskid=123"
|
||||
channel._client = SimpleNamespace(
|
||||
get=AsyncMock(
|
||||
side_effect=[
|
||||
_DummyErrorDownloadResponse(full_url, 500),
|
||||
_DummyDownloadResponse(content=b"fallback-bytes"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
item = {
|
||||
"media": {
|
||||
"full_url": full_url,
|
||||
"encrypt_query_param": "enc-fallback",
|
||||
},
|
||||
}
|
||||
saved_path = await channel._download_media_item(item, "image")
|
||||
|
||||
assert saved_path is not None
|
||||
assert Path(saved_path).read_bytes() == b"fallback-bytes"
|
||||
assert channel._client.get.await_count == 2
|
||||
assert channel._client.get.await_args_list[0].args[0] == full_url
|
||||
fallback_url = channel._client.get.await_args_list[1].args[0]
|
||||
assert fallback_url.startswith(f"{channel.config.cdn_base_url}/download?encrypted_query_param=enc-fallback")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_media_item_falls_back_to_encrypt_query_param(tmp_path) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
weixin_mod.get_media_dir = lambda _name: tmp_path
|
||||
|
||||
channel._client = SimpleNamespace(
|
||||
get=AsyncMock(return_value=_DummyDownloadResponse(content=b"fallback-bytes"))
|
||||
)
|
||||
|
||||
item = {"media": {"encrypt_query_param": "enc-fallback"}}
|
||||
saved_path = await channel._download_media_item(item, "image")
|
||||
|
||||
assert saved_path is not None
|
||||
assert Path(saved_path).read_bytes() == b"fallback-bytes"
|
||||
called_url = channel._client.get.await_args_list[0].args[0]
|
||||
assert called_url.startswith(f"{channel.config.cdn_base_url}/download?encrypted_query_param=enc-fallback")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_media_item_does_not_retry_when_full_url_fails_without_fallback(tmp_path) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
weixin_mod.get_media_dir = lambda _name: tmp_path
|
||||
|
||||
full_url = "https://cdn.example.test/download/full"
|
||||
channel._client = SimpleNamespace(
|
||||
get=AsyncMock(return_value=_DummyErrorDownloadResponse(full_url, 500))
|
||||
)
|
||||
|
||||
item = {"media": {"full_url": full_url}}
|
||||
saved_path = await channel._download_media_item(item, "image")
|
||||
|
||||
assert saved_path is None
|
||||
channel._client.get.assert_awaited_once_with(full_url)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_media_item_non_image_requires_aes_key_even_with_full_url(tmp_path) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
weixin_mod.get_media_dir = lambda _name: tmp_path
|
||||
|
||||
full_url = "https://cdn.example.test/download/voice"
|
||||
channel._client = SimpleNamespace(
|
||||
get=AsyncMock(return_value=_DummyDownloadResponse(content=b"ciphertext-or-unknown"))
|
||||
)
|
||||
|
||||
item = {
|
||||
"media": {
|
||||
"full_url": full_url,
|
||||
},
|
||||
}
|
||||
saved_path = await channel._download_media_item(item, "voice")
|
||||
|
||||
assert saved_path is None
|
||||
channel._client.get.assert_not_awaited()
|
||||
|
||||
@ -317,6 +317,75 @@ def test_openai_compat_provider_passes_model_through():
|
||||
assert provider.get_default_model() == "github-copilot/gpt-5.3-codex"
|
||||
|
||||
|
||||
def test_make_provider_uses_github_copilot_backend():
|
||||
from nanobot.cli.commands import _make_provider
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"provider": "github-copilot",
|
||||
"model": "github-copilot/gpt-4.1",
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = _make_provider(config)
|
||||
|
||||
assert provider.__class__.__name__ == "GitHubCopilotProvider"
|
||||
|
||||
|
||||
def test_github_copilot_provider_strips_prefixed_model_name():
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1")
|
||||
|
||||
kwargs = provider._build_kwargs(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=None,
|
||||
model="github-copilot/gpt-5.1",
|
||||
max_tokens=16,
|
||||
temperature=0.1,
|
||||
reasoning_effort=None,
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
assert kwargs["model"] == "gpt-5.1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_github_copilot_provider_refreshes_client_api_key_before_chat():
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.api_key = "no-key"
|
||||
mock_client.chat.completions.create = AsyncMock(return_value={
|
||||
"choices": [{"message": {"content": "ok"}, "finish_reason": "stop"}],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
})
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI", return_value=mock_client):
|
||||
provider = GitHubCopilotProvider(default_model="github-copilot/gpt-5.1")
|
||||
|
||||
provider._get_copilot_access_token = AsyncMock(return_value="copilot-access-token")
|
||||
|
||||
response = await provider.chat(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="github-copilot/gpt-5.1",
|
||||
max_tokens=16,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
assert response.content == "ok"
|
||||
assert provider._client.api_key == "copilot-access-token"
|
||||
provider._get_copilot_access_token.assert_awaited_once()
|
||||
mock_client.chat.completions.create.assert_awaited_once()
|
||||
|
||||
|
||||
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
|
||||
assert _strip_model_prefix("openai-codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
assert _strip_model_prefix("openai_codex/gpt-5.1-codex") == "gpt-5.1-codex"
|
||||
@ -642,27 +711,105 @@ def test_heartbeat_retains_recent_messages_by_default():
|
||||
assert config.gateway.heartbeat.keep_recent_messages == 8
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||
def _write_instance_config(tmp_path: Path) -> Path:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
return config_file
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
def _stop_gateway_provider(_config) -> object:
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
|
||||
def _patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config: Config,
|
||||
*,
|
||||
set_config_path=None,
|
||||
sync_templates=None,
|
||||
make_provider=None,
|
||||
message_bus=None,
|
||||
session_manager=None,
|
||||
cron_service=None,
|
||||
get_cron_dir=None,
|
||||
) -> None:
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
lambda path: seen.__setitem__("config_path", path),
|
||||
set_config_path or (lambda _path: None),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
sync_templates or (lambda _path: None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
make_provider or (lambda _config: object()),
|
||||
)
|
||||
|
||||
if message_bus is not None:
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", message_bus)
|
||||
if session_manager is not None:
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", session_manager)
|
||||
if cron_service is not None:
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", cron_service)
|
||||
if get_cron_dir is not None:
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", get_cron_dir)
|
||||
|
||||
|
||||
def _patch_serve_runtime(monkeypatch, config: Config, seen: dict[str, object]) -> None:
|
||||
pytest.importorskip("aiohttp")
|
||||
|
||||
class _FakeApiApp:
|
||||
def __init__(self) -> None:
|
||||
self.on_startup: list[object] = []
|
||||
self.on_cleanup: list[object] = []
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, **kwargs) -> None:
|
||||
seen["workspace"] = kwargs["workspace"]
|
||||
|
||||
async def _connect_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
def _fake_create_app(agent_loop, model_name: str, request_timeout: float):
|
||||
seen["agent_loop"] = agent_loop
|
||||
seen["model_name"] = model_name
|
||||
seen["request_timeout"] = request_timeout
|
||||
return _FakeApiApp()
|
||||
|
||||
def _fake_run_app(api_app, host: str, port: int, print):
|
||||
seen["api_app"] = api_app
|
||||
seen["host"] = host
|
||||
seen["port"] = port
|
||||
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.api.server.create_app", _fake_create_app)
|
||||
monkeypatch.setattr("aiohttp.web.run_app", _fake_run_app)
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
set_config_path=lambda path: seen.__setitem__("config_path", path),
|
||||
sync_templates=lambda path: seen.__setitem__("workspace", path),
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
@ -673,24 +820,17 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa
|
||||
|
||||
|
||||
def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
override = tmp_path / "override-workspace"
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
lambda path: seen.__setitem__("workspace", path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
sync_templates=lambda path: seen.__setitem__("workspace", path),
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(
|
||||
@ -704,27 +844,23 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
|
||||
|
||||
|
||||
def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
cron_service=_StopCron,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
@ -735,10 +871,7 @@ def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path:
|
||||
def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
legacy_file = legacy_dir / "jobs.json"
|
||||
@ -748,20 +881,19 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
config = Config()
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
cron_service=_StopCron,
|
||||
get_cron_dir=lambda: legacy_dir,
|
||||
)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
@ -777,10 +909,7 @@ def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
legacy_dir = tmp_path / "global" / "cron"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
legacy_file = legacy_dir / "jobs.json"
|
||||
@ -791,20 +920,19 @@ def test_gateway_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
config.agents.defaults.workspace = str(custom_workspace)
|
||||
seen: dict[str, Path] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: object())
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
monkeypatch.setattr("nanobot.config.paths.get_cron_dir", lambda: legacy_dir)
|
||||
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
cron_service=_StopCron,
|
||||
get_cron_dir=lambda: legacy_dir,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
@ -856,19 +984,14 @@ def test_migrate_cron_store_skips_when_workspace_file_exists(tmp_path: Path) ->
|
||||
|
||||
|
||||
def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
@ -878,19 +1001,14 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_
|
||||
|
||||
|
||||
def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
make_provider=_stop_gateway_provider,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
|
||||
@ -899,6 +1017,63 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
|
||||
assert "port 18792" in result.stdout
|
||||
|
||||
|
||||
def test_serve_uses_api_config_defaults_and_workspace_override(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
config.api.host = "127.0.0.2"
|
||||
config.api.port = 18900
|
||||
config.api.timeout = 45.0
|
||||
override_workspace = tmp_path / "override-workspace"
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
_patch_serve_runtime(monkeypatch, config, seen)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["serve", "--config", str(config_file), "--workspace", str(override_workspace)],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["workspace"] == override_workspace
|
||||
assert seen["host"] == "127.0.0.2"
|
||||
assert seen["port"] == 18900
|
||||
assert seen["request_timeout"] == 45.0
|
||||
|
||||
|
||||
def test_serve_cli_options_override_api_config(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.api.host = "127.0.0.2"
|
||||
config.api.port = 18900
|
||||
config.api.timeout = 45.0
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
_patch_serve_runtime(monkeypatch, config, seen)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
[
|
||||
"serve",
|
||||
"--config",
|
||||
str(config_file),
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"18901",
|
||||
"--timeout",
|
||||
"46",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert seen["host"] == "127.0.0.1"
|
||||
assert seen["port"] == 18901
|
||||
assert seen["request_timeout"] == 46.0
|
||||
|
||||
|
||||
def test_channels_login_requires_channel_name() -> None:
|
||||
result = runner.invoke(app, ["channels", "login"])
|
||||
|
||||
|
||||
@ -152,10 +152,12 @@ class TestRestartCommand:
|
||||
])
|
||||
|
||||
await loop._run_agent_loop([])
|
||||
assert loop._last_usage == {"prompt_tokens": 9, "completion_tokens": 4}
|
||||
assert loop._last_usage["prompt_tokens"] == 9
|
||||
assert loop._last_usage["completion_tokens"] == 4
|
||||
|
||||
await loop._run_agent_loop([])
|
||||
assert loop._last_usage == {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
assert loop._last_usage["prompt_tokens"] == 0
|
||||
assert loop._last_usage["completion_tokens"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_falls_back_to_last_usage_when_context_estimate_missing(self):
|
||||
|
||||
@ -285,6 +285,28 @@ def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None:
|
||||
assert job.schedule.at_ms == expected
|
||||
|
||||
|
||||
def test_add_job_delivers_by_default(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool.set_context("telegram", "chat-1")
|
||||
|
||||
result = tool._add_job("Morning standup", 60, None, None, None)
|
||||
|
||||
assert result.startswith("Created job")
|
||||
job = tool._cron.list_jobs()[0]
|
||||
assert job.payload.deliver is True
|
||||
|
||||
|
||||
def test_add_job_can_disable_delivery(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool.set_context("telegram", "chat-1")
|
||||
|
||||
result = tool._add_job("Background refresh", 60, None, None, None, deliver=False)
|
||||
|
||||
assert result.startswith("Created job")
|
||||
job = tool._cron.list_jobs()[0]
|
||||
assert job.payload.deliver is False
|
||||
|
||||
|
||||
def test_list_excludes_disabled_jobs(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
job = tool._cron.add_job(
|
||||
|
||||
233
tests/providers/test_cached_tokens.py
Normal file
233
tests/providers/test_cached_tokens.py
Normal file
@ -0,0 +1,233 @@
|
||||
"""Tests for cached token extraction from OpenAI-compatible providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
class FakeUsage:
|
||||
"""Mimics an OpenAI SDK usage object (has attributes, not dict keys)."""
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class FakePromptDetails:
|
||||
"""Mimics prompt_tokens_details sub-object."""
|
||||
def __init__(self, cached_tokens=0):
|
||||
self.cached_tokens = cached_tokens
|
||||
|
||||
|
||||
class _FakeSpec:
|
||||
supports_prompt_caching = False
|
||||
model_id_prefix = None
|
||||
strip_model_prefix = False
|
||||
max_completion_tokens = False
|
||||
reasoning_effort = None
|
||||
|
||||
|
||||
def _provider():
|
||||
from unittest.mock import MagicMock
|
||||
p = OpenAICompatProvider.__new__(OpenAICompatProvider)
|
||||
p.client = MagicMock()
|
||||
p.spec = _FakeSpec()
|
||||
return p
|
||||
|
||||
|
||||
# Minimal valid choice so _parse reaches _extract_usage.
|
||||
_DICT_CHOICE = {"message": {"content": "Hello"}}
|
||||
|
||||
class _FakeMessage:
|
||||
content = "Hello"
|
||||
tool_calls = None
|
||||
|
||||
|
||||
class _FakeChoice:
|
||||
message = _FakeMessage()
|
||||
finish_reason = "stop"
|
||||
|
||||
|
||||
# --- dict-based response (raw JSON / mapping) ---
|
||||
|
||||
def test_extract_usage_openai_cached_tokens_dict():
|
||||
"""prompt_tokens_details.cached_tokens from a dict response."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 2000,
|
||||
"completion_tokens": 300,
|
||||
"total_tokens": 2300,
|
||||
"prompt_tokens_details": {"cached_tokens": 1200},
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
assert result.usage["prompt_tokens"] == 2000
|
||||
|
||||
|
||||
def test_extract_usage_deepseek_cached_tokens_dict():
|
||||
"""prompt_cache_hit_tokens from a DeepSeek dict response."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 1500,
|
||||
"completion_tokens": 200,
|
||||
"total_tokens": 1700,
|
||||
"prompt_cache_hit_tokens": 1200,
|
||||
"prompt_cache_miss_tokens": 300,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
|
||||
|
||||
def test_extract_usage_no_cached_tokens_dict():
|
||||
"""Response without any cache fields -> no cached_tokens key."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 200,
|
||||
"total_tokens": 1200,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert "cached_tokens" not in result.usage
|
||||
|
||||
|
||||
def test_extract_usage_openai_cached_zero_dict():
|
||||
"""cached_tokens=0 should NOT be included (same as existing fields)."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 2000,
|
||||
"completion_tokens": 300,
|
||||
"total_tokens": 2300,
|
||||
"prompt_tokens_details": {"cached_tokens": 0},
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert "cached_tokens" not in result.usage
|
||||
|
||||
|
||||
# --- object-based response (OpenAI SDK Pydantic model) ---
|
||||
|
||||
def test_extract_usage_openai_cached_tokens_obj():
|
||||
"""prompt_tokens_details.cached_tokens from an SDK object response."""
|
||||
p = _provider()
|
||||
usage_obj = FakeUsage(
|
||||
prompt_tokens=2000,
|
||||
completion_tokens=300,
|
||||
total_tokens=2300,
|
||||
prompt_tokens_details=FakePromptDetails(cached_tokens=1200),
|
||||
)
|
||||
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
|
||||
|
||||
def test_extract_usage_deepseek_cached_tokens_obj():
|
||||
"""prompt_cache_hit_tokens from a DeepSeek SDK object response."""
|
||||
p = _provider()
|
||||
usage_obj = FakeUsage(
|
||||
prompt_tokens=1500,
|
||||
completion_tokens=200,
|
||||
total_tokens=1700,
|
||||
prompt_cache_hit_tokens=1200,
|
||||
)
|
||||
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
|
||||
|
||||
def test_extract_usage_stepfun_top_level_cached_tokens_dict():
|
||||
"""StepFun/Moonshot: usage.cached_tokens at top level (not nested)."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 591,
|
||||
"completion_tokens": 120,
|
||||
"total_tokens": 711,
|
||||
"cached_tokens": 512,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 512
|
||||
|
||||
|
||||
def test_extract_usage_stepfun_top_level_cached_tokens_obj():
|
||||
"""StepFun/Moonshot: usage.cached_tokens as SDK object attribute."""
|
||||
p = _provider()
|
||||
usage_obj = FakeUsage(
|
||||
prompt_tokens=591,
|
||||
completion_tokens=120,
|
||||
total_tokens=711,
|
||||
cached_tokens=512,
|
||||
)
|
||||
response = FakeUsage(choices=[_FakeChoice()], usage=usage_obj)
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 512
|
||||
|
||||
|
||||
def test_extract_usage_priority_nested_over_top_level_dict():
|
||||
"""When both nested and top-level cached_tokens exist, nested wins."""
|
||||
p = _provider()
|
||||
response = {
|
||||
"choices": [_DICT_CHOICE],
|
||||
"usage": {
|
||||
"prompt_tokens": 2000,
|
||||
"completion_tokens": 300,
|
||||
"total_tokens": 2300,
|
||||
"prompt_tokens_details": {"cached_tokens": 100},
|
||||
"cached_tokens": 500,
|
||||
}
|
||||
}
|
||||
result = p._parse(response)
|
||||
assert result.usage["cached_tokens"] == 100
|
||||
|
||||
|
||||
def test_anthropic_maps_cache_fields_to_cached_tokens():
|
||||
"""Anthropic's cache_read_input_tokens should map to cached_tokens."""
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
usage_obj = FakeUsage(
|
||||
input_tokens=800,
|
||||
output_tokens=200,
|
||||
cache_creation_input_tokens=300,
|
||||
cache_read_input_tokens=1200,
|
||||
)
|
||||
content_block = FakeUsage(type="text", text="hello")
|
||||
response = FakeUsage(
|
||||
id="msg_1",
|
||||
type="message",
|
||||
stop_reason="end_turn",
|
||||
content=[content_block],
|
||||
usage=usage_obj,
|
||||
)
|
||||
result = AnthropicProvider._parse_response(response)
|
||||
assert result.usage["cached_tokens"] == 1200
|
||||
assert result.usage["prompt_tokens"] == 2300
|
||||
assert result.usage["total_tokens"] == 2500
|
||||
assert result.usage["cache_creation_input_tokens"] == 300
|
||||
|
||||
|
||||
def test_anthropic_no_cache_fields():
|
||||
"""Anthropic response without cache fields should not have cached_tokens."""
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
usage_obj = FakeUsage(input_tokens=800, output_tokens=200)
|
||||
content_block = FakeUsage(type="text", text="hello")
|
||||
response = FakeUsage(
|
||||
id="msg_1",
|
||||
type="message",
|
||||
stop_reason="end_turn",
|
||||
content=[content_block],
|
||||
usage=usage_obj,
|
||||
)
|
||||
result = AnthropicProvider._parse_response(response)
|
||||
assert "cached_tokens" not in result.usage
|
||||
@ -11,6 +11,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.github_copilot_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
|
||||
|
||||
providers = importlib.import_module("nanobot.providers")
|
||||
@ -18,6 +19,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
||||
assert "nanobot.providers.anthropic_provider" not in sys.modules
|
||||
assert "nanobot.providers.openai_compat_provider" not in sys.modules
|
||||
assert "nanobot.providers.openai_codex_provider" not in sys.modules
|
||||
assert "nanobot.providers.github_copilot_provider" not in sys.modules
|
||||
assert "nanobot.providers.azure_openai_provider" not in sys.modules
|
||||
assert providers.__all__ == [
|
||||
"LLMProvider",
|
||||
@ -25,6 +27,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
||||
"AnthropicProvider",
|
||||
"OpenAICompatProvider",
|
||||
"OpenAICodexProvider",
|
||||
"GitHubCopilotProvider",
|
||||
"AzureOpenAIProvider",
|
||||
]
|
||||
|
||||
|
||||
59
tests/test_build_status.py
Normal file
59
tests/test_build_status.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Tests for build_status_content cache hit rate display."""
|
||||
|
||||
from nanobot.utils.helpers import build_status_content
|
||||
|
||||
|
||||
def test_status_shows_cache_hit_rate():
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 1200},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=10,
|
||||
context_tokens_estimate=5000,
|
||||
)
|
||||
assert "60% cached" in content
|
||||
assert "2000 in / 300 out" in content
|
||||
|
||||
|
||||
def test_status_no_cache_info():
|
||||
"""Without cached_tokens, display should not show cache percentage."""
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 2000, "completion_tokens": 300},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=10,
|
||||
context_tokens_estimate=5000,
|
||||
)
|
||||
assert "cached" not in content.lower()
|
||||
assert "2000 in / 300 out" in content
|
||||
|
||||
|
||||
def test_status_zero_cached_tokens():
|
||||
"""cached_tokens=0 should not show cache percentage."""
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 2000, "completion_tokens": 300, "cached_tokens": 0},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=10,
|
||||
context_tokens_estimate=5000,
|
||||
)
|
||||
assert "cached" not in content.lower()
|
||||
|
||||
|
||||
def test_status_100_percent_cached():
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="glm-4-plus",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 1000, "completion_tokens": 100, "cached_tokens": 1000},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=5,
|
||||
context_tokens_estimate=3000,
|
||||
)
|
||||
assert "100% cached" in content
|
||||
168
tests/test_nanobot_facade.py
Normal file
168
tests/test_nanobot_facade.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""Tests for the Nanobot programmatic facade."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.nanobot import Nanobot, RunResult
|
||||
|
||||
|
||||
def _write_config(tmp_path: Path, overrides: dict | None = None) -> Path:
|
||||
data = {
|
||||
"providers": {"openrouter": {"apiKey": "sk-test-key"}},
|
||||
"agents": {"defaults": {"model": "openai/gpt-4.1"}},
|
||||
}
|
||||
if overrides:
|
||||
data.update(overrides)
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(json.dumps(data))
|
||||
return config_path
|
||||
|
||||
|
||||
def test_from_config_missing_file():
|
||||
with pytest.raises(FileNotFoundError):
|
||||
Nanobot.from_config("/nonexistent/config.json")
|
||||
|
||||
|
||||
def test_from_config_creates_instance(tmp_path):
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
assert bot._loop is not None
|
||||
assert bot._loop.workspace == tmp_path
|
||||
|
||||
|
||||
def test_from_config_default_path():
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
with patch("nanobot.config.loader.load_config") as mock_load, \
|
||||
patch("nanobot.nanobot._make_provider") as mock_prov:
|
||||
mock_load.return_value = Config()
|
||||
mock_prov.return_value = MagicMock()
|
||||
mock_prov.return_value.get_default_model.return_value = "test"
|
||||
mock_prov.return_value.generation.max_tokens = 4096
|
||||
Nanobot.from_config()
|
||||
mock_load.assert_called_once_with(None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_returns_result(tmp_path):
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
mock_response = OutboundMessage(
|
||||
channel="cli", chat_id="direct", content="Hello back!"
|
||||
)
|
||||
bot._loop.process_direct = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await bot.run("hi")
|
||||
|
||||
assert isinstance(result, RunResult)
|
||||
assert result.content == "Hello back!"
|
||||
bot._loop.process_direct.assert_awaited_once_with("hi", session_key="sdk:default")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_with_hooks(tmp_path):
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
class TestHook(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
pass
|
||||
|
||||
mock_response = OutboundMessage(
|
||||
channel="cli", chat_id="direct", content="done"
|
||||
)
|
||||
bot._loop.process_direct = AsyncMock(return_value=mock_response)
|
||||
|
||||
result = await bot.run("hi", hooks=[TestHook()])
|
||||
|
||||
assert result.content == "done"
|
||||
assert bot._loop._extra_hooks == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_hooks_restored_on_error(tmp_path):
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
from nanobot.agent.hook import AgentHook
|
||||
|
||||
bot._loop.process_direct = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
original_hooks = bot._loop._extra_hooks
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await bot.run("hi", hooks=[AgentHook()])
|
||||
|
||||
assert bot._loop._extra_hooks is original_hooks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_none_response(tmp_path):
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
bot._loop.process_direct = AsyncMock(return_value=None)
|
||||
|
||||
result = await bot.run("hi")
|
||||
assert result.content == ""
|
||||
|
||||
|
||||
def test_workspace_override(tmp_path):
|
||||
config_path = _write_config(tmp_path)
|
||||
custom_ws = tmp_path / "custom_workspace"
|
||||
custom_ws.mkdir()
|
||||
|
||||
bot = Nanobot.from_config(config_path, workspace=custom_ws)
|
||||
assert bot._loop.workspace == custom_ws
|
||||
|
||||
|
||||
def test_sdk_make_provider_uses_github_copilot_backend():
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.nanobot import _make_provider
|
||||
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"provider": "github-copilot",
|
||||
"model": "github-copilot/gpt-4.1",
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = _make_provider(config)
|
||||
|
||||
assert provider.__class__.__name__ == "GitHubCopilotProvider"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_custom_session_key(tmp_path):
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
config_path = _write_config(tmp_path)
|
||||
bot = Nanobot.from_config(config_path, workspace=tmp_path)
|
||||
|
||||
mock_response = OutboundMessage(
|
||||
channel="cli", chat_id="direct", content="ok"
|
||||
)
|
||||
bot._loop.process_direct = AsyncMock(return_value=mock_response)
|
||||
|
||||
await bot.run("hi", session_key="user-alice")
|
||||
bot._loop.process_direct.assert_awaited_once_with("hi", session_key="user-alice")
|
||||
|
||||
|
||||
def test_import_from_top_level():
|
||||
from nanobot import Nanobot as N, RunResult as R
|
||||
assert N is Nanobot
|
||||
assert R is RunResult
|
||||
371
tests/test_openai_api.py
Normal file
371
tests/test_openai_api.py
Normal file
@ -0,0 +1,371 @@
|
||||
"""Focused tests for the fixed-session OpenAI-compatible API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from nanobot.api.server import (
|
||||
API_CHAT_ID,
|
||||
API_SESSION_KEY,
|
||||
_chat_completion_response,
|
||||
_error_json,
|
||||
create_app,
|
||||
handle_chat_completions,
|
||||
)
|
||||
|
||||
try:
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
HAS_AIOHTTP = True
|
||||
except ImportError:
|
||||
HAS_AIOHTTP = False
|
||||
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
|
||||
agent = MagicMock()
|
||||
agent.process_direct = AsyncMock(return_value=response_text)
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
return _make_mock_agent()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_agent):
|
||||
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def aiohttp_client():
|
||||
clients: list[TestClient] = []
|
||||
|
||||
async def _make_client(app):
|
||||
client = TestClient(TestServer(app))
|
||||
await client.start_server()
|
||||
clients.append(client)
|
||||
return client
|
||||
|
||||
try:
|
||||
yield _make_client
|
||||
finally:
|
||||
for client in clients:
|
||||
await client.close()
|
||||
|
||||
|
||||
def test_error_json() -> None:
|
||||
resp = _error_json(400, "bad request")
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert body["error"]["message"] == "bad request"
|
||||
assert body["error"]["code"] == 400
|
||||
|
||||
|
||||
def test_chat_completion_response() -> None:
|
||||
result = _chat_completion_response("hello world", "test-model")
|
||||
assert result["object"] == "chat.completion"
|
||||
assert result["model"] == "test-model"
|
||||
assert result["choices"][0]["message"]["content"] == "hello world"
|
||||
assert result["choices"][0]["finish_reason"] == "stop"
|
||||
assert result["id"].startswith("chatcmpl-")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_messages_returns_400(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post("/v1/chat/completions", json={"model": "test"})
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_user_message_returns_400(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "system", "content": "you are a bot"}]},
|
||||
)
|
||||
assert resp.status == 400
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_true_returns_400(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}], "stream": True},
|
||||
)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert "stream" in body["error"]["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_mismatch_returns_400() -> None:
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"model": "other-model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
}
|
||||
)
|
||||
request.app = {
|
||||
"agent_loop": _make_mock_agent(),
|
||||
"model_name": "test-model",
|
||||
"request_timeout": 10.0,
|
||||
"session_lock": asyncio.Lock(),
|
||||
}
|
||||
|
||||
resp = await handle_chat_completions(request)
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert "test-model" in body["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_user_message_required() -> None:
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"messages": [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "previous reply"},
|
||||
],
|
||||
}
|
||||
)
|
||||
request.app = {
|
||||
"agent_loop": _make_mock_agent(),
|
||||
"model_name": "test-model",
|
||||
"request_timeout": 10.0,
|
||||
"session_lock": asyncio.Lock(),
|
||||
}
|
||||
|
||||
resp = await handle_chat_completions(request)
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert "single user message" in body["error"]["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_user_message_must_have_user_role() -> None:
|
||||
request = MagicMock()
|
||||
request.json = AsyncMock(
|
||||
return_value={
|
||||
"messages": [{"role": "system", "content": "you are a bot"}],
|
||||
}
|
||||
)
|
||||
request.app = {
|
||||
"agent_loop": _make_mock_agent(),
|
||||
"model_name": "test-model",
|
||||
"request_timeout": 10.0,
|
||||
"session_lock": asyncio.Lock(),
|
||||
}
|
||||
|
||||
resp = await handle_chat_completions(request)
|
||||
assert resp.status == 400
|
||||
body = json.loads(resp.body)
|
||||
assert "single user message" in body["error"]["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_agent) -> None:
|
||||
app = create_app(mock_agent, model_name="test-model")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "mock response"
|
||||
assert body["model"] == "test-model"
|
||||
mock_agent.process_direct.assert_called_once_with(
|
||||
content="hello",
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
|
||||
call_log: list[str] = []
|
||||
|
||||
async def fake_process(content, session_key="", channel="", chat_id=""):
|
||||
call_log.append(session_key)
|
||||
return f"reply to {content}"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = fake_process
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
r1 = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "first"}]},
|
||||
)
|
||||
r2 = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "second"}]},
|
||||
)
|
||||
|
||||
assert r1.status == 200
|
||||
assert r2.status == 200
|
||||
assert call_log == [API_SESSION_KEY, API_SESSION_KEY]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None:
|
||||
order: list[str] = []
|
||||
|
||||
async def slow_process(content, session_key="", channel="", chat_id=""):
|
||||
order.append(f"start:{content}")
|
||||
await asyncio.sleep(0.1)
|
||||
order.append(f"end:{content}")
|
||||
return content
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = slow_process
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
async def send(msg: str):
|
||||
return await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": msg}]},
|
||||
)
|
||||
|
||||
r1, r2 = await asyncio.gather(send("first"), send("second"))
|
||||
assert r1.status == 200
|
||||
assert r2.status == 200
|
||||
# Verify serialization: one process must fully finish before the other starts
|
||||
if order[0] == "start:first":
|
||||
assert order.index("end:first") < order.index("start:second")
|
||||
else:
|
||||
assert order.index("end:second") < order.index("start:first")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_endpoint(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/v1/models")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["object"] == "list"
|
||||
assert body["data"][0]["id"] == "test-model"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.get("/health")
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["status"] == "ok"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> None:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
mock_agent.process_direct.assert_called_once_with(
|
||||
content="describe this",
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_retry_then_success(aiohttp_client) -> None:
|
||||
call_count = 0
|
||||
|
||||
async def sometimes_empty(content, session_key="", channel="", chat_id=""):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return ""
|
||||
return "recovered response"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = sometimes_empty
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "recovered response"
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_falls_back(aiohttp_client) -> None:
|
||||
call_count = 0
|
||||
|
||||
async def always_empty(content, session_key="", channel="", chat_id=""):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return ""
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = always_empty
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give."
|
||||
assert call_count == 2
|
||||
@ -95,6 +95,14 @@ def test_exec_extract_absolute_paths_keeps_full_windows_path() -> None:
|
||||
assert paths == [r"C:\user\workspace\txt"]
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_windows_drive_root_path() -> None:
|
||||
"""Windows drive root paths like `E:\\` must be extracted for workspace guarding."""
|
||||
# Note: raw strings cannot end with a single backslash.
|
||||
cmd = "dir E:\\"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert paths == ["E:\\"]
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
|
||||
cmd = ".venv/bin/python script.py"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
@ -134,6 +142,45 @@ def test_exec_guard_blocks_quoted_home_path_outside_workspace(tmp_path) -> None:
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
def test_exec_guard_blocks_windows_drive_root_outside_workspace(monkeypatch) -> None:
|
||||
import nanobot.agent.tools.shell as shell_mod
|
||||
|
||||
class FakeWindowsPath:
|
||||
def __init__(self, raw: str) -> None:
|
||||
self.raw = raw.rstrip("\\") + ("\\" if raw.endswith("\\") else "")
|
||||
|
||||
def resolve(self) -> "FakeWindowsPath":
|
||||
return self
|
||||
|
||||
def expanduser(self) -> "FakeWindowsPath":
|
||||
return self
|
||||
|
||||
def is_absolute(self) -> bool:
|
||||
return len(self.raw) >= 3 and self.raw[1:3] == ":\\"
|
||||
|
||||
@property
|
||||
def parents(self) -> list["FakeWindowsPath"]:
|
||||
if not self.is_absolute():
|
||||
return []
|
||||
trimmed = self.raw.rstrip("\\")
|
||||
if len(trimmed) <= 2:
|
||||
return []
|
||||
idx = trimmed.rfind("\\")
|
||||
if idx <= 2:
|
||||
return [FakeWindowsPath(trimmed[:2] + "\\")]
|
||||
parent = FakeWindowsPath(trimmed[:idx])
|
||||
return [parent, *parent.parents]
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, FakeWindowsPath) and self.raw.lower() == other.raw.lower()
|
||||
|
||||
monkeypatch.setattr(shell_mod, "Path", FakeWindowsPath)
|
||||
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command("dir E:\\", "E:\\workspace")
|
||||
assert error == "Error: Command blocked by safety guard (path outside working dir)"
|
||||
|
||||
|
||||
# --- cast_params tests ---
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user