Merge origin/main into feat/api-file-upload

Keep the API file upload branch current with main, enforce the documented JSON base64 per-file limit, and avoid leaking document extraction error strings into user prompts.

Made-with: Cursor
This commit is contained in:
Xubin Ren 2026-04-14 12:29:43 +00:00
commit 2502fc616b
93 changed files with 14003 additions and 761 deletions

151
README.md
View File

@ -394,7 +394,8 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"],
"groupPolicy": "mention"
"groupPolicy": "mention",
"streaming": true
}
}
}
@ -405,6 +406,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
> - `"open"` — Respond to all messages
> DMs always respond when the sender is in `allowFrom`.
> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session.
> `streaming` defaults to `true`. Disable it only if you explicitly want non-streaming replies.
**5. Invite the bot**
- OAuth2 → URL Generator
@ -558,7 +560,11 @@ Uses **WebSocket** long connection — no public IP required.
"verificationToken": "",
"allowFrom": ["ou_YOUR_OPEN_ID"],
"groupPolicy": "mention",
"streaming": true
"reactEmoji": "OnIt",
"doneEmoji": "DONE",
"toolHintPrefix": "🔧",
"streaming": true,
"domain": "feishu"
}
}
}
@ -568,6 +574,10 @@ Uses **WebSocket** long connection — no public IP required.
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
> `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond.
> `reactEmoji`: Emoji for "processing" status (default: `OnIt`). See [available emojis](https://open.larkoffice.com/document/server-docs/im-v1/message-reaction/emojis-introduce).
> `doneEmoji`: Optional emoji for "completed" status (e.g., `DONE`, `OK`, `HEART`). When set, bot adds this reaction after removing `reactEmoji`.
> `toolHintPrefix`: Prefix for inline tool hints in streaming cards (default: `🔧`).
> `domain`: `"feishu"` (default) for China (open.feishu.cn), `"lark"` for international Lark (open.larksuite.com).
**3. Run**
@ -1043,6 +1053,30 @@ Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, To
```
> For local servers that don't require a key, set `apiKey` to any non-empty string (e.g. `"no-key"`).
>
> `custom` is the right choice for providers that expose an OpenAI-compatible **chat completions** API. It does **not** force third-party endpoints onto the OpenAI/Azure **Responses API**.
>
> If your proxy or gateway is specifically Responses-API-compatible, use the `azure_openai` provider shape instead and point `apiBase` at that endpoint:
>
> ```json
> {
> "providers": {
> "azure_openai": {
> "apiKey": "your-api-key",
> "apiBase": "https://api.your-provider.com",
> "defaultModel": "your-model-name"
> }
> },
> "agents": {
> "defaults": {
> "provider": "azure_openai",
> "model": "your-model-name"
> }
> }
> }
> ```
>
> In short: **chat-completions-compatible endpoint → `custom`**; **Responses-compatible endpoint → `azure_openai`**.
</details>
@ -1304,6 +1338,7 @@ If you need to allow trusted private ranges such as Tailscale / CGNAT addresses,
| `brave` | `apiKey` | `BRAVE_API_KEY` | No |
| `tavily` | `apiKey` | `TAVILY_API_KEY` | No |
| `jina` | `apiKey` | `JINA_API_KEY` | Free tier (10M tokens) |
| `kagi` | `apiKey` | `KAGI_API_KEY` | No |
| `searxng` | `baseUrl` | `SEARXNG_BASE_URL` | Yes (self-hosted) |
| `duckduckgo` (default) | — | — | Yes |
@ -1360,6 +1395,20 @@ If you need to allow trusted private ranges such as Tailscale / CGNAT addresses,
}
```
**Kagi:**
```json
{
"tools": {
"web": {
"search": {
"provider": "kagi",
"apiKey": "your-kagi-api-key"
}
}
}
}
```
**SearXNG** (self-hosted, no API key needed):
```json
{
@ -1495,6 +1544,35 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
**Docker security**: The official Docker image runs as a non-root user (`nanobot`, UID 1000) with bubblewrap pre-installed. When using `docker-compose.yml`, the container drops all Linux capabilities except `SYS_ADMIN` (required for bwrap's namespace isolation).
### Auto Compact
When a user is idle for longer than a configured threshold, nanobot **proactively** compresses the older part of the session context into a summary while keeping a recent legal suffix of live messages. This reduces token cost and first-token latency when the user returns — instead of re-processing a long stale context with an expired KV cache, the model receives a compact summary, the most recent live context, and fresh input.
```json
{
"agents": {
"defaults": {
"idleCompactAfterMinutes": 15
}
}
}
```
| Option | Default | Description |
|--------|---------|-------------|
| `agents.defaults.idleCompactAfterMinutes` | `0` (disabled) | Minutes of idle time before auto-compaction starts. Set to `0` to disable. Recommended: `15` — close to a typical LLM KV cache expiry window, so stale sessions get compacted before the user returns. |
`sessionTtlMinutes` remains accepted as a legacy alias for backward compatibility, but `idleCompactAfterMinutes` is the preferred config key going forward.
How it works:
1. **Idle detection**: On each idle tick (~1 s), checks all sessions for expiration.
2. **Background compaction**: Idle sessions summarize the older live prefix via LLM and keep the most recent legal suffix (currently 8 messages).
3. **Summary injection**: When the user returns, the summary is injected as runtime context (one-shot, not persisted) alongside the retained recent suffix.
4. **Restart-safe resume**: The summary is also mirrored into session metadata so it can still be recovered after a process restart.
> [!TIP]
> Think of auto compact as "summarize older context, keep the freshest live turns." It is not a hard session reset.
### Timezone
Time is context. Context should be precise.
@ -1517,6 +1595,52 @@ Common examples: `UTC`, `America/New_York`, `America/Los_Angeles`, `Europe/Londo
> Need another timezone? Browse the full [IANA Time Zone Database](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).
### Unified Session
By default, each channel × chat ID combination gets its own session. If you use nanobot across multiple channels (e.g. Telegram + Discord + CLI) and want them to share the same conversation, enable `unifiedSession`:
```json
{
"agents": {
"defaults": {
"unifiedSession": true
}
}
}
```
When enabled, all incoming messages — regardless of which channel they arrive on — are routed into a single shared session. Switching from Telegram to Discord (or any other channel) continues the same conversation seamlessly.
| Behavior | `false` (default) | `true` |
|----------|-------------------|--------|
| Session key | `channel:chat_id` | `unified:default` |
| Cross-channel continuity | No | Yes |
| `/new` clears | Current channel session | Shared session |
| `/stop` finds tasks | By channel session | By shared session |
| Existing `session_key_override` (e.g. Telegram thread) | Respected | Still respected — not overwritten |
> This is designed for single-user, multi-device setups. It is **off by default** — existing users see zero behavior change.
### Disabled Skills
nanobot ships with built-in skills, and your workspace can also define custom skills under `skills/`. If you want to hide specific skills from the agent, set `agents.defaults.disabledSkills` to a list of skill directory names:
```json
{
"agents": {
"defaults": {
"disabledSkills": ["github", "weather"]
}
}
}
```
Disabled skills are excluded from the main agent's skill summary, from always-on skill injection, and from subagent skill summaries. This is useful when some bundled skills are unnecessary for your deployment or should not be exposed to end users.
| Option | Default | Description |
|--------|---------|-------------|
| `agents.defaults.disabledSkills` | `[]` | List of skill directory names to exclude from loading. Applies to both built-in skills and workspace skills. |
## 🧩 Multiple Instances
Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint. Optionally pass `--workspace` during `onboard` when you want to initialize or update the saved workspace for a specific instance.
@ -1603,6 +1727,7 @@ Example config:
}
},
"gateway": {
"host": "127.0.0.1",
"port": 18790
}
}
@ -1615,6 +1740,14 @@ nanobot gateway --config ~/.nanobot-telegram/config.json
nanobot gateway --config ~/.nanobot-discord/config.json
```
Each gateway instance also exposes a lightweight HTTP health endpoint on
`gateway.host:gateway.port`. By default, the gateway binds to `127.0.0.1`,
so the endpoint stays local unless you explicitly set `gateway.host` to a
public or LAN-facing address.
- `GET /health` returns `{"status":"ok"}`
- Other paths return `404`
Override workspace for one-off runs when needed:
```bash
@ -1642,6 +1775,7 @@ time.
- `memory/history.jsonl` stores append-only summarized history
- `SOUL.md`, `USER.md`, and `memory/MEMORY.md` store long-term knowledge managed by Dream
- `Dream` can also promote repeated workflows into reusable workspace skills under `skills/`
- `Dream` runs on a schedule and can also be triggered manually
- memory changes can be inspected and restored with built-in commands
@ -1758,6 +1892,19 @@ By default, the API binds to `127.0.0.1:8900`. You can change this in `config.js
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
- No streaming: `stream=true` is not supported
- **File uploads**: supports images, PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) via JSON base64 or `multipart/form-data` (max 10MB per file)
- API requests run in the synthetic `api` channel, so the `message` tool does **not** automatically deliver to Telegram/Discord/etc. To proactively send to another chat, call `message` with an explicit `channel` and `chat_id` for an enabled channel.
Example tool call for cross-channel delivery from an API session:
```json
{
"content": "Build finished successfully.",
"channel": "telegram",
"chat_id": "123456789"
}
```
If `channel` points to a channel that is not enabled in your config, nanobot will queue the outbound event but no platform delivery will occur.
### Endpoints

View File

@ -290,7 +290,6 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] |
|------|---------|
| `_stream_delta: True` | A content chunk (delta contains the new text) |
| `_stream_end: True` | Streaming finished (delta is empty) |
| `_resuming: True` | More streaming rounds coming (e.g. tool call then another response) |
### Example: Webhook with Streaming

331
docs/WEBSOCKET.md Normal file
View File

@ -0,0 +1,331 @@
# WebSocket Server Channel
Nanobot can act as a WebSocket server, allowing external clients (web apps, CLIs, scripts) to interact with the agent in real time via persistent connections.
## Features
- Bidirectional real-time communication over WebSocket
- Streaming support — receive agent responses token by token
- Token-based authentication (static tokens and short-lived issued tokens)
- Per-connection sessions — each connection gets a unique `chat_id`
- TLS/SSL support (WSS) with enforced TLSv1.2 minimum
- Client allow-list via `allowFrom`
- Auto-cleanup of dead connections
## Quick Start
### 1. Configure
Add to `config.json` under `channels.websocket`:
```json
{
"channels": {
"websocket": {
"enabled": true,
"host": "127.0.0.1",
"port": 8765,
"path": "/",
"websocketRequiresToken": false,
"allowFrom": ["*"],
"streaming": true
}
}
}
```
### 2. Start nanobot
```bash
nanobot gateway
```
You should see:
```
WebSocket server listening on ws://127.0.0.1:8765/
```
### 3. Connect a client
```bash
# Using websocat
websocat ws://127.0.0.1:8765/?client_id=alice
# Using Python
import asyncio, json, websockets
async def main():
async with websockets.connect("ws://127.0.0.1:8765/?client_id=alice") as ws:
ready = json.loads(await ws.recv())
print(ready) # {"event": "ready", "chat_id": "...", "client_id": "alice"}
await ws.send(json.dumps({"content": "Hello nanobot!"}))
reply = json.loads(await ws.recv())
print(reply["text"])
asyncio.run(main())
```
## Connection URL
```
ws://{host}:{port}{path}?client_id={id}&token={token}
```
| Parameter | Required | Description |
|-----------|----------|-------------|
| `client_id` | No | Identifier for `allowFrom` authorization. Auto-generated as `anon-xxxxxxxxxxxx` if omitted. Truncated to 128 chars. |
| `token` | Conditional | Authentication token. Required when `websocketRequiresToken` is `true` or `token` (static secret) is configured. |
## Wire Protocol
All frames are JSON text. Each message has an `event` field.
### Server → Client
**`ready`** — sent immediately after connection is established:
```json
{
"event": "ready",
"chat_id": "uuid-v4",
"client_id": "alice"
}
```
**`message`** — full agent response:
```json
{
"event": "message",
"text": "Hello! How can I help?",
"media": ["/tmp/image.png"],
"reply_to": "msg-id"
}
```
`media` and `reply_to` are only present when applicable.
**`delta`** — streaming text chunk (only when `streaming: true`):
```json
{
"event": "delta",
"text": "Hello",
"stream_id": "s1"
}
```
**`stream_end`** — signals the end of a streaming segment:
```json
{
"event": "stream_end",
"stream_id": "s1"
}
```
### Client → Server
Send plain text:
```json
"Hello nanobot!"
```
Or send a JSON object with a recognized text field:
```json
{"content": "Hello nanobot!"}
```
Recognized fields: `content`, `text`, `message` (checked in that order). Invalid JSON is treated as plain text.
## Configuration Reference
All fields go under `channels.websocket` in `config.json`.
### Connection
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `enabled` | bool | `false` | Enable the WebSocket server. |
| `host` | string | `"127.0.0.1"` | Bind address. Use `"0.0.0.0"` to accept external connections. |
| `port` | int | `8765` | Listen port. |
| `path` | string | `"/"` | WebSocket upgrade path. Trailing slashes are normalized (root `/` is preserved). |
| `maxMessageBytes` | int | `1048576` | Maximum inbound message size in bytes (1 KB 16 MB). |
### Authentication
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `token` | string | `""` | Static shared secret. When set, clients must provide `?token=<value>` matching this secret (timing-safe comparison). Issued tokens are also accepted as a fallback. |
| `websocketRequiresToken` | bool | `true` | When `true` and no static `token` is configured, clients must still present a valid issued token. Set to `false` to allow unauthenticated connections (only safe for local/trusted networks). |
| `tokenIssuePath` | string | `""` | HTTP path for issuing short-lived tokens. Must differ from `path`. See [Token Issuance](#token-issuance). |
| `tokenIssueSecret` | string | `""` | Secret required to obtain tokens via the issue endpoint. If empty, any client can obtain tokens (logged as a warning). |
| `tokenTtlS` | int | `300` | Time-to-live for issued tokens in seconds (30 86,400). |
### Access Control
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `allowFrom` | list of string | `["*"]` | Allowed `client_id` values. `"*"` allows all; `[]` denies all. |
### Streaming
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `streaming` | bool | `true` | Enable streaming mode. The agent sends `delta` + `stream_end` frames instead of a single `message`. |
### Keep-alive
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `pingIntervalS` | float | `20.0` | WebSocket ping interval in seconds (5 300). |
| `pingTimeoutS` | float | `20.0` | Time to wait for a pong before closing the connection (5 300). |
### TLS/SSL
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `sslCertfile` | string | `""` | Path to the TLS certificate file (PEM). Both `sslCertfile` and `sslKeyfile` must be set to enable WSS. |
| `sslKeyfile` | string | `""` | Path to the TLS private key file (PEM). Minimum TLS version is enforced as TLSv1.2. |
## Token Issuance
For production deployments where `websocketRequiresToken: true`, use short-lived tokens instead of embedding static secrets in clients.
### How it works
1. Client sends `GET {tokenIssuePath}` with `Authorization: Bearer {tokenIssueSecret}` (or `X-Nanobot-Auth` header).
2. Server responds with a one-time-use token:
```json
{"token": "nbwt_aBcDeFg...", "expires_in": 300}
```
3. Client opens WebSocket with `?token=nbwt_aBcDeFg...&client_id=...`.
4. The token is consumed (single use) and cannot be reused.
### Example setup
```json
{
"channels": {
"websocket": {
"enabled": true,
"port": 8765,
"path": "/ws",
"tokenIssuePath": "/auth/token",
"tokenIssueSecret": "your-secret-here",
"tokenTtlS": 300,
"websocketRequiresToken": true,
"allowFrom": ["*"],
"streaming": true
}
}
}
```
Client flow:
```bash
# 1. Obtain a token
curl -H "Authorization: Bearer your-secret-here" http://127.0.0.1:8765/auth/token
# 2. Connect using the token
websocat "ws://127.0.0.1:8765/ws?client_id=alice&token=nbwt_aBcDeFg..."
```
### Limits
- Issued tokens are single-use — each token can only complete one handshake.
- Outstanding tokens are capped at 10,000. Requests beyond this return HTTP 429.
- Expired tokens are purged lazily on each issue or validation request.
## Security Notes
- **Timing-safe comparison**: Static token validation uses `hmac.compare_digest` to prevent timing attacks.
- **Defense in depth**: `allowFrom` is checked at both the HTTP handshake level and the message level.
- **Token isolation**: Each WebSocket connection gets a unique `chat_id`. Clients cannot access other sessions.
- **TLS enforcement**: When SSL is enabled, TLSv1.2 is the minimum allowed version.
- **Default-secure**: `websocketRequiresToken` defaults to `true`. Explicitly set it to `false` only on trusted networks.
## Media Files
Outbound `message` events may include a `media` field containing local filesystem paths. Remote clients cannot access these files directly — they need either:
- A shared filesystem mount, or
- An HTTP file server serving the nanobot media directory
## Common Patterns
### Trusted local network (no auth)
```json
{
"channels": {
"websocket": {
"enabled": true,
"host": "0.0.0.0",
"port": 8765,
"websocketRequiresToken": false,
"allowFrom": ["*"],
"streaming": true
}
}
}
```
### Static token (simple auth)
```json
{
"channels": {
"websocket": {
"enabled": true,
"token": "my-shared-secret",
"allowFrom": ["alice", "bob"]
}
}
}
```
Clients connect with `?token=my-shared-secret&client_id=alice`.
### Public endpoint with issued tokens
```json
{
"channels": {
"websocket": {
"enabled": true,
"host": "0.0.0.0",
"port": 8765,
"path": "/ws",
"tokenIssuePath": "/auth/token",
"tokenIssueSecret": "production-secret",
"websocketRequiresToken": true,
"sslCertfile": "/etc/ssl/certs/server.pem",
"sslKeyfile": "/etc/ssl/private/server-key.pem",
"allowFrom": ["*"]
}
}
}
```
### Custom path
```json
{
"channels": {
"websocket": {
"enabled": true,
"path": "/chat/ws",
"allowFrom": ["*"]
}
}
}
```
Clients connect to `ws://127.0.0.1:8765/chat/ws?client_id=...`. Trailing slashes are normalized, so `/chat/ws/` works the same.

View File

@ -2,7 +2,29 @@
nanobot - A lightweight AI agent framework
"""
__version__ = "0.1.5"
from importlib.metadata import PackageNotFoundError, version as _pkg_version
from pathlib import Path
import tomllib
def _read_pyproject_version() -> str | None:
"""Read the source-tree version when package metadata is unavailable."""
pyproject = Path(__file__).resolve().parent.parent / "pyproject.toml"
if not pyproject.exists():
return None
data = tomllib.loads(pyproject.read_text(encoding="utf-8"))
return data.get("project", {}).get("version")
def _resolve_version() -> str:
try:
return _pkg_version("nanobot-ai")
except PackageNotFoundError:
# Source checkouts often import nanobot without installed dist-info.
return _read_pyproject_version() or "0.1.5"
__version__ = _resolve_version()
__logo__ = "🐈"
from nanobot.nanobot import Nanobot, RunResult

View File

@ -0,0 +1,123 @@
"""Auto compact: proactive compression of idle sessions to reduce token cost and latency."""
from __future__ import annotations
from collections.abc import Collection
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Coroutine
from loguru import logger
from nanobot.session.manager import Session, SessionManager
if TYPE_CHECKING:
from nanobot.agent.memory import Consolidator
class AutoCompact:
_RECENT_SUFFIX_MESSAGES = 8
def __init__(self, sessions: SessionManager, consolidator: Consolidator,
session_ttl_minutes: int = 0):
self.sessions = sessions
self.consolidator = consolidator
self._ttl = session_ttl_minutes
self._archiving: set[str] = set()
self._summaries: dict[str, tuple[str, datetime]] = {}
def _is_expired(self, ts: datetime | str | None,
now: datetime | None = None) -> bool:
if self._ttl <= 0 or not ts:
return False
if isinstance(ts, str):
ts = datetime.fromisoformat(ts)
return ((now or datetime.now()) - ts).total_seconds() >= self._ttl * 60
@staticmethod
def _format_summary(text: str, last_active: datetime) -> str:
idle_min = int((datetime.now() - last_active).total_seconds() / 60)
return f"Inactive for {idle_min} minutes.\nPrevious conversation summary: {text}"
def _split_unconsolidated(
self, session: Session,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
"""Split live session tail into archiveable prefix and retained recent suffix."""
tail = list(session.messages[session.last_consolidated:])
if not tail:
return [], []
probe = Session(
key=session.key,
messages=tail.copy(),
created_at=session.created_at,
updated_at=session.updated_at,
metadata={},
last_consolidated=0,
)
probe.retain_recent_legal_suffix(self._RECENT_SUFFIX_MESSAGES)
kept = probe.messages
cut = len(tail) - len(kept)
return tail[:cut], kept
def check_expired(self, schedule_background: Callable[[Coroutine], None],
active_session_keys: Collection[str] = ()) -> None:
"""Schedule archival for idle sessions, skipping those with in-flight agent tasks."""
now = datetime.now()
for info in self.sessions.list_sessions():
key = info.get("key", "")
if not key or key in self._archiving:
continue
if key in active_session_keys:
continue
if self._is_expired(info.get("updated_at"), now):
self._archiving.add(key)
schedule_background(self._archive(key))
async def _archive(self, key: str) -> None:
try:
self.sessions.invalidate(key)
session = self.sessions.get_or_create(key)
archive_msgs, kept_msgs = self._split_unconsolidated(session)
if not archive_msgs and not kept_msgs:
session.updated_at = datetime.now()
self.sessions.save(session)
return
last_active = session.updated_at
summary = ""
if archive_msgs:
summary = await self.consolidator.archive(archive_msgs) or ""
if summary and summary != "(nothing)":
self._summaries[key] = (summary, last_active)
session.metadata["_last_summary"] = {"text": summary, "last_active": last_active.isoformat()}
session.messages = kept_msgs
session.last_consolidated = 0
session.updated_at = datetime.now()
self.sessions.save(session)
if archive_msgs:
logger.info(
"Auto-compact: archived {} (archived={}, kept={}, summary={})",
key,
len(archive_msgs),
len(kept_msgs),
bool(summary),
)
except Exception:
logger.exception("Auto-compact: failed for {}", key)
finally:
self._archiving.discard(key)
def prepare_session(self, session: Session, key: str) -> tuple[Session, str | None]:
if key in self._archiving or self._is_expired(session.updated_at):
logger.info("Auto-compact: reloading session {} (archiving={})", key, key in self._archiving)
session = self.sessions.get_or_create(key)
# Hot path: summary from in-memory dict (process hasn't restarted).
# Also clean metadata copy so stale _last_summary never leaks to disk.
entry = self._summaries.pop(key, None)
if entry:
session.metadata.pop("_last_summary", None)
return session, self._format_summary(entry[0], entry[1])
if "_last_summary" in session.metadata:
meta = session.metadata.pop("_last_summary")
self.sessions.save(session)
return session, self._format_summary(meta["text"], datetime.fromisoformat(meta["last_active"]))
return session, None

View File

@ -6,12 +6,10 @@ import platform
from pathlib import Path
from typing import Any
from nanobot.utils.helpers import current_time_str
from nanobot.agent.memory import MemoryStore
from nanobot.utils.prompt_templates import render_template
from nanobot.agent.skills import SkillsLoader
from nanobot.utils.helpers import build_assistant_message, detect_image_mime
from nanobot.utils.helpers import build_assistant_message, current_time_str, detect_image_mime
from nanobot.utils.prompt_templates import render_template
class ContextBuilder:
@ -20,12 +18,13 @@ class ContextBuilder:
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
_MAX_RECENT_HISTORY = 50
_RUNTIME_CONTEXT_END = "[/Runtime Context]"
def __init__(self, workspace: Path, timezone: str | None = None):
def __init__(self, workspace: Path, timezone: str | None = None, disabled_skills: list[str] | None = None):
self.workspace = workspace
self.timezone = timezone
self.memory = MemoryStore(workspace)
self.skills = SkillsLoader(workspace)
self.skills = SkillsLoader(workspace, disabled_skills=set(disabled_skills) if disabled_skills else None)
def build_system_prompt(
self,
@ -79,12 +78,15 @@ class ContextBuilder:
@staticmethod
def _build_runtime_context(
channel: str | None, chat_id: str | None, timezone: str | None = None,
session_summary: str | None = None,
) -> str:
"""Build untrusted runtime metadata block for injection before the user message."""
lines = [f"Current Time: {current_time_str(timezone)}"]
if channel and chat_id:
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
if session_summary:
lines += ["", "[Resumed Session]", session_summary]
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + "\n" + ContextBuilder._RUNTIME_CONTEXT_END
@staticmethod
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
@ -121,9 +123,10 @@ class ContextBuilder:
channel: str | None = None,
chat_id: str | None = None,
current_role: str = "user",
session_summary: str | None = None,
) -> list[dict[str, Any]]:
"""Build the complete message list for an LLM call."""
runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone)
runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone, session_summary=session_summary)
user_content = self._build_user_content(current_message, media)
# Merge runtime context and user content into a single user message
@ -176,7 +179,7 @@ class ContextBuilder:
# Try document text extraction
from nanobot.utils.document import extract_text
extracted = extract_text(p)
if extracted and not extracted.startswith("Error"):
if extracted and not extracted.startswith("[error:"):
doc_texts.append(f"[File: {p.name}]\n{extracted}")
# Build final content

View File

@ -29,6 +29,9 @@ class AgentHookContext:
class AgentHook:
"""Minimal lifecycle surface for shared runner customization."""
def __init__(self, reraise: bool = False) -> None:
self._reraise = reraise
def wants_streaming(self) -> bool:
return False
@ -62,6 +65,7 @@ class CompositeHook(AgentHook):
__slots__ = ("_hooks",)
def __init__(self, hooks: list[AgentHook]) -> None:
super().__init__()
self._hooks = list(hooks)
def wants_streaming(self) -> bool:
@ -69,6 +73,10 @@ class CompositeHook(AgentHook):
async def _for_each_hook_safe(self, method_name: str, *args: Any, **kwargs: Any) -> None:
for h in self._hooks:
if getattr(h, "_reraise", False):
await getattr(h, method_name)(*args, **kwargs)
continue
try:
await getattr(h, method_name)(*args, **kwargs)
except Exception:

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import dataclasses
import json
import os
import time
@ -12,27 +13,30 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger
from nanobot.agent.autocompact import AutoCompact
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.memory import Consolidator, Dream
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.notebook import NotebookEditTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.search import GlobTool, GrepTool
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
from nanobot.bus.queue import MessageBus
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
from nanobot.config.schema import AgentDefaults
from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session, SessionManager
from nanobot.utils.helpers import image_placeholder_text, truncate_text
from nanobot.utils.helpers import image_placeholder_text
from nanobot.utils.helpers import truncate_text as truncate_text_fn
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
if TYPE_CHECKING:
@ -40,6 +44,9 @@ if TYPE_CHECKING:
from nanobot.cron.service import CronService
UNIFIED_SESSION_KEY = "unified:default"
class _LoopHook(AgentHook):
"""Core hook for the main loop."""
@ -54,6 +61,7 @@ class _LoopHook(AgentHook):
chat_id: str = "direct",
message_id: str | None = None,
) -> None:
super().__init__(reraise=True)
self._loop = agent_loop
self._on_progress = on_progress
self._on_stream = on_stream
@ -72,7 +80,7 @@ class _LoopHook(AgentHook):
prev_clean = strip_think(self._stream_buf)
self._stream_buf += delta
new_clean = strip_think(self._stream_buf)
incremental = new_clean[len(prev_clean):]
incremental = new_clean[len(prev_clean) :]
if incremental and self._on_stream:
await self._on_stream(incremental)
@ -109,43 +117,6 @@ class _LoopHook(AgentHook):
return self._loop._strip_think(content)
class _LoopHookChain(AgentHook):
"""Run the core hook before extra hooks."""
__slots__ = ("_primary", "_extras")
def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None:
self._primary = primary
self._extras = CompositeHook(extra_hooks)
def wants_streaming(self) -> bool:
return self._primary.wants_streaming() or self._extras.wants_streaming()
async def before_iteration(self, context: AgentHookContext) -> None:
await self._primary.before_iteration(context)
await self._extras.before_iteration(context)
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
await self._primary.on_stream(context, delta)
await self._extras.on_stream(context, delta)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
await self._primary.on_stream_end(context, resuming=resuming)
await self._extras.on_stream_end(context, resuming=resuming)
async def before_execute_tools(self, context: AgentHookContext) -> None:
await self._primary.before_execute_tools(context)
await self._extras.before_execute_tools(context)
async def after_iteration(self, context: AgentHookContext) -> None:
await self._primary.after_iteration(context)
await self._extras.after_iteration(context)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
content = self._primary.finalize_content(context, content)
return self._extras.finalize_content(context, content)
class AgentLoop:
"""
The agent loop is the core processing engine.
@ -159,6 +130,7 @@ class AgentLoop:
"""
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
_PENDING_USER_TURN_KEY = "pending_user_turn"
def __init__(
self,
@ -179,7 +151,10 @@ class AgentLoop:
mcp_servers: dict | None = None,
channels_config: ChannelsConfig | None = None,
timezone: str | None = None,
session_ttl_minutes: int = 0,
hooks: list[AgentHook] | None = None,
unified_session: bool = False,
disabled_skills: list[str] | None = None,
):
from nanobot.config.schema import ExecToolConfig, WebToolsConfig
@ -212,7 +187,7 @@ class AgentLoop:
self._last_usage: dict[str, int] = {}
self._extra_hooks: list[AgentHook] = hooks or []
self.context = ContextBuilder(workspace, timezone=timezone)
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
self.sessions = session_manager or SessionManager(workspace)
self.tools = ToolRegistry()
self.runner = AgentRunner(provider)
@ -225,16 +200,21 @@ class AgentLoop:
max_tool_result_chars=self.max_tool_result_chars,
exec_config=self.exec_config,
restrict_to_workspace=restrict_to_workspace,
disabled_skills=disabled_skills,
)
self._unified_session = unified_session
self._running = False
self._mcp_servers = mcp_servers or {}
self._mcp_stack: AsyncExitStack | None = None
self._mcp_stacks: dict[str, AsyncExitStack] = {}
self._mcp_connected = False
self._mcp_connecting = False
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._background_tasks: list[asyncio.Task] = []
self._session_locks: dict[str, asyncio.Lock] = {}
# Per-session pending queues for mid-turn message injection.
# When a session has an active task, new messages for that session
# are routed here instead of creating a new task.
self._pending_queues: dict[str, asyncio.Queue] = {}
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
self._concurrency_gate: asyncio.Semaphore | None = (
@ -250,6 +230,11 @@ class AgentLoop:
get_tool_definitions=self.tools.get_definitions,
max_completion_tokens=provider.generation.max_tokens,
)
self.auto_compact = AutoCompact(
sessions=self.sessions,
consolidator=self.consolidator,
session_ttl_minutes=session_ttl_minutes,
)
self.dream = Dream(
store=self.context.memory,
provider=provider,
@ -261,23 +246,35 @@ class AgentLoop:
def _register_default_tools(self) -> None:
"""Register the default set of tools."""
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
allowed_dir = (
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
)
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
self.tools.register(
ReadFileTool(
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
)
)
for cls in (WriteFileTool, EditFileTool, ListDirTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
for cls in (GlobTool, GrepTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
self.tools.register(NotebookEditTool(workspace=self.workspace, allowed_dir=allowed_dir))
if self.exec_config.enable:
self.tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
sandbox=self.exec_config.sandbox,
path_append=self.exec_config.path_append,
))
self.tools.register(
ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
sandbox=self.exec_config.sandbox,
path_append=self.exec_config.path_append,
allowed_env_keys=self.exec_config.allowed_env_keys,
)
)
if self.web_config.enable:
self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy))
self.tools.register(
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
)
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
self.tools.register(SpawnTool(manager=self.subagents))
@ -292,19 +289,19 @@ class AgentLoop:
return
self._mcp_connecting = True
from nanobot.agent.tools.mcp import connect_mcp_servers
try:
self._mcp_stack = AsyncExitStack()
await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
self._mcp_connected = True
self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools)
if self._mcp_stacks:
self._mcp_connected = True
else:
logger.warning("No MCP servers connected successfully (will retry next message)")
except asyncio.CancelledError:
logger.warning("MCP connection cancelled (will retry next message)")
self._mcp_stacks.clear()
except BaseException as e:
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
if self._mcp_stack:
try:
await self._mcp_stack.aclose()
except Exception:
pass
self._mcp_stack = None
self._mcp_stacks.clear()
finally:
self._mcp_connecting = False
@ -321,6 +318,7 @@ class AgentLoop:
if not text:
return None
from nanobot.utils.helpers import strip_think
return strip_think(text) or None
@staticmethod
@ -330,6 +328,12 @@ class AgentLoop:
return format_tool_hints(tool_calls)
def _effective_session_key(self, msg: InboundMessage) -> str:
"""Return the session key used for task routing and mid-turn injections."""
if self._unified_session and not msg.session_key_override:
return UNIFIED_SESSION_KEY
return msg.session_key
async def _run_agent_loop(
self,
initial_messages: list[dict],
@ -341,13 +345,16 @@ class AgentLoop:
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
) -> tuple[str | None, list[str], list[dict]]:
pending_queue: asyncio.Queue | None = None,
) -> tuple[str | None, list[str], list[dict], str, bool]:
"""Run the agent iteration loop.
*on_stream*: called with each content delta during streaming.
*on_stream_end(resuming)*: called when a streaming session finishes.
``resuming=True`` means tool calls follow (spinner should restart);
``resuming=False`` means this is the final response.
Returns (final_content, tools_used, messages, stop_reason, had_injections).
"""
loop_hook = _LoopHook(
self,
@ -359,9 +366,7 @@ class AgentLoop:
message_id=message_id,
)
hook: AgentHook = (
_LoopHookChain(loop_hook, self._extra_hooks)
if self._extra_hooks
else loop_hook
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
)
async def _checkpoint(payload: dict[str, Any]) -> None:
@ -369,6 +374,32 @@ class AgentLoop:
return
self._set_runtime_checkpoint(session, payload)
async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
"""Non-blocking drain of follow-up messages from the pending queue."""
if pending_queue is None:
return []
items: list[dict[str, Any]] = []
while len(items) < limit:
try:
pending_msg = pending_queue.get_nowait()
except asyncio.QueueEmpty:
break
user_content = self.context._build_user_content(
pending_msg.content,
pending_msg.media if pending_msg.media else None,
)
runtime_ctx = self.context._build_runtime_context(
pending_msg.channel,
pending_msg.chat_id,
self.context.timezone,
)
if isinstance(user_content, str):
merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}"
else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content
items.append({"role": "user", "content": merged})
return items
result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
@ -385,13 +416,14 @@ class AgentLoop:
provider_retry_mode=self.provider_retry_mode,
progress_callback=on_progress,
checkpoint_callback=_checkpoint,
injection_callback=_drain_pending,
))
self._last_usage = result.usage
if result.stop_reason == "max_iterations":
logger.warning("Max iterations ({}) reached", self.max_iterations)
elif result.stop_reason == "error":
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
return result.final_content, result.tools_used, result.messages
return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections
async def run(self) -> None:
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
@ -403,6 +435,10 @@ class AgentLoop:
try:
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
except asyncio.TimeoutError:
self.auto_compact.check_expired(
self._schedule_background,
active_session_keys=self._pending_queues.keys(),
)
continue
except asyncio.CancelledError:
# Preserve real task cancellation so shutdown can complete cleanly.
@ -421,79 +457,140 @@ class AgentLoop:
if result:
await self.bus.publish_outbound(result)
continue
effective_key = self._effective_session_key(msg)
# If this session already has an active pending queue (i.e. a task
# is processing this session), route the message there for mid-turn
# injection instead of creating a competing task.
if effective_key in self._pending_queues:
pending_msg = msg
if effective_key != msg.session_key:
pending_msg = dataclasses.replace(
msg,
session_key_override=effective_key,
)
try:
self._pending_queues[effective_key].put_nowait(pending_msg)
except asyncio.QueueFull:
logger.warning(
"Pending queue full for session {}, falling back to queued task",
effective_key,
)
else:
logger.info(
"Routed follow-up message to pending queue for session {}",
effective_key,
)
continue
# Compute the effective session key before dispatching
# This ensures /stop command can find tasks correctly when unified session is enabled
task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(msg.session_key, []).append(task)
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
self._active_tasks.setdefault(effective_key, []).append(task)
task.add_done_callback(
lambda t, k=effective_key: self._active_tasks.get(k, [])
and self._active_tasks[k].remove(t)
if t in self._active_tasks.get(k, [])
else None
)
async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message: per-session serial, cross-session concurrent."""
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
session_key = self._effective_session_key(msg)
if session_key != msg.session_key:
msg = dataclasses.replace(msg, session_key_override=session_key)
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
gate = self._concurrency_gate or nullcontext()
async with lock, gate:
try:
on_stream = on_stream_end = None
if msg.metadata.get("_wants_stream"):
# Split one answer into distinct stream segments.
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
stream_segment = 0
def _current_stream_id() -> str:
return f"{stream_base_id}:{stream_segment}"
# Register a pending queue so follow-up messages for this session are
# routed here (mid-turn injection) instead of spawning a new task.
pending = asyncio.Queue(maxsize=20)
self._pending_queues[session_key] = pending
async def on_stream(delta: str) -> None:
meta = dict(msg.metadata or {})
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
try:
async with lock, gate:
try:
on_stream = on_stream_end = None
if msg.metadata.get("_wants_stream"):
# Split one answer into distinct stream segments.
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
stream_segment = 0
def _current_stream_id() -> str:
return f"{stream_base_id}:{stream_segment}"
async def on_stream(delta: str) -> None:
meta = dict(msg.metadata or {})
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content=delta,
metadata=meta,
))
async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(msg.metadata or {})
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="",
metadata=meta,
))
stream_segment += 1
response = await self._process_message(
msg, on_stream=on_stream, on_stream_end=on_stream_end,
pending_queue=pending,
)
if response is not None:
await self.bus.publish_outbound(response)
elif msg.channel == "cli":
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content=delta,
metadata=meta,
content="", metadata=msg.metadata or {},
))
async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(msg.metadata or {})
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="",
metadata=meta,
))
stream_segment += 1
response = await self._process_message(
msg, on_stream=on_stream, on_stream_end=on_stream_end,
)
if response is not None:
await self.bus.publish_outbound(response)
elif msg.channel == "cli":
except asyncio.CancelledError:
logger.info("Task cancelled for session {}", session_key)
raise
except Exception:
logger.exception("Error processing message for session {}", session_key)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="", metadata=msg.metadata or {},
content="Sorry, I encountered an error.",
))
except asyncio.CancelledError:
logger.info("Task cancelled for session {}", msg.session_key)
raise
except Exception:
logger.exception("Error processing message for session {}", msg.session_key)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Sorry, I encountered an error.",
))
finally:
# Drain any messages still in the pending queue and re-publish
# them to the bus so they are processed as fresh inbound messages
# rather than silently lost.
queue = self._pending_queues.pop(session_key, None)
if queue is not None:
leftover = 0
while True:
try:
item = queue.get_nowait()
except asyncio.QueueEmpty:
break
await self.bus.publish_inbound(item)
leftover += 1
if leftover:
logger.info(
"Re-published {} leftover message(s) to bus for session {}",
leftover, session_key,
)
async def close_mcp(self) -> None:
"""Drain pending background archives, then close MCP connections."""
if self._background_tasks:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
if self._mcp_stack:
for name, stack in self._mcp_stacks.items():
try:
await self._mcp_stack.aclose()
await stack.aclose()
except (RuntimeError, BaseExceptionGroup):
pass # MCP SDK cancel scope cleanup is noisy but harmless
self._mcp_stack = None
logger.debug("MCP server '{}' cleanup error (can be ignored)", name)
self._mcp_stacks.clear()
def _schedule_background(self, coro) -> None:
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
@ -513,27 +610,36 @@ class AgentLoop:
on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
pending_queue: asyncio.Queue | None = None,
) -> OutboundMessage | None:
"""Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id")
if msg.channel == "system":
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
else ("cli", msg.chat_id))
channel, chat_id = (
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
)
logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
if self._restore_pending_user_turn(session):
self.sessions.save(session)
session, pending = self.auto_compact.prepare_session(session, key)
await self.consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=0)
current_role = "assistant" if msg.sender_id == "subagent" else "user"
messages = self.context.build_messages(
history=history,
current_message=msg.content, channel=channel, chat_id=chat_id,
session_summary=pending,
current_role=current_role,
)
final_content, _, all_msgs = await self._run_agent_loop(
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
messages, session=session, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
)
@ -541,8 +647,11 @@ class AgentLoop:
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.")
return OutboundMessage(
channel=channel,
chat_id=chat_id,
content=final_content or "Background task completed.",
)
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
@ -551,6 +660,10 @@ class AgentLoop:
session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
if self._restore_pending_user_turn(session):
self.sessions.save(session)
session, pending = self.auto_compact.prepare_session(session, key)
# Slash commands
raw = msg.content.strip()
@ -566,50 +679,85 @@ class AgentLoop:
message_tool.start_turn()
history = session.get_history(max_messages=0)
initial_messages = self.context.build_messages(
history=history,
current_message=msg.content,
session_summary=pending,
media=msg.media if msg.media else None,
channel=msg.channel, chat_id=msg.chat_id,
channel=msg.channel,
chat_id=msg.chat_id,
)
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
meta = dict(msg.metadata or {})
meta["_progress"] = True
meta["_tool_hint"] = tool_hint
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
))
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=content,
metadata=meta,
)
)
final_content, _, all_msgs = await self._run_agent_loop(
# Persist the triggering user message immediately, before running the
# agent loop. If the process is killed mid-turn (OOM, SIGKILL, self-
# restart, etc.), the existing runtime_checkpoint preserves the
# in-flight assistant/tool state but NOT the user message itself, so
# the user's prompt is silently lost on recovery. Saving it up front
# makes recovery possible from the session log alone.
user_persisted_early = False
if isinstance(msg.content, str) and msg.content.strip():
session.add_message("user", msg.content)
self._mark_pending_user_turn(session)
self.sessions.save(session)
user_persisted_early = True
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
initial_messages,
on_progress=on_progress or _bus_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
session=session,
channel=msg.channel, chat_id=msg.chat_id,
channel=msg.channel,
chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"),
pending_queue=pending_queue,
)
if final_content is None or not final_content.strip():
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
self._save_turn(session, all_msgs, 1 + len(history))
# Skip the already-persisted user message when saving the turn
save_skip = 1 + len(history) + (1 if user_persisted_early else 0)
self._save_turn(session, all_msgs, save_skip)
self._clear_pending_user_turn(session)
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
# When follow-up messages were injected mid-turn, a later natural
# language reply may address those follow-ups and should not be
# suppressed just because MessageTool was used earlier in the turn.
# However, if the turn falls back to the empty-final-response
# placeholder, suppress it when the real user-visible output already
# came from MessageTool.
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None
if not had_injections or stop_reason == "empty_final_response":
return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
meta = dict(msg.metadata or {})
if on_stream is not None:
if on_stream is not None and stop_reason != "error":
meta["_streamed"] = True
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
channel=msg.channel,
chat_id=msg.chat_id,
content=final_content,
metadata=meta,
)
@ -617,7 +765,7 @@ class AgentLoop:
self,
content: list[dict[str, Any]],
*,
truncate_text: bool = False,
should_truncate_text: bool = False,
drop_runtime: bool = False,
) -> list[dict[str, Any]]:
"""Strip volatile multimodal payloads before writing session history."""
@ -635,18 +783,17 @@ class AgentLoop:
):
continue
if (
block.get("type") == "image_url"
and block.get("image_url", {}).get("url", "").startswith("data:image/")
):
if block.get("type") == "image_url" and block.get("image_url", {}).get(
"url", ""
).startswith("data:image/"):
path = (block.get("_meta") or {}).get("path", "")
filtered.append({"type": "text", "text": image_placeholder_text(path)})
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
text = block["text"]
if truncate_text and len(text) > self.max_tool_result_chars:
text = truncate_text(text, self.max_tool_result_chars)
if should_truncate_text and len(text) > self.max_tool_result_chars:
text = truncate_text_fn(text, self.max_tool_result_chars)
filtered.append({**block, "text": text})
continue
@ -657,6 +804,7 @@ class AgentLoop:
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
"""Save new-turn messages into session, truncating large tool results."""
from datetime import datetime
for m in messages[skip:]:
entry = dict(m)
role, content = entry.get("role"), entry.get("content")
@ -664,20 +812,31 @@ class AgentLoop:
continue # skip empty assistant messages — they poison session context
if role == "tool":
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
entry["content"] = truncate_text(content, self.max_tool_result_chars)
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
elif isinstance(content, list):
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
filtered = self._sanitize_persisted_blocks(content, should_truncate_text=True)
if not filtered:
continue
entry["content"] = filtered
elif role == "user":
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
# Strip the runtime-context prefix, keep only the user text.
parts = content.split("\n\n", 1)
if len(parts) > 1 and parts[1].strip():
entry["content"] = parts[1]
# Strip the entire runtime-context block (including any session summary).
# The block is bounded by _RUNTIME_CONTEXT_TAG and _RUNTIME_CONTEXT_END.
end_marker = ContextBuilder._RUNTIME_CONTEXT_END
end_pos = content.find(end_marker)
if end_pos >= 0:
after = content[end_pos + len(end_marker):].lstrip("\n")
if after:
entry["content"] = after
else:
continue
else:
continue
# Fallback: no end marker found, strip the tag prefix
after_tag = content[len(ContextBuilder._RUNTIME_CONTEXT_TAG):].lstrip("\n")
if after_tag.strip():
entry["content"] = after_tag
else:
continue
if isinstance(content, list):
filtered = self._sanitize_persisted_blocks(content, drop_runtime=True)
if not filtered:
@ -692,6 +851,12 @@ class AgentLoop:
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
self.sessions.save(session)
def _mark_pending_user_turn(self, session: Session) -> None:
session.metadata[self._PENDING_USER_TURN_KEY] = True
def _clear_pending_user_turn(self, session: Session) -> None:
session.metadata.pop(self._PENDING_USER_TURN_KEY, None)
def _clear_runtime_checkpoint(self, session: Session) -> None:
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
@ -735,13 +900,15 @@ class AgentLoop:
continue
tool_id = tool_call.get("id")
name = ((tool_call.get("function") or {}).get("name")) or "tool"
restored_messages.append({
"role": "tool",
"tool_call_id": tool_id,
"name": name,
"content": "Error: Task interrupted before this tool finished.",
"timestamp": datetime.now().isoformat(),
})
restored_messages.append(
{
"role": "tool",
"tool_call_id": tool_id,
"name": name,
"content": "Error: Task interrupted before this tool finished.",
"timestamp": datetime.now().isoformat(),
}
)
overlap = 0
max_overlap = min(len(session.messages), len(restored_messages))
@ -756,9 +923,30 @@ class AgentLoop:
break
session.messages.extend(restored_messages[overlap:])
self._clear_pending_user_turn(session)
self._clear_runtime_checkpoint(session)
return True
def _restore_pending_user_turn(self, session: Session) -> bool:
"""Close a turn that only persisted the user message before crashing."""
from datetime import datetime
if not session.metadata.get(self._PENDING_USER_TURN_KEY):
return False
if session.messages and session.messages[-1].get("role") == "user":
session.messages.append(
{
"role": "assistant",
"content": "Error: Task interrupted before a response was generated.",
"timestamp": datetime.now().isoformat(),
}
)
session.updated_at = datetime.now()
self._clear_pending_user_turn(session)
return True
async def process_direct(
self,
content: str,
@ -777,6 +965,9 @@ class AgentLoop:
content=content, media=media or [],
)
return await self._process_message(
msg, session_key=session_key, on_progress=on_progress,
on_stream=on_stream, on_stream_end=on_stream_end,
msg,
session_key=session_key,
on_progress=on_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
)

View File

@ -290,7 +290,7 @@ class MemoryStore:
if not lines:
return None
return json.loads(lines[-1])
except (FileNotFoundError, json.JSONDecodeError):
except (FileNotFoundError, json.JSONDecodeError, UnicodeDecodeError):
return None
def _write_entries(self, entries: list[dict[str, Any]]) -> None:
@ -347,6 +347,7 @@ class Consolidator:
"""Lightweight consolidation: summarizes evicted messages into history.jsonl."""
_MAX_CONSOLIDATION_ROUNDS = 5
_MAX_CHUNK_MESSAGES = 60 # hard cap per consolidation round
_SAFETY_BUFFER = 1024 # extra headroom for tokenizer estimation drift
@ -399,6 +400,22 @@ class Consolidator:
return last_boundary
def _cap_consolidation_boundary(
self,
session: Session,
end_idx: int,
) -> int | None:
"""Clamp the chunk size without breaking the user-turn boundary."""
start = session.last_consolidated
if end_idx - start <= self._MAX_CHUNK_MESSAGES:
return end_idx
capped_end = start + self._MAX_CHUNK_MESSAGES
for idx in range(capped_end, start, -1):
if session.messages[idx].get("role") == "user":
return idx
return None
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
"""Estimate current prompt size for the normal session history view."""
history = session.get_history(max_messages=0)
@ -416,13 +433,13 @@ class Consolidator:
self._get_tool_definitions(),
)
async def archive(self, messages: list[dict]) -> bool:
async def archive(self, messages: list[dict]) -> str | None:
"""Summarize messages via LLM and append to history.jsonl.
Returns True on success (or degraded success), False if nothing to do.
Returns the summary text on success, None if nothing to archive.
"""
if not messages:
return False
return None
try:
formatted = MemoryStore._format_messages(messages)
response = await self.provider.chat_with_retry(
@ -442,11 +459,11 @@ class Consolidator:
)
summary = response.content or "[no summary]"
self.store.append_history(summary)
return True
return summary
except Exception:
logger.warning("Consolidation LLM call failed, raw-dumping to history")
self.store.raw_archive(messages)
return True
return None
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
"""Loop: archive old messages until prompt fits within safe budget.
@ -461,16 +478,22 @@ class Consolidator:
async with lock:
budget = self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER
target = budget // 2
estimated, source = self.estimate_session_prompt_tokens(session)
try:
estimated, source = self.estimate_session_prompt_tokens(session)
except Exception:
logger.exception("Token estimation failed for {}", session.key)
estimated, source = 0, "error"
if estimated <= 0:
return
if estimated < budget:
unconsolidated_count = len(session.messages) - session.last_consolidated
logger.debug(
"Token consolidation idle {}: {}/{} via {}",
"Token consolidation idle {}: {}/{} via {}, msgs={}",
session.key,
estimated,
self.context_window_tokens,
source,
unconsolidated_count,
)
return
@ -488,6 +511,15 @@ class Consolidator:
return
end_idx = boundary[0]
end_idx = self._cap_consolidation_boundary(session, end_idx)
if end_idx is None:
logger.debug(
"Token consolidation: no capped boundary for {} (round {})",
session.key,
round_num,
)
return
chunk = session.messages[session.last_consolidated:end_idx]
if not chunk:
return
@ -506,7 +538,11 @@ class Consolidator:
session.last_consolidated = end_idx
self.sessions.save(session)
estimated, source = self.estimate_session_prompt_tokens(session)
try:
estimated, source = self.estimate_session_prompt_tokens(session)
except Exception:
logger.exception("Token estimation failed for {}", session.key)
estimated, source = 0, "error"
if estimated <= 0:
return
@ -546,18 +582,60 @@ class Dream:
def _build_tools(self) -> ToolRegistry:
"""Build a minimal tool registry for the Dream agent."""
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
tools = ToolRegistry()
workspace = self.store.workspace
tools.register(ReadFileTool(workspace=workspace, allowed_dir=workspace))
# Allow reading builtin skills for reference during skill creation
extra_read = [BUILTIN_SKILLS_DIR] if BUILTIN_SKILLS_DIR.exists() else None
tools.register(ReadFileTool(
workspace=workspace,
allowed_dir=workspace,
extra_allowed_dirs=extra_read,
))
tools.register(EditFileTool(workspace=workspace, allowed_dir=workspace))
# write_file resolves relative paths from workspace root, but can only
# write under skills/ so the prompt can safely use skills/<name>/SKILL.md.
skills_dir = workspace / "skills"
skills_dir.mkdir(parents=True, exist_ok=True)
tools.register(WriteFileTool(workspace=workspace, allowed_dir=skills_dir))
return tools
# -- skill listing --------------------------------------------------------
def _list_existing_skills(self) -> list[str]:
"""List existing skills as 'name — description' for dedup context."""
import re as _re
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
_DESC_RE = _re.compile(r"^description:\s*(.+)$", _re.MULTILINE | _re.IGNORECASE)
entries: dict[str, str] = {}
for base in (self.store.workspace / "skills", BUILTIN_SKILLS_DIR):
if not base.exists():
continue
for d in base.iterdir():
if not d.is_dir():
continue
skill_md = d / "SKILL.md"
if not skill_md.exists():
continue
# Prefer workspace skills over builtin (same name)
if d.name in entries and base == BUILTIN_SKILLS_DIR:
continue
content = skill_md.read_text(encoding="utf-8")[:500]
m = _DESC_RE.search(content)
desc = m.group(1).strip() if m else "(no description)"
entries[d.name] = desc
return [f"{name}{desc}" for name, desc in sorted(entries.items())]
# -- main entry ----------------------------------------------------------
async def run(self) -> bool:
"""Process unprocessed history entries. Returns True if work was done."""
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
last_cursor = self.store.get_last_dream_cursor()
entries = self.store.read_unprocessed_history(since_cursor=last_cursor)
if not entries:
@ -579,6 +657,7 @@ class Dream:
current_memory = self.store.read_memory() or "(empty)"
current_soul = self.store.read_soul() or "(empty)"
current_user = self.store.read_user() or "(empty)"
file_context = (
f"## Current Date\n{current_date}\n\n"
f"## Current MEMORY.md ({len(current_memory)} chars)\n{current_memory}\n\n"
@ -586,7 +665,7 @@ class Dream:
f"## Current USER.md ({len(current_user)} chars)\n{current_user}"
)
# Phase 1: Analyze
# Phase 1: Analyze (no skills list — dedup is Phase 2's job)
phase1_prompt = (
f"## Conversation History\n{history_text}\n\n{file_context}"
)
@ -611,13 +690,25 @@ class Dream:
return False
# Phase 2: Delegate to AgentRunner with read_file / edit_file
phase2_prompt = f"## Analysis Result\n{analysis}\n\n{file_context}"
existing_skills = self._list_existing_skills()
skills_section = ""
if existing_skills:
skills_section = (
"\n\n## Existing Skills\n"
+ "\n".join(f"- {s}" for s in existing_skills)
)
phase2_prompt = f"## Analysis Result\n{analysis}\n\n{file_context}{skills_section}"
tools = self._tools
skill_creator_path = BUILTIN_SKILLS_DIR / "skill-creator" / "SKILL.md"
messages: list[dict[str, Any]] = [
{
"role": "system",
"content": render_template("agent/dream_phase2.md", strip=True),
"content": render_template(
"agent/dream_phase2.md",
strip=True,
skill_creator_path=str(skill_creator_path),
),
},
{"role": "user", "content": phase2_prompt},
]

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
import inspect
from pathlib import Path
from typing import Any
@ -31,8 +32,11 @@ from nanobot.utils.runtime import (
)
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]"
_MAX_EMPTY_RETRIES = 2
_MAX_LENGTH_RECOVERIES = 3
_MAX_INJECTIONS_PER_TURN = 3
_MAX_INJECTION_CYCLES = 5
_SNIP_SAFETY_BUFFER = 1024
_MICROCOMPACT_KEEP_RECENT = 10
_MICROCOMPACT_MIN_CHARS = 500
@ -41,6 +45,9 @@ _COMPACTABLE_TOOLS = frozenset({
"web_search", "web_fetch", "list_dir",
})
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
@dataclass(slots=True)
class AgentRunSpec:
"""Configuration for a single agent execution."""
@ -65,6 +72,7 @@ class AgentRunSpec:
provider_retry_mode: str = "standard"
progress_callback: Any | None = None
checkpoint_callback: Any | None = None
injection_callback: Any | None = None
@dataclass(slots=True)
@ -78,6 +86,7 @@ class AgentRunResult:
stop_reason: str = "completed"
error: str | None = None
tool_events: list[dict[str, str]] = field(default_factory=list)
had_injections: bool = False
class AgentRunner:
@ -86,6 +95,134 @@ class AgentRunner:
def __init__(self, provider: LLMProvider):
self.provider = provider
@staticmethod
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
if isinstance(left, str) and isinstance(right, str):
return f"{left}\n\n{right}" if left else right
def _to_blocks(value: Any) -> list[dict[str, Any]]:
if isinstance(value, list):
return [
item if isinstance(item, dict) else {"type": "text", "text": str(item)}
for item in value
]
if value is None:
return []
return [{"type": "text", "text": str(value)}]
return _to_blocks(left) + _to_blocks(right)
@classmethod
def _append_injected_messages(
cls,
messages: list[dict[str, Any]],
injections: list[dict[str, Any]],
) -> None:
"""Append injected user messages while preserving role alternation."""
for injection in injections:
if (
messages
and injection.get("role") == "user"
and messages[-1].get("role") == "user"
):
merged = dict(messages[-1])
merged["content"] = cls._merge_message_content(
merged.get("content"),
injection.get("content"),
)
messages[-1] = merged
continue
messages.append(injection)
async def _try_drain_injections(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
assistant_message: dict[str, Any] | None,
injection_cycles: int,
*,
phase: str = "after error",
iteration: int | None = None,
) -> tuple[bool, int]:
"""Drain pending injections. Returns (should_continue, updated_cycles).
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
append them to *messages* (and emit a checkpoint if *assistant_message*
and *iteration* are both provided) and return (True, cycles+1) so the
caller continues the iteration loop. Otherwise return (False, cycles).
"""
if injection_cycles >= _MAX_INJECTION_CYCLES:
return False, injection_cycles
injections = await self._drain_injections(spec)
if not injections:
return False, injection_cycles
injection_cycles += 1
if assistant_message is not None:
messages.append(assistant_message)
if iteration is not None:
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
self._append_injected_messages(messages, injections)
logger.info(
"Injected {} follow-up message(s) {} ({}/{})",
len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES,
)
return True, injection_cycles
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
"""Drain pending user messages via the injection callback.
Returns normalized user messages (capped by
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
nothing to inject. Messages beyond the cap are logged so they
are not silently lost.
"""
if spec.injection_callback is None:
return []
try:
signature = inspect.signature(spec.injection_callback)
accepts_limit = (
"limit" in signature.parameters
or any(
parameter.kind is inspect.Parameter.VAR_KEYWORD
for parameter in signature.parameters.values()
)
)
if accepts_limit:
items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN)
else:
items = await spec.injection_callback()
except Exception:
logger.exception("injection_callback failed")
return []
if not items:
return []
injected_messages: list[dict[str, Any]] = []
for item in items:
if isinstance(item, dict) and item.get("role") == "user" and "content" in item:
injected_messages.append(item)
continue
text = getattr(item, "content", str(item))
if text.strip():
injected_messages.append({"role": "user", "content": text})
if len(injected_messages) > _MAX_INJECTIONS_PER_TURN:
dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN
logger.warning(
"Injection callback returned {} messages, capping to {} ({} dropped)",
len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped,
)
injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN]
return injected_messages
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook()
messages = list(spec.initial_messages)
@ -98,21 +235,35 @@ class AgentRunner:
external_lookup_counts: dict[str, int] = {}
empty_content_retries = 0
length_recovery_count = 0
had_injections = False
injection_cycles = 0
for iteration in range(spec.max_iterations):
try:
messages = self._backfill_missing_tool_results(messages)
messages = self._microcompact(messages)
messages = self._apply_tool_result_budget(spec, messages)
messages_for_model = self._snip_history(spec, messages)
# Keep the persisted conversation untouched. Context governance
# may repair or compact historical messages for the model, but
# those synthetic edits must not shift the append boundary used
# later when the caller saves only the new turn.
messages_for_model = self._drop_orphan_tool_results(messages)
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
messages_for_model = self._microcompact(messages_for_model)
messages_for_model = self._apply_tool_result_budget(spec, messages_for_model)
messages_for_model = self._snip_history(spec, messages_for_model)
# Snipping may have created new orphans; clean them up.
messages_for_model = self._drop_orphan_tool_results(messages_for_model)
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
except Exception as exc:
logger.warning(
"Context governance failed on turn {} for {}: {}; using raw messages",
"Context governance failed on turn {} for {}: {}; applying minimal repair",
iteration,
spec.session_key or "default",
exc,
)
messages_for_model = messages
try:
messages_for_model = self._drop_orphan_tool_results(messages)
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
except Exception:
messages_for_model = messages
context = AgentHookContext(iteration=iteration, messages=messages)
await hook.before_iteration(context)
response = await self._request_model(spec, messages_for_model, hook, context)
@ -156,16 +307,6 @@ class AgentRunner:
tool_events.extend(new_events)
context.tool_results = list(results)
context.tool_events = list(new_events)
if fatal_error is not None:
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
final_content = error
stop_reason = "tool_error"
self._append_final_message(messages, final_content)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
break
completed_tool_results: list[dict[str, Any]] = []
for tool_call, result in zip(response.tool_calls, results):
tool_message = {
@ -181,6 +322,23 @@ class AgentRunner:
}
messages.append(tool_message)
completed_tool_results.append(tool_message)
if fatal_error is not None:
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
final_content = error
stop_reason = "tool_error"
self._append_final_message(messages, final_content)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after tool error",
)
if should_continue:
had_injections = True
continue
break
await self._emit_checkpoint(
spec,
{
@ -194,6 +352,13 @@ class AgentRunner:
)
empty_content_retries = 0
length_recovery_count = 0
# Checkpoint 1: drain injections after tools, before next LLM call
_drained, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after tool execution",
)
if _drained:
had_injections = True
await hook.after_iteration(context)
continue
@ -250,18 +415,48 @@ class AgentRunner:
await hook.after_iteration(context)
continue
assistant_message: dict[str, Any] | None = None
if response.finish_reason != "error" and not is_blank_text(clean):
assistant_message = build_assistant_message(
clean,
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
)
# Check for mid-turn injections BEFORE signaling stream end.
# If injections are found we keep the stream alive (resuming=True)
# so streaming channels don't prematurely finalize the card.
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, assistant_message, injection_cycles,
phase="after final response",
iteration=iteration,
)
if should_continue:
had_injections = True
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
await hook.on_stream_end(context, resuming=should_continue)
if should_continue:
await hook.after_iteration(context)
continue
if response.finish_reason == "error":
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
stop_reason = "error"
error = final_content
self._append_final_message(messages, final_content)
self._append_model_error_placeholder(messages)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after LLM error",
)
if should_continue:
had_injections = True
continue
break
if is_blank_text(clean):
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
@ -272,9 +467,16 @@ class AgentRunner:
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after empty response",
)
if should_continue:
had_injections = True
continue
break
messages.append(build_assistant_message(
messages.append(assistant_message or build_assistant_message(
clean,
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
@ -308,6 +510,17 @@ class AgentRunner:
max_iterations=spec.max_iterations,
)
self._append_final_message(messages, final_content)
# Drain any remaining injections so they are appended to the
# conversation history instead of being re-published as
# independent inbound messages by _dispatch's finally block.
# We ignore should_continue here because the for-loop has already
# exhausted all iterations.
drained_after_max_iterations, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after max_iterations",
)
if drained_after_max_iterations:
had_injections = True
return AgentRunResult(
final_content=final_content,
@ -317,6 +530,7 @@ class AgentRunner:
stop_reason=stop_reason,
error=error,
tool_events=tool_events,
had_injections=had_injections,
)
def _build_request_kwargs(
@ -521,6 +735,12 @@ class AgentRunner:
return
messages.append(build_assistant_message(content))
@staticmethod
def _append_model_error_placeholder(messages: list[dict[str, Any]]) -> None:
if messages and messages[-1].get("role") == "assistant" and not messages[-1].get("tool_calls"):
return
messages.append(build_assistant_message(_PERSISTED_MODEL_ERROR_PLACEHOLDER))
def _normalize_tool_result(
self,
spec: AgentRunSpec,
@ -549,6 +769,32 @@ class AgentRunner:
return truncate_text(content, spec.max_tool_result_chars)
return content
@staticmethod
def _drop_orphan_tool_results(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Drop tool results that have no matching assistant tool_call earlier in the history."""
declared: set[str] = set()
updated: list[dict[str, Any]] | None = None
for idx, msg in enumerate(messages):
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
if role == "tool":
tid = msg.get("tool_call_id")
if tid and str(tid) not in declared:
if updated is None:
updated = [dict(m) for m in messages[:idx]]
continue
if updated is not None:
updated.append(dict(msg))
if updated is None:
return messages
return updated
@staticmethod
def _backfill_missing_tool_results(
messages: list[dict[str, Any]],

View File

@ -28,10 +28,11 @@ class SkillsLoader:
specific tools or perform certain tasks.
"""
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None):
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None, disabled_skills: set[str] | None = None):
self.workspace = workspace
self.workspace_skills = workspace / "skills"
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
self.disabled_skills = disabled_skills or set()
def _skill_entries_from_dir(self, base: Path, source: str, *, skip_names: set[str] | None = None) -> list[dict[str, str]]:
if not base.exists():
@ -66,6 +67,9 @@ class SkillsLoader:
self._skill_entries_from_dir(self.builtin_skills, "builtin", skip_names=workspace_names)
)
if self.disabled_skills:
skills = [s for s in skills if s["name"] not in self.disabled_skills]
if filter_unavailable:
return [skill for skill in skills if self._check_requirements(self._get_skill_meta(skill["name"]))]
return skills

View File

@ -27,6 +27,7 @@ class _SubagentHook(AgentHook):
"""Logging-only hook for subagent execution."""
def __init__(self, task_id: str) -> None:
super().__init__()
self._task_id = task_id
async def before_execute_tools(self, context: AgentHookContext) -> None:
@ -51,6 +52,7 @@ class SubagentManager:
web_config: "WebToolsConfig | None" = None,
exec_config: "ExecToolConfig | None" = None,
restrict_to_workspace: bool = False,
disabled_skills: list[str] | None = None,
):
from nanobot.config.schema import ExecToolConfig
@ -62,6 +64,7 @@ class SubagentManager:
self.max_tool_result_chars = max_tool_result_chars
self.exec_config = exec_config or ExecToolConfig()
self.restrict_to_workspace = restrict_to_workspace
self.disabled_skills = set(disabled_skills or [])
self.runner = AgentRunner(provider)
self._running_tasks: dict[str, asyncio.Task[None]] = {}
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
@ -235,7 +238,10 @@ class SubagentManager:
from nanobot.agent.skills import SkillsLoader
time_ctx = ContextBuilder._build_runtime_context(None, None)
skills_summary = SkillsLoader(self.workspace).build_skills_summary()
skills_summary = SkillsLoader(
self.workspace,
disabled_skills=self.disabled_skills,
).build_skills_summary()
return render_template(
"agent/subagent_system.md",
time_ctx=time_ctx,

View File

@ -0,0 +1,105 @@
"""Track file-read state for read-before-edit warnings and read deduplication."""
from __future__ import annotations
import hashlib
import os
from dataclasses import dataclass
from pathlib import Path
@dataclass(slots=True)
class ReadState:
mtime: float
offset: int
limit: int | None
content_hash: str | None
can_dedup: bool
_state: dict[str, ReadState] = {}
def _hash_file(p: str) -> str | None:
try:
return hashlib.sha256(Path(p).read_bytes()).hexdigest()
except OSError:
return None
def record_read(path: str | Path, offset: int = 1, limit: int | None = None) -> None:
"""Record that a file was read (called after successful read)."""
p = str(Path(path).resolve())
try:
mtime = os.path.getmtime(p)
except OSError:
return
_state[p] = ReadState(
mtime=mtime,
offset=offset,
limit=limit,
content_hash=_hash_file(p),
can_dedup=True,
)
def record_write(path: str | Path) -> None:
"""Record that a file was written (updates mtime in state)."""
p = str(Path(path).resolve())
try:
mtime = os.path.getmtime(p)
except OSError:
_state.pop(p, None)
return
_state[p] = ReadState(
mtime=mtime,
offset=1,
limit=None,
content_hash=_hash_file(p),
can_dedup=False,
)
def check_read(path: str | Path) -> str | None:
"""Check if a file has been read and is fresh.
Returns None if OK, or a warning string.
When mtime changed but file content is identical (e.g. touch, editor save),
the check passes to avoid false-positive staleness warnings.
"""
p = str(Path(path).resolve())
entry = _state.get(p)
if entry is None:
return "Warning: file has not been read yet. Read it first to verify content before editing."
try:
current_mtime = os.path.getmtime(p)
except OSError:
return None
if current_mtime != entry.mtime:
if entry.content_hash and _hash_file(p) == entry.content_hash:
entry.mtime = current_mtime
return None
return "Warning: file has been modified since last read. Re-read to verify content before editing."
return None
def is_unchanged(path: str | Path, offset: int = 1, limit: int | None = None) -> bool:
"""Return True if file was previously read with same params and mtime is unchanged."""
p = str(Path(path).resolve())
entry = _state.get(p)
if entry is None:
return False
if not entry.can_dedup:
return False
if entry.offset != offset or entry.limit != limit:
return False
try:
current_mtime = os.path.getmtime(p)
except OSError:
return False
return current_mtime == entry.mtime
def clear() -> None:
"""Clear all tracked state (useful for testing)."""
_state.clear()

View File

@ -2,11 +2,13 @@
import difflib
import mimetypes
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
from nanobot.agent.tools import file_state
from nanobot.utils.helpers import build_image_content_blocks, detect_image_mime
from nanobot.config.paths import get_media_dir
@ -60,6 +62,36 @@ class _FsTool(Tool):
# ---------------------------------------------------------------------------
_BLOCKED_DEVICE_PATHS = frozenset({
"/dev/zero", "/dev/random", "/dev/urandom", "/dev/full",
"/dev/stdin", "/dev/stdout", "/dev/stderr",
"/dev/tty", "/dev/console",
"/dev/fd/0", "/dev/fd/1", "/dev/fd/2",
})
def _is_blocked_device(path: str | Path) -> bool:
"""Check if path is a blocked device that could hang or produce infinite output."""
import re
raw = str(path)
if raw in _BLOCKED_DEVICE_PATHS:
return True
if re.match(r"/proc/\d+/fd/[012]$", raw) or re.match(r"/proc/self/fd/[012]$", raw):
return True
return False
def _parse_page_range(pages: str, total: int) -> tuple[int, int]:
"""Parse a page range like '2-5' into 0-based (start, end) inclusive."""
parts = pages.strip().split("-")
if len(parts) == 1:
p = int(parts[0])
return max(0, p - 1), min(p - 1, total - 1)
start = int(parts[0])
end = int(parts[1])
return max(0, start - 1), min(end - 1, total - 1)
@tool_parameters(
tool_parameters_schema(
path=StringSchema("The file path to read"),
@ -73,6 +105,7 @@ class _FsTool(Tool):
description="Maximum number of lines to read (default 2000)",
minimum=1,
),
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
required=["path"],
)
)
@ -81,6 +114,7 @@ class ReadFileTool(_FsTool):
_MAX_CHARS = 128_000
_DEFAULT_LIMIT = 2000
_MAX_PDF_PAGES = 20
@property
def name(self) -> str:
@ -89,9 +123,10 @@ class ReadFileTool(_FsTool):
@property
def description(self) -> str:
return (
"Read a text file. Output format: LINE_NUM|CONTENT. "
"Read a file (text or image). Text output format: LINE_NUM|CONTENT. "
"Images return visual content for analysis. "
"Use offset and limit for large files. "
"Cannot read binary files or images. "
"Cannot read non-image binary files. "
"Reads exceeding ~128K chars are truncated."
)
@ -99,16 +134,27 @@ class ReadFileTool(_FsTool):
def read_only(self) -> bool:
return True
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any:
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any:
try:
if not path:
return "Error reading file: Unknown path"
# Device path blacklist
if _is_blocked_device(path):
return f"Error: Reading {path} is blocked (device path that could hang or produce infinite output)."
fp = self._resolve(path)
if _is_blocked_device(fp):
return f"Error: Reading {fp} is blocked (device path that could hang or produce infinite output)."
if not fp.exists():
return f"Error: File not found: {path}"
if not fp.is_file():
return f"Error: Not a file: {path}"
# PDF support
if fp.suffix.lower() == ".pdf":
return self._read_pdf(fp, pages)
raw = fp.read_bytes()
if not raw:
return f"(Empty file: {path})"
@ -117,6 +163,10 @@ class ReadFileTool(_FsTool):
if mime and mime.startswith("image/"):
return build_image_content_blocks(raw, mime, str(fp), f"(Image file: {path})")
# Read dedup: same path + offset + limit + unchanged mtime → stub
if file_state.is_unchanged(fp, offset=offset, limit=limit):
return f"[File unchanged since last read: {path}]"
try:
text_content = raw.decode("utf-8")
except UnicodeDecodeError:
@ -149,12 +199,59 @@ class ReadFileTool(_FsTool):
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
else:
result += f"\n\n(End of file — {total} lines total)"
file_state.record_read(fp, offset=offset, limit=limit)
return result
except PermissionError as e:
return f"Error: {e}"
except Exception as e:
return f"Error reading file: {e}"
def _read_pdf(self, fp: Path, pages: str | None) -> str:
try:
import fitz # pymupdf
except ImportError:
return "Error: PDF reading requires pymupdf. Install with: pip install pymupdf"
try:
doc = fitz.open(str(fp))
except Exception as e:
return f"Error reading PDF: {e}"
total_pages = len(doc)
if pages:
try:
start, end = _parse_page_range(pages, total_pages)
except (ValueError, IndexError):
doc.close()
return f"Error: Invalid page range '{pages}'. Use format like '1-5'."
if start > end or start >= total_pages:
doc.close()
return f"Error: Page range '{pages}' is out of bounds (document has {total_pages} pages)."
else:
start = 0
end = min(total_pages - 1, self._MAX_PDF_PAGES - 1)
if end - start + 1 > self._MAX_PDF_PAGES:
end = start + self._MAX_PDF_PAGES - 1
parts: list[str] = []
for i in range(start, end + 1):
page = doc[i]
text = page.get_text().strip()
if text:
parts.append(f"--- Page {i + 1} ---\n{text}")
doc.close()
if not parts:
return f"(PDF has no extractable text: {fp})"
result = "\n\n".join(parts)
if end < total_pages - 1:
result += f"\n\n(Showing pages {start + 1}-{end + 1} of {total_pages}. Use pages='{end + 2}-{min(end + 1 + self._MAX_PDF_PAGES, total_pages)}' to continue.)"
if len(result) > self._MAX_CHARS:
result = result[:self._MAX_CHARS] + "\n\n(PDF text truncated at ~128K chars)"
return result
# ---------------------------------------------------------------------------
# write_file
@ -192,6 +289,7 @@ class WriteFileTool(_FsTool):
fp = self._resolve(path)
fp.parent.mkdir(parents=True, exist_ok=True)
fp.write_text(content, encoding="utf-8")
file_state.record_write(fp)
return f"Successfully wrote {len(content)} characters to {fp}"
except PermissionError as e:
return f"Error: {e}"
@ -203,30 +301,269 @@ class WriteFileTool(_FsTool):
# edit_file
# ---------------------------------------------------------------------------
_QUOTE_TABLE = str.maketrans({
"\u2018": "'", "\u2019": "'", # curly single → straight
"\u201c": '"', "\u201d": '"', # curly double → straight
"'": "'", '"': '"', # identity (kept for completeness)
})
def _normalize_quotes(s: str) -> str:
return s.translate(_QUOTE_TABLE)
def _curly_double_quotes(text: str) -> str:
parts: list[str] = []
opening = True
for ch in text:
if ch == '"':
parts.append("\u201c" if opening else "\u201d")
opening = not opening
else:
parts.append(ch)
return "".join(parts)
def _curly_single_quotes(text: str) -> str:
parts: list[str] = []
opening = True
for i, ch in enumerate(text):
if ch != "'":
parts.append(ch)
continue
prev_ch = text[i - 1] if i > 0 else ""
next_ch = text[i + 1] if i + 1 < len(text) else ""
if prev_ch.isalnum() and next_ch.isalnum():
parts.append("\u2019")
continue
parts.append("\u2018" if opening else "\u2019")
opening = not opening
return "".join(parts)
def _preserve_quote_style(old_text: str, actual_text: str, new_text: str) -> str:
"""Preserve curly quote style when a quote-normalized fallback matched."""
if _normalize_quotes(old_text.strip()) != _normalize_quotes(actual_text.strip()) or old_text == actual_text:
return new_text
styled = new_text
if any(ch in actual_text for ch in ("\u201c", "\u201d")) and '"' in styled:
styled = _curly_double_quotes(styled)
if any(ch in actual_text for ch in ("\u2018", "\u2019")) and "'" in styled:
styled = _curly_single_quotes(styled)
return styled
def _leading_ws(line: str) -> str:
return line[: len(line) - len(line.lstrip(" \t"))]
def _reindent_like_match(old_text: str, actual_text: str, new_text: str) -> str:
"""Preserve the outer indentation from the actual matched block."""
old_lines = old_text.split("\n")
actual_lines = actual_text.split("\n")
if len(old_lines) != len(actual_lines):
return new_text
comparable = [
(old_line, actual_line)
for old_line, actual_line in zip(old_lines, actual_lines)
if old_line.strip() and actual_line.strip()
]
if not comparable or any(
_normalize_quotes(old_line.strip()) != _normalize_quotes(actual_line.strip())
for old_line, actual_line in comparable
):
return new_text
old_ws = _leading_ws(comparable[0][0])
actual_ws = _leading_ws(comparable[0][1])
if actual_ws == old_ws:
return new_text
if old_ws:
if not actual_ws.startswith(old_ws):
return new_text
delta = actual_ws[len(old_ws):]
else:
delta = actual_ws
if not delta:
return new_text
return "\n".join((delta + line) if line else line for line in new_text.split("\n"))
@dataclass(slots=True)
class _MatchSpan:
start: int
end: int
text: str
line: int
def _find_exact_matches(content: str, old_text: str) -> list[_MatchSpan]:
matches: list[_MatchSpan] = []
start = 0
while True:
idx = content.find(old_text, start)
if idx == -1:
break
matches.append(
_MatchSpan(
start=idx,
end=idx + len(old_text),
text=content[idx : idx + len(old_text)],
line=content.count("\n", 0, idx) + 1,
)
)
start = idx + max(1, len(old_text))
return matches
def _find_trim_matches(content: str, old_text: str, *, normalize_quotes: bool = False) -> list[_MatchSpan]:
old_lines = old_text.splitlines()
if not old_lines:
return []
content_lines = content.splitlines()
content_lines_keepends = content.splitlines(keepends=True)
if len(content_lines) < len(old_lines):
return []
offsets: list[int] = []
pos = 0
for line in content_lines_keepends:
offsets.append(pos)
pos += len(line)
offsets.append(pos)
if normalize_quotes:
stripped_old = [_normalize_quotes(line.strip()) for line in old_lines]
else:
stripped_old = [line.strip() for line in old_lines]
matches: list[_MatchSpan] = []
window_size = len(stripped_old)
for i in range(len(content_lines) - window_size + 1):
window = content_lines[i : i + window_size]
if normalize_quotes:
comparable = [_normalize_quotes(line.strip()) for line in window]
else:
comparable = [line.strip() for line in window]
if comparable != stripped_old:
continue
start = offsets[i]
end = offsets[i + window_size]
if content_lines_keepends[i + window_size - 1].endswith("\n"):
end -= 1
matches.append(
_MatchSpan(
start=start,
end=end,
text=content[start:end],
line=i + 1,
)
)
return matches
def _find_quote_matches(content: str, old_text: str) -> list[_MatchSpan]:
norm_content = _normalize_quotes(content)
norm_old = _normalize_quotes(old_text)
matches: list[_MatchSpan] = []
start = 0
while True:
idx = norm_content.find(norm_old, start)
if idx == -1:
break
matches.append(
_MatchSpan(
start=idx,
end=idx + len(old_text),
text=content[idx : idx + len(old_text)],
line=content.count("\n", 0, idx) + 1,
)
)
start = idx + max(1, len(norm_old))
return matches
def _find_matches(content: str, old_text: str) -> list[_MatchSpan]:
"""Locate all matches using progressively looser strategies."""
for matcher in (
lambda: _find_exact_matches(content, old_text),
lambda: _find_trim_matches(content, old_text),
lambda: _find_trim_matches(content, old_text, normalize_quotes=True),
lambda: _find_quote_matches(content, old_text),
):
matches = matcher()
if matches:
return matches
return []
def _find_match_line_numbers(content: str, old_text: str) -> list[int]:
"""Return 1-based starting line numbers for the current matching strategies."""
return [match.line for match in _find_matches(content, old_text)]
def _collapse_internal_whitespace(text: str) -> str:
return "\n".join(" ".join(line.split()) for line in text.splitlines())
def _diagnose_near_match(old_text: str, actual_text: str) -> list[str]:
"""Return actionable hints describing why text was close but not exact."""
hints: list[str] = []
if old_text.lower() == actual_text.lower() and old_text != actual_text:
hints.append("letter case differs")
if _collapse_internal_whitespace(old_text) == _collapse_internal_whitespace(actual_text) and old_text != actual_text:
hints.append("whitespace differs")
if old_text.rstrip("\n") == actual_text.rstrip("\n") and old_text != actual_text:
hints.append("trailing newline differs")
if _normalize_quotes(old_text) == _normalize_quotes(actual_text) and old_text != actual_text:
hints.append("quote style differs")
return hints
def _best_window(old_text: str, content: str) -> tuple[float, int, list[str], list[str]]:
"""Find the closest line-window match and return ratio/start/snippet/hints."""
lines = content.splitlines(keepends=True)
old_lines = old_text.splitlines(keepends=True)
window = max(1, len(old_lines))
best_ratio, best_start = -1.0, 0
best_window_lines: list[str] = []
for i in range(max(1, len(lines) - window + 1)):
current = lines[i : i + window]
ratio = difflib.SequenceMatcher(None, old_lines, current).ratio()
if ratio > best_ratio:
best_ratio, best_start = ratio, i
best_window_lines = current
actual_text = "".join(best_window_lines).replace("\r\n", "\n").rstrip("\n")
hints = _diagnose_near_match(old_text.replace("\r\n", "\n").rstrip("\n"), actual_text)
return best_ratio, best_start, best_window_lines, hints
def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
"""Locate old_text in content: exact first, then line-trimmed sliding window.
"""Locate old_text in content with a multi-level fallback chain:
1. Exact substring match
2. Line-trimmed sliding window (handles indentation differences)
3. Smart quote normalization (curly straight quotes)
Both inputs should use LF line endings (caller normalises CRLF).
Returns (matched_fragment, count) or (None, 0).
"""
if old_text in content:
return old_text, content.count(old_text)
old_lines = old_text.splitlines()
if not old_lines:
matches = _find_matches(content, old_text)
if not matches:
return None, 0
stripped_old = [l.strip() for l in old_lines]
content_lines = content.splitlines()
candidates = []
for i in range(len(content_lines) - len(stripped_old) + 1):
window = content_lines[i : i + len(stripped_old)]
if [l.strip() for l in window] == stripped_old:
candidates.append("\n".join(window))
if candidates:
return candidates[0], len(candidates)
return None, 0
return matches[0].text, len(matches)
@tool_parameters(
@ -241,6 +578,9 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
class EditFileTool(_FsTool):
"""Edit a file by replacing text with fallback matching."""
_MAX_EDIT_FILE_SIZE = 1024 * 1024 * 1024 # 1 GiB
_MARKDOWN_EXTS = frozenset({".md", ".mdx", ".markdown"})
@property
def name(self) -> str:
return "edit_file"
@ -249,11 +589,16 @@ class EditFileTool(_FsTool):
def description(self) -> str:
return (
"Edit a file by replacing old_text with new_text. "
"Tolerates minor whitespace/indentation differences. "
"Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. "
"If old_text matches multiple times, you must provide more context "
"or set replace_all=true. Shows a diff of the closest match on failure."
)
@staticmethod
def _strip_trailing_ws(text: str) -> str:
"""Strip trailing whitespace from each line."""
return "\n".join(line.rstrip() for line in text.split("\n"))
async def execute(
self, path: str | None = None, old_text: str | None = None,
new_text: str | None = None,
@ -267,55 +612,133 @@ class EditFileTool(_FsTool):
if new_text is None:
raise ValueError("Unknown new_text")
# .ipynb detection
if path.endswith(".ipynb"):
return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file."
fp = self._resolve(path)
# Create-file semantics: old_text='' + file doesn't exist → create
if not fp.exists():
return f"Error: File not found: {path}"
if old_text == "":
fp.parent.mkdir(parents=True, exist_ok=True)
fp.write_text(new_text, encoding="utf-8")
file_state.record_write(fp)
return f"Successfully created {fp}"
return self._file_not_found_msg(path, fp)
# File size protection
try:
fsize = fp.stat().st_size
except OSError:
fsize = 0
if fsize > self._MAX_EDIT_FILE_SIZE:
return f"Error: File too large to edit ({fsize / (1024**3):.1f} GiB). Maximum is 1 GiB."
# Create-file: old_text='' but file exists and not empty → reject
if old_text == "":
raw = fp.read_bytes()
content = raw.decode("utf-8")
if content.strip():
return f"Error: Cannot create file — {path} already exists and is not empty."
fp.write_text(new_text, encoding="utf-8")
file_state.record_write(fp)
return f"Successfully edited {fp}"
# Read-before-edit check
warning = file_state.check_read(fp)
raw = fp.read_bytes()
uses_crlf = b"\r\n" in raw
content = raw.decode("utf-8").replace("\r\n", "\n")
match, count = _find_match(content, old_text.replace("\r\n", "\n"))
norm_old = old_text.replace("\r\n", "\n")
matches = _find_matches(content, norm_old)
if match is None:
if not matches:
return self._not_found_msg(old_text, content, path)
count = len(matches)
if count > 1 and not replace_all:
line_numbers = [match.line for match in matches]
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
if len(line_numbers) > 3:
preview += ", ..."
location_hint = f" at {preview}" if preview else ""
return (
f"Warning: old_text appears {count} times. "
f"Warning: old_text appears {count} times{location_hint}. "
"Provide more context to make it unique, or set replace_all=true."
)
norm_new = new_text.replace("\r\n", "\n")
new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1)
# Trailing whitespace stripping (skip markdown to preserve double-space line breaks)
if fp.suffix.lower() not in self._MARKDOWN_EXTS:
norm_new = self._strip_trailing_ws(norm_new)
selected = matches if replace_all else matches[:1]
new_content = content
for match in reversed(selected):
replacement = _preserve_quote_style(norm_old, match.text, norm_new)
replacement = _reindent_like_match(norm_old, match.text, replacement)
# Delete-line cleanup: when deleting text (new_text=''), consume trailing
# newline to avoid leaving a blank line
end = match.end
if replacement == "" and not match.text.endswith("\n") and content[end:end + 1] == "\n":
end += 1
new_content = new_content[: match.start] + replacement + new_content[end:]
if uses_crlf:
new_content = new_content.replace("\n", "\r\n")
fp.write_bytes(new_content.encode("utf-8"))
return f"Successfully edited {fp}"
file_state.record_write(fp)
msg = f"Successfully edited {fp}"
if warning:
msg = f"{warning}\n{msg}"
return msg
except PermissionError as e:
return f"Error: {e}"
except Exception as e:
return f"Error editing file: {e}"
def _file_not_found_msg(self, path: str, fp: Path) -> str:
"""Build an error message with 'Did you mean ...?' suggestions."""
parent = fp.parent
suggestions: list[str] = []
if parent.is_dir():
siblings = [f.name for f in parent.iterdir() if f.is_file()]
close = difflib.get_close_matches(fp.name, siblings, n=3, cutoff=0.6)
suggestions = [str(parent / c) for c in close]
parts = [f"Error: File not found: {path}"]
if suggestions:
parts.append("Did you mean: " + ", ".join(suggestions) + "?")
return "\n".join(parts)
@staticmethod
def _not_found_msg(old_text: str, content: str, path: str) -> str:
lines = content.splitlines(keepends=True)
old_lines = old_text.splitlines(keepends=True)
window = len(old_lines)
best_ratio, best_start = 0.0, 0
for i in range(max(1, len(lines) - window + 1)):
ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio()
if ratio > best_ratio:
best_ratio, best_start = ratio, i
best_ratio, best_start, best_window_lines, hints = _best_window(old_text, content)
if best_ratio > 0.5:
diff = "\n".join(difflib.unified_diff(
old_lines, lines[best_start : best_start + window],
old_text.splitlines(keepends=True),
best_window_lines,
fromfile="old_text (provided)",
tofile=f"{path} (actual, line {best_start + 1})",
lineterm="",
))
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
hint_text = ""
if hints:
hint_text = "\nPossible cause: " + ", ".join(hints) + "."
return (
f"Error: old_text not found in {path}."
f"{hint_text}\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
)
if hints:
return (
f"Error: old_text not found in {path}. "
f"Possible cause: {', '.join(hints)}. "
"Copy the exact text from read_file and try again."
)
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."

View File

@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
if "properties" in normalized and isinstance(normalized["properties"], dict):
normalized["properties"] = {
name: _normalize_schema_for_openai(prop)
if isinstance(prop, dict)
else prop
name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop
for name, prop in normalized["properties"].items()
}
@ -138,9 +136,7 @@ class MCPToolWrapper(Tool):
class MCPResourceWrapper(Tool):
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
def __init__(
self, session, server_name: str, resource_def, resource_timeout: int = 30
):
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
self._session = session
self._uri = resource_def.uri
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool):
class MCPPromptWrapper(Tool):
"""Wraps an MCP prompt as a read-only nanobot Tool."""
def __init__(
self, session, server_name: str, prompt_def, prompt_timeout: int = 30
):
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
self._session = session
self._prompt_name = prompt_def.name
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool):
timeout=self._prompt_timeout,
)
except asyncio.TimeoutError:
logger.warning(
"MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout
)
logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout)
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
except asyncio.CancelledError:
task = asyncio.current_task()
@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool):
except McpError as exc:
logger.error(
"MCP prompt '{}' failed: code={} message={}",
self._name, exc.error.code, exc.error.message,
self._name,
exc.error.code,
exc.error.message,
)
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
except Exception as exc:
logger.exception(
"MCP prompt '{}' failed: {}: {}",
self._name, type(exc).__name__, exc,
self._name,
type(exc).__name__,
exc,
)
return f"(MCP prompt call failed: {type(exc).__name__})"
@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool):
async def connect_mcp_servers(
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
) -> None:
"""Connect to configured MCP servers and register their tools, resources, and prompts."""
mcp_servers: dict, registry: ToolRegistry
) -> dict[str, AsyncExitStack]:
"""Connect to configured MCP servers and register their tools, resources, prompts.
Returns a dict mapping server name -> its dedicated AsyncExitStack.
Each server gets its own stack and runs in its own task to prevent
cancel scope conflicts when multiple MCP servers are configured.
"""
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamable_http_client
for name, cfg in mcp_servers.items():
async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]:
server_stack = AsyncExitStack()
await server_stack.__aenter__()
try:
transport_type = cfg.type
if not transport_type:
if cfg.command:
transport_type = "stdio"
elif cfg.url:
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
transport_type = (
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
)
else:
logger.warning("MCP server '{}': no command or url configured, skipping", name)
continue
await server_stack.aclose()
return name, None
if transport_type == "stdio":
params = StdioServerParameters(
command=cfg.command, args=cfg.args, env=cfg.env or None
)
read, write = await stack.enter_async_context(stdio_client(params))
read, write = await server_stack.enter_async_context(stdio_client(params))
elif transport_type == "sse":
def httpx_client_factory(
headers: dict[str, str] | None = None,
timeout: httpx.Timeout | None = None,
@ -353,27 +358,26 @@ async def connect_mcp_servers(
auth=auth,
)
read, write = await stack.enter_async_context(
read, write = await server_stack.enter_async_context(
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
)
elif transport_type == "streamableHttp":
# Always provide an explicit httpx client so MCP HTTP transport does not
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
http_client = await stack.enter_async_context(
http_client = await server_stack.enter_async_context(
httpx.AsyncClient(
headers=cfg.headers or None,
follow_redirects=True,
timeout=None,
)
)
read, write, _ = await stack.enter_async_context(
read, write, _ = await server_stack.enter_async_context(
streamable_http_client(cfg.url, http_client=http_client)
)
else:
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
continue
await server_stack.aclose()
return name, None
session = await stack.enter_async_context(ClientSession(read, write))
session = await server_stack.enter_async_context(ClientSession(read, write))
await session.initialize()
tools = await session.list_tools()
@ -418,7 +422,6 @@ async def connect_mcp_servers(
", ".join(available_wrapped_names) or "(none)",
)
# --- Register resources ---
try:
resources_result = await session.list_resources()
for resource in resources_result.resources:
@ -433,7 +436,6 @@ async def connect_mcp_servers(
except Exception as e:
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
# --- Register prompts ---
try:
prompts_result = await session.list_prompts()
for prompt in prompts_result.prompts:
@ -442,14 +444,54 @@ async def connect_mcp_servers(
)
registry.register(wrapper)
registered_count += 1
logger.debug(
"MCP: registered prompt '{}' from server '{}'", wrapper.name, name
)
logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name)
except Exception as e:
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
logger.info(
"MCP server '{}': connected, {} capabilities registered", name, registered_count
)
return name, server_stack
except Exception as e:
logger.error("MCP server '{}': failed to connect: {}", name, e)
hint = ""
text = str(e).lower()
if any(
marker in text
for marker in (
"parse error",
"invalid json",
"unexpected token",
"jsonrpc",
"content-length",
)
):
hint = (
" Hint: this looks like stdio protocol pollution. Make sure the MCP server writes "
"only JSON-RPC to stdout and sends logs/debug output to stderr instead."
)
logger.error("MCP server '{}': failed to connect: {}{}", name, e, hint)
try:
await server_stack.aclose()
except Exception:
pass
return name, None
server_stacks: dict[str, AsyncExitStack] = {}
tasks: list[asyncio.Task] = []
for name, cfg in mcp_servers.items():
task = asyncio.create_task(connect_single_server(name, cfg))
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
name = list(mcp_servers.keys())[i]
if isinstance(result, BaseException):
if not isinstance(result, asyncio.CancelledError):
logger.error("MCP server '{}' connection task failed: {}", name, result)
elif result is not None and result[1] is not None:
server_stacks[result[0]] = result[1]
return server_stacks

View File

@ -0,0 +1,161 @@
"""NotebookEditTool — edit Jupyter .ipynb notebooks."""
from __future__ import annotations
import json
import uuid
from typing import Any
from nanobot.agent.tools.base import tool_parameters
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
from nanobot.agent.tools.filesystem import _FsTool
def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict:
cell: dict[str, Any] = {
"cell_type": cell_type,
"source": source,
"metadata": {},
}
if cell_type == "code":
cell["outputs"] = []
cell["execution_count"] = None
if generate_id:
cell["id"] = uuid.uuid4().hex[:8]
return cell
def _make_empty_notebook() -> dict:
return {
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
"language_info": {"name": "python"},
},
"cells": [],
}
@tool_parameters(
tool_parameters_schema(
path=StringSchema("Path to the .ipynb notebook file"),
cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0),
new_source=StringSchema("New source content for the cell"),
cell_type=StringSchema(
"Cell type: 'code' or 'markdown' (default: code)",
enum=["code", "markdown"],
),
edit_mode=StringSchema(
"Mode: 'replace' (default), 'insert' (after target), or 'delete'",
enum=["replace", "insert", "delete"],
),
required=["path", "cell_index"],
)
)
class NotebookEditTool(_FsTool):
"""Edit Jupyter notebook cells: replace, insert, or delete."""
_VALID_CELL_TYPES = frozenset({"code", "markdown"})
_VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"})
@property
def name(self) -> str:
return "notebook_edit"
@property
def description(self) -> str:
return (
"Edit a Jupyter notebook (.ipynb) cell. "
"Modes: replace (default) replaces cell content, "
"insert adds a new cell after the target index, "
"delete removes the cell at the index. "
"cell_index is 0-based."
)
async def execute(
self,
path: str | None = None,
cell_index: int = 0,
new_source: str = "",
cell_type: str = "code",
edit_mode: str = "replace",
**kwargs: Any,
) -> str:
try:
if not path:
return "Error: path is required"
if not path.endswith(".ipynb"):
return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files."
if edit_mode not in self._VALID_EDIT_MODES:
return (
f"Error: Invalid edit_mode '{edit_mode}'. "
"Use one of: replace, insert, delete."
)
if cell_type not in self._VALID_CELL_TYPES:
return (
f"Error: Invalid cell_type '{cell_type}'. "
"Use one of: code, markdown."
)
fp = self._resolve(path)
# Create new notebook if file doesn't exist and mode is insert
if not fp.exists():
if edit_mode != "insert":
return f"Error: File not found: {path}"
nb = _make_empty_notebook()
cell = _new_cell(new_source, cell_type, generate_id=True)
nb["cells"].append(cell)
fp.parent.mkdir(parents=True, exist_ok=True)
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
return f"Successfully created {fp} with 1 cell"
try:
nb = json.loads(fp.read_text(encoding="utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
return f"Error: Failed to parse notebook: {e}"
cells = nb.get("cells", [])
nbformat_minor = nb.get("nbformat_minor", 0)
generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5
if edit_mode == "delete":
if cell_index < 0 or cell_index >= len(cells):
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
cells.pop(cell_index)
nb["cells"] = cells
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
return f"Successfully deleted cell {cell_index} from {fp}"
if edit_mode == "insert":
insert_at = min(cell_index + 1, len(cells))
cell = _new_cell(new_source, cell_type, generate_id=generate_id)
cells.insert(insert_at, cell)
nb["cells"] = cells
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
return f"Successfully inserted cell at index {insert_at} in {fp}"
# Default: replace
if cell_index < 0 or cell_index >= len(cells):
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
cells[cell_index]["source"] = new_source
if cell_type and cells[cell_index].get("cell_type") != cell_type:
cells[cell_index]["cell_type"] = cell_type
if cell_type == "code":
cells[cell_index].setdefault("outputs", [])
cells[cell_index].setdefault("execution_count", None)
elif "outputs" in cells[cell_index]:
del cells[cell_index]["outputs"]
cells[cell_index].pop("execution_count", None)
nb["cells"] = cells
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
return f"Successfully edited cell {cell_index} in {fp}"
except PermissionError as e:
return f"Error: {e}"
except Exception as e:
return f"Error editing notebook: {e}"

View File

@ -68,6 +68,13 @@ class ToolRegistry:
params: dict[str, Any],
) -> tuple[Tool | None, dict[str, Any], str | None]:
"""Resolve, cast, and validate one tool call."""
# Guard against invalid parameter types (e.g., list instead of dict)
if not isinstance(params, dict) and name in ('write_file', 'read_file'):
return None, params, (
f"Error: Tool '{name}' parameters must be a JSON object, got {type(params).__name__}. "
"Use named parameters: tool_name(param1=\"value1\", param2=\"value2\")"
)
tool = self._tools.get(name)
if not tool:
return None, params, (

View File

@ -46,6 +46,7 @@ class ExecTool(Tool):
restrict_to_workspace: bool = False,
sandbox: str = "",
path_append: str = "",
allowed_env_keys: list[str] | None = None,
):
self.timeout = timeout
self.working_dir = working_dir
@ -60,10 +61,19 @@ class ExecTool(Tool):
r">\s*/dev/sd", # write to disk
r"\b(shutdown|reboot|poweroff)\b", # system power
r":\(\)\s*\{.*\};\s*:", # fork bomb
# Block writes to nanobot internal state files (#2989).
# history.jsonl / .dream_cursor are managed by append_history();
# direct writes corrupt the cursor format and crash /dream.
r">>?\s*\S*(?:history\.jsonl|\.dream_cursor)", # > / >> redirect
r"\btee\b[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # tee / tee -a
r"\b(?:cp|mv)\b(?:\s+[^\s|;&<>]+)+\s+\S*(?:history\.jsonl|\.dream_cursor)", # cp/mv target
r"\bdd\b[^|;&<>]*\bof=\S*(?:history\.jsonl|\.dream_cursor)", # dd of=
r"\bsed\s+-i[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # sed -i
]
self.allow_patterns = allow_patterns or []
self.restrict_to_workspace = restrict_to_workspace
self.path_append = path_append
self.allowed_env_keys = allowed_env_keys or []
@property
def name(self) -> str:
@ -91,6 +101,21 @@ class ExecTool(Tool):
timeout: int | None = None, **kwargs: Any,
) -> str:
cwd = working_dir or self.working_dir or os.getcwd()
# Prevent an LLM-supplied working_dir from escaping the configured
# workspace when restrict_to_workspace is enabled (#2826). Without
# this, a caller can pass working_dir="/etc" and then all absolute
# paths under /etc would pass the _guard_command check that anchors
# on cwd.
if self.restrict_to_workspace and self.working_dir:
try:
requested = Path(cwd).expanduser().resolve()
workspace_root = Path(self.working_dir).expanduser().resolve()
except Exception:
return "Error: working_dir could not be resolved"
if requested != workspace_root and workspace_root not in requested.parents:
return "Error: working_dir is outside the configured workspace"
guard_error = self._guard_command(command, cwd)
if guard_error:
return guard_error
@ -208,7 +233,7 @@ class ExecTool(Tool):
"""
if _IS_WINDOWS:
sr = os.environ.get("SYSTEMROOT", r"C:\Windows")
return {
env = {
"SYSTEMROOT": sr,
"COMSPEC": os.environ.get("COMSPEC", f"{sr}\\system32\\cmd.exe"),
"USERPROFILE": os.environ.get("USERPROFILE", ""),
@ -218,13 +243,29 @@ class ExecTool(Tool):
"TMP": os.environ.get("TMP", f"{sr}\\Temp"),
"PATHEXT": os.environ.get("PATHEXT", ".COM;.EXE;.BAT;.CMD"),
"PATH": os.environ.get("PATH", f"{sr}\\system32;{sr}"),
"APPDATA": os.environ.get("APPDATA", ""),
"LOCALAPPDATA": os.environ.get("LOCALAPPDATA", ""),
"ProgramData": os.environ.get("ProgramData", ""),
"ProgramFiles": os.environ.get("ProgramFiles", ""),
"ProgramFiles(x86)": os.environ.get("ProgramFiles(x86)", ""),
"ProgramW6432": os.environ.get("ProgramW6432", ""),
}
for key in self.allowed_env_keys:
val = os.environ.get(key)
if val is not None:
env[key] = val
return env
home = os.environ.get("HOME", "/tmp")
return {
env = {
"HOME": home,
"LANG": os.environ.get("LANG", "C.UTF-8"),
"TERM": os.environ.get("TERM", "dumb"),
}
for key in self.allowed_env_keys:
val = os.environ.get(key)
if val is not None:
env[key] = val
return env
def _guard_command(self, command: str, cwd: str) -> str | None:
"""Best-effort safety guard for potentially destructive commands."""

View File

@ -96,10 +96,37 @@ class WebSearchTool(Tool):
self.config = config if config is not None else WebSearchConfig()
self.proxy = proxy
def _effective_provider(self) -> str:
"""Resolve the backend that execute() will actually use."""
provider = self.config.provider.strip().lower() or "brave"
if provider == "duckduckgo":
return "duckduckgo"
if provider == "brave":
api_key = self.config.api_key or os.environ.get("BRAVE_API_KEY", "")
return "brave" if api_key else "duckduckgo"
if provider == "tavily":
api_key = self.config.api_key or os.environ.get("TAVILY_API_KEY", "")
return "tavily" if api_key else "duckduckgo"
if provider == "searxng":
base_url = (self.config.base_url or os.environ.get("SEARXNG_BASE_URL", "")).strip()
return "searxng" if base_url else "duckduckgo"
if provider == "jina":
api_key = self.config.api_key or os.environ.get("JINA_API_KEY", "")
return "jina" if api_key else "duckduckgo"
if provider == "kagi":
api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "")
return "kagi" if api_key else "duckduckgo"
return provider
@property
def read_only(self) -> bool:
return True
@property
def exclusive(self) -> bool:
"""DuckDuckGo searches are serialized because ddgs is not concurrency-safe."""
return self._effective_provider() == "duckduckgo"
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
provider = self.config.provider.strip().lower() or "brave"
n = min(max(count or self.config.max_results, 1), 10)
@ -114,6 +141,8 @@ class WebSearchTool(Tool):
return await self._search_jina(query, n)
elif provider == "brave":
return await self._search_brave(query, n)
elif provider == "kagi":
return await self._search_kagi(query, n)
else:
return f"Error: unknown search provider '{provider}'"
@ -204,6 +233,29 @@ class WebSearchTool(Tool):
logger.warning("Jina search failed ({}), falling back to DuckDuckGo", e)
return await self._search_duckduckgo(query, n)
async def _search_kagi(self, query: str, n: int) -> str:
api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "")
if not api_key:
logger.warning("KAGI_API_KEY not set, falling back to DuckDuckGo")
return await self._search_duckduckgo(query, n)
try:
async with httpx.AsyncClient(proxy=self.proxy) as client:
r = await client.get(
"https://kagi.com/api/v0/search",
params={"q": query, "limit": n},
headers={"Authorization": f"Bot {api_key}"},
timeout=10.0,
)
r.raise_for_status()
# t=0 items are search results; other values are related searches, etc.
items = [
{"title": d.get("title", ""), "url": d.get("url", ""), "content": d.get("snippet", "")}
for d in r.json().get("data", []) if d.get("t") == 0
]
return _format_results(query, items, n)
except Exception as e:
return f"Error: {e}"
async def _search_duckduckgo(self, query: str, n: int) -> str:
try:
# Note: duckduckgo_search is synchronous and does its own requests

View File

@ -84,6 +84,10 @@ def _save_base64_data_url(data_url: str, media_dir: Path) -> str | None:
raw = base64.b64decode(b64_payload)
except Exception:
return None
if len(raw) > MAX_FILE_SIZE:
raise _FileSizeExceeded(
f"File exceeds {MAX_FILE_SIZE // (1024 * 1024)}MB limit"
)
ext = mimetypes.guess_extension(mime_type) or ".bin"
filename = f"{uuid.uuid4().hex[:12]}{ext}"
dest = media_dir / safe_filename(filename)

View File

@ -5,6 +5,8 @@ import json
import mimetypes
import os
import time
import zipfile
from io import BytesIO
from pathlib import Path
from typing import Any
from urllib.parse import unquote, urlparse
@ -171,6 +173,7 @@ class DingTalkChannel(BaseChannel):
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
_ZIP_BEFORE_UPLOAD_EXTS = {".htm", ".html"}
@classmethod
def default_config(cls) -> dict[str, Any]:
@ -287,6 +290,31 @@ class DingTalkChannel(BaseChannel):
name = os.path.basename(urlparse(media_ref).path)
return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
@staticmethod
def _zip_bytes(filename: str, data: bytes) -> tuple[bytes, str, str]:
stem = Path(filename).stem or "attachment"
safe_name = filename or "attachment.bin"
zip_name = f"{stem}.zip"
buffer = BytesIO()
with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
archive.writestr(safe_name, data)
return buffer.getvalue(), zip_name, "application/zip"
def _normalize_upload_payload(
self,
filename: str,
data: bytes,
content_type: str | None,
) -> tuple[bytes, str, str | None]:
ext = Path(filename).suffix.lower()
if ext in self._ZIP_BEFORE_UPLOAD_EXTS or content_type == "text/html":
logger.info(
"DingTalk does not accept raw HTML attachments, zipping {} before upload",
filename,
)
return self._zip_bytes(filename, data)
return data, filename, content_type
async def _read_media_bytes(
self,
media_ref: str,
@ -309,6 +337,9 @@ class DingTalkChannel(BaseChannel):
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
return resp.content, filename, content_type or None
except httpx.TransportError as e:
logger.error("DingTalk media download network error ref={} err={}", media_ref, e)
raise
except Exception as e:
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
return None, None, None
@ -360,6 +391,9 @@ class DingTalkChannel(BaseChannel):
logger.error("DingTalk media upload missing media_id body={}", text[:500])
return None
return str(media_id)
except httpx.TransportError as e:
logger.error("DingTalk media upload network error type={} err={}", media_type, e)
raise
except Exception as e:
logger.error("DingTalk media upload error type={} err={}", media_type, e)
return None
@ -409,6 +443,9 @@ class DingTalkChannel(BaseChannel):
return False
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
return True
except httpx.TransportError as e:
logger.error("DingTalk network error sending message msgKey={} err={}", msg_key, e)
raise
except Exception as e:
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
return False
@ -444,6 +481,7 @@ class DingTalkChannel(BaseChannel):
return False
filename = filename or self._guess_filename(media_ref, upload_type)
data, filename, content_type = self._normalize_upload_payload(filename, data, content_type)
file_type = Path(filename).suffix.lower().lstrip(".")
if not file_type:
guessed = mimetypes.guess_extension(content_type or "")

View File

@ -4,6 +4,8 @@ from __future__ import annotations
import asyncio
import importlib.util
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
@ -20,6 +22,7 @@ from nanobot.utils.helpers import safe_filename, split_message
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
if TYPE_CHECKING:
import aiohttp
import discord
from discord import app_commands
from discord.abc import Messageable
@ -34,6 +37,16 @@ MAX_MESSAGE_LEN = 2000 # Discord message character limit
TYPING_INTERVAL_S = 8
@dataclass
class _StreamBuf:
"""Per-chat streaming accumulator for progressive Discord message edits."""
text: str = ""
message: Any | None = None
last_edit: float = 0.0
stream_id: str | None = None
class DiscordConfig(Base):
"""Discord channel configuration."""
@ -45,6 +58,10 @@ class DiscordConfig(Base):
read_receipt_emoji: str = "👀"
working_emoji: str = "🔧"
working_emoji_delay: float = 2.0
streaming: bool = True
proxy: str | None = None
proxy_username: str | None = None
proxy_password: str | None = None
if DISCORD_AVAILABLE:
@ -52,8 +69,15 @@ if DISCORD_AVAILABLE:
class DiscordBotClient(discord.Client):
"""discord.py client that forwards events to the channel."""
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
super().__init__(intents=intents)
def __init__(
self,
channel: DiscordChannel,
*,
intents: discord.Intents,
proxy: str | None = None,
proxy_auth: aiohttp.BasicAuth | None = None,
) -> None:
super().__init__(intents=intents, proxy=proxy, proxy_auth=proxy_auth)
self._channel = channel
self.tree = app_commands.CommandTree(self)
self._register_app_commands()
@ -117,6 +141,7 @@ if DISCORD_AVAILABLE:
)
for name, description, command_text in commands:
@self.tree.command(name=name, description=description)
async def command_handler(
interaction: discord.Interaction,
@ -173,7 +198,9 @@ if DISCORD_AVAILABLE:
else:
failed_media.append(Path(media_path).name)
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
for index, chunk in enumerate(
self._build_chunks(msg.content or "", failed_media, sent_media)
):
kwargs: dict[str, Any] = {"content": chunk}
if index == 0 and reference is not None and not sent_media:
kwargs["reference"] = reference
@ -242,6 +269,7 @@ class DiscordChannel(BaseChannel):
name = "discord"
display_name = "Discord"
_STREAM_EDIT_INTERVAL = 0.8
@classmethod
def default_config(cls) -> dict[str, Any]:
@ -263,6 +291,7 @@ class DiscordChannel(BaseChannel):
self._bot_user_id: str | None = None
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
self._stream_bufs: dict[str, _StreamBuf] = {}
async def start(self) -> None:
"""Start the Discord client."""
@ -277,7 +306,29 @@ class DiscordChannel(BaseChannel):
try:
intents = discord.Intents.none()
intents.value = self.config.intents
self._client = DiscordBotClient(self, intents=intents)
proxy_auth = None
has_user = bool(self.config.proxy_username)
has_pass = bool(self.config.proxy_password)
if has_user and has_pass:
import aiohttp
proxy_auth = aiohttp.BasicAuth(
login=self.config.proxy_username,
password=self.config.proxy_password,
)
elif has_user != has_pass:
logger.warning(
"Discord proxy auth incomplete: both proxy_username and "
"proxy_password must be set; ignoring partial credentials",
)
self._client = DiscordBotClient(
self,
intents=intents,
proxy=self.config.proxy,
proxy_auth=proxy_auth,
)
except Exception as e:
logger.error("Failed to initialize Discord client: {}", e)
self._client = None
@ -315,11 +366,71 @@ class DiscordChannel(BaseChannel):
await client.send_outbound(msg)
except Exception as e:
logger.error("Error sending Discord message: {}", e)
raise
finally:
if not is_progress:
await self._stop_typing(msg.chat_id)
await self._clear_reactions(msg.chat_id)
async def send_delta(
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
) -> None:
"""Progressive Discord delivery: send once, then edit until the stream ends."""
client = self._client
if client is None or not client.is_ready():
logger.warning("Discord client not ready; dropping stream delta")
return
meta = metadata or {}
stream_id = meta.get("_stream_id")
if meta.get("_stream_end"):
buf = self._stream_bufs.get(chat_id)
if not buf or buf.message is None or not buf.text:
return
if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id:
return
await self._finalize_stream(chat_id, buf)
return
buf = self._stream_bufs.get(chat_id)
if buf is None or (
stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id
):
buf = _StreamBuf(stream_id=stream_id)
self._stream_bufs[chat_id] = buf
elif buf.stream_id is None:
buf.stream_id = stream_id
buf.text += delta
if not buf.text.strip():
return
target = await self._resolve_channel(chat_id)
if target is None:
logger.warning("Discord stream target {} unavailable", chat_id)
return
now = time.monotonic()
if buf.message is None:
try:
buf.message = await target.send(content=buf.text)
buf.last_edit = now
except Exception as e:
logger.warning("Discord stream initial send failed: {}", e)
raise
return
if (now - buf.last_edit) < self._STREAM_EDIT_INTERVAL:
return
try:
await buf.message.edit(content=DiscordBotClient._build_chunks(buf.text, [], False)[0])
buf.last_edit = now
except Exception as e:
logger.warning("Discord stream edit failed: {}", e)
raise
async def _handle_discord_message(self, message: discord.Message) -> None:
"""Handle incoming Discord messages from discord.py."""
if message.author.bot:
@ -373,6 +484,47 @@ class DiscordChannel(BaseChannel):
"""Backward-compatible alias for legacy tests/callers."""
await self._handle_discord_message(message)
async def _resolve_channel(self, chat_id: str) -> Any | None:
"""Resolve a Discord channel from cache first, then network fetch."""
client = self._client
if client is None or not client.is_ready():
return None
channel_id = int(chat_id)
channel = client.get_channel(channel_id)
if channel is not None:
return channel
try:
return await client.fetch_channel(channel_id)
except Exception as e:
logger.warning("Discord channel {} unavailable: {}", chat_id, e)
return None
async def _finalize_stream(self, chat_id: str, buf: _StreamBuf) -> None:
"""Commit the final streamed content and flush overflow chunks."""
chunks = DiscordBotClient._build_chunks(buf.text, [], False)
if not chunks:
self._stream_bufs.pop(chat_id, None)
return
try:
await buf.message.edit(content=chunks[0])
except Exception as e:
logger.warning("Discord final stream edit failed: {}", e)
raise
target = getattr(buf.message, "channel", None) or await self._resolve_channel(chat_id)
if target is None:
logger.warning("Discord stream follow-up target {} unavailable", chat_id)
self._stream_bufs.pop(chat_id, None)
return
for extra_chunk in chunks[1:]:
await target.send(content=extra_chunk)
self._stream_bufs.pop(chat_id, None)
await self._stop_typing(chat_id)
await self._clear_reactions(chat_id)
def _should_accept_inbound(
self,
message: discord.Message,
@ -423,7 +575,11 @@ class DiscordChannel(BaseChannel):
@staticmethod
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
"""Build metadata for inbound Discord messages."""
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
reply_to = (
str(message.reference.message_id)
if message.reference and message.reference.message_id
else None
)
return {
"message_id": str(message.id),
"guild_id": str(message.guild.id) if message.guild else None,
@ -438,7 +594,9 @@ class DiscordChannel(BaseChannel):
if self.config.group_policy == "mention":
bot_user_id = self._bot_user_id
if bot_user_id is None:
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
logger.debug(
"Discord message in {} ignored (bot identity unavailable)", message.channel.id
)
return False
if any(str(user.id) == bot_user_id for user in message.mentions):
@ -480,7 +638,6 @@ class DiscordChannel(BaseChannel):
except asyncio.CancelledError:
pass
async def _clear_reactions(self, chat_id: str) -> None:
"""Remove all pending reactions after bot replies."""
# Cancel delayed working emoji if it hasn't fired yet
@ -507,6 +664,7 @@ class DiscordChannel(BaseChannel):
async def _reset_runtime_state(self, close_client: bool) -> None:
"""Reset client and typing state."""
await self._cancel_all_typing()
self._stream_bufs.clear()
if close_client and self._client is not None and not self._client.is_closed():
try:
await self._client.close()

View File

@ -22,6 +22,8 @@ from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
# Message type display mapping
@ -250,9 +252,12 @@ class FeishuConfig(Base):
verification_token: str = ""
allow_from: list[str] = Field(default_factory=list)
react_emoji: str = "THUMBSUP"
done_emoji: str | None = None # Emoji to show when task is completed (e.g., "DONE", "OK")
tool_hint_prefix: str = "\U0001f527" # Prefix for inline tool hints (default: 🔧)
group_policy: Literal["open", "mention"] = "mention"
reply_to_message: bool = False # If True, bot replies quote the user's original message
streaming: bool = True
domain: Literal["feishu", "lark"] = "feishu" # Set to "lark" for international Lark
_STREAM_ELEMENT_ID = "streaming_md"
@ -326,10 +331,12 @@ class FeishuChannel(BaseChannel):
self._loop = asyncio.get_running_loop()
# Create Lark client for sending messages
domain = LARK_DOMAIN if self.config.domain == "lark" else FEISHU_DOMAIN
self._client = (
lark.Client.builder()
.app_id(self.config.app_id)
.app_secret(self.config.app_secret)
.domain(domain)
.log_level(lark.LogLevel.INFO)
.build()
)
@ -357,6 +364,7 @@ class FeishuChannel(BaseChannel):
self._ws_client = lark.ws.Client(
self.config.app_id,
self.config.app_secret,
domain=domain,
event_handler=event_handler,
log_level=lark.LogLevel.INFO,
)
@ -1012,14 +1020,29 @@ class FeishuChannel(BaseChannel):
elif msg_type in ("audio", "file", "media"):
file_key = content_json.get("file_key")
if file_key and message_id:
data, filename = await loop.run_in_executor(
None, self._download_file_sync, message_id, file_key, msg_type
)
if not filename:
filename = file_key[:16]
if msg_type == "audio" and not filename.endswith(".opus"):
filename = f"{filename}.opus"
if not file_key:
logger.warning("Feishu {} message missing file_key: {}", msg_type, content_json)
return None, f"[{msg_type}: missing file_key]"
if not message_id:
logger.warning("Feishu {} message missing message_id", msg_type)
return None, f"[{msg_type}: missing message_id]"
data, filename = await loop.run_in_executor(
None, self._download_file_sync, message_id, file_key, msg_type
)
if not data:
logger.warning("Feishu {} download failed: file_key={}", msg_type, file_key)
return None, f"[{msg_type}: download failed]"
if not filename:
filename = file_key[:16]
# Feishu voice messages are opus in OGG container.
# Use .ogg extension for better Whisper compatibility.
if msg_type == "audio":
if not any(filename.endswith(ext) for ext in (".opus", ".ogg", ".oga")):
filename = f"{filename}.ogg"
if data and filename:
file_path = media_dir / filename
@ -1263,7 +1286,14 @@ class FeishuChannel(BaseChannel):
async def send_delta(
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
) -> None:
"""Progressive streaming via CardKit: create card on first delta, stream-update on subsequent."""
"""Progressive streaming via CardKit: create card on first delta, stream-update on subsequent.
Supported metadata keys:
_stream_end: Finalize the streaming card.
_tool_hint: Delta is a formatted tool hint (for display only).
message_id: Original message id (used with _stream_end for reaction cleanup).
reaction_id: Reaction id to remove on stream end.
"""
if not self._client:
return
meta = metadata or {}
@ -1274,38 +1304,48 @@ class FeishuChannel(BaseChannel):
if meta.get("_stream_end"):
if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")):
await self._remove_reaction(message_id, reaction_id)
# Add completion emoji if configured
if self.config.done_emoji and message_id:
await self._add_reaction(message_id, self.config.done_emoji)
buf = self._stream_bufs.pop(chat_id, None)
if not buf or not buf.text:
return
# Try to finalize via streaming card; if that fails (e.g.
# streaming mode was closed by Feishu due to timeout), fall
# back to sending a regular interactive card.
if buf.card_id:
buf.sequence += 1
await loop.run_in_executor(
ok = await loop.run_in_executor(
None,
self._stream_update_text_sync,
buf.card_id,
buf.text,
buf.sequence,
)
# Required so the chat list preview exits the streaming placeholder (Feishu streaming card docs).
buf.sequence += 1
await loop.run_in_executor(
None,
self._close_streaming_mode_sync,
buf.card_id,
buf.sequence,
)
else:
for chunk in self._split_elements_by_table_limit(
self._build_card_elements(buf.text)
):
card = json.dumps(
{"config": {"wide_screen_mode": True}, "elements": chunk},
ensure_ascii=False,
)
if ok:
buf.sequence += 1
await loop.run_in_executor(
None, self._send_message_sync, rid_type, chat_id, "interactive", card
None,
self._close_streaming_mode_sync,
buf.card_id,
buf.sequence,
)
return
logger.warning(
"Streaming card {} final update failed, falling back to regular card",
buf.card_id,
)
for chunk in self._split_elements_by_table_limit(
self._build_card_elements(buf.text)
):
card = json.dumps(
{"config": {"wide_screen_mode": True}, "elements": chunk},
ensure_ascii=False,
)
await loop.run_in_executor(
None, self._send_message_sync, rid_type, chat_id, "interactive", card
)
return
# --- accumulate delta ---
@ -1346,13 +1386,33 @@ class FeishuChannel(BaseChannel):
receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id"
loop = asyncio.get_running_loop()
# Handle tool hint messages as code blocks in interactive cards.
# These are progress-only messages and should bypass normal reply routing.
# Handle tool hint messages. When a streaming card is active for
# this chat, inline the hint into the card instead of sending a
# separate message so the user experience stays cohesive.
if msg.metadata.get("_tool_hint"):
if msg.content and msg.content.strip():
await self._send_tool_hint_card(
receive_id_type, msg.chat_id, msg.content.strip()
hint = (msg.content or "").strip()
if not hint:
return
buf = self._stream_bufs.get(msg.chat_id)
if buf and buf.card_id:
# Delegate to send_delta so tool hints get the same
# throttling (and card creation) as regular text deltas.
await self.send_delta(
msg.chat_id,
"\n\n" + self._format_tool_hint_delta(hint) + "\n\n",
)
return
# No active streaming card — send as a regular
# interactive card with the same 🔧 prefix style.
card = json.dumps(
{"config": {"wide_screen_mode": True}, "elements": [
{"tag": "markdown", "content": self._format_tool_hint_delta(hint)},
]},
ensure_ascii=False,
)
await loop.run_in_executor(
None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card
)
return
# Determine whether the first message should quote the user's message.
@ -1648,33 +1708,9 @@ class FeishuChannel(BaseChannel):
return "\n".join(part for part in parts if part)
async def _send_tool_hint_card(
self, receive_id_type: str, receive_id: str, tool_hint: str
) -> None:
"""Send tool hint as an interactive card with formatted code block.
Args:
receive_id_type: "chat_id" or "open_id"
receive_id: The target chat or user ID
tool_hint: Formatted tool hint string (e.g., 'web_search("q"), read_file("path")')
"""
loop = asyncio.get_running_loop()
# Put each top-level tool call on its own line without altering commas inside arguments.
formatted_code = self._format_tool_hint_lines(tool_hint)
card = {
"config": {"wide_screen_mode": True},
"elements": [
{"tag": "markdown", "content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```"}
],
}
await loop.run_in_executor(
None,
self._send_message_sync,
receive_id_type,
receive_id,
"interactive",
json.dumps(card, ensure_ascii=False),
def _format_tool_hint_delta(self, tool_hint: str) -> str:
"""Format a tool hint string with the 🔧 prefix for each line."""
lines = self.__class__._format_tool_hint_lines(tool_hint).split("\n")
return "\n".join(
f"{self.config.tool_hint_prefix} {ln}" for ln in lines if ln.strip()
)

View File

@ -242,43 +242,49 @@ class QQChannel(BaseChannel):
async def send(self, msg: OutboundMessage) -> None:
"""Send attachments first, then text."""
if not self._client:
logger.warning("QQ client not initialized")
return
try:
if not self._client:
logger.warning("QQ client not initialized")
return
msg_id = msg.metadata.get("message_id")
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
is_group = chat_type == "group"
msg_id = msg.metadata.get("message_id")
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
is_group = chat_type == "group"
# 1) Send media
for media_ref in msg.media or []:
ok = await self._send_media(
chat_id=msg.chat_id,
media_ref=media_ref,
msg_id=msg_id,
is_group=is_group,
)
if not ok:
filename = (
os.path.basename(urlparse(media_ref).path)
or os.path.basename(media_ref)
or "file"
# 1) Send media
for media_ref in msg.media or []:
ok = await self._send_media(
chat_id=msg.chat_id,
media_ref=media_ref,
msg_id=msg_id,
is_group=is_group,
)
if not ok:
filename = (
os.path.basename(urlparse(media_ref).path)
or os.path.basename(media_ref)
or "file"
)
await self._send_text_only(
chat_id=msg.chat_id,
is_group=is_group,
msg_id=msg_id,
content=f"[Attachment send failed: {filename}]",
)
# 2) Send text
if msg.content and msg.content.strip():
await self._send_text_only(
chat_id=msg.chat_id,
is_group=is_group,
msg_id=msg_id,
content=f"[Attachment send failed: {filename}]",
content=msg.content.strip(),
)
# 2) Send text
if msg.content and msg.content.strip():
await self._send_text_only(
chat_id=msg.chat_id,
is_group=is_group,
msg_id=msg_id,
content=msg.content.strip(),
)
except (aiohttp.ClientError, OSError):
# Network / transport errors — propagate so ChannelManager can retry
raise
except Exception:
logger.exception("Error sending QQ message to chat_id={}", msg.chat_id)
async def _send_text_only(
self,
@ -359,7 +365,12 @@ class QQChannel(BaseChannel):
logger.info("QQ media sent: {}", filename)
return True
except (aiohttp.ClientError, OSError) as e:
# Network / transport errors — propagate for retry by caller
logger.warning("QQ send media network error filename={} err={}", filename, e)
raise
except Exception as e:
# API-level or other non-network errors — return False so send() can fallback
logger.error("QQ send media failed filename={} err={}", filename, e)
return False
@ -438,15 +449,26 @@ class QQChannel(BaseChannel):
endpoint = "/v2/users/{openid}/files"
id_key = "openid"
payload = {
payload: dict[str, Any] = {
id_key: chat_id,
"file_type": file_type,
"file_data": file_data,
"file_name": file_name,
"srv_send_msg": srv_send_msg,
}
# Only pass file_name for non-image types (file_type=4).
# Passing file_name for images causes QQ client to render them as
# file attachments instead of inline images.
if file_type != QQ_FILE_TYPE_IMAGE and file_name:
payload["file_name"] = file_name
route = Route("POST", endpoint, **{id_key: chat_id})
return await self._client.api._http.request(route, json=payload)
result = await self._client.api._http.request(route, json=payload)
# Extract only the file_info field to avoid extra fields (file_uuid, ttl, etc.)
# that may confuse QQ client when sending the media object.
if isinstance(result, dict) and "file_info" in result:
return {"file_info": result["file_info"]}
return result
# ---------------------------
# Inbound (receive)
@ -454,58 +476,68 @@ class QQChannel(BaseChannel):
async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None:
"""Parse inbound message, download attachments, and publish to the bus."""
if data.id in self._processed_ids:
return
self._processed_ids.append(data.id)
try:
if data.id in self._processed_ids:
return
self._processed_ids.append(data.id)
if is_group:
chat_id = data.group_openid
user_id = data.author.member_openid
self._chat_type_cache[chat_id] = "group"
else:
chat_id = str(
getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown")
)
user_id = chat_id
self._chat_type_cache[chat_id] = "c2c"
content = (data.content or "").strip()
# the data used by tests don't contain attachments property
# so we use getattr with a default of [] to avoid AttributeError in tests
attachments = getattr(data, "attachments", None) or []
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
# Compose content that always contains actionable saved paths
if recv_lines:
tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]"
file_block = "Received files:\n" + "\n".join(recv_lines)
content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
if not content and not media_paths:
return
if self.config.ack_message:
try:
await self._send_text_only(
chat_id=chat_id,
is_group=is_group,
msg_id=data.id,
content=self.config.ack_message,
if is_group:
chat_id = data.group_openid
user_id = data.author.member_openid
self._chat_type_cache[chat_id] = "group"
else:
chat_id = str(
getattr(data.author, "id", None)
or getattr(data.author, "user_openid", "unknown")
)
except Exception:
logger.debug("QQ ack message failed for chat_id={}", chat_id)
user_id = chat_id
self._chat_type_cache[chat_id] = "c2c"
await self._handle_message(
sender_id=user_id,
chat_id=chat_id,
content=content,
media=media_paths if media_paths else None,
metadata={
"message_id": data.id,
"attachments": att_meta,
},
)
content = (data.content or "").strip()
# the data used by tests don't contain attachments property
# so we use getattr with a default of [] to avoid AttributeError in tests
attachments = getattr(data, "attachments", None) or []
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
# Compose content that always contains actionable saved paths
if recv_lines:
tag = (
"[Image]"
if any(_is_image_name(Path(p).name) for p in media_paths)
else "[File]"
)
file_block = "Received files:\n" + "\n".join(recv_lines)
content = (
f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
)
if not content and not media_paths:
return
if self.config.ack_message:
try:
await self._send_text_only(
chat_id=chat_id,
is_group=is_group,
msg_id=data.id,
content=self.config.ack_message,
)
except Exception:
logger.debug("QQ ack message failed for chat_id={}", chat_id)
await self._handle_message(
sender_id=user_id,
chat_id=chat_id,
content=content,
media=media_paths if media_paths else None,
metadata={
"message_id": data.id,
"attachments": att_meta,
},
)
except Exception:
logger.exception("Error handling QQ inbound message id={}", getattr(data, "id", "?"))
async def _handle_attachments(
self,
@ -520,7 +552,9 @@ class QQChannel(BaseChannel):
return media_paths, recv_lines, att_meta
for att in attachments:
url, filename, ctype = att.url, att.filename, att.content_type
url = getattr(att, "url", None) or ""
filename = getattr(att, "filename", None) or ""
ctype = getattr(att, "content_type", None) or ""
logger.info("Downloading file from QQ: {}", filename or url)
local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename)
@ -555,6 +589,10 @@ class QQChannel(BaseChannel):
Enforces a max download size and writes to a .part temp file
that is atomically renamed on success.
"""
# Handle protocol-relative URLs (e.g. "//multimedia.nt.qq.com/...")
if url.startswith("//"):
url = f"https:{url}"
if not self._http:
self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))

View File

@ -5,6 +5,7 @@ import re
from typing import Any
from loguru import logger
from pydantic import Field
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.socket_mode.websockets import SocketModeClient
@ -13,8 +14,6 @@ from slackify_markdown import slackify_markdown
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from pydantic import Field
from nanobot.channels.base import BaseChannel
from nanobot.config.schema import Base
@ -50,6 +49,9 @@ class SlackChannel(BaseChannel):
name = "slack"
display_name = "Slack"
_SLACK_ID_RE = re.compile(r"^[CDGUW][A-Z0-9]{2,}$")
_SLACK_CHANNEL_REF_RE = re.compile(r"^<#([A-Z0-9]+)(?:\|[^>]+)?>$")
_SLACK_USER_REF_RE = re.compile(r"^<@([A-Z0-9]+)(?:\|[^>]+)?>$")
@classmethod
def default_config(cls) -> dict[str, Any]:
@ -63,6 +65,7 @@ class SlackChannel(BaseChannel):
self._web_client: AsyncWebClient | None = None
self._socket_client: SocketModeClient | None = None
self._bot_user_id: str | None = None
self._target_cache: dict[str, str] = {}
async def start(self) -> None:
"""Start the Slack Socket Mode client."""
@ -113,17 +116,23 @@ class SlackChannel(BaseChannel):
logger.warning("Slack client not running")
return
try:
target_chat_id = await self._resolve_target_chat_id(msg.chat_id)
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
thread_ts = slack_meta.get("thread_ts")
channel_type = slack_meta.get("channel_type")
origin_chat_id = str((slack_meta.get("event", {}) or {}).get("channel") or msg.chat_id)
# Slack DMs don't use threads; channel/group replies may keep thread_ts.
thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None
thread_ts_param = (
thread_ts
if thread_ts and channel_type != "im" and target_chat_id == origin_chat_id
else None
)
# Slack rejects empty text payloads. Keep media-only messages media-only,
# but send a single blank message when the bot has no text or files to send.
if msg.content or not (msg.media or []):
await self._web_client.chat_postMessage(
channel=msg.chat_id,
channel=target_chat_id,
text=self._to_mrkdwn(msg.content) if msg.content else " ",
thread_ts=thread_ts_param,
)
@ -131,7 +140,7 @@ class SlackChannel(BaseChannel):
for media_path in msg.media or []:
try:
await self._web_client.files_upload_v2(
channel=msg.chat_id,
channel=target_chat_id,
file=media_path,
thread_ts=thread_ts_param,
)
@ -141,12 +150,123 @@ class SlackChannel(BaseChannel):
# Update reaction emoji when the final (non-progress) response is sent
if not (msg.metadata or {}).get("_progress"):
event = slack_meta.get("event", {})
await self._update_react_emoji(msg.chat_id, event.get("ts"))
await self._update_react_emoji(origin_chat_id, event.get("ts"))
except Exception as e:
logger.error("Error sending Slack message: {}", e)
raise
async def _resolve_target_chat_id(self, target: str) -> str:
"""Resolve human-friendly Slack targets to concrete IDs when needed."""
if not self._web_client:
return target
target = target.strip()
if not target:
return target
if match := self._SLACK_CHANNEL_REF_RE.fullmatch(target):
return match.group(1)
if match := self._SLACK_USER_REF_RE.fullmatch(target):
return await self._open_dm_for_user(match.group(1))
if self._SLACK_ID_RE.fullmatch(target):
if target.startswith(("U", "W")):
return await self._open_dm_for_user(target)
return target
if target.startswith("#"):
return await self._resolve_channel_name(target[1:])
if target.startswith("@"):
return await self._resolve_user_handle(target[1:])
try:
return await self._resolve_channel_name(target)
except ValueError:
return await self._resolve_user_handle(target)
async def _resolve_channel_name(self, name: str) -> str:
normalized = self._normalize_target_name(name)
if not normalized:
raise ValueError("Slack target channel name is empty")
cache_key = f"channel:{normalized}"
if cache_key in self._target_cache:
return self._target_cache[cache_key]
cursor: str | None = None
while True:
response = await self._web_client.conversations_list(
types="public_channel,private_channel",
exclude_archived=True,
limit=200,
cursor=cursor,
)
for channel in response.get("channels", []):
if self._normalize_target_name(str(channel.get("name") or "")) == normalized:
channel_id = str(channel.get("id") or "")
if channel_id:
self._target_cache[cache_key] = channel_id
return channel_id
cursor = ((response.get("response_metadata") or {}).get("next_cursor") or "").strip()
if not cursor:
break
raise ValueError(
f"Slack channel '{name}' was not found. Use a joined channel name like "
f"'#general' or a concrete channel ID."
)
async def _resolve_user_handle(self, handle: str) -> str:
normalized = self._normalize_target_name(handle)
if not normalized:
raise ValueError("Slack target user handle is empty")
cache_key = f"user:{normalized}"
if cache_key in self._target_cache:
return self._target_cache[cache_key]
cursor: str | None = None
while True:
response = await self._web_client.users_list(limit=200, cursor=cursor)
for member in response.get("members", []):
if self._member_matches_handle(member, normalized):
user_id = str(member.get("id") or "")
if not user_id:
continue
dm_id = await self._open_dm_for_user(user_id)
self._target_cache[cache_key] = dm_id
return dm_id
cursor = ((response.get("response_metadata") or {}).get("next_cursor") or "").strip()
if not cursor:
break
raise ValueError(
f"Slack user '{handle}' was not found. Use '@name' or a concrete DM/channel ID."
)
async def _open_dm_for_user(self, user_id: str) -> str:
response = await self._web_client.conversations_open(users=user_id)
channel_id = str(((response.get("channel") or {}).get("id")) or "")
if not channel_id:
raise ValueError(f"Slack DM target for user '{user_id}' could not be opened.")
return channel_id
@staticmethod
def _normalize_target_name(value: str) -> str:
return value.strip().lstrip("#@").lower()
@classmethod
def _member_matches_handle(cls, member: dict[str, Any], normalized: str) -> bool:
profile = member.get("profile") or {}
candidates = {
str(member.get("name") or ""),
str(profile.get("display_name") or ""),
str(profile.get("display_name_normalized") or ""),
str(profile.get("real_name") or ""),
str(profile.get("real_name_normalized") or ""),
}
return normalized in {cls._normalize_target_name(candidate) for candidate in candidates if candidate}
async def _on_socket_request(
self,
client: SocketModeClient,

View File

@ -166,6 +166,7 @@ def _markdown_to_telegram_html(text: str) -> str:
_SEND_MAX_RETRIES = 3
_SEND_RETRY_BASE_DELAY = 0.5 # seconds, doubled each retry
_STREAM_EDIT_INTERVAL_DEFAULT = 0.6 # min seconds between edit_message_text calls
@dataclass
@ -190,6 +191,7 @@ class TelegramConfig(Base):
connection_pool_size: int = 32
pool_timeout: float = 5.0
streaming: bool = True
stream_edit_interval: float = Field(default=_STREAM_EDIT_INTERVAL_DEFAULT, ge=0.1)
class TelegramChannel(BaseChannel):
@ -219,8 +221,6 @@ class TelegramChannel(BaseChannel):
def default_config(cls) -> dict[str, Any]:
return TelegramConfig().model_dump(by_alias=True)
_STREAM_EDIT_INTERVAL = 0.6 # min seconds between edit_message_text calls
def __init__(self, config: Any, bus: MessageBus):
if isinstance(config, dict):
config = TelegramConfig.model_validate(config)
@ -520,7 +520,10 @@ class TelegramChannel(BaseChannel):
reply_parameters=reply_params,
**(thread_kwargs or {}),
)
except Exception as e:
except BadRequest as e:
# Only fall back to plain text on actual HTML parse/format errors.
# Network errors (TimedOut, NetworkError) should propagate immediately
# to avoid doubling connection demand during pool exhaustion.
logger.warning("HTML parse failed, falling back to plain text: {}", e)
try:
await self._call_with_retry(
@ -567,7 +570,10 @@ class TelegramChannel(BaseChannel):
chat_id=int_chat_id, message_id=buf.message_id,
text=html, parse_mode="HTML",
)
except Exception as e:
except BadRequest as e:
# Only fall back to plain text on actual HTML parse/format errors.
# Network errors (TimedOut, NetworkError) should propagate immediately
# to avoid doubling connection demand during pool exhaustion.
if self._is_not_modified_error(e):
logger.debug("Final stream edit already applied for {}", chat_id)
self._stream_bufs.pop(chat_id, None)
@ -619,7 +625,7 @@ class TelegramChannel(BaseChannel):
except Exception as e:
logger.warning("Stream initial send failed: {}", e)
raise # Let ChannelManager handle retry
elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
elif (now - buf.last_edit) >= self.config.stream_edit_interval:
try:
await self._call_with_retry(
self._app.bot.edit_message_text,

View File

@ -0,0 +1,457 @@
"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
from __future__ import annotations
import asyncio
import email.utils
import hmac
import http
import json
import secrets
import ssl
import time
import uuid
from typing import Any, Self
from urllib.parse import parse_qs, urlparse
from loguru import logger
from pydantic import Field, field_validator, model_validator
from websockets.asyncio.server import ServerConnection, serve
from websockets.datastructures import Headers
from websockets.exceptions import ConnectionClosed
from websockets.http11 import Request as WsRequest, Response
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.schema import Base
def _strip_trailing_slash(path: str) -> str:
if len(path) > 1 and path.endswith("/"):
return path.rstrip("/")
return path or "/"
def _normalize_config_path(path: str) -> str:
return _strip_trailing_slash(path)
class WebSocketConfig(Base):
"""WebSocket server channel configuration.
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
from ``token_issue_path`` are also accepted.
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
nanobot and shares the asyncio loop, use a thread or async HTTP client for GETdo not call
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
``X-Nanobot-Auth: <secret>``.
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
- ``media`` field in outbound messages contains local filesystem paths; remote clients need a
shared filesystem or an HTTP file server to access these files.
"""
enabled: bool = False
host: str = "127.0.0.1"
port: int = 8765
path: str = "/"
token: str = ""
token_issue_path: str = ""
token_issue_secret: str = ""
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
websocket_requires_token: bool = True
allow_from: list[str] = Field(default_factory=lambda: ["*"])
streaming: bool = True
max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216)
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
ssl_certfile: str = ""
ssl_keyfile: str = ""
@field_validator("path")
@classmethod
def path_must_start_with_slash(cls, value: str) -> str:
if not value.startswith("/"):
raise ValueError('path must start with "/"')
return _normalize_config_path(value)
@field_validator("token_issue_path")
@classmethod
def token_issue_path_format(cls, value: str) -> str:
value = value.strip()
if not value:
return ""
if not value.startswith("/"):
raise ValueError('token_issue_path must start with "/"')
return _normalize_config_path(value)
@model_validator(mode="after")
def token_issue_path_differs_from_ws_path(self) -> Self:
if not self.token_issue_path:
return self
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
return self
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
headers = Headers(
[
("Date", email.utils.formatdate(usegmt=True)),
("Connection", "close"),
("Content-Length", str(len(body))),
("Content-Type", "application/json; charset=utf-8"),
]
)
reason = http.HTTPStatus(status).phrase
return Response(status, reason, headers, body)
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
"""Parse normalized path and query parameters in one pass."""
parsed = urlparse("ws://x" + path_with_query)
path = _strip_trailing_slash(parsed.path or "/")
return path, parse_qs(parsed.query)
def _normalize_http_path(path_with_query: str) -> str:
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
return _parse_request_path(path_with_query)[0]
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
return _parse_request_path(path_with_query)[1]
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
"""Return the first value for *key*, or None."""
values = query.get(key)
return values[0] if values else None
def _parse_inbound_payload(raw: str) -> str | None:
"""Parse a client frame into text; return None for empty or unrecognized content."""
text = raw.strip()
if not text:
return None
if text.startswith("{"):
try:
data = json.loads(text)
except json.JSONDecodeError:
return text
if isinstance(data, dict):
for key in ("content", "text", "message"):
value = data.get(key)
if isinstance(value, str) and value.strip():
return value
return None
return None
return text
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
if not configured_secret:
return True
authorization = headers.get("Authorization") or headers.get("authorization")
if authorization and authorization.lower().startswith("bearer "):
supplied = authorization[7:].strip()
return hmac.compare_digest(supplied, configured_secret)
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
if not header_token:
return False
return hmac.compare_digest(header_token.strip(), configured_secret)
class WebSocketChannel(BaseChannel):
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
name = "websocket"
display_name = "WebSocket"
def __init__(self, config: Any, bus: MessageBus):
if isinstance(config, dict):
config = WebSocketConfig.model_validate(config)
super().__init__(config, bus)
self.config: WebSocketConfig = config
self._connections: dict[str, Any] = {}
self._issued_tokens: dict[str, float] = {}
self._stop_event: asyncio.Event | None = None
self._server_task: asyncio.Task[None] | None = None
@classmethod
def default_config(cls) -> dict[str, Any]:
return WebSocketConfig().model_dump(by_alias=True)
def _expected_path(self) -> str:
return _normalize_config_path(self.config.path)
def _build_ssl_context(self) -> ssl.SSLContext | None:
cert = self.config.ssl_certfile.strip()
key = self.config.ssl_keyfile.strip()
if not cert and not key:
return None
if not cert or not key:
raise ValueError(
"websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
)
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
ctx.load_cert_chain(certfile=cert, keyfile=key)
return ctx
_MAX_ISSUED_TOKENS = 10_000
def _purge_expired_issued_tokens(self) -> None:
now = time.monotonic()
for token_key, expiry in list(self._issued_tokens.items()):
if now > expiry:
self._issued_tokens.pop(token_key, None)
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
"""Validate and consume one issued token (single use per connection attempt).
Uses single-step pop to minimize the window between lookup and removal;
safe under asyncio's single-threaded cooperative model.
"""
if not token_value:
return False
self._purge_expired_issued_tokens()
expiry = self._issued_tokens.pop(token_value, None)
if expiry is None:
return False
if time.monotonic() > expiry:
return False
return True
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
secret = self.config.token_issue_secret.strip()
if secret:
if not _issue_route_secret_matches(request.headers, secret):
return connection.respond(401, "Unauthorized")
else:
logger.warning(
"websocket: token_issue_path is set but token_issue_secret is empty; "
"any client can obtain connection tokens — set token_issue_secret for production."
)
self._purge_expired_issued_tokens()
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
logger.error(
"websocket: too many outstanding issued tokens ({}), rejecting issuance",
len(self._issued_tokens),
)
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
return _http_json_response(
{"token": token_value, "expires_in": self.config.token_ttl_s}
)
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
supplied = _query_first(query, "token")
static_token = self.config.token.strip()
if static_token:
if supplied and hmac.compare_digest(supplied, static_token):
return None
if supplied and self._take_issued_token_if_valid(supplied):
return None
return connection.respond(401, "Unauthorized")
if self.config.websocket_requires_token:
if supplied and self._take_issued_token_if_valid(supplied):
return None
return connection.respond(401, "Unauthorized")
if supplied:
self._take_issued_token_if_valid(supplied)
return None
async def start(self) -> None:
self._running = True
self._stop_event = asyncio.Event()
ssl_context = self._build_ssl_context()
scheme = "wss" if ssl_context else "ws"
async def process_request(
connection: ServerConnection,
request: WsRequest,
) -> Any:
got, _ = _parse_request_path(request.path)
if self.config.token_issue_path:
issue_expected = _normalize_config_path(self.config.token_issue_path)
if got == issue_expected:
return self._handle_token_issue_http(connection, request)
expected_ws = self._expected_path()
if got != expected_ws:
return connection.respond(404, "Not Found")
# Early reject before WebSocket upgrade to avoid unnecessary overhead;
# _handle_message() performs a second check as defense-in-depth.
query = _parse_query(request.path)
client_id = _query_first(query, "client_id") or ""
if len(client_id) > 128:
client_id = client_id[:128]
if not self.is_allowed(client_id):
return connection.respond(403, "Forbidden")
return self._authorize_websocket_handshake(connection, query)
async def handler(connection: ServerConnection) -> None:
await self._connection_loop(connection)
logger.info(
"WebSocket server listening on {}://{}:{}{}",
scheme,
self.config.host,
self.config.port,
self.config.path,
)
if self.config.token_issue_path:
logger.info(
"WebSocket token issue route: {}://{}:{}{}",
scheme,
self.config.host,
self.config.port,
_normalize_config_path(self.config.token_issue_path),
)
async def runner() -> None:
async with serve(
handler,
self.config.host,
self.config.port,
process_request=process_request,
max_size=self.config.max_message_bytes,
ping_interval=self.config.ping_interval_s,
ping_timeout=self.config.ping_timeout_s,
ssl=ssl_context,
):
assert self._stop_event is not None
await self._stop_event.wait()
self._server_task = asyncio.create_task(runner())
await self._server_task
async def _connection_loop(self, connection: Any) -> None:
request = connection.request
path_part = request.path if request else "/"
_, query = _parse_request_path(path_part)
client_id_raw = _query_first(query, "client_id")
client_id = client_id_raw.strip() if client_id_raw else ""
if not client_id:
client_id = f"anon-{uuid.uuid4().hex[:12]}"
elif len(client_id) > 128:
logger.warning("websocket: client_id too long ({} chars), truncating", len(client_id))
client_id = client_id[:128]
chat_id = str(uuid.uuid4())
try:
await connection.send(
json.dumps(
{
"event": "ready",
"chat_id": chat_id,
"client_id": client_id,
},
ensure_ascii=False,
)
)
# Register only after ready is successfully sent to avoid out-of-order sends
self._connections[chat_id] = connection
async for raw in connection:
if isinstance(raw, bytes):
try:
raw = raw.decode("utf-8")
except UnicodeDecodeError:
logger.warning("websocket: ignoring non-utf8 binary frame")
continue
content = _parse_inbound_payload(raw)
if content is None:
continue
await self._handle_message(
sender_id=client_id,
chat_id=chat_id,
content=content,
metadata={"remote": getattr(connection, "remote_address", None)},
)
except Exception as e:
logger.debug("websocket connection ended: {}", e)
finally:
self._connections.pop(chat_id, None)
async def stop(self) -> None:
if not self._running:
return
self._running = False
if self._stop_event:
self._stop_event.set()
if self._server_task:
try:
await self._server_task
except Exception as e:
logger.warning("websocket: server task error during shutdown: {}", e)
self._server_task = None
self._connections.clear()
self._issued_tokens.clear()
async def _safe_send(self, chat_id: str, raw: str, *, label: str = "") -> None:
"""Send a raw frame, cleaning up dead connections on ConnectionClosed."""
connection = self._connections.get(chat_id)
if connection is None:
return
try:
await connection.send(raw)
except ConnectionClosed:
self._connections.pop(chat_id, None)
logger.warning("websocket{}connection gone for chat_id={}", label, chat_id)
except Exception as e:
logger.error("websocket{}send failed: {}", label, e)
raise
async def send(self, msg: OutboundMessage) -> None:
connection = self._connections.get(msg.chat_id)
if connection is None:
logger.warning("websocket: no active connection for chat_id={}", msg.chat_id)
return
payload: dict[str, Any] = {
"event": "message",
"text": msg.content,
}
if msg.media:
payload["media"] = msg.media
if msg.reply_to:
payload["reply_to"] = msg.reply_to
raw = json.dumps(payload, ensure_ascii=False)
await self._safe_send(msg.chat_id, raw, label=" ")
async def send_delta(
self,
chat_id: str,
delta: str,
metadata: dict[str, Any] | None = None,
) -> None:
if self._connections.get(chat_id) is None:
return
meta = metadata or {}
if meta.get("_stream_end"):
body: dict[str, Any] = {"event": "stream_end"}
else:
body = {
"event": "delta",
"text": delta,
}
if meta.get("_stream_id") is not None:
body["stream_id"] = meta["_stream_id"]
raw = json.dumps(body, ensure_ascii=False)
await self._safe_send(chat_id, raw, label=" stream ")

View File

@ -1,9 +1,13 @@
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
import asyncio
import base64
import hashlib
import importlib.util
import os
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any
from loguru import logger
@ -17,6 +21,37 @@ from pydantic import Field
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
# Upload safety limits (matching QQ channel defaults)
WECOM_UPLOAD_MAX_BYTES = 1024 * 1024 * 200 # 200MB
# Replace unsafe characters with "_", keep Chinese and common safe punctuation.
_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]()【】\u4e00-\u9fff]+", re.UNICODE)
def _sanitize_filename(name: str) -> str:
"""Sanitize filename to avoid traversal and problematic chars."""
name = (name or "").strip()
name = Path(name).name
name = _SAFE_NAME_RE.sub("_", name).strip("._ ")
return name
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
_VIDEO_EXTS = {".mp4", ".avi", ".mov"}
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg"}
def _guess_wecom_media_type(filename: str) -> str:
"""Classify file extension as WeCom media_type string."""
ext = Path(filename).suffix.lower()
if ext in _IMAGE_EXTS:
return "image"
if ext in _VIDEO_EXTS:
return "video"
if ext in _AUDIO_EXTS:
return "voice"
return "file"
class WecomConfig(Base):
"""WeCom (Enterprise WeChat) AI Bot channel configuration."""
@ -217,6 +252,7 @@ class WecomChannel(BaseChannel):
chat_id = body.get("chatid", sender_id)
content_parts = []
media_paths: list[str] = []
if msg_type == "text":
text = body.get("text", {}).get("content", "")
@ -232,7 +268,8 @@ class WecomChannel(BaseChannel):
file_path = await self._download_and_save_media(file_url, aes_key, "image")
if file_path:
filename = os.path.basename(file_path)
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
content_parts.append(f"[image: {filename}]")
media_paths.append(file_path)
else:
content_parts.append("[image: download failed]")
else:
@ -256,7 +293,8 @@ class WecomChannel(BaseChannel):
if file_url and aes_key:
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
if file_path:
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
content_parts.append(f"[file: {file_name}]")
media_paths.append(file_path)
else:
content_parts.append(f"[file: {file_name}: download failed]")
else:
@ -286,12 +324,11 @@ class WecomChannel(BaseChannel):
self._chat_frames[chat_id] = frame
# Forward to message bus
# Note: media paths are included in content for broader model compatibility
await self._handle_message(
sender_id=sender_id,
chat_id=chat_id,
content=content,
media=None,
media=media_paths or None,
metadata={
"message_id": msg_id,
"msg_type": msg_type,
@ -322,13 +359,21 @@ class WecomChannel(BaseChannel):
logger.warning("Failed to download media from WeCom")
return None
if len(data) > WECOM_UPLOAD_MAX_BYTES:
logger.warning(
"WeCom inbound media too large: {} bytes (max {})",
len(data),
WECOM_UPLOAD_MAX_BYTES,
)
return None
media_dir = get_media_dir("wecom")
if not filename:
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
filename = os.path.basename(filename)
filename = _sanitize_filename(filename)
file_path = media_dir / filename
file_path.write_bytes(data)
await asyncio.to_thread(file_path.write_bytes, data)
logger.debug("Downloaded {} to {}", media_type, file_path)
return str(file_path)
@ -336,6 +381,100 @@ class WecomChannel(BaseChannel):
logger.error("Error downloading media: {}", e)
return None
async def _upload_media_ws(
self, client: Any, file_path: str,
) -> "tuple[str, str] | tuple[None, None]":
"""Upload a local file to WeCom via WebSocket 3-step protocol (base64).
Uses the WeCom WebSocket upload commands directly via
``client._ws_manager.send_reply()``:
``aibot_upload_media_init`` upload_id
``aibot_upload_media_chunk`` × N (512 KB raw per chunk, base64)
``aibot_upload_media_finish`` media_id
Returns (media_id, media_type) on success, (None, None) on failure.
"""
from wecom_aibot_sdk.utils import generate_req_id as _gen_req_id
try:
fname = os.path.basename(file_path)
media_type = _guess_wecom_media_type(fname)
# Read file size and data in a thread to avoid blocking the event loop
def _read_file():
file_size = os.path.getsize(file_path)
if file_size > WECOM_UPLOAD_MAX_BYTES:
raise ValueError(
f"File too large: {file_size} bytes (max {WECOM_UPLOAD_MAX_BYTES})"
)
with open(file_path, "rb") as f:
return file_size, f.read()
file_size, data = await asyncio.to_thread(_read_file)
# MD5 is used for file integrity only, not cryptographic security
md5_hash = hashlib.md5(data).hexdigest()
CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64)
mv = memoryview(data)
chunk_list = [bytes(mv[i : i + CHUNK_SIZE]) for i in range(0, file_size, CHUNK_SIZE)]
n_chunks = len(chunk_list)
del mv, data
# Step 1: init
req_id = _gen_req_id("upload_init")
resp = await client._ws_manager.send_reply(req_id, {
"type": media_type,
"filename": fname,
"total_size": file_size,
"total_chunks": n_chunks,
"md5": md5_hash,
}, "aibot_upload_media_init")
if resp.errcode != 0:
logger.warning("WeCom upload init failed ({}): {}", resp.errcode, resp.errmsg)
return None, None
upload_id = resp.body.get("upload_id") if resp.body else None
if not upload_id:
logger.warning("WeCom upload init: no upload_id in response")
return None, None
# Step 2: send chunks
for i, chunk in enumerate(chunk_list):
req_id = _gen_req_id("upload_chunk")
resp = await client._ws_manager.send_reply(req_id, {
"upload_id": upload_id,
"chunk_index": i,
"base64_data": base64.b64encode(chunk).decode(),
}, "aibot_upload_media_chunk")
if resp.errcode != 0:
logger.warning("WeCom upload chunk {} failed ({}): {}", i, resp.errcode, resp.errmsg)
return None, None
# Step 3: finish
req_id = _gen_req_id("upload_finish")
resp = await client._ws_manager.send_reply(req_id, {
"upload_id": upload_id,
}, "aibot_upload_media_finish")
if resp.errcode != 0:
logger.warning("WeCom upload finish failed ({}): {}", resp.errcode, resp.errmsg)
return None, None
media_id = resp.body.get("media_id") if resp.body else None
if not media_id:
logger.warning("WeCom upload finish: no media_id in response body={}", resp.body)
return None, None
suffix = "..." if len(media_id) > 16 else ""
logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + suffix)
return media_id, media_type
except ValueError as e:
logger.warning("WeCom upload skipped for {}: {}", file_path, e)
return None, None
except Exception as e:
logger.error("WeCom _upload_media_ws error for {}: {}", file_path, e)
return None, None
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through WeCom."""
if not self._client:
@ -343,29 +482,59 @@ class WecomChannel(BaseChannel):
return
try:
content = msg.content.strip()
if not content:
return
content = (msg.content or "").strip()
is_progress = bool(msg.metadata.get("_progress"))
# Get the stored frame for this chat
frame = self._chat_frames.get(msg.chat_id)
if not frame:
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
# Send media files via WebSocket upload
for file_path in msg.media or []:
if not os.path.isfile(file_path):
logger.warning("WeCom media file not found: {}", file_path)
continue
media_id, media_type = await self._upload_media_ws(self._client, file_path)
if media_id:
if frame:
await self._client.reply(frame, {
"msgtype": media_type,
media_type: {"media_id": media_id},
})
else:
await self._client.send_message(msg.chat_id, {
"msgtype": media_type,
media_type: {"media_id": media_id},
})
logger.debug("WeCom sent {}{}", media_type, msg.chat_id)
else:
content += f"\n[file upload failed: {os.path.basename(file_path)}]"
if not content:
return
# Use streaming reply for better UX
stream_id = self._generate_req_id("stream")
if frame:
# Both progress and final messages must use reply_stream (cmd="aibot_respond_msg").
# The plain reply() uses cmd="reply" which does not support "text" msgtype
# and causes errcode=40008 from WeCom API.
stream_id = self._generate_req_id("stream")
await self._client.reply_stream(
frame,
stream_id,
content,
finish=not is_progress,
)
logger.debug(
"WeCom {} sent to {}",
"progress" if is_progress else "message",
msg.chat_id,
)
else:
# No frame (e.g. cron push): proactive send only supports markdown
await self._client.send_message(msg.chat_id, {
"msgtype": "markdown",
"markdown": {"content": content},
})
logger.info("WeCom proactive send to {}", msg.chat_id)
# Send as streaming message with finish=True
await self._client.reply_stream(
frame,
stream_id,
content,
finish=True,
)
logger.debug("WeCom message sent to {}", msg.chat_id)
except Exception as e:
logger.error("Error sending WeCom message: {}", e)
raise
except Exception:
logger.exception("Error sending WeCom message to chat_id={}", msg.chat_id)

View File

@ -985,7 +985,43 @@ class WeixinChannel(BaseChannel):
for media_path in (msg.media or []):
try:
await self._send_media_file(msg.chat_id, media_path, ctx_token)
except (httpx.TimeoutException, httpx.TransportError) as net_err:
# Network/transport errors: do NOT fall back to text —
# the text send would also likely fail, and the outer
# except will re-raise so ChannelManager retries properly.
logger.error(
"Network error sending WeChat media {}: {}",
media_path,
net_err,
)
raise
except httpx.HTTPStatusError as http_err:
status_code = (
http_err.response.status_code
if http_err.response is not None
else 0
)
if status_code >= 500:
# Server-side / retryable HTTP error — same as network.
logger.error(
"Server error ({} {}) sending WeChat media {}: {}",
status_code,
http_err.response.reason_phrase
if http_err.response is not None
else "",
media_path,
http_err,
)
raise
# 4xx client errors are NOT retryable — fall back to text.
filename = Path(media_path).name
logger.error("Failed to send WeChat media {}: {}", media_path, http_err)
await self._send_text(
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
)
except Exception as e:
# Non-network errors (format, file-not-found, etc.):
# notify the user via text fallback.
filename = Path(media_path).name
logger.error("Failed to send WeChat media {}: {}", media_path, e)
# Notify user about failure via text

View File

@ -590,6 +590,9 @@ def serve(
mcp_servers=runtime_config.tools.mcp_servers,
channels_config=runtime_config.channels,
timezone=runtime_config.agents.defaults.timezone,
unified_session=runtime_config.agents.defaults.unified_session,
disabled_skills=runtime_config.agents.defaults.disabled_skills,
session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes,
)
model_name = runtime_config.agents.defaults.model
@ -681,6 +684,9 @@ def gateway(
mcp_servers=config.tools.mcp_servers,
channels_config=config.channels,
timezone=config.agents.defaults.timezone,
unified_session=config.agents.defaults.unified_session,
disabled_skills=config.agents.defaults.disabled_skills,
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
)
# Set cron callback (needs agent)
@ -815,6 +821,48 @@ def gateway(
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
async def _health_server(host: str, health_port: int):
"""Lightweight HTTP health endpoint on the gateway port."""
import json as _json
async def handle(reader, writer):
try:
data = await asyncio.wait_for(reader.read(4096), timeout=5)
except (asyncio.TimeoutError, ConnectionError):
writer.close()
return
request_line = data.split(b"\r\n", 1)[0].decode("utf-8", errors="replace")
method, path = "", ""
parts = request_line.split(" ")
if len(parts) >= 2:
method, path = parts[0], parts[1]
if method == "GET" and path == "/health":
body = _json.dumps({"status": "ok"})
resp = (
f"HTTP/1.0 200 OK\r\n"
f"Content-Type: application/json\r\n"
f"Content-Length: {len(body)}\r\n"
f"\r\n{body}"
)
else:
body = "Not Found"
resp = (
f"HTTP/1.0 404 Not Found\r\n"
f"Content-Type: text/plain\r\n"
f"Content-Length: {len(body)}\r\n"
f"\r\n{body}"
)
writer.write(resp.encode())
await writer.drain()
writer.close()
server = await asyncio.start_server(handle, host, health_port)
console.print(f"[green]✓[/green] Health endpoint: http://{host}:{health_port}/health")
async with server:
await server.serve_forever()
# Register Dream system job (always-on, idempotent on restart)
dream_cfg = config.agents.defaults.dream
if dream_cfg.model_override:
@ -837,6 +885,7 @@ def gateway(
await asyncio.gather(
agent.run(),
channels.start_all(),
_health_server(config.gateway.host, port),
)
except KeyboardInterrupt:
console.print("\nShutting down...")
@ -912,6 +961,9 @@ def agent(
mcp_servers=config.tools.mcp_servers,
channels_config=config.channels,
timezone=config.agents.defaults.timezone,
unified_session=config.agents.defaults.unified_session,
disabled_skills=config.agents.defaults.disabled_skills,
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
)
restart_notice = consume_restart_notice_from_env()
if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):
@ -1116,7 +1168,7 @@ def channels_status(
table = Table(title="Channel Status")
table.add_column("Channel", style="cyan")
table.add_column("Enabled", style="green")
table.add_column("Enabled")
for name, cls in sorted(discover_all().items()):
section = getattr(config.channels, name, None)
@ -1251,7 +1303,7 @@ def plugins_list():
table = Table(title="Channel Plugins")
table.add_column("Name", style="cyan")
table.add_column("Source", style="magenta")
table.add_column("Enabled", style="green")
table.add_column("Enabled")
for name in sorted(all_channels):
cls = all_channels[name]

View File

@ -76,6 +76,14 @@ class AgentDefaults(Base):
provider_retry_mode: Literal["standard", "persistent"] = "standard"
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
unified_session: bool = False # Share one session across all channels (single-user multi-device)
disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"])
session_ttl_minutes: int = Field(
default=0,
ge=0,
validation_alias=AliasChoices("idleCompactAfterMinutes", "sessionTtlMinutes"),
serialization_alias="idleCompactAfterMinutes",
) # Auto-compact idle threshold in minutes (0 = disabled)
dream: DreamConfig = Field(default_factory=DreamConfig)
@ -144,7 +152,7 @@ class ApiConfig(Base):
class GatewayConfig(Base):
"""Gateway/server configuration."""
host: str = "0.0.0.0"
host: str = "127.0.0.1" # Safer default: local-only bind.
port: int = 18790
heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig)
@ -152,7 +160,7 @@ class GatewayConfig(Base):
class WebSearchConfig(Base):
"""Web search tool configuration."""
provider: str = "duckduckgo" # brave, tavily, duckduckgo, searxng, jina
provider: str = "duckduckgo" # brave, tavily, duckduckgo, searxng, jina, kagi
api_key: str = ""
base_url: str = "" # SearXNG base URL
max_results: int = 5
@ -176,6 +184,7 @@ class ExecToolConfig(Base):
timeout: int = 60
path_append: str = ""
sandbox: str = "" # sandbox backend: "" (none) or "bwrap"
allowed_env_keys: list[str] = Field(default_factory=list) # Env var names to pass through to subprocess (e.g. ["GOPATH", "JAVA_HOME"])
class MCPServerConfig(Base):
"""MCP server connection configuration (stdio or HTTP)."""

View File

@ -4,10 +4,12 @@ import asyncio
import json
import time
import uuid
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Coroutine, Literal
from filelock import FileLock
from loguru import logger
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronRunRecord, CronSchedule, CronStore
@ -69,28 +71,26 @@ class CronService:
self,
store_path: Path,
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None,
max_sleep_ms: int = 300_000, # 5 minutes
):
self.store_path = store_path
self._action_path = store_path.parent / "action.jsonl"
self._lock = FileLock(str(self._action_path.parent) + ".lock")
self.on_job = on_job
self._store: CronStore | None = None
self._last_mtime: float = 0.0
self._timer_task: asyncio.Task | None = None
self._running = False
self._timer_active = False
self.max_sleep_ms = max_sleep_ms
def _load_store(self) -> CronStore:
"""Load jobs from disk. Reloads automatically if file was modified externally."""
if self._store and self.store_path.exists():
mtime = self.store_path.stat().st_mtime
if mtime != self._last_mtime:
logger.info("Cron: jobs.json modified externally, reloading")
self._store = None
if self._store:
return self._store
def _load_jobs(self) -> tuple[list[CronJob], int]:
jobs = []
version = 1
if self.store_path.exists():
try:
data = json.loads(self.store_path.read_text(encoding="utf-8"))
jobs = []
version = data.get("version", 1)
for j in data.get("jobs", []):
jobs.append(CronJob(
id=j["id"],
@ -129,13 +129,57 @@ class CronService:
updated_at_ms=j.get("updatedAtMs", 0),
delete_after_run=j.get("deleteAfterRun", False),
))
self._store = CronStore(jobs=jobs)
self._last_mtime = self.store_path.stat().st_mtime
except Exception as e:
logger.warning("Failed to load cron store: {}", e)
self._store = CronStore()
else:
self._store = CronStore()
return jobs, version
def _merge_action(self):
if not self._action_path.exists():
return
jobs_map = {j.id: j for j in self._store.jobs}
def _update(params: dict):
j = CronJob.from_dict(params)
jobs_map[j.id] = j
def _del(params: dict):
if job_id := params.get("job_id"):
jobs_map.pop(job_id)
with self._lock:
with open(self._action_path, "r", encoding="utf-8") as f:
changed = False
for line in f:
try:
line = line.strip()
action = json.loads(line)
if "action" not in action:
continue
if action["action"] == "del":
_del(action.get("params", {}))
else:
_update(action.get("params", {}))
changed = True
except Exception as exp:
logger.debug(f"load action line error: {exp}")
continue
self._store.jobs = list(jobs_map.values())
if self._running and changed:
self._action_path.write_text("", encoding="utf-8")
self._save_store()
return
def _load_store(self) -> CronStore:
"""Load jobs from disk. Reloads automatically if file was modified externally.
- Reload every time because it needs to merge operations on the jobs object from other instances.
- During _on_timer execution, return the existing store to prevent concurrent
_load_store calls (e.g. from list_jobs polling) from replacing it mid-execution.
"""
if self._timer_active and self._store:
return self._store
jobs, version = self._load_jobs()
self._store = CronStore(version=version, jobs=jobs)
self._merge_action()
return self._store
@ -230,11 +274,14 @@ class CronService:
if self._timer_task:
self._timer_task.cancel()
next_wake = self._get_next_wake_ms()
if not next_wake or not self._running:
if not self._running:
return
delay_ms = max(0, next_wake - _now_ms())
next_wake = self._get_next_wake_ms()
if next_wake is None:
delay_ms = self.max_sleep_ms
else:
delay_ms = min(self.max_sleep_ms, max(0, next_wake - _now_ms()))
delay_s = delay_ms / 1000
async def tick():
@ -248,18 +295,23 @@ class CronService:
"""Handle timer tick - run due jobs."""
self._load_store()
if not self._store:
self._arm_timer()
return
now = _now_ms()
due_jobs = [
j for j in self._store.jobs
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
]
self._timer_active = True
try:
now = _now_ms()
due_jobs = [
j for j in self._store.jobs
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
]
for job in due_jobs:
await self._execute_job(job)
for job in due_jobs:
await self._execute_job(job)
self._save_store()
self._save_store()
finally:
self._timer_active = False
self._arm_timer()
async def _execute_job(self, job: CronJob) -> None:
@ -303,6 +355,13 @@ class CronService:
# Compute next run
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
def _append_action(self, action: Literal["add", "del", "update"], params: dict):
self.store_path.parent.mkdir(parents=True, exist_ok=True)
with self._lock:
with open(self._action_path, "a", encoding="utf-8") as f:
f.write(json.dumps({"action": action, "params": params}, ensure_ascii=False) + "\n")
# ========== Public API ==========
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
@ -322,7 +381,6 @@ class CronService:
delete_after_run: bool = False,
) -> CronJob:
"""Add a new job."""
store = self._load_store()
_validate_schedule_for_add(schedule)
now = _now_ms()
@ -343,10 +401,13 @@ class CronService:
updated_at_ms=now,
delete_after_run=delete_after_run,
)
store.jobs.append(job)
self._save_store()
self._arm_timer()
if self._running:
store = self._load_store()
store.jobs.append(job)
self._save_store()
self._arm_timer()
else:
self._append_action("add", asdict(job))
logger.info("Cron: added job '{}' ({})", name, job.id)
return job
@ -380,8 +441,11 @@ class CronService:
removed = len(store.jobs) < before
if removed:
self._save_store()
self._arm_timer()
if self._running:
self._save_store()
self._arm_timer()
else:
self._append_action("del", {"job_id": job_id})
logger.info("Cron: removed job {}", job_id)
return "removed"
@ -398,23 +462,85 @@ class CronService:
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
else:
job.state.next_run_at_ms = None
self._save_store()
self._arm_timer()
if self._running:
self._save_store()
self._arm_timer()
else:
self._append_action("update", asdict(job))
return job
return None
async def run_job(self, job_id: str, force: bool = False) -> bool:
"""Manually run a job."""
def update_job(
self,
job_id: str,
*,
name: str | None = None,
schedule: CronSchedule | None = None,
message: str | None = None,
deliver: bool | None = None,
channel: str | None = ...,
to: str | None = ...,
delete_after_run: bool | None = None,
) -> CronJob | Literal["not_found", "protected"]:
"""Update mutable fields of an existing job. System jobs cannot be updated.
For ``channel`` and ``to``, pass an explicit value (including ``None``)
to update; omit (sentinel ``...``) to leave unchanged.
"""
store = self._load_store()
for job in store.jobs:
if job.id == job_id:
if not force and not job.enabled:
return False
await self._execute_job(job)
self._save_store()
job = next((j for j in store.jobs if j.id == job_id), None)
if job is None:
return "not_found"
if job.payload.kind == "system_event":
return "protected"
if schedule is not None:
_validate_schedule_for_add(schedule)
job.schedule = schedule
if name is not None:
job.name = name
if message is not None:
job.payload.message = message
if deliver is not None:
job.payload.deliver = deliver
if channel is not ...:
job.payload.channel = channel
if to is not ...:
job.payload.to = to
if delete_after_run is not None:
job.delete_after_run = delete_after_run
job.updated_at_ms = _now_ms()
if job.enabled:
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
if self._running:
self._save_store()
self._arm_timer()
else:
self._append_action("update", asdict(job))
logger.info("Cron: updated job '{}' ({})", job.name, job.id)
return job
async def run_job(self, job_id: str, force: bool = False) -> bool:
"""Manually run a job without disturbing the service's running state."""
was_running = self._running
self._running = True
try:
store = self._load_store()
for job in store.jobs:
if job.id == job_id:
if not force and not job.enabled:
return False
await self._execute_job(job)
self._save_store()
return True
return False
finally:
self._running = was_running
if was_running:
self._arm_timer()
return True
return False
def get_job(self, job_id: str) -> CronJob | None:
"""Get a job by ID."""

View File

@ -61,6 +61,18 @@ class CronJob:
updated_at_ms: int = 0
delete_after_run: bool = False
@classmethod
def from_dict(cls, kwargs: dict):
state_kwargs = dict(kwargs.get("state", {}))
state_kwargs["run_history"] = [
record if isinstance(record, CronRunRecord) else CronRunRecord(**record)
for record in state_kwargs.get("run_history", [])
]
kwargs["schedule"] = CronSchedule(**kwargs.get("schedule", {"kind": "every"}))
kwargs["payload"] = CronPayload(**kwargs.get("payload", {}))
kwargs["state"] = CronJobState(**state_kwargs)
return cls(**kwargs)
@dataclass
class CronStore:

View File

@ -81,6 +81,9 @@ class Nanobot:
restrict_to_workspace=config.tools.restrict_to_workspace,
mcp_servers=config.tools.mcp_servers,
timezone=defaults.timezone,
unified_session=defaults.unified_session,
disabled_skills=defaults.disabled_skills,
session_ttl_minutes=defaults.session_ttl_minutes,
)
return cls(loop)

View File

@ -353,6 +353,64 @@ class LLMProvider(ABC):
# Unknown 429 defaults to WAIT+retry.
return True
@staticmethod
def _enforce_role_alternation(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Merge consecutive same-role messages and drop trailing assistant messages.
Some providers (OpenAI-compat, Azure, vLLM, Ollama, etc.) reject requests
where the last message is 'assistant' (prefill not supported) or two
consecutive non-system messages share the same role.
"""
if not messages:
return messages
merged: list[dict[str, Any]] = []
for msg in messages:
role = msg.get("role")
if (
merged
and role != "system"
and role not in ("tool",)
and merged[-1].get("role") == role
and role in ("user", "assistant")
):
prev = merged[-1]
if role == "assistant":
prev_has_tools = bool(prev.get("tool_calls"))
curr_has_tools = bool(msg.get("tool_calls"))
if curr_has_tools:
merged[-1] = dict(msg)
continue
if prev_has_tools:
continue
prev_content = prev.get("content") or ""
curr_content = msg.get("content") or ""
if isinstance(prev_content, str) and isinstance(curr_content, str):
prev["content"] = (prev_content + "\n\n" + curr_content).strip()
else:
merged[-1] = dict(msg)
else:
merged.append(dict(msg))
last_popped = None
while merged and merged[-1].get("role") == "assistant":
last_popped = merged.pop()
# If removing trailing assistant messages left only system messages,
# the request would be invalid for most providers (e.g. Zhipu/GLM
# error 1214). Recover by converting the last popped assistant
# message to a user message so the LLM can still see the content.
if (
merged
and last_popped is not None
and not any(m.get("role") in ("user", "tool") for m in merged)
):
recovered = dict(last_popped)
recovered["role"] = "user"
merged.append(recovered)
return merged
@staticmethod
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
@ -375,6 +433,26 @@ class LLMProvider(ABC):
result.append(msg)
return result if found else None
@staticmethod
def _strip_image_content_inplace(messages: list[dict[str, Any]]) -> bool:
"""Replace image_url blocks with text placeholder *in-place*.
Mutates the content lists of the original message dicts so that
callers holding references to those dicts also see the stripped
version.
"""
found = False
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
for i, b in enumerate(content):
if isinstance(b, dict) and b.get("type") == "image_url":
path = (b.get("_meta") or {}).get("path", "")
placeholder = image_placeholder_text(path, empty="[image omitted]")
content[i] = {"type": "text", "text": placeholder}
found = True
return found
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
"""Call chat() and convert unexpected exceptions to error responses."""
try:
@ -626,7 +704,12 @@ class LLMProvider(ABC):
)
retry_kw = dict(kw)
retry_kw["messages"] = stripped
return await call(**retry_kw)
result = await call(**retry_kw)
# Permanently strip images from the original messages so
# subsequent iterations do not repeat the error-retry cycle.
if result.finish_reason != "error":
self._strip_image_content_inplace(original_messages)
return result
return response
if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT:

View File

@ -26,6 +26,12 @@ else:
from openai import AsyncOpenAI
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.openai_responses import (
consume_sdk_stream,
convert_messages,
convert_tools,
parse_response_output,
)
if TYPE_CHECKING:
from nanobot.providers.registry import ProviderSpec
@ -113,6 +119,14 @@ def _uses_openrouter_attribution(spec: "ProviderSpec | None", api_base: str | No
return bool(api_base and "openrouter" in api_base.lower())
def _is_direct_openai_base(api_base: str | None) -> bool:
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
if not api_base:
return True
normalized = api_base.strip().lower().rstrip("/")
return "api.openai.com" in normalized and "openrouter" not in normalized
class OpenAICompatProvider(LLMProvider):
"""Unified provider for all OpenAI-compatible APIs.
@ -137,6 +151,7 @@ class OpenAICompatProvider(LLMProvider):
self._setup_env(api_key, api_base)
effective_base = api_base or (spec.default_api_base if spec else None) or None
self._effective_base = effective_base
default_headers = {"x-session-affinity": uuid.uuid4().hex}
if _uses_openrouter_attribution(spec, effective_base):
default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
@ -228,9 +243,13 @@ class OpenAICompatProvider(LLMProvider):
tc_clean["id"] = map_id(tc_clean.get("id"))
normalized.append(tc_clean)
clean["tool_calls"] = normalized
if clean.get("role") == "assistant":
# Some OpenAI-compatible gateways reject assistant messages
# that mix non-empty content with tool_calls.
clean["content"] = None
if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_id(clean["tool_call_id"])
return sanitized
return self._enforce_role_alternation(sanitized)
# ------------------------------------------------------------------
# Build kwargs
@ -321,6 +340,88 @@ class OpenAICompatProvider(LLMProvider):
return kwargs
def _should_use_responses_api(
self,
model: str | None,
reasoning_effort: str | None,
) -> bool:
"""Use Responses API only for direct OpenAI requests that benefit from it."""
if self._spec and self._spec.name != "openai":
return False
if not _is_direct_openai_base(self._effective_base):
return False
model_name = (model or self.default_model).lower()
if reasoning_effort and reasoning_effort.lower() != "none":
return True
return any(token in model_name for token in ("gpt-5", "o1", "o3", "o4"))
@staticmethod
def _should_fallback_from_responses_error(e: Exception) -> bool:
"""Fallback only for likely Responses API compatibility errors."""
response = getattr(e, "response", None)
status_code = getattr(e, "status_code", None)
if status_code is None and response is not None:
status_code = getattr(response, "status_code", None)
if status_code not in {400, 404, 422}:
return False
body = (
getattr(e, "body", None)
or getattr(e, "doc", None)
or getattr(response, "text", None)
)
body_text = str(body).lower() if body is not None else ""
compatibility_markers = (
"responses",
"response api",
"max_output_tokens",
"instructions",
"previous_response",
"unsupported",
"not supported",
"unknown parameter",
"unrecognized request argument",
)
return any(marker in body_text for marker in compatibility_markers)
def _build_responses_body(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
model: str | None,
max_tokens: int,
temperature: float,
reasoning_effort: str | None,
tool_choice: str | dict[str, Any] | None,
) -> dict[str, Any]:
"""Build a Responses API body for direct OpenAI requests."""
model_name = model or self.default_model
sanitized_messages = self._sanitize_messages(self._sanitize_empty_content(messages))
instructions, input_items = convert_messages(sanitized_messages)
body: dict[str, Any] = {
"model": model_name,
"instructions": instructions or None,
"input": input_items,
"max_output_tokens": max(1, max_tokens),
"store": False,
"stream": False,
}
if self._supports_temperature(model_name, reasoning_effort):
body["temperature"] = temperature
if reasoning_effort and reasoning_effort.lower() != "none":
body["reasoning"] = {"effort": reasoning_effort}
body["include"] = ["reasoning.encrypted_content"]
if tools:
body["tools"] = convert_tools(tools)
body["tool_choice"] = tool_choice or "auto"
return body
# ------------------------------------------------------------------
# Response parsing
# ------------------------------------------------------------------
@ -698,7 +799,12 @@ class OpenAICompatProvider(LLMProvider):
}
@staticmethod
def _handle_error(e: Exception) -> LLMResponse:
def _handle_error(
e: Exception,
*,
spec: ProviderSpec | None = None,
api_base: str | None = None,
) -> LLMResponse:
body = (
getattr(e, "doc", None)
or getattr(e, "body", None)
@ -706,6 +812,15 @@ class OpenAICompatProvider(LLMProvider):
)
body_text = body if isinstance(body, str) else str(body) if body is not None else ""
msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}"
text = f"{body_text} {e}".lower()
if spec and spec.is_local and ("502" in text or "connection" in text or "refused" in text):
msg += (
"\nHint: this is a local model endpoint. Check that the local server is reachable at "
f"{api_base or spec.default_api_base}, and if you are using a proxy/tunnel, make sure it "
"can reach your local Ollama/vLLM service instead of routing localhost through the remote host."
)
response = getattr(e, "response", None)
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
if retry_after is None:
@ -731,14 +846,25 @@ class OpenAICompatProvider(LLMProvider):
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
kwargs = self._build_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
try:
if self._should_use_responses_api(model, reasoning_effort):
try:
body = self._build_responses_body(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
return parse_response_output(await self._client.responses.create(**body))
except Exception as responses_error:
if not self._should_fallback_from_responses_error(responses_error):
raise
kwargs = self._build_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
return self._parse(await self._client.chat.completions.create(**kwargs))
except Exception as e:
return self._handle_error(e)
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
async def chat_stream(
self,
@ -751,14 +877,49 @@ class OpenAICompatProvider(LLMProvider):
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
kwargs = self._build_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
try:
if self._should_use_responses_api(model, reasoning_effort):
try:
body = self._build_responses_body(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
body["stream"] = True
stream = await self._client.responses.create(**body)
async def _timed_stream():
stream_iter = stream.__aiter__()
while True:
try:
yield await asyncio.wait_for(
stream_iter.__anext__(),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
break
content, tool_calls, finish_reason, usage, reasoning_content = await consume_sdk_stream(
_timed_stream(),
on_content_delta,
)
return LLMResponse(
content=content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage,
reasoning_content=reasoning_content,
)
except Exception as responses_error:
if not self._should_fallback_from_responses_error(responses_error):
raise
kwargs = self._build_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
stream = await self._client.chat.completions.create(**kwargs)
chunks: list[Any] = []
stream_iter = stream.__aiter__()
@ -786,7 +947,7 @@ class OpenAICompatProvider(LLMProvider):
error_kind="timeout",
)
except Exception as e:
return self._handle_error(e)
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
def get_default_model(self) -> str:
return self.default_model

View File

@ -155,6 +155,7 @@ class SessionManager:
messages = []
metadata = {}
created_at = None
updated_at = None
last_consolidated = 0
with open(path, encoding="utf-8") as f:
@ -168,6 +169,7 @@ class SessionManager:
if data.get("_type") == "metadata":
metadata = data.get("metadata", {})
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
updated_at = datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None
last_consolidated = data.get("last_consolidated", 0)
else:
messages.append(data)
@ -176,6 +178,7 @@ class SessionManager:
key=key,
messages=messages,
created_at=created_at or datetime.now(),
updated_at=updated_at or datetime.now(),
metadata=metadata,
last_consolidated=last_consolidated
)

View File

@ -3,6 +3,7 @@ Compare conversation history against current memory files. Also scan memory file
Output one line per finding:
[FILE] atomic fact (not already in memory)
[FILE-REMOVE] reason for removal
[SKILL] kebab-case-name: one-line description of the reusable pattern
Files: USER (identity, preferences), SOUL (bot behavior, tone), MEMORY (knowledge, project context)
@ -18,6 +19,12 @@ Staleness — flag for [FILE-REMOVE]:
- Detailed incident info after 14 days — reduce to one-line summary
- Superseded: approaches replaced by newer solutions, deprecated dependencies
Skill discovery — flag [SKILL] when ALL of these are true:
- A specific, repeatable workflow appeared 2+ times in the conversation history
- It involves clear steps (not vague preferences like "likes concise answers")
- It is substantial enough to warrant its own instruction set (not trivial like "read a file")
- Do not worry about duplicates — the next phase will check against existing skills
Do not add: current weather, transient status, temporary errors, conversational filler.
[SKIP] if nothing needs updating.

View File

@ -1,11 +1,13 @@
Update memory files based on the analysis below.
- [FILE] entries: add the described content to the appropriate file
- [FILE-REMOVE] entries: delete the corresponding content from memory files
- [SKILL] entries: create a new skill under skills/<name>/SKILL.md using write_file
## File paths (relative to workspace root)
- SOUL.md
- USER.md
- memory/MEMORY.md
- skills/<name>/SKILL.md (for [SKILL] entries only)
Do NOT guess paths.
@ -17,6 +19,17 @@ Do NOT guess paths.
- Surgical edits only — never rewrite entire files
- If nothing to update, stop without calling tools
## Skill creation rules (for [SKILL] entries)
- Use write_file to create skills/<name>/SKILL.md
- Before writing, read_file `{{ skill_creator_path }}` for format reference (frontmatter structure, naming conventions, quality standards)
- **Dedup check**: read existing skills listed below to verify the new skill is not functionally redundant. Skip creation if an existing skill already covers the same workflow.
- Include YAML frontmatter with name and description fields
- Keep SKILL.md under 2000 words — concise and actionable
- Include: when to use, steps, output format, at least one example
- Do NOT overwrite existing skills — skip if the skill directory already exists
- Reference specific tools the agent has access to (read_file, write_file, exec, web_search, etc.)
- Skills are instruction sets, not code — do not include implementation code
## Quality
- Every line must carry standalone value
- Concise bullets under clear headers

View File

@ -15,9 +15,12 @@ from loguru import logger
def strip_think(text: str) -> str:
"""Remove <think>…</think> blocks and any unclosed trailing <think> tag."""
"""Remove thinking blocks and any unclosed trailing tag."""
text = re.sub(r"<think>[\s\S]*?</think>", "", text)
text = re.sub(r"<think>[\s\S]*$", "", text)
text = re.sub(r"^\s*<think>[\s\S]*$", "", text)
# Gemma 4 and similar models use <thought>...</thought> blocks
text = re.sub(r"<thought>[\s\S]*?</thought>", "", text)
text = re.sub(r"^\s*<thought>[\s\S]*$", "", text)
return text.strip()
@ -272,7 +275,7 @@ def build_assistant_message(
thinking_blocks: list[dict] | None = None,
) -> dict[str, Any]:
"""Build a provider-safe assistant message with optional reasoning fields."""
msg: dict[str, Any] = {"role": "assistant", "content": content}
msg: dict[str, Any] = {"role": "assistant", "content": content or ""}
if tool_calls:
msg["tool_calls"] = tool_calls
if reasoning_content is not None or thinking_blocks:
@ -417,7 +420,7 @@ def build_status_content(
ctx_total = max(context_window_tokens, 0)
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a"
ctx_total_str = f"{ctx_total // 1000}k" if ctx_total > 0 else "n/a"
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
if cached and last_in:
token_line += f" ({cached * 100 // last_in}% cached)"

View File

@ -2,6 +2,8 @@
from __future__ import annotations
import re
from nanobot.utils.path import abbreviate_path
# Registry: tool_name -> (key_args, template, is_path, is_command)
@ -17,27 +19,39 @@ _TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = {
"list_dir": (["path"], "ls {}", True, False),
}
# Matches file paths embedded in shell commands, including quoted paths with spaces.
_PATH_IN_CMD_RE = re.compile(
r'"(?P<double>(?:[A-Za-z]:[/\\]|~/|/)[^"]+)"'
r"|'(?P<single>(?:[A-Za-z]:[/\\]|~/|/)[^']+)'"
r"|(?P<bare>(?:[A-Za-z]:[/\\]|~/|(?<=\s)/)[^\s;&|<>\"']+)"
)
def format_tool_hints(tool_calls: list) -> str:
"""Format tool calls as concise hints with smart abbreviation."""
if not tool_calls:
return ""
hints = []
for name, count, example_tc in _group_consecutive(tool_calls):
fmt = _TOOL_FORMATS.get(name)
formatted = []
for tc in tool_calls:
fmt = _TOOL_FORMATS.get(tc.name)
if fmt:
hint = _fmt_known(example_tc, fmt)
elif name.startswith("mcp_"):
hint = _fmt_mcp(example_tc)
formatted.append(_fmt_known(tc, fmt))
elif tc.name.startswith("mcp_"):
formatted.append(_fmt_mcp(tc))
else:
hint = _fmt_fallback(example_tc)
formatted.append(_fmt_fallback(tc))
if count > 1:
hint = f"{hint} \u00d7 {count}"
hints.append(hint)
hints = []
for hint in formatted:
if hints and hints[-1][0] == hint:
hints[-1] = (hint, hints[-1][1] + 1)
else:
hints.append((hint, 1))
return ", ".join(hints)
return ", ".join(
f"{h} \u00d7 {c}" if c > 1 else h for h, c in hints
)
def _get_args(tc) -> dict:
@ -51,17 +65,6 @@ def _get_args(tc) -> dict:
return {}
def _group_consecutive(calls: list) -> list[tuple[str, int, object]]:
"""Group consecutive calls to the same tool: [(name, count, first), ...]."""
groups: list[tuple[str, int, object]] = []
for tc in calls:
if groups and groups[-1][0] == tc.name:
groups[-1] = (groups[-1][0], groups[-1][1] + 1, groups[-1][2])
else:
groups.append((tc.name, 1, tc))
return groups
def _extract_arg(tc, key_args: list[str]) -> str | None:
"""Extract the first available value from preferred key names."""
args = _get_args(tc)
@ -85,10 +88,25 @@ def _fmt_known(tc, fmt: tuple) -> str:
if fmt[2]: # is_path
val = abbreviate_path(val)
elif fmt[3]: # is_command
val = val[:40] + "\u2026" if len(val) > 40 else val
val = _abbreviate_command(val)
return fmt[1].format(val)
def _abbreviate_command(cmd: str, max_len: int = 40) -> str:
"""Abbreviate paths in a command string, then truncate."""
def _replace_path(match: re.Match[str]) -> str:
if match.group("double") is not None:
return f'"{abbreviate_path(match.group("double"), max_len=25)}"'
if match.group("single") is not None:
return f"'{abbreviate_path(match.group('single'), max_len=25)}'"
return abbreviate_path(match.group("bare"), max_len=25)
abbreviated = _PATH_IN_CMD_RE.sub(_replace_path, cmd)
if len(abbreviated) <= max_len:
return abbreviated
return abbreviated[:max_len - 1] + "\u2026"
def _fmt_mcp(tc) -> str:
"""Format MCP tool as server::tool."""
name = tc.name

View File

@ -54,6 +54,7 @@ dependencies = [
"python-docx>=1.1.0,<2.0.0",
"openpyxl>=3.1.0,<4.0.0",
"python-pptx>=1.0.0,<2.0.0",
"filelock>=3.25.2",
]
[project.optional-dependencies]
@ -79,12 +80,16 @@ discord = [
langsmith = [
"langsmith>=0.1.0",
]
pdf = [
"pymupdf>=1.25.0",
]
dev = [
"pytest>=9.0.0,<10.0.0",
"pytest-asyncio>=1.3.0,<2.0.0",
"aiohttp>=3.9.0,<4.0.0",
"pytest-cov>=6.0.0,<7.0.0",
"ruff>=0.1.0",
"pymupdf>=1.25.0",
]
[project.scripts]

File diff suppressed because it is too large Load Diff

View File

@ -46,7 +46,7 @@ class TestConsolidatorSummarize:
{"role": "assistant", "content": "Done, fixed the race condition."},
]
result = await consolidator.archive(messages)
assert result is True
assert result == "User fixed a bug in the auth module."
entries = store.read_unprocessed_history(since_cursor=0)
assert len(entries) == 1
@ -55,14 +55,14 @@ class TestConsolidatorSummarize:
mock_provider.chat_with_retry.side_effect = Exception("API error")
messages = [{"role": "user", "content": "hello"}]
result = await consolidator.archive(messages)
assert result is True # always succeeds
assert result is None # no summary on raw dump fallback
entries = store.read_unprocessed_history(since_cursor=0)
assert len(entries) == 1
assert "[RAW]" in entries[0]["content"]
async def test_summarize_skips_empty_messages(self, consolidator):
result = await consolidator.archive([])
assert result is False
assert result is None
class TestConsolidatorTokenBudget:
@ -76,3 +76,52 @@ class TestConsolidatorTokenBudget:
consolidator.archive = AsyncMock(return_value=True)
await consolidator.maybe_consolidate_by_tokens(session)
consolidator.archive.assert_not_called()
async def test_chunk_cap_preserves_user_turn_boundary(self, consolidator):
"""Chunk cap should rewind to the last user boundary within the cap."""
consolidator._SAFETY_BUFFER = 0
session = MagicMock()
session.last_consolidated = 0
session.key = "test:key"
session.messages = [
{
"role": "user" if i in {0, 50, 61} else "assistant",
"content": f"m{i}",
}
for i in range(70)
]
consolidator.estimate_session_prompt_tokens = MagicMock(
side_effect=[(1200, "tiktoken"), (400, "tiktoken")]
)
consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999))
consolidator.archive = AsyncMock(return_value=True)
await consolidator.maybe_consolidate_by_tokens(session)
archived_chunk = consolidator.archive.await_args.args[0]
assert len(archived_chunk) == 50
assert archived_chunk[0]["content"] == "m0"
assert archived_chunk[-1]["content"] == "m49"
assert session.last_consolidated == 50
async def test_chunk_cap_skips_when_no_user_boundary_within_cap(self, consolidator):
"""If the cap would cut mid-turn, consolidation should skip that round."""
consolidator._SAFETY_BUFFER = 0
session = MagicMock()
session.last_consolidated = 0
session.key = "test:key"
session.messages = [
{
"role": "user" if i in {0, 61} else "assistant",
"content": f"m{i}",
}
for i in range(70)
]
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(1200, "tiktoken"))
consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999))
consolidator.archive = AsyncMock(return_value=True)
await consolidator.maybe_consolidate_by_tokens(session)
consolidator.archive.assert_not_awaited()
assert session.last_consolidated == 0

View File

@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock
from nanobot.agent.memory import Dream, MemoryStore
from nanobot.agent.runner import AgentRunResult
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
@pytest.fixture
@ -95,3 +96,30 @@ class TestDreamRun:
entries = store.read_unprocessed_history(since_cursor=0)
assert all(e["cursor"] > 0 for e in entries)
async def test_skill_phase_uses_builtin_skill_creator_path(self, dream, mock_provider, mock_runner, store):
"""Dream should point skill creation guidance at the builtin skill-creator template."""
store.append_history("Repeated workflow one")
store.append_history("Repeated workflow two")
mock_provider.chat_with_retry.return_value = MagicMock(content="[SKILL] test-skill: test description")
mock_runner.run = AsyncMock(return_value=_make_run_result())
await dream.run()
spec = mock_runner.run.call_args[0][0]
system_prompt = spec.initial_messages[0]["content"]
expected = str(BUILTIN_SKILLS_DIR / "skill-creator" / "SKILL.md")
assert expected in system_prompt
async def test_skill_write_tool_accepts_workspace_relative_skill_path(self, dream, store):
"""Dream skill creation should allow skills/<name>/SKILL.md relative to workspace root."""
write_tool = dream._tools.get("write_file")
assert write_tool is not None
result = await write_tool.execute(
path="skills/test-skill/SKILL.md",
content="---\nname: test-skill\ndescription: Test\n---\n",
)
assert "Successfully wrote" in result
assert (store.workspace / "skills" / "test-skill" / "SKILL.md").exists()

View File

@ -184,17 +184,22 @@ def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
messages = [{
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_1",
"type": "function",
"function": {"name": "fn", "arguments": "{}"},
"extra_content": GEMINI_EXTRA,
}],
}]
messages = [
{"role": "user", "content": "hi"},
{
"role": "assistant",
"content": None,
"tool_calls": [{
"id": "call_1",
"type": "function",
"function": {"name": "fn", "arguments": "{}"},
"extra_content": GEMINI_EXTRA,
}],
},
{"role": "tool", "content": "ok", "tool_call_id": "call_1"},
{"role": "user", "content": "thanks"},
]
sanitized = provider._sanitize_messages(messages)
assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA
assert sanitized[1]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA

View File

@ -232,6 +232,35 @@ async def test_composite_empty_hooks_no_ops():
assert hook.finalize_content(ctx, "test") == "test"
@pytest.mark.asyncio
async def test_composite_supports_legacy_hook_init_without_super():
calls: list[str] = []
class LegacyHook(AgentHook):
def __init__(self, label: str) -> None:
self.label = label
async def before_iteration(self, context: AgentHookContext) -> None:
calls.append(self.label)
hook = CompositeHook([LegacyHook("legacy")])
await hook.before_iteration(_ctx())
assert calls == ["legacy"]
@pytest.mark.asyncio
async def test_composite_can_wrap_another_composite():
calls: list[str] = []
class Inner(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
calls.append("inner")
hook = CompositeHook([CompositeHook([Inner()])])
await hook.before_iteration(_ctx())
assert calls == ["inner"]
# ---------------------------------------------------------------------------
# Integration: AgentLoop with extra hooks
# ---------------------------------------------------------------------------
@ -278,7 +307,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
)
loop.tools.get_definitions = MagicMock(return_value=[])
content, tools_used, messages = await loop._run_agent_loop(
content, tools_used, messages, _, _ = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}]
)
@ -302,7 +331,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path):
)
loop.tools.get_definitions = MagicMock(return_value=[])
content, _, _ = await loop._run_agent_loop(
content, _, _, _, _ = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}]
)
@ -344,7 +373,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path):
loop.tools.execute = AsyncMock(return_value="ok")
loop.max_iterations = 2
content, tools_used, _ = await loop._run_agent_loop([])
content, tools_used, _, _, _ = await loop._run_agent_loop([])
assert content == (
"I reached the maximum number of tool call iterations (2) "
"without completing the task. You can try breaking the task into smaller steps."

View File

@ -1,5 +1,13 @@
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.context import ContextBuilder
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.session.manager import Session
@ -11,6 +19,12 @@ def _mk_loop() -> AgentLoop:
return loop
def _make_full_loop(tmp_path: Path) -> AgentLoop:
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
loop = _mk_loop()
session = Session(key="test:runtime-only")
@ -200,3 +214,206 @@ def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
assert session.messages[0]["role"] == "assistant"
assert session.messages[1]["tool_call_id"] == "call_done"
assert session.messages[2]["tool_call_id"] == "call_pending"
@pytest.mark.asyncio
async def test_process_message_persists_user_message_before_turn_completes(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
loop._run_agent_loop = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c1", content="persist me")
with pytest.raises(RuntimeError, match="boom"):
await loop._process_message(msg)
loop.sessions.invalidate("feishu:c1")
persisted = loop.sessions.get_or_create("feishu:c1")
assert [m["role"] for m in persisted.messages] == ["user"]
assert persisted.messages[0]["content"] == "persist me"
assert persisted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True
assert persisted.updated_at >= persisted.created_at
@pytest.mark.asyncio
async def test_process_message_does_not_duplicate_early_persisted_user_message(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
loop._run_agent_loop = AsyncMock(return_value=(
"done",
None,
[
{"role": "system", "content": "system"},
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "done"},
],
"stop",
False,
)) # type: ignore[method-assign]
result = await loop._process_message(
InboundMessage(channel="feishu", sender_id="u1", chat_id="c2", content="hello")
)
assert result is not None
assert result.content == "done"
session = loop.sessions.get_or_create("feishu:c2")
assert [
{k: v for k, v in m.items() if k in {"role", "content"}}
for m in session.messages
] == [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "done"},
]
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
@pytest.mark.asyncio
async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
loop.provider.chat_with_retry = AsyncMock(return_value=MagicMock()) # unused because _run_agent_loop is stubbed
session = loop.sessions.get_or_create("feishu:c3")
session.add_message("user", "old question")
session.metadata[AgentLoop._PENDING_USER_TURN_KEY] = True
loop.sessions.save(session)
loop._run_agent_loop = AsyncMock(return_value=(
"new answer",
None,
[
{"role": "system", "content": "system"},
{"role": "user", "content": "old question"},
{"role": "assistant", "content": "Error: Task interrupted before a response was generated."},
{"role": "user", "content": "new question"},
{"role": "assistant", "content": "new answer"},
],
"stop",
False,
)) # type: ignore[method-assign]
result = await loop._process_message(
InboundMessage(channel="feishu", sender_id="u1", chat_id="c3", content="new question")
)
assert result is not None
assert result.content == "new answer"
session = loop.sessions.get_or_create("feishu:c3")
assert [
{k: v for k, v in m.items() if k in {"role", "content"}}
for m in session.messages
] == [
{"role": "user", "content": "old question"},
{"role": "assistant", "content": "Error: Task interrupted before a response was generated."},
{"role": "user", "content": "new question"},
{"role": "assistant", "content": "new answer"},
]
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
@pytest.mark.asyncio
async def test_stop_preserves_runtime_checkpoint_for_next_turn(tmp_path: Path) -> None:
from nanobot.command.builtin import cmd_stop
from nanobot.command.router import CommandContext
loop = _make_full_loop(tmp_path)
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
checkpoint_saved = asyncio.Event()
async def interrupted_run_agent_loop(_initial_messages, *, session=None, **_kwargs):
assert session is not None
loop._set_runtime_checkpoint(
session,
{
"assistant_message": {
"role": "assistant",
"content": "working",
"tool_calls": [
{
"id": "call_done",
"type": "function",
"function": {"name": "read_file", "arguments": "{}"},
},
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
},
],
},
"completed_tool_results": [
{
"role": "tool",
"tool_call_id": "call_done",
"name": "read_file",
"content": "ok",
}
],
"pending_tool_calls": [
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
}
],
},
)
checkpoint_saved.set()
await asyncio.Event().wait()
loop._run_agent_loop = interrupted_run_agent_loop # type: ignore[method-assign]
first_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="keep progress")
task = asyncio.create_task(loop._process_message(first_msg))
loop._active_tasks[first_msg.session_key] = [task]
await asyncio.wait_for(checkpoint_saved.wait(), timeout=1.0)
stop_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="/stop")
stop_ctx = CommandContext(msg=stop_msg, session=None, key=stop_msg.session_key, raw="/stop", loop=loop)
stop_result = await cmd_stop(stop_ctx)
assert "Stopped 1 task" in stop_result.content
assert task.done()
loop.sessions.invalidate("feishu:c4")
interrupted = loop.sessions.get_or_create("feishu:c4")
assert interrupted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True
assert interrupted.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is not None
async def resumed_run_agent_loop(initial_messages, **_kwargs):
return (
"next answer",
None,
[*initial_messages, {"role": "assistant", "content": "next answer"}],
"stop",
False,
)
loop._run_agent_loop = resumed_run_agent_loop # type: ignore[method-assign]
result = await loop._process_message(
InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="continue here")
)
assert result is not None
assert result.content == "next answer"
session = loop.sessions.get_or_create("feishu:c4")
assert [
{k: v for k, v in m.items() if k in {"role", "content", "tool_call_id", "name"}}
for m in session.messages
] == [
{"role": "user", "content": "keep progress"},
{"role": "assistant", "content": "working"},
{"role": "tool", "tool_call_id": "call_done", "name": "read_file", "content": "ok"},
{
"role": "tool",
"tool_call_id": "call_pending",
"name": "exec",
"content": "Error: Task interrupted before this tool finished.",
},
{"role": "user", "content": "continue here"},
{"role": "assistant", "content": "next answer"},
]
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
assert AgentLoop._RUNTIME_CHECKPOINT_KEY not in session.metadata

View File

@ -0,0 +1,44 @@
"""Tests for MCP connection lifecycle in AgentLoop."""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
def _make_loop(tmp_path, *, mcp_servers: dict | None = None) -> AgentLoop:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.generation.max_tokens = 4096
return AgentLoop(
bus=bus,
provider=provider,
workspace=tmp_path,
model="test-model",
mcp_servers=mcp_servers or {"test": object()},
)
@pytest.mark.asyncio
async def test_connect_mcp_retries_when_no_servers_connect(tmp_path, monkeypatch: pytest.MonkeyPatch):
loop = _make_loop(tmp_path)
attempts = 0
async def _fake_connect(_servers, _registry):
nonlocal attempts
attempts += 1
return {}
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
await loop._connect_mcp()
await loop._connect_mcp()
assert attempts == 2
assert loop._mcp_connected is False
assert loop._mcp_stacks == {}

File diff suppressed because it is too large Load Diff

View File

@ -250,3 +250,63 @@ def test_list_skills_openclaw_metadata_parsed_for_requirements(
assert entries == [
{"name": "openclaw_skill", "path": str(skill_path), "source": "workspace"},
]
def test_disabled_skills_excluded_from_list(tmp_path: Path) -> None:
workspace = tmp_path / "ws"
ws_skills = workspace / "skills"
ws_skills.mkdir(parents=True)
_write_skill(ws_skills, "alpha", body="# Alpha")
beta_path = _write_skill(ws_skills, "beta", body="# Beta")
builtin = tmp_path / "builtin"
builtin.mkdir()
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"})
entries = loader.list_skills(filter_unavailable=False)
assert len(entries) == 1
assert entries[0]["name"] == "beta"
assert entries[0]["path"] == str(beta_path)
def test_disabled_skills_empty_set_no_effect(tmp_path: Path) -> None:
workspace = tmp_path / "ws"
ws_skills = workspace / "skills"
ws_skills.mkdir(parents=True)
_write_skill(ws_skills, "alpha", body="# Alpha")
_write_skill(ws_skills, "beta", body="# Beta")
builtin = tmp_path / "builtin"
builtin.mkdir()
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills=set())
entries = loader.list_skills(filter_unavailable=False)
assert len(entries) == 2
def test_disabled_skills_excluded_from_build_skills_summary(tmp_path: Path) -> None:
workspace = tmp_path / "ws"
ws_skills = workspace / "skills"
ws_skills.mkdir(parents=True)
_write_skill(ws_skills, "alpha", body="# Alpha")
_write_skill(ws_skills, "beta", body="# Beta")
builtin = tmp_path / "builtin"
builtin.mkdir()
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"})
summary = loader.build_skills_summary()
assert "alpha" not in summary
assert "beta" in summary
def test_disabled_skills_excluded_from_get_always_skills(tmp_path: Path) -> None:
workspace = tmp_path / "ws"
ws_skills = workspace / "skills"
ws_skills.mkdir(parents=True)
_write_skill(ws_skills, "alpha", metadata_json={"always": True}, body="# Alpha")
_write_skill(ws_skills, "beta", metadata_json={"always": True}, body="# Beta")
builtin = tmp_path / "builtin"
builtin.mkdir()
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"})
always = loader.get_always_skills()
assert "alpha" not in always
assert "beta" in always

View File

@ -52,6 +52,53 @@ class TestToolHintKnownTools:
assert result.startswith("$ ")
assert len(result) <= 50 # reasonable limit
def test_exec_abbreviates_paths_in_command(self):
"""Windows paths in exec commands should be folded, not blindly truncated."""
cmd = "cd D:\\Documents\\GitHub\\nanobot\\.worktree\\tomain\\nanobot && git diff origin/main...pr-2706 --name-only 2>&1"
result = _hint([_tc("exec", {"command": cmd})])
assert "\u2026/" in result # path should be folded with …/
assert "worktree" not in result # middle segments should be collapsed
def test_exec_abbreviates_linux_paths(self):
"""Unix absolute paths in exec commands should be folded."""
cmd = "cd /home/user/projects/nanobot/.worktree/tomain && make build"
result = _hint([_tc("exec", {"command": cmd})])
assert "\u2026/" in result
assert "projects" not in result
def test_exec_abbreviates_home_paths(self):
"""~/ paths in exec commands should be folded."""
cmd = "cd ~/projects/nanobot/workspace && pytest tests/"
result = _hint([_tc("exec", {"command": cmd})])
assert "\u2026/" in result
def test_exec_abbreviates_quoted_linux_paths_with_spaces(self):
"""Quoted Unix paths with spaces should still be folded."""
cmd = 'cd "/home/user/My Documents/project" && pytest tests/'
result = _hint([_tc("exec", {"command": cmd})])
assert "\u2026/" in result
assert '"/home/user/My Documents/project"' not in result
assert '"' in result
def test_exec_abbreviates_quoted_windows_paths_with_spaces(self):
"""Quoted Windows paths with spaces should still be folded."""
cmd = 'cd "C:/Program Files/Git/project" && git status'
result = _hint([_tc("exec", {"command": cmd})])
assert "\u2026/" in result
assert '"C:/Program Files/Git/project"' not in result
assert '"' in result
def test_exec_short_command_unchanged(self):
result = _hint([_tc("exec", {"command": "npm install typescript"})])
assert result == "$ npm install typescript"
def test_exec_chained_commands_truncated_not_mid_path(self):
"""Long chained commands should truncate preserving abbreviated paths."""
cmd = "cd D:\\Documents\\GitHub\\project && npm run build && npm test"
result = _hint([_tc("exec", {"command": cmd})])
assert "\u2026/" in result # path folded
assert "npm" in result # chained command still visible
def test_web_search(self):
result = _hint([_tc("web_search", {"query": "Claude 4 vs GPT-4"})])
assert result == 'search "Claude 4 vs GPT-4"'
@ -105,22 +152,30 @@ class TestToolHintFolding:
result = _hint(calls)
assert "\u00d7" not in result
def test_two_consecutive_same_folded(self):
def test_two_consecutive_different_args_not_folded(self):
calls = [
_tc("grep", {"pattern": "*.py"}),
_tc("grep", {"pattern": "*.ts"}),
]
result = _hint(calls)
assert "\u00d7" not in result
def test_two_consecutive_same_args_folded(self):
calls = [
_tc("grep", {"pattern": "TODO"}),
_tc("grep", {"pattern": "TODO"}),
]
result = _hint(calls)
assert "\u00d7 2" in result
def test_three_consecutive_same_folded(self):
def test_three_consecutive_different_args_not_folded(self):
calls = [
_tc("read_file", {"path": "a.py"}),
_tc("read_file", {"path": "b.py"}),
_tc("read_file", {"path": "c.py"}),
]
result = _hint(calls)
assert "\u00d7 3" in result
assert "\u00d7" not in result
def test_different_tools_not_folded(self):
calls = [
@ -187,7 +242,7 @@ class TestToolHintMixedFolding:
"""G4: Mixed folding groups with interleaved same-tool segments."""
def test_read_read_grep_grep_read(self):
"""read×2, grep×2, read — should produce two separate groups."""
"""All different args — each hint listed separately."""
calls = [
_tc("read_file", {"path": "a.py"}),
_tc("read_file", {"path": "b.py"}),
@ -196,7 +251,6 @@ class TestToolHintMixedFolding:
_tc("read_file", {"path": "c.py"}),
]
result = _hint(calls)
assert "\u00d7 2" in result
# Should have 3 groups: read×2, grep×2, read
assert "\u00d7" not in result
parts = result.split(", ")
assert len(parts) == 3
assert len(parts) == 5

View File

@ -0,0 +1,502 @@
"""Tests for unified_session feature.
Covers:
- AgentLoop._dispatch() rewrites session_key to "unified:default" when enabled
- Existing session_key_override is respected (not overwritten)
- Feature is off by default (no behavior change for existing users)
- Config schema serialises unified_session as camelCase "unifiedSession"
- onboard-generated config.json contains "unifiedSession" key
- /new command correctly clears the shared session in unified mode
- /new is NOT a priority command (goes through _dispatch, key rewrite applies)
- Context window consolidation is unaffected by unified_session
"""
import asyncio
import json
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.command.builtin import cmd_new, register_builtin_commands
from nanobot.command.router import CommandContext, CommandRouter
from nanobot.config.schema import AgentDefaults, Config
from nanobot.session.manager import Session, SessionManager
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_loop(tmp_path: Path, unified_session: bool = False) -> AgentLoop:
"""Create a minimal AgentLoop for dispatch-level tests."""
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
with patch("nanobot.agent.loop.SessionManager"), \
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr, \
patch("nanobot.agent.loop.Dream"):
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
loop = AgentLoop(
bus=bus,
provider=provider,
workspace=tmp_path,
unified_session=unified_session,
)
return loop
def _make_msg(channel: str = "telegram", chat_id: str = "111",
session_key_override: str | None = None) -> InboundMessage:
return InboundMessage(
channel=channel,
chat_id=chat_id,
sender_id="user1",
content="hello",
session_key_override=session_key_override,
)
# ---------------------------------------------------------------------------
# TestUnifiedSessionDispatch — core behaviour
# ---------------------------------------------------------------------------
class TestUnifiedSessionDispatch:
"""AgentLoop._dispatch() session key rewriting logic."""
@pytest.mark.asyncio
async def test_unified_session_rewrites_key_to_unified_default(self, tmp_path: Path):
"""When unified_session=True, all messages use 'unified:default' as session key."""
loop = _make_loop(tmp_path, unified_session=True)
captured: list[str] = []
async def fake_process(msg, **kwargs):
captured.append(msg.session_key)
return None
loop._process_message = fake_process # type: ignore[method-assign]
msg = _make_msg(channel="telegram", chat_id="111")
await loop._dispatch(msg)
assert captured == ["unified:default"]
@pytest.mark.asyncio
async def test_unified_session_different_channels_share_same_key(self, tmp_path: Path):
"""Messages from different channels all resolve to the same session key."""
loop = _make_loop(tmp_path, unified_session=True)
captured: list[str] = []
async def fake_process(msg, **kwargs):
captured.append(msg.session_key)
return None
loop._process_message = fake_process # type: ignore[method-assign]
await loop._dispatch(_make_msg(channel="telegram", chat_id="111"))
await loop._dispatch(_make_msg(channel="discord", chat_id="222"))
await loop._dispatch(_make_msg(channel="cli", chat_id="direct"))
assert captured == ["unified:default", "unified:default", "unified:default"]
@pytest.mark.asyncio
async def test_unified_session_disabled_preserves_original_key(self, tmp_path: Path):
"""When unified_session=False (default), session key is channel:chat_id as usual."""
loop = _make_loop(tmp_path, unified_session=False)
captured: list[str] = []
async def fake_process(msg, **kwargs):
captured.append(msg.session_key)
return None
loop._process_message = fake_process # type: ignore[method-assign]
msg = _make_msg(channel="telegram", chat_id="999")
await loop._dispatch(msg)
assert captured == ["telegram:999"]
@pytest.mark.asyncio
async def test_unified_session_respects_existing_override(self, tmp_path: Path):
"""If session_key_override is already set (e.g. Telegram thread), it is NOT overwritten."""
loop = _make_loop(tmp_path, unified_session=True)
captured: list[str] = []
async def fake_process(msg, **kwargs):
captured.append(msg.session_key)
return None
loop._process_message = fake_process # type: ignore[method-assign]
msg = _make_msg(channel="telegram", chat_id="111", session_key_override="telegram:thread:42")
await loop._dispatch(msg)
assert captured == ["telegram:thread:42"]
def test_unified_session_default_is_false(self, tmp_path: Path):
"""unified_session defaults to False — no behavior change for existing users."""
loop = _make_loop(tmp_path)
assert loop._unified_session is False
# ---------------------------------------------------------------------------
# TestUnifiedSessionConfig — schema & serialisation
# ---------------------------------------------------------------------------
class TestUnifiedSessionConfig:
"""Config schema and onboard serialisation for unified_session."""
def test_agent_defaults_unified_session_default_is_false(self):
"""AgentDefaults.unified_session defaults to False."""
defaults = AgentDefaults()
assert defaults.unified_session is False
def test_agent_defaults_unified_session_can_be_enabled(self):
"""AgentDefaults.unified_session can be set to True."""
defaults = AgentDefaults(unified_session=True)
assert defaults.unified_session is True
def test_config_serialises_unified_session_as_camel_case(self):
"""model_dump(by_alias=True) outputs 'unifiedSession' (camelCase) for JSON."""
config = Config()
data = config.model_dump(mode="json", by_alias=True)
agents_defaults = data["agents"]["defaults"]
assert "unifiedSession" in agents_defaults
assert agents_defaults["unifiedSession"] is False
def test_config_parses_unified_session_from_camel_case(self):
"""Config can be loaded from JSON with camelCase 'unifiedSession'."""
raw = {"agents": {"defaults": {"unifiedSession": True}}}
config = Config.model_validate(raw)
assert config.agents.defaults.unified_session is True
def test_config_parses_unified_session_from_snake_case(self):
"""Config also accepts snake_case 'unified_session' (populate_by_name=True)."""
raw = {"agents": {"defaults": {"unified_session": True}}}
config = Config.model_validate(raw)
assert config.agents.defaults.unified_session is True
def test_onboard_generated_config_contains_unified_session(self, tmp_path: Path):
"""save_config() writes 'unifiedSession' into config.json (simulates nanobot onboard)."""
from nanobot.config.loader import save_config
config = Config()
config_path = tmp_path / "config.json"
save_config(config, config_path)
with open(config_path, encoding="utf-8") as f:
data = json.load(f)
agents_defaults = data["agents"]["defaults"]
assert "unifiedSession" in agents_defaults, (
"onboard-generated config.json must contain 'unifiedSession' key"
)
assert agents_defaults["unifiedSession"] is False
# ---------------------------------------------------------------------------
# TestCmdNewUnifiedSession — /new command behaviour in unified mode
# ---------------------------------------------------------------------------
class TestCmdNewUnifiedSession:
"""/new command routing and session-clear behaviour in unified mode."""
def test_new_is_not_a_priority_command(self):
"""/new must NOT be in the priority table — it must go through _dispatch()
so the unified session key rewrite applies before cmd_new runs."""
router = CommandRouter()
register_builtin_commands(router)
assert router.is_priority("/new") is False
def test_new_is_an_exact_command(self):
"""/new must be registered as an exact command."""
router = CommandRouter()
register_builtin_commands(router)
assert "/new" in router._exact
@pytest.mark.asyncio
async def test_cmd_new_clears_unified_session(self, tmp_path: Path):
"""cmd_new called with key='unified:default' clears the shared session."""
sessions = SessionManager(tmp_path)
# Pre-populate the shared session with some messages
shared = sessions.get_or_create("unified:default")
shared.add_message("user", "hello from telegram")
shared.add_message("assistant", "hi there")
sessions.save(shared)
assert len(sessions.get_or_create("unified:default").messages) == 2
# _schedule_background is a *sync* method that schedules a coroutine via
# asyncio.create_task(). Mirror that exactly so the coroutine is consumed
# and no RuntimeWarning is emitted.
loop = SimpleNamespace(
sessions=sessions,
consolidator=SimpleNamespace(archive=AsyncMock(return_value=True)),
)
loop._schedule_background = lambda coro: asyncio.ensure_future(coro)
msg = InboundMessage(
channel="telegram", sender_id="user1", chat_id="111", content="/new",
session_key_override="unified:default", # as _dispatch() would set it
)
ctx = CommandContext(msg=msg, session=None, key="unified:default", raw="/new", loop=loop)
result = await cmd_new(ctx)
assert "New session started" in result.content
# Invalidate cache and reload from disk to confirm persistence
sessions.invalidate("unified:default")
reloaded = sessions.get_or_create("unified:default")
assert reloaded.messages == []
@pytest.mark.asyncio
async def test_cmd_new_in_unified_mode_does_not_affect_other_sessions(self, tmp_path: Path):
"""Clearing unified:default must not touch other sessions on disk."""
sessions = SessionManager(tmp_path)
other = sessions.get_or_create("discord:999")
other.add_message("user", "discord message")
sessions.save(other)
shared = sessions.get_or_create("unified:default")
shared.add_message("user", "shared message")
sessions.save(shared)
loop = SimpleNamespace(
sessions=sessions,
consolidator=SimpleNamespace(archive=AsyncMock(return_value=True)),
)
loop._schedule_background = lambda coro: asyncio.ensure_future(coro)
msg = InboundMessage(
channel="telegram", sender_id="user1", chat_id="111", content="/new",
session_key_override="unified:default",
)
ctx = CommandContext(msg=msg, session=None, key="unified:default", raw="/new", loop=loop)
await cmd_new(ctx)
sessions.invalidate("unified:default")
sessions.invalidate("discord:999")
assert sessions.get_or_create("unified:default").messages == []
assert len(sessions.get_or_create("discord:999").messages) == 1
# ---------------------------------------------------------------------------
# TestConsolidationUnaffectedByUnifiedSession — consolidation is key-agnostic
# ---------------------------------------------------------------------------
class TestConsolidationUnaffectedByUnifiedSession:
"""maybe_consolidate_by_tokens() behaviour is identical regardless of session key."""
@pytest.mark.asyncio
async def test_consolidation_skips_empty_session_for_unified_key(self):
"""Empty unified:default session → consolidation exits immediately, archive not called."""
from nanobot.agent.memory import Consolidator, MemoryStore
store = MagicMock(spec=MemoryStore)
mock_provider = MagicMock()
mock_provider.chat_with_retry = AsyncMock(return_value=MagicMock(content="summary"))
# Use spec= so MagicMock doesn't auto-generate AsyncMock for non-async methods,
# which would leave unawaited coroutines and trigger RuntimeWarning.
sessions = MagicMock(spec=SessionManager)
consolidator = Consolidator(
store=store,
provider=mock_provider,
model="test-model",
sessions=sessions,
context_window_tokens=1000,
build_messages=MagicMock(return_value=[]),
get_tool_definitions=MagicMock(return_value=[]),
max_completion_tokens=100,
)
consolidator.archive = AsyncMock()
session = Session(key="unified:default")
session.messages = []
await consolidator.maybe_consolidate_by_tokens(session)
consolidator.archive.assert_not_called()
@pytest.mark.asyncio
async def test_consolidation_behaviour_identical_for_any_key(self):
"""archive call count is the same for 'telegram:123' and 'unified:default'
under identical token conditions."""
from nanobot.agent.memory import Consolidator, MemoryStore
archive_calls: dict[str, int] = {}
for key in ("telegram:123", "unified:default"):
store = MagicMock(spec=MemoryStore)
mock_provider = MagicMock()
mock_provider.chat_with_retry = AsyncMock(return_value=MagicMock(content="summary"))
sessions = MagicMock(spec=SessionManager)
consolidator = Consolidator(
store=store,
provider=mock_provider,
model="test-model",
sessions=sessions,
context_window_tokens=1000,
build_messages=MagicMock(return_value=[]),
get_tool_definitions=MagicMock(return_value=[]),
max_completion_tokens=100,
)
session = Session(key=key)
session.messages = [] # empty → exits immediately for both keys
consolidator.archive = AsyncMock()
await consolidator.maybe_consolidate_by_tokens(session)
archive_calls[key] = consolidator.archive.call_count
assert archive_calls["telegram:123"] == archive_calls["unified:default"] == 0
@pytest.mark.asyncio
async def test_consolidation_triggers_when_over_budget_unified_key(self):
"""When tokens exceed budget, consolidation attempts to find a boundary —
behaviour is identical to any other session key."""
from nanobot.agent.memory import Consolidator, MemoryStore
store = MagicMock(spec=MemoryStore)
mock_provider = MagicMock()
sessions = MagicMock(spec=SessionManager)
consolidator = Consolidator(
store=store,
provider=mock_provider,
model="test-model",
sessions=sessions,
context_window_tokens=1000,
build_messages=MagicMock(return_value=[]),
get_tool_definitions=MagicMock(return_value=[]),
max_completion_tokens=100,
)
session = Session(key="unified:default")
session.messages = [{"role": "user", "content": "msg"}]
# Simulate over-budget: estimated > budget
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(950, "tiktoken"))
# No valid boundary found → returns gracefully without archiving
consolidator.pick_consolidation_boundary = MagicMock(return_value=None)
consolidator.archive = AsyncMock()
await consolidator.maybe_consolidate_by_tokens(session)
# estimate was called (consolidation was attempted)
consolidator.estimate_session_prompt_tokens.assert_called_once_with(session)
# but archive was not called (no valid boundary)
consolidator.archive.assert_not_called()
# ---------------------------------------------------------------------------
# TestStopCommandWithUnifiedSession — /stop command integration
# ---------------------------------------------------------------------------
class TestStopCommandWithUnifiedSession:
"""Verify /stop command works correctly with unified session enabled."""
@pytest.mark.asyncio
async def test_active_tasks_use_effective_key_in_unified_mode(self, tmp_path: Path):
"""When unified_session=True, tasks are stored under UNIFIED_SESSION_KEY."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
loop = _make_loop(tmp_path, unified_session=True)
# Create a message from telegram channel
msg = _make_msg(channel="telegram", chat_id="123456")
# Mock _dispatch to complete immediately
async def fake_dispatch(m):
pass
loop._dispatch = fake_dispatch # type: ignore[method-assign]
# Simulate the task creation flow (from _run loop)
effective_key = UNIFIED_SESSION_KEY if loop._unified_session and not msg.session_key_override else msg.session_key
task = asyncio.create_task(loop._dispatch(msg))
loop._active_tasks.setdefault(effective_key, []).append(task)
# Wait for task to complete
await task
# Verify the task is stored under UNIFIED_SESSION_KEY, not the original channel:chat_id
assert UNIFIED_SESSION_KEY in loop._active_tasks
assert "telegram:123456" not in loop._active_tasks
@pytest.mark.asyncio
async def test_stop_command_finds_task_in_unified_mode(self, tmp_path: Path):
"""cmd_stop can cancel tasks when unified_session=True."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.command.builtin import cmd_stop
loop = _make_loop(tmp_path, unified_session=True)
# Create a long-running task stored under UNIFIED_SESSION_KEY
async def long_running():
await asyncio.sleep(10) # Will be cancelled
task = asyncio.create_task(long_running())
loop._active_tasks[UNIFIED_SESSION_KEY] = [task]
# Create a message that would have session_key=UNIFIED_SESSION_KEY after dispatch
msg = InboundMessage(
channel="telegram",
chat_id="123456",
sender_id="user1",
content="/stop",
session_key_override=UNIFIED_SESSION_KEY, # Simulate post-dispatch state
)
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
# Execute /stop
result = await cmd_stop(ctx)
# Verify task was cancelled
assert task.cancelled() or task.done()
assert "Stopped 1 task" in result.content
@pytest.mark.asyncio
async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path):
"""In unified mode, /stop from one channel cancels tasks from another channel."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.command.builtin import cmd_stop
loop = _make_loop(tmp_path, unified_session=True)
# Create tasks from different channels, all stored under UNIFIED_SESSION_KEY
async def long_running():
await asyncio.sleep(10)
task1 = asyncio.create_task(long_running())
task2 = asyncio.create_task(long_running())
loop._active_tasks[UNIFIED_SESSION_KEY] = [task1, task2]
# /stop from discord should cancel tasks started from telegram
msg = InboundMessage(
channel="discord",
chat_id="789012",
sender_id="user2",
content="/stop",
session_key_override=UNIFIED_SESSION_KEY,
)
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
result = await cmd_stop(ctx)
# Both tasks should be cancelled
assert "Stopped 2 task" in result.content

View File

@ -1,6 +1,10 @@
import asyncio
import zipfile
from io import BytesIO
from types import SimpleNamespace
from unittest.mock import AsyncMock
import httpx
import pytest
# Check optional dingtalk dependencies before running tests
@ -50,6 +54,21 @@ class _FakeHttp:
return self._next_response()
class _NetworkErrorHttp:
"""HTTP client stub that raises httpx.TransportError on every request."""
def __init__(self) -> None:
self.calls: list[dict] = []
async def post(self, url: str, json=None, headers=None, **kwargs):
self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
raise httpx.ConnectError("Connection refused")
async def get(self, url: str, **kwargs):
self.calls.append({"method": "GET", "url": url})
raise httpx.ConnectError("Connection refused")
@pytest.mark.asyncio
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
@ -221,3 +240,216 @@ async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None:
assert "messageFiles/download" in channel._http.calls[0]["url"]
assert channel._http.calls[0]["json"]["downloadCode"] == "code123"
assert channel._http.calls[1]["method"] == "GET"
def test_normalize_upload_payload_zips_html_attachment() -> None:
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
MessageBus(),
)
data, filename, content_type = channel._normalize_upload_payload(
"report.html",
b"<html><body>Hello</body></html>",
"text/html",
)
assert filename == "report.zip"
assert content_type == "application/zip"
archive = zipfile.ZipFile(BytesIO(data))
assert archive.namelist() == ["report.html"]
assert archive.read("report.html") == b"<html><body>Hello</body></html>"
@pytest.mark.asyncio
async def test_send_media_ref_zips_html_before_upload(tmp_path, monkeypatch) -> None:
channel = DingTalkChannel(
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
MessageBus(),
)
html_path = tmp_path / "report.html"
html_path.write_text("<html><body>Hello</body></html>", encoding="utf-8")
captured: dict[str, object] = {}
async def fake_upload_media(*, token, data, media_type, filename, content_type):
captured.update(
{
"token": token,
"data": data,
"media_type": media_type,
"filename": filename,
"content_type": content_type,
}
)
return "media-123"
async def fake_send_batch_message(token, chat_id, msg_key, msg_param):
captured.update(
{
"sent_token": token,
"chat_id": chat_id,
"msg_key": msg_key,
"msg_param": msg_param,
}
)
return True
monkeypatch.setattr(channel, "_upload_media", fake_upload_media)
monkeypatch.setattr(channel, "_send_batch_message", fake_send_batch_message)
ok = await channel._send_media_ref("token-123", "user-1", str(html_path))
assert ok is True
assert captured["media_type"] == "file"
assert captured["filename"] == "report.zip"
assert captured["content_type"] == "application/zip"
assert captured["msg_key"] == "sampleFile"
assert captured["msg_param"] == {
"mediaId": "media-123",
"fileName": "report.zip",
"fileType": "zip",
}
archive = zipfile.ZipFile(BytesIO(captured["data"]))
assert archive.namelist() == ["report.html"]
# ── Exception handling tests ──────────────────────────────────────────
@pytest.mark.asyncio
async def test_send_batch_message_propagates_transport_error() -> None:
"""Network/transport errors must re-raise so callers can retry."""
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
channel = DingTalkChannel(config, MessageBus())
channel._http = _NetworkErrorHttp()
with pytest.raises(httpx.ConnectError, match="Connection refused"):
await channel._send_batch_message(
"token",
"user123",
"sampleMarkdown",
{"text": "hello", "title": "Nanobot Reply"},
)
# The POST was attempted exactly once
assert len(channel._http.calls) == 1
assert channel._http.calls[0]["method"] == "POST"
@pytest.mark.asyncio
async def test_send_batch_message_returns_false_on_api_error() -> None:
"""DingTalk API-level errors (non-200 status, errcode != 0) should return False."""
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
channel = DingTalkChannel(config, MessageBus())
# Non-200 status code → API error → return False
channel._http = _FakeHttp(responses=[_FakeResponse(400, {"errcode": 400})])
result = await channel._send_batch_message(
"token", "user123", "sampleMarkdown", {"text": "hello"}
)
assert result is False
# 200 with non-zero errcode → API error → return False
channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 100})])
result = await channel._send_batch_message(
"token", "user123", "sampleMarkdown", {"text": "hello"}
)
assert result is False
# 200 with errcode=0 → success → return True
channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 0})])
result = await channel._send_batch_message(
"token", "user123", "sampleMarkdown", {"text": "hello"}
)
assert result is True
@pytest.mark.asyncio
async def test_send_media_ref_short_circuits_on_transport_error() -> None:
"""When the first send fails with a transport error, _send_media_ref must
re-raise immediately instead of trying download+upload+fallback."""
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
channel = DingTalkChannel(config, MessageBus())
channel._http = _NetworkErrorHttp()
# An image URL triggers the sampleImageMsg path first
with pytest.raises(httpx.ConnectError, match="Connection refused"):
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
# Only one POST should have been attempted — no download/upload/fallback
assert len(channel._http.calls) == 1
assert channel._http.calls[0]["method"] == "POST"
@pytest.mark.asyncio
async def test_send_media_ref_short_circuits_on_download_transport_error() -> None:
"""When the image URL send returns an API error (False) but the download
for the fallback hits a transport error, it must re-raise rather than
silently returning False."""
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
channel = DingTalkChannel(config, MessageBus())
# First POST (sampleImageMsg) returns API error → False, then GET (download) raises transport error
class _MixedHttp:
def __init__(self) -> None:
self.calls: list[dict] = []
async def post(self, url, json=None, headers=None, **kwargs):
self.calls.append({"method": "POST", "url": url})
# API-level failure: 200 with errcode != 0
return _FakeResponse(200, {"errcode": 100})
async def get(self, url, **kwargs):
self.calls.append({"method": "GET", "url": url})
raise httpx.ConnectError("Connection refused")
channel._http = _MixedHttp()
with pytest.raises(httpx.ConnectError, match="Connection refused"):
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
# Should have attempted POST (image URL) and GET (download), but NOT upload
assert len(channel._http.calls) == 2
assert channel._http.calls[0]["method"] == "POST"
assert channel._http.calls[1]["method"] == "GET"
@pytest.mark.asyncio
async def test_send_media_ref_short_circuits_on_upload_transport_error() -> None:
"""When download succeeds but upload hits a transport error, must re-raise."""
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
channel = DingTalkChannel(config, MessageBus())
image_bytes = b"\xff\xd8\xff\xe0" + b"\x00" * 100 # minimal JPEG-ish data
class _UploadFailsHttp:
def __init__(self) -> None:
self.calls: list[dict] = []
async def post(self, url, json=None, headers=None, files=None, **kwargs):
self.calls.append({"method": "POST", "url": url})
# If it's the upload endpoint, raise transport error
if "media/upload" in url:
raise httpx.ConnectError("Connection refused")
# Otherwise (sampleImageMsg), return API error to trigger fallback
return _FakeResponse(200, {"errcode": 100})
async def get(self, url, **kwargs):
self.calls.append({"method": "GET", "url": url})
resp = _FakeResponse(200)
resp.content = image_bytes
resp.headers = {"content-type": "image/jpeg"}
return resp
channel._http = _UploadFailsHttp()
with pytest.raises(httpx.ConnectError, match="Connection refused"):
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
# POST (image URL), GET (download), POST (upload) attempted — no further sends
methods = [c["method"] for c in channel._http.calls]
assert methods == ["POST", "GET", "POST"]

View File

@ -5,11 +5,17 @@ from pathlib import Path
from types import SimpleNamespace
import pytest
discord = pytest.importorskip("discord")
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig
from nanobot.channels.discord import (
MAX_MESSAGE_LEN,
DiscordBotClient,
DiscordChannel,
DiscordConfig,
)
from nanobot.command.builtin import build_help_text
@ -18,9 +24,11 @@ class _FakeDiscordClient:
instances: list["_FakeDiscordClient"] = []
start_error: Exception | None = None
def __init__(self, owner, *, intents) -> None:
def __init__(self, owner, *, intents, proxy=None, proxy_auth=None) -> None:
self.owner = owner
self.intents = intents
self.proxy = proxy
self.proxy_auth = proxy_auth
self.closed = False
self.ready = True
self.channels: dict[int, object] = {}
@ -53,7 +61,9 @@ class _FakeDiscordClient:
class _FakeAttachment:
# Attachment double that can simulate successful or failing save() calls.
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
def __init__(
self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False
) -> None:
self.id = attachment_id
self.filename = filename
self.size = size
@ -71,11 +81,25 @@ class _FakePartialMessage:
self.id = message_id
class _FakeSentMessage:
# Sent-message double supporting edit() for streaming tests.
def __init__(self, channel, content: str) -> None:
self.channel = channel
self.content = content
self.edits: list[dict] = []
async def edit(self, **kwargs) -> None:
self.edits.append(dict(kwargs))
if "content" in kwargs:
self.content = kwargs["content"]
class _FakeChannel:
# Channel double that records outbound payloads and typing activity.
def __init__(self, channel_id: int = 123) -> None:
self.id = channel_id
self.sent_payloads: list[dict] = []
self.sent_messages: list[_FakeSentMessage] = []
self.trigger_typing_calls = 0
self.typing_enter_hook = None
@ -85,6 +109,9 @@ class _FakeChannel:
payload["file_name"] = payload["file"].filename
del payload["file"]
self.sent_payloads.append(payload)
message = _FakeSentMessage(self, payload.get("content", ""))
self.sent_messages.append(message)
return message
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
return _FakePartialMessage(message_id)
@ -194,7 +221,7 @@ async def test_start_handles_client_construction_failure(monkeypatch) -> None:
MessageBus(),
)
def _boom(owner, *, intents):
def _boom(owner, *, intents, proxy=None, proxy_auth=None):
raise RuntimeError("bad client")
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
@ -427,6 +454,60 @@ async def test_send_fetches_channel_when_not_cached() -> None:
assert target.sent_payloads == [{"content": "hello"}]
def test_supports_streaming_enabled_by_default() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
assert channel.supports_streaming is True
@pytest.mark.asyncio
async def test_send_delta_streams_by_editing_message(monkeypatch) -> None:
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(owner, intents=None)
owner._client = client
owner._running = True
target = _FakeChannel(channel_id=123)
client.channels[123] = target
times = iter([1.0, 3.0, 5.0])
monkeypatch.setattr("nanobot.channels.discord.time.monotonic", lambda: next(times, 5.0))
await owner.send_delta("123", "hel", {"_stream_delta": True, "_stream_id": "s1"})
await owner.send_delta("123", "lo", {"_stream_delta": True, "_stream_id": "s1"})
await owner.send_delta("123", "", {"_stream_end": True, "_stream_id": "s1"})
assert target.sent_payloads[0] == {"content": "hel"}
assert target.sent_messages[0].edits == [{"content": "hello"}, {"content": "hello"}]
assert owner._stream_bufs == {}
@pytest.mark.asyncio
async def test_send_delta_stream_end_splits_oversized_reply(monkeypatch) -> None:
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(owner, intents=None)
owner._client = client
owner._running = True
target = _FakeChannel(channel_id=123)
client.channels[123] = target
prefix = "a" * (MAX_MESSAGE_LEN - 100)
suffix = "b" * 150
full_text = prefix + suffix
chunks = DiscordBotClient._build_chunks(full_text, [], False)
assert len(chunks) == 2
times = iter([1.0, 3.0])
monkeypatch.setattr("nanobot.channels.discord.time.monotonic", lambda: next(times, 3.0))
await owner.send_delta("123", prefix, {"_stream_delta": True, "_stream_id": "s1"})
await owner.send_delta("123", suffix, {"_stream_delta": True, "_stream_id": "s1"})
await owner.send_delta("123", "", {"_stream_end": True, "_stream_id": "s1"})
assert target.sent_payloads == [{"content": prefix}, {"content": chunks[1]}]
assert target.sent_messages[0].edits == [{"content": chunks[0]}, {"content": chunks[0]}]
assert owner._stream_bufs == {}
@pytest.mark.asyncio
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
@ -443,9 +524,7 @@ async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
assert new_cmd is not None
await new_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": "Processing /new...", "ephemeral": True}
]
assert interaction.response.messages == [{"content": "Processing /new...", "ephemeral": True}]
assert len(handled) == 1
assert handled[0]["content"] == "/new"
assert handled[0]["sender_id"] == "123"
@ -519,9 +598,7 @@ async def test_slash_help_returns_ephemeral_help_text() -> None:
assert help_cmd is not None
await help_cmd.callback(interaction)
assert interaction.response.messages == [
{"content": build_help_text(), "ephemeral": True}
]
assert interaction.response.messages == [{"content": build_help_text(), "ephemeral": True}]
assert handled == []
@ -656,11 +733,13 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
def typing(self):
async def _waiter():
await release.wait()
# Hold the loop so task remains active until explicitly stopped.
class _Ctx(_TypingCtx):
async def __aenter__(self):
await super().__aenter__()
await _waiter()
return _Ctx()
typing_channel = _NoTriggerChannel(channel_id=123)
@ -674,3 +753,214 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
await asyncio.sleep(0)
assert channel._typing_tasks == {}
def test_config_accepts_proxy_fields() -> None:
config = DiscordConfig(
enabled=True,
token="token",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
proxy_username="user",
proxy_password="pass",
)
assert config.proxy == "http://127.0.0.1:7890"
assert config.proxy_username == "user"
assert config.proxy_password == "pass"
def test_config_proxy_defaults_to_none() -> None:
config = DiscordConfig(enabled=True, token="token", allow_from=["*"])
assert config.proxy is None
assert config.proxy_username is None
assert config.proxy_password is None
@pytest.mark.asyncio
async def test_start_passes_proxy_to_client(monkeypatch) -> None:
_FakeDiscordClient.instances.clear()
channel = DiscordChannel(
DiscordConfig(
enabled=True,
token="token",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert len(_FakeDiscordClient.instances) == 1
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
assert _FakeDiscordClient.instances[0].proxy_auth is None
@pytest.mark.asyncio
async def test_start_passes_proxy_auth_when_credentials_provided(monkeypatch) -> None:
aiohttp = pytest.importorskip("aiohttp")
_FakeDiscordClient.instances.clear()
channel = DiscordChannel(
DiscordConfig(
enabled=True,
token="token",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
proxy_username="user",
proxy_password="pass",
),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert len(_FakeDiscordClient.instances) == 1
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
assert _FakeDiscordClient.instances[0].proxy_auth is not None
assert isinstance(_FakeDiscordClient.instances[0].proxy_auth, aiohttp.BasicAuth)
assert _FakeDiscordClient.instances[0].proxy_auth.login == "user"
assert _FakeDiscordClient.instances[0].proxy_auth.password == "pass"
@pytest.mark.asyncio
async def test_start_no_proxy_auth_when_only_username(monkeypatch) -> None:
_FakeDiscordClient.instances.clear()
channel = DiscordChannel(
DiscordConfig(
enabled=True,
token="token",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
proxy_username="user",
),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert _FakeDiscordClient.instances[0].proxy_auth is None
@pytest.mark.asyncio
async def test_start_no_proxy_auth_when_only_password(monkeypatch) -> None:
_FakeDiscordClient.instances.clear()
channel = DiscordChannel(
DiscordConfig(
enabled=True,
token="token",
allow_from=["*"],
proxy="http://127.0.0.1:7890",
proxy_password="pass",
),
MessageBus(),
)
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
await channel.start()
assert channel.is_running is False
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
assert _FakeDiscordClient.instances[0].proxy_auth is None
# ---------------------------------------------------------------------------
# Tests for the send() exception propagation fix
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_re_raises_network_error() -> None:
"""Network errors during send must propagate so ChannelManager can retry."""
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
async def _failing_send_outbound(msg: OutboundMessage) -> None:
raise ConnectionError("network unreachable")
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
with pytest.raises(ConnectionError, match="network unreachable"):
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
@pytest.mark.asyncio
async def test_send_re_raises_generic_exception() -> None:
"""Any exception from send_outbound must propagate, not be swallowed."""
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
async def _failing_send_outbound(msg: OutboundMessage) -> None:
raise RuntimeError("discord API failure")
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
with pytest.raises(RuntimeError, match="discord API failure"):
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
@pytest.mark.asyncio
async def test_send_still_stops_typing_on_error() -> None:
"""Typing cleanup must still run in the finally block even when send raises."""
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
# Start a typing task so we can verify it gets cleaned up
start = asyncio.Event()
release = asyncio.Event()
async def slow_typing() -> None:
start.set()
await release.wait()
typing_channel = _FakeChannel(channel_id=123)
typing_channel.typing_enter_hook = slow_typing
await channel._start_typing(typing_channel)
await asyncio.wait_for(start.wait(), timeout=1.0)
async def _failing_send_outbound(msg: OutboundMessage) -> None:
raise ConnectionError("timeout")
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
with pytest.raises(ConnectionError, match="timeout"):
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
release.set()
await asyncio.sleep(0)
# Typing should have been cleaned up by the finally block
assert channel._typing_tasks == {}
@pytest.mark.asyncio
async def test_send_succeeds_normally() -> None:
"""Successful sends should work without raising."""
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
client = _FakeDiscordClient(channel, intents=None)
channel._client = client
channel._running = True
sent_messages: list[OutboundMessage] = []
async def _capture_send_outbound(msg: OutboundMessage) -> None:
sent_messages.append(msg)
client.send_outbound = _capture_send_outbound # type: ignore[method-assign]
msg = OutboundMessage(channel="discord", chat_id="123", content="hello world")
await channel.send(msg)
assert len(sent_messages) == 1
assert sent_messages[0].content == "hello world"
assert sent_messages[0].chat_id == "123"

View File

@ -0,0 +1,48 @@
"""Tests for Feishu/Lark domain configuration."""
from unittest.mock import MagicMock
import pytest
from nanobot.bus.queue import MessageBus
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
def _make_channel(domain: str = "feishu") -> FeishuChannel:
config = FeishuConfig(
enabled=True,
app_id="cli_test",
app_secret="secret",
allow_from=["*"],
domain=domain,
)
ch = FeishuChannel(config, MessageBus())
ch._client = MagicMock()
ch._loop = None
return ch
class TestFeishuConfigDomain:
def test_domain_default_is_feishu(self):
config = FeishuConfig()
assert config.domain == "feishu"
def test_domain_accepts_lark(self):
config = FeishuConfig(domain="lark")
assert config.domain == "lark"
def test_domain_accepts_feishu(self):
config = FeishuConfig(domain="feishu")
assert config.domain == "feishu"
def test_default_config_includes_domain(self):
default_cfg = FeishuChannel.default_config()
assert "domain" in default_cfg
assert default_cfg["domain"] == "feishu"
def test_channel_persists_domain_from_config(self):
ch = _make_channel(domain="lark")
assert ch.config.domain == "lark"
def test_channel_persists_feishu_domain_from_config(self):
ch = _make_channel(domain="feishu")
assert ch.config.domain == "feishu"

View File

@ -5,6 +5,7 @@ from unittest.mock import MagicMock
import pytest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.feishu import FeishuChannel, FeishuConfig, _FeishuStreamBuf
@ -203,6 +204,24 @@ class TestSendDelta:
ch._client.cardkit.v1.card_element.content.assert_not_called()
ch._client.im.v1.message.create.assert_called_once()
@pytest.mark.asyncio
async def test_stream_end_fallback_when_final_update_fails(self):
"""If streaming mode was closed (e.g. Feishu timeout), fall back to a regular card."""
ch = _make_channel()
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
text="Lost content", card_id="card_1", sequence=3, last_edit=0.0,
)
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response(success=False)
ch._client.im.v1.message.create.return_value = _mock_send_response("om_fb")
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
assert "oc_chat1" not in ch._stream_bufs
# Should NOT attempt to close streaming mode since update failed
ch._client.cardkit.v1.card.settings.assert_not_called()
# Should fall back to sending a regular interactive card
ch._client.im.v1.message.create.assert_called_once()
@pytest.mark.asyncio
async def test_stream_end_without_buf_is_noop(self):
ch = _make_channel()
@ -239,6 +258,130 @@ class TestSendDelta:
assert buf.sequence == 7
class TestToolHintInlineStreaming:
"""Tool hint messages should be inlined into active streaming cards."""
@pytest.mark.asyncio
async def test_tool_hint_inlined_when_stream_active(self):
"""With an active streaming buffer, tool hint appends to the card."""
ch = _make_channel()
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
)
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
msg = OutboundMessage(
channel="feishu", chat_id="oc_chat1",
content='web_fetch("https://example.com")',
metadata={"_tool_hint": True},
)
await ch.send(msg)
buf = ch._stream_bufs["oc_chat1"]
assert '🔧 web_fetch("https://example.com")' in buf.text
assert buf.sequence == 3
ch._client.cardkit.v1.card_element.content.assert_called_once()
ch._client.im.v1.message.create.assert_not_called()
@pytest.mark.asyncio
async def test_tool_hint_preserved_on_next_delta(self):
"""When new delta arrives, the tool hint is kept as permanent content and delta appends after it."""
ch = _make_channel()
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
text="Partial answer\n\n🔧 web_fetch(\"url\")\n\n",
card_id="card_1", sequence=3, last_edit=0.0,
)
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
await ch.send_delta("oc_chat1", " continued")
buf = ch._stream_bufs["oc_chat1"]
assert "Partial answer" in buf.text
assert "🔧 web_fetch" in buf.text
assert buf.text.endswith(" continued")
@pytest.mark.asyncio
async def test_tool_hint_fallback_when_no_stream(self):
"""Without an active buffer, tool hint falls back to a standalone card."""
ch = _make_channel()
ch._client.im.v1.message.create.return_value = _mock_send_response("om_hint")
msg = OutboundMessage(
channel="feishu", chat_id="oc_chat1",
content='read_file("path")',
metadata={"_tool_hint": True},
)
await ch.send(msg)
assert "oc_chat1" not in ch._stream_bufs
ch._client.im.v1.message.create.assert_called_once()
@pytest.mark.asyncio
async def test_consecutive_tool_hints_append(self):
"""When multiple tool hints arrive consecutively, each appends to the card."""
ch = _make_channel()
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
)
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
msg1 = OutboundMessage(
channel="feishu", chat_id="oc_chat1",
content='$ cd /project', metadata={"_tool_hint": True},
)
await ch.send(msg1)
msg2 = OutboundMessage(
channel="feishu", chat_id="oc_chat1",
content='$ git status', metadata={"_tool_hint": True},
)
await ch.send(msg2)
buf = ch._stream_bufs["oc_chat1"]
assert "$ cd /project" in buf.text
assert "$ git status" in buf.text
assert buf.text.startswith("Partial answer")
assert "🔧 $ cd /project" in buf.text
assert "🔧 $ git status" in buf.text
@pytest.mark.asyncio
async def test_tool_hint_preserved_on_final_stream_end(self):
"""When final _stream_end closes the card, tool hint is kept in the final text."""
ch = _make_channel()
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
text="Final content\n\n🔧 web_fetch(\"url\")\n\n",
card_id="card_1", sequence=3, last_edit=0.0,
)
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response()
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
assert "oc_chat1" not in ch._stream_bufs
update_call = ch._client.cardkit.v1.card_element.content.call_args[0][0]
assert "🔧" in update_call.body.content
@pytest.mark.asyncio
async def test_empty_tool_hint_is_noop(self):
"""Empty or whitespace-only tool hint content is silently ignored."""
ch = _make_channel()
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
)
for content in ("", " ", "\t\n"):
msg = OutboundMessage(
channel="feishu", chat_id="oc_chat1",
content=content, metadata={"_tool_hint": True},
)
await ch.send(msg)
buf = ch._stream_bufs["oc_chat1"]
assert buf.text == "Partial answer"
assert buf.sequence == 2
ch._client.cardkit.v1.card_element.content.assert_not_called()
class TestSendMessageReturnsId:
def test_returns_message_id_on_success(self):
ch = _make_channel()

View File

@ -1,6 +1,7 @@
"""Tests for FeishuChannel tool hint code block formatting."""
"""Tests for FeishuChannel tool hint formatting."""
import json
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
@ -28,15 +29,24 @@ def mock_feishu_channel():
config.app_secret = "test_app_secret"
config.encrypt_key = None
config.verification_token = None
config.tool_hint_prefix = "\U0001f527" # 🔧
bus = MagicMock()
channel = FeishuChannel(config, bus)
channel._client = MagicMock() # Simulate initialized client
channel._client = MagicMock()
return channel
def _get_tool_hint_card(mock_send):
"""Extract the interactive card from _send_message_sync calls."""
call_args = mock_send.call_args[0]
_, _, msg_type, content = call_args
assert msg_type == "interactive"
return json.loads(content)
@mark.asyncio
async def test_tool_hint_sends_code_message(mock_feishu_channel):
"""Tool hint messages should be sent as interactive cards with code blocks."""
async def test_tool_hint_sends_interactive_card(mock_feishu_channel):
"""Tool hint without active buffer sends an interactive card with 🔧 style."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
@ -47,23 +57,12 @@ async def test_tool_hint_sends_code_message(mock_feishu_channel):
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
# Verify interactive message with card was sent
assert mock_send.call_count == 1
call_args = mock_send.call_args[0]
receive_id_type, receive_id, msg_type, content = call_args
assert receive_id_type == "chat_id"
assert receive_id == "oc_123456"
assert msg_type == "interactive"
# Parse content to verify card structure
card = json.loads(content)
card = _get_tool_hint_card(mock_send)
assert card["config"]["wide_screen_mode"] is True
assert len(card["elements"]) == 1
assert card["elements"][0]["tag"] == "markdown"
# Check that code block is properly formatted with language hint
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"test query\")\n```"
assert card["elements"][0]["content"] == expected_md
md = card["elements"][0]["content"]
assert "\U0001f527" in md
assert "web_search" in md
@mark.asyncio
@ -78,8 +77,6 @@ async def test_tool_hint_empty_content_does_not_send(mock_feishu_channel):
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
# Should not send any message
mock_send.assert_not_called()
@ -96,7 +93,6 @@ async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel):
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
# Should send as text message (detected format)
assert mock_send.call_count == 1
call_args = mock_send.call_args[0]
_, _, msg_type, content = call_args
@ -106,7 +102,7 @@ async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel):
@mark.asyncio
async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
"""Multiple tool calls should be displayed each on its own line in a code block."""
"""Multiple tool calls should each get the 🔧 prefix."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
@ -117,13 +113,11 @@ async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
call_args = mock_send.call_args[0]
msg_type = call_args[2]
content = json.loads(call_args[3])
assert msg_type == "interactive"
# Each tool call should be on its own line
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"query\"),\nread_file(\"/path/to/file\")\n```"
assert content["elements"][0]["content"] == expected_md
card = _get_tool_hint_card(mock_send)
md = card["elements"][0]["content"]
assert "web_search" in md
assert "read_file" in md
assert "\U0001f527" in md
@mark.asyncio
@ -139,8 +133,8 @@ async def test_tool_hint_new_format_basic(mock_feishu_channel):
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
content = json.loads(mock_send.call_args[0][3])
md = content["elements"][0]["content"]
card = _get_tool_hint_card(mock_send)
md = card["elements"][0]["content"]
assert "read src/main.py" in md
assert 'grep "TODO"' in md
@ -158,16 +152,15 @@ async def test_tool_hint_new_format_with_comma_in_quotes(mock_feishu_channel):
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
content = json.loads(mock_send.call_args[0][3])
md = content["elements"][0]["content"]
# The comma inside quotes should NOT cause a line break
card = _get_tool_hint_card(mock_send)
md = card["elements"][0]["content"]
assert 'grep "hello, world"' in md
assert "$ echo test" in md
@mark.asyncio
async def test_tool_hint_new_format_with_folding(mock_feishu_channel):
"""Folded calls (× N) should display on separate lines."""
"""Folded calls (× N) should display correctly."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
@ -178,8 +171,8 @@ async def test_tool_hint_new_format_with_folding(mock_feishu_channel):
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
content = json.loads(mock_send.call_args[0][3])
md = content["elements"][0]["content"]
card = _get_tool_hint_card(mock_send)
md = card["elements"][0]["content"]
assert "\u00d7 3" in md
assert 'grep "pattern"' in md
@ -197,9 +190,12 @@ async def test_tool_hint_new_format_mcp(mock_feishu_channel):
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
content = json.loads(mock_send.call_args[0][3])
md = content["elements"][0]["content"]
card = _get_tool_hint_card(mock_send)
md = card["elements"][0]["content"]
assert "4_5v::analyze_image" in md
@mark.asyncio
async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
"""Commas inside a single tool argument must not be split onto a new line."""
msg = OutboundMessage(
@ -212,10 +208,7 @@ async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
content = json.loads(mock_send.call_args[0][3])
expected_md = (
"**Tool Calls**\n\n```text\n"
"web_search(\"foo, bar\"),\n"
"read_file(\"/path/to/file\")\n```"
)
assert content["elements"][0]["content"] == expected_md
card = _get_tool_hint_card(mock_send)
md = card["elements"][0]["content"]
assert 'web_search("foo, bar")' in md
assert 'read_file("/path/to/file")' in md

View File

@ -1,6 +1,7 @@
import tempfile
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
@ -14,6 +15,8 @@ except ImportError:
if not QQ_AVAILABLE:
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
import aiohttp
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.qq import QQChannel, QQConfig
@ -170,3 +173,221 @@ async def test_read_media_bytes_missing_file() -> None:
data, filename = await channel._read_media_bytes("/nonexistent/path/image.png")
assert data is None
assert filename is None
# -------------------------------------------------------
# Tests for _send_media exception handling
# -------------------------------------------------------
def _make_channel_with_local_file(suffix: str = ".png", content: bytes = b"\x89PNG\r\n"):
"""Create a QQChannel with a fake client and a temp file for media."""
channel = QQChannel(
QQConfig(app_id="app", secret="secret", allow_from=["*"]),
MessageBus(),
)
channel._client = _FakeClient()
channel._chat_type_cache["user1"] = "c2c"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp.write(content)
tmp.close()
return channel, tmp.name
@pytest.mark.asyncio
async def test_send_media_network_error_propagates() -> None:
"""aiohttp.ClientError (network/transport) should re-raise, not return False."""
channel, tmp_path = _make_channel_with_local_file()
# Make the base64 upload raise a network error
channel._client.api._http = SimpleNamespace()
channel._client.api._http.request = AsyncMock(
side_effect=aiohttp.ServerDisconnectedError("connection lost"),
)
with pytest.raises(aiohttp.ServerDisconnectedError):
await channel._send_media(
chat_id="user1",
media_ref=tmp_path,
msg_id="msg1",
is_group=False,
)
@pytest.mark.asyncio
async def test_send_media_client_connector_error_propagates() -> None:
"""aiohttp.ClientConnectorError (DNS/connection refused) should re-raise."""
channel, tmp_path = _make_channel_with_local_file()
from aiohttp.client_reqrep import ConnectionKey
conn_key = ConnectionKey("api.qq.com", 443, True, None, None, None, None)
connector_error = aiohttp.ClientConnectorError(
connection_key=conn_key,
os_error=OSError("Connection refused"),
)
channel._client.api._http = SimpleNamespace()
channel._client.api._http.request = AsyncMock(
side_effect=connector_error,
)
with pytest.raises(aiohttp.ClientConnectorError):
await channel._send_media(
chat_id="user1",
media_ref=tmp_path,
msg_id="msg1",
is_group=False,
)
@pytest.mark.asyncio
async def test_send_media_oserror_propagates() -> None:
"""OSError (low-level I/O) should re-raise for retry."""
channel, tmp_path = _make_channel_with_local_file()
channel._client.api._http = SimpleNamespace()
channel._client.api._http.request = AsyncMock(
side_effect=OSError("Network is unreachable"),
)
with pytest.raises(OSError):
await channel._send_media(
chat_id="user1",
media_ref=tmp_path,
msg_id="msg1",
is_group=False,
)
@pytest.mark.asyncio
async def test_send_media_api_error_returns_false() -> None:
"""API-level errors (botpy RuntimeError subclasses) should return False, not raise."""
channel, tmp_path = _make_channel_with_local_file()
# Simulate a botpy API error (e.g. ServerError is a RuntimeError subclass)
from botpy.errors import ServerError
channel._client.api._http = SimpleNamespace()
channel._client.api._http.request = AsyncMock(
side_effect=ServerError("internal server error"),
)
result = await channel._send_media(
chat_id="user1",
media_ref=tmp_path,
msg_id="msg1",
is_group=False,
)
assert result is False
@pytest.mark.asyncio
async def test_send_media_generic_runtime_error_returns_false() -> None:
"""Generic RuntimeError (not network) should return False."""
channel, tmp_path = _make_channel_with_local_file()
channel._client.api._http = SimpleNamespace()
channel._client.api._http.request = AsyncMock(
side_effect=RuntimeError("some API error"),
)
result = await channel._send_media(
chat_id="user1",
media_ref=tmp_path,
msg_id="msg1",
is_group=False,
)
assert result is False
@pytest.mark.asyncio
async def test_send_media_value_error_returns_false() -> None:
"""ValueError (bad API response data) should return False."""
channel, tmp_path = _make_channel_with_local_file()
channel._client.api._http = SimpleNamespace()
channel._client.api._http.request = AsyncMock(
side_effect=ValueError("bad response data"),
)
result = await channel._send_media(
chat_id="user1",
media_ref=tmp_path,
msg_id="msg1",
is_group=False,
)
assert result is False
@pytest.mark.asyncio
async def test_send_media_timeout_error_propagates() -> None:
"""asyncio.TimeoutError inherits from Exception but not ClientError/OSError.
However, aiohttp.ServerTimeoutError IS a ClientError subclass, so that propagates.
For a plain TimeoutError (which is also OSError in Python 3.11+), it should propagate."""
channel, tmp_path = _make_channel_with_local_file()
channel._client.api._http = SimpleNamespace()
channel._client.api._http.request = AsyncMock(
side_effect=aiohttp.ServerTimeoutError("request timed out"),
)
with pytest.raises(aiohttp.ServerTimeoutError):
await channel._send_media(
chat_id="user1",
media_ref=tmp_path,
msg_id="msg1",
is_group=False,
)
@pytest.mark.asyncio
async def test_send_fallback_text_on_api_error() -> None:
"""When _send_media returns False (API error), send() should emit fallback text."""
channel, tmp_path = _make_channel_with_local_file()
from botpy.errors import ServerError
channel._client.api._http = SimpleNamespace()
channel._client.api._http.request = AsyncMock(
side_effect=ServerError("internal server error"),
)
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user1",
content="",
media=[tmp_path],
metadata={"message_id": "msg1"},
)
)
# Should have sent a fallback text message
assert len(channel._client.api.c2c_calls) == 1
fallback_content = channel._client.api.c2c_calls[0]["content"]
assert "Attachment send failed" in fallback_content
@pytest.mark.asyncio
async def test_send_propagates_network_error_no_fallback() -> None:
"""When _send_media raises a network error, send() should NOT silently fallback."""
channel, tmp_path = _make_channel_with_local_file()
channel._client.api._http = SimpleNamespace()
channel._client.api._http.request = AsyncMock(
side_effect=aiohttp.ServerDisconnectedError("connection lost"),
)
with pytest.raises(aiohttp.ServerDisconnectedError):
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user1",
content="hello",
media=[tmp_path],
metadata={"message_id": "msg1"},
)
)
# No fallback text should have been sent
assert len(channel._client.api.c2c_calls) == 0

View File

@ -0,0 +1,304 @@
"""Tests for QQ channel media support: helpers, send, inbound, and upload."""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
try:
from nanobot.channels import qq
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
except ImportError:
QQ_AVAILABLE = False
if not QQ_AVAILABLE:
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.qq import (
QQ_FILE_TYPE_FILE,
QQ_FILE_TYPE_IMAGE,
QQChannel,
QQConfig,
_guess_send_file_type,
_is_image_name,
_sanitize_filename,
)
class _FakeApi:
def __init__(self) -> None:
self.c2c_calls: list[dict] = []
self.group_calls: list[dict] = []
async def post_c2c_message(self, **kwargs) -> None:
self.c2c_calls.append(kwargs)
async def post_group_message(self, **kwargs) -> None:
self.group_calls.append(kwargs)
class _FakeHttp:
"""Fake _http for _post_base64file tests."""
def __init__(self, return_value: dict | None = None) -> None:
self.return_value = return_value or {}
self.calls: list[tuple] = []
async def request(self, route, **kwargs):
self.calls.append((route, kwargs))
return self.return_value
class _FakeClient:
def __init__(self, http_return: dict | None = None) -> None:
self.api = _FakeApi()
self.api._http = _FakeHttp(http_return)
# ── Helper function tests (pure, no async) ──────────────────────────
def test_sanitize_filename_strips_path_traversal() -> None:
assert _sanitize_filename("../../etc/passwd") == "passwd"
def test_sanitize_filename_keeps_chinese_chars() -> None:
assert _sanitize_filename("文件1.jpg") == "文件1.jpg"
def test_sanitize_filename_strips_unsafe_chars() -> None:
result = _sanitize_filename('file<>:"|?*.txt')
# All unsafe chars replaced with "_", but * is replaced too
assert result.startswith("file")
assert result.endswith(".txt")
assert "<" not in result
assert ">" not in result
assert '"' not in result
assert "|" not in result
assert "?" not in result
def test_sanitize_filename_empty_input() -> None:
assert _sanitize_filename("") == ""
def test_is_image_name_with_known_extensions() -> None:
for ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".ico", ".svg"):
assert _is_image_name(f"photo{ext}") is True
def test_is_image_name_with_unknown_extension() -> None:
for ext in (".pdf", ".txt", ".mp3", ".mp4"):
assert _is_image_name(f"doc{ext}") is False
def test_guess_send_file_type_image() -> None:
assert _guess_send_file_type("photo.png") == QQ_FILE_TYPE_IMAGE
assert _guess_send_file_type("pic.jpg") == QQ_FILE_TYPE_IMAGE
def test_guess_send_file_type_file() -> None:
assert _guess_send_file_type("doc.pdf") == QQ_FILE_TYPE_FILE
def test_guess_send_file_type_by_mime() -> None:
# A filename with no known extension but whose mime type is image/*
assert _guess_send_file_type("photo.xyz_image_test") == QQ_FILE_TYPE_FILE
# ── send() exception handling ───────────────────────────────────────
@pytest.mark.asyncio
async def test_send_exception_caught_not_raised() -> None:
"""Exceptions inside send() must not propagate."""
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with patch.object(channel, "_send_text_only", new_callable=AsyncMock, side_effect=RuntimeError("boom")):
await channel.send(
OutboundMessage(channel="qq", chat_id="user1", content="hello")
)
# No exception raised — test passes if we get here.
@pytest.mark.asyncio
async def test_send_media_then_text() -> None:
"""Media is sent before text when both are present."""
import tempfile
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
with patch.object(channel, "_post_base64file", new_callable=AsyncMock, return_value={"file_info": "1"}) as mock_upload:
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user1",
content="text after image",
media=[tmp],
metadata={"message_id": "m1"},
)
)
assert mock_upload.called
# Text should have been sent via c2c (default chat type)
text_calls = [c for c in channel._client.api.c2c_calls if c.get("msg_type") == 0]
assert len(text_calls) >= 1
assert text_calls[-1]["content"] == "text after image"
finally:
import os
os.unlink(tmp)
@pytest.mark.asyncio
async def test_send_media_failure_falls_back_to_text() -> None:
"""When _send_media returns False, a failure notice is appended."""
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with patch.object(channel, "_send_media", new_callable=AsyncMock, return_value=False):
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user1",
content="hello",
media=["https://example.com/bad.png"],
metadata={"message_id": "m1"},
)
)
# Should have the failure text among the c2c calls
failure_calls = [c for c in channel._client.api.c2c_calls if "Attachment send failed" in c.get("content", "")]
assert len(failure_calls) == 1
assert "bad.png" in failure_calls[0]["content"]
# ── _on_message() exception handling ────────────────────────────────
@pytest.mark.asyncio
async def test_on_message_exception_caught_not_raised() -> None:
"""Missing required attributes should not crash _on_message."""
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
# Construct a message-like object that lacks 'author' — triggers AttributeError
bad_data = SimpleNamespace(id="x1", content="hi")
# Should not raise
await channel._on_message(bad_data, is_group=False)
@pytest.mark.asyncio
async def test_on_message_with_attachments() -> None:
"""Messages with attachments produce media_paths and formatted content."""
import tempfile
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
saved_path = f.name
att = SimpleNamespace(url="", filename="screenshot.png", content_type="image/png")
# Patch _download_to_media_dir_chunked to return the temp file path
async def fake_download(url, filename_hint=""):
return saved_path
try:
with patch.object(channel, "_download_to_media_dir_chunked", side_effect=fake_download):
data = SimpleNamespace(
id="att1",
content="look at this",
author=SimpleNamespace(user_openid="u1"),
attachments=[att],
)
await channel._on_message(data, is_group=False)
msg = await channel.bus.consume_inbound()
assert "look at this" in msg.content
assert "screenshot.png" in msg.content
assert "Received files:" in msg.content
assert len(msg.media) == 1
assert msg.media[0] == saved_path
finally:
import os
os.unlink(saved_path)
# ── _post_base64file() ─────────────────────────────────────────────
@pytest.mark.asyncio
async def test_post_base64file_omits_file_name_for_images() -> None:
"""file_type=1 (image) → payload must not contain file_name."""
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
channel._client = _FakeClient(http_return={"file_info": "img_abc"})
await channel._post_base64file(
chat_id="user1",
is_group=False,
file_type=QQ_FILE_TYPE_IMAGE,
file_data="ZmFrZQ==",
file_name="photo.png",
)
http = channel._client.api._http
assert len(http.calls) == 1
payload = http.calls[0][1]["json"]
assert "file_name" not in payload
assert payload["file_type"] == QQ_FILE_TYPE_IMAGE
@pytest.mark.asyncio
async def test_post_base64file_includes_file_name_for_files() -> None:
"""file_type=4 (file) → payload must contain file_name."""
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
channel._client = _FakeClient(http_return={"file_info": "file_abc"})
await channel._post_base64file(
chat_id="user1",
is_group=False,
file_type=QQ_FILE_TYPE_FILE,
file_data="ZmFrZQ==",
file_name="report.pdf",
)
http = channel._client.api._http
assert len(http.calls) == 1
payload = http.calls[0][1]["json"]
assert payload["file_name"] == "report.pdf"
assert payload["file_type"] == QQ_FILE_TYPE_FILE
@pytest.mark.asyncio
async def test_post_base64file_filters_response_to_file_info() -> None:
"""Response with file_info + extra fields must be filtered to only file_info."""
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
channel._client = _FakeClient(http_return={
"file_info": "fi_123",
"file_uuid": "uuid_xxx",
"ttl": 3600,
})
result = await channel._post_base64file(
chat_id="user1",
is_group=False,
file_type=QQ_FILE_TYPE_FILE,
file_data="ZmFrZQ==",
file_name="doc.pdf",
)
assert result == {"file_info": "fi_123"}
assert "file_uuid" not in result
assert "ttl" not in result

View File

@ -10,8 +10,7 @@ except ImportError:
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.slack import SlackChannel
from nanobot.channels.slack import SlackConfig
from nanobot.channels.slack import SlackChannel, SlackConfig
class _FakeAsyncWebClient:
@ -20,6 +19,12 @@ class _FakeAsyncWebClient:
self.file_upload_calls: list[dict[str, object | None]] = []
self.reactions_add_calls: list[dict[str, object | None]] = []
self.reactions_remove_calls: list[dict[str, object | None]] = []
self.conversations_list_calls: list[dict[str, object | None]] = []
self.users_list_calls: list[dict[str, object | None]] = []
self.conversations_open_calls: list[dict[str, object | None]] = []
self._conversations_pages: list[dict[str, object]] = []
self._users_pages: list[dict[str, object]] = []
self._open_dm_response: dict[str, object] = {"channel": {"id": "D_OPENED"}}
async def chat_postMessage(
self,
@ -81,6 +86,22 @@ class _FakeAsyncWebClient:
}
)
async def conversations_list(self, **kwargs):
self.conversations_list_calls.append(kwargs)
if self._conversations_pages:
return self._conversations_pages.pop(0)
return {"channels": [], "response_metadata": {"next_cursor": ""}}
async def users_list(self, **kwargs):
self.users_list_calls.append(kwargs)
if self._users_pages:
return self._users_pages.pop(0)
return {"members": [], "response_metadata": {"next_cursor": ""}}
async def conversations_open(self, **kwargs):
self.conversations_open_calls.append(kwargs)
return self._open_dm_response
@pytest.mark.asyncio
async def test_send_uses_thread_for_channel_messages() -> None:
@ -151,3 +172,147 @@ async def test_send_updates_reaction_when_final_response_sent() -> None:
assert fake_web.reactions_add_calls == [
{"channel": "C123", "name": "white_check_mark", "timestamp": "1700000000.000100"}
]
@pytest.mark.asyncio
async def test_send_resolves_channel_name_to_channel_id() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
fake_web._conversations_pages = [
{
"channels": [{"id": "C999", "name": "channel_x"}],
"response_metadata": {"next_cursor": ""},
}
]
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="#channel_x",
content="hello",
)
)
assert fake_web.chat_post_calls == [
{"channel": "C999", "text": "hello\n", "thread_ts": None}
]
assert len(fake_web.conversations_list_calls) == 1
@pytest.mark.asyncio
async def test_send_resolves_user_handle_to_dm_channel() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
fake_web._users_pages = [
{
"members": [
{
"id": "U234",
"name": "alice",
"profile": {"display_name": "Alice"},
}
],
"response_metadata": {"next_cursor": ""},
}
]
fake_web._open_dm_response = {"channel": {"id": "D234"}}
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="@alice",
content="hello",
)
)
assert fake_web.conversations_open_calls == [{"users": "U234"}]
assert fake_web.chat_post_calls == [
{"channel": "D234", "text": "hello\n", "thread_ts": None}
]
@pytest.mark.asyncio
async def test_send_updates_reaction_on_origin_channel_for_cross_channel_send() -> None:
channel = SlackChannel(SlackConfig(enabled=True, react_emoji="eyes"), MessageBus())
fake_web = _FakeAsyncWebClient()
fake_web._conversations_pages = [
{
"channels": [{"id": "C999", "name": "channel_x"}],
"response_metadata": {"next_cursor": ""},
}
]
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="channel_x",
content="done",
metadata={
"slack": {
"event": {"ts": "1700000000.000100", "channel": "D_ORIGIN"},
"channel_type": "im",
},
},
)
)
assert fake_web.chat_post_calls == [
{"channel": "C999", "text": "done\n", "thread_ts": None}
]
assert fake_web.reactions_remove_calls == [
{"channel": "D_ORIGIN", "name": "eyes", "timestamp": "1700000000.000100"}
]
assert fake_web.reactions_add_calls == [
{"channel": "D_ORIGIN", "name": "white_check_mark", "timestamp": "1700000000.000100"}
]
@pytest.mark.asyncio
async def test_send_does_not_reuse_origin_thread_ts_for_cross_channel_send() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
fake_web._conversations_pages = [
{
"channels": [{"id": "C999", "name": "channel_x"}],
"response_metadata": {"next_cursor": ""},
}
]
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="channel_x",
content="done",
metadata={
"slack": {
"event": {"ts": "1700000000.000100", "channel": "C_ORIGIN"},
"thread_ts": "1700000000.000200",
"channel_type": "channel",
},
},
)
)
assert fake_web.chat_post_calls == [
{"channel": "C999", "text": "done\n", "thread_ts": None}
]
@pytest.mark.asyncio
async def test_send_raises_when_named_target_cannot_be_resolved() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
channel._web_client = fake_web
with pytest.raises(ValueError, match="was not found"):
await channel.send(
OutboundMessage(
channel="slack",
chat_id="#missing-channel",
content="hello",
)
)

View File

@ -387,6 +387,84 @@ async def test_send_delta_stream_end_treats_not_modified_as_success() -> None:
assert "123" not in channel._stream_bufs
@pytest.mark.asyncio
async def test_send_delta_stream_end_does_not_fallback_on_network_timeout() -> None:
"""TimedOut during HTML edit should propagate, never fall back to plain text."""
from telegram.error import TimedOut
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
# _call_with_retry retries TimedOut up to 3 times, so the mock will be called
# multiple times but all calls must be with parse_mode="HTML" (no plain fallback).
channel._app.bot.edit_message_text = AsyncMock(side_effect=TimedOut("network timeout"))
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0)
with pytest.raises(TimedOut, match="network timeout"):
await channel.send_delta("123", "", {"_stream_end": True})
# Every call to edit_message_text must have used parse_mode="HTML" —
# no plain-text fallback call should have been made.
for call in channel._app.bot.edit_message_text.call_args_list:
assert call.kwargs.get("parse_mode") == "HTML"
# Buffer should still be present (not cleaned up on error)
assert "123" in channel._stream_bufs
@pytest.mark.asyncio
async def test_send_delta_stream_end_does_not_fallback_on_network_error() -> None:
"""NetworkError during HTML edit should propagate, never fall back to plain text."""
from telegram.error import NetworkError
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.edit_message_text = AsyncMock(side_effect=NetworkError("connection reset"))
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0)
with pytest.raises(NetworkError, match="connection reset"):
await channel.send_delta("123", "", {"_stream_end": True})
# Every call to edit_message_text must have used parse_mode="HTML" —
# no plain-text fallback call should have been made.
for call in channel._app.bot.edit_message_text.call_args_list:
assert call.kwargs.get("parse_mode") == "HTML"
# Buffer should still be present (not cleaned up on error)
assert "123" in channel._stream_bufs
@pytest.mark.asyncio
async def test_send_delta_stream_end_falls_back_on_bad_request() -> None:
"""BadRequest (HTML parse error) should still trigger plain-text fallback."""
from telegram.error import BadRequest
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
# First call (HTML) raises BadRequest, second call (plain) succeeds
channel._app.bot.edit_message_text = AsyncMock(
side_effect=[BadRequest("Can't parse entities"), None]
)
channel._stream_bufs["123"] = _StreamBuf(text="hello <bad>", message_id=7, last_edit=0.0)
await channel.send_delta("123", "", {"_stream_end": True})
# edit_message_text should have been called twice: once for HTML, once for plain fallback
assert channel._app.bot.edit_message_text.call_count == 2
# Second call should not use parse_mode="HTML"
second_call_kwargs = channel._app.bot.edit_message_text.call_args_list[1].kwargs
assert "parse_mode" not in second_call_kwargs or second_call_kwargs.get("parse_mode") is None
# Buffer should be cleaned up on success
assert "123" not in channel._stream_bufs
@pytest.mark.asyncio
async def test_send_delta_stream_end_splits_oversized_reply() -> None:
"""Final streamed reply exceeding Telegram limit is split into chunks."""
@ -1159,3 +1237,159 @@ async def test_on_message_location_with_text() -> None:
assert len(handled) == 1
assert "meet me here" in handled[0]["content"]
assert "[location: 51.5074, -0.1278]" in handled[0]["content"]
# ---------------------------------------------------------------------------
# Tests for retry amplification fix (issue #3050)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_text_does_not_fallback_on_network_timeout() -> None:
"""TimedOut should propagate immediately, NOT trigger plain-text fallback.
Before the fix, _send_text caught ALL exceptions (including TimedOut)
and retried as plain text, doubling connection demand during pool
exhaustion see issue #3050.
"""
from telegram.error import TimedOut
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
call_count = 0
async def always_timeout(**kwargs):
nonlocal call_count
call_count += 1
raise TimedOut()
channel._app.bot.send_message = always_timeout
import nanobot.channels.telegram as tg_mod
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
try:
with pytest.raises(TimedOut):
await channel._send_text(123, "hello", None, {})
finally:
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
# With the fix: only _call_with_retry's 3 HTML attempts (no plain fallback).
# Before the fix: 3 HTML + 3 plain = 6 attempts.
assert call_count == 3, (
f"Expected 3 calls (HTML retries only), got {call_count} "
"(plain-text fallback should not trigger on TimedOut)"
)
@pytest.mark.asyncio
async def test_send_text_does_not_fallback_on_network_error() -> None:
"""NetworkError should propagate immediately, NOT trigger plain-text fallback."""
from telegram.error import NetworkError
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
call_count = 0
async def always_network_error(**kwargs):
nonlocal call_count
call_count += 1
raise NetworkError("Connection reset")
channel._app.bot.send_message = always_network_error
import nanobot.channels.telegram as tg_mod
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
try:
with pytest.raises(NetworkError):
await channel._send_text(123, "hello", None, {})
finally:
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
# _call_with_retry does NOT retry NetworkError (only TimedOut/RetryAfter),
# so it raises after 1 attempt. The fix prevents plain-text fallback.
# Before the fix: 1 HTML + 1 plain = 2. After the fix: 1 HTML only.
assert call_count == 1, (
f"Expected 1 call (HTML only, no plain fallback), got {call_count}"
)
@pytest.mark.asyncio
async def test_send_text_falls_back_on_bad_request() -> None:
"""BadRequest (HTML parse error) should still trigger plain-text fallback."""
from telegram.error import BadRequest
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
original_send = channel._app.bot.send_message
html_call_count = 0
async def html_fails(**kwargs):
nonlocal html_call_count
if kwargs.get("parse_mode") == "HTML":
html_call_count += 1
raise BadRequest("Can't parse entities")
return await original_send(**kwargs)
channel._app.bot.send_message = html_fails
import nanobot.channels.telegram as tg_mod
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
try:
await channel._send_text(123, "hello **world**", None, {})
finally:
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
# HTML attempt failed with BadRequest → fallback to plain text succeeds.
assert html_call_count == 1, f"Expected 1 HTML attempt, got {html_call_count}"
assert len(channel._app.bot.sent_messages) == 1
# Plain text send should NOT have parse_mode
assert channel._app.bot.sent_messages[0].get("parse_mode") is None
@pytest.mark.asyncio
async def test_send_text_bad_request_plain_fallback_exhausted() -> None:
"""When both HTML and plain-text fallback fail with BadRequest, the error propagates."""
from telegram.error import BadRequest
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
call_count = 0
async def always_bad_request(**kwargs):
nonlocal call_count
call_count += 1
raise BadRequest("Bad request")
channel._app.bot.send_message = always_bad_request
import nanobot.channels.telegram as tg_mod
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
try:
with pytest.raises(BadRequest):
await channel._send_text(123, "hello", None, {})
finally:
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
# _call_with_retry does NOT retry BadRequest (only TimedOut/RetryAfter),
# so HTML fails after 1 attempt → fallback to plain also fails after 1 attempt.
# Before the fix: 2 total. After the fix: still 2 (BadRequest SHOULD fallback).
assert call_count == 2, f"Expected 2 calls (1 HTML + 1 plain), got {call_count}"

View File

@ -0,0 +1,598 @@
"""Unit and lightweight integration tests for the WebSocket channel."""
import asyncio
import functools
import json
import time
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import httpx
import pytest
import websockets
from websockets.exceptions import ConnectionClosed
from websockets.frames import Close
from nanobot.bus.events import OutboundMessage
from nanobot.channels.websocket import (
WebSocketChannel,
WebSocketConfig,
_issue_route_secret_matches,
_normalize_config_path,
_normalize_http_path,
_parse_inbound_payload,
_parse_query,
_parse_request_path,
)
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
_PORT = 29876
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
cfg: dict[str, Any] = {
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": _PORT,
"path": "/ws",
"websocketRequiresToken": False,
}
cfg.update(kw)
return WebSocketChannel(cfg, bus)
@pytest.fixture()
def bus() -> MagicMock:
b = MagicMock()
b.publish_inbound = AsyncMock()
return b
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
return await asyncio.to_thread(
functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0)
)
def test_normalize_http_path_strips_trailing_slash_except_root() -> None:
assert _normalize_http_path("/chat/") == "/chat"
assert _normalize_http_path("/chat?x=1") == "/chat"
assert _normalize_http_path("/") == "/"
def test_parse_request_path_matches_normalize_and_query() -> None:
path, query = _parse_request_path("/ws/?token=secret&client_id=u1")
assert path == _normalize_http_path("/ws/?token=secret&client_id=u1")
assert query == _parse_query("/ws/?token=secret&client_id=u1")
def test_normalize_config_path_matches_request() -> None:
assert _normalize_config_path("/ws/") == "/ws"
assert _normalize_config_path("/") == "/"
def test_parse_query_extracts_token_and_client_id() -> None:
query = _parse_query("/?token=secret&client_id=u1")
assert query.get("token") == ["secret"]
assert query.get("client_id") == ["u1"]
@pytest.mark.parametrize(
("raw", "expected"),
[
("plain", "plain"),
('{"content": "hi"}', "hi"),
('{"text": "there"}', "there"),
('{"message": "x"}', "x"),
(" ", None),
("{}", None),
],
)
def test_parse_inbound_payload(raw: str, expected: str | None) -> None:
assert _parse_inbound_payload(raw) == expected
def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
assert _parse_inbound_payload("{not json") == "{not json"
@pytest.mark.parametrize(
("raw", "expected"),
[
('{"content": ""}', None), # empty string content
('{"content": 123}', None), # non-string content
('{"content": " "}', None), # whitespace-only content
('["hello"]', '["hello"]'), # JSON array: not a dict, treated as plain text
('{"unknown_key": "val"}', None), # unrecognized key
('{"content": null}', None), # null content
],
)
def test_parse_inbound_payload_edge_cases(raw: str, expected: str | None) -> None:
assert _parse_inbound_payload(raw) == expected
def test_web_socket_config_path_must_start_with_slash() -> None:
with pytest.raises(ValueError, match='path must start with "/"'):
WebSocketConfig(path="bad")
def test_ssl_context_requires_both_cert_and_key_files() -> None:
bus = MagicMock()
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
bus,
)
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
channel._build_ssl_context()
def test_default_config_includes_safe_bind_and_streaming() -> None:
defaults = WebSocketChannel.default_config()
assert defaults["enabled"] is False
assert defaults["host"] == "127.0.0.1"
assert defaults["streaming"] is True
assert defaults["allowFrom"] == ["*"]
assert defaults.get("tokenIssuePath", "") == ""
def test_token_issue_path_must_differ_from_websocket_path() -> None:
with pytest.raises(ValueError, match="token_issue_path must differ"):
WebSocketConfig(path="/ws", token_issue_path="/ws")
def test_issue_route_secret_matches_bearer_and_header() -> None:
from websockets.datastructures import Headers
secret = "my-secret"
bearer_headers = Headers([("Authorization", "Bearer my-secret")])
assert _issue_route_secret_matches(bearer_headers, secret) is True
x_headers = Headers([("X-Nanobot-Auth", "my-secret")])
assert _issue_route_secret_matches(x_headers, secret) is True
wrong = Headers([("Authorization", "Bearer other")])
assert _issue_route_secret_matches(wrong, secret) is False
def test_issue_route_secret_matches_empty_secret() -> None:
from websockets.datastructures import Headers
# Empty secret always returns True regardless of headers
assert _issue_route_secret_matches(Headers([]), "") is True
assert _issue_route_secret_matches(Headers([("Authorization", "Bearer anything")]), "") is True
@pytest.mark.asyncio
async def test_send_delivers_json_message_with_media_and_reply() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock()
channel._connections["chat-1"] = mock_ws
msg = OutboundMessage(
channel="websocket",
chat_id="chat-1",
content="hello",
reply_to="m1",
media=["/tmp/a.png"],
)
await channel.send(msg)
mock_ws.send.assert_awaited_once()
payload = json.loads(mock_ws.send.call_args[0][0])
assert payload["event"] == "message"
assert payload["text"] == "hello"
assert payload["reply_to"] == "m1"
assert payload["media"] == ["/tmp/a.png"]
@pytest.mark.asyncio
async def test_send_missing_connection_is_noop_without_error() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
await channel.send(msg)
@pytest.mark.asyncio
async def test_send_removes_connection_on_connection_closed() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock()
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
channel._connections["chat-1"] = mock_ws
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
await channel.send(msg)
assert "chat-1" not in channel._connections
@pytest.mark.asyncio
async def test_send_delta_removes_connection_on_connection_closed() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
mock_ws = AsyncMock()
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
channel._connections["chat-1"] = mock_ws
await channel.send_delta("chat-1", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
assert "chat-1" not in channel._connections
@pytest.mark.asyncio
async def test_send_delta_emits_delta_and_stream_end() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
mock_ws = AsyncMock()
channel._connections["chat-1"] = mock_ws
await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"})
await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"})
assert mock_ws.send.await_count == 2
first = json.loads(mock_ws.send.call_args_list[0][0][0])
second = json.loads(mock_ws.send.call_args_list[1][0][0])
assert first["event"] == "delta"
assert first["text"] == "part"
assert first["stream_id"] == "sid"
assert second["event"] == "stream_end"
assert second["stream_id"] == "sid"
@pytest.mark.asyncio
async def test_send_non_connection_closed_exception_is_raised() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
mock_ws = AsyncMock()
mock_ws.send.side_effect = RuntimeError("unexpected")
channel._connections["chat-1"] = mock_ws
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
with pytest.raises(RuntimeError, match="unexpected"):
await channel.send(msg)
@pytest.mark.asyncio
async def test_send_delta_missing_connection_is_noop() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
# No exception, no error — just a no-op
await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
@pytest.mark.asyncio
async def test_stop_is_idempotent() -> None:
bus = MagicMock()
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
# stop() before start() should not raise
await channel.stop()
await channel.stop()
@pytest.mark.asyncio
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound(bus: MagicMock) -> None:
port = 29876
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client:
ready_raw = await client.recv()
ready = json.loads(ready_raw)
assert ready["event"] == "ready"
assert ready["client_id"] == "tester"
chat_id = ready["chat_id"]
await client.send(json.dumps({"content": "ping from client"}))
await asyncio.sleep(0.08)
bus.publish_inbound.assert_awaited()
inbound = bus.publish_inbound.call_args[0][0]
assert inbound.channel == "websocket"
assert inbound.sender_id == "tester"
assert inbound.chat_id == chat_id
assert inbound.content == "ping from client"
await client.send("plain text frame")
await asyncio.sleep(0.08)
assert bus.publish_inbound.await_count >= 2
second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1]
assert second.content == "plain text frame"
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_token_rejects_handshake_when_mismatch(bus: MagicMock) -> None:
port = 29877
channel = _ch(bus, port=port, path="/", token="secret")
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"):
pass
assert excinfo.value.response.status_code == 401
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_wrong_path_returns_404(bus: MagicMock) -> None:
port = 29878
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
async with websockets.connect(f"ws://127.0.0.1:{port}/other"):
pass
assert excinfo.value.response.status_code == 404
finally:
await channel.stop()
await server_task
def test_registry_discovers_websocket_channel() -> None:
from nanobot.channels.registry import load_channel_class
cls = load_channel_class("websocket")
assert cls.name == "websocket"
@pytest.mark.asyncio
async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock) -> None:
port = 29879
channel = _ch(
bus, port=port,
tokenIssuePath="/auth/token",
tokenIssueSecret="route-secret",
websocketRequiresToken=True,
)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
deny = await _http_get(f"http://127.0.0.1:{port}/auth/token")
assert deny.status_code == 401
issue = await _http_get(
f"http://127.0.0.1:{port}/auth/token",
headers={"Authorization": "Bearer route-secret"},
)
assert issue.status_code == 200
token = issue.json()["token"]
assert token.startswith("nbwt_")
with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"):
pass
assert missing_token.value.response.status_code == 401
uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller"
async with websockets.connect(uri) as client:
ready = json.loads(await client.recv())
assert ready["event"] == "ready"
assert ready["client_id"] == "caller"
with pytest.raises(websockets.exceptions.InvalidStatus) as reuse:
async with websockets.connect(uri):
pass
assert reuse.value.response.status_code == 401
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
port = 29880
channel = _ch(bus, port=port, streaming=True)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=stream-tester") as client:
ready_raw = await client.recv()
ready = json.loads(ready_raw)
chat_id = ready["chat_id"]
# Server pushes deltas directly
await channel.send_delta(
chat_id, "Hello ", {"_stream_delta": True, "_stream_id": "s1"}
)
await channel.send_delta(
chat_id, "world", {"_stream_delta": True, "_stream_id": "s1"}
)
await channel.send_delta(
chat_id, "", {"_stream_end": True, "_stream_id": "s1"}
)
delta1 = json.loads(await client.recv())
assert delta1["event"] == "delta"
assert delta1["text"] == "Hello "
assert delta1["stream_id"] == "s1"
delta2 = json.loads(await client.recv())
assert delta2["event"] == "delta"
assert delta2["text"] == "world"
assert delta2["stream_id"] == "s1"
end = json.loads(await client.recv())
assert end["event"] == "stream_end"
assert end["stream_id"] == "s1"
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None:
port = 29881
channel = _ch(bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="s")
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
# Fill issued tokens to capacity
channel._issued_tokens = {
f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._MAX_ISSUED_TOKENS)
}
resp = await _http_get(
f"http://127.0.0.1:{port}/auth/token",
headers={"Authorization": "Bearer s"},
)
assert resp.status_code == 429
data = resp.json()
assert "error" in data
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_allow_from_rejects_unauthorized_client_id(bus: MagicMock) -> None:
port = 29882
channel = _ch(bus, port=port, allowFrom=["alice", "bob"])
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=eve"):
pass
assert exc_info.value.response.status_code == 403
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_client_id_truncation(bus: MagicMock) -> None:
port = 29883
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
long_id = "x" * 200
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id={long_id}") as client:
ready = json.loads(await client.recv())
assert ready["client_id"] == "x" * 128
assert len(ready["client_id"]) == 128
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_non_utf8_binary_frame_ignored(bus: MagicMock) -> None:
port = 29884
channel = _ch(bus, port=port)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=bin-test") as client:
await client.recv() # consume ready
# Send non-UTF-8 bytes
await client.send(b"\xff\xfe\xfd")
await asyncio.sleep(0.05)
# publish_inbound should NOT have been called
bus.publish_inbound.assert_not_awaited()
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_static_token_accepts_issued_token_as_fallback(bus: MagicMock) -> None:
port = 29885
channel = _ch(
bus, port=port,
token="static-secret",
tokenIssuePath="/auth/token",
tokenIssueSecret="route-secret",
)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
# Get an issued token
resp = await _http_get(
f"http://127.0.0.1:{port}/auth/token",
headers={"Authorization": "Bearer route-secret"},
)
assert resp.status_code == 200
issued_token = resp.json()["token"]
# Connect using issued token (not the static one)
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?token={issued_token}&client_id=caller") as client:
ready = json.loads(await client.recv())
assert ready["event"] == "ready"
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_allow_from_empty_list_denies_all(bus: MagicMock) -> None:
port = 29886
channel = _ch(bus, port=port, allowFrom=[])
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=anyone"):
pass
assert exc_info.value.response.status_code == 403
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio
async def test_websocket_requires_token_without_issue_path(bus: MagicMock) -> None:
"""When websocket_requires_token is True but no token or issue path configured, all connections are rejected."""
port = 29887
channel = _ch(bus, port=port, websocketRequiresToken=True)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
# No token at all → 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u"):
pass
assert exc_info.value.response.status_code == 401
# Wrong token → 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u&token=wrong"):
pass
assert exc_info.value.response.status_code == 401
finally:
await channel.stop()
await server_task

View File

@ -0,0 +1,478 @@
"""Integration tests for the WebSocket channel using WsTestClient.
Complements the unit/lightweight tests in test_websocket_channel.py by covering
multi-client scenarios, edge cases, and realistic usage patterns.
"""
from __future__ import annotations
import asyncio
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
import websockets
from nanobot.channels.websocket import WebSocketChannel
from nanobot.bus.events import OutboundMessage
from ws_test_client import WsTestClient, issue_token, issue_token_ok
def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
cfg: dict[str, Any] = {
"enabled": True,
"allowFrom": ["*"],
"host": "127.0.0.1",
"port": port,
"path": "/",
"websocketRequiresToken": False,
}
cfg.update(kw)
return WebSocketChannel(cfg, bus)
@pytest.fixture()
def bus() -> MagicMock:
b = MagicMock()
b.publish_inbound = AsyncMock()
return b
# -- Connection basics ----------------------------------------------------
@pytest.mark.asyncio
async def test_ready_event_fields(bus: MagicMock) -> None:
ch = _ch(bus, 29901)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29901/", client_id="c1") as c:
r = await c.recv_ready()
assert r.event == "ready"
assert len(r.chat_id) == 36
assert r.client_id == "c1"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None:
ch = _ch(bus, 29902)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29902/", client_id="") as c:
r = await c.recv_ready()
assert r.client_id.startswith("anon-")
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_each_connection_unique_chat_id(bus: MagicMock) -> None:
ch = _ch(bus, 29903)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29903/", client_id="a") as c1:
async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2:
assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id
finally:
await ch.stop(); await t
# -- Inbound messages (client -> server) ----------------------------------
@pytest.mark.asyncio
async def test_plain_text(bus: MagicMock) -> None:
ch = _ch(bus, 29904)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29904/", client_id="p") as c:
await c.recv_ready()
await c.send_text("hello world")
await asyncio.sleep(0.1)
inbound = bus.publish_inbound.call_args[0][0]
assert inbound.content == "hello world"
assert inbound.sender_id == "p"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_json_content_field(bus: MagicMock) -> None:
ch = _ch(bus, 29905)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29905/", client_id="j") as c:
await c.recv_ready()
await c.send_json({"content": "structured"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "structured"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_json_text_and_message_fields(bus: MagicMock) -> None:
ch = _ch(bus, 29906)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29906/", client_id="x") as c:
await c.recv_ready()
await c.send_json({"text": "via text"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "via text"
await c.send_json({"message": "via message"})
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "via message"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_empty_payload_ignored(bus: MagicMock) -> None:
ch = _ch(bus, 29907)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29907/", client_id="e") as c:
await c.recv_ready()
await c.send_text(" ")
await c.send_json({})
await asyncio.sleep(0.1)
bus.publish_inbound.assert_not_awaited()
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_messages_preserve_order(bus: MagicMock) -> None:
ch = _ch(bus, 29908)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29908/", client_id="o") as c:
await c.recv_ready()
for i in range(5):
await c.send_text(f"msg-{i}")
await asyncio.sleep(0.2)
contents = [call[0][0].content for call in bus.publish_inbound.call_args_list]
assert contents == [f"msg-{i}" for i in range(5)]
finally:
await ch.stop(); await t
# -- Outbound messages (server -> client) ---------------------------------
@pytest.mark.asyncio
async def test_server_send_message(bus: MagicMock) -> None:
ch = _ch(bus, 29909)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29909/", client_id="r") as c:
ready = await c.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content="reply",
))
msg = await c.recv_message()
assert msg.text == "reply"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_server_send_with_media_and_reply(bus: MagicMock) -> None:
ch = _ch(bus, 29910)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29910/", client_id="m") as c:
ready = await c.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content="img",
media=["/tmp/a.png"], reply_to="m1",
))
msg = await c.recv_message()
assert msg.text == "img"
assert msg.media == ["/tmp/a.png"]
assert msg.reply_to == "m1"
finally:
await ch.stop(); await t
# -- Streaming ------------------------------------------------------------
@pytest.mark.asyncio
async def test_streaming_deltas_and_end(bus: MagicMock) -> None:
ch = _ch(bus, 29911, streaming=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29911/", client_id="s") as c:
cid = (await c.recv_ready()).chat_id
for part in ("Hello", " ", "world", "!"):
await ch.send_delta(cid, part, {"_stream_delta": True, "_stream_id": "s1"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "s1"})
msgs = await c.collect_stream()
deltas = [m for m in msgs if m.event == "delta"]
assert "".join(d.text for d in deltas) == "Hello world!"
ends = [m for m in msgs if m.event == "stream_end"]
assert len(ends) == 1
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_interleaved_streams(bus: MagicMock) -> None:
ch = _ch(bus, 29912, streaming=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29912/", client_id="i") as c:
cid = (await c.recv_ready()).chat_id
await ch.send_delta(cid, "A1", {"_stream_delta": True, "_stream_id": "sa"})
await ch.send_delta(cid, "B1", {"_stream_delta": True, "_stream_id": "sb"})
await ch.send_delta(cid, "A2", {"_stream_delta": True, "_stream_id": "sa"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sa"})
await ch.send_delta(cid, "B2", {"_stream_delta": True, "_stream_id": "sb"})
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sb"})
msgs = await c.recv_n(6)
sa = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sa")
sb = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sb")
assert sa == "A1A2"
assert sb == "B1B2"
finally:
await ch.stop(); await t
# -- Multi-client ---------------------------------------------------------
@pytest.mark.asyncio
async def test_independent_sessions(bus: MagicMock) -> None:
ch = _ch(bus, 29913)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u1") as c1:
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u2") as c2:
r1, r2 = await c1.recv_ready(), await c2.recv_ready()
await ch.send(OutboundMessage(
channel="websocket", chat_id=r1.chat_id, content="for-u1",
))
assert (await c1.recv_message()).text == "for-u1"
await ch.send(OutboundMessage(
channel="websocket", chat_id=r2.chat_id, content="for-u2",
))
assert (await c2.recv_message()).text == "for-u2"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_disconnected_client_cleanup(bus: MagicMock) -> None:
ch = _ch(bus, 29914)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29914/", client_id="tmp") as c:
chat_id = (await c.recv_ready()).chat_id
# disconnected
await ch.send(OutboundMessage(
channel="websocket", chat_id=chat_id, content="orphan",
))
assert chat_id not in ch._connections
finally:
await ch.stop(); await t
# -- Authentication -------------------------------------------------------
@pytest.mark.asyncio
async def test_static_token_accepted(bus: MagicMock) -> None:
ch = _ch(bus, 29915, token="secret")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c:
assert (await c.recv_ready()).client_id == "a"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_static_token_rejected(bus: MagicMock) -> None:
ch = _ch(bus, 29916, token="correct")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29916/", client_id="b", token="wrong"):
pass
assert exc.value.response.status_code == 401
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_token_issue_full_flow(bus: MagicMock) -> None:
ch = _ch(bus, 29917, path="/ws",
tokenIssuePath="/auth/token", tokenIssueSecret="s",
websocketRequiresToken=True)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
# no secret -> 401
_, status = await issue_token(port=29917, issue_path="/auth/token")
assert status == 401
# with secret -> token
token = await issue_token_ok(port=29917, issue_path="/auth/token", secret="s")
# no token -> 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="x"):
pass
assert exc.value.response.status_code == 401
# valid token -> ok
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="ok", token=token) as c:
assert (await c.recv_ready()).client_id == "ok"
# reuse -> 401
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="r", token=token):
pass
assert exc.value.response.status_code == 401
finally:
await ch.stop(); await t
# -- Path routing ---------------------------------------------------------
@pytest.mark.asyncio
async def test_custom_path(bus: MagicMock) -> None:
ch = _ch(bus, 29918, path="/my-chat")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c:
assert (await c.recv_ready()).event == "ready"
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_wrong_path_404(bus: MagicMock) -> None:
ch = _ch(bus, 29919, path="/ws")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
async with WsTestClient("ws://127.0.0.1:29919/wrong", client_id="x"):
pass
assert exc.value.response.status_code == 404
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_trailing_slash_normalized(bus: MagicMock) -> None:
ch = _ch(bus, 29920, path="/ws")
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c:
assert (await c.recv_ready()).event == "ready"
finally:
await ch.stop(); await t
# -- Edge cases -----------------------------------------------------------
@pytest.mark.asyncio
async def test_large_message(bus: MagicMock) -> None:
ch = _ch(bus, 29921)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29921/", client_id="big") as c:
await c.recv_ready()
big = "x" * 100_000
await c.send_text(big)
await asyncio.sleep(0.2)
assert bus.publish_inbound.call_args[0][0].content == big
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_unicode_roundtrip(bus: MagicMock) -> None:
ch = _ch(bus, 29922)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29922/", client_id="u") as c:
ready = await c.recv_ready()
text = "你好世界 🌍 日本語テスト"
await c.send_text(text)
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == text
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content=text,
))
assert (await c.recv_message()).text == text
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_rapid_fire(bus: MagicMock) -> None:
ch = _ch(bus, 29923)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29923/", client_id="r") as c:
ready = await c.recv_ready()
for i in range(50):
await c.send_text(f"in-{i}")
await asyncio.sleep(0.5)
assert bus.publish_inbound.await_count == 50
for i in range(50):
await ch.send(OutboundMessage(
channel="websocket", chat_id=ready.chat_id, content=f"out-{i}",
))
received = [(await c.recv_message()).text for _ in range(50)]
assert received == [f"out-{i}" for i in range(50)]
finally:
await ch.stop(); await t
@pytest.mark.asyncio
async def test_invalid_json_as_plain_text(bus: MagicMock) -> None:
ch = _ch(bus, 29924)
t = asyncio.create_task(ch.start())
await asyncio.sleep(0.3)
try:
async with WsTestClient("ws://127.0.0.1:29924/", client_id="j") as c:
await c.recv_ready()
await c.send_text("{broken json")
await asyncio.sleep(0.1)
assert bus.publish_inbound.call_args[0][0].content == "{broken json"
finally:
await ch.stop(); await t

View File

@ -0,0 +1,584 @@
"""Tests for WeCom channel: helpers, download, upload, send, and message processing."""
import os
import tempfile
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
try:
import importlib.util
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
except ImportError:
WECOM_AVAILABLE = False
if not WECOM_AVAILABLE:
pytest.skip("WeCom dependencies not installed (wecom_aibot_sdk)", allow_module_level=True)
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.wecom import (
WecomChannel,
WecomConfig,
_guess_wecom_media_type,
_sanitize_filename,
)
# Try to import the real response class; fall back to a stub if unavailable.
try:
from wecom_aibot_sdk.utils import WsResponse
_RealWsResponse = WsResponse
except ImportError:
_RealWsResponse = None
class _FakeResponse:
"""Minimal stand-in for wecom_aibot_sdk WsResponse."""
def __init__(self, errcode: int = 0, body: dict | None = None, errmsg: str = "ok"):
self.errcode = errcode
self.errmsg = errmsg
self.body = body or {}
class _FakeWsManager:
"""Tracks send_reply calls and returns configurable responses."""
def __init__(self, responses: list[_FakeResponse] | None = None):
self.responses = responses or []
self.calls: list[tuple[str, dict, str]] = []
self._idx = 0
async def send_reply(self, req_id: str, data: dict, cmd: str) -> _FakeResponse:
self.calls.append((req_id, data, cmd))
if self._idx < len(self.responses):
resp = self.responses[self._idx]
self._idx += 1
return resp
return _FakeResponse()
class _FakeFrame:
"""Minimal frame object with a body dict."""
def __init__(self, body: dict | None = None):
self.body = body or {}
class _FakeWeComClient:
"""Fake WeCom client with mock methods."""
def __init__(self, ws_responses: list[_FakeResponse] | None = None):
self._ws_manager = _FakeWsManager(ws_responses)
self.download_file = AsyncMock(return_value=(None, None))
self.reply = AsyncMock()
self.reply_stream = AsyncMock()
self.send_message = AsyncMock()
self.reply_welcome = AsyncMock()
# ── Helper function tests (pure, no async) ──────────────────────────
def test_sanitize_filename_strips_path_traversal() -> None:
assert _sanitize_filename("../../etc/passwd") == "passwd"
def test_sanitize_filename_keeps_chinese_chars() -> None:
assert _sanitize_filename("文件1.jpg") == "文件1.jpg"
def test_sanitize_filename_empty_input() -> None:
assert _sanitize_filename("") == ""
def test_guess_wecom_media_type_image() -> None:
for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"):
assert _guess_wecom_media_type(f"photo{ext}") == "image"
def test_guess_wecom_media_type_video() -> None:
for ext in (".mp4", ".avi", ".mov"):
assert _guess_wecom_media_type(f"video{ext}") == "video"
def test_guess_wecom_media_type_voice() -> None:
for ext in (".amr", ".mp3", ".wav", ".ogg"):
assert _guess_wecom_media_type(f"audio{ext}") == "voice"
def test_guess_wecom_media_type_file_fallback() -> None:
for ext in (".pdf", ".doc", ".xlsx", ".zip"):
assert _guess_wecom_media_type(f"doc{ext}") == "file"
def test_guess_wecom_media_type_case_insensitive() -> None:
assert _guess_wecom_media_type("photo.PNG") == "image"
assert _guess_wecom_media_type("photo.Jpg") == "image"
# ── _download_and_save_media() ──────────────────────────────────────
@pytest.mark.asyncio
async def test_download_and_save_success() -> None:
"""Successful download writes file and returns sanitized path."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
fake_data = b"\x89PNG\r\nfake image"
client.download_file.return_value = (fake_data, "raw_photo.png")
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
path = await channel._download_and_save_media("https://example.com/img.png", "aes_key", "image", "photo.png")
assert path is not None
assert os.path.isfile(path)
assert os.path.basename(path) == "photo.png"
# Cleanup
os.unlink(path)
@pytest.mark.asyncio
async def test_download_and_save_oversized_rejected() -> None:
"""Data exceeding 200MB is rejected → returns None."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
big_data = b"\x00" * (200 * 1024 * 1024 + 1) # 200MB + 1 byte
client.download_file.return_value = (big_data, "big.bin")
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
result = await channel._download_and_save_media("https://example.com/big.bin", "key", "file", "big.bin")
assert result is None
@pytest.mark.asyncio
async def test_download_and_save_failure() -> None:
"""SDK returns None data → returns None."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
client.download_file.return_value = (None, None)
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
result = await channel._download_and_save_media("https://example.com/fail.png", "key", "image")
assert result is None
# ── _upload_media_ws() ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_upload_media_ws_success() -> None:
"""Happy path: init → chunk → finish → returns (media_id, media_type)."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=0, body={}),
_FakeResponse(errcode=0, body={"media_id": "media_abc"}),
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
media_id, media_type = await channel._upload_media_ws(client, tmp)
assert media_id == "media_abc"
assert media_type == "image"
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_upload_media_ws_oversized_file() -> None:
"""File >200MB triggers ValueError → returns (None, None)."""
# Instead of creating a real 200MB+ file, mock os.path.getsize and open
with patch("os.path.getsize", return_value=200 * 1024 * 1024 + 1), \
patch("builtins.open", MagicMock()):
client = _FakeWeComClient()
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
result = await channel._upload_media_ws(client, "/fake/large.bin")
assert result == (None, None)
@pytest.mark.asyncio
async def test_upload_media_ws_init_failure() -> None:
"""Init step returns errcode != 0 → returns (None, None)."""
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
f.write(b"hello")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=50001, errmsg="invalid"),
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
result = await channel._upload_media_ws(client, tmp)
assert result == (None, None)
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_upload_media_ws_chunk_failure() -> None:
"""Chunk step returns errcode != 0 → returns (None, None)."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=50002, errmsg="chunk fail"),
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
result = await channel._upload_media_ws(client, tmp)
assert result == (None, None)
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_upload_media_ws_finish_no_media_id() -> None:
"""Finish step returns empty media_id → returns (None, None)."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=0, body={}),
_FakeResponse(errcode=0, body={}), # no media_id
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
result = await channel._upload_media_ws(client, tmp)
assert result == (None, None)
finally:
os.unlink(tmp)
# ── send() ──────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_send_text_with_frame() -> None:
"""When frame is stored, send uses reply_stream for final text."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="hello")
)
client.reply_stream.assert_called_once()
call_args = client.reply_stream.call_args
assert call_args[0][2] == "hello" # content arg
@pytest.mark.asyncio
async def test_send_progress_with_frame() -> None:
"""When metadata has _progress, send uses reply_stream with finish=False."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="thinking...", metadata={"_progress": True})
)
client.reply_stream.assert_called_once()
call_args = client.reply_stream.call_args
assert call_args[0][2] == "thinking..." # content arg
assert call_args[1]["finish"] is False
@pytest.mark.asyncio
async def test_send_proactive_without_frame() -> None:
"""Without stored frame, send uses send_message with markdown."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="proactive msg")
)
client.send_message.assert_called_once()
call_args = client.send_message.call_args
assert call_args[0][0] == "chat1"
assert call_args[0][1]["msgtype"] == "markdown"
@pytest.mark.asyncio
async def test_send_media_then_text() -> None:
"""Media files are uploaded and sent before text content."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=0, body={}),
_FakeResponse(errcode=0, body={"media_id": "media_123"}),
]
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient(responses)
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="see image", media=[tmp])
)
# Media should have been sent via reply
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") == "image"]
assert len(media_calls) == 1
assert media_calls[0][0][1]["image"]["media_id"] == "media_123"
# Text should have been sent via reply_stream
client.reply_stream.assert_called_once()
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_send_media_file_not_found() -> None:
"""Non-existent media path is skipped with a warning."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="hello", media=["/nonexistent/file.png"])
)
# reply_stream should still be called for the text part
client.reply_stream.assert_called_once()
# No media reply should happen
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") in ("image", "file", "video")]
assert len(media_calls) == 0
@pytest.mark.asyncio
async def test_send_exception_caught_not_raised() -> None:
"""Exceptions inside send() must not propagate."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
# Make reply_stream raise
client.reply_stream.side_effect = RuntimeError("boom")
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="fail test")
)
# No exception — test passes if we reach here.
# ── _process_message() ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_process_text_message() -> None:
"""Text message is routed to bus with correct fields."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_text_1",
"chatid": "chat1",
"chattype": "single",
"from": {"userid": "user1"},
"text": {"content": "hello wecom"},
})
await channel._process_message(frame, "text")
msg = await channel.bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "chat1"
assert msg.content == "hello wecom"
assert msg.metadata["msg_type"] == "text"
@pytest.mark.asyncio
async def test_process_image_message() -> None:
"""Image message: download success → media_paths non-empty."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
saved = f.name
client.download_file.return_value = (b"\x89PNG\r\n", "photo.png")
channel._client = client
try:
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
frame = _FakeFrame(body={
"msgid": "msg_img_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"image": {"url": "https://example.com/img.png", "aeskey": "key123"},
})
await channel._process_message(frame, "image")
msg = await channel.bus.consume_inbound()
assert len(msg.media) == 1
assert msg.media[0].endswith("photo.png")
assert "[image:" in msg.content
finally:
if os.path.exists(saved):
pass # may have been overwritten; clean up if exists
# Clean up any photo.png in tempdir
p = os.path.join(os.path.dirname(saved), "photo.png")
if os.path.exists(p):
os.unlink(p)
@pytest.mark.asyncio
async def test_process_file_message() -> None:
"""File message: download success → media_paths non-empty (critical fix verification)."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(b"%PDF-1.4 fake")
saved = f.name
client.download_file.return_value = (b"%PDF-1.4 fake", "report.pdf")
channel._client = client
try:
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
frame = _FakeFrame(body={
"msgid": "msg_file_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"file": {"url": "https://example.com/report.pdf", "aeskey": "key456", "name": "report.pdf"},
})
await channel._process_message(frame, "file")
msg = await channel.bus.consume_inbound()
assert len(msg.media) == 1
assert msg.media[0].endswith("report.pdf")
assert "[file: report.pdf]" in msg.content
finally:
p = os.path.join(os.path.dirname(saved), "report.pdf")
if os.path.exists(p):
os.unlink(p)
@pytest.mark.asyncio
async def test_process_voice_message() -> None:
"""Voice message: transcribed text is included in content."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_voice_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"voice": {"content": "transcribed text here"},
})
await channel._process_message(frame, "voice")
msg = await channel.bus.consume_inbound()
assert "transcribed text here" in msg.content
assert "[voice]" in msg.content
@pytest.mark.asyncio
async def test_process_message_deduplication() -> None:
"""Same msg_id is not processed twice."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_dup_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"text": {"content": "once"},
})
await channel._process_message(frame, "text")
await channel._process_message(frame, "text")
msg = await channel.bus.consume_inbound()
assert msg.content == "once"
# Second message should not appear on the bus
assert channel.bus.inbound.empty()
@pytest.mark.asyncio
async def test_process_message_empty_content_skipped() -> None:
"""Message with empty content produces no bus message."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_empty_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"text": {"content": ""},
})
await channel._process_message(frame, "text")
assert channel.bus.inbound.empty()

View File

@ -1003,3 +1003,185 @@ async def test_download_media_item_non_image_requires_aes_key_even_with_full_url
assert saved_path is None
channel._client.get.assert_not_awaited()
# ---------------------------------------------------------------------------
# Tests for media-send error classification (network vs non-network errors)
# ---------------------------------------------------------------------------
def _make_outbound_msg(chat_id: str = "wx-user", content: str = "", media: list | None = None):
"""Build a minimal OutboundMessage-like object for send() tests."""
from nanobot.bus.events import OutboundMessage
return OutboundMessage(
channel="weixin",
chat_id=chat_id,
content=content,
media=media or [],
metadata={},
)
@pytest.mark.asyncio
async def test_send_media_timeout_error_propagates_without_text_fallback() -> None:
"""httpx.TimeoutException during media send must re-raise immediately,
NOT fall back to _send_text (which would also fail during network issues)."""
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-1"
channel._send_media_file = AsyncMock(side_effect=httpx.TimeoutException("timed out"))
channel._send_text = AsyncMock()
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
with pytest.raises(httpx.TimeoutException, match="timed out"):
await channel.send(msg)
# _send_text must NOT have been called as a fallback
channel._send_text.assert_not_awaited()
@pytest.mark.asyncio
async def test_send_media_transport_error_propagates_without_text_fallback() -> None:
"""httpx.TransportError during media send must re-raise immediately."""
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-1"
channel._send_media_file = AsyncMock(
side_effect=httpx.TransportError("connection reset")
)
channel._send_text = AsyncMock()
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
with pytest.raises(httpx.TransportError, match="connection reset"):
await channel.send(msg)
channel._send_text.assert_not_awaited()
@pytest.mark.asyncio
async def test_send_media_5xx_http_status_error_propagates_without_text_fallback() -> None:
"""httpx.HTTPStatusError with a 5xx status must re-raise immediately."""
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-1"
fake_response = httpx.Response(
status_code=503,
request=httpx.Request("POST", "https://example.test/upload"),
)
channel._send_media_file = AsyncMock(
side_effect=httpx.HTTPStatusError(
"Service Unavailable", request=fake_response.request, response=fake_response
)
)
channel._send_text = AsyncMock()
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
with pytest.raises(httpx.HTTPStatusError, match="Service Unavailable"):
await channel.send(msg)
channel._send_text.assert_not_awaited()
@pytest.mark.asyncio
async def test_send_media_4xx_http_status_error_falls_back_to_text() -> None:
"""httpx.HTTPStatusError with a 4xx status should fall back to text, not re-raise."""
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-1"
fake_response = httpx.Response(
status_code=400,
request=httpx.Request("POST", "https://example.test/upload"),
)
channel._send_media_file = AsyncMock(
side_effect=httpx.HTTPStatusError(
"Bad Request", request=fake_response.request, response=fake_response
)
)
channel._send_text = AsyncMock()
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
# Should NOT raise — 4xx is a client error, non-retryable
await channel.send(msg)
# _send_text should have been called with the fallback message
channel._send_text.assert_awaited_once_with(
"wx-user", "[Failed to send: photo.jpg]", "ctx-1"
)
@pytest.mark.asyncio
async def test_send_media_file_not_found_falls_back_to_text() -> None:
"""FileNotFoundError (a non-network error) should fall back to text."""
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-1"
channel._send_media_file = AsyncMock(
side_effect=FileNotFoundError("Media file not found: /tmp/missing.jpg")
)
channel._send_text = AsyncMock()
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/missing.jpg"])
# Should NOT raise
await channel.send(msg)
channel._send_text.assert_awaited_once_with(
"wx-user", "[Failed to send: missing.jpg]", "ctx-1"
)
@pytest.mark.asyncio
async def test_send_media_value_error_falls_back_to_text() -> None:
"""ValueError (e.g. unsupported format) should fall back to text."""
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-1"
channel._send_media_file = AsyncMock(
side_effect=ValueError("Unsupported media format")
)
channel._send_text = AsyncMock()
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/file.xyz"])
# Should NOT raise
await channel.send(msg)
channel._send_text.assert_awaited_once_with(
"wx-user", "[Failed to send: file.xyz]", "ctx-1"
)
@pytest.mark.asyncio
async def test_send_media_network_error_does_not_double_api_calls() -> None:
"""During network issues, media send should make exactly 1 API call attempt,
not 2 (media + text fallback). Verify total call count."""
channel, _bus = _make_channel()
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-1"
channel._send_media_file = AsyncMock(
side_effect=httpx.ConnectError("connection refused")
)
channel._send_text = AsyncMock()
msg = _make_outbound_msg(chat_id="wx-user", content="hello", media=["/tmp/img.png"])
with pytest.raises(httpx.ConnectError):
await channel.send(msg)
# _send_media_file called once, _send_text never called
channel._send_media_file.assert_awaited_once()
channel._send_text.assert_not_awaited()

View File

@ -0,0 +1,227 @@
"""Lightweight WebSocket test client for integration testing the nanobot WebSocket channel.
Provides an async ``WsTestClient`` class and token-issuance helpers that
integration tests can import and use directly::
from ws_test_client import WsTestClient
async with WsTestClient("ws://127.0.0.1:8765/", client_id="t") as c:
ready = await c.recv_ready()
await c.send_text("hello")
msg = await c.recv_message()
"""
from __future__ import annotations
import asyncio
import json
from dataclasses import dataclass, field
from typing import Any
import httpx
import websockets
from websockets.asyncio.client import ClientConnection
@dataclass
class WsMessage:
"""A parsed message received from the WebSocket server."""
event: str
raw: dict[str, Any] = field(repr=False)
@property
def text(self) -> str | None:
return self.raw.get("text")
@property
def chat_id(self) -> str | None:
return self.raw.get("chat_id")
@property
def client_id(self) -> str | None:
return self.raw.get("client_id")
@property
def media(self) -> list[str] | None:
return self.raw.get("media")
@property
def reply_to(self) -> str | None:
return self.raw.get("reply_to")
@property
def stream_id(self) -> str | None:
return self.raw.get("stream_id")
def __eq__(self, other: object) -> bool:
if not isinstance(other, WsMessage):
return NotImplemented
return self.event == other.event and self.raw == other.raw
class WsTestClient:
"""Async WebSocket test client with helper methods for common operations.
Usage::
async with WsTestClient("ws://127.0.0.1:8765/", client_id="tester") as client:
ready = await client.recv_ready()
await client.send_text("hello")
msg = await client.recv_message(timeout=5.0)
"""
def __init__(
self,
uri: str,
*,
client_id: str = "test-client",
token: str = "",
extra_headers: dict[str, str] | None = None,
) -> None:
params: list[str] = []
if client_id:
params.append(f"client_id={client_id}")
if token:
params.append(f"token={token}")
sep = "&" if "?" in uri else "?"
self._uri = uri + sep + "&".join(params) if params else uri
self._extra_headers = extra_headers
self._ws: ClientConnection | None = None
async def connect(self) -> None:
self._ws = await websockets.connect(
self._uri,
additional_headers=self._extra_headers,
)
async def close(self) -> None:
if self._ws:
await self._ws.close()
self._ws = None
async def __aenter__(self) -> WsTestClient:
await self.connect()
return self
async def __aexit__(self, *args: Any) -> None:
await self.close()
@property
def ws(self) -> ClientConnection:
assert self._ws is not None, "Client is not connected"
return self._ws
# -- Receiving --------------------------------------------------------
async def recv_raw(self, timeout: float = 10.0) -> dict[str, Any]:
"""Receive and parse one raw JSON message with timeout."""
raw = await asyncio.wait_for(self.ws.recv(), timeout=timeout)
return json.loads(raw)
async def recv(self, timeout: float = 10.0) -> WsMessage:
"""Receive one message, returning a WsMessage wrapper."""
data = await self.recv_raw(timeout)
return WsMessage(event=data.get("event", ""), raw=data)
async def recv_ready(self, timeout: float = 5.0) -> WsMessage:
"""Receive and validate the 'ready' event."""
msg = await self.recv(timeout)
assert msg.event == "ready", f"Expected 'ready' event, got '{msg.event}'"
return msg
async def recv_message(self, timeout: float = 10.0) -> WsMessage:
"""Receive and validate a 'message' event."""
msg = await self.recv(timeout)
assert msg.event == "message", f"Expected 'message' event, got '{msg.event}'"
return msg
async def recv_delta(self, timeout: float = 10.0) -> WsMessage:
"""Receive and validate a 'delta' event."""
msg = await self.recv(timeout)
assert msg.event == "delta", f"Expected 'delta' event, got '{msg.event}'"
return msg
async def recv_stream_end(self, timeout: float = 10.0) -> WsMessage:
"""Receive and validate a 'stream_end' event."""
msg = await self.recv(timeout)
assert msg.event == "stream_end", f"Expected 'stream_end' event, got '{msg.event}'"
return msg
async def collect_stream(self, timeout: float = 10.0) -> list[WsMessage]:
"""Collect all deltas and the final stream_end into a list."""
messages: list[WsMessage] = []
while True:
msg = await self.recv(timeout)
messages.append(msg)
if msg.event == "stream_end":
break
return messages
async def recv_n(self, n: int, timeout: float = 10.0) -> list[WsMessage]:
"""Receive exactly *n* messages."""
return [await self.recv(timeout) for _ in range(n)]
# -- Sending ----------------------------------------------------------
async def send_text(self, text: str) -> None:
"""Send a plain text frame."""
await self.ws.send(text)
async def send_json(self, data: dict[str, Any]) -> None:
"""Send a JSON frame."""
await self.ws.send(json.dumps(data, ensure_ascii=False))
async def send_content(self, content: str) -> None:
"""Send content in the preferred JSON format ``{"content": ...}``."""
await self.send_json({"content": content})
# -- Connection introspection -----------------------------------------
@property
def closed(self) -> bool:
return self._ws is None or self._ws.closed
# -- Token issuance helpers -----------------------------------------------
async def issue_token(
host: str = "127.0.0.1",
port: int = 8765,
issue_path: str = "/auth/token",
secret: str = "",
) -> tuple[dict[str, Any] | None, int]:
"""Request a short-lived token from the token-issue HTTP endpoint.
Returns ``(parsed_json_or_None, status_code)``.
"""
url = f"http://{host}:{port}{issue_path}"
headers: dict[str, str] = {}
if secret:
headers["Authorization"] = f"Bearer {secret}"
loop = asyncio.get_running_loop()
resp = await loop.run_in_executor(
None, lambda: httpx.get(url, headers=headers, timeout=5.0)
)
try:
data = resp.json()
except Exception:
data = None
return data, resp.status_code
async def issue_token_ok(
host: str = "127.0.0.1",
port: int = 8765,
issue_path: str = "/auth/token",
secret: str = "",
) -> str:
"""Request a token, asserting success, and return the token string."""
(data, status) = await issue_token(host, port, issue_path, secret)
assert status == 200, f"Token issue failed with status {status}"
assert data is not None
token = data["token"]
assert token.startswith("nbwt_"), f"Unexpected token format: {token}"
return token

View File

@ -1126,6 +1126,153 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
assert "port 18792" in result.stdout
def test_gateway_health_endpoint_binds_and_serves_expected_responses(
monkeypatch, tmp_path: Path
) -> None:
config_file = _write_instance_config(tmp_path)
config = Config()
config.gateway.port = 18791
captured: dict[str, object] = {}
class _FakeDream:
model = None
max_batch_size = 0
max_iterations = 0
async def run(self) -> None:
return None
class _FakeAgentLoop:
def __init__(self, **_kwargs) -> None:
self.model = "test-model"
self.dream = _FakeDream()
async def run(self) -> None:
await asyncio.Event().wait()
async def close_mcp(self) -> None:
return None
def stop(self) -> None:
return None
class _FakeChannelManager:
def __init__(self, _config, _bus) -> None:
self.enabled_channels = ["telegram", "discord"]
async def start_all(self) -> None:
await asyncio.Event().wait()
async def stop_all(self) -> None:
return None
class _FakeCronService:
def __init__(self, _store_path: Path) -> None:
self.on_job = None
async def start(self) -> None:
return None
def stop(self) -> None:
return None
def status(self) -> dict[str, int]:
return {"jobs": 0}
def register_system_job(self, _job) -> None:
return None
class _FakeHeartbeatService:
def __init__(self, **_kwargs) -> None:
return None
async def start(self) -> None:
return None
def stop(self) -> None:
return None
class _FakeServer:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
async def serve_forever(self) -> None:
raise _StopGatewayError("stop")
async def _fake_start_server(handler, host: str, port: int):
captured["handler"] = handler
captured["host"] = host
captured["port"] = port
return _FakeServer()
class _FakeReader:
def __init__(self, payload: bytes) -> None:
self.payload = payload
async def read(self, _size: int) -> bytes:
return self.payload
class _FakeWriter:
def __init__(self) -> None:
self.output = b""
self.closed = False
def write(self, data: bytes) -> None:
self.output += data
async def drain(self) -> None:
return None
def close(self) -> None:
self.closed = True
_patch_cli_command_runtime(
monkeypatch,
config,
message_bus=lambda: object(),
session_manager=lambda _workspace: object(),
)
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager)
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCronService)
monkeypatch.setattr("nanobot.heartbeat.service.HeartbeatService", _FakeHeartbeatService)
monkeypatch.setattr("asyncio.start_server", _fake_start_server)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert result.exit_code == 0
assert captured["host"] == "127.0.0.1"
assert captured["port"] == 18791
assert "Health endpoint: http://127.0.0.1:18791/health" in result.stdout
def _call_handler(path: str) -> tuple[str, _FakeWriter]:
request = f"GET {path} HTTP/1.1\r\nHost: localhost\r\n\r\n".encode()
writer = _FakeWriter()
handler = captured["handler"]
assert callable(handler)
asyncio.run(handler(_FakeReader(request), writer))
return writer.output.decode(), writer
root_response, root_writer = _call_handler("/")
assert root_writer.closed is True
assert "HTTP/1.0 404 Not Found" in root_response
assert root_response.endswith("\r\n\r\nNot Found")
health_response, health_writer = _call_handler("/health")
assert health_writer.closed is True
assert "HTTP/1.0 200 OK" in health_response
health_body = json.loads(health_response.split("\r\n\r\n", 1)[1])
assert health_body == {"status": "ok"}
missing_response, missing_writer = _call_handler("/missing")
assert missing_writer.closed is True
assert "HTTP/1.0 404 Not Found" in missing_response
assert missing_response.endswith("\r\n\r\nNot Found")
def test_serve_uses_api_config_defaults_and_workspace_override(
monkeypatch, tmp_path: Path
) -> None:

View File

@ -148,7 +148,7 @@ class TestRestartCommand:
assert response is not None
assert "Model: test-model" in response.content
assert "Tokens: 0 in / 0 out" in response.content
assert "Context: 20k/64k (31%)" in response.content
assert "Context: 20k/65k (31%)" in response.content
assert "Session: 3 messages" in response.content
assert "Uptime: 2m 5s" in response.content
assert response.metadata == {"render_as": "text"}
@ -186,7 +186,7 @@ class TestRestartCommand:
assert response is not None
assert "Tokens: 1200 in / 34 out" in response.content
assert "Context: 1k/64k (1%)" in response.content
assert "Context: 1k/65k (1%)" in response.content
@pytest.mark.asyncio
async def test_process_direct_preserves_render_metadata(self):

View File

@ -1,5 +1,6 @@
import asyncio
import json
import time
import pytest
@ -114,6 +115,41 @@ async def test_run_history_persisted_to_disk(tmp_path) -> None:
assert loaded.state.run_history[0].status == "ok"
@pytest.mark.asyncio
async def test_run_job_disabled_does_not_flip_running_state(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
job = service.add_job(
name="disabled",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
service.enable_job(job.id, enabled=False)
result = await service.run_job(job.id)
assert result is False
assert service._running is False
@pytest.mark.asyncio
async def test_run_job_preserves_running_service_state(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
service._running = True
job = service.add_job(
name="manual",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
result = await service.run_job(job.id, force=True)
assert result is True
assert service._running is True
service.stop()
@pytest.mark.asyncio
async def test_running_service_honors_external_disable(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
@ -158,24 +194,49 @@ def test_remove_job_refuses_system_jobs(tmp_path) -> None:
assert service.get_job("dream") is not None
def test_reload_jobs(tmp_path):
@pytest.mark.asyncio
async def test_start_server_not_jobs(tmp_path):
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
service.add_job(
name="hist",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
called = []
async def on_job(job):
called.append(job.name)
assert len(service.list_jobs()) == 1
service = CronService(store_path, on_job=on_job, max_sleep_ms=1000)
await service.start()
assert len(service.list_jobs()) == 0
service2 = CronService(tmp_path / "cron" / "jobs.json")
service2.add_job(
name="hist2",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello2",
name="hist",
schedule=CronSchedule(kind="every", every_ms=500),
message="hello",
)
assert len(service.list_jobs()) == 2
assert len(service.list_jobs()) == 1
await asyncio.sleep(2)
assert len(called) != 0
service.stop()
@pytest.mark.asyncio
async def test_subsecond_job_not_delayed_to_one_second(tmp_path):
store_path = tmp_path / "cron" / "jobs.json"
called = []
async def on_job(job):
called.append(job.name)
service = CronService(store_path, on_job=on_job, max_sleep_ms=5000)
service.add_job(
name="fast",
schedule=CronSchedule(kind="every", every_ms=100),
message="hello",
)
await service.start()
try:
await asyncio.sleep(0.35)
assert called
finally:
service.stop()
@pytest.mark.asyncio
@ -204,7 +265,302 @@ async def test_running_service_picks_up_external_add(tmp_path):
message="ping",
)
await asyncio.sleep(0.6)
await asyncio.sleep(2)
assert "external" in called
finally:
service.stop()
@pytest.mark.asyncio
async def test_add_job_during_jobs_exec(tmp_path):
store_path = tmp_path / "cron" / "jobs.json"
run_once = True
async def on_job(job):
nonlocal run_once
if run_once:
service2 = CronService(store_path, on_job=lambda x: asyncio.sleep(0))
service2.add_job(
name="test",
schedule=CronSchedule(kind="every", every_ms=150),
message="tick",
)
run_once = False
service = CronService(store_path, on_job=on_job)
service.add_job(
name="heartbeat",
schedule=CronSchedule(kind="every", every_ms=150),
message="tick",
)
assert len(service.list_jobs()) == 1
await service.start()
try:
await asyncio.sleep(3)
jobs = service.list_jobs()
assert len(jobs) == 2
assert "test" in [j.name for j in jobs]
finally:
service.stop()
@pytest.mark.asyncio
async def test_external_update_preserves_run_history_records(tmp_path):
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
job = service.add_job(
name="history",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
await service.run_job(job.id, force=True)
external = CronService(store_path)
updated = external.enable_job(job.id, enabled=False)
assert updated is not None
fresh = CronService(store_path)
loaded = fresh.get_job(job.id)
assert loaded is not None
assert loaded.state.run_history
assert loaded.state.run_history[0].status == "ok"
fresh._running = True
fresh._save_store()
# ── timer race regression tests ──
@pytest.mark.asyncio
async def test_timer_execution_is_not_rolled_back_by_list_jobs_reload(tmp_path):
"""list_jobs() during _on_timer should not replace the active store and re-run the same due job."""
store_path = tmp_path / "cron" / "jobs.json"
calls: list[str] = []
async def on_job(job):
calls.append(job.id)
# Simulate frontend polling list_jobs while the timer callback is mid-execution.
service.list_jobs(include_disabled=True)
await asyncio.sleep(0)
service = CronService(store_path, on_job=on_job)
service._running = True
service._load_store()
service._arm_timer = lambda: None
job = service.add_job(
name="race",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
job.state.next_run_at_ms = max(1, int(time.time() * 1000) - 1_000)
service._save_store()
await service._on_timer()
await service._on_timer()
assert calls == [job.id]
loaded = service.get_job(job.id)
assert loaded is not None
assert loaded.state.last_run_at_ms is not None
assert loaded.state.next_run_at_ms is not None
assert loaded.state.next_run_at_ms > loaded.state.last_run_at_ms
# ── update_job tests ──
def test_update_job_changes_name(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
job = service.add_job(
name="old name",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
result = service.update_job(job.id, name="new name")
assert isinstance(result, CronJob)
assert result.name == "new name"
assert result.payload.message == "hello"
def test_update_job_changes_schedule(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
job = service.add_job(
name="sched",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
old_next = job.state.next_run_at_ms
new_sched = CronSchedule(kind="every", every_ms=120_000)
result = service.update_job(job.id, schedule=new_sched)
assert isinstance(result, CronJob)
assert result.schedule.every_ms == 120_000
assert result.state.next_run_at_ms != old_next
def test_update_job_changes_message(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
job = service.add_job(
name="msg",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="old message",
)
result = service.update_job(job.id, message="new message")
assert isinstance(result, CronJob)
assert result.payload.message == "new message"
def test_update_job_changes_cron_expression(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
job = service.add_job(
name="cron-job",
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
message="hello",
)
result = service.update_job(
job.id,
schedule=CronSchedule(kind="cron", expr="0 18 * * *", tz="UTC"),
)
assert isinstance(result, CronJob)
assert result.schedule.expr == "0 18 * * *"
assert result.state.next_run_at_ms is not None
def test_update_job_not_found(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
result = service.update_job("nonexistent", name="x")
assert result == "not_found"
def test_update_job_rejects_system_job(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
service.register_system_job(CronJob(
id="dream",
name="dream",
schedule=CronSchedule(kind="cron", expr="0 */2 * * *", tz="UTC"),
payload=CronPayload(kind="system_event"),
))
result = service.update_job("dream", name="hacked")
assert result == "protected"
assert service.get_job("dream").name == "dream"
def test_update_job_validates_schedule(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
job = service.add_job(
name="validate",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
with pytest.raises(ValueError, match="unknown timezone"):
service.update_job(
job.id,
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="Bad/Zone"),
)
@pytest.mark.asyncio
async def test_update_job_preserves_run_history(tmp_path) -> None:
import asyncio
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
job = service.add_job(
name="hist",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
await service.run_job(job.id)
result = service.update_job(job.id, name="renamed")
assert isinstance(result, CronJob)
assert len(result.state.run_history) == 1
assert result.state.run_history[0].status == "ok"
def test_update_job_offline_writes_action(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
job = service.add_job(
name="offline",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
)
service.update_job(job.id, name="updated-offline")
action_path = tmp_path / "cron" / "action.jsonl"
assert action_path.exists()
lines = [l for l in action_path.read_text().strip().split("\n") if l]
last = json.loads(lines[-1])
assert last["action"] == "update"
assert last["params"]["name"] == "updated-offline"
def test_update_job_sentinel_channel_and_to(tmp_path) -> None:
"""Passing None clears channel/to; omitting leaves them unchanged."""
service = CronService(tmp_path / "cron" / "jobs.json")
job = service.add_job(
name="sentinel",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
channel="telegram",
to="user123",
)
assert job.payload.channel == "telegram"
assert job.payload.to == "user123"
result = service.update_job(job.id, name="renamed")
assert isinstance(result, CronJob)
assert result.payload.channel == "telegram"
assert result.payload.to == "user123"
result = service.update_job(job.id, channel=None, to=None)
assert isinstance(result, CronJob)
assert result.payload.channel is None
assert result.payload.to is None
@pytest.mark.asyncio
async def test_list_jobs_during_on_job_does_not_cause_stale_reload(tmp_path) -> None:
"""Regression: if the bot calls list_jobs (which reloads from disk) during
on_job execution, the in-memory next_run_at_ms update must not be lost.
Previously this caused an infinite re-trigger loop."""
store_path = tmp_path / "cron" / "jobs.json"
execution_count = 0
async def on_job_that_lists(job):
nonlocal execution_count
execution_count += 1
# Simulate the bot calling cron(action=list) mid-execution
service.list_jobs()
service = CronService(store_path, on_job=on_job_that_lists, max_sleep_ms=100)
await service.start()
# Add two jobs scheduled in the past so they're immediately due
now_ms = int(time.time() * 1000)
for name in ("job-a", "job-b"):
service.add_job(
name=name,
schedule=CronSchedule(kind="every", every_ms=3_600_000),
message="test",
)
# Force next_run to the past so _on_timer picks them up
for job in service._store.jobs:
job.state.next_run_at_ms = now_ms - 1000
service._save_store()
service._arm_timer()
# Let the timer fire once
await asyncio.sleep(0.3)
service.stop()
# Each job should have run exactly once, not looped
assert execution_count == 2
# Verify next_run_at_ms was persisted correctly (in the future)
raw = json.loads(store_path.read_text())
for j in raw["jobs"]:
next_run = j["state"]["nextRunAtMs"]
assert next_run is not None
assert next_run > now_ms, f"Job '{j['name']}' next_run should be in the future"

View File

@ -2,9 +2,12 @@
from datetime import datetime, timezone
import pytest
from nanobot.agent.tools.cron import CronTool
from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronSchedule
from tests.test_openai_api import pytest_plugins
def _make_tool(tmp_path) -> CronTool:
@ -215,8 +218,10 @@ def test_list_at_job_shows_iso_timestamp(tmp_path) -> None:
assert "Asia/Shanghai" in result
def test_list_shows_last_run_state(tmp_path) -> None:
@pytest.mark.asyncio
async def test_list_shows_last_run_state(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool._cron._running = True
job = tool._cron.add_job(
name="Stateful job",
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
@ -232,9 +237,10 @@ def test_list_shows_last_run_state(tmp_path) -> None:
assert "ok" in result
assert "(UTC)" in result
def test_list_shows_error_message(tmp_path) -> None:
@pytest.mark.asyncio
async def test_list_shows_error_message(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool._cron._running = True
job = tool._cron.add_job(
name="Failed job",
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),

View File

@ -4,6 +4,7 @@ from types import SimpleNamespace
from unittest.mock import patch
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
from nanobot.providers.registry import find_by_name
def test_custom_provider_parse_handles_empty_choices() -> None:
@ -53,3 +54,20 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None:
assert result.finish_reason == "stop"
assert result.content == "hello world"
def test_local_provider_502_error_includes_reachability_hint() -> None:
spec = find_by_name("ollama")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider(api_base="http://localhost:11434/v1", spec=spec)
result = provider._handle_error(
Exception("Error code: 502"),
spec=spec,
api_base="http://localhost:11434/v1",
)
assert result.finish_reason == "error"
assert "local model endpoint" in result.content
assert "http://localhost:11434/v1" in result.content
assert "proxy/tunnel" in result.content

View File

@ -0,0 +1,197 @@
"""Tests for LLMProvider._enforce_role_alternation."""
from nanobot.providers.base import LLMProvider
class TestEnforceRoleAlternation:
"""Verify trailing-assistant removal and consecutive same-role merging."""
def test_empty_messages(self):
assert LLMProvider._enforce_role_alternation([]) == []
def test_no_change_needed(self):
msgs = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "user", "content": "Bye"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 4
assert result[-1]["role"] == "user"
def test_trailing_assistant_removed(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 1
assert result[0]["role"] == "user"
def test_multiple_trailing_assistants_removed(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "A"},
{"role": "assistant", "content": "B"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 1
assert result[0]["role"] == "user"
def test_consecutive_user_messages_merged(self):
msgs = [
{"role": "user", "content": "Hello"},
{"role": "user", "content": "How are you?"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 1
assert "Hello" in result[0]["content"]
assert "How are you?" in result[0]["content"]
def test_consecutive_assistant_messages_merged(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "assistant", "content": "How can I help?"},
{"role": "user", "content": "Thanks"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 3
assert "Hello!" in result[1]["content"]
assert "How can I help?" in result[1]["content"]
def test_system_messages_not_merged(self):
msgs = [
{"role": "system", "content": "System A"},
{"role": "system", "content": "System B"},
{"role": "user", "content": "Hi"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 3
assert result[0]["content"] == "System A"
assert result[1]["content"] == "System B"
def test_tool_messages_not_merged(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
{"role": "tool", "content": "result1", "tool_call_id": "1"},
{"role": "tool", "content": "result2", "tool_call_id": "2"},
{"role": "user", "content": "Next"},
]
result = LLMProvider._enforce_role_alternation(msgs)
tool_msgs = [m for m in result if m["role"] == "tool"]
assert len(tool_msgs) == 2
def test_consecutive_assistant_keeps_later_tool_call_message(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Previous reply"},
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
{"role": "tool", "content": "result1", "tool_call_id": "1"},
{"role": "user", "content": "Next"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert result[1]["role"] == "assistant"
assert result[1]["tool_calls"] == [{"id": "1"}]
assert result[1]["content"] is None
assert result[2]["role"] == "tool"
def test_consecutive_assistant_does_not_overwrite_existing_tool_call_message(self):
msgs = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
{"role": "assistant", "content": "Later plain assistant"},
{"role": "tool", "content": "result1", "tool_call_id": "1"},
{"role": "user", "content": "Next"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert result[1]["role"] == "assistant"
assert result[1]["tool_calls"] == [{"id": "1"}]
assert result[1]["content"] is None
assert result[2]["role"] == "tool"
def test_non_string_content_uses_latest(self):
msgs = [
{"role": "user", "content": [{"type": "text", "text": "A"}]},
{"role": "user", "content": "B"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 1
assert result[0]["content"] == "B"
def test_original_messages_not_mutated(self):
msgs = [
{"role": "user", "content": "Hello"},
{"role": "user", "content": "World"},
]
original_first = dict(msgs[0])
LLMProvider._enforce_role_alternation(msgs)
assert msgs[0] == original_first
assert len(msgs) == 2
def test_trailing_assistant_recovered_as_user_when_only_system_remains(self):
"""Subagent result injected as assistant message must not be silently dropped.
When build_messages(current_role="assistant") produces [system, assistant],
_enforce_role_alternation would drop the assistant, leaving only [system].
Most providers (e.g. Zhipu/GLM error 1214) reject such requests.
The trailing assistant should be recovered as a user message instead.
"""
msgs = [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Subagent completed successfully."},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 2
assert result[0]["role"] == "system"
assert result[1]["role"] == "user"
assert "Subagent completed successfully." in result[1]["content"]
def test_trailing_assistant_not_recovered_when_user_message_present(self):
"""Recovery should NOT happen when a user message already exists."""
msgs = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 2
assert result[-1]["role"] == "user"
def test_trailing_assistant_recovered_with_tool_result_preceding(self):
"""When only [system, tool, assistant] remains, recovery is not needed
because tool messages are valid non-system content."""
msgs = [
{"role": "system", "content": "You are helpful."},
{"role": "tool", "content": "result", "tool_call_id": "1"},
{"role": "assistant", "content": "Done."},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 2
assert result[-1]["role"] == "tool"
def test_only_assistant_messages(self):
msgs = [
{"role": "assistant", "content": "A"},
{"role": "assistant", "content": "B"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert result == []
def test_realistic_conversation(self):
msgs = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
{"role": "user", "content": "And 3+3?"},
{"role": "user", "content": "(please be quick)"},
{"role": "assistant", "content": "6"},
]
result = LLMProvider._enforce_role_alternation(msgs)
assert len(result) == 4
assert result[2]["role"] == "assistant"
assert result[3]["role"] == "user"
assert "And 3+3?" in result[3]["content"]
assert "(please be quick)" in result[3]["content"]

View File

@ -10,7 +10,7 @@ from __future__ import annotations
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -54,6 +54,57 @@ def _fake_tool_call_response() -> SimpleNamespace:
return SimpleNamespace(choices=[choice], usage=usage)
def _fake_responses_response(content: str = "ok") -> MagicMock:
"""Build a minimal Responses API response object."""
resp = MagicMock()
resp.model_dump.return_value = {
"output": [{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": content}],
}],
"status": "completed",
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
}
return resp
def _fake_responses_stream(text: str = "ok"):
async def _stream():
yield SimpleNamespace(type="response.output_text.delta", delta=text)
yield SimpleNamespace(
type="response.completed",
response=SimpleNamespace(
status="completed",
usage=SimpleNamespace(input_tokens=10, output_tokens=5, total_tokens=15),
output=[],
),
)
return _stream()
def _fake_chat_stream(text: str = "ok"):
async def _stream():
yield SimpleNamespace(
choices=[SimpleNamespace(finish_reason=None, delta=SimpleNamespace(content=text, reasoning_content=None, tool_calls=None))],
usage=None,
)
yield SimpleNamespace(
choices=[SimpleNamespace(finish_reason="stop", delta=SimpleNamespace(content=None, reasoning_content=None, tool_calls=None))],
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
)
return _stream()
class _FakeResponsesError(Exception):
def __init__(self, status_code: int, text: str):
super().__init__(text)
self.status_code = status_code
self.response = SimpleNamespace(status_code=status_code, text=text, headers={})
class _StalledStream:
def __aiter__(self):
return self
@ -226,6 +277,224 @@ def test_openai_model_passthrough() -> None:
assert provider.get_default_model() == "gpt-4o"
@pytest.mark.asyncio
async def test_direct_openai_gpt5_uses_responses_api() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response())
mock_responses = AsyncMock(return_value=_fake_responses_response("from responses"))
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-5-chat",
spec=spec,
)
result = await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="gpt-5-chat",
)
assert result.content == "from responses"
mock_responses.assert_awaited_once()
mock_chat.assert_not_awaited()
call_kwargs = mock_responses.call_args.kwargs
assert call_kwargs["model"] == "gpt-5-chat"
assert call_kwargs["max_output_tokens"] == 4096
assert "input" in call_kwargs
assert "messages" not in call_kwargs
@pytest.mark.asyncio
async def test_direct_openai_reasoning_prefers_responses_api() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response())
mock_responses = AsyncMock(return_value=_fake_responses_response("reasoned"))
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-4o",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="gpt-4o",
reasoning_effort="medium",
)
mock_responses.assert_awaited_once()
mock_chat.assert_not_awaited()
call_kwargs = mock_responses.call_args.kwargs
assert call_kwargs["reasoning"] == {"effort": "medium"}
assert call_kwargs["include"] == ["reasoning.encrypted_content"]
@pytest.mark.asyncio
async def test_direct_openai_gpt4o_stays_on_chat_completions() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response())
mock_responses = AsyncMock(return_value=_fake_responses_response())
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-4o",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="gpt-4o",
)
mock_chat.assert_awaited_once()
mock_responses.assert_not_awaited()
@pytest.mark.asyncio
async def test_openrouter_gpt5_stays_on_chat_completions() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response())
mock_responses = AsyncMock(return_value=_fake_responses_response())
spec = find_by_name("openrouter")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-or-test-key",
api_base="https://openrouter.ai/api/v1",
default_model="openai/gpt-5",
spec=spec,
)
await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="openai/gpt-5",
)
mock_chat.assert_awaited_once()
mock_responses.assert_not_awaited()
@pytest.mark.asyncio
async def test_direct_openai_streaming_gpt5_uses_responses_api() -> None:
mock_chat = AsyncMock(return_value=_StalledStream())
mock_responses = AsyncMock(return_value=_fake_responses_stream("hi"))
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-5-chat",
spec=spec,
)
result = await provider.chat_stream(
messages=[{"role": "user", "content": "hello"}],
model="gpt-5-chat",
)
assert result.content == "hi"
assert result.finish_reason == "stop"
mock_responses.assert_awaited_once()
mock_chat.assert_not_awaited()
@pytest.mark.asyncio
async def test_direct_openai_responses_404_falls_back_to_chat_completions() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response("from chat"))
mock_responses = AsyncMock(side_effect=_FakeResponsesError(404, "Responses endpoint not supported"))
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-5-chat",
spec=spec,
)
result = await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="gpt-5-chat",
)
assert result.content == "from chat"
mock_responses.assert_awaited_once()
mock_chat.assert_awaited_once()
@pytest.mark.asyncio
async def test_direct_openai_stream_responses_unsupported_param_falls_back() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_stream("fallback stream"))
mock_responses = AsyncMock(
side_effect=_FakeResponsesError(400, "Unknown parameter: max_output_tokens for Responses API")
)
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-5-chat",
spec=spec,
)
result = await provider.chat_stream(
messages=[{"role": "user", "content": "hello"}],
model="gpt-5-chat",
)
assert result.content == "fallback stream"
mock_responses.assert_awaited_once()
mock_chat.assert_awaited_once()
@pytest.mark.asyncio
async def test_direct_openai_responses_rate_limit_does_not_fallback() -> None:
mock_chat = AsyncMock(return_value=_fake_chat_response("from chat"))
mock_responses = AsyncMock(side_effect=_FakeResponsesError(429, "rate limit"))
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_chat
client_instance.responses.create = mock_responses
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-5-chat",
spec=spec,
)
result = await provider.chat(
messages=[{"role": "user", "content": "hello"}],
model="gpt-5-chat",
)
assert result.finish_reason == "error"
mock_responses.assert_awaited_once()
mock_chat.assert_not_awaited()
def test_openai_compat_supports_temperature_matches_reasoning_model_rules() -> None:
assert OpenAICompatProvider._supports_temperature("gpt-4o") is True
assert OpenAICompatProvider._supports_temperature("gpt-5-chat") is False
@ -263,6 +532,7 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
provider = OpenAICompatProvider()
sanitized = provider._sanitize_messages([
{"role": "user", "content": "hi"},
{
"role": "assistant",
"content": "done",
@ -276,12 +546,42 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
"extra_content": {"google": {"thought_signature": "sig"}},
}
],
}
},
{"role": "user", "content": "thanks"},
])
assert sanitized[0]["reasoning_content"] == "hidden"
assert sanitized[0]["extra_content"] == {"debug": True}
assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
assert sanitized[1]["content"] is None
assert sanitized[1]["reasoning_content"] == "hidden"
assert sanitized[1]["extra_content"] == {"debug": True}
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
sanitized = provider._sanitize_messages([
{"role": "user", "content": "不错"},
{"role": "assistant", "content": "对,破 4 万指日可待"},
{
"role": "assistant",
"content": "<think>我再查一下</think>",
"tool_calls": [
{
"id": "call_function_akxp3wqzn7ph_1",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
}
],
},
{"role": "tool", "tool_call_id": "call_function_akxp3wqzn7ph_1", "name": "exec", "content": "ok"},
{"role": "user", "content": "多少star了呢"},
])
assert sanitized[1]["role"] == "assistant"
assert sanitized[1]["content"] is None
assert sanitized[1]["tool_calls"][0]["id"] == "3ec83c30d"
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
@pytest.mark.asyncio

View File

@ -1,4 +1,5 @@
import asyncio
import copy
import pytest
@ -152,7 +153,7 @@ async def test_non_transient_error_with_images_retries_without_images() -> None:
LLMResponse(content="ok, no image"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG))
assert response.content == "ok, no image"
assert provider.calls == 2
@ -164,6 +165,24 @@ async def test_non_transient_error_with_images_retries_without_images() -> None:
assert any("[image: /media/test.png]" in (b.get("text") or "") for b in content)
@pytest.mark.asyncio
async def test_successful_image_retry_mutates_original_messages_in_place() -> None:
"""Successful no-image retry should update the caller's message history."""
provider = ScriptedProvider([
LLMResponse(content="model does not support images", finish_reason="error"),
LLMResponse(content="ok, no image"),
])
messages = copy.deepcopy(_IMAGE_MSG)
response = await provider.chat_with_retry(messages=messages)
assert response.content == "ok, no image"
content = messages[0]["content"]
assert isinstance(content, list)
assert all(block.get("type") != "image_url" for block in content)
assert any("[image: /media/test.png]" in (block.get("text") or "") for block in content)
@pytest.mark.asyncio
async def test_non_transient_error_without_images_no_retry() -> None:
"""Non-transient errors without image content are returned immediately."""
@ -187,7 +206,7 @@ async def test_image_fallback_returns_error_on_second_failure() -> None:
LLMResponse(content="still failing", finish_reason="error"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG))
assert provider.calls == 2
assert response.content == "still failing"
@ -202,7 +221,7 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
LLMResponse(content="ok"),
])
response = await provider.chat_with_retry(messages=_IMAGE_MSG_NO_META)
response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG_NO_META))
assert response.content == "ok"
assert provider.calls == 2

View File

@ -10,8 +10,7 @@ import pytest
import pytest_asyncio
from nanobot.api.server import (
API_CHAT_ID,
API_SESSION_KEY,
_FileSizeExceeded,
_parse_json_content,
_save_base64_data_url,
create_app,
@ -91,6 +90,15 @@ def test_save_base64_data_url_handles_unknown_mime(tmp_path) -> None:
assert result.endswith(".bin")
def test_save_base64_data_url_rejects_oversized_payload(tmp_path) -> None:
"""Base64 uploads should respect the same per-file limit as multipart."""
large_payload = base64.b64encode(b"x" * (11 * 1024 * 1024)).decode()
data_url = f"data:image/png;base64,{large_payload}"
with pytest.raises(_FileSizeExceeded, match="10MB limit"):
_save_base64_data_url(data_url, tmp_path)
def test_parse_json_content_extracts_text_and_media(tmp_path) -> None:
"""Parse JSON with text + base64 image saves image and returns paths."""
b64_data = base64.b64encode(b"img").decode()
@ -144,6 +152,31 @@ def test_parse_json_content_validates_user_role() -> None:
_parse_json_content(body)
def test_parse_json_content_rejects_oversized_base64_file(tmp_path) -> None:
"""Oversized JSON data URLs should fail before writing to disk."""
large_payload = base64.b64encode(b"x" * (11 * 1024 * 1024)).decode()
body = {
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "describe"},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{large_payload}"}},
],
}
]
}
import os
original_cwd = os.getcwd()
os.chdir(tmp_path)
try:
with pytest.raises(_FileSizeExceeded, match="10MB limit"):
_parse_json_content(body)
finally:
os.chdir(original_cwd)
# ---------------------------------------------------------------------------
# Multipart upload tests
# ---------------------------------------------------------------------------

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import pytest
from pathlib import Path
from nanobot.agent.context import ContextBuilder
@ -64,3 +63,19 @@ def test_build_user_content_mixed_image_and_document(tmp_path: Path) -> None:
assert any(b["type"] == "image_url" for b in result)
text_parts = [b.get("text", "") for b in result if b.get("type") == "text"]
assert any("Report text here" in t for t in text_parts)
def test_build_user_content_skips_document_extraction_errors(tmp_path: Path, monkeypatch) -> None:
"""Document extraction errors should not be embedded into the user prompt."""
docx_path = tmp_path / "broken.docx"
docx_path.write_text("not a real docx", encoding="utf-8")
builder = _make_builder(tmp_path)
monkeypatch.setattr(
"nanobot.utils.document.extract_text",
lambda _path: "[error: failed to extract DOCX: boom]",
)
result = builder._build_user_content("summarize this", [str(docx_path)])
assert result == "summarize this"

View File

@ -1,10 +1,7 @@
"""Tests for document text extraction utilities."""
import io
from pathlib import Path
import pytest
from nanobot.utils.document import (
SUPPORTED_EXTENSIONS,
_is_text_extension,

View File

@ -0,0 +1,41 @@
from __future__ import annotations
import subprocess
import sys
import textwrap
from pathlib import Path
import tomllib
def test_source_checkout_import_uses_pyproject_version_without_metadata() -> None:
repo_root = Path(__file__).resolve().parents[1]
expected = tomllib.loads((repo_root / "pyproject.toml").read_text(encoding="utf-8"))["project"][
"version"
]
script = textwrap.dedent(
f"""
import sys
import types
sys.path.insert(0, {str(repo_root)!r})
fake = types.ModuleType("nanobot.nanobot")
fake.Nanobot = object
fake.RunResult = object
sys.modules["nanobot.nanobot"] = fake
import nanobot
print(nanobot.__version__)
"""
)
proc = subprocess.run(
[sys.executable, "-S", "-c", script],
capture_output=True,
text=True,
check=False,
)
assert proc.returncode == 0, proc.stderr
assert proc.stdout.strip() == expected

View File

@ -0,0 +1,31 @@
import inspect
from types import SimpleNamespace
def test_sanitize_persisted_blocks_truncate_text_shadowing_regression() -> None:
"""Regression: avoid bool param shadowing imported truncate_text.
Buggy behavior (historical):
- loop.py imports `truncate_text` from helpers
- `_sanitize_persisted_blocks(..., truncate_text: bool=...)` uses same name
- when called with `truncate_text=True`, function body executes `truncate_text(text, ...)`
which resolves to bool and raises `TypeError: 'bool' object is not callable`.
This test asserts the fixed API exists and truncation works without raising.
"""
from nanobot.agent.loop import AgentLoop
sig = inspect.signature(AgentLoop._sanitize_persisted_blocks)
assert "should_truncate_text" in sig.parameters
assert "truncate_text" not in sig.parameters
dummy = SimpleNamespace(max_tool_result_chars=5)
content = [{"type": "text", "text": "0123456789"}]
out = AgentLoop._sanitize_persisted_blocks(dummy, content, should_truncate_text=True)
assert isinstance(out, list)
assert out and out[0]["type"] == "text"
assert isinstance(out[0]["text"], str)
assert out[0]["text"] != content[0]["text"]

View File

@ -0,0 +1,423 @@
"""Tests for advanced EditFileTool enhancements inspired by claude-code:
- Delete-line newline cleanup
- Smart quote normalization (curly straight)
- Quote style preservation in replacements
- Indentation preservation when fallback match is trimmed
- Trailing whitespace stripping for new_text
- File size protection
- Stale detection with content-equality fallback
"""
import os
import time
import pytest
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, _find_match
from nanobot.agent.tools import file_state
@pytest.fixture(autouse=True)
def _clear_file_state():
file_state.clear()
yield
file_state.clear()
# ---------------------------------------------------------------------------
# Delete-line newline cleanup
# ---------------------------------------------------------------------------
class TestDeleteLineCleanup:
"""When new_text='' and deleting a line, trailing newline should be consumed."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_delete_line_consumes_trailing_newline(self, tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("line1\nline2\nline3\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="line2", new_text="")
assert "Successfully" in result
content = f.read_text()
# Should not leave a blank line where line2 was
assert content == "line1\nline3\n"
@pytest.mark.asyncio
async def test_delete_line_with_explicit_newline_in_old_text(self, tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("line1\nline2\nline3\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="line2\n", new_text="")
assert "Successfully" in result
assert f.read_text() == "line1\nline3\n"
@pytest.mark.asyncio
async def test_delete_preserves_content_when_not_trailing_newline(self, tool, tmp_path):
"""Deleting a word mid-line should not consume extra characters."""
f = tmp_path / "a.py"
f.write_text("hello world here\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="world ", new_text="")
assert "Successfully" in result
assert f.read_text() == "hello here\n"
# ---------------------------------------------------------------------------
# Smart quote normalization
# ---------------------------------------------------------------------------
class TestSmartQuoteNormalization:
"""_find_match should handle curly ↔ straight quote fallback."""
def test_curly_double_quotes_match_straight(self):
content = 'She said \u201chello\u201d to him'
old_text = 'She said "hello" to him'
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
# Returned match should be the ORIGINAL content with curly quotes
assert "\u201c" in match
def test_curly_single_quotes_match_straight(self):
content = "it\u2019s a test"
old_text = "it's a test"
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
assert "\u2019" in match
def test_straight_matches_curly_in_old_text(self):
content = 'x = "hello"'
old_text = 'x = \u201chello\u201d'
match, count = _find_match(content, old_text)
assert match is not None
assert count == 1
def test_exact_match_still_preferred_over_quote_normalization(self):
content = 'x = "hello"'
old_text = 'x = "hello"'
match, count = _find_match(content, old_text)
assert match == old_text
assert count == 1
class TestQuoteStylePreservation:
"""When quote-normalized matching occurs, replacement should preserve actual quote style."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_replacement_preserves_curly_double_quotes(self, tool, tmp_path):
f = tmp_path / "quotes.txt"
f.write_text('message = “hello”\n', encoding="utf-8")
result = await tool.execute(
path=str(f),
old_text='message = "hello"',
new_text='message = "goodbye"',
)
assert "Successfully" in result
assert f.read_text(encoding="utf-8") == 'message = “goodbye”\n'
@pytest.mark.asyncio
async def test_replacement_preserves_curly_apostrophe(self, tool, tmp_path):
f = tmp_path / "apostrophe.txt"
f.write_text("its fine\n", encoding="utf-8")
result = await tool.execute(
path=str(f),
old_text="it's fine",
new_text="it's better",
)
assert "Successfully" in result
assert f.read_text(encoding="utf-8") == "its better\n"
# ---------------------------------------------------------------------------
# Indentation preservation
# ---------------------------------------------------------------------------
class TestIndentationPreservation:
"""Replacement should keep outer indentation when trim fallback matched."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_trim_fallback_preserves_outer_indentation(self, tool, tmp_path):
f = tmp_path / "indent.py"
f.write_text(
"if True:\n"
" def foo():\n"
" pass\n",
encoding="utf-8",
)
result = await tool.execute(
path=str(f),
old_text="def foo():\n pass",
new_text="def bar():\n return 1",
)
assert "Successfully" in result
assert f.read_text(encoding="utf-8") == (
"if True:\n"
" def bar():\n"
" return 1\n"
)
# ---------------------------------------------------------------------------
# Failure diagnostics
# ---------------------------------------------------------------------------
class TestEditDiagnostics:
"""Failure paths should offer actionable hints."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_ambiguous_match_reports_candidate_lines(self, tool, tmp_path):
f = tmp_path / "dup.py"
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
assert "appears 2 times" in result.lower()
assert "line 1" in result.lower()
assert "line 3" in result.lower()
assert "replace_all=true" in result
@pytest.mark.asyncio
async def test_not_found_reports_whitespace_hint(self, tool, tmp_path):
f = tmp_path / "space.py"
f.write_text("value = 1\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="value = 1", new_text="value = 2")
assert "Error" in result
assert "whitespace" in result.lower()
@pytest.mark.asyncio
async def test_not_found_reports_case_hint(self, tool, tmp_path):
f = tmp_path / "case.py"
f.write_text("HelloWorld\n", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="helloworld", new_text="goodbye")
assert "Error" in result
assert "letter case differs" in result.lower()
# ---------------------------------------------------------------------------
# Advanced fallback replacement behavior
# ---------------------------------------------------------------------------
class TestAdvancedReplaceAll:
"""replace_all should work correctly for fallback-based matches too."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path):
f = tmp_path / "indent_multi.py"
f.write_text(
"if a:\n"
" def foo():\n"
" pass\n"
"if b:\n"
" def foo():\n"
" pass\n",
encoding="utf-8",
)
result = await tool.execute(
path=str(f),
old_text="def foo():\n pass",
new_text="def bar():\n return 1",
replace_all=True,
)
assert "Successfully" in result
assert f.read_text(encoding="utf-8") == (
"if a:\n"
" def bar():\n"
" return 1\n"
"if b:\n"
" def bar():\n"
" return 1\n"
)
@pytest.mark.asyncio
async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path):
f = tmp_path / "quote_indent.py"
f.write_text(" message = “hello”\n", encoding="utf-8")
result = await tool.execute(
path=str(f),
old_text='message = "hello"',
new_text='message = "goodbye"',
)
assert "Successfully" in result
assert f.read_text(encoding="utf-8") == " message = “goodbye”\n"
# ---------------------------------------------------------------------------
# Advanced fallback replacement behavior
# ---------------------------------------------------------------------------
class TestAdvancedReplaceAll:
"""replace_all should work correctly for fallback-based matches too."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path):
f = tmp_path / "indent_multi.py"
f.write_text(
"if a:\n"
" def foo():\n"
" pass\n"
"if b:\n"
" def foo():\n"
" pass\n",
encoding="utf-8",
)
result = await tool.execute(
path=str(f),
old_text="def foo():\n pass",
new_text="def bar():\n return 1",
replace_all=True,
)
assert "Successfully" in result
assert f.read_text(encoding="utf-8") == (
"if a:\n"
" def bar():\n"
" return 1\n"
"if b:\n"
" def bar():\n"
" return 1\n"
)
@pytest.mark.asyncio
async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path):
f = tmp_path / "quote_indent.py"
f.write_text(" message = “hello”\n", encoding="utf-8")
result = await tool.execute(
path=str(f),
old_text='message = "hello"',
new_text='message = "goodbye"',
)
assert "Successfully" in result
assert f.read_text(encoding="utf-8") == " message = “goodbye”\n"
# ---------------------------------------------------------------------------
# Trailing whitespace stripping on new_text
# ---------------------------------------------------------------------------
class TestTrailingWhitespaceStrip:
"""new_text trailing whitespace should be stripped (except .md files)."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_strips_trailing_whitespace_from_new_text(self, tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("x = 1\n", encoding="utf-8")
result = await tool.execute(
path=str(f), old_text="x = 1", new_text="x = 2 \ny = 3 ",
)
assert "Successfully" in result
content = f.read_text()
assert "x = 2\ny = 3\n" == content
@pytest.mark.asyncio
async def test_preserves_trailing_whitespace_in_markdown(self, tool, tmp_path):
f = tmp_path / "doc.md"
f.write_text("# Title\n", encoding="utf-8")
# Markdown uses trailing double-space for line breaks
result = await tool.execute(
path=str(f), old_text="# Title", new_text="# Title \nSubtitle ",
)
assert "Successfully" in result
content = f.read_text()
# Trailing spaces should be preserved for markdown
assert "Title " in content
assert "Subtitle " in content
# ---------------------------------------------------------------------------
# File size protection
# ---------------------------------------------------------------------------
class TestFileSizeProtection:
"""Editing extremely large files should be rejected."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_rejects_file_over_size_limit(self, tool, tmp_path):
f = tmp_path / "huge.txt"
f.write_text("x", encoding="utf-8")
# Monkey-patch the file size check by creating a stat mock
original_stat = f.stat
class FakeStat:
def __init__(self, real_stat):
self._real = real_stat
def __getattr__(self, name):
return getattr(self._real, name)
@property
def st_size(self):
return 2 * 1024 * 1024 * 1024 # 2 GiB
import unittest.mock
with unittest.mock.patch.object(type(f), 'stat', return_value=FakeStat(f.stat())):
result = await tool.execute(path=str(f), old_text="x", new_text="y")
assert "Error" in result
assert "too large" in result.lower() or "size" in result.lower()
# ---------------------------------------------------------------------------
# Stale detection with content-equality fallback
# ---------------------------------------------------------------------------
class TestStaleDetectionContentFallback:
"""When mtime changed but file content is unchanged, edit should proceed without warning."""
@pytest.fixture()
def read_tool(self, tmp_path):
return ReadFileTool(workspace=tmp_path)
@pytest.fixture()
def edit_tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_mtime_bump_same_content_no_warning(self, read_tool, edit_tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello world", encoding="utf-8")
await read_tool.execute(path=str(f))
# Touch the file to bump mtime without changing content
time.sleep(0.05)
original_content = f.read_text()
f.write_text(original_content, encoding="utf-8")
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
assert "Successfully" in result
# Should NOT warn about modification since content is the same
assert "modified" not in result.lower()

View File

@ -0,0 +1,152 @@
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
.ipynb detection, and create-file semantics."""
import pytest
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools import file_state
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _clear_file_state():
"""Reset global read-state between tests."""
file_state.clear()
yield
file_state.clear()
# ---------------------------------------------------------------------------
# Read-before-edit tracking
# ---------------------------------------------------------------------------
class TestEditReadTracking:
"""edit_file should warn when file hasn't been read first."""
@pytest.fixture()
def read_tool(self, tmp_path):
return ReadFileTool(workspace=tmp_path)
@pytest.fixture()
def edit_tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_edit_warns_if_file_not_read_first(self, edit_tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello world", encoding="utf-8")
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
# Should still succeed but include a warning
assert "Successfully" in result
assert "not been read" in result.lower() or "warning" in result.lower()
@pytest.mark.asyncio
async def test_edit_succeeds_cleanly_after_read(self, read_tool, edit_tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello world", encoding="utf-8")
await read_tool.execute(path=str(f))
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
assert "Successfully" in result
# No warning when file was read first
assert "not been read" not in result.lower()
assert f.read_text() == "hello earth"
@pytest.mark.asyncio
async def test_edit_warns_if_file_modified_since_read(self, read_tool, edit_tool, tmp_path):
f = tmp_path / "a.py"
f.write_text("hello world", encoding="utf-8")
await read_tool.execute(path=str(f))
# External modification
f.write_text("hello universe", encoding="utf-8")
result = await edit_tool.execute(path=str(f), old_text="universe", new_text="earth")
assert "Successfully" in result
assert "modified" in result.lower() or "warning" in result.lower()
# ---------------------------------------------------------------------------
# Create-file semantics
# ---------------------------------------------------------------------------
class TestEditCreateFile:
"""edit_file with old_text='' creates new file if not exists."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_create_new_file_with_empty_old_text(self, tool, tmp_path):
f = tmp_path / "subdir" / "new.py"
result = await tool.execute(path=str(f), old_text="", new_text="print('hi')")
assert "created" in result.lower() or "Successfully" in result
assert f.exists()
assert f.read_text() == "print('hi')"
@pytest.mark.asyncio
async def test_create_fails_if_file_already_exists_and_not_empty(self, tool, tmp_path):
f = tmp_path / "existing.py"
f.write_text("existing content", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="", new_text="new content")
assert "Error" in result or "already exists" in result.lower()
# File should be unchanged
assert f.read_text() == "existing content"
@pytest.mark.asyncio
async def test_create_succeeds_if_file_exists_but_empty(self, tool, tmp_path):
f = tmp_path / "empty.py"
f.write_text("", encoding="utf-8")
result = await tool.execute(path=str(f), old_text="", new_text="print('hi')")
assert "Successfully" in result
assert f.read_text() == "print('hi')"
# ---------------------------------------------------------------------------
# .ipynb detection
# ---------------------------------------------------------------------------
class TestEditIpynbDetection:
"""edit_file should refuse .ipynb and suggest notebook_edit."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path):
f = tmp_path / "analysis.ipynb"
f.write_text('{"cells": []}', encoding="utf-8")
result = await tool.execute(path=str(f), old_text="x", new_text="y")
assert "notebook" in result.lower()
# ---------------------------------------------------------------------------
# Path suggestion on not-found
# ---------------------------------------------------------------------------
class TestEditPathSuggestion:
"""edit_file should suggest similar paths on not-found."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_suggests_similar_filename(self, tool, tmp_path):
f = tmp_path / "config.py"
f.write_text("x = 1", encoding="utf-8")
# Typo: conifg.py
result = await tool.execute(
path=str(tmp_path / "conifg.py"), old_text="x = 1", new_text="x = 2",
)
assert "Error" in result
assert "config.py" in result
@pytest.mark.asyncio
async def test_shows_cwd_in_error(self, tool, tmp_path):
result = await tool.execute(
path=str(tmp_path / "nonexistent.py"), old_text="a", new_text="b",
)
assert "Error" in result

View File

@ -43,3 +43,34 @@ async def test_exec_path_append_preserves_system_path():
tool = ExecTool(path_append="/opt/custom/bin")
result = await tool.execute(command="ls /")
assert "Exit code: 0" in result
@_UNIX_ONLY
@pytest.mark.asyncio
async def test_exec_allowed_env_keys_passthrough(monkeypatch):
"""Env vars listed in allowed_env_keys should be visible to commands."""
monkeypatch.setenv("MY_CUSTOM_VAR", "hello-from-config")
tool = ExecTool(allowed_env_keys=["MY_CUSTOM_VAR"])
result = await tool.execute(command="printenv MY_CUSTOM_VAR")
assert "hello-from-config" in result
@_UNIX_ONLY
@pytest.mark.asyncio
async def test_exec_allowed_env_keys_does_not_leak_others(monkeypatch):
"""Env vars NOT in allowed_env_keys should still be blocked."""
monkeypatch.setenv("MY_CUSTOM_VAR", "hello-from-config")
monkeypatch.setenv("MY_SECRET_VAR", "secret-value")
tool = ExecTool(allowed_env_keys=["MY_CUSTOM_VAR"])
result = await tool.execute(command="printenv MY_SECRET_VAR")
assert "secret-value" not in result
@_UNIX_ONLY
@pytest.mark.asyncio
async def test_exec_allowed_env_keys_missing_var_ignored(monkeypatch):
"""If an allowed key is not set in the parent process, it should be silently skipped."""
monkeypatch.delenv("NONEXISTENT_VAR_12345", raising=False)
tool = ExecTool(allowed_env_keys=["NONEXISTENT_VAR_12345"])
result = await tool.execute(command="printenv NONEXISTENT_VAR_12345")
assert "Exit code: 1" in result

View File

@ -5,12 +5,18 @@ strategy, and sandbox behaviour per platform — without actually running
platform-specific binaries (all subprocess calls are mocked).
"""
import sys
from unittest.mock import AsyncMock, patch
import pytest
from nanobot.agent.tools.shell import ExecTool
_WINDOWS_ENV_KEYS = {
"APPDATA", "LOCALAPPDATA", "ProgramData",
"ProgramFiles", "ProgramFiles(x86)", "ProgramW6432",
}
# ---------------------------------------------------------------------------
# _build_env
@ -21,7 +27,10 @@ class TestBuildEnvUnix:
def test_expected_keys(self):
with patch("nanobot.agent.tools.shell._IS_WINDOWS", False):
env = ExecTool()._build_env()
assert set(env) == {"HOME", "LANG", "TERM"}
expected = {"HOME", "LANG", "TERM"}
assert expected <= set(env)
if sys.platform != "win32":
assert set(env) == expected
def test_home_from_environ(self, monkeypatch):
monkeypatch.setenv("HOME", "/Users/dev")
@ -45,6 +54,7 @@ class TestBuildEnvWindows:
_EXPECTED_KEYS = {
"SYSTEMROOT", "COMSPEC", "USERPROFILE", "HOMEDRIVE",
"HOMEPATH", "TEMP", "TMP", "PATHEXT", "PATH",
*_WINDOWS_ENV_KEYS,
}
def test_expected_keys(self):

View File

@ -67,3 +67,118 @@ async def test_exec_blocks_chained_internal_url():
command="echo start && curl http://169.254.169.254/latest/meta-data/ && echo done"
)
assert "Error" in result
# --- #2989: block writes to nanobot internal state files -----------------
@pytest.mark.parametrize(
"command",
[
"cat foo >> history.jsonl",
"echo '{}' > history.jsonl",
"echo '{}' > memory/history.jsonl",
"echo '{}' > ./workspace/memory/history.jsonl",
"tee -a history.jsonl < foo",
"tee history.jsonl",
"cp /tmp/fake.jsonl history.jsonl",
"mv backup.jsonl memory/history.jsonl",
"dd if=/dev/zero of=memory/history.jsonl",
"sed -i 's/old/new/' history.jsonl",
"echo x > .dream_cursor",
"cp /tmp/x memory/.dream_cursor",
],
)
def test_exec_blocks_writes_to_history_jsonl(command):
"""Direct writes to history.jsonl / .dream_cursor must be blocked (#2989)."""
tool = ExecTool()
result = tool._guard_command(command, "/tmp")
assert result is not None
assert "dangerous pattern" in result.lower()
@pytest.mark.parametrize(
"command",
[
"cat history.jsonl",
"wc -l history.jsonl",
"tail -n 5 history.jsonl",
"grep foo history.jsonl",
"cp history.jsonl /tmp/history.backup",
"ls memory/",
"echo history.jsonl",
],
)
def test_exec_allows_reads_of_history_jsonl(command):
"""Read-only access to history.jsonl must still be allowed."""
tool = ExecTool()
result = tool._guard_command(command, "/tmp")
assert result is None
# --- #2826: working_dir must not escape the configured workspace ---------
@pytest.mark.asyncio
async def test_exec_blocks_working_dir_outside_workspace(tmp_path):
"""An LLM-supplied working_dir outside the workspace must be rejected."""
workspace = tmp_path / "workspace"
workspace.mkdir()
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True)
result = await tool.execute(command="rm calendar.ics", working_dir="/etc")
assert "outside the configured workspace" in result
@pytest.mark.asyncio
async def test_exec_blocks_absolute_rm_via_hijacked_working_dir(tmp_path):
"""Regression for #2826: `rm /abs/path` via working_dir hijack."""
workspace = tmp_path / "workspace"
workspace.mkdir()
victim_dir = tmp_path / "outside"
victim_dir.mkdir()
victim = victim_dir / "file.ics"
victim.write_text("data")
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True)
result = await tool.execute(
command=f"rm {victim}",
working_dir=str(victim_dir),
)
assert "outside the configured workspace" in result
assert victim.exists(), "victim file must not have been deleted"
@pytest.mark.asyncio
async def test_exec_allows_working_dir_within_workspace(tmp_path):
"""A working_dir that is a subdirectory of the workspace is fine."""
workspace = tmp_path / "workspace"
subdir = workspace / "project"
subdir.mkdir(parents=True)
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True, timeout=5)
result = await tool.execute(command="echo ok", working_dir=str(subdir))
assert "ok" in result
assert "outside the configured workspace" not in result
@pytest.mark.asyncio
async def test_exec_allows_working_dir_equal_to_workspace(tmp_path):
"""Passing working_dir equal to the workspace root must be allowed."""
workspace = tmp_path / "workspace"
workspace.mkdir()
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True, timeout=5)
result = await tool.execute(command="echo ok", working_dir=str(workspace))
assert "ok" in result
assert "outside the configured workspace" not in result
@pytest.mark.asyncio
async def test_exec_ignores_workspace_check_when_not_restricted(tmp_path):
"""Without restrict_to_workspace, the LLM may still choose any working_dir."""
workspace = tmp_path / "workspace"
workspace.mkdir()
other = tmp_path / "other"
other.mkdir()
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=False, timeout=5)
result = await tool.execute(command="echo ok", working_dir=str(other))
assert "ok" in result
assert "outside the configured workspace" not in result

View File

@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo"]
@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake")},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake")},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == ["mcp_test_demo"]
@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
) -> None:
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == []
@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == []
@ -376,6 +356,73 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
@pytest.mark.asyncio
async def test_connect_mcp_servers_logs_stdio_pollution_hint(
monkeypatch: pytest.MonkeyPatch,
) -> None:
messages: list[str] = []
def _error(message: str, *args: object) -> None:
messages.append(message.format(*args))
@asynccontextmanager
async def _broken_stdio_client(_params: object):
raise RuntimeError("Parse error: Unexpected token 'INFO' before JSON-RPC headers")
yield # pragma: no cover
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _broken_stdio_client)
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.error", _error)
registry = ToolRegistry()
stacks = await connect_mcp_servers({"gh": MCPServerConfig(command="github-mcp")}, registry)
assert stacks == {}
assert messages
assert "stdio protocol pollution" in messages[-1]
assert "stdout" in messages[-1]
assert "stderr" in messages[-1]
@pytest.mark.asyncio
async def test_connect_mcp_servers_one_failure_does_not_block_others(
monkeypatch: pytest.MonkeyPatch,
) -> None:
sessions = {"good": _make_fake_session(["demo"])}
class _SelectiveClientSession:
def __init__(self, read: object, _write: object) -> None:
self._session = sessions[read]
async def __aenter__(self) -> object:
return self._session
async def __aexit__(self, exc_type, exc, tb) -> bool:
return False
@asynccontextmanager
async def _selective_stdio_client(params: object):
if params.command == "bad":
raise RuntimeError("boom")
yield params.command, object()
monkeypatch.setattr(sys.modules["mcp"], "ClientSession", _SelectiveClientSession)
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _selective_stdio_client)
registry = ToolRegistry()
stacks = await connect_mcp_servers(
{
"good": MCPServerConfig(command="good"),
"bad": MCPServerConfig(command="bad"),
},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == ["mcp_good_demo"]
assert set(stacks) == {"good"}
# ---------------------------------------------------------------------------
# MCPResourceWrapper tests
# ---------------------------------------------------------------------------
@ -389,9 +436,7 @@ def _make_resource_def(
return SimpleNamespace(name=name, uri=uri, description=description)
def _make_resource_wrapper(
session: object, *, timeout: float = 0.1
) -> MCPResourceWrapper:
def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper:
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
@ -434,9 +479,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None:
await asyncio.sleep(1)
return SimpleNamespace(contents=[])
wrapper = _make_resource_wrapper(
SimpleNamespace(read_resource=read_resource), timeout=0.01
)
wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01)
result = await wrapper.execute()
assert result == "(MCP resource read timed out after 0.01s)"
@ -464,20 +507,14 @@ def _make_prompt_def(
return SimpleNamespace(name=name, description=description, arguments=arguments)
def _make_prompt_wrapper(
session: object, *, timeout: float = 0.1
) -> MCPPromptWrapper:
return MCPPromptWrapper(
session, "srv", _make_prompt_def(), prompt_timeout=timeout
)
def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper:
return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout)
def test_prompt_wrapper_properties() -> None:
arg1 = SimpleNamespace(name="topic", required=True)
arg2 = SimpleNamespace(name="style", required=False)
wrapper = MCPPromptWrapper(
None, "myserver", _make_prompt_def(arguments=[arg1, arg2])
)
wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2]))
assert wrapper.name == "mcp_myserver_prompt_myprompt"
assert "[MCP Prompt]" in wrapper.description
assert "A test prompt" in wrapper.description
@ -528,9 +565,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None:
await asyncio.sleep(1)
return SimpleNamespace(messages=[])
wrapper = _make_prompt_wrapper(
SimpleNamespace(get_prompt=get_prompt), timeout=0.01
)
wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01)
result = await wrapper.execute()
assert result == "(MCP prompt call timed out after 0.01s)"
@ -616,15 +651,11 @@ async def test_connect_registers_resources_and_prompts(
prompt_names=["prompt_c"],
)
registry = ToolRegistry()
stack = AsyncExitStack()
await stack.__aenter__()
try:
await connect_mcp_servers(
{"test": MCPServerConfig(command="fake")},
registry,
stack,
)
finally:
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake")},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert "mcp_test_tool_a" in registry.tool_names

View File

@ -1,5 +1,6 @@
"""Test message tool suppress logic for final replies."""
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic:
assert result is not None
assert "Hello" in result.content
@pytest.mark.asyncio
async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback(
self, tmp_path: Path
) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"},
)
calls = iter([
LLMResponse(content="First answer", tool_calls=[]),
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="", tool_calls=[]),
LLMResponse(content="", tool_calls=[]),
LLMResponse(content="", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
pending_queue = asyncio.Queue()
await pending_queue.put(
InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up")
)
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start")
result = await loop._process_message(msg, pending_queue=pending_queue)
assert len(sent) == 1
assert sent[0].content == "Tool reply"
assert result is None
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
@ -107,7 +144,7 @@ class TestMessageToolSuppressLogic:
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
progress.append((content, tool_hint))
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
assert final_content == "Done"
assert progress == [

View File

@ -0,0 +1,147 @@
"""Tests for NotebookEditTool — Jupyter .ipynb editing."""
import json
import pytest
from nanobot.agent.tools.notebook import NotebookEditTool
def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict:
"""Build a minimal valid .ipynb structure."""
return {
"nbformat": nbformat,
"nbformat_minor": nbformat_minor,
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
"cells": cells or [],
}
def _code_cell(source: str, cell_id: str | None = None) -> dict:
cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None}
if cell_id:
cell["id"] = cell_id
return cell
def _md_cell(source: str, cell_id: str | None = None) -> dict:
cell = {"cell_type": "markdown", "source": source, "metadata": {}}
if cell_id:
cell["id"] = cell_id
return cell
def _write_nb(tmp_path, name: str, nb: dict) -> str:
p = tmp_path / name
p.write_text(json.dumps(nb), encoding="utf-8")
return str(p)
class TestNotebookEdit:
@pytest.fixture()
def tool(self, tmp_path):
return NotebookEditTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_replace_cell_content(self, tool, tmp_path):
nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=0, new_source="print('world')")
assert "Successfully" in result
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert saved["cells"][0]["source"] == "print('world')"
assert saved["cells"][1]["source"] == "x = 1"
@pytest.mark.asyncio
async def test_insert_cell_after_target(self, tool, tmp_path):
nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert")
assert "Successfully" in result
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert len(saved["cells"]) == 3
assert saved["cells"][0]["source"] == "cell 0"
assert saved["cells"][1]["source"] == "inserted"
assert saved["cells"][2]["source"] == "cell 1"
@pytest.mark.asyncio
async def test_delete_cell(self, tool, tmp_path):
nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=1, edit_mode="delete")
assert "Successfully" in result
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert len(saved["cells"]) == 2
assert saved["cells"][0]["source"] == "A"
assert saved["cells"][1]["source"] == "C"
@pytest.mark.asyncio
async def test_create_new_notebook_from_scratch(self, tool, tmp_path):
path = str(tmp_path / "new.ipynb")
result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown")
assert "Successfully" in result or "created" in result.lower()
saved = json.loads((tmp_path / "new.ipynb").read_text())
assert saved["nbformat"] == 4
assert len(saved["cells"]) == 1
assert saved["cells"][0]["cell_type"] == "markdown"
assert saved["cells"][0]["source"] == "# Hello"
@pytest.mark.asyncio
async def test_invalid_cell_index_error(self, tool, tmp_path):
nb = _make_notebook([_code_cell("only cell")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=5, new_source="x")
assert "Error" in result
@pytest.mark.asyncio
async def test_non_ipynb_rejected(self, tool, tmp_path):
f = tmp_path / "script.py"
f.write_text("pass")
result = await tool.execute(path=str(f), cell_index=0, new_source="x")
assert "Error" in result
assert ".ipynb" in result
@pytest.mark.asyncio
async def test_preserves_metadata_and_outputs(self, tool, tmp_path):
cell = _code_cell("old")
cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}]
cell["execution_count"] = 42
nb = _make_notebook([cell])
path = _write_nb(tmp_path, "test.ipynb", nb)
await tool.execute(path=path, cell_index=0, new_source="new")
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert saved["metadata"]["kernelspec"]["language"] == "python"
@pytest.mark.asyncio
async def test_nbformat_45_generates_cell_id(self, tool, tmp_path):
nb = _make_notebook([], nbformat_minor=5)
path = _write_nb(tmp_path, "test.ipynb", nb)
await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert")
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert "id" in saved["cells"][0]
assert len(saved["cells"][0]["id"]) > 0
@pytest.mark.asyncio
async def test_insert_with_cell_type_markdown(self, tool, tmp_path):
nb = _make_notebook([_code_cell("code")])
path = _write_nb(tmp_path, "test.ipynb", nb)
await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown")
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert saved["cells"][1]["cell_type"] == "markdown"
@pytest.mark.asyncio
async def test_invalid_edit_mode_rejected(self, tool, tmp_path):
nb = _make_notebook([_code_cell("code")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae")
assert "Error" in result
assert "edit_mode" in result
@pytest.mark.asyncio
async def test_invalid_cell_type_rejected(self, tool, tmp_path):
nb = _make_notebook([_code_cell("code")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw")
assert "Error" in result
assert "cell_type" in result

View File

@ -0,0 +1,180 @@
"""Tests for ReadFileTool enhancements: description fix, read dedup, PDF support, device blacklist."""
import pytest
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool
from nanobot.agent.tools import file_state
@pytest.fixture(autouse=True)
def _clear_file_state():
file_state.clear()
yield
file_state.clear()
# ---------------------------------------------------------------------------
# Description fix
# ---------------------------------------------------------------------------
class TestReadDescriptionFix:
def test_description_mentions_image_support(self):
tool = ReadFileTool()
assert "image" in tool.description.lower()
def test_description_no_longer_says_cannot_read_images(self):
tool = ReadFileTool()
assert "cannot read binary files or images" not in tool.description.lower()
# ---------------------------------------------------------------------------
# Read deduplication
# ---------------------------------------------------------------------------
class TestReadDedup:
"""Same file + same offset/limit + unchanged mtime -> short stub."""
@pytest.fixture()
def tool(self, tmp_path):
return ReadFileTool(workspace=tmp_path)
@pytest.fixture()
def write_tool(self, tmp_path):
return WriteFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_second_read_returns_unchanged_stub(self, tool, tmp_path):
f = tmp_path / "data.txt"
f.write_text("\n".join(f"line {i}" for i in range(100)), encoding="utf-8")
first = await tool.execute(path=str(f))
assert "line 0" in first
second = await tool.execute(path=str(f))
assert "unchanged" in second.lower()
# Stub should not contain file content
assert "line 0" not in second
@pytest.mark.asyncio
async def test_read_after_external_modification_returns_full(self, tool, tmp_path):
f = tmp_path / "data.txt"
f.write_text("original", encoding="utf-8")
await tool.execute(path=str(f))
# Modify the file externally
f.write_text("modified content", encoding="utf-8")
second = await tool.execute(path=str(f))
assert "modified content" in second
@pytest.mark.asyncio
async def test_different_offset_returns_full(self, tool, tmp_path):
f = tmp_path / "data.txt"
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
await tool.execute(path=str(f), offset=1, limit=5)
second = await tool.execute(path=str(f), offset=6, limit=5)
# Different offset → full read, not stub
assert "line 6" in second
@pytest.mark.asyncio
async def test_first_read_after_write_returns_full_content(self, tool, write_tool, tmp_path):
f = tmp_path / "fresh.txt"
result = await write_tool.execute(path=str(f), content="hello")
assert "Successfully" in result
read_result = await tool.execute(path=str(f))
assert "hello" in read_result
assert "unchanged" not in read_result.lower()
@pytest.mark.asyncio
async def test_dedup_does_not_apply_to_images(self, tool, tmp_path):
f = tmp_path / "img.png"
f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data")
first = await tool.execute(path=str(f))
assert isinstance(first, list)
second = await tool.execute(path=str(f))
# Images should always return full content blocks, not a stub
assert isinstance(second, list)
# ---------------------------------------------------------------------------
# PDF support
# ---------------------------------------------------------------------------
class TestReadPdf:
@pytest.fixture()
def tool(self, tmp_path):
return ReadFileTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_pdf_returns_text_content(self, tool, tmp_path):
fitz = pytest.importorskip("fitz")
pdf_path = tmp_path / "test.pdf"
doc = fitz.open()
page = doc.new_page()
page.insert_text((72, 72), "Hello PDF World")
doc.save(str(pdf_path))
doc.close()
result = await tool.execute(path=str(pdf_path))
assert "Hello PDF World" in result
@pytest.mark.asyncio
async def test_pdf_pages_parameter(self, tool, tmp_path):
fitz = pytest.importorskip("fitz")
pdf_path = tmp_path / "multi.pdf"
doc = fitz.open()
for i in range(5):
page = doc.new_page()
page.insert_text((72, 72), f"Page {i + 1} content")
doc.save(str(pdf_path))
doc.close()
result = await tool.execute(path=str(pdf_path), pages="2-3")
assert "Page 2 content" in result
assert "Page 3 content" in result
assert "Page 1 content" not in result
@pytest.mark.asyncio
async def test_pdf_file_not_found_error(self, tool, tmp_path):
result = await tool.execute(path=str(tmp_path / "nope.pdf"))
assert "Error" in result
assert "not found" in result
# ---------------------------------------------------------------------------
# Device path blacklist
# ---------------------------------------------------------------------------
class TestReadDeviceBlacklist:
@pytest.fixture()
def tool(self):
return ReadFileTool()
@pytest.mark.asyncio
async def test_dev_random_blocked(self, tool):
result = await tool.execute(path="/dev/random")
assert "Error" in result
assert "blocked" in result.lower() or "device" in result.lower()
@pytest.mark.asyncio
async def test_dev_urandom_blocked(self, tool):
result = await tool.execute(path="/dev/urandom")
assert "Error" in result
@pytest.mark.asyncio
async def test_dev_zero_blocked(self, tool):
result = await tool.execute(path="/dev/zero")
assert "Error" in result
@pytest.mark.asyncio
async def test_proc_fd_blocked(self, tool):
result = await tool.execute(path="/proc/self/fd/0")
assert "Error" in result
@pytest.mark.asyncio
async def test_symlink_to_dev_zero_blocked(self, tmp_path):
tool = ReadFileTool(workspace=tmp_path)
link = tmp_path / "zero-link"
link.symlink_to("/dev/zero")
result = await tool.execute(path=str(link))
assert "Error" in result
assert "blocked" in result.lower() or "device" in result.lower()

View File

@ -323,3 +323,27 @@ async def test_subagent_registers_grep_and_glob(tmp_path: Path) -> None:
assert "grep" in captured["tool_names"]
assert "glob" in captured["tool_names"]
def test_subagent_prompt_respects_disabled_skills(tmp_path: Path) -> None:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
skills_dir = tmp_path / "skills"
(skills_dir / "alpha").mkdir(parents=True)
(skills_dir / "alpha" / "SKILL.md").write_text("# Alpha\n\nhidden\n", encoding="utf-8")
(skills_dir / "beta").mkdir(parents=True)
(skills_dir / "beta" / "SKILL.md").write_text("# Beta\n\nshown\n", encoding="utf-8")
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=4096,
disabled_skills=["alpha"],
)
prompt = mgr._build_subagent_prompt()
assert "alpha" not in prompt
assert "beta" in prompt

View File

@ -47,3 +47,27 @@ def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
"mcp_fs_list",
"mcp_git_status",
]
def test_prepare_call_read_file_rejects_non_object_params_with_actionable_hint() -> None:
registry = ToolRegistry()
registry.register(_FakeTool("read_file"))
tool, params, error = registry.prepare_call("read_file", ["foo.txt"])
assert tool is None
assert params == ["foo.txt"]
assert error is not None
assert "must be a JSON object" in error
assert "Use named parameters" in error
def test_prepare_call_other_tools_keep_generic_object_validation() -> None:
registry = ToolRegistry()
registry.register(_FakeTool("grep"))
tool, params, error = registry.prepare_call("grep", ["TODO"])
assert tool is not None
assert params == ["TODO"]
assert error == "Error: Invalid parameters for tool 'grep': parameters must be an object, got list"

View File

@ -1,7 +1,5 @@
"""Tests for multi-provider web search."""
import asyncio
import httpx
import pytest
@ -20,6 +18,25 @@ def _response(status: int = 200, json: dict | None = None) -> httpx.Response:
return r
def test_duckduckgo_search_is_exclusive():
tool = _tool(provider="duckduckgo")
assert tool.exclusive is True
assert tool.concurrency_safe is False
def test_brave_with_api_key_remains_concurrency_safe():
tool = _tool(provider="brave", api_key="brave-key")
assert tool.exclusive is False
assert tool.concurrency_safe is True
def test_brave_without_api_key_is_treated_as_duckduckgo_for_concurrency(monkeypatch):
monkeypatch.delenv("BRAVE_API_KEY", raising=False)
tool = _tool(provider="brave", api_key="")
assert tool.exclusive is True
assert tool.concurrency_safe is False
@pytest.mark.asyncio
async def test_brave_search(monkeypatch):
async def mock_get(self, url, **kw):
@ -79,7 +96,6 @@ async def test_duckduckgo_search(monkeypatch):
import nanobot.agent.tools.web as web_mod
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
from ddgs import DDGS
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
tool = _tool(provider="duckduckgo")
@ -120,6 +136,27 @@ async def test_jina_search(monkeypatch):
assert "https://jina.ai" in result
@pytest.mark.asyncio
async def test_kagi_search(monkeypatch):
async def mock_get(self, url, **kw):
assert "kagi.com/api/v0/search" in url
assert kw["headers"]["Authorization"] == "Bot kagi-key"
assert kw["params"] == {"q": "test", "limit": 2}
return _response(json={
"data": [
{"t": 0, "title": "Kagi Result", "url": "https://kagi.com", "snippet": "Premium search"},
{"t": 1, "list": ["ignored related search"]},
]
})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="kagi", api_key="kagi-key")
result = await tool.execute(query="test", count=2)
assert "Kagi Result" in result
assert "https://kagi.com" in result
assert "ignored related search" not in result
@pytest.mark.asyncio
async def test_unknown_provider():
tool = _tool(provider="unknown")
@ -189,6 +226,23 @@ async def test_jina_422_falls_back_to_duckduckgo(monkeypatch):
assert "DuckDuckGo fallback" in result
@pytest.mark.asyncio
async def test_kagi_fallback_to_duckduckgo_when_no_key(monkeypatch):
class MockDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}]
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
monkeypatch.delenv("KAGI_API_KEY", raising=False)
tool = _tool(provider="kagi", api_key="")
result = await tool.execute(query="test")
assert "Fallback" in result
@pytest.mark.asyncio
async def test_jina_search_uses_path_encoded_query(monkeypatch):
calls = {}
@ -227,5 +281,3 @@ async def test_duckduckgo_timeout_returns_error(monkeypatch):
result = await tool.execute(query="test")
gate.set()
assert "Error" in result

View File

@ -0,0 +1,65 @@
import pytest
from nanobot.utils.helpers import strip_think
class TestStripThinkTag:
"""Test <thought>...</thought> block stripping (Gemma 4 and similar models)."""
def test_closed_tag(self):
assert strip_think("Hello <thought>reasoning</thought> World") == "Hello World"
def test_unclosed_trailing_tag(self):
assert strip_think("<thought>ongoing...") == ""
def test_multiline_tag(self):
assert strip_think("<thought>\nline1\nline2\n</thought>End") == "End"
def test_tag_with_nested_angle_brackets(self):
text = "<thought>a < 3 and b > 2</thought>result"
assert strip_think(text) == "result"
def test_multiple_tag_blocks(self):
text = "A<thought>x</thought>B<thought>y</thought>C"
assert strip_think(text) == "ABC"
def test_tag_only_whitespace_inside(self):
assert strip_think("before<thought> </thought>after") == "beforeafter"
def test_self_closing_tag_not_matched(self):
assert strip_think("<thought/>some text") == "<thought/>some text"
def test_normal_text_unchanged(self):
assert strip_think("Just normal text") == "Just normal text"
def test_empty_string(self):
assert strip_think("") == ""
class TestStripThinkFalsePositive:
"""Ensure mid-content <think>/<thought> tags are NOT stripped (#3004)."""
def test_backtick_think_tag_preserved(self):
text = "*Think Stripping:* A new utility to strip `<think>` tags from output."
assert strip_think(text) == text
def test_prose_think_tag_preserved(self):
text = "The model emits <think> at the start of its response."
assert strip_think(text) == text
def test_code_block_think_tag_preserved(self):
text = "Example:\n```\ntext = re.sub(r\"<think>[\\s\\S]*\", \"\", text)\n```\nDone."
assert strip_think(text) == text
def test_backtick_thought_tag_preserved(self):
text = "Gemma 4 uses `<thought>` blocks for reasoning."
assert strip_think(text) == text
def test_prefix_unclosed_think_still_stripped(self):
assert strip_think("<think>reasoning without closing") == ""
def test_prefix_unclosed_think_with_whitespace(self):
assert strip_think(" <think>reasoning...") == ""
def test_prefix_unclosed_thought_still_stripped(self):
assert strip_think("<thought>reasoning without closing") == ""