mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-14 23:19:55 +00:00
Merge origin/main into pr-2959
Resolve the config plumbing conflicts and keep disabled skill filtering consistent for subagent prompts after syncing with main. Made-with: Cursor
This commit is contained in:
commit
09c238ca0f
189
.docs/NEW_THINGS_SHOULD_ADD_TO_WEB.md
Normal file
189
.docs/NEW_THINGS_SHOULD_ADD_TO_WEB.md
Normal file
@ -0,0 +1,189 @@
|
||||
# Pending Web Documentation Updates
|
||||
|
||||
Items that need to be synced to `.web/nanobot-web/nanobot-web-page/docs/` when ready.
|
||||
|
||||
## Merged (ready to update)
|
||||
|
||||
### 1. Anthropic adaptive thinking mode (PR #2882)
|
||||
- **What changed:** `reasoning_effort` now supports `"adaptive"` in addition to `"low"` / `"medium"` / `"high"`.
|
||||
When set to `"adaptive"`, the model decides when and how much to think (supported on claude-sonnet-4-6, claude-opus-4-6).
|
||||
- **Where to update:**
|
||||
- `content.js` → `agents.defaults` reference section → `reasoningEffort` field description:
|
||||
current text lists `"low"`, `"medium"`, `"high"`, or `null` — add `"adaptive"`.
|
||||
- All 6 locale files (`zh-CN.js`, `zh-TW.js`, `ja.js`, `ko.js`, `es.js`, `fr.js`) → same field.
|
||||
- **Source:** `nanobot/config/schema.py` line comment, `nanobot/providers/anthropic_provider.py` `_build_kwargs`.
|
||||
|
||||
## Not yet merged (update after merge)
|
||||
|
||||
### 2. Windows shell / cross-platform exec tool (PR #2926 + PR #2941)
|
||||
- **What changed:** `exec` tool now works on Windows via `cmd.exe /c`. Environment isolation is
|
||||
platform-aware: Unix passes `HOME`/`LANG`/`TERM` (bash -l handles PATH); Windows passes a curated
|
||||
set of 15 system variables (`PATH`, `SYSTEMROOT`, `COMSPEC`, `USERPROFILE`, `HOMEDRIVE`,
|
||||
`HOMEPATH`, `TEMP`, `TMP`, `PATHEXT`, `APPDATA`, `LOCALAPPDATA`, `ProgramData`, `ProgramFiles`,
|
||||
`ProgramFiles(x86)`, `ProgramW6432`) while still excluding secrets. `bwrap` sandbox is gracefully
|
||||
skipped on Windows with a warning.
|
||||
- **Where to update:**
|
||||
- `content.js` → Security section → exec tool environment description:
|
||||
current text says "only HOME, LANG, TERM" — needs platform-specific note listing the 15 Windows variables.
|
||||
- All 6 locale files → same section.
|
||||
- **Source:** `nanobot/agent/tools/shell.py` `_build_env`, `_spawn`.
|
||||
|
||||
### 3. Channel Plugin Guide — Pydantic config requirement (PR #2850)
|
||||
- **Status:** ✅ Already updated in this batch (v=20260407d).
|
||||
- Code examples updated to use `WebhookConfig(Base)` Pydantic model.
|
||||
- Warning note added in all 7 languages explaining `is_allowed()` silent failure with plain dict.
|
||||
|
||||
### 4. Telegram location sharing support (PR #2910)
|
||||
- **What changed:** Telegram channel now handles location messages. When a user shares a
|
||||
location pin, coordinates are forwarded to the agent as `[location: lat, lon]` — consistent
|
||||
with the existing `[image: ...]` / `[transcription: ...]` conventions. This enables MCP tools
|
||||
that accept geo coordinates (maps, weather, nearby search) to be triggered from a Telegram
|
||||
location share.
|
||||
- **Where to update:**
|
||||
- `content.js` → Telegram channel section → supported message types:
|
||||
current text lists text, images, voice, audio, documents — add location pins.
|
||||
- All 6 locale files → same section.
|
||||
- **Source:** `nanobot/channels/telegram.py` — `filters.LOCATION` in handler, `message.location` extraction in `_on_message`.
|
||||
|
||||
### 5. Tool hint formatting for exec paths and dedup (PR #2926)
|
||||
- **What changed:** Tool hints now fold file paths embedded in `exec` commands instead of blindly
|
||||
truncating them mid-path. This includes quoted paths with spaces on Unix and Windows. Consecutive
|
||||
hints are also deduplicated by the final formatted hint string, so different arguments are shown
|
||||
separately while truly identical calls still fold as `× N`.
|
||||
- **Where to update:**
|
||||
- `content.js` → Agent loop / tool hint display section:
|
||||
explain that exec command previews abbreviate embedded paths for readability and that folding
|
||||
happens only for repeated identical rendered hints.
|
||||
- All 6 locale files → same section.
|
||||
- **Source:** `nanobot/utils/tool_hints.py`, `tests/agent/test_tool_hint.py`.
|
||||
|
||||
### 6. Discord streaming replies enabled by default (PR #2939)
|
||||
- **What changed:** Discord now supports the streaming reply path used by Telegram, and Discord
|
||||
config gains a `streaming` flag that defaults to `true`. This avoids the previous non-streaming
|
||||
fallback path that could end in an empty final response with some OpenAI-compatible gateways.
|
||||
- **Where to update:**
|
||||
- `content.js` → Discord channel section → config reference:
|
||||
add the `streaming` field, note that it defaults to `true`, and explain it can be disabled to
|
||||
force non-streaming replies.
|
||||
- All 6 locale files → same section.
|
||||
- **Source:** `nanobot/channels/discord.py`, `tests/channels/test_discord_channel.py`, `README.md`.
|
||||
|
||||
### 7. WebSocket server channel (PR #2964)
|
||||
- **What changed:** New `websocket` channel that runs a WebSocket server, allowing external clients
|
||||
(web apps, CLIs, Chrome extensions, scripts) to interact with the agent in real time via persistent
|
||||
connections. Supports streaming (`delta` + `stream_end` events), token-based authentication
|
||||
(static tokens and short-lived issued tokens via HTTP endpoint), per-connection sessions,
|
||||
TLS/SSL (WSS), and client allow-list.
|
||||
- **Where to update:**
|
||||
- `content.js` → Channels section:
|
||||
add a new WebSocket channel subsection covering configuration (`channels.websocket`), wire
|
||||
protocol (`ready`, `message`, `delta`, `stream_end` events), authentication modes (static token,
|
||||
issued tokens via `tokenIssuePath`), and common deployment patterns.
|
||||
- All 6 locale files → same section.
|
||||
- README → supported channels list: add WebSocket.
|
||||
- **Source:** `nanobot/channels/websocket.py`, `docs/WEBSOCKET.md` (comprehensive standalone doc).
|
||||
|
||||
### 8. Exec tool `allowed_env_keys` config (PR #2962)
|
||||
- **What changed:** New `allowed_env_keys` field in `tools.exec` config. Users can list host
|
||||
environment variable names (e.g. `["GOPATH", "JAVA_HOME"]`) to selectively forward into the
|
||||
sandboxed subprocess. Default is an empty list — no behavior change for existing users. Works
|
||||
on both Unix and Windows.
|
||||
- **Where to update:**
|
||||
- `content.js` → Security section → exec tool environment description:
|
||||
current text describes the default allow-list (HOME/LANG/TERM on Unix, 15 vars on Windows).
|
||||
Add a note about `allowed_env_keys` for passing additional env vars.
|
||||
- All 6 locale files → same section.
|
||||
- **Source:** `nanobot/config/schema.py` (`ExecToolConfig.allowed_env_keys`), `nanobot/agent/tools/shell.py` (`_build_env`).
|
||||
|
||||
### 9. Discord proxy support (PR #2960)
|
||||
- **What changed:** Discord channel config gains `proxy`, `proxy_username`, and `proxy_password`
|
||||
fields. When set, the Discord bot connection is routed through the specified HTTP proxy,
|
||||
optionally with BasicAuth. Partial credentials (only username or only password) are logged
|
||||
as a warning and ignored.
|
||||
- **Where to update:**
|
||||
- `content.js` → Discord channel section → config reference:
|
||||
add the three proxy fields, note that `proxy_username`/`proxy_password` are both required
|
||||
for auth, and that partial credentials are ignored with a warning.
|
||||
- All 6 locale files → same section.
|
||||
- **Source:** `nanobot/channels/discord.py` (`DiscordConfig`, `DiscordChannel.start`).
|
||||
|
||||
### 10. Feishu streaming enhancements: resuming, inline tool hints, done emoji (PR #2993)
|
||||
- **What changed:** Three Feishu channel improvements:
|
||||
1. `doneEmoji` config field — optional completion emoji (e.g. `"DONE"`) added after `reactEmoji` is removed when the bot finishes processing.
|
||||
2. `toolHintPrefix` config field — configurable prefix for inline tool hints (default: `🔧`).
|
||||
3. Streaming resuming — mid-turn tool calls flush text to the streaming card without closing it, so the next text segment continues on the same card. Tool hints are inlined into active streaming cards instead of sent as separate messages.
|
||||
- **Where to update:**
|
||||
- `content.js` → Feishu channel section → config reference:
|
||||
add `doneEmoji` (optional string, emoji name for completion reaction) and `toolHintPrefix` (string, default `🔧`).
|
||||
Note streaming resuming behavior for mid-turn tool calls.
|
||||
- All 6 locale files → same section.
|
||||
- README → already updated in this PR with config example.
|
||||
- **Source:** `nanobot/channels/feishu.py` (`FeishuConfig.done_emoji`, `FeishuConfig.tool_hint_prefix`, `send_delta` resuming logic, `send` tool hint inline logic).
|
||||
|
||||
### 11. Unified session across channels (PR #2900)
|
||||
- **What changed:** New `unifiedSession` toggle in `config.json` (`agents.defaults`). When set to
|
||||
`true`, all incoming messages — regardless of which channel they arrive on — share a single
|
||||
session key (`unified:default`). Switching from Telegram to Discord continues the same
|
||||
conversation. Defaults to `false` — zero behavior change for existing users. Existing
|
||||
`session_key_override` (e.g. Telegram thread) is respected and not overwritten.
|
||||
- **Where to update:**
|
||||
- `content.js` → `agents.defaults` reference section:
|
||||
add `unifiedSession` field, type `boolean`, default `false`, explain single-user multi-device
|
||||
use case and that it merges all channel sessions into one.
|
||||
- All 6 locale files → same section.
|
||||
- README → config example or feature list, mention cross-channel unified session.
|
||||
- **Source:** `nanobot/config/schema.py` (`unified_session`), `nanobot/agent/loop.py` (`UNIFIED_SESSION_KEY`, `_dispatch`).
|
||||
|
||||
### 12. Auto compact config rename + recent live suffix retention (PR #3007)
|
||||
- **What changed:** Auto compact now preserves a recent legal suffix of live session messages while
|
||||
summarizing the older unconsolidated prefix, instead of clearing the entire live session. The
|
||||
preferred config key is now `idleCompactAfterMinutes`; legacy `sessionTtlMinutes` remains accepted
|
||||
as a backward-compatible alias.
|
||||
- **Where to update:**
|
||||
- `content.js` → `agents.defaults` reference section:
|
||||
rename the field to `idleCompactAfterMinutes`, note that `sessionTtlMinutes` is a legacy alias,
|
||||
and explain that auto compact keeps recent live context instead of replacing the whole session
|
||||
with only a summary.
|
||||
- All 6 locale files → same section.
|
||||
- Any auto-compact behavior notes:
|
||||
update wording from "session cleared" to "older context summarized, recent live suffix retained".
|
||||
- **Source:** `nanobot/config/schema.py` (`AgentDefaults.session_ttl_minutes` aliases),
|
||||
`nanobot/agent/auto_compact.py` (`_split_unconsolidated`, `_archive`), `README.md` Auto Compact section.
|
||||
|
||||
### 13. Kagi web search provider (PR #2945)
|
||||
- **What changed:** `tools.web.search.provider` now accepts `kagi`, using `apiKey` / `KAGI_API_KEY`
|
||||
to call Kagi's Search API through the built-in `web_search` tool.
|
||||
- **Where to update:**
|
||||
- `content.js` → web tools / search provider section:
|
||||
add `kagi` to the provider list, note that it uses the standard `apiKey` field or `KAGI_API_KEY`.
|
||||
- All 6 locale files → same section.
|
||||
- Any provider comparison tables:
|
||||
add Kagi alongside Brave, Tavily, Jina, SearXNG, and DuckDuckGo.
|
||||
- **Source:** `nanobot/agent/tools/web.py` (`_search_kagi`),
|
||||
`nanobot/config/schema.py` (`WebSearchConfig.provider` comment), `README.md` web tools section.
|
||||
|
||||
### 14. Mid-turn follow-up injection for active agent runs (PR #3042)
|
||||
- **What changed:** If a user sends another message while the agent is still working on the same
|
||||
session, the follow-up can now be injected into the current agent turn instead of waiting behind
|
||||
the per-session lock as a separate later turn. Streaming channels keep the active reply open when
|
||||
the turn resumes, so the follow-up answer can continue in the same live response flow.
|
||||
- **Where to update:**
|
||||
- `content.js` → agent loop / streaming behavior section:
|
||||
explain that same-session follow-ups during an active turn may be folded into the in-flight
|
||||
response instead of always starting a brand-new queued turn.
|
||||
- All 6 locale files → same section.
|
||||
- **Source:** `nanobot/agent/loop.py` (`_pending_queues`, unified-session routing, leftover re-publish),
|
||||
`nanobot/agent/runner.py` (injection checkpoints, resumed stream end handling).
|
||||
|
||||
### 15. Disable built-in/workspace skills via config (PR #2959)
|
||||
- **What changed:** New `disabledSkills` field under `agents.defaults`. Users can provide a list of
|
||||
skill directory names to exclude from loading, so selected built-in or workspace skills no longer
|
||||
appear in the main agent or subagent skill summaries and are not auto-injected as always-on skills.
|
||||
- **Where to update:**
|
||||
- `content.js` -> `agents.defaults` reference section:
|
||||
add `disabledSkills` as an array of skill names, explain that names match skill directory names,
|
||||
and note that disabled skills are hidden from both the main agent and subagents.
|
||||
- All 6 locale files -> same section.
|
||||
- **Source:** `nanobot/config/schema.py` (`AgentDefaults.disabled_skills`),
|
||||
`nanobot/agent/context.py` (`ContextBuilder`), `nanobot/agent/subagent.py` (`SubagentManager._build_subagent_prompt`),
|
||||
`nanobot/agent/skills.py` (`SkillsLoader` filtering).
|
||||
54
README.md
54
README.md
@ -560,7 +560,11 @@ Uses **WebSocket** long connection — no public IP required.
|
||||
"verificationToken": "",
|
||||
"allowFrom": ["ou_YOUR_OPEN_ID"],
|
||||
"groupPolicy": "mention",
|
||||
"streaming": true
|
||||
"reactEmoji": "OnIt",
|
||||
"doneEmoji": "DONE",
|
||||
"toolHintPrefix": "🔧",
|
||||
"streaming": true,
|
||||
"domain": "feishu"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -570,6 +574,10 @@ Uses **WebSocket** long connection — no public IP required.
|
||||
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
|
||||
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
|
||||
> `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond.
|
||||
> `reactEmoji`: Emoji for "processing" status (default: `OnIt`). See [available emojis](https://open.larkoffice.com/document/server-docs/im-v1/message-reaction/emojis-introduce).
|
||||
> `doneEmoji`: Optional emoji for "completed" status (e.g., `DONE`, `OK`, `HEART`). When set, bot adds this reaction after removing `reactEmoji`.
|
||||
> `toolHintPrefix`: Prefix for inline tool hints in streaming cards (default: `🔧`).
|
||||
> `domain`: `"feishu"` (default) for China (open.feishu.cn), `"lark"` for international Lark (open.larksuite.com).
|
||||
|
||||
**3. Run**
|
||||
|
||||
@ -1306,6 +1314,7 @@ If you need to allow trusted private ranges such as Tailscale / CGNAT addresses,
|
||||
| `brave` | `apiKey` | `BRAVE_API_KEY` | No |
|
||||
| `tavily` | `apiKey` | `TAVILY_API_KEY` | No |
|
||||
| `jina` | `apiKey` | `JINA_API_KEY` | Free tier (10M tokens) |
|
||||
| `kagi` | `apiKey` | `KAGI_API_KEY` | No |
|
||||
| `searxng` | `baseUrl` | `SEARXNG_BASE_URL` | Yes (self-hosted) |
|
||||
| `duckduckgo` (default) | — | — | Yes |
|
||||
|
||||
@ -1362,6 +1371,20 @@ If you need to allow trusted private ranges such as Tailscale / CGNAT addresses,
|
||||
}
|
||||
```
|
||||
|
||||
**Kagi:**
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"provider": "kagi",
|
||||
"apiKey": "your-kagi-api-key"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**SearXNG** (self-hosted, no API key needed):
|
||||
```json
|
||||
{
|
||||
@ -1497,6 +1520,35 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
|
||||
**Docker security**: The official Docker image runs as a non-root user (`nanobot`, UID 1000) with bubblewrap pre-installed. When using `docker-compose.yml`, the container drops all Linux capabilities except `SYS_ADMIN` (required for bwrap's namespace isolation).
|
||||
|
||||
|
||||
### Auto Compact
|
||||
|
||||
When a user is idle for longer than a configured threshold, nanobot **proactively** compresses the older part of the session context into a summary while keeping a recent legal suffix of live messages. This reduces token cost and first-token latency when the user returns — instead of re-processing a long stale context with an expired KV cache, the model receives a compact summary, the most recent live context, and fresh input.
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"idleCompactAfterMinutes": 15
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `agents.defaults.idleCompactAfterMinutes` | `0` (disabled) | Minutes of idle time before auto-compaction starts. Set to `0` to disable. Recommended: `15` — close to a typical LLM KV cache expiry window, so stale sessions get compacted before the user returns. |
|
||||
|
||||
`sessionTtlMinutes` remains accepted as a legacy alias for backward compatibility, but `idleCompactAfterMinutes` is the preferred config key going forward.
|
||||
|
||||
How it works:
|
||||
1. **Idle detection**: On each idle tick (~1 s), checks all sessions for expiration.
|
||||
2. **Background compaction**: Idle sessions summarize the older live prefix via LLM and keep the most recent legal suffix (currently 8 messages).
|
||||
3. **Summary injection**: When the user returns, the summary is injected as runtime context (one-shot, not persisted) alongside the retained recent suffix.
|
||||
4. **Restart-safe resume**: The summary is also mirrored into session metadata so it can still be recovered after a process restart.
|
||||
|
||||
> [!TIP]
|
||||
> Think of auto compact as "summarize older context, keep the freshest live turns." It is not a hard session reset.
|
||||
|
||||
### Timezone
|
||||
|
||||
Time is context. Context should be precise.
|
||||
|
||||
331
docs/WEBSOCKET.md
Normal file
331
docs/WEBSOCKET.md
Normal file
@ -0,0 +1,331 @@
|
||||
# WebSocket Server Channel
|
||||
|
||||
Nanobot can act as a WebSocket server, allowing external clients (web apps, CLIs, scripts) to interact with the agent in real time via persistent connections.
|
||||
|
||||
## Features
|
||||
|
||||
- Bidirectional real-time communication over WebSocket
|
||||
- Streaming support — receive agent responses token by token
|
||||
- Token-based authentication (static tokens and short-lived issued tokens)
|
||||
- Per-connection sessions — each connection gets a unique `chat_id`
|
||||
- TLS/SSL support (WSS) with enforced TLSv1.2 minimum
|
||||
- Client allow-list via `allowFrom`
|
||||
- Auto-cleanup of dead connections
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Configure
|
||||
|
||||
Add to `config.json` under `channels.websocket`:
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"host": "127.0.0.1",
|
||||
"port": 8765,
|
||||
"path": "/",
|
||||
"websocketRequiresToken": false,
|
||||
"allowFrom": ["*"],
|
||||
"streaming": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Start nanobot
|
||||
|
||||
```bash
|
||||
nanobot gateway
|
||||
```
|
||||
|
||||
You should see:
|
||||
|
||||
```
|
||||
WebSocket server listening on ws://127.0.0.1:8765/
|
||||
```
|
||||
|
||||
### 3. Connect a client
|
||||
|
||||
```bash
|
||||
# Using websocat
|
||||
websocat ws://127.0.0.1:8765/?client_id=alice
|
||||
|
||||
# Using Python
|
||||
import asyncio, json, websockets
|
||||
|
||||
async def main():
|
||||
async with websockets.connect("ws://127.0.0.1:8765/?client_id=alice") as ws:
|
||||
ready = json.loads(await ws.recv())
|
||||
print(ready) # {"event": "ready", "chat_id": "...", "client_id": "alice"}
|
||||
await ws.send(json.dumps({"content": "Hello nanobot!"}))
|
||||
reply = json.loads(await ws.recv())
|
||||
print(reply["text"])
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Connection URL
|
||||
|
||||
```
|
||||
ws://{host}:{port}{path}?client_id={id}&token={token}
|
||||
```
|
||||
|
||||
| Parameter | Required | Description |
|
||||
|-----------|----------|-------------|
|
||||
| `client_id` | No | Identifier for `allowFrom` authorization. Auto-generated as `anon-xxxxxxxxxxxx` if omitted. Truncated to 128 chars. |
|
||||
| `token` | Conditional | Authentication token. Required when `websocketRequiresToken` is `true` or `token` (static secret) is configured. |
|
||||
|
||||
## Wire Protocol
|
||||
|
||||
All frames are JSON text. Each message has an `event` field.
|
||||
|
||||
### Server → Client
|
||||
|
||||
**`ready`** — sent immediately after connection is established:
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "ready",
|
||||
"chat_id": "uuid-v4",
|
||||
"client_id": "alice"
|
||||
}
|
||||
```
|
||||
|
||||
**`message`** — full agent response:
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "message",
|
||||
"text": "Hello! How can I help?",
|
||||
"media": ["/tmp/image.png"],
|
||||
"reply_to": "msg-id"
|
||||
}
|
||||
```
|
||||
|
||||
`media` and `reply_to` are only present when applicable.
|
||||
|
||||
**`delta`** — streaming text chunk (only when `streaming: true`):
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "delta",
|
||||
"text": "Hello",
|
||||
"stream_id": "s1"
|
||||
}
|
||||
```
|
||||
|
||||
**`stream_end`** — signals the end of a streaming segment:
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "stream_end",
|
||||
"stream_id": "s1"
|
||||
}
|
||||
```
|
||||
|
||||
### Client → Server
|
||||
|
||||
Send plain text:
|
||||
|
||||
```json
|
||||
"Hello nanobot!"
|
||||
```
|
||||
|
||||
Or send a JSON object with a recognized text field:
|
||||
|
||||
```json
|
||||
{"content": "Hello nanobot!"}
|
||||
```
|
||||
|
||||
Recognized fields: `content`, `text`, `message` (checked in that order). Invalid JSON is treated as plain text.
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
All fields go under `channels.websocket` in `config.json`.
|
||||
|
||||
### Connection
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `enabled` | bool | `false` | Enable the WebSocket server. |
|
||||
| `host` | string | `"127.0.0.1"` | Bind address. Use `"0.0.0.0"` to accept external connections. |
|
||||
| `port` | int | `8765` | Listen port. |
|
||||
| `path` | string | `"/"` | WebSocket upgrade path. Trailing slashes are normalized (root `/` is preserved). |
|
||||
| `maxMessageBytes` | int | `1048576` | Maximum inbound message size in bytes (1 KB – 16 MB). |
|
||||
|
||||
### Authentication
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `token` | string | `""` | Static shared secret. When set, clients must provide `?token=<value>` matching this secret (timing-safe comparison). Issued tokens are also accepted as a fallback. |
|
||||
| `websocketRequiresToken` | bool | `true` | When `true` and no static `token` is configured, clients must still present a valid issued token. Set to `false` to allow unauthenticated connections (only safe for local/trusted networks). |
|
||||
| `tokenIssuePath` | string | `""` | HTTP path for issuing short-lived tokens. Must differ from `path`. See [Token Issuance](#token-issuance). |
|
||||
| `tokenIssueSecret` | string | `""` | Secret required to obtain tokens via the issue endpoint. If empty, any client can obtain tokens (logged as a warning). |
|
||||
| `tokenTtlS` | int | `300` | Time-to-live for issued tokens in seconds (30 – 86,400). |
|
||||
|
||||
### Access Control
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `allowFrom` | list of string | `["*"]` | Allowed `client_id` values. `"*"` allows all; `[]` denies all. |
|
||||
|
||||
### Streaming
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `streaming` | bool | `true` | Enable streaming mode. The agent sends `delta` + `stream_end` frames instead of a single `message`. |
|
||||
|
||||
### Keep-alive
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `pingIntervalS` | float | `20.0` | WebSocket ping interval in seconds (5 – 300). |
|
||||
| `pingTimeoutS` | float | `20.0` | Time to wait for a pong before closing the connection (5 – 300). |
|
||||
|
||||
### TLS/SSL
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `sslCertfile` | string | `""` | Path to the TLS certificate file (PEM). Both `sslCertfile` and `sslKeyfile` must be set to enable WSS. |
|
||||
| `sslKeyfile` | string | `""` | Path to the TLS private key file (PEM). Minimum TLS version is enforced as TLSv1.2. |
|
||||
|
||||
## Token Issuance
|
||||
|
||||
For production deployments where `websocketRequiresToken: true`, use short-lived tokens instead of embedding static secrets in clients.
|
||||
|
||||
### How it works
|
||||
|
||||
1. Client sends `GET {tokenIssuePath}` with `Authorization: Bearer {tokenIssueSecret}` (or `X-Nanobot-Auth` header).
|
||||
2. Server responds with a one-time-use token:
|
||||
|
||||
```json
|
||||
{"token": "nbwt_aBcDeFg...", "expires_in": 300}
|
||||
```
|
||||
|
||||
3. Client opens WebSocket with `?token=nbwt_aBcDeFg...&client_id=...`.
|
||||
4. The token is consumed (single use) and cannot be reused.
|
||||
|
||||
### Example setup
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"port": 8765,
|
||||
"path": "/ws",
|
||||
"tokenIssuePath": "/auth/token",
|
||||
"tokenIssueSecret": "your-secret-here",
|
||||
"tokenTtlS": 300,
|
||||
"websocketRequiresToken": true,
|
||||
"allowFrom": ["*"],
|
||||
"streaming": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Client flow:
|
||||
|
||||
```bash
|
||||
# 1. Obtain a token
|
||||
curl -H "Authorization: Bearer your-secret-here" http://127.0.0.1:8765/auth/token
|
||||
|
||||
# 2. Connect using the token
|
||||
websocat "ws://127.0.0.1:8765/ws?client_id=alice&token=nbwt_aBcDeFg..."
|
||||
```
|
||||
|
||||
### Limits
|
||||
|
||||
- Issued tokens are single-use — each token can only complete one handshake.
|
||||
- Outstanding tokens are capped at 10,000. Requests beyond this return HTTP 429.
|
||||
- Expired tokens are purged lazily on each issue or validation request.
|
||||
|
||||
## Security Notes
|
||||
|
||||
- **Timing-safe comparison**: Static token validation uses `hmac.compare_digest` to prevent timing attacks.
|
||||
- **Defense in depth**: `allowFrom` is checked at both the HTTP handshake level and the message level.
|
||||
- **Token isolation**: Each WebSocket connection gets a unique `chat_id`. Clients cannot access other sessions.
|
||||
- **TLS enforcement**: When SSL is enabled, TLSv1.2 is the minimum allowed version.
|
||||
- **Default-secure**: `websocketRequiresToken` defaults to `true`. Explicitly set it to `false` only on trusted networks.
|
||||
|
||||
## Media Files
|
||||
|
||||
Outbound `message` events may include a `media` field containing local filesystem paths. Remote clients cannot access these files directly — they need either:
|
||||
|
||||
- A shared filesystem mount, or
|
||||
- An HTTP file server serving the nanobot media directory
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Trusted local network (no auth)
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"host": "0.0.0.0",
|
||||
"port": 8765,
|
||||
"websocketRequiresToken": false,
|
||||
"allowFrom": ["*"],
|
||||
"streaming": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Static token (simple auth)
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"token": "my-shared-secret",
|
||||
"allowFrom": ["alice", "bob"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Clients connect with `?token=my-shared-secret&client_id=alice`.
|
||||
|
||||
### Public endpoint with issued tokens
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"host": "0.0.0.0",
|
||||
"port": 8765,
|
||||
"path": "/ws",
|
||||
"tokenIssuePath": "/auth/token",
|
||||
"tokenIssueSecret": "production-secret",
|
||||
"websocketRequiresToken": true,
|
||||
"sslCertfile": "/etc/ssl/certs/server.pem",
|
||||
"sslKeyfile": "/etc/ssl/private/server-key.pem",
|
||||
"allowFrom": ["*"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Custom path
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"path": "/chat/ws",
|
||||
"allowFrom": ["*"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Clients connect to `ws://127.0.0.1:8765/chat/ws?client_id=...`. Trailing slashes are normalized, so `/chat/ws/` works the same.
|
||||
115
nanobot/agent/autocompact.py
Normal file
115
nanobot/agent/autocompact.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""Auto compact: proactive compression of idle sessions to reduce token cost and latency."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
||||
|
||||
from loguru import logger
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.agent.memory import Consolidator
|
||||
|
||||
|
||||
class AutoCompact:
|
||||
_RECENT_SUFFIX_MESSAGES = 8
|
||||
|
||||
def __init__(self, sessions: SessionManager, consolidator: Consolidator,
|
||||
session_ttl_minutes: int = 0):
|
||||
self.sessions = sessions
|
||||
self.consolidator = consolidator
|
||||
self._ttl = session_ttl_minutes
|
||||
self._archiving: set[str] = set()
|
||||
self._summaries: dict[str, tuple[str, datetime]] = {}
|
||||
|
||||
def _is_expired(self, ts: datetime | str | None) -> bool:
|
||||
if self._ttl <= 0 or not ts:
|
||||
return False
|
||||
if isinstance(ts, str):
|
||||
ts = datetime.fromisoformat(ts)
|
||||
return (datetime.now() - ts).total_seconds() >= self._ttl * 60
|
||||
|
||||
@staticmethod
|
||||
def _format_summary(text: str, last_active: datetime) -> str:
|
||||
idle_min = int((datetime.now() - last_active).total_seconds() / 60)
|
||||
return f"Inactive for {idle_min} minutes.\nPrevious conversation summary: {text}"
|
||||
|
||||
def _split_unconsolidated(
|
||||
self, session: Session,
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Split live session tail into archiveable prefix and retained recent suffix."""
|
||||
tail = list(session.messages[session.last_consolidated:])
|
||||
if not tail:
|
||||
return [], []
|
||||
|
||||
probe = Session(
|
||||
key=session.key,
|
||||
messages=tail.copy(),
|
||||
created_at=session.created_at,
|
||||
updated_at=session.updated_at,
|
||||
metadata={},
|
||||
last_consolidated=0,
|
||||
)
|
||||
probe.retain_recent_legal_suffix(self._RECENT_SUFFIX_MESSAGES)
|
||||
kept = probe.messages
|
||||
cut = len(tail) - len(kept)
|
||||
return tail[:cut], kept
|
||||
|
||||
def check_expired(self, schedule_background: Callable[[Coroutine], None]) -> None:
|
||||
for info in self.sessions.list_sessions():
|
||||
key = info.get("key", "")
|
||||
if key and key not in self._archiving and self._is_expired(info.get("updated_at")):
|
||||
self._archiving.add(key)
|
||||
logger.debug("Auto-compact: scheduling archival for {} (idle > {} min)", key, self._ttl)
|
||||
schedule_background(self._archive(key))
|
||||
|
||||
async def _archive(self, key: str) -> None:
|
||||
try:
|
||||
self.sessions.invalidate(key)
|
||||
session = self.sessions.get_or_create(key)
|
||||
archive_msgs, kept_msgs = self._split_unconsolidated(session)
|
||||
if not archive_msgs and not kept_msgs:
|
||||
logger.debug("Auto-compact: skipping {}, no un-consolidated messages", key)
|
||||
session.updated_at = datetime.now()
|
||||
self.sessions.save(session)
|
||||
return
|
||||
|
||||
last_active = session.updated_at
|
||||
summary = ""
|
||||
if archive_msgs:
|
||||
summary = await self.consolidator.archive(archive_msgs) or ""
|
||||
if summary and summary != "(nothing)":
|
||||
self._summaries[key] = (summary, last_active)
|
||||
session.metadata["_last_summary"] = {"text": summary, "last_active": last_active.isoformat()}
|
||||
session.messages = kept_msgs
|
||||
session.last_consolidated = 0
|
||||
session.updated_at = datetime.now()
|
||||
self.sessions.save(session)
|
||||
logger.info(
|
||||
"Auto-compact: archived {} (archived={}, kept={}, summary={})",
|
||||
key,
|
||||
len(archive_msgs),
|
||||
len(kept_msgs),
|
||||
bool(summary),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Auto-compact: failed for {}", key)
|
||||
finally:
|
||||
self._archiving.discard(key)
|
||||
|
||||
def prepare_session(self, session: Session, key: str) -> tuple[Session, str | None]:
|
||||
if key in self._archiving or self._is_expired(session.updated_at):
|
||||
logger.info("Auto-compact: reloading session {} (archiving={})", key, key in self._archiving)
|
||||
session = self.sessions.get_or_create(key)
|
||||
# Hot path: summary from in-memory dict (process hasn't restarted).
|
||||
# Also clean metadata copy so stale _last_summary never leaks to disk.
|
||||
entry = self._summaries.pop(key, None)
|
||||
if entry:
|
||||
session.metadata.pop("_last_summary", None)
|
||||
return session, self._format_summary(entry[0], entry[1])
|
||||
if "_last_summary" in session.metadata:
|
||||
meta = session.metadata.pop("_last_summary")
|
||||
self.sessions.save(session)
|
||||
return session, self._format_summary(meta["text"], datetime.fromisoformat(meta["last_active"]))
|
||||
return session, None
|
||||
@ -20,6 +20,7 @@ class ContextBuilder:
|
||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
|
||||
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||
_MAX_RECENT_HISTORY = 50
|
||||
_RUNTIME_CONTEXT_END = "[/Runtime Context]"
|
||||
|
||||
def __init__(self, workspace: Path, timezone: str | None = None, disabled_skills: list[str] | None = None):
|
||||
self.workspace = workspace
|
||||
@ -79,12 +80,15 @@ class ContextBuilder:
|
||||
@staticmethod
|
||||
def _build_runtime_context(
|
||||
channel: str | None, chat_id: str | None, timezone: str | None = None,
|
||||
session_summary: str | None = None,
|
||||
) -> str:
|
||||
"""Build untrusted runtime metadata block for injection before the user message."""
|
||||
lines = [f"Current Time: {current_time_str(timezone)}"]
|
||||
if channel and chat_id:
|
||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
||||
if session_summary:
|
||||
lines += ["", "[Resumed Session]", session_summary]
|
||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + "\n" + ContextBuilder._RUNTIME_CONTEXT_END
|
||||
|
||||
@staticmethod
|
||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||
@ -121,9 +125,10 @@ class ContextBuilder:
|
||||
channel: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
current_role: str = "user",
|
||||
session_summary: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the complete message list for an LLM call."""
|
||||
runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone)
|
||||
runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone, session_summary=session_summary)
|
||||
user_content = self._build_user_content(current_message, media)
|
||||
|
||||
# Merge runtime context and user content into a single user message
|
||||
|
||||
@ -13,15 +13,17 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.autocompact import AutoCompact
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.memory import Consolidator, Dream
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.search import GlobTool, GrepTool
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
@ -33,7 +35,7 @@ from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.utils.helpers import image_placeholder_text, truncate_text
|
||||
from nanobot.utils.helpers import image_placeholder_text, truncate_text as truncate_text_fn
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -43,6 +45,7 @@ if TYPE_CHECKING:
|
||||
|
||||
UNIFIED_SESSION_KEY = "unified:default"
|
||||
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
"""Core hook for the main loop."""
|
||||
|
||||
@ -76,7 +79,7 @@ class _LoopHook(AgentHook):
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
incremental = new_clean[len(prev_clean) :]
|
||||
if incremental and self._on_stream:
|
||||
await self._on_stream(incremental)
|
||||
|
||||
@ -112,6 +115,7 @@ class _LoopHook(AgentHook):
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return self._loop._strip_think(content)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
@ -145,6 +149,7 @@ class AgentLoop:
|
||||
mcp_servers: dict | None = None,
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
timezone: str | None = None,
|
||||
session_ttl_minutes: int = 0,
|
||||
hooks: list[AgentHook] | None = None,
|
||||
unified_session: bool = False,
|
||||
disabled_skills: list[str] | None = None,
|
||||
@ -193,16 +198,21 @@ class AgentLoop:
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
exec_config=self.exec_config,
|
||||
restrict_to_workspace=restrict_to_workspace,
|
||||
disabled_skills=disabled_skills,
|
||||
)
|
||||
self._unified_session = unified_session
|
||||
self._running = False
|
||||
self._mcp_servers = mcp_servers or {}
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
self._mcp_stacks: dict[str, AsyncExitStack] = {}
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._background_tasks: list[asyncio.Task] = []
|
||||
self._session_locks: dict[str, asyncio.Lock] = {}
|
||||
# Per-session pending queues for mid-turn message injection.
|
||||
# When a session has an active task, new messages for that session
|
||||
# are routed here instead of creating a new task.
|
||||
self._pending_queues: dict[str, asyncio.Queue] = {}
|
||||
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
|
||||
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
|
||||
self._concurrency_gate: asyncio.Semaphore | None = (
|
||||
@ -218,6 +228,11 @@ class AgentLoop:
|
||||
get_tool_definitions=self.tools.get_definitions,
|
||||
max_completion_tokens=provider.generation.max_tokens,
|
||||
)
|
||||
self.auto_compact = AutoCompact(
|
||||
sessions=self.sessions,
|
||||
consolidator=self.consolidator,
|
||||
session_ttl_minutes=session_ttl_minutes,
|
||||
)
|
||||
self.dream = Dream(
|
||||
store=self.context.memory,
|
||||
provider=provider,
|
||||
@ -229,23 +244,35 @@ class AgentLoop:
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
allowed_dir = (
|
||||
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
)
|
||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
||||
self.tools.register(
|
||||
ReadFileTool(
|
||||
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
||||
)
|
||||
)
|
||||
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
for cls in (GlobTool, GrepTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
self.tools.register(NotebookEditTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
if self.exec_config.enable:
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
sandbox=self.exec_config.sandbox,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
self.tools.register(
|
||||
ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
sandbox=self.exec_config.sandbox,
|
||||
path_append=self.exec_config.path_append,
|
||||
allowed_env_keys=self.exec_config.allowed_env_keys,
|
||||
)
|
||||
)
|
||||
if self.web_config.enable:
|
||||
self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy))
|
||||
self.tools.register(
|
||||
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
|
||||
)
|
||||
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||
self.tools.register(SpawnTool(manager=self.subagents))
|
||||
@ -260,19 +287,19 @@ class AgentLoop:
|
||||
return
|
||||
self._mcp_connecting = True
|
||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||
|
||||
try:
|
||||
self._mcp_stack = AsyncExitStack()
|
||||
await self._mcp_stack.__aenter__()
|
||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||
self._mcp_connected = True
|
||||
self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools)
|
||||
if self._mcp_stacks:
|
||||
self._mcp_connected = True
|
||||
else:
|
||||
logger.warning("No MCP servers connected successfully (will retry next message)")
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("MCP connection cancelled (will retry next message)")
|
||||
self._mcp_stacks.clear()
|
||||
except BaseException as e:
|
||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._mcp_stack = None
|
||||
self._mcp_stacks.clear()
|
||||
finally:
|
||||
self._mcp_connecting = False
|
||||
|
||||
@ -289,6 +316,7 @@ class AgentLoop:
|
||||
if not text:
|
||||
return None
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
return strip_think(text) or None
|
||||
|
||||
@staticmethod
|
||||
@ -298,6 +326,12 @@ class AgentLoop:
|
||||
|
||||
return format_tool_hints(tool_calls)
|
||||
|
||||
def _effective_session_key(self, msg: InboundMessage) -> str:
|
||||
"""Return the session key used for task routing and mid-turn injections."""
|
||||
if self._unified_session and not msg.session_key_override:
|
||||
return UNIFIED_SESSION_KEY
|
||||
return msg.session_key
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
@ -309,13 +343,16 @@ class AgentLoop:
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict], str, bool]:
|
||||
"""Run the agent iteration loop.
|
||||
|
||||
*on_stream*: called with each content delta during streaming.
|
||||
*on_stream_end(resuming)*: called when a streaming session finishes.
|
||||
``resuming=True`` means tool calls follow (spinner should restart);
|
||||
``resuming=False`` means this is the final response.
|
||||
|
||||
Returns (final_content, tools_used, messages, stop_reason, had_injections).
|
||||
"""
|
||||
loop_hook = _LoopHook(
|
||||
self,
|
||||
@ -327,9 +364,7 @@ class AgentLoop:
|
||||
message_id=message_id,
|
||||
)
|
||||
hook: AgentHook = (
|
||||
CompositeHook([loop_hook] + self._extra_hooks)
|
||||
if self._extra_hooks
|
||||
else loop_hook
|
||||
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
|
||||
)
|
||||
|
||||
async def _checkpoint(payload: dict[str, Any]) -> None:
|
||||
@ -337,6 +372,32 @@ class AgentLoop:
|
||||
return
|
||||
self._set_runtime_checkpoint(session, payload)
|
||||
|
||||
async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
|
||||
"""Non-blocking drain of follow-up messages from the pending queue."""
|
||||
if pending_queue is None:
|
||||
return []
|
||||
items: list[dict[str, Any]] = []
|
||||
while len(items) < limit:
|
||||
try:
|
||||
pending_msg = pending_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
user_content = self.context._build_user_content(
|
||||
pending_msg.content,
|
||||
pending_msg.media if pending_msg.media else None,
|
||||
)
|
||||
runtime_ctx = self.context._build_runtime_context(
|
||||
pending_msg.channel,
|
||||
pending_msg.chat_id,
|
||||
self.context.timezone,
|
||||
)
|
||||
if isinstance(user_content, str):
|
||||
merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}"
|
||||
else:
|
||||
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||
items.append({"role": "user", "content": merged})
|
||||
return items
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
@ -353,13 +414,14 @@ class AgentLoop:
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
checkpoint_callback=_checkpoint,
|
||||
injection_callback=_drain_pending,
|
||||
))
|
||||
self._last_usage = result.usage
|
||||
if result.stop_reason == "max_iterations":
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
elif result.stop_reason == "error":
|
||||
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
|
||||
return result.final_content, result.tools_used, result.messages
|
||||
return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||
@ -371,6 +433,7 @@ class AgentLoop:
|
||||
try:
|
||||
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.auto_compact.check_expired(self._schedule_background)
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
# Preserve real task cancellation so shutdown can complete cleanly.
|
||||
@ -389,84 +452,140 @@ class AgentLoop:
|
||||
if result:
|
||||
await self.bus.publish_outbound(result)
|
||||
continue
|
||||
effective_key = self._effective_session_key(msg)
|
||||
# If this session already has an active pending queue (i.e. a task
|
||||
# is processing this session), route the message there for mid-turn
|
||||
# injection instead of creating a competing task.
|
||||
if effective_key in self._pending_queues:
|
||||
pending_msg = msg
|
||||
if effective_key != msg.session_key:
|
||||
pending_msg = dataclasses.replace(
|
||||
msg,
|
||||
session_key_override=effective_key,
|
||||
)
|
||||
try:
|
||||
self._pending_queues[effective_key].put_nowait(pending_msg)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"Pending queue full for session {}, falling back to queued task",
|
||||
effective_key,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Routed follow-up message to pending queue for session {}",
|
||||
effective_key,
|
||||
)
|
||||
continue
|
||||
# Compute the effective session key before dispatching
|
||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||
effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||
task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||
task.add_done_callback(
|
||||
lambda t, k=effective_key: self._active_tasks.get(k, [])
|
||||
and self._active_tasks[k].remove(t)
|
||||
if t in self._active_tasks.get(k, [])
|
||||
else None
|
||||
)
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message: per-session serial, cross-session concurrent."""
|
||||
if self._unified_session and not msg.session_key_override:
|
||||
msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY)
|
||||
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
|
||||
session_key = self._effective_session_key(msg)
|
||||
if session_key != msg.session_key:
|
||||
msg = dataclasses.replace(msg, session_key_override=session_key)
|
||||
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
||||
gate = self._concurrency_gate or nullcontext()
|
||||
async with lock, gate:
|
||||
try:
|
||||
on_stream = on_stream_end = None
|
||||
if msg.metadata.get("_wants_stream"):
|
||||
# Split one answer into distinct stream segments.
|
||||
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
|
||||
stream_segment = 0
|
||||
|
||||
def _current_stream_id() -> str:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
# Register a pending queue so follow-up messages for this session are
|
||||
# routed here (mid-turn injection) instead of spawning a new task.
|
||||
pending = asyncio.Queue(maxsize=20)
|
||||
self._pending_queues[session_key] = pending
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
try:
|
||||
async with lock, gate:
|
||||
try:
|
||||
on_stream = on_stream_end = None
|
||||
if msg.metadata.get("_wants_stream"):
|
||||
# Split one answer into distinct stream segments.
|
||||
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
|
||||
stream_segment = 0
|
||||
|
||||
def _current_stream_id() -> str:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
))
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
))
|
||||
stream_segment += 1
|
||||
|
||||
response = await self._process_message(
|
||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
pending_queue=pending,
|
||||
)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
content="", metadata=msg.metadata or {},
|
||||
))
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
))
|
||||
stream_segment += 1
|
||||
|
||||
response = await self._process_message(
|
||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="", metadata=msg.metadata or {},
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", msg.session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", msg.session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
finally:
|
||||
# Drain any messages still in the pending queue and re-publish
|
||||
# them to the bus so they are processed as fresh inbound messages
|
||||
# rather than silently lost.
|
||||
queue = self._pending_queues.pop(session_key, None)
|
||||
if queue is not None:
|
||||
leftover = 0
|
||||
while True:
|
||||
try:
|
||||
item = queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
await self.bus.publish_inbound(item)
|
||||
leftover += 1
|
||||
if leftover:
|
||||
logger.info(
|
||||
"Re-published {} leftover message(s) to bus for session {}",
|
||||
leftover, session_key,
|
||||
)
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Drain pending background archives, then close MCP connections."""
|
||||
if self._background_tasks:
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
if self._mcp_stack:
|
||||
for name, stack in self._mcp_stacks.items():
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
await stack.aclose()
|
||||
except (RuntimeError, BaseExceptionGroup):
|
||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
||||
self._mcp_stack = None
|
||||
logger.debug("MCP server '{}' cleanup error (can be ignored)", name)
|
||||
self._mcp_stacks.clear()
|
||||
|
||||
def _schedule_background(self, coro) -> None:
|
||||
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
||||
@ -486,27 +605,34 @@ class AgentLoop:
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a single inbound message and return the response."""
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
if msg.channel == "system":
|
||||
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
||||
else ("cli", msg.chat_id))
|
||||
channel, chat_id = (
|
||||
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
|
||||
)
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
|
||||
session, pending = self.auto_compact.prepare_session(session, key)
|
||||
|
||||
await self.consolidator.maybe_consolidate_by_tokens(session)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=0)
|
||||
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
||||
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
session_summary=pending,
|
||||
current_role=current_role,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
@ -514,8 +640,11 @@ class AgentLoop:
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
return OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=final_content or "Background task completed.",
|
||||
)
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
@ -525,6 +654,8 @@ class AgentLoop:
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
|
||||
session, pending = self.auto_compact.prepare_session(session, key)
|
||||
|
||||
# Slash commands
|
||||
raw = msg.content.strip()
|
||||
ctx = CommandContext(msg=msg, session=session, key=key, raw=raw, loop=self)
|
||||
@ -539,29 +670,39 @@ class AgentLoop:
|
||||
message_tool.start_turn()
|
||||
|
||||
history = session.get_history(max_messages=0)
|
||||
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
session_summary=pending,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_progress"] = True
|
||||
meta["_tool_hint"] = tool_hint
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=content,
|
||||
metadata=meta,
|
||||
)
|
||||
)
|
||||
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
|
||||
initial_messages,
|
||||
on_progress=on_progress or _bus_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
session=session,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
if final_content is None or not final_content.strip():
|
||||
@ -572,17 +713,26 @@ class AgentLoop:
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
|
||||
# When follow-up messages were injected mid-turn, a later natural
|
||||
# language reply may address those follow-ups and should not be
|
||||
# suppressed just because MessageTool was used earlier in the turn.
|
||||
# However, if the turn falls back to the empty-final-response
|
||||
# placeholder, suppress it when the real user-visible output already
|
||||
# came from MessageTool.
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
if not had_injections or stop_reason == "empty_final_response":
|
||||
return None
|
||||
|
||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
meta = dict(msg.metadata or {})
|
||||
if on_stream is not None:
|
||||
if on_stream is not None and stop_reason != "error":
|
||||
meta["_streamed"] = True
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=final_content,
|
||||
metadata=meta,
|
||||
)
|
||||
|
||||
@ -590,7 +740,7 @@ class AgentLoop:
|
||||
self,
|
||||
content: list[dict[str, Any]],
|
||||
*,
|
||||
truncate_text: bool = False,
|
||||
should_truncate_text: bool = False,
|
||||
drop_runtime: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Strip volatile multimodal payloads before writing session history."""
|
||||
@ -608,18 +758,17 @@ class AgentLoop:
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
block.get("type") == "image_url"
|
||||
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
||||
):
|
||||
if block.get("type") == "image_url" and block.get("image_url", {}).get(
|
||||
"url", ""
|
||||
).startswith("data:image/"):
|
||||
path = (block.get("_meta") or {}).get("path", "")
|
||||
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||
continue
|
||||
|
||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||
text = block["text"]
|
||||
if truncate_text and len(text) > self.max_tool_result_chars:
|
||||
text = truncate_text(text, self.max_tool_result_chars)
|
||||
if should_truncate_text and len(text) > self.max_tool_result_chars:
|
||||
text = truncate_text_fn(text, self.max_tool_result_chars)
|
||||
filtered.append({**block, "text": text})
|
||||
continue
|
||||
|
||||
@ -630,6 +779,7 @@ class AgentLoop:
|
||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||
"""Save new-turn messages into session, truncating large tool results."""
|
||||
from datetime import datetime
|
||||
|
||||
for m in messages[skip:]:
|
||||
entry = dict(m)
|
||||
role, content = entry.get("role"), entry.get("content")
|
||||
@ -637,20 +787,31 @@ class AgentLoop:
|
||||
continue # skip empty assistant messages — they poison session context
|
||||
if role == "tool":
|
||||
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
|
||||
entry["content"] = truncate_text(content, self.max_tool_result_chars)
|
||||
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
|
||||
elif isinstance(content, list):
|
||||
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
|
||||
filtered = self._sanitize_persisted_blocks(content, should_truncate_text=True)
|
||||
if not filtered:
|
||||
continue
|
||||
entry["content"] = filtered
|
||||
elif role == "user":
|
||||
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||
# Strip the runtime-context prefix, keep only the user text.
|
||||
parts = content.split("\n\n", 1)
|
||||
if len(parts) > 1 and parts[1].strip():
|
||||
entry["content"] = parts[1]
|
||||
# Strip the entire runtime-context block (including any session summary).
|
||||
# The block is bounded by _RUNTIME_CONTEXT_TAG and _RUNTIME_CONTEXT_END.
|
||||
end_marker = ContextBuilder._RUNTIME_CONTEXT_END
|
||||
end_pos = content.find(end_marker)
|
||||
if end_pos >= 0:
|
||||
after = content[end_pos + len(end_marker):].lstrip("\n")
|
||||
if after:
|
||||
entry["content"] = after
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
# Fallback: no end marker found, strip the tag prefix
|
||||
after_tag = content[len(ContextBuilder._RUNTIME_CONTEXT_TAG):].lstrip("\n")
|
||||
if after_tag.strip():
|
||||
entry["content"] = after_tag
|
||||
else:
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
filtered = self._sanitize_persisted_blocks(content, drop_runtime=True)
|
||||
if not filtered:
|
||||
@ -708,13 +869,15 @@ class AgentLoop:
|
||||
continue
|
||||
tool_id = tool_call.get("id")
|
||||
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||
restored_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
restored_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
overlap = 0
|
||||
max_overlap = min(len(session.messages), len(restored_messages))
|
||||
@ -746,6 +909,9 @@ class AgentLoop:
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
return await self._process_message(
|
||||
msg, session_key=session_key, on_progress=on_progress,
|
||||
on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
msg,
|
||||
session_key=session_key,
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
@ -290,7 +290,7 @@ class MemoryStore:
|
||||
if not lines:
|
||||
return None
|
||||
return json.loads(lines[-1])
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
except (FileNotFoundError, json.JSONDecodeError, UnicodeDecodeError):
|
||||
return None
|
||||
|
||||
def _write_entries(self, entries: list[dict[str, Any]]) -> None:
|
||||
@ -347,6 +347,7 @@ class Consolidator:
|
||||
"""Lightweight consolidation: summarizes evicted messages into history.jsonl."""
|
||||
|
||||
_MAX_CONSOLIDATION_ROUNDS = 5
|
||||
_MAX_CHUNK_MESSAGES = 60 # hard cap per consolidation round
|
||||
|
||||
_SAFETY_BUFFER = 1024 # extra headroom for tokenizer estimation drift
|
||||
|
||||
@ -399,6 +400,22 @@ class Consolidator:
|
||||
|
||||
return last_boundary
|
||||
|
||||
def _cap_consolidation_boundary(
|
||||
self,
|
||||
session: Session,
|
||||
end_idx: int,
|
||||
) -> int | None:
|
||||
"""Clamp the chunk size without breaking the user-turn boundary."""
|
||||
start = session.last_consolidated
|
||||
if end_idx - start <= self._MAX_CHUNK_MESSAGES:
|
||||
return end_idx
|
||||
|
||||
capped_end = start + self._MAX_CHUNK_MESSAGES
|
||||
for idx in range(capped_end, start, -1):
|
||||
if session.messages[idx].get("role") == "user":
|
||||
return idx
|
||||
return None
|
||||
|
||||
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
|
||||
"""Estimate current prompt size for the normal session history view."""
|
||||
history = session.get_history(max_messages=0)
|
||||
@ -416,13 +433,13 @@ class Consolidator:
|
||||
self._get_tool_definitions(),
|
||||
)
|
||||
|
||||
async def archive(self, messages: list[dict]) -> bool:
|
||||
async def archive(self, messages: list[dict]) -> str | None:
|
||||
"""Summarize messages via LLM and append to history.jsonl.
|
||||
|
||||
Returns True on success (or degraded success), False if nothing to do.
|
||||
Returns the summary text on success, None if nothing to archive.
|
||||
"""
|
||||
if not messages:
|
||||
return False
|
||||
return None
|
||||
try:
|
||||
formatted = MemoryStore._format_messages(messages)
|
||||
response = await self.provider.chat_with_retry(
|
||||
@ -442,11 +459,11 @@ class Consolidator:
|
||||
)
|
||||
summary = response.content or "[no summary]"
|
||||
self.store.append_history(summary)
|
||||
return True
|
||||
return summary
|
||||
except Exception:
|
||||
logger.warning("Consolidation LLM call failed, raw-dumping to history")
|
||||
self.store.raw_archive(messages)
|
||||
return True
|
||||
return None
|
||||
|
||||
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||
"""Loop: archive old messages until prompt fits within safe budget.
|
||||
@ -461,16 +478,22 @@ class Consolidator:
|
||||
async with lock:
|
||||
budget = self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER
|
||||
target = budget // 2
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
try:
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
except Exception:
|
||||
logger.exception("Token estimation failed for {}", session.key)
|
||||
estimated, source = 0, "error"
|
||||
if estimated <= 0:
|
||||
return
|
||||
if estimated < budget:
|
||||
unconsolidated_count = len(session.messages) - session.last_consolidated
|
||||
logger.debug(
|
||||
"Token consolidation idle {}: {}/{} via {}",
|
||||
"Token consolidation idle {}: {}/{} via {}, msgs={}",
|
||||
session.key,
|
||||
estimated,
|
||||
self.context_window_tokens,
|
||||
source,
|
||||
unconsolidated_count,
|
||||
)
|
||||
return
|
||||
|
||||
@ -488,6 +511,15 @@ class Consolidator:
|
||||
return
|
||||
|
||||
end_idx = boundary[0]
|
||||
end_idx = self._cap_consolidation_boundary(session, end_idx)
|
||||
if end_idx is None:
|
||||
logger.debug(
|
||||
"Token consolidation: no capped boundary for {} (round {})",
|
||||
session.key,
|
||||
round_num,
|
||||
)
|
||||
return
|
||||
|
||||
chunk = session.messages[session.last_consolidated:end_idx]
|
||||
if not chunk:
|
||||
return
|
||||
@ -506,7 +538,11 @@ class Consolidator:
|
||||
session.last_consolidated = end_idx
|
||||
self.sessions.save(session)
|
||||
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
try:
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
except Exception:
|
||||
logger.exception("Token estimation failed for {}", session.key)
|
||||
estimated, source = 0, "error"
|
||||
if estimated <= 0:
|
||||
return
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -31,8 +32,11 @@ from nanobot.utils.runtime import (
|
||||
)
|
||||
|
||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||
_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]"
|
||||
_MAX_EMPTY_RETRIES = 2
|
||||
_MAX_LENGTH_RECOVERIES = 3
|
||||
_MAX_INJECTIONS_PER_TURN = 3
|
||||
_MAX_INJECTION_CYCLES = 5
|
||||
_SNIP_SAFETY_BUFFER = 1024
|
||||
_MICROCOMPACT_KEEP_RECENT = 10
|
||||
_MICROCOMPACT_MIN_CHARS = 500
|
||||
@ -41,6 +45,9 @@ _COMPACTABLE_TOOLS = frozenset({
|
||||
"web_search", "web_fetch", "list_dir",
|
||||
})
|
||||
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
|
||||
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRunSpec:
|
||||
"""Configuration for a single agent execution."""
|
||||
@ -65,6 +72,7 @@ class AgentRunSpec:
|
||||
provider_retry_mode: str = "standard"
|
||||
progress_callback: Any | None = None
|
||||
checkpoint_callback: Any | None = None
|
||||
injection_callback: Any | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -78,6 +86,7 @@ class AgentRunResult:
|
||||
stop_reason: str = "completed"
|
||||
error: str | None = None
|
||||
tool_events: list[dict[str, str]] = field(default_factory=list)
|
||||
had_injections: bool = False
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
@ -86,6 +95,90 @@ class AgentRunner:
|
||||
def __init__(self, provider: LLMProvider):
|
||||
self.provider = provider
|
||||
|
||||
@staticmethod
|
||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||
if isinstance(left, str) and isinstance(right, str):
|
||||
return f"{left}\n\n{right}" if left else right
|
||||
|
||||
def _to_blocks(value: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(value, list):
|
||||
return [
|
||||
item if isinstance(item, dict) else {"type": "text", "text": str(item)}
|
||||
for item in value
|
||||
]
|
||||
if value is None:
|
||||
return []
|
||||
return [{"type": "text", "text": str(value)}]
|
||||
|
||||
return _to_blocks(left) + _to_blocks(right)
|
||||
|
||||
@classmethod
|
||||
def _append_injected_messages(
|
||||
cls,
|
||||
messages: list[dict[str, Any]],
|
||||
injections: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Append injected user messages while preserving role alternation."""
|
||||
for injection in injections:
|
||||
if (
|
||||
messages
|
||||
and injection.get("role") == "user"
|
||||
and messages[-1].get("role") == "user"
|
||||
):
|
||||
merged = dict(messages[-1])
|
||||
merged["content"] = cls._merge_message_content(
|
||||
merged.get("content"),
|
||||
injection.get("content"),
|
||||
)
|
||||
messages[-1] = merged
|
||||
continue
|
||||
messages.append(injection)
|
||||
|
||||
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
|
||||
"""Drain pending user messages via the injection callback.
|
||||
|
||||
Returns normalized user messages (capped by
|
||||
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
|
||||
nothing to inject. Messages beyond the cap are logged so they
|
||||
are not silently lost.
|
||||
"""
|
||||
if spec.injection_callback is None:
|
||||
return []
|
||||
try:
|
||||
signature = inspect.signature(spec.injection_callback)
|
||||
accepts_limit = (
|
||||
"limit" in signature.parameters
|
||||
or any(
|
||||
parameter.kind is inspect.Parameter.VAR_KEYWORD
|
||||
for parameter in signature.parameters.values()
|
||||
)
|
||||
)
|
||||
if accepts_limit:
|
||||
items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN)
|
||||
else:
|
||||
items = await spec.injection_callback()
|
||||
except Exception:
|
||||
logger.exception("injection_callback failed")
|
||||
return []
|
||||
if not items:
|
||||
return []
|
||||
injected_messages: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict) and item.get("role") == "user" and "content" in item:
|
||||
injected_messages.append(item)
|
||||
continue
|
||||
text = getattr(item, "content", str(item))
|
||||
if text.strip():
|
||||
injected_messages.append({"role": "user", "content": text})
|
||||
if len(injected_messages) > _MAX_INJECTIONS_PER_TURN:
|
||||
dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN
|
||||
logger.warning(
|
||||
"Injection callback returned {} messages, capping to {} ({} dropped)",
|
||||
len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped,
|
||||
)
|
||||
injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN]
|
||||
return injected_messages
|
||||
|
||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||
hook = spec.hook or AgentHook()
|
||||
messages = list(spec.initial_messages)
|
||||
@ -98,21 +191,35 @@ class AgentRunner:
|
||||
external_lookup_counts: dict[str, int] = {}
|
||||
empty_content_retries = 0
|
||||
length_recovery_count = 0
|
||||
had_injections = False
|
||||
injection_cycles = 0
|
||||
|
||||
for iteration in range(spec.max_iterations):
|
||||
try:
|
||||
messages = self._backfill_missing_tool_results(messages)
|
||||
messages = self._microcompact(messages)
|
||||
messages = self._apply_tool_result_budget(spec, messages)
|
||||
messages_for_model = self._snip_history(spec, messages)
|
||||
# Keep the persisted conversation untouched. Context governance
|
||||
# may repair or compact historical messages for the model, but
|
||||
# those synthetic edits must not shift the append boundary used
|
||||
# later when the caller saves only the new turn.
|
||||
messages_for_model = self._drop_orphan_tool_results(messages)
|
||||
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||
messages_for_model = self._microcompact(messages_for_model)
|
||||
messages_for_model = self._apply_tool_result_budget(spec, messages_for_model)
|
||||
messages_for_model = self._snip_history(spec, messages_for_model)
|
||||
# Snipping may have created new orphans; clean them up.
|
||||
messages_for_model = self._drop_orphan_tool_results(messages_for_model)
|
||||
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Context governance failed on turn {} for {}: {}; using raw messages",
|
||||
"Context governance failed on turn {} for {}: {}; applying minimal repair",
|
||||
iteration,
|
||||
spec.session_key or "default",
|
||||
exc,
|
||||
)
|
||||
messages_for_model = messages
|
||||
try:
|
||||
messages_for_model = self._drop_orphan_tool_results(messages)
|
||||
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||
except Exception:
|
||||
messages_for_model = messages
|
||||
context = AgentHookContext(iteration=iteration, messages=messages)
|
||||
await hook.before_iteration(context)
|
||||
response = await self._request_model(spec, messages_for_model, hook, context)
|
||||
@ -156,16 +263,6 @@ class AgentRunner:
|
||||
tool_events.extend(new_events)
|
||||
context.tool_results = list(results)
|
||||
context.tool_events = list(new_events)
|
||||
if fatal_error is not None:
|
||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||
final_content = error
|
||||
stop_reason = "tool_error"
|
||||
self._append_final_message(messages, final_content)
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
completed_tool_results: list[dict[str, Any]] = []
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
tool_message = {
|
||||
@ -181,6 +278,16 @@ class AgentRunner:
|
||||
}
|
||||
messages.append(tool_message)
|
||||
completed_tool_results.append(tool_message)
|
||||
if fatal_error is not None:
|
||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||
final_content = error
|
||||
stop_reason = "tool_error"
|
||||
self._append_final_message(messages, final_content)
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
@ -194,6 +301,17 @@ class AgentRunner:
|
||||
)
|
||||
empty_content_retries = 0
|
||||
length_recovery_count = 0
|
||||
# Checkpoint 1: drain injections after tools, before next LLM call
|
||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
||||
injections = await self._drain_injections(spec)
|
||||
if injections:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after tool execution ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
@ -250,14 +368,55 @@ class AgentRunner:
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
assistant_message: dict[str, Any] | None = None
|
||||
if response.finish_reason != "error" and not is_blank_text(clean):
|
||||
assistant_message = build_assistant_message(
|
||||
clean,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
|
||||
# Check for mid-turn injections BEFORE signaling stream end.
|
||||
# If injections are found we keep the stream alive (resuming=True)
|
||||
# so streaming channels don't prematurely finalize the card.
|
||||
_injected_after_final = False
|
||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
||||
injections = await self._drain_injections(spec)
|
||||
if injections:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
_injected_after_final = True
|
||||
if assistant_message is not None:
|
||||
messages.append(assistant_message)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after final response ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
await hook.on_stream_end(context, resuming=_injected_after_final)
|
||||
|
||||
if _injected_after_final:
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
if response.finish_reason == "error":
|
||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||
stop_reason = "error"
|
||||
error = final_content
|
||||
self._append_final_message(messages, final_content)
|
||||
self._append_model_error_placeholder(messages)
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
@ -274,7 +433,7 @@ class AgentRunner:
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
|
||||
messages.append(build_assistant_message(
|
||||
messages.append(assistant_message or build_assistant_message(
|
||||
clean,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
@ -317,6 +476,7 @@ class AgentRunner:
|
||||
stop_reason=stop_reason,
|
||||
error=error,
|
||||
tool_events=tool_events,
|
||||
had_injections=had_injections,
|
||||
)
|
||||
|
||||
def _build_request_kwargs(
|
||||
@ -521,6 +681,12 @@ class AgentRunner:
|
||||
return
|
||||
messages.append(build_assistant_message(content))
|
||||
|
||||
@staticmethod
|
||||
def _append_model_error_placeholder(messages: list[dict[str, Any]]) -> None:
|
||||
if messages and messages[-1].get("role") == "assistant" and not messages[-1].get("tool_calls"):
|
||||
return
|
||||
messages.append(build_assistant_message(_PERSISTED_MODEL_ERROR_PLACEHOLDER))
|
||||
|
||||
def _normalize_tool_result(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
@ -549,6 +715,32 @@ class AgentRunner:
|
||||
return truncate_text(content, spec.max_tool_result_chars)
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
def _drop_orphan_tool_results(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Drop tool results that have no matching assistant tool_call earlier in the history."""
|
||||
declared: set[str] = set()
|
||||
updated: list[dict[str, Any]] | None = None
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
if role == "tool":
|
||||
tid = msg.get("tool_call_id")
|
||||
if tid and str(tid) not in declared:
|
||||
if updated is None:
|
||||
updated = [dict(m) for m in messages[:idx]]
|
||||
continue
|
||||
if updated is not None:
|
||||
updated.append(dict(msg))
|
||||
|
||||
if updated is None:
|
||||
return messages
|
||||
return updated
|
||||
|
||||
@staticmethod
|
||||
def _backfill_missing_tool_results(
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@ -52,6 +52,7 @@ class SubagentManager:
|
||||
web_config: "WebToolsConfig | None" = None,
|
||||
exec_config: "ExecToolConfig | None" = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
disabled_skills: list[str] | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
|
||||
@ -63,6 +64,7 @@ class SubagentManager:
|
||||
self.max_tool_result_chars = max_tool_result_chars
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.disabled_skills = set(disabled_skills or [])
|
||||
self.runner = AgentRunner(provider)
|
||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||
@ -236,7 +238,10 @@ class SubagentManager:
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
|
||||
time_ctx = ContextBuilder._build_runtime_context(None, None)
|
||||
skills_summary = SkillsLoader(self.workspace).build_skills_summary()
|
||||
skills_summary = SkillsLoader(
|
||||
self.workspace,
|
||||
disabled_skills=self.disabled_skills,
|
||||
).build_skills_summary()
|
||||
return render_template(
|
||||
"agent/subagent_system.md",
|
||||
time_ctx=time_ctx,
|
||||
|
||||
105
nanobot/agent/tools/file_state.py
Normal file
105
nanobot/agent/tools/file_state.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""Track file-read state for read-before-edit warnings and read deduplication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ReadState:
|
||||
mtime: float
|
||||
offset: int
|
||||
limit: int | None
|
||||
content_hash: str | None
|
||||
can_dedup: bool
|
||||
|
||||
|
||||
_state: dict[str, ReadState] = {}
|
||||
|
||||
|
||||
def _hash_file(p: str) -> str | None:
|
||||
try:
|
||||
return hashlib.sha256(Path(p).read_bytes()).hexdigest()
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
def record_read(path: str | Path, offset: int = 1, limit: int | None = None) -> None:
|
||||
"""Record that a file was read (called after successful read)."""
|
||||
p = str(Path(path).resolve())
|
||||
try:
|
||||
mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
return
|
||||
_state[p] = ReadState(
|
||||
mtime=mtime,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
content_hash=_hash_file(p),
|
||||
can_dedup=True,
|
||||
)
|
||||
|
||||
|
||||
def record_write(path: str | Path) -> None:
|
||||
"""Record that a file was written (updates mtime in state)."""
|
||||
p = str(Path(path).resolve())
|
||||
try:
|
||||
mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
_state.pop(p, None)
|
||||
return
|
||||
_state[p] = ReadState(
|
||||
mtime=mtime,
|
||||
offset=1,
|
||||
limit=None,
|
||||
content_hash=_hash_file(p),
|
||||
can_dedup=False,
|
||||
)
|
||||
|
||||
|
||||
def check_read(path: str | Path) -> str | None:
|
||||
"""Check if a file has been read and is fresh.
|
||||
|
||||
Returns None if OK, or a warning string.
|
||||
When mtime changed but file content is identical (e.g. touch, editor save),
|
||||
the check passes to avoid false-positive staleness warnings.
|
||||
"""
|
||||
p = str(Path(path).resolve())
|
||||
entry = _state.get(p)
|
||||
if entry is None:
|
||||
return "Warning: file has not been read yet. Read it first to verify content before editing."
|
||||
try:
|
||||
current_mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
return None
|
||||
if current_mtime != entry.mtime:
|
||||
if entry.content_hash and _hash_file(p) == entry.content_hash:
|
||||
entry.mtime = current_mtime
|
||||
return None
|
||||
return "Warning: file has been modified since last read. Re-read to verify content before editing."
|
||||
return None
|
||||
|
||||
|
||||
def is_unchanged(path: str | Path, offset: int = 1, limit: int | None = None) -> bool:
|
||||
"""Return True if file was previously read with same params and mtime is unchanged."""
|
||||
p = str(Path(path).resolve())
|
||||
entry = _state.get(p)
|
||||
if entry is None:
|
||||
return False
|
||||
if not entry.can_dedup:
|
||||
return False
|
||||
if entry.offset != offset or entry.limit != limit:
|
||||
return False
|
||||
try:
|
||||
current_mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
return False
|
||||
return current_mtime == entry.mtime
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
"""Clear all tracked state (useful for testing)."""
|
||||
_state.clear()
|
||||
@ -2,11 +2,13 @@
|
||||
|
||||
import difflib
|
||||
import mimetypes
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools import file_state
|
||||
from nanobot.utils.helpers import build_image_content_blocks, detect_image_mime
|
||||
from nanobot.config.paths import get_media_dir
|
||||
|
||||
@ -60,6 +62,36 @@ class _FsTool(Tool):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_BLOCKED_DEVICE_PATHS = frozenset({
|
||||
"/dev/zero", "/dev/random", "/dev/urandom", "/dev/full",
|
||||
"/dev/stdin", "/dev/stdout", "/dev/stderr",
|
||||
"/dev/tty", "/dev/console",
|
||||
"/dev/fd/0", "/dev/fd/1", "/dev/fd/2",
|
||||
})
|
||||
|
||||
|
||||
def _is_blocked_device(path: str | Path) -> bool:
|
||||
"""Check if path is a blocked device that could hang or produce infinite output."""
|
||||
import re
|
||||
raw = str(path)
|
||||
if raw in _BLOCKED_DEVICE_PATHS:
|
||||
return True
|
||||
if re.match(r"/proc/\d+/fd/[012]$", raw) or re.match(r"/proc/self/fd/[012]$", raw):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _parse_page_range(pages: str, total: int) -> tuple[int, int]:
|
||||
"""Parse a page range like '2-5' into 0-based (start, end) inclusive."""
|
||||
parts = pages.strip().split("-")
|
||||
if len(parts) == 1:
|
||||
p = int(parts[0])
|
||||
return max(0, p - 1), min(p - 1, total - 1)
|
||||
start = int(parts[0])
|
||||
end = int(parts[1])
|
||||
return max(0, start - 1), min(end - 1, total - 1)
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
path=StringSchema("The file path to read"),
|
||||
@ -73,6 +105,7 @@ class _FsTool(Tool):
|
||||
description="Maximum number of lines to read (default 2000)",
|
||||
minimum=1,
|
||||
),
|
||||
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
|
||||
required=["path"],
|
||||
)
|
||||
)
|
||||
@ -81,6 +114,7 @@ class ReadFileTool(_FsTool):
|
||||
|
||||
_MAX_CHARS = 128_000
|
||||
_DEFAULT_LIMIT = 2000
|
||||
_MAX_PDF_PAGES = 20
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -89,9 +123,10 @@ class ReadFileTool(_FsTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a text file. Output format: LINE_NUM|CONTENT. "
|
||||
"Read a file (text or image). Text output format: LINE_NUM|CONTENT. "
|
||||
"Images return visual content for analysis. "
|
||||
"Use offset and limit for large files. "
|
||||
"Cannot read binary files or images. "
|
||||
"Cannot read non-image binary files. "
|
||||
"Reads exceeding ~128K chars are truncated."
|
||||
)
|
||||
|
||||
@ -99,16 +134,27 @@ class ReadFileTool(_FsTool):
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any:
|
||||
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any:
|
||||
try:
|
||||
if not path:
|
||||
return "Error reading file: Unknown path"
|
||||
|
||||
# Device path blacklist
|
||||
if _is_blocked_device(path):
|
||||
return f"Error: Reading {path} is blocked (device path that could hang or produce infinite output)."
|
||||
|
||||
fp = self._resolve(path)
|
||||
if _is_blocked_device(fp):
|
||||
return f"Error: Reading {fp} is blocked (device path that could hang or produce infinite output)."
|
||||
if not fp.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
if not fp.is_file():
|
||||
return f"Error: Not a file: {path}"
|
||||
|
||||
# PDF support
|
||||
if fp.suffix.lower() == ".pdf":
|
||||
return self._read_pdf(fp, pages)
|
||||
|
||||
raw = fp.read_bytes()
|
||||
if not raw:
|
||||
return f"(Empty file: {path})"
|
||||
@ -117,6 +163,10 @@ class ReadFileTool(_FsTool):
|
||||
if mime and mime.startswith("image/"):
|
||||
return build_image_content_blocks(raw, mime, str(fp), f"(Image file: {path})")
|
||||
|
||||
# Read dedup: same path + offset + limit + unchanged mtime → stub
|
||||
if file_state.is_unchanged(fp, offset=offset, limit=limit):
|
||||
return f"[File unchanged since last read: {path}]"
|
||||
|
||||
try:
|
||||
text_content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
@ -149,12 +199,59 @@ class ReadFileTool(_FsTool):
|
||||
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
|
||||
else:
|
||||
result += f"\n\n(End of file — {total} lines total)"
|
||||
file_state.record_read(fp, offset=offset, limit=limit)
|
||||
return result
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error reading file: {e}"
|
||||
|
||||
def _read_pdf(self, fp: Path, pages: str | None) -> str:
|
||||
try:
|
||||
import fitz # pymupdf
|
||||
except ImportError:
|
||||
return "Error: PDF reading requires pymupdf. Install with: pip install pymupdf"
|
||||
|
||||
try:
|
||||
doc = fitz.open(str(fp))
|
||||
except Exception as e:
|
||||
return f"Error reading PDF: {e}"
|
||||
|
||||
total_pages = len(doc)
|
||||
if pages:
|
||||
try:
|
||||
start, end = _parse_page_range(pages, total_pages)
|
||||
except (ValueError, IndexError):
|
||||
doc.close()
|
||||
return f"Error: Invalid page range '{pages}'. Use format like '1-5'."
|
||||
if start > end or start >= total_pages:
|
||||
doc.close()
|
||||
return f"Error: Page range '{pages}' is out of bounds (document has {total_pages} pages)."
|
||||
else:
|
||||
start = 0
|
||||
end = min(total_pages - 1, self._MAX_PDF_PAGES - 1)
|
||||
|
||||
if end - start + 1 > self._MAX_PDF_PAGES:
|
||||
end = start + self._MAX_PDF_PAGES - 1
|
||||
|
||||
parts: list[str] = []
|
||||
for i in range(start, end + 1):
|
||||
page = doc[i]
|
||||
text = page.get_text().strip()
|
||||
if text:
|
||||
parts.append(f"--- Page {i + 1} ---\n{text}")
|
||||
doc.close()
|
||||
|
||||
if not parts:
|
||||
return f"(PDF has no extractable text: {fp})"
|
||||
|
||||
result = "\n\n".join(parts)
|
||||
if end < total_pages - 1:
|
||||
result += f"\n\n(Showing pages {start + 1}-{end + 1} of {total_pages}. Use pages='{end + 2}-{min(end + 1 + self._MAX_PDF_PAGES, total_pages)}' to continue.)"
|
||||
if len(result) > self._MAX_CHARS:
|
||||
result = result[:self._MAX_CHARS] + "\n\n(PDF text truncated at ~128K chars)"
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# write_file
|
||||
@ -192,6 +289,7 @@ class WriteFileTool(_FsTool):
|
||||
fp = self._resolve(path)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(content, encoding="utf-8")
|
||||
file_state.record_write(fp)
|
||||
return f"Successfully wrote {len(content)} characters to {fp}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
@ -203,30 +301,269 @@ class WriteFileTool(_FsTool):
|
||||
# edit_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_QUOTE_TABLE = str.maketrans({
|
||||
"\u2018": "'", "\u2019": "'", # curly single → straight
|
||||
"\u201c": '"', "\u201d": '"', # curly double → straight
|
||||
"'": "'", '"': '"', # identity (kept for completeness)
|
||||
})
|
||||
|
||||
|
||||
def _normalize_quotes(s: str) -> str:
|
||||
return s.translate(_QUOTE_TABLE)
|
||||
|
||||
|
||||
def _curly_double_quotes(text: str) -> str:
|
||||
parts: list[str] = []
|
||||
opening = True
|
||||
for ch in text:
|
||||
if ch == '"':
|
||||
parts.append("\u201c" if opening else "\u201d")
|
||||
opening = not opening
|
||||
else:
|
||||
parts.append(ch)
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _curly_single_quotes(text: str) -> str:
|
||||
parts: list[str] = []
|
||||
opening = True
|
||||
for i, ch in enumerate(text):
|
||||
if ch != "'":
|
||||
parts.append(ch)
|
||||
continue
|
||||
prev_ch = text[i - 1] if i > 0 else ""
|
||||
next_ch = text[i + 1] if i + 1 < len(text) else ""
|
||||
if prev_ch.isalnum() and next_ch.isalnum():
|
||||
parts.append("\u2019")
|
||||
continue
|
||||
parts.append("\u2018" if opening else "\u2019")
|
||||
opening = not opening
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _preserve_quote_style(old_text: str, actual_text: str, new_text: str) -> str:
|
||||
"""Preserve curly quote style when a quote-normalized fallback matched."""
|
||||
if _normalize_quotes(old_text.strip()) != _normalize_quotes(actual_text.strip()) or old_text == actual_text:
|
||||
return new_text
|
||||
|
||||
styled = new_text
|
||||
if any(ch in actual_text for ch in ("\u201c", "\u201d")) and '"' in styled:
|
||||
styled = _curly_double_quotes(styled)
|
||||
if any(ch in actual_text for ch in ("\u2018", "\u2019")) and "'" in styled:
|
||||
styled = _curly_single_quotes(styled)
|
||||
return styled
|
||||
|
||||
|
||||
def _leading_ws(line: str) -> str:
|
||||
return line[: len(line) - len(line.lstrip(" \t"))]
|
||||
|
||||
|
||||
def _reindent_like_match(old_text: str, actual_text: str, new_text: str) -> str:
|
||||
"""Preserve the outer indentation from the actual matched block."""
|
||||
old_lines = old_text.split("\n")
|
||||
actual_lines = actual_text.split("\n")
|
||||
if len(old_lines) != len(actual_lines):
|
||||
return new_text
|
||||
|
||||
comparable = [
|
||||
(old_line, actual_line)
|
||||
for old_line, actual_line in zip(old_lines, actual_lines)
|
||||
if old_line.strip() and actual_line.strip()
|
||||
]
|
||||
if not comparable or any(
|
||||
_normalize_quotes(old_line.strip()) != _normalize_quotes(actual_line.strip())
|
||||
for old_line, actual_line in comparable
|
||||
):
|
||||
return new_text
|
||||
|
||||
old_ws = _leading_ws(comparable[0][0])
|
||||
actual_ws = _leading_ws(comparable[0][1])
|
||||
if actual_ws == old_ws:
|
||||
return new_text
|
||||
|
||||
if old_ws:
|
||||
if not actual_ws.startswith(old_ws):
|
||||
return new_text
|
||||
delta = actual_ws[len(old_ws):]
|
||||
else:
|
||||
delta = actual_ws
|
||||
|
||||
if not delta:
|
||||
return new_text
|
||||
|
||||
return "\n".join((delta + line) if line else line for line in new_text.split("\n"))
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _MatchSpan:
|
||||
start: int
|
||||
end: int
|
||||
text: str
|
||||
line: int
|
||||
|
||||
|
||||
def _find_exact_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||
matches: list[_MatchSpan] = []
|
||||
start = 0
|
||||
while True:
|
||||
idx = content.find(old_text, start)
|
||||
if idx == -1:
|
||||
break
|
||||
matches.append(
|
||||
_MatchSpan(
|
||||
start=idx,
|
||||
end=idx + len(old_text),
|
||||
text=content[idx : idx + len(old_text)],
|
||||
line=content.count("\n", 0, idx) + 1,
|
||||
)
|
||||
)
|
||||
start = idx + max(1, len(old_text))
|
||||
return matches
|
||||
|
||||
|
||||
def _find_trim_matches(content: str, old_text: str, *, normalize_quotes: bool = False) -> list[_MatchSpan]:
|
||||
old_lines = old_text.splitlines()
|
||||
if not old_lines:
|
||||
return []
|
||||
|
||||
content_lines = content.splitlines()
|
||||
content_lines_keepends = content.splitlines(keepends=True)
|
||||
if len(content_lines) < len(old_lines):
|
||||
return []
|
||||
|
||||
offsets: list[int] = []
|
||||
pos = 0
|
||||
for line in content_lines_keepends:
|
||||
offsets.append(pos)
|
||||
pos += len(line)
|
||||
offsets.append(pos)
|
||||
|
||||
if normalize_quotes:
|
||||
stripped_old = [_normalize_quotes(line.strip()) for line in old_lines]
|
||||
else:
|
||||
stripped_old = [line.strip() for line in old_lines]
|
||||
|
||||
matches: list[_MatchSpan] = []
|
||||
window_size = len(stripped_old)
|
||||
for i in range(len(content_lines) - window_size + 1):
|
||||
window = content_lines[i : i + window_size]
|
||||
if normalize_quotes:
|
||||
comparable = [_normalize_quotes(line.strip()) for line in window]
|
||||
else:
|
||||
comparable = [line.strip() for line in window]
|
||||
if comparable != stripped_old:
|
||||
continue
|
||||
|
||||
start = offsets[i]
|
||||
end = offsets[i + window_size]
|
||||
if content_lines_keepends[i + window_size - 1].endswith("\n"):
|
||||
end -= 1
|
||||
matches.append(
|
||||
_MatchSpan(
|
||||
start=start,
|
||||
end=end,
|
||||
text=content[start:end],
|
||||
line=i + 1,
|
||||
)
|
||||
)
|
||||
return matches
|
||||
|
||||
|
||||
def _find_quote_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||
norm_content = _normalize_quotes(content)
|
||||
norm_old = _normalize_quotes(old_text)
|
||||
matches: list[_MatchSpan] = []
|
||||
start = 0
|
||||
while True:
|
||||
idx = norm_content.find(norm_old, start)
|
||||
if idx == -1:
|
||||
break
|
||||
matches.append(
|
||||
_MatchSpan(
|
||||
start=idx,
|
||||
end=idx + len(old_text),
|
||||
text=content[idx : idx + len(old_text)],
|
||||
line=content.count("\n", 0, idx) + 1,
|
||||
)
|
||||
)
|
||||
start = idx + max(1, len(norm_old))
|
||||
return matches
|
||||
|
||||
|
||||
def _find_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||
"""Locate all matches using progressively looser strategies."""
|
||||
for matcher in (
|
||||
lambda: _find_exact_matches(content, old_text),
|
||||
lambda: _find_trim_matches(content, old_text),
|
||||
lambda: _find_trim_matches(content, old_text, normalize_quotes=True),
|
||||
lambda: _find_quote_matches(content, old_text),
|
||||
):
|
||||
matches = matcher()
|
||||
if matches:
|
||||
return matches
|
||||
return []
|
||||
|
||||
|
||||
def _find_match_line_numbers(content: str, old_text: str) -> list[int]:
|
||||
"""Return 1-based starting line numbers for the current matching strategies."""
|
||||
return [match.line for match in _find_matches(content, old_text)]
|
||||
|
||||
|
||||
def _collapse_internal_whitespace(text: str) -> str:
|
||||
return "\n".join(" ".join(line.split()) for line in text.splitlines())
|
||||
|
||||
|
||||
def _diagnose_near_match(old_text: str, actual_text: str) -> list[str]:
|
||||
"""Return actionable hints describing why text was close but not exact."""
|
||||
hints: list[str] = []
|
||||
|
||||
if old_text.lower() == actual_text.lower() and old_text != actual_text:
|
||||
hints.append("letter case differs")
|
||||
if _collapse_internal_whitespace(old_text) == _collapse_internal_whitespace(actual_text) and old_text != actual_text:
|
||||
hints.append("whitespace differs")
|
||||
if old_text.rstrip("\n") == actual_text.rstrip("\n") and old_text != actual_text:
|
||||
hints.append("trailing newline differs")
|
||||
if _normalize_quotes(old_text) == _normalize_quotes(actual_text) and old_text != actual_text:
|
||||
hints.append("quote style differs")
|
||||
|
||||
return hints
|
||||
|
||||
|
||||
def _best_window(old_text: str, content: str) -> tuple[float, int, list[str], list[str]]:
|
||||
"""Find the closest line-window match and return ratio/start/snippet/hints."""
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = max(1, len(old_lines))
|
||||
|
||||
best_ratio, best_start = -1.0, 0
|
||||
best_window_lines: list[str] = []
|
||||
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
current = lines[i : i + window]
|
||||
ratio = difflib.SequenceMatcher(None, old_lines, current).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start = ratio, i
|
||||
best_window_lines = current
|
||||
|
||||
actual_text = "".join(best_window_lines).replace("\r\n", "\n").rstrip("\n")
|
||||
hints = _diagnose_near_match(old_text.replace("\r\n", "\n").rstrip("\n"), actual_text)
|
||||
return best_ratio, best_start, best_window_lines, hints
|
||||
|
||||
|
||||
def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||
"""Locate old_text in content: exact first, then line-trimmed sliding window.
|
||||
"""Locate old_text in content with a multi-level fallback chain:
|
||||
|
||||
1. Exact substring match
|
||||
2. Line-trimmed sliding window (handles indentation differences)
|
||||
3. Smart quote normalization (curly ↔ straight quotes)
|
||||
|
||||
Both inputs should use LF line endings (caller normalises CRLF).
|
||||
Returns (matched_fragment, count) or (None, 0).
|
||||
"""
|
||||
if old_text in content:
|
||||
return old_text, content.count(old_text)
|
||||
|
||||
old_lines = old_text.splitlines()
|
||||
if not old_lines:
|
||||
matches = _find_matches(content, old_text)
|
||||
if not matches:
|
||||
return None, 0
|
||||
stripped_old = [l.strip() for l in old_lines]
|
||||
content_lines = content.splitlines()
|
||||
|
||||
candidates = []
|
||||
for i in range(len(content_lines) - len(stripped_old) + 1):
|
||||
window = content_lines[i : i + len(stripped_old)]
|
||||
if [l.strip() for l in window] == stripped_old:
|
||||
candidates.append("\n".join(window))
|
||||
|
||||
if candidates:
|
||||
return candidates[0], len(candidates)
|
||||
return None, 0
|
||||
return matches[0].text, len(matches)
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
@ -241,6 +578,9 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||
class EditFileTool(_FsTool):
|
||||
"""Edit a file by replacing text with fallback matching."""
|
||||
|
||||
_MAX_EDIT_FILE_SIZE = 1024 * 1024 * 1024 # 1 GiB
|
||||
_MARKDOWN_EXTS = frozenset({".md", ".mdx", ".markdown"})
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_file"
|
||||
@ -249,11 +589,16 @@ class EditFileTool(_FsTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a file by replacing old_text with new_text. "
|
||||
"Tolerates minor whitespace/indentation differences. "
|
||||
"Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. "
|
||||
"If old_text matches multiple times, you must provide more context "
|
||||
"or set replace_all=true. Shows a diff of the closest match on failure."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _strip_trailing_ws(text: str) -> str:
|
||||
"""Strip trailing whitespace from each line."""
|
||||
return "\n".join(line.rstrip() for line in text.split("\n"))
|
||||
|
||||
async def execute(
|
||||
self, path: str | None = None, old_text: str | None = None,
|
||||
new_text: str | None = None,
|
||||
@ -267,55 +612,133 @@ class EditFileTool(_FsTool):
|
||||
if new_text is None:
|
||||
raise ValueError("Unknown new_text")
|
||||
|
||||
# .ipynb detection
|
||||
if path.endswith(".ipynb"):
|
||||
return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file."
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
# Create-file semantics: old_text='' + file doesn't exist → create
|
||||
if not fp.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
if old_text == "":
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(new_text, encoding="utf-8")
|
||||
file_state.record_write(fp)
|
||||
return f"Successfully created {fp}"
|
||||
return self._file_not_found_msg(path, fp)
|
||||
|
||||
# File size protection
|
||||
try:
|
||||
fsize = fp.stat().st_size
|
||||
except OSError:
|
||||
fsize = 0
|
||||
if fsize > self._MAX_EDIT_FILE_SIZE:
|
||||
return f"Error: File too large to edit ({fsize / (1024**3):.1f} GiB). Maximum is 1 GiB."
|
||||
|
||||
# Create-file: old_text='' but file exists and not empty → reject
|
||||
if old_text == "":
|
||||
raw = fp.read_bytes()
|
||||
content = raw.decode("utf-8")
|
||||
if content.strip():
|
||||
return f"Error: Cannot create file — {path} already exists and is not empty."
|
||||
fp.write_text(new_text, encoding="utf-8")
|
||||
file_state.record_write(fp)
|
||||
return f"Successfully edited {fp}"
|
||||
|
||||
# Read-before-edit check
|
||||
warning = file_state.check_read(fp)
|
||||
|
||||
raw = fp.read_bytes()
|
||||
uses_crlf = b"\r\n" in raw
|
||||
content = raw.decode("utf-8").replace("\r\n", "\n")
|
||||
match, count = _find_match(content, old_text.replace("\r\n", "\n"))
|
||||
norm_old = old_text.replace("\r\n", "\n")
|
||||
matches = _find_matches(content, norm_old)
|
||||
|
||||
if match is None:
|
||||
if not matches:
|
||||
return self._not_found_msg(old_text, content, path)
|
||||
count = len(matches)
|
||||
if count > 1 and not replace_all:
|
||||
line_numbers = [match.line for match in matches]
|
||||
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
|
||||
if len(line_numbers) > 3:
|
||||
preview += ", ..."
|
||||
location_hint = f" at {preview}" if preview else ""
|
||||
return (
|
||||
f"Warning: old_text appears {count} times. "
|
||||
f"Warning: old_text appears {count} times{location_hint}. "
|
||||
"Provide more context to make it unique, or set replace_all=true."
|
||||
)
|
||||
|
||||
norm_new = new_text.replace("\r\n", "\n")
|
||||
new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1)
|
||||
|
||||
# Trailing whitespace stripping (skip markdown to preserve double-space line breaks)
|
||||
if fp.suffix.lower() not in self._MARKDOWN_EXTS:
|
||||
norm_new = self._strip_trailing_ws(norm_new)
|
||||
|
||||
selected = matches if replace_all else matches[:1]
|
||||
new_content = content
|
||||
for match in reversed(selected):
|
||||
replacement = _preserve_quote_style(norm_old, match.text, norm_new)
|
||||
replacement = _reindent_like_match(norm_old, match.text, replacement)
|
||||
|
||||
# Delete-line cleanup: when deleting text (new_text=''), consume trailing
|
||||
# newline to avoid leaving a blank line
|
||||
end = match.end
|
||||
if replacement == "" and not match.text.endswith("\n") and content[end:end + 1] == "\n":
|
||||
end += 1
|
||||
|
||||
new_content = new_content[: match.start] + replacement + new_content[end:]
|
||||
if uses_crlf:
|
||||
new_content = new_content.replace("\n", "\r\n")
|
||||
|
||||
fp.write_bytes(new_content.encode("utf-8"))
|
||||
return f"Successfully edited {fp}"
|
||||
file_state.record_write(fp)
|
||||
msg = f"Successfully edited {fp}"
|
||||
if warning:
|
||||
msg = f"{warning}\n{msg}"
|
||||
return msg
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing file: {e}"
|
||||
|
||||
def _file_not_found_msg(self, path: str, fp: Path) -> str:
|
||||
"""Build an error message with 'Did you mean ...?' suggestions."""
|
||||
parent = fp.parent
|
||||
suggestions: list[str] = []
|
||||
if parent.is_dir():
|
||||
siblings = [f.name for f in parent.iterdir() if f.is_file()]
|
||||
close = difflib.get_close_matches(fp.name, siblings, n=3, cutoff=0.6)
|
||||
suggestions = [str(parent / c) for c in close]
|
||||
parts = [f"Error: File not found: {path}"]
|
||||
if suggestions:
|
||||
parts.append("Did you mean: " + ", ".join(suggestions) + "?")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _not_found_msg(old_text: str, content: str, path: str) -> str:
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = len(old_lines)
|
||||
|
||||
best_ratio, best_start = 0.0, 0
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start = ratio, i
|
||||
|
||||
best_ratio, best_start, best_window_lines, hints = _best_window(old_text, content)
|
||||
if best_ratio > 0.5:
|
||||
diff = "\n".join(difflib.unified_diff(
|
||||
old_lines, lines[best_start : best_start + window],
|
||||
old_text.splitlines(keepends=True),
|
||||
best_window_lines,
|
||||
fromfile="old_text (provided)",
|
||||
tofile=f"{path} (actual, line {best_start + 1})",
|
||||
lineterm="",
|
||||
))
|
||||
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
hint_text = ""
|
||||
if hints:
|
||||
hint_text = "\nPossible cause: " + ", ".join(hints) + "."
|
||||
return (
|
||||
f"Error: old_text not found in {path}."
|
||||
f"{hint_text}\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
)
|
||||
|
||||
if hints:
|
||||
return (
|
||||
f"Error: old_text not found in {path}. "
|
||||
f"Possible cause: {', '.join(hints)}. "
|
||||
"Copy the exact text from read_file and try again."
|
||||
)
|
||||
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||
|
||||
|
||||
|
||||
@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
|
||||
|
||||
if "properties" in normalized and isinstance(normalized["properties"], dict):
|
||||
normalized["properties"] = {
|
||||
name: _normalize_schema_for_openai(prop)
|
||||
if isinstance(prop, dict)
|
||||
else prop
|
||||
name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop
|
||||
for name, prop in normalized["properties"].items()
|
||||
}
|
||||
|
||||
@ -138,9 +136,7 @@ class MCPToolWrapper(Tool):
|
||||
class MCPResourceWrapper(Tool):
|
||||
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
||||
|
||||
def __init__(
|
||||
self, session, server_name: str, resource_def, resource_timeout: int = 30
|
||||
):
|
||||
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
|
||||
self._session = session
|
||||
self._uri = resource_def.uri
|
||||
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
|
||||
@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool):
|
||||
class MCPPromptWrapper(Tool):
|
||||
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
||||
|
||||
def __init__(
|
||||
self, session, server_name: str, prompt_def, prompt_timeout: int = 30
|
||||
):
|
||||
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
|
||||
self._session = session
|
||||
self._prompt_name = prompt_def.name
|
||||
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
|
||||
@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool):
|
||||
timeout=self._prompt_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout
|
||||
)
|
||||
logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout)
|
||||
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
|
||||
except asyncio.CancelledError:
|
||||
task = asyncio.current_task()
|
||||
@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool):
|
||||
except McpError as exc:
|
||||
logger.error(
|
||||
"MCP prompt '{}' failed: code={} message={}",
|
||||
self._name, exc.error.code, exc.error.message,
|
||||
self._name,
|
||||
exc.error.code,
|
||||
exc.error.message,
|
||||
)
|
||||
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"MCP prompt '{}' failed: {}: {}",
|
||||
self._name, type(exc).__name__, exc,
|
||||
self._name,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
return f"(MCP prompt call failed: {type(exc).__name__})"
|
||||
|
||||
@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool):
|
||||
|
||||
|
||||
async def connect_mcp_servers(
|
||||
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
|
||||
) -> None:
|
||||
"""Connect to configured MCP servers and register their tools, resources, and prompts."""
|
||||
mcp_servers: dict, registry: ToolRegistry
|
||||
) -> dict[str, AsyncExitStack]:
|
||||
"""Connect to configured MCP servers and register their tools, resources, prompts.
|
||||
|
||||
Returns a dict mapping server name -> its dedicated AsyncExitStack.
|
||||
Each server gets its own stack and runs in its own task to prevent
|
||||
cancel scope conflicts when multiple MCP servers are configured.
|
||||
"""
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
|
||||
for name, cfg in mcp_servers.items():
|
||||
async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]:
|
||||
server_stack = AsyncExitStack()
|
||||
await server_stack.__aenter__()
|
||||
|
||||
try:
|
||||
transport_type = cfg.type
|
||||
if not transport_type:
|
||||
if cfg.command:
|
||||
transport_type = "stdio"
|
||||
elif cfg.url:
|
||||
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
|
||||
transport_type = (
|
||||
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||
continue
|
||||
await server_stack.aclose()
|
||||
return name, None
|
||||
|
||||
if transport_type == "stdio":
|
||||
params = StdioServerParameters(
|
||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
read, write = await server_stack.enter_async_context(stdio_client(params))
|
||||
elif transport_type == "sse":
|
||||
|
||||
def httpx_client_factory(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
@ -353,27 +358,26 @@ async def connect_mcp_servers(
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
read, write = await stack.enter_async_context(
|
||||
read, write = await server_stack.enter_async_context(
|
||||
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
||||
)
|
||||
elif transport_type == "streamableHttp":
|
||||
# Always provide an explicit httpx client so MCP HTTP transport does not
|
||||
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
||||
http_client = await stack.enter_async_context(
|
||||
http_client = await server_stack.enter_async_context(
|
||||
httpx.AsyncClient(
|
||||
headers=cfg.headers or None,
|
||||
follow_redirects=True,
|
||||
timeout=None,
|
||||
)
|
||||
)
|
||||
read, write, _ = await stack.enter_async_context(
|
||||
read, write, _ = await server_stack.enter_async_context(
|
||||
streamable_http_client(cfg.url, http_client=http_client)
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
||||
continue
|
||||
await server_stack.aclose()
|
||||
return name, None
|
||||
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
session = await server_stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
|
||||
tools = await session.list_tools()
|
||||
@ -418,7 +422,6 @@ async def connect_mcp_servers(
|
||||
", ".join(available_wrapped_names) or "(none)",
|
||||
)
|
||||
|
||||
# --- Register resources ---
|
||||
try:
|
||||
resources_result = await session.list_resources()
|
||||
for resource in resources_result.resources:
|
||||
@ -433,7 +436,6 @@ async def connect_mcp_servers(
|
||||
except Exception as e:
|
||||
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
|
||||
|
||||
# --- Register prompts ---
|
||||
try:
|
||||
prompts_result = await session.list_prompts()
|
||||
for prompt in prompts_result.prompts:
|
||||
@ -442,14 +444,38 @@ async def connect_mcp_servers(
|
||||
)
|
||||
registry.register(wrapper)
|
||||
registered_count += 1
|
||||
logger.debug(
|
||||
"MCP: registered prompt '{}' from server '{}'", wrapper.name, name
|
||||
)
|
||||
logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name)
|
||||
except Exception as e:
|
||||
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
|
||||
|
||||
logger.info(
|
||||
"MCP server '{}': connected, {} capabilities registered", name, registered_count
|
||||
)
|
||||
return name, server_stack
|
||||
|
||||
except Exception as e:
|
||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||
try:
|
||||
await server_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
return name, None
|
||||
|
||||
server_stacks: dict[str, AsyncExitStack] = {}
|
||||
|
||||
tasks: list[asyncio.Task] = []
|
||||
for name, cfg in mcp_servers.items():
|
||||
task = asyncio.create_task(connect_single_server(name, cfg))
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
name = list(mcp_servers.keys())[i]
|
||||
if isinstance(result, BaseException):
|
||||
if not isinstance(result, asyncio.CancelledError):
|
||||
logger.error("MCP server '{}' connection task failed: {}", name, result)
|
||||
elif result is not None and result[1] is not None:
|
||||
server_stacks[result[0]] = result[1]
|
||||
|
||||
return server_stacks
|
||||
|
||||
161
nanobot/agent/tools/notebook.py
Normal file
161
nanobot/agent/tools/notebook.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""NotebookEditTool — edit Jupyter .ipynb notebooks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import tool_parameters
|
||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools.filesystem import _FsTool
|
||||
|
||||
|
||||
def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict:
|
||||
cell: dict[str, Any] = {
|
||||
"cell_type": cell_type,
|
||||
"source": source,
|
||||
"metadata": {},
|
||||
}
|
||||
if cell_type == "code":
|
||||
cell["outputs"] = []
|
||||
cell["execution_count"] = None
|
||||
if generate_id:
|
||||
cell["id"] = uuid.uuid4().hex[:8]
|
||||
return cell
|
||||
|
||||
|
||||
def _make_empty_notebook() -> dict:
|
||||
return {
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5,
|
||||
"metadata": {
|
||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
||||
"language_info": {"name": "python"},
|
||||
},
|
||||
"cells": [],
|
||||
}
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
path=StringSchema("Path to the .ipynb notebook file"),
|
||||
cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0),
|
||||
new_source=StringSchema("New source content for the cell"),
|
||||
cell_type=StringSchema(
|
||||
"Cell type: 'code' or 'markdown' (default: code)",
|
||||
enum=["code", "markdown"],
|
||||
),
|
||||
edit_mode=StringSchema(
|
||||
"Mode: 'replace' (default), 'insert' (after target), or 'delete'",
|
||||
enum=["replace", "insert", "delete"],
|
||||
),
|
||||
required=["path", "cell_index"],
|
||||
)
|
||||
)
|
||||
class NotebookEditTool(_FsTool):
|
||||
"""Edit Jupyter notebook cells: replace, insert, or delete."""
|
||||
|
||||
_VALID_CELL_TYPES = frozenset({"code", "markdown"})
|
||||
_VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"})
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "notebook_edit"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a Jupyter notebook (.ipynb) cell. "
|
||||
"Modes: replace (default) replaces cell content, "
|
||||
"insert adds a new cell after the target index, "
|
||||
"delete removes the cell at the index. "
|
||||
"cell_index is 0-based."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
path: str | None = None,
|
||||
cell_index: int = 0,
|
||||
new_source: str = "",
|
||||
cell_type: str = "code",
|
||||
edit_mode: str = "replace",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if not path:
|
||||
return "Error: path is required"
|
||||
|
||||
if not path.endswith(".ipynb"):
|
||||
return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files."
|
||||
|
||||
if edit_mode not in self._VALID_EDIT_MODES:
|
||||
return (
|
||||
f"Error: Invalid edit_mode '{edit_mode}'. "
|
||||
"Use one of: replace, insert, delete."
|
||||
)
|
||||
|
||||
if cell_type not in self._VALID_CELL_TYPES:
|
||||
return (
|
||||
f"Error: Invalid cell_type '{cell_type}'. "
|
||||
"Use one of: code, markdown."
|
||||
)
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
# Create new notebook if file doesn't exist and mode is insert
|
||||
if not fp.exists():
|
||||
if edit_mode != "insert":
|
||||
return f"Error: File not found: {path}"
|
||||
nb = _make_empty_notebook()
|
||||
cell = _new_cell(new_source, cell_type, generate_id=True)
|
||||
nb["cells"].append(cell)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully created {fp} with 1 cell"
|
||||
|
||||
try:
|
||||
nb = json.loads(fp.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
return f"Error: Failed to parse notebook: {e}"
|
||||
|
||||
cells = nb.get("cells", [])
|
||||
nbformat_minor = nb.get("nbformat_minor", 0)
|
||||
generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5
|
||||
|
||||
if edit_mode == "delete":
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells.pop(cell_index)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully deleted cell {cell_index} from {fp}"
|
||||
|
||||
if edit_mode == "insert":
|
||||
insert_at = min(cell_index + 1, len(cells))
|
||||
cell = _new_cell(new_source, cell_type, generate_id=generate_id)
|
||||
cells.insert(insert_at, cell)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully inserted cell at index {insert_at} in {fp}"
|
||||
|
||||
# Default: replace
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells[cell_index]["source"] = new_source
|
||||
if cell_type and cells[cell_index].get("cell_type") != cell_type:
|
||||
cells[cell_index]["cell_type"] = cell_type
|
||||
if cell_type == "code":
|
||||
cells[cell_index].setdefault("outputs", [])
|
||||
cells[cell_index].setdefault("execution_count", None)
|
||||
elif "outputs" in cells[cell_index]:
|
||||
del cells[cell_index]["outputs"]
|
||||
cells[cell_index].pop("execution_count", None)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully edited cell {cell_index} in {fp}"
|
||||
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing notebook: {e}"
|
||||
@ -46,6 +46,7 @@ class ExecTool(Tool):
|
||||
restrict_to_workspace: bool = False,
|
||||
sandbox: str = "",
|
||||
path_append: str = "",
|
||||
allowed_env_keys: list[str] | None = None,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.working_dir = working_dir
|
||||
@ -64,6 +65,7 @@ class ExecTool(Tool):
|
||||
self.allow_patterns = allow_patterns or []
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.path_append = path_append
|
||||
self.allowed_env_keys = allowed_env_keys or []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -208,7 +210,7 @@ class ExecTool(Tool):
|
||||
"""
|
||||
if _IS_WINDOWS:
|
||||
sr = os.environ.get("SYSTEMROOT", r"C:\Windows")
|
||||
return {
|
||||
env = {
|
||||
"SYSTEMROOT": sr,
|
||||
"COMSPEC": os.environ.get("COMSPEC", f"{sr}\\system32\\cmd.exe"),
|
||||
"USERPROFILE": os.environ.get("USERPROFILE", ""),
|
||||
@ -225,12 +227,22 @@ class ExecTool(Tool):
|
||||
"ProgramFiles(x86)": os.environ.get("ProgramFiles(x86)", ""),
|
||||
"ProgramW6432": os.environ.get("ProgramW6432", ""),
|
||||
}
|
||||
for key in self.allowed_env_keys:
|
||||
val = os.environ.get(key)
|
||||
if val is not None:
|
||||
env[key] = val
|
||||
return env
|
||||
home = os.environ.get("HOME", "/tmp")
|
||||
return {
|
||||
env = {
|
||||
"HOME": home,
|
||||
"LANG": os.environ.get("LANG", "C.UTF-8"),
|
||||
"TERM": os.environ.get("TERM", "dumb"),
|
||||
}
|
||||
for key in self.allowed_env_keys:
|
||||
val = os.environ.get(key)
|
||||
if val is not None:
|
||||
env[key] = val
|
||||
return env
|
||||
|
||||
def _guard_command(self, command: str, cwd: str) -> str | None:
|
||||
"""Best-effort safety guard for potentially destructive commands."""
|
||||
|
||||
@ -114,6 +114,8 @@ class WebSearchTool(Tool):
|
||||
return await self._search_jina(query, n)
|
||||
elif provider == "brave":
|
||||
return await self._search_brave(query, n)
|
||||
elif provider == "kagi":
|
||||
return await self._search_kagi(query, n)
|
||||
else:
|
||||
return f"Error: unknown search provider '{provider}'"
|
||||
|
||||
@ -204,6 +206,29 @@ class WebSearchTool(Tool):
|
||||
logger.warning("Jina search failed ({}), falling back to DuckDuckGo", e)
|
||||
return await self._search_duckduckgo(query, n)
|
||||
|
||||
async def _search_kagi(self, query: str, n: int) -> str:
|
||||
api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "")
|
||||
if not api_key:
|
||||
logger.warning("KAGI_API_KEY not set, falling back to DuckDuckGo")
|
||||
return await self._search_duckduckgo(query, n)
|
||||
try:
|
||||
async with httpx.AsyncClient(proxy=self.proxy) as client:
|
||||
r = await client.get(
|
||||
"https://kagi.com/api/v0/search",
|
||||
params={"q": query, "limit": n},
|
||||
headers={"Authorization": f"Bot {api_key}"},
|
||||
timeout=10.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
# t=0 items are search results; other values are related searches, etc.
|
||||
items = [
|
||||
{"title": d.get("title", ""), "url": d.get("url", ""), "content": d.get("snippet", "")}
|
||||
for d in r.json().get("data", []) if d.get("t") == 0
|
||||
]
|
||||
return _format_results(query, items, n)
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
async def _search_duckduckgo(self, query: str, n: int) -> str:
|
||||
try:
|
||||
# Note: duckduckgo_search is synchronous and does its own requests
|
||||
|
||||
@ -22,6 +22,7 @@ from nanobot.utils.helpers import safe_filename, split_message
|
||||
|
||||
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.abc import Messageable
|
||||
@ -58,6 +59,9 @@ class DiscordConfig(Base):
|
||||
working_emoji: str = "🔧"
|
||||
working_emoji_delay: float = 2.0
|
||||
streaming: bool = True
|
||||
proxy: str | None = None
|
||||
proxy_username: str | None = None
|
||||
proxy_password: str | None = None
|
||||
|
||||
|
||||
if DISCORD_AVAILABLE:
|
||||
@ -65,8 +69,15 @@ if DISCORD_AVAILABLE:
|
||||
class DiscordBotClient(discord.Client):
|
||||
"""discord.py client that forwards events to the channel."""
|
||||
|
||||
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
|
||||
super().__init__(intents=intents)
|
||||
def __init__(
|
||||
self,
|
||||
channel: DiscordChannel,
|
||||
*,
|
||||
intents: discord.Intents,
|
||||
proxy: str | None = None,
|
||||
proxy_auth: aiohttp.BasicAuth | None = None,
|
||||
) -> None:
|
||||
super().__init__(intents=intents, proxy=proxy, proxy_auth=proxy_auth)
|
||||
self._channel = channel
|
||||
self.tree = app_commands.CommandTree(self)
|
||||
self._register_app_commands()
|
||||
@ -130,6 +141,7 @@ if DISCORD_AVAILABLE:
|
||||
)
|
||||
|
||||
for name, description, command_text in commands:
|
||||
|
||||
@self.tree.command(name=name, description=description)
|
||||
async def command_handler(
|
||||
interaction: discord.Interaction,
|
||||
@ -186,7 +198,9 @@ if DISCORD_AVAILABLE:
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
|
||||
for index, chunk in enumerate(
|
||||
self._build_chunks(msg.content or "", failed_media, sent_media)
|
||||
):
|
||||
kwargs: dict[str, Any] = {"content": chunk}
|
||||
if index == 0 and reference is not None and not sent_media:
|
||||
kwargs["reference"] = reference
|
||||
@ -292,7 +306,29 @@ class DiscordChannel(BaseChannel):
|
||||
try:
|
||||
intents = discord.Intents.none()
|
||||
intents.value = self.config.intents
|
||||
self._client = DiscordBotClient(self, intents=intents)
|
||||
|
||||
proxy_auth = None
|
||||
has_user = bool(self.config.proxy_username)
|
||||
has_pass = bool(self.config.proxy_password)
|
||||
if has_user and has_pass:
|
||||
import aiohttp
|
||||
|
||||
proxy_auth = aiohttp.BasicAuth(
|
||||
login=self.config.proxy_username,
|
||||
password=self.config.proxy_password,
|
||||
)
|
||||
elif has_user != has_pass:
|
||||
logger.warning(
|
||||
"Discord proxy auth incomplete: both proxy_username and "
|
||||
"proxy_password must be set; ignoring partial credentials",
|
||||
)
|
||||
|
||||
self._client = DiscordBotClient(
|
||||
self,
|
||||
intents=intents,
|
||||
proxy=self.config.proxy,
|
||||
proxy_auth=proxy_auth,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Discord client: {}", e)
|
||||
self._client = None
|
||||
@ -335,7 +371,9 @@ class DiscordChannel(BaseChannel):
|
||||
await self._stop_typing(msg.chat_id)
|
||||
await self._clear_reactions(msg.chat_id)
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||
async def send_delta(
|
||||
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Progressive Discord delivery: send once, then edit until the stream ends."""
|
||||
client = self._client
|
||||
if client is None or not client.is_ready():
|
||||
@ -355,7 +393,9 @@ class DiscordChannel(BaseChannel):
|
||||
return
|
||||
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id):
|
||||
if buf is None or (
|
||||
stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id
|
||||
):
|
||||
buf = _StreamBuf(stream_id=stream_id)
|
||||
self._stream_bufs[chat_id] = buf
|
||||
elif buf.stream_id is None:
|
||||
@ -534,7 +574,11 @@ class DiscordChannel(BaseChannel):
|
||||
@staticmethod
|
||||
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
||||
"""Build metadata for inbound Discord messages."""
|
||||
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
|
||||
reply_to = (
|
||||
str(message.reference.message_id)
|
||||
if message.reference and message.reference.message_id
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"message_id": str(message.id),
|
||||
"guild_id": str(message.guild.id) if message.guild else None,
|
||||
@ -549,7 +593,9 @@ class DiscordChannel(BaseChannel):
|
||||
if self.config.group_policy == "mention":
|
||||
bot_user_id = self._bot_user_id
|
||||
if bot_user_id is None:
|
||||
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
|
||||
logger.debug(
|
||||
"Discord message in {} ignored (bot identity unavailable)", message.channel.id
|
||||
)
|
||||
return False
|
||||
|
||||
if any(str(user.id) == bot_user_id for user in message.mentions):
|
||||
@ -591,7 +637,6 @@ class DiscordChannel(BaseChannel):
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
async def _clear_reactions(self, chat_id: str) -> None:
|
||||
"""Remove all pending reactions after bot replies."""
|
||||
# Cancel delayed working emoji if it hasn't fired yet
|
||||
|
||||
@ -22,6 +22,8 @@ from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
|
||||
|
||||
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||
|
||||
# Message type display mapping
|
||||
@ -250,9 +252,12 @@ class FeishuConfig(Base):
|
||||
verification_token: str = ""
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
react_emoji: str = "THUMBSUP"
|
||||
done_emoji: str | None = None # Emoji to show when task is completed (e.g., "DONE", "OK")
|
||||
tool_hint_prefix: str = "\U0001f527" # Prefix for inline tool hints (default: 🔧)
|
||||
group_policy: Literal["open", "mention"] = "mention"
|
||||
reply_to_message: bool = False # If True, bot replies quote the user's original message
|
||||
streaming: bool = True
|
||||
domain: Literal["feishu", "lark"] = "feishu" # Set to "lark" for international Lark
|
||||
|
||||
|
||||
_STREAM_ELEMENT_ID = "streaming_md"
|
||||
@ -326,10 +331,12 @@ class FeishuChannel(BaseChannel):
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Create Lark client for sending messages
|
||||
domain = LARK_DOMAIN if self.config.domain == "lark" else FEISHU_DOMAIN
|
||||
self._client = (
|
||||
lark.Client.builder()
|
||||
.app_id(self.config.app_id)
|
||||
.app_secret(self.config.app_secret)
|
||||
.domain(domain)
|
||||
.log_level(lark.LogLevel.INFO)
|
||||
.build()
|
||||
)
|
||||
@ -357,6 +364,7 @@ class FeishuChannel(BaseChannel):
|
||||
self._ws_client = lark.ws.Client(
|
||||
self.config.app_id,
|
||||
self.config.app_secret,
|
||||
domain=domain,
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.INFO,
|
||||
)
|
||||
@ -1012,14 +1020,29 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
elif msg_type in ("audio", "file", "media"):
|
||||
file_key = content_json.get("file_key")
|
||||
if file_key and message_id:
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_file_sync, message_id, file_key, msg_type
|
||||
)
|
||||
if not filename:
|
||||
filename = file_key[:16]
|
||||
if msg_type == "audio" and not filename.endswith(".opus"):
|
||||
filename = f"{filename}.opus"
|
||||
if not file_key:
|
||||
logger.warning("Feishu {} message missing file_key: {}", msg_type, content_json)
|
||||
return None, f"[{msg_type}: missing file_key]"
|
||||
if not message_id:
|
||||
logger.warning("Feishu {} message missing message_id", msg_type)
|
||||
return None, f"[{msg_type}: missing message_id]"
|
||||
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_file_sync, message_id, file_key, msg_type
|
||||
)
|
||||
|
||||
if not data:
|
||||
logger.warning("Feishu {} download failed: file_key={}", msg_type, file_key)
|
||||
return None, f"[{msg_type}: download failed]"
|
||||
|
||||
if not filename:
|
||||
filename = file_key[:16]
|
||||
|
||||
# Feishu voice messages are opus in OGG container.
|
||||
# Use .ogg extension for better Whisper compatibility.
|
||||
if msg_type == "audio":
|
||||
if not any(filename.endswith(ext) for ext in (".opus", ".ogg", ".oga")):
|
||||
filename = f"{filename}.ogg"
|
||||
|
||||
if data and filename:
|
||||
file_path = media_dir / filename
|
||||
@ -1263,7 +1286,15 @@ class FeishuChannel(BaseChannel):
|
||||
async def send_delta(
|
||||
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Progressive streaming via CardKit: create card on first delta, stream-update on subsequent."""
|
||||
"""Progressive streaming via CardKit: create card on first delta, stream-update on subsequent.
|
||||
|
||||
Supported metadata keys:
|
||||
_stream_end: Finalize the streaming card.
|
||||
_resuming: Mid-turn pause – flush but keep the buffer alive.
|
||||
_tool_hint: Delta is a formatted tool hint (for display only).
|
||||
message_id: Original message id (used with _stream_end for reaction cleanup).
|
||||
reaction_id: Reaction id to remove on stream end.
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
meta = metadata or {}
|
||||
@ -1274,6 +1305,22 @@ class FeishuChannel(BaseChannel):
|
||||
if meta.get("_stream_end"):
|
||||
if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")):
|
||||
await self._remove_reaction(message_id, reaction_id)
|
||||
# Add completion emoji if configured
|
||||
if self.config.done_emoji and message_id:
|
||||
await self._add_reaction(message_id, self.config.done_emoji)
|
||||
|
||||
resuming = meta.get("_resuming", False)
|
||||
if resuming:
|
||||
# Mid-turn pause (e.g. tool call between streaming segments).
|
||||
# Flush current text to card but keep the buffer alive so the
|
||||
# next segment appends to the same card.
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
if buf and buf.card_id and buf.text:
|
||||
buf.sequence += 1
|
||||
await loop.run_in_executor(
|
||||
None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence,
|
||||
)
|
||||
return
|
||||
|
||||
buf = self._stream_bufs.pop(chat_id, None)
|
||||
if not buf or not buf.text:
|
||||
@ -1346,13 +1393,26 @@ class FeishuChannel(BaseChannel):
|
||||
receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id"
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Handle tool hint messages as code blocks in interactive cards.
|
||||
# These are progress-only messages and should bypass normal reply routing.
|
||||
# Handle tool hint messages. When a streaming card is active for
|
||||
# this chat, inline the hint into the card instead of sending a
|
||||
# separate message so the user experience stays cohesive.
|
||||
if msg.metadata.get("_tool_hint"):
|
||||
if msg.content and msg.content.strip():
|
||||
await self._send_tool_hint_card(
|
||||
receive_id_type, msg.chat_id, msg.content.strip()
|
||||
)
|
||||
hint = (msg.content or "").strip()
|
||||
if not hint:
|
||||
return
|
||||
buf = self._stream_bufs.get(msg.chat_id)
|
||||
if buf and buf.card_id:
|
||||
# Delegate to send_delta so tool hints get the same
|
||||
# throttling (and card creation) as regular text deltas.
|
||||
lines = self.__class__._format_tool_hint_lines(hint).split("\n")
|
||||
delta = "\n\n" + "\n".join(
|
||||
f"{self.config.tool_hint_prefix} {ln}" for ln in lines if ln.strip()
|
||||
) + "\n\n"
|
||||
await self.send_delta(msg.chat_id, delta)
|
||||
return
|
||||
await self._send_tool_hint_card(
|
||||
receive_id_type, msg.chat_id, hint
|
||||
)
|
||||
return
|
||||
|
||||
# Determine whether the first message should quote the user's message.
|
||||
@ -1661,7 +1721,7 @@ class FeishuChannel(BaseChannel):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Put each top-level tool call on its own line without altering commas inside arguments.
|
||||
formatted_code = self._format_tool_hint_lines(tool_hint)
|
||||
formatted_code = self.__class__._format_tool_hint_lines(tool_hint)
|
||||
|
||||
card = {
|
||||
"config": {"wide_screen_mode": True},
|
||||
|
||||
@ -242,43 +242,46 @@ class QQChannel(BaseChannel):
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send attachments first, then text."""
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
try:
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
|
||||
msg_id = msg.metadata.get("message_id")
|
||||
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||
is_group = chat_type == "group"
|
||||
msg_id = msg.metadata.get("message_id")
|
||||
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||
is_group = chat_type == "group"
|
||||
|
||||
# 1) Send media
|
||||
for media_ref in msg.media or []:
|
||||
ok = await self._send_media(
|
||||
chat_id=msg.chat_id,
|
||||
media_ref=media_ref,
|
||||
msg_id=msg_id,
|
||||
is_group=is_group,
|
||||
)
|
||||
if not ok:
|
||||
filename = (
|
||||
os.path.basename(urlparse(media_ref).path)
|
||||
or os.path.basename(media_ref)
|
||||
or "file"
|
||||
# 1) Send media
|
||||
for media_ref in msg.media or []:
|
||||
ok = await self._send_media(
|
||||
chat_id=msg.chat_id,
|
||||
media_ref=media_ref,
|
||||
msg_id=msg_id,
|
||||
is_group=is_group,
|
||||
)
|
||||
if not ok:
|
||||
filename = (
|
||||
os.path.basename(urlparse(media_ref).path)
|
||||
or os.path.basename(media_ref)
|
||||
or "file"
|
||||
)
|
||||
await self._send_text_only(
|
||||
chat_id=msg.chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=msg_id,
|
||||
content=f"[Attachment send failed: {filename}]",
|
||||
)
|
||||
|
||||
# 2) Send text
|
||||
if msg.content and msg.content.strip():
|
||||
await self._send_text_only(
|
||||
chat_id=msg.chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=msg_id,
|
||||
content=f"[Attachment send failed: {filename}]",
|
||||
content=msg.content.strip(),
|
||||
)
|
||||
|
||||
# 2) Send text
|
||||
if msg.content and msg.content.strip():
|
||||
await self._send_text_only(
|
||||
chat_id=msg.chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=msg_id,
|
||||
content=msg.content.strip(),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error sending QQ message to chat_id={}", msg.chat_id)
|
||||
|
||||
async def _send_text_only(
|
||||
self,
|
||||
@ -438,15 +441,26 @@ class QQChannel(BaseChannel):
|
||||
endpoint = "/v2/users/{openid}/files"
|
||||
id_key = "openid"
|
||||
|
||||
payload = {
|
||||
payload: dict[str, Any] = {
|
||||
id_key: chat_id,
|
||||
"file_type": file_type,
|
||||
"file_data": file_data,
|
||||
"file_name": file_name,
|
||||
"srv_send_msg": srv_send_msg,
|
||||
}
|
||||
# Only pass file_name for non-image types (file_type=4).
|
||||
# Passing file_name for images causes QQ client to render them as
|
||||
# file attachments instead of inline images.
|
||||
if file_type != QQ_FILE_TYPE_IMAGE and file_name:
|
||||
payload["file_name"] = file_name
|
||||
|
||||
route = Route("POST", endpoint, **{id_key: chat_id})
|
||||
return await self._client.api._http.request(route, json=payload)
|
||||
result = await self._client.api._http.request(route, json=payload)
|
||||
|
||||
# Extract only the file_info field to avoid extra fields (file_uuid, ttl, etc.)
|
||||
# that may confuse QQ client when sending the media object.
|
||||
if isinstance(result, dict) and "file_info" in result:
|
||||
return {"file_info": result["file_info"]}
|
||||
return result
|
||||
|
||||
# ---------------------------
|
||||
# Inbound (receive)
|
||||
@ -454,58 +468,68 @@ class QQChannel(BaseChannel):
|
||||
|
||||
async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None:
|
||||
"""Parse inbound message, download attachments, and publish to the bus."""
|
||||
if data.id in self._processed_ids:
|
||||
return
|
||||
self._processed_ids.append(data.id)
|
||||
try:
|
||||
if data.id in self._processed_ids:
|
||||
return
|
||||
self._processed_ids.append(data.id)
|
||||
|
||||
if is_group:
|
||||
chat_id = data.group_openid
|
||||
user_id = data.author.member_openid
|
||||
self._chat_type_cache[chat_id] = "group"
|
||||
else:
|
||||
chat_id = str(
|
||||
getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown")
|
||||
)
|
||||
user_id = chat_id
|
||||
self._chat_type_cache[chat_id] = "c2c"
|
||||
|
||||
content = (data.content or "").strip()
|
||||
|
||||
# the data used by tests don't contain attachments property
|
||||
# so we use getattr with a default of [] to avoid AttributeError in tests
|
||||
attachments = getattr(data, "attachments", None) or []
|
||||
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
|
||||
|
||||
# Compose content that always contains actionable saved paths
|
||||
if recv_lines:
|
||||
tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]"
|
||||
file_block = "Received files:\n" + "\n".join(recv_lines)
|
||||
content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
|
||||
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
if self.config.ack_message:
|
||||
try:
|
||||
await self._send_text_only(
|
||||
chat_id=chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=data.id,
|
||||
content=self.config.ack_message,
|
||||
if is_group:
|
||||
chat_id = data.group_openid
|
||||
user_id = data.author.member_openid
|
||||
self._chat_type_cache[chat_id] = "group"
|
||||
else:
|
||||
chat_id = str(
|
||||
getattr(data.author, "id", None)
|
||||
or getattr(data.author, "user_openid", "unknown")
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("QQ ack message failed for chat_id={}", chat_id)
|
||||
user_id = chat_id
|
||||
self._chat_type_cache[chat_id] = "c2c"
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media_paths if media_paths else None,
|
||||
metadata={
|
||||
"message_id": data.id,
|
||||
"attachments": att_meta,
|
||||
},
|
||||
)
|
||||
content = (data.content or "").strip()
|
||||
|
||||
# the data used by tests don't contain attachments property
|
||||
# so we use getattr with a default of [] to avoid AttributeError in tests
|
||||
attachments = getattr(data, "attachments", None) or []
|
||||
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
|
||||
|
||||
# Compose content that always contains actionable saved paths
|
||||
if recv_lines:
|
||||
tag = (
|
||||
"[Image]"
|
||||
if any(_is_image_name(Path(p).name) for p in media_paths)
|
||||
else "[File]"
|
||||
)
|
||||
file_block = "Received files:\n" + "\n".join(recv_lines)
|
||||
content = (
|
||||
f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
|
||||
)
|
||||
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
if self.config.ack_message:
|
||||
try:
|
||||
await self._send_text_only(
|
||||
chat_id=chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=data.id,
|
||||
content=self.config.ack_message,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("QQ ack message failed for chat_id={}", chat_id)
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media_paths if media_paths else None,
|
||||
metadata={
|
||||
"message_id": data.id,
|
||||
"attachments": att_meta,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling QQ inbound message id={}", getattr(data, "id", "?"))
|
||||
|
||||
async def _handle_attachments(
|
||||
self,
|
||||
@ -520,7 +544,9 @@ class QQChannel(BaseChannel):
|
||||
return media_paths, recv_lines, att_meta
|
||||
|
||||
for att in attachments:
|
||||
url, filename, ctype = att.url, att.filename, att.content_type
|
||||
url = getattr(att, "url", None) or ""
|
||||
filename = getattr(att, "filename", None) or ""
|
||||
ctype = getattr(att, "content_type", None) or ""
|
||||
|
||||
logger.info("Downloading file from QQ: {}", filename or url)
|
||||
local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename)
|
||||
@ -555,6 +581,10 @@ class QQChannel(BaseChannel):
|
||||
Enforces a max download size and writes to a .part temp file
|
||||
that is atomically renamed on success.
|
||||
"""
|
||||
# Handle protocol-relative URLs (e.g. "//multimedia.nt.qq.com/...")
|
||||
if url.startswith("//"):
|
||||
url = f"https:{url}"
|
||||
|
||||
if not self._http:
|
||||
self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
|
||||
|
||||
|
||||
457
nanobot/channels/websocket.py
Normal file
457
nanobot/channels/websocket.py
Normal file
@ -0,0 +1,457 @@
|
||||
"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import email.utils
|
||||
import hmac
|
||||
import http
|
||||
import json
|
||||
import secrets
|
||||
import ssl
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Self
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from websockets.asyncio.server import ServerConnection, serve
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.http11 import Request as WsRequest, Response
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
|
||||
def _strip_trailing_slash(path: str) -> str:
|
||||
if len(path) > 1 and path.endswith("/"):
|
||||
return path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _normalize_config_path(path: str) -> str:
|
||||
return _strip_trailing_slash(path)
|
||||
|
||||
|
||||
class WebSocketConfig(Base):
|
||||
"""WebSocket server channel configuration.
|
||||
|
||||
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
|
||||
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
|
||||
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
|
||||
from ``token_issue_path`` are also accepted.
|
||||
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
|
||||
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
|
||||
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
|
||||
nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call
|
||||
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
|
||||
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
|
||||
``X-Nanobot-Auth: <secret>``.
|
||||
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
|
||||
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
|
||||
- ``media`` field in outbound messages contains local filesystem paths; remote clients need a
|
||||
shared filesystem or an HTTP file server to access these files.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 8765
|
||||
path: str = "/"
|
||||
token: str = ""
|
||||
token_issue_path: str = ""
|
||||
token_issue_secret: str = ""
|
||||
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
|
||||
websocket_requires_token: bool = True
|
||||
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
||||
streaming: bool = True
|
||||
max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216)
|
||||
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||||
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||||
ssl_certfile: str = ""
|
||||
ssl_keyfile: str = ""
|
||||
|
||||
@field_validator("path")
|
||||
@classmethod
|
||||
def path_must_start_with_slash(cls, value: str) -> str:
|
||||
if not value.startswith("/"):
|
||||
raise ValueError('path must start with "/"')
|
||||
return _normalize_config_path(value)
|
||||
|
||||
@field_validator("token_issue_path")
|
||||
@classmethod
|
||||
def token_issue_path_format(cls, value: str) -> str:
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return ""
|
||||
if not value.startswith("/"):
|
||||
raise ValueError('token_issue_path must start with "/"')
|
||||
return _normalize_config_path(value)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def token_issue_path_differs_from_ws_path(self) -> Self:
|
||||
if not self.token_issue_path:
|
||||
return self
|
||||
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
|
||||
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
|
||||
return self
|
||||
|
||||
|
||||
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||
headers = Headers(
|
||||
[
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", "application/json; charset=utf-8"),
|
||||
]
|
||||
)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, headers, body)
|
||||
|
||||
|
||||
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||||
"""Parse normalized path and query parameters in one pass."""
|
||||
parsed = urlparse("ws://x" + path_with_query)
|
||||
path = _strip_trailing_slash(parsed.path or "/")
|
||||
return path, parse_qs(parsed.query)
|
||||
|
||||
|
||||
def _normalize_http_path(path_with_query: str) -> str:
|
||||
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
|
||||
return _parse_request_path(path_with_query)[0]
|
||||
|
||||
|
||||
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||||
return _parse_request_path(path_with_query)[1]
|
||||
|
||||
|
||||
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
|
||||
"""Return the first value for *key*, or None."""
|
||||
values = query.get(key)
|
||||
return values[0] if values else None
|
||||
|
||||
|
||||
def _parse_inbound_payload(raw: str) -> str | None:
|
||||
"""Parse a client frame into text; return None for empty or unrecognized content."""
|
||||
text = raw.strip()
|
||||
if not text:
|
||||
return None
|
||||
if text.startswith("{"):
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return text
|
||||
if isinstance(data, dict):
|
||||
for key in ("content", "text", "message"):
|
||||
value = data.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value
|
||||
return None
|
||||
return None
|
||||
return text
|
||||
|
||||
|
||||
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||||
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
|
||||
if not configured_secret:
|
||||
return True
|
||||
authorization = headers.get("Authorization") or headers.get("authorization")
|
||||
if authorization and authorization.lower().startswith("bearer "):
|
||||
supplied = authorization[7:].strip()
|
||||
return hmac.compare_digest(supplied, configured_secret)
|
||||
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||||
if not header_token:
|
||||
return False
|
||||
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||||
|
||||
|
||||
class WebSocketChannel(BaseChannel):
|
||||
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
|
||||
|
||||
name = "websocket"
|
||||
display_name = "WebSocket"
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
if isinstance(config, dict):
|
||||
config = WebSocketConfig.model_validate(config)
|
||||
super().__init__(config, bus)
|
||||
self.config: WebSocketConfig = config
|
||||
self._connections: dict[str, Any] = {}
|
||||
self._issued_tokens: dict[str, float] = {}
|
||||
self._stop_event: asyncio.Event | None = None
|
||||
self._server_task: asyncio.Task[None] | None = None
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
return WebSocketConfig().model_dump(by_alias=True)
|
||||
|
||||
def _expected_path(self) -> str:
|
||||
return _normalize_config_path(self.config.path)
|
||||
|
||||
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
||||
cert = self.config.ssl_certfile.strip()
|
||||
key = self.config.ssl_keyfile.strip()
|
||||
if not cert and not key:
|
||||
return None
|
||||
if not cert or not key:
|
||||
raise ValueError(
|
||||
"websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
||||
)
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||||
return ctx
|
||||
|
||||
_MAX_ISSUED_TOKENS = 10_000
|
||||
|
||||
def _purge_expired_issued_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self._issued_tokens.items()):
|
||||
if now > expiry:
|
||||
self._issued_tokens.pop(token_key, None)
|
||||
|
||||
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||||
"""Validate and consume one issued token (single use per connection attempt).
|
||||
|
||||
Uses single-step pop to minimize the window between lookup and removal;
|
||||
safe under asyncio's single-threaded cooperative model.
|
||||
"""
|
||||
if not token_value:
|
||||
return False
|
||||
self._purge_expired_issued_tokens()
|
||||
expiry = self._issued_tokens.pop(token_value, None)
|
||||
if expiry is None:
|
||||
return False
|
||||
if time.monotonic() > expiry:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
|
||||
secret = self.config.token_issue_secret.strip()
|
||||
if secret:
|
||||
if not _issue_route_secret_matches(request.headers, secret):
|
||||
return connection.respond(401, "Unauthorized")
|
||||
else:
|
||||
logger.warning(
|
||||
"websocket: token_issue_path is set but token_issue_secret is empty; "
|
||||
"any client can obtain connection tokens — set token_issue_secret for production."
|
||||
)
|
||||
self._purge_expired_issued_tokens()
|
||||
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
||||
logger.error(
|
||||
"websocket: too many outstanding issued tokens ({}), rejecting issuance",
|
||||
len(self._issued_tokens),
|
||||
)
|
||||
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
||||
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||||
|
||||
return _http_json_response(
|
||||
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||||
)
|
||||
|
||||
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
|
||||
supplied = _query_first(query, "token")
|
||||
static_token = self.config.token.strip()
|
||||
|
||||
if static_token:
|
||||
if supplied and hmac.compare_digest(supplied, static_token):
|
||||
return None
|
||||
if supplied and self._take_issued_token_if_valid(supplied):
|
||||
return None
|
||||
return connection.respond(401, "Unauthorized")
|
||||
|
||||
if self.config.websocket_requires_token:
|
||||
if supplied and self._take_issued_token_if_valid(supplied):
|
||||
return None
|
||||
return connection.respond(401, "Unauthorized")
|
||||
|
||||
if supplied:
|
||||
self._take_issued_token_if_valid(supplied)
|
||||
return None
|
||||
|
||||
async def start(self) -> None:
|
||||
self._running = True
|
||||
self._stop_event = asyncio.Event()
|
||||
|
||||
ssl_context = self._build_ssl_context()
|
||||
scheme = "wss" if ssl_context else "ws"
|
||||
|
||||
async def process_request(
|
||||
connection: ServerConnection,
|
||||
request: WsRequest,
|
||||
) -> Any:
|
||||
got, _ = _parse_request_path(request.path)
|
||||
if self.config.token_issue_path:
|
||||
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
||||
if got == issue_expected:
|
||||
return self._handle_token_issue_http(connection, request)
|
||||
|
||||
expected_ws = self._expected_path()
|
||||
if got != expected_ws:
|
||||
return connection.respond(404, "Not Found")
|
||||
# Early reject before WebSocket upgrade to avoid unnecessary overhead;
|
||||
# _handle_message() performs a second check as defense-in-depth.
|
||||
query = _parse_query(request.path)
|
||||
client_id = _query_first(query, "client_id") or ""
|
||||
if len(client_id) > 128:
|
||||
client_id = client_id[:128]
|
||||
if not self.is_allowed(client_id):
|
||||
return connection.respond(403, "Forbidden")
|
||||
return self._authorize_websocket_handshake(connection, query)
|
||||
|
||||
async def handler(connection: ServerConnection) -> None:
|
||||
await self._connection_loop(connection)
|
||||
|
||||
logger.info(
|
||||
"WebSocket server listening on {}://{}:{}{}",
|
||||
scheme,
|
||||
self.config.host,
|
||||
self.config.port,
|
||||
self.config.path,
|
||||
)
|
||||
if self.config.token_issue_path:
|
||||
logger.info(
|
||||
"WebSocket token issue route: {}://{}:{}{}",
|
||||
scheme,
|
||||
self.config.host,
|
||||
self.config.port,
|
||||
_normalize_config_path(self.config.token_issue_path),
|
||||
)
|
||||
|
||||
async def runner() -> None:
|
||||
async with serve(
|
||||
handler,
|
||||
self.config.host,
|
||||
self.config.port,
|
||||
process_request=process_request,
|
||||
max_size=self.config.max_message_bytes,
|
||||
ping_interval=self.config.ping_interval_s,
|
||||
ping_timeout=self.config.ping_timeout_s,
|
||||
ssl=ssl_context,
|
||||
):
|
||||
assert self._stop_event is not None
|
||||
await self._stop_event.wait()
|
||||
|
||||
self._server_task = asyncio.create_task(runner())
|
||||
await self._server_task
|
||||
|
||||
async def _connection_loop(self, connection: Any) -> None:
|
||||
request = connection.request
|
||||
path_part = request.path if request else "/"
|
||||
_, query = _parse_request_path(path_part)
|
||||
client_id_raw = _query_first(query, "client_id")
|
||||
client_id = client_id_raw.strip() if client_id_raw else ""
|
||||
if not client_id:
|
||||
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
||||
elif len(client_id) > 128:
|
||||
logger.warning("websocket: client_id too long ({} chars), truncating", len(client_id))
|
||||
client_id = client_id[:128]
|
||||
|
||||
chat_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
await connection.send(
|
||||
json.dumps(
|
||||
{
|
||||
"event": "ready",
|
||||
"chat_id": chat_id,
|
||||
"client_id": client_id,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
# Register only after ready is successfully sent to avoid out-of-order sends
|
||||
self._connections[chat_id] = connection
|
||||
|
||||
async for raw in connection:
|
||||
if isinstance(raw, bytes):
|
||||
try:
|
||||
raw = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("websocket: ignoring non-utf8 binary frame")
|
||||
continue
|
||||
content = _parse_inbound_payload(raw)
|
||||
if content is None:
|
||||
continue
|
||||
await self._handle_message(
|
||||
sender_id=client_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
metadata={"remote": getattr(connection, "remote_address", None)},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("websocket connection ended: {}", e)
|
||||
finally:
|
||||
self._connections.pop(chat_id, None)
|
||||
|
||||
async def stop(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._running = False
|
||||
if self._stop_event:
|
||||
self._stop_event.set()
|
||||
if self._server_task:
|
||||
try:
|
||||
await self._server_task
|
||||
except Exception as e:
|
||||
logger.warning("websocket: server task error during shutdown: {}", e)
|
||||
self._server_task = None
|
||||
self._connections.clear()
|
||||
self._issued_tokens.clear()
|
||||
|
||||
async def _safe_send(self, chat_id: str, raw: str, *, label: str = "") -> None:
|
||||
"""Send a raw frame, cleaning up dead connections on ConnectionClosed."""
|
||||
connection = self._connections.get(chat_id)
|
||||
if connection is None:
|
||||
return
|
||||
try:
|
||||
await connection.send(raw)
|
||||
except ConnectionClosed:
|
||||
self._connections.pop(chat_id, None)
|
||||
logger.warning("websocket{}connection gone for chat_id={}", label, chat_id)
|
||||
except Exception as e:
|
||||
logger.error("websocket{}send failed: {}", label, e)
|
||||
raise
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
connection = self._connections.get(msg.chat_id)
|
||||
if connection is None:
|
||||
logger.warning("websocket: no active connection for chat_id={}", msg.chat_id)
|
||||
return
|
||||
payload: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"text": msg.content,
|
||||
}
|
||||
if msg.media:
|
||||
payload["media"] = msg.media
|
||||
if msg.reply_to:
|
||||
payload["reply_to"] = msg.reply_to
|
||||
raw = json.dumps(payload, ensure_ascii=False)
|
||||
await self._safe_send(msg.chat_id, raw, label=" ")
|
||||
|
||||
async def send_delta(
|
||||
self,
|
||||
chat_id: str,
|
||||
delta: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if self._connections.get(chat_id) is None:
|
||||
return
|
||||
meta = metadata or {}
|
||||
if meta.get("_stream_end"):
|
||||
body: dict[str, Any] = {"event": "stream_end"}
|
||||
else:
|
||||
body = {
|
||||
"event": "delta",
|
||||
"text": delta,
|
||||
}
|
||||
if meta.get("_stream_id") is not None:
|
||||
body["stream_id"] = meta["_stream_id"]
|
||||
raw = json.dumps(body, ensure_ascii=False)
|
||||
await self._safe_send(chat_id, raw, label=" stream ")
|
||||
@ -1,9 +1,13 @@
|
||||
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import importlib.util
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
@ -17,6 +21,37 @@ from pydantic import Field
|
||||
|
||||
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||
|
||||
# Upload safety limits (matching QQ channel defaults)
|
||||
WECOM_UPLOAD_MAX_BYTES = 1024 * 1024 * 200 # 200MB
|
||||
|
||||
# Replace unsafe characters with "_", keep Chinese and common safe punctuation.
|
||||
_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]()【】\u4e00-\u9fff]+", re.UNICODE)
|
||||
|
||||
|
||||
def _sanitize_filename(name: str) -> str:
|
||||
"""Sanitize filename to avoid traversal and problematic chars."""
|
||||
name = (name or "").strip()
|
||||
name = Path(name).name
|
||||
name = _SAFE_NAME_RE.sub("_", name).strip("._ ")
|
||||
return name
|
||||
|
||||
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
|
||||
_VIDEO_EXTS = {".mp4", ".avi", ".mov"}
|
||||
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg"}
|
||||
|
||||
|
||||
def _guess_wecom_media_type(filename: str) -> str:
|
||||
"""Classify file extension as WeCom media_type string."""
|
||||
ext = Path(filename).suffix.lower()
|
||||
if ext in _IMAGE_EXTS:
|
||||
return "image"
|
||||
if ext in _VIDEO_EXTS:
|
||||
return "video"
|
||||
if ext in _AUDIO_EXTS:
|
||||
return "voice"
|
||||
return "file"
|
||||
|
||||
class WecomConfig(Base):
|
||||
"""WeCom (Enterprise WeChat) AI Bot channel configuration."""
|
||||
|
||||
@ -217,6 +252,7 @@ class WecomChannel(BaseChannel):
|
||||
chat_id = body.get("chatid", sender_id)
|
||||
|
||||
content_parts = []
|
||||
media_paths: list[str] = []
|
||||
|
||||
if msg_type == "text":
|
||||
text = body.get("text", {}).get("content", "")
|
||||
@ -232,7 +268,8 @@ class WecomChannel(BaseChannel):
|
||||
file_path = await self._download_and_save_media(file_url, aes_key, "image")
|
||||
if file_path:
|
||||
filename = os.path.basename(file_path)
|
||||
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
|
||||
content_parts.append(f"[image: {filename}]")
|
||||
media_paths.append(file_path)
|
||||
else:
|
||||
content_parts.append("[image: download failed]")
|
||||
else:
|
||||
@ -256,7 +293,8 @@ class WecomChannel(BaseChannel):
|
||||
if file_url and aes_key:
|
||||
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
|
||||
if file_path:
|
||||
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
|
||||
content_parts.append(f"[file: {file_name}]")
|
||||
media_paths.append(file_path)
|
||||
else:
|
||||
content_parts.append(f"[file: {file_name}: download failed]")
|
||||
else:
|
||||
@ -286,12 +324,11 @@ class WecomChannel(BaseChannel):
|
||||
self._chat_frames[chat_id] = frame
|
||||
|
||||
# Forward to message bus
|
||||
# Note: media paths are included in content for broader model compatibility
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=None,
|
||||
media=media_paths or None,
|
||||
metadata={
|
||||
"message_id": msg_id,
|
||||
"msg_type": msg_type,
|
||||
@ -322,13 +359,21 @@ class WecomChannel(BaseChannel):
|
||||
logger.warning("Failed to download media from WeCom")
|
||||
return None
|
||||
|
||||
if len(data) > WECOM_UPLOAD_MAX_BYTES:
|
||||
logger.warning(
|
||||
"WeCom inbound media too large: {} bytes (max {})",
|
||||
len(data),
|
||||
WECOM_UPLOAD_MAX_BYTES,
|
||||
)
|
||||
return None
|
||||
|
||||
media_dir = get_media_dir("wecom")
|
||||
if not filename:
|
||||
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
|
||||
filename = os.path.basename(filename)
|
||||
filename = _sanitize_filename(filename)
|
||||
|
||||
file_path = media_dir / filename
|
||||
file_path.write_bytes(data)
|
||||
await asyncio.to_thread(file_path.write_bytes, data)
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
return str(file_path)
|
||||
|
||||
@ -336,6 +381,100 @@ class WecomChannel(BaseChannel):
|
||||
logger.error("Error downloading media: {}", e)
|
||||
return None
|
||||
|
||||
async def _upload_media_ws(
|
||||
self, client: Any, file_path: str,
|
||||
) -> "tuple[str, str] | tuple[None, None]":
|
||||
"""Upload a local file to WeCom via WebSocket 3-step protocol (base64).
|
||||
|
||||
Uses the WeCom WebSocket upload commands directly via
|
||||
``client._ws_manager.send_reply()``:
|
||||
|
||||
``aibot_upload_media_init`` → upload_id
|
||||
``aibot_upload_media_chunk`` × N (≤512 KB raw per chunk, base64)
|
||||
``aibot_upload_media_finish`` → media_id
|
||||
|
||||
Returns (media_id, media_type) on success, (None, None) on failure.
|
||||
"""
|
||||
from wecom_aibot_sdk.utils import generate_req_id as _gen_req_id
|
||||
|
||||
try:
|
||||
fname = os.path.basename(file_path)
|
||||
media_type = _guess_wecom_media_type(fname)
|
||||
|
||||
# Read file size and data in a thread to avoid blocking the event loop
|
||||
def _read_file():
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size > WECOM_UPLOAD_MAX_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large: {file_size} bytes (max {WECOM_UPLOAD_MAX_BYTES})"
|
||||
)
|
||||
with open(file_path, "rb") as f:
|
||||
return file_size, f.read()
|
||||
|
||||
file_size, data = await asyncio.to_thread(_read_file)
|
||||
# MD5 is used for file integrity only, not cryptographic security
|
||||
md5_hash = hashlib.md5(data).hexdigest()
|
||||
|
||||
CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64)
|
||||
mv = memoryview(data)
|
||||
chunk_list = [bytes(mv[i : i + CHUNK_SIZE]) for i in range(0, file_size, CHUNK_SIZE)]
|
||||
n_chunks = len(chunk_list)
|
||||
del mv, data
|
||||
|
||||
# Step 1: init
|
||||
req_id = _gen_req_id("upload_init")
|
||||
resp = await client._ws_manager.send_reply(req_id, {
|
||||
"type": media_type,
|
||||
"filename": fname,
|
||||
"total_size": file_size,
|
||||
"total_chunks": n_chunks,
|
||||
"md5": md5_hash,
|
||||
}, "aibot_upload_media_init")
|
||||
if resp.errcode != 0:
|
||||
logger.warning("WeCom upload init failed ({}): {}", resp.errcode, resp.errmsg)
|
||||
return None, None
|
||||
upload_id = resp.body.get("upload_id") if resp.body else None
|
||||
if not upload_id:
|
||||
logger.warning("WeCom upload init: no upload_id in response")
|
||||
return None, None
|
||||
|
||||
# Step 2: send chunks
|
||||
for i, chunk in enumerate(chunk_list):
|
||||
req_id = _gen_req_id("upload_chunk")
|
||||
resp = await client._ws_manager.send_reply(req_id, {
|
||||
"upload_id": upload_id,
|
||||
"chunk_index": i,
|
||||
"base64_data": base64.b64encode(chunk).decode(),
|
||||
}, "aibot_upload_media_chunk")
|
||||
if resp.errcode != 0:
|
||||
logger.warning("WeCom upload chunk {} failed ({}): {}", i, resp.errcode, resp.errmsg)
|
||||
return None, None
|
||||
|
||||
# Step 3: finish
|
||||
req_id = _gen_req_id("upload_finish")
|
||||
resp = await client._ws_manager.send_reply(req_id, {
|
||||
"upload_id": upload_id,
|
||||
}, "aibot_upload_media_finish")
|
||||
if resp.errcode != 0:
|
||||
logger.warning("WeCom upload finish failed ({}): {}", resp.errcode, resp.errmsg)
|
||||
return None, None
|
||||
|
||||
media_id = resp.body.get("media_id") if resp.body else None
|
||||
if not media_id:
|
||||
logger.warning("WeCom upload finish: no media_id in response body={}", resp.body)
|
||||
return None, None
|
||||
|
||||
suffix = "..." if len(media_id) > 16 else ""
|
||||
logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + suffix)
|
||||
return media_id, media_type
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning("WeCom upload skipped for {}: {}", file_path, e)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error("WeCom _upload_media_ws error for {}: {}", file_path, e)
|
||||
return None, None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through WeCom."""
|
||||
if not self._client:
|
||||
@ -343,29 +482,59 @@ class WecomChannel(BaseChannel):
|
||||
return
|
||||
|
||||
try:
|
||||
content = msg.content.strip()
|
||||
if not content:
|
||||
return
|
||||
content = (msg.content or "").strip()
|
||||
is_progress = bool(msg.metadata.get("_progress"))
|
||||
|
||||
# Get the stored frame for this chat
|
||||
frame = self._chat_frames.get(msg.chat_id)
|
||||
if not frame:
|
||||
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
|
||||
|
||||
# Send media files via WebSocket upload
|
||||
for file_path in msg.media or []:
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning("WeCom media file not found: {}", file_path)
|
||||
continue
|
||||
media_id, media_type = await self._upload_media_ws(self._client, file_path)
|
||||
if media_id:
|
||||
if frame:
|
||||
await self._client.reply(frame, {
|
||||
"msgtype": media_type,
|
||||
media_type: {"media_id": media_id},
|
||||
})
|
||||
else:
|
||||
await self._client.send_message(msg.chat_id, {
|
||||
"msgtype": media_type,
|
||||
media_type: {"media_id": media_id},
|
||||
})
|
||||
logger.debug("WeCom sent {} → {}", media_type, msg.chat_id)
|
||||
else:
|
||||
content += f"\n[file upload failed: {os.path.basename(file_path)}]"
|
||||
|
||||
if not content:
|
||||
return
|
||||
|
||||
# Use streaming reply for better UX
|
||||
stream_id = self._generate_req_id("stream")
|
||||
if frame:
|
||||
# Both progress and final messages must use reply_stream (cmd="aibot_respond_msg").
|
||||
# The plain reply() uses cmd="reply" which does not support "text" msgtype
|
||||
# and causes errcode=40008 from WeCom API.
|
||||
stream_id = self._generate_req_id("stream")
|
||||
await self._client.reply_stream(
|
||||
frame,
|
||||
stream_id,
|
||||
content,
|
||||
finish=not is_progress,
|
||||
)
|
||||
logger.debug(
|
||||
"WeCom {} sent to {}",
|
||||
"progress" if is_progress else "message",
|
||||
msg.chat_id,
|
||||
)
|
||||
else:
|
||||
# No frame (e.g. cron push): proactive send only supports markdown
|
||||
await self._client.send_message(msg.chat_id, {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {"content": content},
|
||||
})
|
||||
logger.info("WeCom proactive send to {}", msg.chat_id)
|
||||
|
||||
# Send as streaming message with finish=True
|
||||
await self._client.reply_stream(
|
||||
frame,
|
||||
stream_id,
|
||||
content,
|
||||
finish=True,
|
||||
)
|
||||
|
||||
logger.debug("WeCom message sent to {}", msg.chat_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending WeCom message: {}", e)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error sending WeCom message to chat_id={}", msg.chat_id)
|
||||
|
||||
@ -592,6 +592,7 @@ def serve(
|
||||
timezone=runtime_config.agents.defaults.timezone,
|
||||
unified_session=runtime_config.agents.defaults.unified_session,
|
||||
disabled_skills=runtime_config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes,
|
||||
)
|
||||
|
||||
model_name = runtime_config.agents.defaults.model
|
||||
@ -685,6 +686,7 @@ def gateway(
|
||||
timezone=config.agents.defaults.timezone,
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
disabled_skills=config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||
)
|
||||
|
||||
# Set cron callback (needs agent)
|
||||
@ -918,6 +920,7 @@ def agent(
|
||||
timezone=config.agents.defaults.timezone,
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
disabled_skills=config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||
)
|
||||
restart_notice = consume_restart_notice_from_env()
|
||||
if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):
|
||||
|
||||
@ -78,6 +78,12 @@ class AgentDefaults(Base):
|
||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
||||
disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"])
|
||||
session_ttl_minutes: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
validation_alias=AliasChoices("idleCompactAfterMinutes", "sessionTtlMinutes"),
|
||||
serialization_alias="idleCompactAfterMinutes",
|
||||
) # Auto-compact idle threshold in minutes (0 = disabled)
|
||||
dream: DreamConfig = Field(default_factory=DreamConfig)
|
||||
|
||||
|
||||
@ -154,7 +160,7 @@ class GatewayConfig(Base):
|
||||
class WebSearchConfig(Base):
|
||||
"""Web search tool configuration."""
|
||||
|
||||
provider: str = "duckduckgo" # brave, tavily, duckduckgo, searxng, jina
|
||||
provider: str = "duckduckgo" # brave, tavily, duckduckgo, searxng, jina, kagi
|
||||
api_key: str = ""
|
||||
base_url: str = "" # SearXNG base URL
|
||||
max_results: int = 5
|
||||
@ -178,6 +184,7 @@ class ExecToolConfig(Base):
|
||||
timeout: int = 60
|
||||
path_append: str = ""
|
||||
sandbox: str = "" # sandbox backend: "" (none) or "bwrap"
|
||||
allowed_env_keys: list[str] = Field(default_factory=list) # Env var names to pass through to subprocess (e.g. ["GOPATH", "JAVA_HOME"])
|
||||
|
||||
class MCPServerConfig(Base):
|
||||
"""MCP server connection configuration (stdio or HTTP)."""
|
||||
|
||||
@ -80,6 +80,7 @@ class CronService:
|
||||
self._store: CronStore | None = None
|
||||
self._timer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._timer_active = False
|
||||
self.max_sleep_ms = max_sleep_ms
|
||||
|
||||
def _load_jobs(self) -> tuple[list[CronJob], int]:
|
||||
@ -171,7 +172,11 @@ class CronService:
|
||||
def _load_store(self) -> CronStore:
|
||||
"""Load jobs from disk. Reloads automatically if file was modified externally.
|
||||
- Reload every time because it needs to merge operations on the jobs object from other instances.
|
||||
- During _on_timer execution, return the existing store to prevent concurrent
|
||||
_load_store calls (e.g. from list_jobs polling) from replacing it mid-execution.
|
||||
"""
|
||||
if self._timer_active and self._store:
|
||||
return self._store
|
||||
jobs, version = self._load_jobs()
|
||||
self._store = CronStore(version=version, jobs=jobs)
|
||||
self._merge_action()
|
||||
@ -290,18 +295,23 @@ class CronService:
|
||||
"""Handle timer tick - run due jobs."""
|
||||
self._load_store()
|
||||
if not self._store:
|
||||
self._arm_timer()
|
||||
return
|
||||
|
||||
now = _now_ms()
|
||||
due_jobs = [
|
||||
j for j in self._store.jobs
|
||||
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
|
||||
]
|
||||
self._timer_active = True
|
||||
try:
|
||||
now = _now_ms()
|
||||
due_jobs = [
|
||||
j for j in self._store.jobs
|
||||
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
|
||||
]
|
||||
|
||||
for job in due_jobs:
|
||||
await self._execute_job(job)
|
||||
for job in due_jobs:
|
||||
await self._execute_job(job)
|
||||
|
||||
self._save_store()
|
||||
self._save_store()
|
||||
finally:
|
||||
self._timer_active = False
|
||||
self._arm_timer()
|
||||
|
||||
async def _execute_job(self, job: CronJob) -> None:
|
||||
@ -460,6 +470,59 @@ class CronService:
|
||||
return job
|
||||
return None
|
||||
|
||||
def update_job(
|
||||
self,
|
||||
job_id: str,
|
||||
*,
|
||||
name: str | None = None,
|
||||
schedule: CronSchedule | None = None,
|
||||
message: str | None = None,
|
||||
deliver: bool | None = None,
|
||||
channel: str | None = ...,
|
||||
to: str | None = ...,
|
||||
delete_after_run: bool | None = None,
|
||||
) -> CronJob | Literal["not_found", "protected"]:
|
||||
"""Update mutable fields of an existing job. System jobs cannot be updated.
|
||||
|
||||
For ``channel`` and ``to``, pass an explicit value (including ``None``)
|
||||
to update; omit (sentinel ``...``) to leave unchanged.
|
||||
"""
|
||||
store = self._load_store()
|
||||
job = next((j for j in store.jobs if j.id == job_id), None)
|
||||
if job is None:
|
||||
return "not_found"
|
||||
if job.payload.kind == "system_event":
|
||||
return "protected"
|
||||
|
||||
if schedule is not None:
|
||||
_validate_schedule_for_add(schedule)
|
||||
job.schedule = schedule
|
||||
if name is not None:
|
||||
job.name = name
|
||||
if message is not None:
|
||||
job.payload.message = message
|
||||
if deliver is not None:
|
||||
job.payload.deliver = deliver
|
||||
if channel is not ...:
|
||||
job.payload.channel = channel
|
||||
if to is not ...:
|
||||
job.payload.to = to
|
||||
if delete_after_run is not None:
|
||||
job.delete_after_run = delete_after_run
|
||||
|
||||
job.updated_at_ms = _now_ms()
|
||||
if job.enabled:
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||
|
||||
if self._running:
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
else:
|
||||
self._append_action("update", asdict(job))
|
||||
|
||||
logger.info("Cron: updated job '{}' ({})", job.name, job.id)
|
||||
return job
|
||||
|
||||
async def run_job(self, job_id: str, force: bool = False) -> bool:
|
||||
"""Manually run a job without disturbing the service's running state."""
|
||||
was_running = self._running
|
||||
|
||||
@ -83,6 +83,7 @@ class Nanobot:
|
||||
timezone=defaults.timezone,
|
||||
unified_session=defaults.unified_session,
|
||||
disabled_skills=defaults.disabled_skills,
|
||||
session_ttl_minutes=defaults.session_ttl_minutes,
|
||||
)
|
||||
return cls(loop)
|
||||
|
||||
|
||||
@ -375,6 +375,14 @@ class LLMProvider(ABC):
|
||||
and role in ("user", "assistant")
|
||||
):
|
||||
prev = merged[-1]
|
||||
if role == "assistant":
|
||||
prev_has_tools = bool(prev.get("tool_calls"))
|
||||
curr_has_tools = bool(msg.get("tool_calls"))
|
||||
if curr_has_tools:
|
||||
merged[-1] = dict(msg)
|
||||
continue
|
||||
if prev_has_tools:
|
||||
continue
|
||||
prev_content = prev.get("content") or ""
|
||||
curr_content = msg.get("content") or ""
|
||||
if isinstance(prev_content, str) and isinstance(curr_content, str):
|
||||
|
||||
@ -243,6 +243,10 @@ class OpenAICompatProvider(LLMProvider):
|
||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||
normalized.append(tc_clean)
|
||||
clean["tool_calls"] = normalized
|
||||
if clean.get("role") == "assistant":
|
||||
# Some OpenAI-compatible gateways reject assistant messages
|
||||
# that mix non-empty content with tool_calls.
|
||||
clean["content"] = None
|
||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return self._enforce_role_alternation(sanitized)
|
||||
|
||||
@ -155,6 +155,7 @@ class SessionManager:
|
||||
messages = []
|
||||
metadata = {}
|
||||
created_at = None
|
||||
updated_at = None
|
||||
last_consolidated = 0
|
||||
|
||||
with open(path, encoding="utf-8") as f:
|
||||
@ -168,6 +169,7 @@ class SessionManager:
|
||||
if data.get("_type") == "metadata":
|
||||
metadata = data.get("metadata", {})
|
||||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||||
updated_at = datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None
|
||||
last_consolidated = data.get("last_consolidated", 0)
|
||||
else:
|
||||
messages.append(data)
|
||||
@ -176,6 +178,7 @@ class SessionManager:
|
||||
key=key,
|
||||
messages=messages,
|
||||
created_at=created_at or datetime.now(),
|
||||
updated_at=updated_at or datetime.now(),
|
||||
metadata=metadata,
|
||||
last_consolidated=last_consolidated
|
||||
)
|
||||
|
||||
@ -15,9 +15,12 @@ from loguru import logger
|
||||
|
||||
|
||||
def strip_think(text: str) -> str:
|
||||
"""Remove <think>…</think> blocks and any unclosed trailing <think> tag."""
|
||||
"""Remove thinking blocks and any unclosed trailing tag."""
|
||||
text = re.sub(r"<think>[\s\S]*?</think>", "", text)
|
||||
text = re.sub(r"<think>[\s\S]*$", "", text)
|
||||
text = re.sub(r"^\s*<think>[\s\S]*$", "", text)
|
||||
# Gemma 4 and similar models use <thought>...</thought> blocks
|
||||
text = re.sub(r"<thought>[\s\S]*?</thought>", "", text)
|
||||
text = re.sub(r"^\s*<thought>[\s\S]*$", "", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
|
||||
@ -76,12 +76,16 @@ discord = [
|
||||
langsmith = [
|
||||
"langsmith>=0.1.0",
|
||||
]
|
||||
pdf = [
|
||||
"pymupdf>=1.25.0",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=9.0.0,<10.0.0",
|
||||
"pytest-asyncio>=1.3.0,<2.0.0",
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
"pytest-cov>=6.0.0,<7.0.0",
|
||||
"ruff>=0.1.0",
|
||||
"pymupdf>=1.25.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
914
tests/agent/test_auto_compact.py
Normal file
914
tests/agent/test_auto_compact.py
Normal file
@ -0,0 +1,914 @@
|
||||
"""Tests for auto compact (idle TTL) feature."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.command import CommandContext
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_loop(tmp_path: Path, session_ttl_minutes: int = 15) -> AgentLoop:
|
||||
"""Create a minimal AgentLoop for testing."""
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||
provider.generation.max_tokens = 4096
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=128_000,
|
||||
session_ttl_minutes=session_ttl_minutes,
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
return loop
|
||||
|
||||
|
||||
def _add_turns(session, turns: int, *, prefix: str = "msg") -> None:
|
||||
"""Append simple user/assistant turns to a session."""
|
||||
for i in range(turns):
|
||||
session.add_message("user", f"{prefix} user {i}")
|
||||
session.add_message("assistant", f"{prefix} assistant {i}")
|
||||
|
||||
|
||||
class TestSessionTTLConfig:
|
||||
"""Test session TTL configuration."""
|
||||
|
||||
def test_default_ttl_is_zero(self):
|
||||
"""Default TTL should be 0 (disabled)."""
|
||||
defaults = AgentDefaults()
|
||||
assert defaults.session_ttl_minutes == 0
|
||||
|
||||
def test_custom_ttl(self):
|
||||
"""Custom TTL should be stored correctly."""
|
||||
defaults = AgentDefaults(session_ttl_minutes=30)
|
||||
assert defaults.session_ttl_minutes == 30
|
||||
|
||||
def test_user_friendly_alias_is_supported(self):
|
||||
"""Config should accept idleCompactAfterMinutes as the preferred JSON key."""
|
||||
defaults = AgentDefaults.model_validate({"idleCompactAfterMinutes": 30})
|
||||
assert defaults.session_ttl_minutes == 30
|
||||
|
||||
def test_legacy_alias_is_still_supported(self):
|
||||
"""Config should still accept the old sessionTtlMinutes key for compatibility."""
|
||||
defaults = AgentDefaults.model_validate({"sessionTtlMinutes": 30})
|
||||
assert defaults.session_ttl_minutes == 30
|
||||
|
||||
def test_serializes_with_user_friendly_alias(self):
|
||||
"""Config dumps should use idleCompactAfterMinutes for JSON output."""
|
||||
defaults = AgentDefaults(session_ttl_minutes=30)
|
||||
data = defaults.model_dump(mode="json", by_alias=True)
|
||||
assert data["idleCompactAfterMinutes"] == 30
|
||||
assert "sessionTtlMinutes" not in data
|
||||
|
||||
|
||||
class TestAgentLoopTTLParam:
|
||||
"""Test that AutoCompact receives and stores session_ttl_minutes."""
|
||||
|
||||
def test_loop_stores_ttl(self, tmp_path):
|
||||
"""AutoCompact should store the TTL value."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=25)
|
||||
assert loop.auto_compact._ttl == 25
|
||||
|
||||
def test_loop_default_ttl_zero(self, tmp_path):
|
||||
"""AutoCompact default TTL should be 0 (disabled)."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=0)
|
||||
assert loop.auto_compact._ttl == 0
|
||||
|
||||
|
||||
class TestAutoCompact:
|
||||
"""Test the _archive method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_expired_boundary(self, tmp_path):
|
||||
"""Exactly at TTL boundary should be expired (>= not >)."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
ts = datetime.now() - timedelta(minutes=15)
|
||||
assert loop.auto_compact._is_expired(ts) is True
|
||||
ts2 = datetime.now() - timedelta(minutes=14, seconds=59)
|
||||
assert loop.auto_compact._is_expired(ts2) is False
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_expired_string_timestamp(self, tmp_path):
|
||||
"""_is_expired should parse ISO string timestamps."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
ts = (datetime.now() - timedelta(minutes=20)).isoformat()
|
||||
assert loop.auto_compact._is_expired(ts) is True
|
||||
assert loop.auto_compact._is_expired(None) is False
|
||||
assert loop.auto_compact._is_expired("") is False
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_expired_only_archives_expired_sessions(self, tmp_path):
|
||||
"""With multiple sessions, only the expired one should be archived."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
# Expired session
|
||||
s1 = loop.sessions.get_or_create("cli:expired")
|
||||
s1.add_message("user", "old")
|
||||
s1.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(s1)
|
||||
# Active session
|
||||
s2 = loop.sessions.get_or_create("cli:active")
|
||||
s2.add_message("user", "recent")
|
||||
loop.sessions.save(s2)
|
||||
|
||||
async def _fake_archive(messages):
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
loop.auto_compact.check_expired(loop._schedule_background)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
active_after = loop.sessions.get_or_create("cli:active")
|
||||
assert len(active_after.messages) == 1
|
||||
assert active_after.messages[0]["content"] == "recent"
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_archives_prefix_and_keeps_recent_suffix(self, tmp_path):
|
||||
"""_archive should summarize the old prefix and keep a recent legal suffix."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6)
|
||||
loop.sessions.save(session)
|
||||
|
||||
archived_messages = []
|
||||
|
||||
async def _fake_archive(messages):
|
||||
archived_messages.extend(messages)
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
assert len(archived_messages) == 4
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES
|
||||
assert session_after.messages[0]["content"] == "msg user 2"
|
||||
assert session_after.messages[-1]["content"] == "msg assistant 5"
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_stores_summary(self, tmp_path):
|
||||
"""_archive should store the summary in _summaries."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6, prefix="hello")
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _fake_archive(messages):
|
||||
return "User said hello."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
entry = loop.auto_compact._summaries.get("cli:test")
|
||||
assert entry is not None
|
||||
assert entry[0] == "User said hello."
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_empty_session(self, tmp_path):
|
||||
"""_archive on empty session should not archive."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
|
||||
archive_called = False
|
||||
|
||||
async def _fake_archive(messages):
|
||||
nonlocal archive_called
|
||||
archive_called = True
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
assert not archive_called
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == 0
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_respects_last_consolidated(self, tmp_path):
|
||||
"""_archive should only archive un-consolidated messages."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 14)
|
||||
session.last_consolidated = 18
|
||||
loop.sessions.save(session)
|
||||
|
||||
archived_count = 0
|
||||
|
||||
async def _fake_archive(messages):
|
||||
nonlocal archived_count
|
||||
archived_count = len(messages)
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
assert archived_count == 2
|
||||
await loop.close_mcp()
|
||||
|
||||
|
||||
class TestAutoCompactIdleDetection:
|
||||
"""Test idle detection triggers auto-new in _process_message."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_auto_compact_when_ttl_disabled(self, tmp_path):
|
||||
"""No auto-new should happen when TTL is 0 (disabled)."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=0)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.add_message("user", "old message")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=30)
|
||||
loop.sessions.save(session)
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new msg")
|
||||
await loop._process_message(msg)
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert any(m["content"] == "old message" for m in session_after.messages)
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_triggers_on_idle(self, tmp_path):
|
||||
"""Proactive auto-new archives expired session; _process_message reloads it."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6, prefix="old")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
archived_messages = []
|
||||
|
||||
async def _fake_archive(messages):
|
||||
archived_messages.extend(messages)
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
# Simulate proactive archive completing before message arrives
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new msg")
|
||||
await loop._process_message(msg)
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(archived_messages) == 4
|
||||
assert not any(m["content"] == "old user 0" for m in session_after.messages)
|
||||
assert any(m["content"] == "new msg" for m in session_after.messages)
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_auto_compact_when_active(self, tmp_path):
|
||||
"""No auto-new should happen when session is recently active."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.add_message("user", "recent message")
|
||||
loop.sessions.save(session)
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="new msg")
|
||||
await loop._process_message(msg)
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert any(m["content"] == "recent message" for m in session_after.messages)
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_does_not_affect_priority_commands(self, tmp_path):
|
||||
"""Priority commands (/stop, /restart) bypass _process_message entirely via run()."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.add_message("user", "old message")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
# Priority commands are dispatched in run() before _process_message is called.
|
||||
# Simulate that path directly via dispatch_priority.
|
||||
raw = "/stop"
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content=raw)
|
||||
ctx = CommandContext(msg=msg, session=session, key="cli:test", raw=raw, loop=loop)
|
||||
result = await loop.commands.dispatch_priority(ctx)
|
||||
assert result is not None
|
||||
assert "stopped" in result.content.lower() or "no active task" in result.content.lower()
|
||||
|
||||
# Session should be untouched since priority commands skip _process_message
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert any(m["content"] == "old message" for m in session_after.messages)
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_with_slash_new(self, tmp_path):
|
||||
"""Auto-new fires before /new dispatches; session is cleared twice but idempotent."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(4):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _fake_archive(messages):
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(msg)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
# Session is empty (auto-new archived and cleared, /new cleared again)
|
||||
assert len(session_after.messages) == 0
|
||||
await loop.close_mcp()
|
||||
|
||||
|
||||
class TestAutoCompactSystemMessages:
|
||||
"""Test that auto-new also works for system messages."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_triggers_for_system_messages(self, tmp_path):
|
||||
"""Proactive auto-new archives expired session; system messages reload it."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6, prefix="old")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _fake_archive(messages):
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
# Simulate proactive archive completing before system message arrives
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
msg = InboundMessage(
|
||||
channel="system", sender_id="subagent", chat_id="cli:test",
|
||||
content="subagent result",
|
||||
)
|
||||
await loop._process_message(msg)
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert not any(
|
||||
m["content"] == "old user 0"
|
||||
for m in session_after.messages
|
||||
)
|
||||
await loop.close_mcp()
|
||||
|
||||
|
||||
class TestAutoCompactEdgeCases:
|
||||
"""Edge cases for auto session new."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_with_nothing_summary(self, tmp_path):
|
||||
"""Auto-new should not inject when archive produces '(nothing)'."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6, prefix="thanks")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
loop.provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(content="(nothing)", tool_calls=[])
|
||||
)
|
||||
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES
|
||||
# "(nothing)" summary should not be stored
|
||||
assert "cli:test" not in loop.auto_compact._summaries
|
||||
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_archive_failure_still_keeps_recent_suffix(self, tmp_path):
|
||||
"""Auto-new should keep the recent suffix even if LLM archive falls back to raw dump."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6, prefix="important")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=Exception("API down"))
|
||||
|
||||
# Should not raise
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES
|
||||
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_compact_preserves_runtime_checkpoint_before_check(self, tmp_path):
|
||||
"""Short expired sessions keep recent messages; checkpoint restore still works on resume."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.metadata[AgentLoop._RUNTIME_CHECKPOINT_KEY] = {
|
||||
"assistant_message": {"role": "assistant", "content": "interrupted response"},
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
}
|
||||
session.add_message("user", "previous message")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
archived_messages = []
|
||||
|
||||
async def _fake_archive(messages):
|
||||
archived_messages.extend(messages)
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
# Simulate proactive archive completing before message arrives
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="continue")
|
||||
await loop._process_message(msg)
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert archived_messages == []
|
||||
assert any(m["content"] == "previous message" for m in session_after.messages)
|
||||
assert any(m["content"] == "interrupted response" for m in session_after.messages)
|
||||
|
||||
await loop.close_mcp()
|
||||
|
||||
|
||||
class TestAutoCompactIntegration:
|
||||
"""End-to-end test of auto session new feature."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle(self, tmp_path):
|
||||
"""
|
||||
Full lifecycle: messages -> idle -> auto-new -> archive -> clear -> summary injected as runtime context.
|
||||
"""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
|
||||
# Phase 1: User has a conversation longer than the retained recent suffix
|
||||
session.add_message("user", "I'm learning English, teach me past tense")
|
||||
session.add_message("assistant", "Past tense is used for actions completed in the past...")
|
||||
session.add_message("user", "Give me an example")
|
||||
session.add_message("assistant", '"I walked to the store yesterday."')
|
||||
session.add_message("user", "Give me another example")
|
||||
session.add_message("assistant", '"She visited Paris last year."')
|
||||
session.add_message("user", "Quiz me")
|
||||
session.add_message("assistant", "What is the past tense of go?")
|
||||
session.add_message("user", "I think it is went")
|
||||
session.add_message("assistant", "Correct.")
|
||||
loop.sessions.save(session)
|
||||
|
||||
# Phase 2: Time passes (simulate idle)
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
# Phase 3: User returns with a new message
|
||||
loop.provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(
|
||||
content="User is learning English past tense. Example: 'I walked to the store yesterday.'",
|
||||
tool_calls=[],
|
||||
)
|
||||
)
|
||||
|
||||
msg = InboundMessage(
|
||||
channel="cli", sender_id="user", chat_id="test",
|
||||
content="Let's continue, teach me present perfect",
|
||||
)
|
||||
response = await loop._process_message(msg)
|
||||
|
||||
# Phase 4: Verify
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
|
||||
# The oldest messages should be trimmed from live session history
|
||||
assert not any(
|
||||
"past tense is used" in str(m.get("content", "")) for m in session_after.messages
|
||||
)
|
||||
|
||||
# Summary should NOT be persisted in session (ephemeral, one-shot)
|
||||
assert not any(
|
||||
"[Resumed Session]" in str(m.get("content", "")) for m in session_after.messages
|
||||
)
|
||||
# Runtime context end marker should NOT be persisted
|
||||
assert not any(
|
||||
"[/Runtime Context]" in str(m.get("content", "")) for m in session_after.messages
|
||||
)
|
||||
|
||||
# Pending summary should be consumed (one-shot)
|
||||
assert "cli:test" not in loop.auto_compact._summaries
|
||||
|
||||
# The new message should be processed (response exists)
|
||||
assert response is not None
|
||||
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_context_markers_not_persisted_for_multi_paragraph_turn(self, tmp_path):
|
||||
"""Auto-compact resume context must not leak runtime markers into persisted session history."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.add_message("user", "old message")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _fake_archive(messages):
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
# Simulate proactive archive completing before message arrives
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
msg = InboundMessage(
|
||||
channel="cli", sender_id="user", chat_id="test",
|
||||
content="Paragraph one\n\nParagraph two\n\nParagraph three",
|
||||
)
|
||||
await loop._process_message(msg)
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert any(m.get("content") == "old message" for m in session_after.messages)
|
||||
for persisted in session_after.messages:
|
||||
content = str(persisted.get("content", ""))
|
||||
assert "[Runtime Context" not in content
|
||||
assert "[/Runtime Context]" not in content
|
||||
await loop.close_mcp()
|
||||
|
||||
|
||||
class TestProactiveAutoCompact:
|
||||
"""Test proactive auto-new on idle ticks (TimeoutError path in run loop)."""
|
||||
|
||||
@staticmethod
|
||||
async def _run_check_expired(loop):
|
||||
"""Helper: run check_expired via callback and wait for background tasks."""
|
||||
loop.auto_compact.check_expired(loop._schedule_background)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_check_when_ttl_disabled(self, tmp_path):
|
||||
"""check_expired should be a no-op when TTL is 0."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=0)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.add_message("user", "old message")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=30)
|
||||
loop.sessions.save(session)
|
||||
|
||||
await self._run_check_expired(loop)
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == 1
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proactive_archive_on_idle_tick(self, tmp_path):
|
||||
"""Expired session should be archived during idle tick."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 5, prefix="old")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
archived_messages = []
|
||||
|
||||
async def _fake_archive(messages):
|
||||
archived_messages.extend(messages)
|
||||
return "User chatted about old things."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
await self._run_check_expired(loop)
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES
|
||||
assert len(archived_messages) == 2
|
||||
entry = loop.auto_compact._summaries.get("cli:test")
|
||||
assert entry is not None
|
||||
assert entry[0] == "User chatted about old things."
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_proactive_archive_when_active(self, tmp_path):
|
||||
"""Recently active session should NOT be archived on idle tick."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.add_message("user", "recent message")
|
||||
loop.sessions.save(session)
|
||||
|
||||
await self._run_check_expired(loop)
|
||||
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
assert len(session_after.messages) == 1
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_duplicate_archive(self, tmp_path):
|
||||
"""Should not archive the same session twice if already in progress."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6, prefix="old")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
archive_count = 0
|
||||
started = asyncio.Event()
|
||||
block_forever = asyncio.Event()
|
||||
|
||||
async def _slow_archive(messages):
|
||||
nonlocal archive_count
|
||||
archive_count += 1
|
||||
started.set()
|
||||
await block_forever.wait()
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _slow_archive
|
||||
|
||||
# First call starts archiving via callback
|
||||
loop.auto_compact.check_expired(loop._schedule_background)
|
||||
await started.wait()
|
||||
assert archive_count == 1
|
||||
|
||||
# Second call should skip (key is in _archiving)
|
||||
loop.auto_compact.check_expired(loop._schedule_background)
|
||||
await asyncio.sleep(0.05)
|
||||
assert archive_count == 1
|
||||
|
||||
# Clean up
|
||||
block_forever.set()
|
||||
await asyncio.sleep(0.1)
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proactive_archive_error_does_not_block(self, tmp_path):
|
||||
"""Proactive archive failure should be caught and not block future ticks."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6, prefix="old")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _failing_archive(messages):
|
||||
raise RuntimeError("LLM down")
|
||||
|
||||
loop.consolidator.archive = _failing_archive
|
||||
|
||||
# Should not raise
|
||||
await self._run_check_expired(loop)
|
||||
|
||||
# Key should be removed from _archiving (finally block)
|
||||
assert "cli:test" not in loop.auto_compact._archiving
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proactive_archive_skips_empty_sessions(self, tmp_path):
|
||||
"""Proactive archive should not call LLM for sessions with no un-consolidated messages."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
archive_called = False
|
||||
|
||||
async def _fake_archive(messages):
|
||||
nonlocal archive_called
|
||||
archive_called = True
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
await self._run_check_expired(loop)
|
||||
|
||||
assert not archive_called
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_reschedule_after_successful_archive(self, tmp_path):
|
||||
"""Already-archived session should NOT be re-scheduled on subsequent ticks."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 5, prefix="old")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
archive_count = 0
|
||||
|
||||
async def _fake_archive(messages):
|
||||
nonlocal archive_count
|
||||
archive_count += 1
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
# First tick: archives the session
|
||||
await self._run_check_expired(loop)
|
||||
assert archive_count == 1
|
||||
|
||||
# Second tick: should NOT re-schedule (updated_at is fresh after clear)
|
||||
await self._run_check_expired(loop)
|
||||
assert archive_count == 1 # Still 1, not re-scheduled
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_skip_refreshes_updated_at_prevents_reschedule(self, tmp_path):
|
||||
"""Empty session skip refreshes updated_at, preventing immediate re-scheduling."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
archive_count = 0
|
||||
|
||||
async def _fake_archive(messages):
|
||||
nonlocal archive_count
|
||||
archive_count += 1
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
# First tick: skips (no messages), refreshes updated_at
|
||||
await self._run_check_expired(loop)
|
||||
assert archive_count == 0
|
||||
|
||||
# Second tick: should NOT re-schedule because updated_at is fresh
|
||||
await self._run_check_expired(loop)
|
||||
assert archive_count == 0
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_can_be_compacted_again_after_new_messages(self, tmp_path):
|
||||
"""After successful compact + user sends new messages + idle again, should compact again."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 5, prefix="first")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
archive_count = 0
|
||||
|
||||
async def _fake_archive(messages):
|
||||
nonlocal archive_count
|
||||
archive_count += 1
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
# First compact cycle
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
assert archive_count == 1
|
||||
|
||||
# User returns, sends new messages
|
||||
msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="second topic")
|
||||
await loop._process_message(msg)
|
||||
|
||||
# Simulate idle again
|
||||
loop.sessions.invalidate("cli:test")
|
||||
session2 = loop.sessions.get_or_create("cli:test")
|
||||
session2.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session2)
|
||||
|
||||
# Second compact cycle should succeed
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
assert archive_count == 2
|
||||
await loop.close_mcp()
|
||||
|
||||
|
||||
class TestSummaryPersistence:
|
||||
"""Test that summary survives restart via session metadata."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_persisted_in_session_metadata(self, tmp_path):
|
||||
"""After archive, _last_summary should be in session metadata."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6, prefix="hello")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _fake_archive(messages):
|
||||
return "User said hello."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
# Summary should be persisted in session metadata
|
||||
session_after = loop.sessions.get_or_create("cli:test")
|
||||
meta = session_after.metadata.get("_last_summary")
|
||||
assert meta is not None
|
||||
assert meta["text"] == "User said hello."
|
||||
assert "last_active" in meta
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_recovered_after_restart(self, tmp_path):
|
||||
"""Summary should be recovered from metadata when _summaries is empty (simulates restart)."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6, prefix="hello")
|
||||
last_active = datetime.now() - timedelta(minutes=20)
|
||||
session.updated_at = last_active
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _fake_archive(messages):
|
||||
return "User said hello."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
# Archive
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
# Simulate restart: clear in-memory state
|
||||
loop.auto_compact._summaries.clear()
|
||||
loop.sessions.invalidate("cli:test")
|
||||
|
||||
# prepare_session should recover summary from metadata
|
||||
reloaded = loop.sessions.get_or_create("cli:test")
|
||||
assert len(reloaded.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES
|
||||
_, summary = loop.auto_compact.prepare_session(reloaded, "cli:test")
|
||||
|
||||
assert summary is not None
|
||||
assert "User said hello." in summary
|
||||
assert "Inactive for" in summary
|
||||
# Metadata should be cleaned up after consumption
|
||||
assert "_last_summary" not in reloaded.metadata
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_cleanup_no_leak(self, tmp_path):
|
||||
"""_last_summary should be removed from metadata after being consumed."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6, prefix="hello")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _fake_archive(messages):
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
# Clear in-memory to force metadata path
|
||||
loop.auto_compact._summaries.clear()
|
||||
loop.sessions.invalidate("cli:test")
|
||||
reloaded = loop.sessions.get_or_create("cli:test")
|
||||
|
||||
# First call: consumes from metadata
|
||||
_, summary = loop.auto_compact.prepare_session(reloaded, "cli:test")
|
||||
assert summary is not None
|
||||
|
||||
# Second call: no summary (already consumed)
|
||||
_, summary2 = loop.auto_compact.prepare_session(reloaded, "cli:test")
|
||||
assert summary2 is None
|
||||
assert "_last_summary" not in reloaded.metadata
|
||||
await loop.close_mcp()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_cleanup_on_inmemory_path(self, tmp_path):
|
||||
"""In-memory _summaries path should also clean up _last_summary from metadata."""
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=15)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
_add_turns(session, 6, prefix="hello")
|
||||
session.updated_at = datetime.now() - timedelta(minutes=20)
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _fake_archive(messages):
|
||||
return "Summary."
|
||||
|
||||
loop.consolidator.archive = _fake_archive
|
||||
|
||||
await loop.auto_compact._archive("cli:test")
|
||||
|
||||
# Both _summaries and metadata have the summary
|
||||
assert "cli:test" in loop.auto_compact._summaries
|
||||
loop.sessions.invalidate("cli:test")
|
||||
reloaded = loop.sessions.get_or_create("cli:test")
|
||||
assert "_last_summary" in reloaded.metadata
|
||||
|
||||
# In-memory path is taken (no restart)
|
||||
_, summary = loop.auto_compact.prepare_session(reloaded, "cli:test")
|
||||
assert summary is not None
|
||||
# Metadata should also be cleaned up
|
||||
assert "_last_summary" not in reloaded.metadata
|
||||
await loop.close_mcp()
|
||||
@ -46,7 +46,7 @@ class TestConsolidatorSummarize:
|
||||
{"role": "assistant", "content": "Done, fixed the race condition."},
|
||||
]
|
||||
result = await consolidator.archive(messages)
|
||||
assert result is True
|
||||
assert result == "User fixed a bug in the auth module."
|
||||
entries = store.read_unprocessed_history(since_cursor=0)
|
||||
assert len(entries) == 1
|
||||
|
||||
@ -55,14 +55,14 @@ class TestConsolidatorSummarize:
|
||||
mock_provider.chat_with_retry.side_effect = Exception("API error")
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
result = await consolidator.archive(messages)
|
||||
assert result is True # always succeeds
|
||||
assert result is None # no summary on raw dump fallback
|
||||
entries = store.read_unprocessed_history(since_cursor=0)
|
||||
assert len(entries) == 1
|
||||
assert "[RAW]" in entries[0]["content"]
|
||||
|
||||
async def test_summarize_skips_empty_messages(self, consolidator):
|
||||
result = await consolidator.archive([])
|
||||
assert result is False
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestConsolidatorTokenBudget:
|
||||
@ -76,3 +76,52 @@ class TestConsolidatorTokenBudget:
|
||||
consolidator.archive = AsyncMock(return_value=True)
|
||||
await consolidator.maybe_consolidate_by_tokens(session)
|
||||
consolidator.archive.assert_not_called()
|
||||
|
||||
async def test_chunk_cap_preserves_user_turn_boundary(self, consolidator):
|
||||
"""Chunk cap should rewind to the last user boundary within the cap."""
|
||||
consolidator._SAFETY_BUFFER = 0
|
||||
session = MagicMock()
|
||||
session.last_consolidated = 0
|
||||
session.key = "test:key"
|
||||
session.messages = [
|
||||
{
|
||||
"role": "user" if i in {0, 50, 61} else "assistant",
|
||||
"content": f"m{i}",
|
||||
}
|
||||
for i in range(70)
|
||||
]
|
||||
consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||
side_effect=[(1200, "tiktoken"), (400, "tiktoken")]
|
||||
)
|
||||
consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999))
|
||||
consolidator.archive = AsyncMock(return_value=True)
|
||||
|
||||
await consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
archived_chunk = consolidator.archive.await_args.args[0]
|
||||
assert len(archived_chunk) == 50
|
||||
assert archived_chunk[0]["content"] == "m0"
|
||||
assert archived_chunk[-1]["content"] == "m49"
|
||||
assert session.last_consolidated == 50
|
||||
|
||||
async def test_chunk_cap_skips_when_no_user_boundary_within_cap(self, consolidator):
|
||||
"""If the cap would cut mid-turn, consolidation should skip that round."""
|
||||
consolidator._SAFETY_BUFFER = 0
|
||||
session = MagicMock()
|
||||
session.last_consolidated = 0
|
||||
session.key = "test:key"
|
||||
session.messages = [
|
||||
{
|
||||
"role": "user" if i in {0, 61} else "assistant",
|
||||
"content": f"m{i}",
|
||||
}
|
||||
for i in range(70)
|
||||
]
|
||||
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(1200, "tiktoken"))
|
||||
consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999))
|
||||
consolidator.archive = AsyncMock(return_value=True)
|
||||
|
||||
await consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
consolidator.archive.assert_not_awaited()
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
@ -307,7 +307,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, tools_used, messages = await loop._run_agent_loop(
|
||||
content, tools_used, messages, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
@ -331,7 +331,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path):
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, _, _ = await loop._run_agent_loop(
|
||||
content, _, _, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
@ -373,7 +373,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path):
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
content, tools_used, _ = await loop._run_agent_loop([])
|
||||
content, tools_used, _, _, _ = await loop._run_agent_loop([])
|
||||
assert content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
|
||||
44
tests/agent/test_mcp_connection.py
Normal file
44
tests/agent/test_mcp_connection.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Tests for MCP connection lifecycle in AgentLoop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
|
||||
def _make_loop(tmp_path, *, mcp_servers: dict | None = None) -> AgentLoop:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation.max_tokens = 4096
|
||||
return AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
mcp_servers=mcp_servers or {"test": object()},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_retries_when_no_servers_connect(tmp_path, monkeypatch: pytest.MonkeyPatch):
|
||||
loop = _make_loop(tmp_path)
|
||||
attempts = 0
|
||||
|
||||
async def _fake_connect(_servers, _registry):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
||||
|
||||
await loop._connect_mcp()
|
||||
await loop._connect_mcp()
|
||||
|
||||
assert attempts == 2
|
||||
assert loop._mcp_connected is False
|
||||
assert loop._mcp_stacks == {}
|
||||
File diff suppressed because it is too large
Load Diff
@ -5,11 +5,17 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
discord = pytest.importorskip("discord")
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.discord import MAX_MESSAGE_LEN, DiscordBotClient, DiscordChannel, DiscordConfig
|
||||
from nanobot.channels.discord import (
|
||||
MAX_MESSAGE_LEN,
|
||||
DiscordBotClient,
|
||||
DiscordChannel,
|
||||
DiscordConfig,
|
||||
)
|
||||
from nanobot.command.builtin import build_help_text
|
||||
|
||||
|
||||
@ -18,9 +24,11 @@ class _FakeDiscordClient:
|
||||
instances: list["_FakeDiscordClient"] = []
|
||||
start_error: Exception | None = None
|
||||
|
||||
def __init__(self, owner, *, intents) -> None:
|
||||
def __init__(self, owner, *, intents, proxy=None, proxy_auth=None) -> None:
|
||||
self.owner = owner
|
||||
self.intents = intents
|
||||
self.proxy = proxy
|
||||
self.proxy_auth = proxy_auth
|
||||
self.closed = False
|
||||
self.ready = True
|
||||
self.channels: dict[int, object] = {}
|
||||
@ -53,7 +61,9 @@ class _FakeDiscordClient:
|
||||
|
||||
class _FakeAttachment:
|
||||
# Attachment double that can simulate successful or failing save() calls.
|
||||
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
|
||||
def __init__(
|
||||
self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False
|
||||
) -> None:
|
||||
self.id = attachment_id
|
||||
self.filename = filename
|
||||
self.size = size
|
||||
@ -211,7 +221,7 @@ async def test_start_handles_client_construction_failure(monkeypatch) -> None:
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
def _boom(owner, *, intents):
|
||||
def _boom(owner, *, intents, proxy=None, proxy_auth=None):
|
||||
raise RuntimeError("bad client")
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
|
||||
@ -514,9 +524,7 @@ async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
||||
assert new_cmd is not None
|
||||
await new_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": "Processing /new...", "ephemeral": True}
|
||||
]
|
||||
assert interaction.response.messages == [{"content": "Processing /new...", "ephemeral": True}]
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "/new"
|
||||
assert handled[0]["sender_id"] == "123"
|
||||
@ -590,9 +598,7 @@ async def test_slash_help_returns_ephemeral_help_text() -> None:
|
||||
assert help_cmd is not None
|
||||
await help_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": build_help_text(), "ephemeral": True}
|
||||
]
|
||||
assert interaction.response.messages == [{"content": build_help_text(), "ephemeral": True}]
|
||||
assert handled == []
|
||||
|
||||
|
||||
@ -727,11 +733,13 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
||||
def typing(self):
|
||||
async def _waiter():
|
||||
await release.wait()
|
||||
|
||||
# Hold the loop so task remains active until explicitly stopped.
|
||||
class _Ctx(_TypingCtx):
|
||||
async def __aenter__(self):
|
||||
await super().__aenter__()
|
||||
await _waiter()
|
||||
|
||||
return _Ctx()
|
||||
|
||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||
@ -745,3 +753,117 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
def test_config_accepts_proxy_fields() -> None:
|
||||
config = DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
proxy_username="user",
|
||||
proxy_password="pass",
|
||||
)
|
||||
assert config.proxy == "http://127.0.0.1:7890"
|
||||
assert config.proxy_username == "user"
|
||||
assert config.proxy_password == "pass"
|
||||
|
||||
|
||||
def test_config_proxy_defaults_to_none() -> None:
|
||||
config = DiscordConfig(enabled=True, token="token", allow_from=["*"])
|
||||
assert config.proxy is None
|
||||
assert config.proxy_username is None
|
||||
assert config.proxy_password is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_passes_proxy_to_client(monkeypatch) -> None:
|
||||
_FakeDiscordClient.instances.clear()
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert len(_FakeDiscordClient.instances) == 1
|
||||
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_passes_proxy_auth_when_credentials_provided(monkeypatch) -> None:
|
||||
aiohttp = pytest.importorskip("aiohttp")
|
||||
_FakeDiscordClient.instances.clear()
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
proxy_username="user",
|
||||
proxy_password="pass",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert len(_FakeDiscordClient.instances) == 1
|
||||
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is not None
|
||||
assert isinstance(_FakeDiscordClient.instances[0].proxy_auth, aiohttp.BasicAuth)
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth.login == "user"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth.password == "pass"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_no_proxy_auth_when_only_username(monkeypatch) -> None:
|
||||
_FakeDiscordClient.instances.clear()
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
proxy_username="user",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_no_proxy_auth_when_only_password(monkeypatch) -> None:
|
||||
_FakeDiscordClient.instances.clear()
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
proxy_password="pass",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||
|
||||
48
tests/channels/test_feishu_domain.py
Normal file
48
tests/channels/test_feishu_domain.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""Tests for Feishu/Lark domain configuration."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
|
||||
|
||||
|
||||
def _make_channel(domain: str = "feishu") -> FeishuChannel:
|
||||
config = FeishuConfig(
|
||||
enabled=True,
|
||||
app_id="cli_test",
|
||||
app_secret="secret",
|
||||
allow_from=["*"],
|
||||
domain=domain,
|
||||
)
|
||||
ch = FeishuChannel(config, MessageBus())
|
||||
ch._client = MagicMock()
|
||||
ch._loop = None
|
||||
return ch
|
||||
|
||||
|
||||
class TestFeishuConfigDomain:
|
||||
def test_domain_default_is_feishu(self):
|
||||
config = FeishuConfig()
|
||||
assert config.domain == "feishu"
|
||||
|
||||
def test_domain_accepts_lark(self):
|
||||
config = FeishuConfig(domain="lark")
|
||||
assert config.domain == "lark"
|
||||
|
||||
def test_domain_accepts_feishu(self):
|
||||
config = FeishuConfig(domain="feishu")
|
||||
assert config.domain == "feishu"
|
||||
|
||||
def test_default_config_includes_domain(self):
|
||||
default_cfg = FeishuChannel.default_config()
|
||||
assert "domain" in default_cfg
|
||||
assert default_cfg["domain"] == "feishu"
|
||||
|
||||
def test_channel_persists_domain_from_config(self):
|
||||
ch = _make_channel(domain="lark")
|
||||
assert ch.config.domain == "lark"
|
||||
|
||||
def test_channel_persists_feishu_domain_from_config(self):
|
||||
ch = _make_channel(domain="feishu")
|
||||
assert ch.config.domain == "feishu"
|
||||
@ -5,6 +5,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig, _FeishuStreamBuf
|
||||
|
||||
@ -203,6 +204,55 @@ class TestSendDelta:
|
||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||
ch._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_resuming_keeps_buffer(self):
|
||||
"""_resuming=True flushes text to card but keeps the buffer for the next segment."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True})
|
||||
|
||||
assert "oc_chat1" in ch._stream_bufs
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert buf.card_id == "card_1"
|
||||
assert buf.sequence == 3
|
||||
ch._client.cardkit.v1.card_element.content.assert_called_once()
|
||||
ch._client.cardkit.v1.card.settings.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_resuming_then_final_end(self):
|
||||
"""Full multi-segment flow: resuming mid-turn, then final end closes the card."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Seg1", card_id="card_1", sequence=1, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response()
|
||||
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True})
|
||||
assert "oc_chat1" in ch._stream_bufs
|
||||
|
||||
ch._stream_bufs["oc_chat1"].text += " Seg2"
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
ch._client.cardkit.v1.card.settings.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_resuming_no_card_is_noop(self):
|
||||
"""_resuming with no card_id (card creation failed) is a safe no-op."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="text", card_id=None, sequence=0, last_edit=0.0,
|
||||
)
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True})
|
||||
|
||||
assert "oc_chat1" in ch._stream_bufs
|
||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_without_buf_is_noop(self):
|
||||
ch = _make_channel()
|
||||
@ -239,6 +289,146 @@ class TestSendDelta:
|
||||
assert buf.sequence == 7
|
||||
|
||||
|
||||
class TestToolHintInlineStreaming:
|
||||
"""Tool hint messages should be inlined into active streaming cards."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_hint_inlined_when_stream_active(self):
|
||||
"""With an active streaming buffer, tool hint appends to the card."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel="feishu", chat_id="oc_chat1",
|
||||
content='web_fetch("https://example.com")',
|
||||
metadata={"_tool_hint": True},
|
||||
)
|
||||
await ch.send(msg)
|
||||
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert '🔧 web_fetch("https://example.com")' in buf.text
|
||||
assert buf.sequence == 3
|
||||
ch._client.cardkit.v1.card_element.content.assert_called_once()
|
||||
ch._client.im.v1.message.create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_hint_preserved_on_next_delta(self):
|
||||
"""When new delta arrives, the tool hint is kept as permanent content and delta appends after it."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Partial answer\n\n🔧 web_fetch(\"url\")\n\n",
|
||||
card_id="card_1", sequence=3, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
|
||||
await ch.send_delta("oc_chat1", " continued")
|
||||
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert "Partial answer" in buf.text
|
||||
assert "🔧 web_fetch" in buf.text
|
||||
assert buf.text.endswith(" continued")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_hint_fallback_when_no_stream(self):
|
||||
"""Without an active buffer, tool hint falls back to a standalone card."""
|
||||
ch = _make_channel()
|
||||
ch._client.im.v1.message.create.return_value = _mock_send_response("om_hint")
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel="feishu", chat_id="oc_chat1",
|
||||
content='read_file("path")',
|
||||
metadata={"_tool_hint": True},
|
||||
)
|
||||
await ch.send(msg)
|
||||
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
ch._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consecutive_tool_hints_append(self):
|
||||
"""When multiple tool hints arrive consecutively, each appends to the card."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
|
||||
msg1 = OutboundMessage(
|
||||
channel="feishu", chat_id="oc_chat1",
|
||||
content='$ cd /project', metadata={"_tool_hint": True},
|
||||
)
|
||||
await ch.send(msg1)
|
||||
|
||||
msg2 = OutboundMessage(
|
||||
channel="feishu", chat_id="oc_chat1",
|
||||
content='$ git status', metadata={"_tool_hint": True},
|
||||
)
|
||||
await ch.send(msg2)
|
||||
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert "$ cd /project" in buf.text
|
||||
assert "$ git status" in buf.text
|
||||
assert buf.text.startswith("Partial answer")
|
||||
assert "🔧 $ cd /project" in buf.text
|
||||
assert "🔧 $ git status" in buf.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_hint_preserved_on_resuming_flush(self):
|
||||
"""When _resuming flushes the buffer, tool hint is kept as permanent content."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Partial answer\n\n🔧 $ cd /project\n\n",
|
||||
card_id="card_1", sequence=2, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True, "_resuming": True})
|
||||
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert "Partial answer" in buf.text
|
||||
assert "🔧 $ cd /project" in buf.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_hint_preserved_on_final_stream_end(self):
|
||||
"""When final _stream_end closes the card, tool hint is kept in the final text."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Final content\n\n🔧 web_fetch(\"url\")\n\n",
|
||||
card_id="card_1", sequence=3, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response()
|
||||
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
update_call = ch._client.cardkit.v1.card_element.content.call_args[0][0]
|
||||
assert "🔧" in update_call.body.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_tool_hint_is_noop(self):
|
||||
"""Empty or whitespace-only tool hint content is silently ignored."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
|
||||
)
|
||||
|
||||
for content in ("", " ", "\t\n"):
|
||||
msg = OutboundMessage(
|
||||
channel="feishu", chat_id="oc_chat1",
|
||||
content=content, metadata={"_tool_hint": True},
|
||||
)
|
||||
await ch.send(msg)
|
||||
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert buf.text == "Partial answer"
|
||||
assert buf.sequence == 2
|
||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||
|
||||
|
||||
class TestSendMessageReturnsId:
|
||||
def test_returns_message_id_on_success(self):
|
||||
ch = _make_channel()
|
||||
|
||||
304
tests/channels/test_qq_media.py
Normal file
304
tests/channels/test_qq_media.py
Normal file
@ -0,0 +1,304 @@
|
||||
"""Tests for QQ channel media support: helpers, send, inbound, and upload."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from nanobot.channels import qq
|
||||
|
||||
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
|
||||
except ImportError:
|
||||
QQ_AVAILABLE = False
|
||||
|
||||
if not QQ_AVAILABLE:
|
||||
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.qq import (
|
||||
QQ_FILE_TYPE_FILE,
|
||||
QQ_FILE_TYPE_IMAGE,
|
||||
QQChannel,
|
||||
QQConfig,
|
||||
_guess_send_file_type,
|
||||
_is_image_name,
|
||||
_sanitize_filename,
|
||||
)
|
||||
|
||||
|
||||
class _FakeApi:
|
||||
def __init__(self) -> None:
|
||||
self.c2c_calls: list[dict] = []
|
||||
self.group_calls: list[dict] = []
|
||||
|
||||
async def post_c2c_message(self, **kwargs) -> None:
|
||||
self.c2c_calls.append(kwargs)
|
||||
|
||||
async def post_group_message(self, **kwargs) -> None:
|
||||
self.group_calls.append(kwargs)
|
||||
|
||||
|
||||
class _FakeHttp:
|
||||
"""Fake _http for _post_base64file tests."""
|
||||
|
||||
def __init__(self, return_value: dict | None = None) -> None:
|
||||
self.return_value = return_value or {}
|
||||
self.calls: list[tuple] = []
|
||||
|
||||
async def request(self, route, **kwargs):
|
||||
self.calls.append((route, kwargs))
|
||||
return self.return_value
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, http_return: dict | None = None) -> None:
|
||||
self.api = _FakeApi()
|
||||
self.api._http = _FakeHttp(http_return)
|
||||
|
||||
|
||||
# ── Helper function tests (pure, no async) ──────────────────────────
|
||||
|
||||
|
||||
def test_sanitize_filename_strips_path_traversal() -> None:
|
||||
assert _sanitize_filename("../../etc/passwd") == "passwd"
|
||||
|
||||
|
||||
def test_sanitize_filename_keeps_chinese_chars() -> None:
|
||||
assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg"
|
||||
|
||||
|
||||
def test_sanitize_filename_strips_unsafe_chars() -> None:
|
||||
result = _sanitize_filename('file<>:"|?*.txt')
|
||||
# All unsafe chars replaced with "_", but * is replaced too
|
||||
assert result.startswith("file")
|
||||
assert result.endswith(".txt")
|
||||
assert "<" not in result
|
||||
assert ">" not in result
|
||||
assert '"' not in result
|
||||
assert "|" not in result
|
||||
assert "?" not in result
|
||||
|
||||
|
||||
def test_sanitize_filename_empty_input() -> None:
|
||||
assert _sanitize_filename("") == ""
|
||||
|
||||
|
||||
def test_is_image_name_with_known_extensions() -> None:
|
||||
for ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".ico", ".svg"):
|
||||
assert _is_image_name(f"photo{ext}") is True
|
||||
|
||||
|
||||
def test_is_image_name_with_unknown_extension() -> None:
|
||||
for ext in (".pdf", ".txt", ".mp3", ".mp4"):
|
||||
assert _is_image_name(f"doc{ext}") is False
|
||||
|
||||
|
||||
def test_guess_send_file_type_image() -> None:
|
||||
assert _guess_send_file_type("photo.png") == QQ_FILE_TYPE_IMAGE
|
||||
assert _guess_send_file_type("pic.jpg") == QQ_FILE_TYPE_IMAGE
|
||||
|
||||
|
||||
def test_guess_send_file_type_file() -> None:
|
||||
assert _guess_send_file_type("doc.pdf") == QQ_FILE_TYPE_FILE
|
||||
|
||||
|
||||
def test_guess_send_file_type_by_mime() -> None:
|
||||
# A filename with no known extension but whose mime type is image/*
|
||||
assert _guess_send_file_type("photo.xyz_image_test") == QQ_FILE_TYPE_FILE
|
||||
|
||||
|
||||
# ── send() exception handling ───────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_exception_caught_not_raised() -> None:
|
||||
"""Exceptions inside send() must not propagate."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with patch.object(channel, "_send_text_only", new_callable=AsyncMock, side_effect=RuntimeError("boom")):
|
||||
await channel.send(
|
||||
OutboundMessage(channel="qq", chat_id="user1", content="hello")
|
||||
)
|
||||
# No exception raised — test passes if we get here.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_then_text() -> None:
|
||||
"""Media is sent before text when both are present."""
|
||||
import tempfile
|
||||
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
with patch.object(channel, "_post_base64file", new_callable=AsyncMock, return_value={"file_info": "1"}) as mock_upload:
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user1",
|
||||
content="text after image",
|
||||
media=[tmp],
|
||||
metadata={"message_id": "m1"},
|
||||
)
|
||||
)
|
||||
assert mock_upload.called
|
||||
|
||||
# Text should have been sent via c2c (default chat type)
|
||||
text_calls = [c for c in channel._client.api.c2c_calls if c.get("msg_type") == 0]
|
||||
assert len(text_calls) >= 1
|
||||
assert text_calls[-1]["content"] == "text after image"
|
||||
finally:
|
||||
import os
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_failure_falls_back_to_text() -> None:
|
||||
"""When _send_media returns False, a failure notice is appended."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with patch.object(channel, "_send_media", new_callable=AsyncMock, return_value=False):
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user1",
|
||||
content="hello",
|
||||
media=["https://example.com/bad.png"],
|
||||
metadata={"message_id": "m1"},
|
||||
)
|
||||
)
|
||||
|
||||
# Should have the failure text among the c2c calls
|
||||
failure_calls = [c for c in channel._client.api.c2c_calls if "Attachment send failed" in c.get("content", "")]
|
||||
assert len(failure_calls) == 1
|
||||
assert "bad.png" in failure_calls[0]["content"]
|
||||
|
||||
|
||||
# ── _on_message() exception handling ────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_exception_caught_not_raised() -> None:
|
||||
"""Missing required attributes should not crash _on_message."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
# Construct a message-like object that lacks 'author' — triggers AttributeError
|
||||
bad_data = SimpleNamespace(id="x1", content="hi")
|
||||
# Should not raise
|
||||
await channel._on_message(bad_data, is_group=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_with_attachments() -> None:
|
||||
"""Messages with attachments produce media_paths and formatted content."""
|
||||
import tempfile
|
||||
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
saved_path = f.name
|
||||
|
||||
att = SimpleNamespace(url="", filename="screenshot.png", content_type="image/png")
|
||||
|
||||
# Patch _download_to_media_dir_chunked to return the temp file path
|
||||
async def fake_download(url, filename_hint=""):
|
||||
return saved_path
|
||||
|
||||
try:
|
||||
with patch.object(channel, "_download_to_media_dir_chunked", side_effect=fake_download):
|
||||
data = SimpleNamespace(
|
||||
id="att1",
|
||||
content="look at this",
|
||||
author=SimpleNamespace(user_openid="u1"),
|
||||
attachments=[att],
|
||||
)
|
||||
await channel._on_message(data, is_group=False)
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert "look at this" in msg.content
|
||||
assert "screenshot.png" in msg.content
|
||||
assert "Received files:" in msg.content
|
||||
assert len(msg.media) == 1
|
||||
assert msg.media[0] == saved_path
|
||||
finally:
|
||||
import os
|
||||
os.unlink(saved_path)
|
||||
|
||||
|
||||
# ── _post_base64file() ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_base64file_omits_file_name_for_images() -> None:
|
||||
"""file_type=1 (image) → payload must not contain file_name."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
channel._client = _FakeClient(http_return={"file_info": "img_abc"})
|
||||
|
||||
await channel._post_base64file(
|
||||
chat_id="user1",
|
||||
is_group=False,
|
||||
file_type=QQ_FILE_TYPE_IMAGE,
|
||||
file_data="ZmFrZQ==",
|
||||
file_name="photo.png",
|
||||
)
|
||||
|
||||
http = channel._client.api._http
|
||||
assert len(http.calls) == 1
|
||||
payload = http.calls[0][1]["json"]
|
||||
assert "file_name" not in payload
|
||||
assert payload["file_type"] == QQ_FILE_TYPE_IMAGE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_base64file_includes_file_name_for_files() -> None:
|
||||
"""file_type=4 (file) → payload must contain file_name."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
channel._client = _FakeClient(http_return={"file_info": "file_abc"})
|
||||
|
||||
await channel._post_base64file(
|
||||
chat_id="user1",
|
||||
is_group=False,
|
||||
file_type=QQ_FILE_TYPE_FILE,
|
||||
file_data="ZmFrZQ==",
|
||||
file_name="report.pdf",
|
||||
)
|
||||
|
||||
http = channel._client.api._http
|
||||
assert len(http.calls) == 1
|
||||
payload = http.calls[0][1]["json"]
|
||||
assert payload["file_name"] == "report.pdf"
|
||||
assert payload["file_type"] == QQ_FILE_TYPE_FILE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_base64file_filters_response_to_file_info() -> None:
|
||||
"""Response with file_info + extra fields must be filtered to only file_info."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
channel._client = _FakeClient(http_return={
|
||||
"file_info": "fi_123",
|
||||
"file_uuid": "uuid_xxx",
|
||||
"ttl": 3600,
|
||||
})
|
||||
|
||||
result = await channel._post_base64file(
|
||||
chat_id="user1",
|
||||
is_group=False,
|
||||
file_type=QQ_FILE_TYPE_FILE,
|
||||
file_data="ZmFrZQ==",
|
||||
file_name="doc.pdf",
|
||||
)
|
||||
|
||||
assert result == {"file_info": "fi_123"}
|
||||
assert "file_uuid" not in result
|
||||
assert "ttl" not in result
|
||||
598
tests/channels/test_websocket_channel.py
Normal file
598
tests/channels/test_websocket_channel.py
Normal file
@ -0,0 +1,598 @@
|
||||
"""Unit and lightweight integration tests for the WebSocket channel."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.frames import Close
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.channels.websocket import (
|
||||
WebSocketChannel,
|
||||
WebSocketConfig,
|
||||
_issue_route_secret_matches,
|
||||
_normalize_config_path,
|
||||
_normalize_http_path,
|
||||
_parse_inbound_payload,
|
||||
_parse_query,
|
||||
_parse_request_path,
|
||||
)
|
||||
|
||||
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
|
||||
|
||||
_PORT = 29876
|
||||
|
||||
|
||||
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
|
||||
cfg: dict[str, Any] = {
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": _PORT,
|
||||
"path": "/ws",
|
||||
"websocketRequiresToken": False,
|
||||
}
|
||||
cfg.update(kw)
|
||||
return WebSocketChannel(cfg, bus)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def bus() -> MagicMock:
|
||||
b = MagicMock()
|
||||
b.publish_inbound = AsyncMock()
|
||||
return b
|
||||
|
||||
|
||||
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
|
||||
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
|
||||
return await asyncio.to_thread(
|
||||
functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0)
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_http_path_strips_trailing_slash_except_root() -> None:
|
||||
assert _normalize_http_path("/chat/") == "/chat"
|
||||
assert _normalize_http_path("/chat?x=1") == "/chat"
|
||||
assert _normalize_http_path("/") == "/"
|
||||
|
||||
|
||||
def test_parse_request_path_matches_normalize_and_query() -> None:
|
||||
path, query = _parse_request_path("/ws/?token=secret&client_id=u1")
|
||||
assert path == _normalize_http_path("/ws/?token=secret&client_id=u1")
|
||||
assert query == _parse_query("/ws/?token=secret&client_id=u1")
|
||||
|
||||
|
||||
def test_normalize_config_path_matches_request() -> None:
|
||||
assert _normalize_config_path("/ws/") == "/ws"
|
||||
assert _normalize_config_path("/") == "/"
|
||||
|
||||
|
||||
def test_parse_query_extracts_token_and_client_id() -> None:
|
||||
query = _parse_query("/?token=secret&client_id=u1")
|
||||
assert query.get("token") == ["secret"]
|
||||
assert query.get("client_id") == ["u1"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "expected"),
|
||||
[
|
||||
("plain", "plain"),
|
||||
('{"content": "hi"}', "hi"),
|
||||
('{"text": "there"}', "there"),
|
||||
('{"message": "x"}', "x"),
|
||||
(" ", None),
|
||||
("{}", None),
|
||||
],
|
||||
)
|
||||
def test_parse_inbound_payload(raw: str, expected: str | None) -> None:
|
||||
assert _parse_inbound_payload(raw) == expected
|
||||
|
||||
|
||||
def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
|
||||
assert _parse_inbound_payload("{not json") == "{not json"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "expected"),
|
||||
[
|
||||
('{"content": ""}', None), # empty string content
|
||||
('{"content": 123}', None), # non-string content
|
||||
('{"content": " "}', None), # whitespace-only content
|
||||
('["hello"]', '["hello"]'), # JSON array: not a dict, treated as plain text
|
||||
('{"unknown_key": "val"}', None), # unrecognized key
|
||||
('{"content": null}', None), # null content
|
||||
],
|
||||
)
|
||||
def test_parse_inbound_payload_edge_cases(raw: str, expected: str | None) -> None:
|
||||
assert _parse_inbound_payload(raw) == expected
|
||||
|
||||
|
||||
def test_web_socket_config_path_must_start_with_slash() -> None:
|
||||
with pytest.raises(ValueError, match='path must start with "/"'):
|
||||
WebSocketConfig(path="bad")
|
||||
|
||||
|
||||
def test_ssl_context_requires_both_cert_and_key_files() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
|
||||
bus,
|
||||
)
|
||||
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
|
||||
channel._build_ssl_context()
|
||||
|
||||
|
||||
def test_default_config_includes_safe_bind_and_streaming() -> None:
|
||||
defaults = WebSocketChannel.default_config()
|
||||
assert defaults["enabled"] is False
|
||||
assert defaults["host"] == "127.0.0.1"
|
||||
assert defaults["streaming"] is True
|
||||
assert defaults["allowFrom"] == ["*"]
|
||||
assert defaults.get("tokenIssuePath", "") == ""
|
||||
|
||||
|
||||
def test_token_issue_path_must_differ_from_websocket_path() -> None:
|
||||
with pytest.raises(ValueError, match="token_issue_path must differ"):
|
||||
WebSocketConfig(path="/ws", token_issue_path="/ws")
|
||||
|
||||
|
||||
def test_issue_route_secret_matches_bearer_and_header() -> None:
|
||||
from websockets.datastructures import Headers
|
||||
|
||||
secret = "my-secret"
|
||||
bearer_headers = Headers([("Authorization", "Bearer my-secret")])
|
||||
assert _issue_route_secret_matches(bearer_headers, secret) is True
|
||||
x_headers = Headers([("X-Nanobot-Auth", "my-secret")])
|
||||
assert _issue_route_secret_matches(x_headers, secret) is True
|
||||
wrong = Headers([("Authorization", "Bearer other")])
|
||||
assert _issue_route_secret_matches(wrong, secret) is False
|
||||
|
||||
|
||||
def test_issue_route_secret_matches_empty_secret() -> None:
|
||||
from websockets.datastructures import Headers
|
||||
|
||||
# Empty secret always returns True regardless of headers
|
||||
assert _issue_route_secret_matches(Headers([]), "") is True
|
||||
assert _issue_route_secret_matches(Headers([("Authorization", "Bearer anything")]), "") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="chat-1",
|
||||
content="hello",
|
||||
reply_to="m1",
|
||||
media=["/tmp/a.png"],
|
||||
)
|
||||
await channel.send(msg)
|
||||
|
||||
mock_ws.send.assert_awaited_once()
|
||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||
assert payload["event"] == "message"
|
||||
assert payload["text"] == "hello"
|
||||
assert payload["reply_to"] == "m1"
|
||||
assert payload["media"] == ["/tmp/a.png"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_missing_connection_is_noop_without_error() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
|
||||
await channel.send(msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_removes_connection_on_connection_closed() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
|
||||
await channel.send(msg)
|
||||
|
||||
assert "chat-1" not in channel._connections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_removes_connection_on_connection_closed() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
await channel.send_delta("chat-1", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
|
||||
|
||||
assert "chat-1" not in channel._connections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_emits_delta_and_stream_end() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"})
|
||||
await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"})
|
||||
|
||||
assert mock_ws.send.await_count == 2
|
||||
first = json.loads(mock_ws.send.call_args_list[0][0][0])
|
||||
second = json.loads(mock_ws.send.call_args_list[1][0][0])
|
||||
assert first["event"] == "delta"
|
||||
assert first["text"] == "part"
|
||||
assert first["stream_id"] == "sid"
|
||||
assert second["event"] == "stream_end"
|
||||
assert second["stream_id"] == "sid"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_non_connection_closed_exception_is_raised() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = RuntimeError("unexpected")
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
|
||||
with pytest.raises(RuntimeError, match="unexpected"):
|
||||
await channel.send(msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_missing_connection_is_noop() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||
# No exception, no error — just a no-op
|
||||
await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_is_idempotent() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
# stop() before start() should not raise
|
||||
await channel.stop()
|
||||
await channel.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound(bus: MagicMock) -> None:
|
||||
port = 29876
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client:
|
||||
ready_raw = await client.recv()
|
||||
ready = json.loads(ready_raw)
|
||||
assert ready["event"] == "ready"
|
||||
assert ready["client_id"] == "tester"
|
||||
chat_id = ready["chat_id"]
|
||||
|
||||
await client.send(json.dumps({"content": "ping from client"}))
|
||||
await asyncio.sleep(0.08)
|
||||
|
||||
bus.publish_inbound.assert_awaited()
|
||||
inbound = bus.publish_inbound.call_args[0][0]
|
||||
assert inbound.channel == "websocket"
|
||||
assert inbound.sender_id == "tester"
|
||||
assert inbound.chat_id == chat_id
|
||||
assert inbound.content == "ping from client"
|
||||
|
||||
await client.send("plain text frame")
|
||||
await asyncio.sleep(0.08)
|
||||
assert bus.publish_inbound.await_count >= 2
|
||||
second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1]
|
||||
assert second.content == "plain text frame"
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_rejects_handshake_when_mismatch(bus: MagicMock) -> None:
|
||||
port = 29877
|
||||
channel = _ch(bus, port=port, path="/", token="secret")
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"):
|
||||
pass
|
||||
assert excinfo.value.response.status_code == 401
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_path_returns_404(bus: MagicMock) -> None:
|
||||
port = 29878
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/other"):
|
||||
pass
|
||||
assert excinfo.value.response.status_code == 404
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
def test_registry_discovers_websocket_channel() -> None:
|
||||
from nanobot.channels.registry import load_channel_class
|
||||
|
||||
cls = load_channel_class("websocket")
|
||||
assert cls.name == "websocket"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock) -> None:
|
||||
port = 29879
|
||||
channel = _ch(
|
||||
bus, port=port,
|
||||
tokenIssuePath="/auth/token",
|
||||
tokenIssueSecret="route-secret",
|
||||
websocketRequiresToken=True,
|
||||
)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
deny = await _http_get(f"http://127.0.0.1:{port}/auth/token")
|
||||
assert deny.status_code == 401
|
||||
|
||||
issue = await _http_get(
|
||||
f"http://127.0.0.1:{port}/auth/token",
|
||||
headers={"Authorization": "Bearer route-secret"},
|
||||
)
|
||||
assert issue.status_code == 200
|
||||
token = issue.json()["token"]
|
||||
assert token.startswith("nbwt_")
|
||||
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"):
|
||||
pass
|
||||
assert missing_token.value.response.status_code == 401
|
||||
|
||||
uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller"
|
||||
async with websockets.connect(uri) as client:
|
||||
ready = json.loads(await client.recv())
|
||||
assert ready["event"] == "ready"
|
||||
assert ready["client_id"] == "caller"
|
||||
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as reuse:
|
||||
async with websockets.connect(uri):
|
||||
pass
|
||||
assert reuse.value.response.status_code == 401
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
|
||||
port = 29880
|
||||
channel = _ch(bus, port=port, streaming=True)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=stream-tester") as client:
|
||||
ready_raw = await client.recv()
|
||||
ready = json.loads(ready_raw)
|
||||
chat_id = ready["chat_id"]
|
||||
|
||||
# Server pushes deltas directly
|
||||
await channel.send_delta(
|
||||
chat_id, "Hello ", {"_stream_delta": True, "_stream_id": "s1"}
|
||||
)
|
||||
await channel.send_delta(
|
||||
chat_id, "world", {"_stream_delta": True, "_stream_id": "s1"}
|
||||
)
|
||||
await channel.send_delta(
|
||||
chat_id, "", {"_stream_end": True, "_stream_id": "s1"}
|
||||
)
|
||||
|
||||
delta1 = json.loads(await client.recv())
|
||||
assert delta1["event"] == "delta"
|
||||
assert delta1["text"] == "Hello "
|
||||
assert delta1["stream_id"] == "s1"
|
||||
|
||||
delta2 = json.loads(await client.recv())
|
||||
assert delta2["event"] == "delta"
|
||||
assert delta2["text"] == "world"
|
||||
assert delta2["stream_id"] == "s1"
|
||||
|
||||
end = json.loads(await client.recv())
|
||||
assert end["event"] == "stream_end"
|
||||
assert end["stream_id"] == "s1"
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None:
|
||||
port = 29881
|
||||
channel = _ch(bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="s")
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
# Fill issued tokens to capacity
|
||||
channel._issued_tokens = {
|
||||
f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._MAX_ISSUED_TOKENS)
|
||||
}
|
||||
|
||||
resp = await _http_get(
|
||||
f"http://127.0.0.1:{port}/auth/token",
|
||||
headers={"Authorization": "Bearer s"},
|
||||
)
|
||||
assert resp.status_code == 429
|
||||
data = resp.json()
|
||||
assert "error" in data
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_from_rejects_unauthorized_client_id(bus: MagicMock) -> None:
|
||||
port = 29882
|
||||
channel = _ch(bus, port=port, allowFrom=["alice", "bob"])
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=eve"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 403
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_id_truncation(bus: MagicMock) -> None:
|
||||
port = 29883
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
long_id = "x" * 200
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id={long_id}") as client:
|
||||
ready = json.loads(await client.recv())
|
||||
assert ready["client_id"] == "x" * 128
|
||||
assert len(ready["client_id"]) == 128
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_utf8_binary_frame_ignored(bus: MagicMock) -> None:
|
||||
port = 29884
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=bin-test") as client:
|
||||
await client.recv() # consume ready
|
||||
# Send non-UTF-8 bytes
|
||||
await client.send(b"\xff\xfe\xfd")
|
||||
await asyncio.sleep(0.05)
|
||||
# publish_inbound should NOT have been called
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_token_accepts_issued_token_as_fallback(bus: MagicMock) -> None:
|
||||
port = 29885
|
||||
channel = _ch(
|
||||
bus, port=port,
|
||||
token="static-secret",
|
||||
tokenIssuePath="/auth/token",
|
||||
tokenIssueSecret="route-secret",
|
||||
)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
# Get an issued token
|
||||
resp = await _http_get(
|
||||
f"http://127.0.0.1:{port}/auth/token",
|
||||
headers={"Authorization": "Bearer route-secret"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
issued_token = resp.json()["token"]
|
||||
|
||||
# Connect using issued token (not the static one)
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?token={issued_token}&client_id=caller") as client:
|
||||
ready = json.loads(await client.recv())
|
||||
assert ready["event"] == "ready"
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_from_empty_list_denies_all(bus: MagicMock) -> None:
|
||||
port = 29886
|
||||
channel = _ch(bus, port=port, allowFrom=[])
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=anyone"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 403
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_requires_token_without_issue_path(bus: MagicMock) -> None:
|
||||
"""When websocket_requires_token is True but no token or issue path configured, all connections are rejected."""
|
||||
port = 29887
|
||||
channel = _ch(bus, port=port, websocketRequiresToken=True)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
# No token at all → 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 401
|
||||
|
||||
# Wrong token → 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u&token=wrong"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 401
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
478
tests/channels/test_websocket_integration.py
Normal file
478
tests/channels/test_websocket_integration.py
Normal file
@ -0,0 +1,478 @@
|
||||
"""Integration tests for the WebSocket channel using WsTestClient.
|
||||
|
||||
Complements the unit/lightweight tests in test_websocket_channel.py by covering
|
||||
multi-client scenarios, edge cases, and realistic usage patterns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
|
||||
from nanobot.channels.websocket import WebSocketChannel
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from ws_test_client import WsTestClient, issue_token, issue_token_ok
|
||||
|
||||
|
||||
def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
|
||||
cfg: dict[str, Any] = {
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"path": "/",
|
||||
"websocketRequiresToken": False,
|
||||
}
|
||||
cfg.update(kw)
|
||||
return WebSocketChannel(cfg, bus)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def bus() -> MagicMock:
|
||||
b = MagicMock()
|
||||
b.publish_inbound = AsyncMock()
|
||||
return b
|
||||
|
||||
|
||||
# -- Connection basics ----------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ready_event_fields(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29901)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29901/", client_id="c1") as c:
|
||||
r = await c.recv_ready()
|
||||
assert r.event == "ready"
|
||||
assert len(r.chat_id) == 36
|
||||
assert r.client_id == "c1"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29902)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29902/", client_id="") as c:
|
||||
r = await c.recv_ready()
|
||||
assert r.client_id.startswith("anon-")
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_each_connection_unique_chat_id(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29903)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29903/", client_id="a") as c1:
|
||||
async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2:
|
||||
assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Inbound messages (client -> server) ----------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_text(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29904)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29904/", client_id="p") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_text("hello world")
|
||||
await asyncio.sleep(0.1)
|
||||
inbound = bus.publish_inbound.call_args[0][0]
|
||||
assert inbound.content == "hello world"
|
||||
assert inbound.sender_id == "p"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_content_field(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29905)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29905/", client_id="j") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_json({"content": "structured"})
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "structured"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_text_and_message_fields(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29906)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29906/", client_id="x") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_json({"text": "via text"})
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "via text"
|
||||
await c.send_json({"message": "via message"})
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "via message"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_payload_ignored(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29907)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29907/", client_id="e") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_text(" ")
|
||||
await c.send_json({})
|
||||
await asyncio.sleep(0.1)
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_preserve_order(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29908)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29908/", client_id="o") as c:
|
||||
await c.recv_ready()
|
||||
for i in range(5):
|
||||
await c.send_text(f"msg-{i}")
|
||||
await asyncio.sleep(0.2)
|
||||
contents = [call[0][0].content for call in bus.publish_inbound.call_args_list]
|
||||
assert contents == [f"msg-{i}" for i in range(5)]
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Outbound messages (server -> client) ---------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_send_message(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29909)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29909/", client_id="r") as c:
|
||||
ready = await c.recv_ready()
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content="reply",
|
||||
))
|
||||
msg = await c.recv_message()
|
||||
assert msg.text == "reply"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_send_with_media_and_reply(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29910)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29910/", client_id="m") as c:
|
||||
ready = await c.recv_ready()
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content="img",
|
||||
media=["/tmp/a.png"], reply_to="m1",
|
||||
))
|
||||
msg = await c.recv_message()
|
||||
assert msg.text == "img"
|
||||
assert msg.media == ["/tmp/a.png"]
|
||||
assert msg.reply_to == "m1"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Streaming ------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_deltas_and_end(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29911, streaming=True)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29911/", client_id="s") as c:
|
||||
cid = (await c.recv_ready()).chat_id
|
||||
for part in ("Hello", " ", "world", "!"):
|
||||
await ch.send_delta(cid, part, {"_stream_delta": True, "_stream_id": "s1"})
|
||||
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "s1"})
|
||||
|
||||
msgs = await c.collect_stream()
|
||||
deltas = [m for m in msgs if m.event == "delta"]
|
||||
assert "".join(d.text for d in deltas) == "Hello world!"
|
||||
ends = [m for m in msgs if m.event == "stream_end"]
|
||||
assert len(ends) == 1
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interleaved_streams(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29912, streaming=True)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29912/", client_id="i") as c:
|
||||
cid = (await c.recv_ready()).chat_id
|
||||
await ch.send_delta(cid, "A1", {"_stream_delta": True, "_stream_id": "sa"})
|
||||
await ch.send_delta(cid, "B1", {"_stream_delta": True, "_stream_id": "sb"})
|
||||
await ch.send_delta(cid, "A2", {"_stream_delta": True, "_stream_id": "sa"})
|
||||
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sa"})
|
||||
await ch.send_delta(cid, "B2", {"_stream_delta": True, "_stream_id": "sb"})
|
||||
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sb"})
|
||||
|
||||
msgs = await c.recv_n(6)
|
||||
sa = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sa")
|
||||
sb = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sb")
|
||||
assert sa == "A1A2"
|
||||
assert sb == "B1B2"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Multi-client ---------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_independent_sessions(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29913)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u1") as c1:
|
||||
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u2") as c2:
|
||||
r1, r2 = await c1.recv_ready(), await c2.recv_ready()
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=r1.chat_id, content="for-u1",
|
||||
))
|
||||
assert (await c1.recv_message()).text == "for-u1"
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=r2.chat_id, content="for-u2",
|
||||
))
|
||||
assert (await c2.recv_message()).text == "for-u2"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnected_client_cleanup(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29914)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29914/", client_id="tmp") as c:
|
||||
chat_id = (await c.recv_ready()).chat_id
|
||||
# disconnected
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=chat_id, content="orphan",
|
||||
))
|
||||
assert chat_id not in ch._connections
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Authentication -------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_token_accepted(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29915, token="secret")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c:
|
||||
assert (await c.recv_ready()).client_id == "a"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_token_rejected(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29916, token="correct")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29916/", client_id="b", token="wrong"):
|
||||
pass
|
||||
assert exc.value.response.status_code == 401
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_issue_full_flow(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29917, path="/ws",
|
||||
tokenIssuePath="/auth/token", tokenIssueSecret="s",
|
||||
websocketRequiresToken=True)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
# no secret -> 401
|
||||
_, status = await issue_token(port=29917, issue_path="/auth/token")
|
||||
assert status == 401
|
||||
|
||||
# with secret -> token
|
||||
token = await issue_token_ok(port=29917, issue_path="/auth/token", secret="s")
|
||||
|
||||
# no token -> 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="x"):
|
||||
pass
|
||||
assert exc.value.response.status_code == 401
|
||||
|
||||
# valid token -> ok
|
||||
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="ok", token=token) as c:
|
||||
assert (await c.recv_ready()).client_id == "ok"
|
||||
|
||||
# reuse -> 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="r", token=token):
|
||||
pass
|
||||
assert exc.value.response.status_code == 401
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Path routing ---------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_path(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29918, path="/my-chat")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c:
|
||||
assert (await c.recv_ready()).event == "ready"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_path_404(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29919, path="/ws")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29919/wrong", client_id="x"):
|
||||
pass
|
||||
assert exc.value.response.status_code == 404
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trailing_slash_normalized(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29920, path="/ws")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c:
|
||||
assert (await c.recv_ready()).event == "ready"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Edge cases -----------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_message(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29921)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29921/", client_id="big") as c:
|
||||
await c.recv_ready()
|
||||
big = "x" * 100_000
|
||||
await c.send_text(big)
|
||||
await asyncio.sleep(0.2)
|
||||
assert bus.publish_inbound.call_args[0][0].content == big
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unicode_roundtrip(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29922)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29922/", client_id="u") as c:
|
||||
ready = await c.recv_ready()
|
||||
text = "你好世界 🌍 日本語テスト"
|
||||
await c.send_text(text)
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == text
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content=text,
|
||||
))
|
||||
assert (await c.recv_message()).text == text
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_fire(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29923)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29923/", client_id="r") as c:
|
||||
ready = await c.recv_ready()
|
||||
for i in range(50):
|
||||
await c.send_text(f"in-{i}")
|
||||
await asyncio.sleep(0.5)
|
||||
assert bus.publish_inbound.await_count == 50
|
||||
for i in range(50):
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content=f"out-{i}",
|
||||
))
|
||||
received = [(await c.recv_message()).text for _ in range(50)]
|
||||
assert received == [f"out-{i}" for i in range(50)]
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_as_plain_text(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29924)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29924/", client_id="j") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_text("{broken json")
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "{broken json"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
584
tests/channels/test_wecom_channel.py
Normal file
584
tests/channels/test_wecom_channel.py
Normal file
@ -0,0 +1,584 @@
|
||||
"""Tests for WeCom channel: helpers, download, upload, send, and message processing."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import importlib.util
|
||||
|
||||
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||
except ImportError:
|
||||
WECOM_AVAILABLE = False
|
||||
|
||||
if not WECOM_AVAILABLE:
|
||||
pytest.skip("WeCom dependencies not installed (wecom_aibot_sdk)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.wecom import (
|
||||
WecomChannel,
|
||||
WecomConfig,
|
||||
_guess_wecom_media_type,
|
||||
_sanitize_filename,
|
||||
)
|
||||
|
||||
# Try to import the real response class; fall back to a stub if unavailable.
|
||||
try:
|
||||
from wecom_aibot_sdk.utils import WsResponse
|
||||
|
||||
_RealWsResponse = WsResponse
|
||||
except ImportError:
|
||||
_RealWsResponse = None
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
"""Minimal stand-in for wecom_aibot_sdk WsResponse."""
|
||||
|
||||
def __init__(self, errcode: int = 0, body: dict | None = None, errmsg: str = "ok"):
|
||||
self.errcode = errcode
|
||||
self.errmsg = errmsg
|
||||
self.body = body or {}
|
||||
|
||||
|
||||
class _FakeWsManager:
|
||||
"""Tracks send_reply calls and returns configurable responses."""
|
||||
|
||||
def __init__(self, responses: list[_FakeResponse] | None = None):
|
||||
self.responses = responses or []
|
||||
self.calls: list[tuple[str, dict, str]] = []
|
||||
self._idx = 0
|
||||
|
||||
async def send_reply(self, req_id: str, data: dict, cmd: str) -> _FakeResponse:
|
||||
self.calls.append((req_id, data, cmd))
|
||||
if self._idx < len(self.responses):
|
||||
resp = self.responses[self._idx]
|
||||
self._idx += 1
|
||||
return resp
|
||||
return _FakeResponse()
|
||||
|
||||
|
||||
class _FakeFrame:
|
||||
"""Minimal frame object with a body dict."""
|
||||
|
||||
def __init__(self, body: dict | None = None):
|
||||
self.body = body or {}
|
||||
|
||||
|
||||
class _FakeWeComClient:
|
||||
"""Fake WeCom client with mock methods."""
|
||||
|
||||
def __init__(self, ws_responses: list[_FakeResponse] | None = None):
|
||||
self._ws_manager = _FakeWsManager(ws_responses)
|
||||
self.download_file = AsyncMock(return_value=(None, None))
|
||||
self.reply = AsyncMock()
|
||||
self.reply_stream = AsyncMock()
|
||||
self.send_message = AsyncMock()
|
||||
self.reply_welcome = AsyncMock()
|
||||
|
||||
|
||||
# ── Helper function tests (pure, no async) ──────────────────────────
|
||||
|
||||
|
||||
def test_sanitize_filename_strips_path_traversal() -> None:
|
||||
assert _sanitize_filename("../../etc/passwd") == "passwd"
|
||||
|
||||
|
||||
def test_sanitize_filename_keeps_chinese_chars() -> None:
|
||||
assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg"
|
||||
|
||||
|
||||
def test_sanitize_filename_empty_input() -> None:
|
||||
assert _sanitize_filename("") == ""
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_image() -> None:
|
||||
for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"):
|
||||
assert _guess_wecom_media_type(f"photo{ext}") == "image"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_video() -> None:
|
||||
for ext in (".mp4", ".avi", ".mov"):
|
||||
assert _guess_wecom_media_type(f"video{ext}") == "video"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_voice() -> None:
|
||||
for ext in (".amr", ".mp3", ".wav", ".ogg"):
|
||||
assert _guess_wecom_media_type(f"audio{ext}") == "voice"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_file_fallback() -> None:
|
||||
for ext in (".pdf", ".doc", ".xlsx", ".zip"):
|
||||
assert _guess_wecom_media_type(f"doc{ext}") == "file"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_case_insensitive() -> None:
|
||||
assert _guess_wecom_media_type("photo.PNG") == "image"
|
||||
assert _guess_wecom_media_type("photo.Jpg") == "image"
|
||||
|
||||
|
||||
# ── _download_and_save_media() ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_and_save_success() -> None:
|
||||
"""Successful download writes file and returns sanitized path."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
fake_data = b"\x89PNG\r\nfake image"
|
||||
client.download_file.return_value = (fake_data, "raw_photo.png")
|
||||
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||
path = await channel._download_and_save_media("https://example.com/img.png", "aes_key", "image", "photo.png")
|
||||
|
||||
assert path is not None
|
||||
assert os.path.isfile(path)
|
||||
assert os.path.basename(path) == "photo.png"
|
||||
# Cleanup
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_and_save_oversized_rejected() -> None:
|
||||
"""Data exceeding 200MB is rejected → returns None."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
big_data = b"\x00" * (200 * 1024 * 1024 + 1) # 200MB + 1 byte
|
||||
client.download_file.return_value = (big_data, "big.bin")
|
||||
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||
result = await channel._download_and_save_media("https://example.com/big.bin", "key", "file", "big.bin")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_and_save_failure() -> None:
|
||||
"""SDK returns None data → returns None."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
client.download_file.return_value = (None, None)
|
||||
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||
result = await channel._download_and_save_media("https://example.com/fail.png", "key", "image")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── _upload_media_ws() ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_success() -> None:
|
||||
"""Happy path: init → chunk → finish → returns (media_id, media_type)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=0, body={}),
|
||||
_FakeResponse(errcode=0, body={"media_id": "media_abc"}),
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
media_id, media_type = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert media_id == "media_abc"
|
||||
assert media_type == "image"
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_oversized_file() -> None:
|
||||
"""File >200MB triggers ValueError → returns (None, None)."""
|
||||
# Instead of creating a real 200MB+ file, mock os.path.getsize and open
|
||||
with patch("os.path.getsize", return_value=200 * 1024 * 1024 + 1), \
|
||||
patch("builtins.open", MagicMock()):
|
||||
client = _FakeWeComClient()
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
result = await channel._upload_media_ws(client, "/fake/large.bin")
|
||||
assert result == (None, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_init_failure() -> None:
|
||||
"""Init step returns errcode != 0 → returns (None, None)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
||||
f.write(b"hello")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=50001, errmsg="invalid"),
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
result = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert result == (None, None)
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_chunk_failure() -> None:
|
||||
"""Chunk step returns errcode != 0 → returns (None, None)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=50002, errmsg="chunk fail"),
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
result = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert result == (None, None)
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_finish_no_media_id() -> None:
|
||||
"""Finish step returns empty media_id → returns (None, None)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=0, body={}),
|
||||
_FakeResponse(errcode=0, body={}), # no media_id
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
result = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert result == (None, None)
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
# ── send() ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_with_frame() -> None:
|
||||
"""When frame is stored, send uses reply_stream for final text."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="hello")
|
||||
)
|
||||
|
||||
client.reply_stream.assert_called_once()
|
||||
call_args = client.reply_stream.call_args
|
||||
assert call_args[0][2] == "hello" # content arg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_with_frame() -> None:
|
||||
"""When metadata has _progress, send uses reply_stream with finish=False."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="thinking...", metadata={"_progress": True})
|
||||
)
|
||||
|
||||
client.reply_stream.assert_called_once()
|
||||
call_args = client.reply_stream.call_args
|
||||
assert call_args[0][2] == "thinking..." # content arg
|
||||
assert call_args[1]["finish"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_proactive_without_frame() -> None:
|
||||
"""Without stored frame, send uses send_message with markdown."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="proactive msg")
|
||||
)
|
||||
|
||||
client.send_message.assert_called_once()
|
||||
call_args = client.send_message.call_args
|
||||
assert call_args[0][0] == "chat1"
|
||||
assert call_args[0][1]["msgtype"] == "markdown"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_then_text() -> None:
|
||||
"""Media files are uploaded and sent before text content."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=0, body={}),
|
||||
_FakeResponse(errcode=0, body={"media_id": "media_123"}),
|
||||
]
|
||||
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient(responses)
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="see image", media=[tmp])
|
||||
)
|
||||
|
||||
# Media should have been sent via reply
|
||||
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") == "image"]
|
||||
assert len(media_calls) == 1
|
||||
assert media_calls[0][0][1]["image"]["media_id"] == "media_123"
|
||||
|
||||
# Text should have been sent via reply_stream
|
||||
client.reply_stream.assert_called_once()
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_file_not_found() -> None:
|
||||
"""Non-existent media path is skipped with a warning."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="hello", media=["/nonexistent/file.png"])
|
||||
)
|
||||
|
||||
# reply_stream should still be called for the text part
|
||||
client.reply_stream.assert_called_once()
|
||||
# No media reply should happen
|
||||
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") in ("image", "file", "video")]
|
||||
assert len(media_calls) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_exception_caught_not_raised() -> None:
|
||||
"""Exceptions inside send() must not propagate."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
# Make reply_stream raise
|
||||
client.reply_stream.side_effect = RuntimeError("boom")
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="fail test")
|
||||
)
|
||||
# No exception — test passes if we reach here.
|
||||
|
||||
|
||||
# ── _process_message() ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_text_message() -> None:
|
||||
"""Text message is routed to bus with correct fields."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_text_1",
|
||||
"chatid": "chat1",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user1"},
|
||||
"text": {"content": "hello wecom"},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "text")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "chat1"
|
||||
assert msg.content == "hello wecom"
|
||||
assert msg.metadata["msg_type"] == "text"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_image_message() -> None:
|
||||
"""Image message: download success → media_paths non-empty."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
saved = f.name
|
||||
|
||||
client.download_file.return_value = (b"\x89PNG\r\n", "photo.png")
|
||||
channel._client = client
|
||||
|
||||
try:
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_img_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"image": {"url": "https://example.com/img.png", "aeskey": "key123"},
|
||||
})
|
||||
await channel._process_message(frame, "image")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert len(msg.media) == 1
|
||||
assert msg.media[0].endswith("photo.png")
|
||||
assert "[image:" in msg.content
|
||||
finally:
|
||||
if os.path.exists(saved):
|
||||
pass # may have been overwritten; clean up if exists
|
||||
# Clean up any photo.png in tempdir
|
||||
p = os.path.join(os.path.dirname(saved), "photo.png")
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_file_message() -> None:
|
||||
"""File message: download success → media_paths non-empty (critical fix verification)."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(b"%PDF-1.4 fake")
|
||||
saved = f.name
|
||||
|
||||
client.download_file.return_value = (b"%PDF-1.4 fake", "report.pdf")
|
||||
channel._client = client
|
||||
|
||||
try:
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_file_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"file": {"url": "https://example.com/report.pdf", "aeskey": "key456", "name": "report.pdf"},
|
||||
})
|
||||
await channel._process_message(frame, "file")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert len(msg.media) == 1
|
||||
assert msg.media[0].endswith("report.pdf")
|
||||
assert "[file: report.pdf]" in msg.content
|
||||
finally:
|
||||
p = os.path.join(os.path.dirname(saved), "report.pdf")
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_voice_message() -> None:
|
||||
"""Voice message: transcribed text is included in content."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_voice_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"voice": {"content": "transcribed text here"},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "voice")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert "transcribed text here" in msg.content
|
||||
assert "[voice]" in msg.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_deduplication() -> None:
|
||||
"""Same msg_id is not processed twice."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_dup_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"text": {"content": "once"},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "text")
|
||||
await channel._process_message(frame, "text")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.content == "once"
|
||||
|
||||
# Second message should not appear on the bus
|
||||
assert channel.bus.inbound.empty()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_empty_content_skipped() -> None:
|
||||
"""Message with empty content produces no bus message."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_empty_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"text": {"content": ""},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "text")
|
||||
|
||||
assert channel.bus.inbound.empty()
|
||||
227
tests/channels/ws_test_client.py
Normal file
227
tests/channels/ws_test_client.py
Normal file
@ -0,0 +1,227 @@
|
||||
"""Lightweight WebSocket test client for integration testing the nanobot WebSocket channel.
|
||||
|
||||
Provides an async ``WsTestClient`` class and token-issuance helpers that
|
||||
integration tests can import and use directly::
|
||||
|
||||
from ws_test_client import WsTestClient
|
||||
|
||||
async with WsTestClient("ws://127.0.0.1:8765/", client_id="t") as c:
|
||||
ready = await c.recv_ready()
|
||||
await c.send_text("hello")
|
||||
msg = await c.recv_message()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
|
||||
@dataclass
|
||||
class WsMessage:
|
||||
"""A parsed message received from the WebSocket server."""
|
||||
|
||||
event: str
|
||||
raw: dict[str, Any] = field(repr=False)
|
||||
|
||||
@property
|
||||
def text(self) -> str | None:
|
||||
return self.raw.get("text")
|
||||
|
||||
@property
|
||||
def chat_id(self) -> str | None:
|
||||
return self.raw.get("chat_id")
|
||||
|
||||
@property
|
||||
def client_id(self) -> str | None:
|
||||
return self.raw.get("client_id")
|
||||
|
||||
@property
|
||||
def media(self) -> list[str] | None:
|
||||
return self.raw.get("media")
|
||||
|
||||
@property
|
||||
def reply_to(self) -> str | None:
|
||||
return self.raw.get("reply_to")
|
||||
|
||||
@property
|
||||
def stream_id(self) -> str | None:
|
||||
return self.raw.get("stream_id")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, WsMessage):
|
||||
return NotImplemented
|
||||
return self.event == other.event and self.raw == other.raw
|
||||
|
||||
|
||||
class WsTestClient:
|
||||
"""Async WebSocket test client with helper methods for common operations.
|
||||
|
||||
Usage::
|
||||
|
||||
async with WsTestClient("ws://127.0.0.1:8765/", client_id="tester") as client:
|
||||
ready = await client.recv_ready()
|
||||
await client.send_text("hello")
|
||||
msg = await client.recv_message(timeout=5.0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
client_id: str = "test-client",
|
||||
token: str = "",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
params: list[str] = []
|
||||
if client_id:
|
||||
params.append(f"client_id={client_id}")
|
||||
if token:
|
||||
params.append(f"token={token}")
|
||||
sep = "&" if "?" in uri else "?"
|
||||
self._uri = uri + sep + "&".join(params) if params else uri
|
||||
self._extra_headers = extra_headers
|
||||
self._ws: ClientConnection | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._ws = await websockets.connect(
|
||||
self._uri,
|
||||
additional_headers=self._extra_headers,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def __aenter__(self) -> WsTestClient:
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
@property
|
||||
def ws(self) -> ClientConnection:
|
||||
assert self._ws is not None, "Client is not connected"
|
||||
return self._ws
|
||||
|
||||
# -- Receiving --------------------------------------------------------
|
||||
|
||||
async def recv_raw(self, timeout: float = 10.0) -> dict[str, Any]:
|
||||
"""Receive and parse one raw JSON message with timeout."""
|
||||
raw = await asyncio.wait_for(self.ws.recv(), timeout=timeout)
|
||||
return json.loads(raw)
|
||||
|
||||
async def recv(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive one message, returning a WsMessage wrapper."""
|
||||
data = await self.recv_raw(timeout)
|
||||
return WsMessage(event=data.get("event", ""), raw=data)
|
||||
|
||||
async def recv_ready(self, timeout: float = 5.0) -> WsMessage:
|
||||
"""Receive and validate the 'ready' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "ready", f"Expected 'ready' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def recv_message(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive and validate a 'message' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "message", f"Expected 'message' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def recv_delta(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive and validate a 'delta' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "delta", f"Expected 'delta' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def recv_stream_end(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive and validate a 'stream_end' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "stream_end", f"Expected 'stream_end' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def collect_stream(self, timeout: float = 10.0) -> list[WsMessage]:
|
||||
"""Collect all deltas and the final stream_end into a list."""
|
||||
messages: list[WsMessage] = []
|
||||
while True:
|
||||
msg = await self.recv(timeout)
|
||||
messages.append(msg)
|
||||
if msg.event == "stream_end":
|
||||
break
|
||||
return messages
|
||||
|
||||
async def recv_n(self, n: int, timeout: float = 10.0) -> list[WsMessage]:
|
||||
"""Receive exactly *n* messages."""
|
||||
return [await self.recv(timeout) for _ in range(n)]
|
||||
|
||||
# -- Sending ----------------------------------------------------------
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send a plain text frame."""
|
||||
await self.ws.send(text)
|
||||
|
||||
async def send_json(self, data: dict[str, Any]) -> None:
|
||||
"""Send a JSON frame."""
|
||||
await self.ws.send(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
async def send_content(self, content: str) -> None:
|
||||
"""Send content in the preferred JSON format ``{"content": ...}``."""
|
||||
await self.send_json({"content": content})
|
||||
|
||||
# -- Connection introspection -----------------------------------------
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._ws is None or self._ws.closed
|
||||
|
||||
|
||||
# -- Token issuance helpers -----------------------------------------------
|
||||
|
||||
|
||||
async def issue_token(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
issue_path: str = "/auth/token",
|
||||
secret: str = "",
|
||||
) -> tuple[dict[str, Any] | None, int]:
|
||||
"""Request a short-lived token from the token-issue HTTP endpoint.
|
||||
|
||||
Returns ``(parsed_json_or_None, status_code)``.
|
||||
"""
|
||||
url = f"http://{host}:{port}{issue_path}"
|
||||
headers: dict[str, str] = {}
|
||||
if secret:
|
||||
headers["Authorization"] = f"Bearer {secret}"
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
resp = await loop.run_in_executor(
|
||||
None, lambda: httpx.get(url, headers=headers, timeout=5.0)
|
||||
)
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
data = None
|
||||
return data, resp.status_code
|
||||
|
||||
|
||||
async def issue_token_ok(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
issue_path: str = "/auth/token",
|
||||
secret: str = "",
|
||||
) -> str:
|
||||
"""Request a token, asserting success, and return the token string."""
|
||||
(data, status) = await issue_token(host, port, issue_path, secret)
|
||||
assert status == 200, f"Token issue failed with status {status}"
|
||||
assert data is not None
|
||||
token = data["token"]
|
||||
assert token.startswith("nbwt_"), f"Unexpected token format: {token}"
|
||||
return token
|
||||
@ -327,3 +327,240 @@ async def test_external_update_preserves_run_history_records(tmp_path):
|
||||
|
||||
fresh._running = True
|
||||
fresh._save_store()
|
||||
|
||||
|
||||
# ── timer race regression tests ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timer_execution_is_not_rolled_back_by_list_jobs_reload(tmp_path):
|
||||
"""list_jobs() during _on_timer should not replace the active store and re-run the same due job."""
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
calls: list[str] = []
|
||||
|
||||
async def on_job(job):
|
||||
calls.append(job.id)
|
||||
# Simulate frontend polling list_jobs while the timer callback is mid-execution.
|
||||
service.list_jobs(include_disabled=True)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
service = CronService(store_path, on_job=on_job)
|
||||
service._running = True
|
||||
service._load_store()
|
||||
service._arm_timer = lambda: None
|
||||
|
||||
job = service.add_job(
|
||||
name="race",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
job.state.next_run_at_ms = max(1, int(time.time() * 1000) - 1_000)
|
||||
service._save_store()
|
||||
|
||||
await service._on_timer()
|
||||
await service._on_timer()
|
||||
|
||||
assert calls == [job.id]
|
||||
loaded = service.get_job(job.id)
|
||||
assert loaded is not None
|
||||
assert loaded.state.last_run_at_ms is not None
|
||||
assert loaded.state.next_run_at_ms is not None
|
||||
assert loaded.state.next_run_at_ms > loaded.state.last_run_at_ms
|
||||
|
||||
|
||||
# ── update_job tests ──
|
||||
|
||||
|
||||
def test_update_job_changes_name(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="old name",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
result = service.update_job(job.id, name="new name")
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.name == "new name"
|
||||
assert result.payload.message == "hello"
|
||||
|
||||
|
||||
def test_update_job_changes_schedule(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="sched",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
old_next = job.state.next_run_at_ms
|
||||
|
||||
new_sched = CronSchedule(kind="every", every_ms=120_000)
|
||||
result = service.update_job(job.id, schedule=new_sched)
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.schedule.every_ms == 120_000
|
||||
assert result.state.next_run_at_ms != old_next
|
||||
|
||||
|
||||
def test_update_job_changes_message(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="msg",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="old message",
|
||||
)
|
||||
result = service.update_job(job.id, message="new message")
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.payload.message == "new message"
|
||||
|
||||
|
||||
def test_update_job_changes_cron_expression(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="cron-job",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||
message="hello",
|
||||
)
|
||||
result = service.update_job(
|
||||
job.id,
|
||||
schedule=CronSchedule(kind="cron", expr="0 18 * * *", tz="UTC"),
|
||||
)
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.schedule.expr == "0 18 * * *"
|
||||
assert result.state.next_run_at_ms is not None
|
||||
|
||||
|
||||
def test_update_job_not_found(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
result = service.update_job("nonexistent", name="x")
|
||||
assert result == "not_found"
|
||||
|
||||
|
||||
def test_update_job_rejects_system_job(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
service.register_system_job(CronJob(
|
||||
id="dream",
|
||||
name="dream",
|
||||
schedule=CronSchedule(kind="cron", expr="0 */2 * * *", tz="UTC"),
|
||||
payload=CronPayload(kind="system_event"),
|
||||
))
|
||||
result = service.update_job("dream", name="hacked")
|
||||
assert result == "protected"
|
||||
assert service.get_job("dream").name == "dream"
|
||||
|
||||
|
||||
def test_update_job_validates_schedule(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="validate",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
with pytest.raises(ValueError, match="unknown timezone"):
|
||||
service.update_job(
|
||||
job.id,
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="Bad/Zone"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_preserves_run_history(tmp_path) -> None:
|
||||
import asyncio
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
job = service.add_job(
|
||||
name="hist",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
await service.run_job(job.id)
|
||||
|
||||
result = service.update_job(job.id, name="renamed")
|
||||
assert isinstance(result, CronJob)
|
||||
assert len(result.state.run_history) == 1
|
||||
assert result.state.run_history[0].status == "ok"
|
||||
|
||||
|
||||
def test_update_job_offline_writes_action(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="offline",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
service.update_job(job.id, name="updated-offline")
|
||||
|
||||
action_path = tmp_path / "cron" / "action.jsonl"
|
||||
assert action_path.exists()
|
||||
lines = [l for l in action_path.read_text().strip().split("\n") if l]
|
||||
last = json.loads(lines[-1])
|
||||
assert last["action"] == "update"
|
||||
assert last["params"]["name"] == "updated-offline"
|
||||
|
||||
|
||||
def test_update_job_sentinel_channel_and_to(tmp_path) -> None:
|
||||
"""Passing None clears channel/to; omitting leaves them unchanged."""
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="sentinel",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
channel="telegram",
|
||||
to="user123",
|
||||
)
|
||||
assert job.payload.channel == "telegram"
|
||||
assert job.payload.to == "user123"
|
||||
|
||||
result = service.update_job(job.id, name="renamed")
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.payload.channel == "telegram"
|
||||
assert result.payload.to == "user123"
|
||||
|
||||
result = service.update_job(job.id, channel=None, to=None)
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.payload.channel is None
|
||||
assert result.payload.to is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_jobs_during_on_job_does_not_cause_stale_reload(tmp_path) -> None:
|
||||
"""Regression: if the bot calls list_jobs (which reloads from disk) during
|
||||
on_job execution, the in-memory next_run_at_ms update must not be lost.
|
||||
Previously this caused an infinite re-trigger loop."""
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
execution_count = 0
|
||||
|
||||
async def on_job_that_lists(job):
|
||||
nonlocal execution_count
|
||||
execution_count += 1
|
||||
# Simulate the bot calling cron(action=list) mid-execution
|
||||
service.list_jobs()
|
||||
|
||||
service = CronService(store_path, on_job=on_job_that_lists, max_sleep_ms=100)
|
||||
await service.start()
|
||||
|
||||
# Add two jobs scheduled in the past so they're immediately due
|
||||
now_ms = int(time.time() * 1000)
|
||||
for name in ("job-a", "job-b"):
|
||||
service.add_job(
|
||||
name=name,
|
||||
schedule=CronSchedule(kind="every", every_ms=3_600_000),
|
||||
message="test",
|
||||
)
|
||||
# Force next_run to the past so _on_timer picks them up
|
||||
for job in service._store.jobs:
|
||||
job.state.next_run_at_ms = now_ms - 1000
|
||||
service._save_store()
|
||||
service._arm_timer()
|
||||
|
||||
# Let the timer fire once
|
||||
await asyncio.sleep(0.3)
|
||||
service.stop()
|
||||
|
||||
# Each job should have run exactly once, not looped
|
||||
assert execution_count == 2
|
||||
|
||||
# Verify next_run_at_ms was persisted correctly (in the future)
|
||||
raw = json.loads(store_path.read_text())
|
||||
for j in raw["jobs"]:
|
||||
next_run = j["state"]["nextRunAtMs"]
|
||||
assert next_run is not None
|
||||
assert next_run > now_ms, f"Job '{j['name']}' next_run should be in the future"
|
||||
|
||||
@ -84,6 +84,34 @@ class TestEnforceRoleAlternation:
|
||||
tool_msgs = [m for m in result if m["role"] == "tool"]
|
||||
assert len(tool_msgs) == 2
|
||||
|
||||
def test_consecutive_assistant_keeps_later_tool_call_message(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Previous reply"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||
{"role": "user", "content": "Next"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert result[1]["tool_calls"] == [{"id": "1"}]
|
||||
assert result[1]["content"] is None
|
||||
assert result[2]["role"] == "tool"
|
||||
|
||||
def test_consecutive_assistant_does_not_overwrite_existing_tool_call_message(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||
{"role": "assistant", "content": "Later plain assistant"},
|
||||
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||
{"role": "user", "content": "Next"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert result[1]["tool_calls"] == [{"id": "1"}]
|
||||
assert result[1]["content"] is None
|
||||
assert result[2]["role"] == "tool"
|
||||
|
||||
def test_non_string_content_uses_latest(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": [{"type": "text", "text": "A"}]},
|
||||
|
||||
@ -550,11 +550,40 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
||||
{"role": "user", "content": "thanks"},
|
||||
])
|
||||
|
||||
assert sanitized[1]["content"] is None
|
||||
assert sanitized[1]["reasoning_content"] == "hidden"
|
||||
assert sanitized[1]["extra_content"] == {"debug": True}
|
||||
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||
|
||||
|
||||
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{"role": "user", "content": "不错"},
|
||||
{"role": "assistant", "content": "对,破 4 万指日可待"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<think>我再查一下</think>",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_function_akxp3wqzn7ph_1",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_function_akxp3wqzn7ph_1", "name": "exec", "content": "ok"},
|
||||
{"role": "user", "content": "多少star了呢"},
|
||||
])
|
||||
|
||||
assert sanitized[1]["role"] == "assistant"
|
||||
assert sanitized[1]["content"] is None
|
||||
assert sanitized[1]["tool_calls"][0]["id"] == "3ec83c30d"
|
||||
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
|
||||
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
|
||||
|
||||
31
tests/test_truncate_text_shadowing.py
Normal file
31
tests/test_truncate_text_shadowing.py
Normal file
@ -0,0 +1,31 @@
|
||||
import inspect
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
def test_sanitize_persisted_blocks_truncate_text_shadowing_regression() -> None:
|
||||
"""Regression: avoid bool param shadowing imported truncate_text.
|
||||
|
||||
Buggy behavior (historical):
|
||||
- loop.py imports `truncate_text` from helpers
|
||||
- `_sanitize_persisted_blocks(..., truncate_text: bool=...)` uses same name
|
||||
- when called with `truncate_text=True`, function body executes `truncate_text(text, ...)`
|
||||
which resolves to bool and raises `TypeError: 'bool' object is not callable`.
|
||||
|
||||
This test asserts the fixed API exists and truncation works without raising.
|
||||
"""
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
sig = inspect.signature(AgentLoop._sanitize_persisted_blocks)
|
||||
assert "should_truncate_text" in sig.parameters
|
||||
assert "truncate_text" not in sig.parameters
|
||||
|
||||
dummy = SimpleNamespace(max_tool_result_chars=5)
|
||||
content = [{"type": "text", "text": "0123456789"}]
|
||||
|
||||
out = AgentLoop._sanitize_persisted_blocks(dummy, content, should_truncate_text=True)
|
||||
assert isinstance(out, list)
|
||||
assert out and out[0]["type"] == "text"
|
||||
assert isinstance(out[0]["text"], str)
|
||||
assert out[0]["text"] != content[0]["text"]
|
||||
|
||||
423
tests/tools/test_edit_advanced.py
Normal file
423
tests/tools/test_edit_advanced.py
Normal file
@ -0,0 +1,423 @@
|
||||
"""Tests for advanced EditFileTool enhancements inspired by claude-code:
|
||||
- Delete-line newline cleanup
|
||||
- Smart quote normalization (curly ↔ straight)
|
||||
- Quote style preservation in replacements
|
||||
- Indentation preservation when fallback match is trimmed
|
||||
- Trailing whitespace stripping for new_text
|
||||
- File size protection
|
||||
- Stale detection with content-equality fallback
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, _find_match
|
||||
from nanobot.agent.tools import file_state
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_file_state():
|
||||
file_state.clear()
|
||||
yield
|
||||
file_state.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete-line newline cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteLineCleanup:
|
||||
"""When new_text='' and deleting a line, trailing newline should be consumed."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_line_consumes_trailing_newline(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("line1\nline2\nline3\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="line2", new_text="")
|
||||
assert "Successfully" in result
|
||||
content = f.read_text()
|
||||
# Should not leave a blank line where line2 was
|
||||
assert content == "line1\nline3\n"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_line_with_explicit_newline_in_old_text(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("line1\nline2\nline3\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="line2\n", new_text="")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "line1\nline3\n"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_preserves_content_when_not_trailing_newline(self, tool, tmp_path):
|
||||
"""Deleting a word mid-line should not consume extra characters."""
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world here\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="world ", new_text="")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "hello here\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Smart quote normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSmartQuoteNormalization:
|
||||
"""_find_match should handle curly ↔ straight quote fallback."""
|
||||
|
||||
def test_curly_double_quotes_match_straight(self):
|
||||
content = 'She said \u201chello\u201d to him'
|
||||
old_text = 'She said "hello" to him'
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
# Returned match should be the ORIGINAL content with curly quotes
|
||||
assert "\u201c" in match
|
||||
|
||||
def test_curly_single_quotes_match_straight(self):
|
||||
content = "it\u2019s a test"
|
||||
old_text = "it's a test"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
assert "\u2019" in match
|
||||
|
||||
def test_straight_matches_curly_in_old_text(self):
|
||||
content = 'x = "hello"'
|
||||
old_text = 'x = \u201chello\u201d'
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
|
||||
def test_exact_match_still_preferred_over_quote_normalization(self):
|
||||
content = 'x = "hello"'
|
||||
old_text = 'x = "hello"'
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match == old_text
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestQuoteStylePreservation:
|
||||
"""When quote-normalized matching occurs, replacement should preserve actual quote style."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replacement_preserves_curly_double_quotes(self, tool, tmp_path):
|
||||
f = tmp_path / "quotes.txt"
|
||||
f.write_text('message = “hello”\n', encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='message = "hello"',
|
||||
new_text='message = "goodbye"',
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == 'message = “goodbye”\n'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replacement_preserves_curly_apostrophe(self, tool, tmp_path):
|
||||
f = tmp_path / "apostrophe.txt"
|
||||
f.write_text("it’s fine\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="it's fine",
|
||||
new_text="it's better",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == "it’s better\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Indentation preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIndentationPreservation:
|
||||
"""Replacement should keep outer indentation when trim fallback matched."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_fallback_preserves_outer_indentation(self, tool, tmp_path):
|
||||
f = tmp_path / "indent.py"
|
||||
f.write_text(
|
||||
"if True:\n"
|
||||
" def foo():\n"
|
||||
" pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="def foo():\n pass",
|
||||
new_text="def bar():\n return 1",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == (
|
||||
"if True:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Failure diagnostics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEditDiagnostics:
|
||||
"""Failure paths should offer actionable hints."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_match_reports_candidate_lines(self, tool, tmp_path):
|
||||
f = tmp_path / "dup.py"
|
||||
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
|
||||
assert "appears 2 times" in result.lower()
|
||||
assert "line 1" in result.lower()
|
||||
assert "line 3" in result.lower()
|
||||
assert "replace_all=true" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_reports_whitespace_hint(self, tool, tmp_path):
|
||||
f = tmp_path / "space.py"
|
||||
f.write_text("value = 1\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="value = 1", new_text="value = 2")
|
||||
assert "Error" in result
|
||||
assert "whitespace" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_reports_case_hint(self, tool, tmp_path):
|
||||
f = tmp_path / "case.py"
|
||||
f.write_text("HelloWorld\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="helloworld", new_text="goodbye")
|
||||
assert "Error" in result
|
||||
assert "letter case differs" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Advanced fallback replacement behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdvancedReplaceAll:
|
||||
"""replace_all should work correctly for fallback-based matches too."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path):
|
||||
f = tmp_path / "indent_multi.py"
|
||||
f.write_text(
|
||||
"if a:\n"
|
||||
" def foo():\n"
|
||||
" pass\n"
|
||||
"if b:\n"
|
||||
" def foo():\n"
|
||||
" pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="def foo():\n pass",
|
||||
new_text="def bar():\n return 1",
|
||||
replace_all=True,
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == (
|
||||
"if a:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
"if b:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path):
|
||||
f = tmp_path / "quote_indent.py"
|
||||
f.write_text(" message = “hello”\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='message = "hello"',
|
||||
new_text='message = "goodbye"',
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == " message = “goodbye”\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Advanced fallback replacement behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdvancedReplaceAll:
|
||||
"""replace_all should work correctly for fallback-based matches too."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path):
|
||||
f = tmp_path / "indent_multi.py"
|
||||
f.write_text(
|
||||
"if a:\n"
|
||||
" def foo():\n"
|
||||
" pass\n"
|
||||
"if b:\n"
|
||||
" def foo():\n"
|
||||
" pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="def foo():\n pass",
|
||||
new_text="def bar():\n return 1",
|
||||
replace_all=True,
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == (
|
||||
"if a:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
"if b:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path):
|
||||
f = tmp_path / "quote_indent.py"
|
||||
f.write_text(" message = “hello”\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='message = "hello"',
|
||||
new_text='message = "goodbye"',
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == " message = “goodbye”\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trailing whitespace stripping on new_text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrailingWhitespaceStrip:
|
||||
"""new_text trailing whitespace should be stripped (except .md files)."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strips_trailing_whitespace_from_new_text(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("x = 1\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="x = 1", new_text="x = 2 \ny = 3 ",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
content = f.read_text()
|
||||
assert "x = 2\ny = 3\n" == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_trailing_whitespace_in_markdown(self, tool, tmp_path):
|
||||
f = tmp_path / "doc.md"
|
||||
f.write_text("# Title\n", encoding="utf-8")
|
||||
# Markdown uses trailing double-space for line breaks
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="# Title", new_text="# Title \nSubtitle ",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
content = f.read_text()
|
||||
# Trailing spaces should be preserved for markdown
|
||||
assert "Title " in content
|
||||
assert "Subtitle " in content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File size protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileSizeProtection:
|
||||
"""Editing extremely large files should be rejected."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_file_over_size_limit(self, tool, tmp_path):
|
||||
f = tmp_path / "huge.txt"
|
||||
f.write_text("x", encoding="utf-8")
|
||||
# Monkey-patch the file size check by creating a stat mock
|
||||
original_stat = f.stat
|
||||
|
||||
class FakeStat:
|
||||
def __init__(self, real_stat):
|
||||
self._real = real_stat
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._real, name)
|
||||
|
||||
@property
|
||||
def st_size(self):
|
||||
return 2 * 1024 * 1024 * 1024 # 2 GiB
|
||||
|
||||
import unittest.mock
|
||||
with unittest.mock.patch.object(type(f), 'stat', return_value=FakeStat(f.stat())):
|
||||
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
||||
assert "Error" in result
|
||||
assert "too large" in result.lower() or "size" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stale detection with content-equality fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStaleDetectionContentFallback:
|
||||
"""When mtime changed but file content is unchanged, edit should proceed without warning."""
|
||||
|
||||
@pytest.fixture()
|
||||
def read_tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def edit_tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mtime_bump_same_content_no_warning(self, read_tool, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
await read_tool.execute(path=str(f))
|
||||
|
||||
# Touch the file to bump mtime without changing content
|
||||
time.sleep(0.05)
|
||||
original_content = f.read_text()
|
||||
f.write_text(original_content, encoding="utf-8")
|
||||
|
||||
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
# Should NOT warn about modification since content is the same
|
||||
assert "modified" not in result.lower()
|
||||
152
tests/tools/test_edit_enhancements.py
Normal file
152
tests/tools/test_edit_enhancements.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
|
||||
.ipynb detection, and create-file semantics."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools import file_state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_file_state():
|
||||
"""Reset global read-state between tests."""
|
||||
file_state.clear()
|
||||
yield
|
||||
file_state.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read-before-edit tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditReadTracking:
|
||||
"""edit_file should warn when file hasn't been read first."""
|
||||
|
||||
@pytest.fixture()
|
||||
def read_tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def edit_tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_warns_if_file_not_read_first(self, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
# Should still succeed but include a warning
|
||||
assert "Successfully" in result
|
||||
assert "not been read" in result.lower() or "warning" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_succeeds_cleanly_after_read(self, read_tool, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
await read_tool.execute(path=str(f))
|
||||
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
# No warning when file was read first
|
||||
assert "not been read" not in result.lower()
|
||||
assert f.read_text() == "hello earth"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_warns_if_file_modified_since_read(self, read_tool, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
await read_tool.execute(path=str(f))
|
||||
# External modification
|
||||
f.write_text("hello universe", encoding="utf-8")
|
||||
result = await edit_tool.execute(path=str(f), old_text="universe", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
assert "modified" in result.lower() or "warning" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Create-file semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditCreateFile:
|
||||
"""edit_file with old_text='' creates new file if not exists."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_file_with_empty_old_text(self, tool, tmp_path):
|
||||
f = tmp_path / "subdir" / "new.py"
|
||||
result = await tool.execute(path=str(f), old_text="", new_text="print('hi')")
|
||||
assert "created" in result.lower() or "Successfully" in result
|
||||
assert f.exists()
|
||||
assert f.read_text() == "print('hi')"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_fails_if_file_already_exists_and_not_empty(self, tool, tmp_path):
|
||||
f = tmp_path / "existing.py"
|
||||
f.write_text("existing content", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="", new_text="new content")
|
||||
assert "Error" in result or "already exists" in result.lower()
|
||||
# File should be unchanged
|
||||
assert f.read_text() == "existing content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_succeeds_if_file_exists_but_empty(self, tool, tmp_path):
|
||||
f = tmp_path / "empty.py"
|
||||
f.write_text("", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="", new_text="print('hi')")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "print('hi')"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# .ipynb detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditIpynbDetection:
|
||||
"""edit_file should refuse .ipynb and suggest notebook_edit."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path):
|
||||
f = tmp_path / "analysis.ipynb"
|
||||
f.write_text('{"cells": []}', encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
||||
assert "notebook" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Path suggestion on not-found
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditPathSuggestion:
|
||||
"""edit_file should suggest similar paths on not-found."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggests_similar_filename(self, tool, tmp_path):
|
||||
f = tmp_path / "config.py"
|
||||
f.write_text("x = 1", encoding="utf-8")
|
||||
# Typo: conifg.py
|
||||
result = await tool.execute(
|
||||
path=str(tmp_path / "conifg.py"), old_text="x = 1", new_text="x = 2",
|
||||
)
|
||||
assert "Error" in result
|
||||
assert "config.py" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shows_cwd_in_error(self, tool, tmp_path):
|
||||
result = await tool.execute(
|
||||
path=str(tmp_path / "nonexistent.py"), old_text="a", new_text="b",
|
||||
)
|
||||
assert "Error" in result
|
||||
@ -43,3 +43,34 @@ async def test_exec_path_append_preserves_system_path():
|
||||
tool = ExecTool(path_append="/opt/custom/bin")
|
||||
result = await tool.execute(command="ls /")
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allowed_env_keys_passthrough(monkeypatch):
|
||||
"""Env vars listed in allowed_env_keys should be visible to commands."""
|
||||
monkeypatch.setenv("MY_CUSTOM_VAR", "hello-from-config")
|
||||
tool = ExecTool(allowed_env_keys=["MY_CUSTOM_VAR"])
|
||||
result = await tool.execute(command="printenv MY_CUSTOM_VAR")
|
||||
assert "hello-from-config" in result
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allowed_env_keys_does_not_leak_others(monkeypatch):
|
||||
"""Env vars NOT in allowed_env_keys should still be blocked."""
|
||||
monkeypatch.setenv("MY_CUSTOM_VAR", "hello-from-config")
|
||||
monkeypatch.setenv("MY_SECRET_VAR", "secret-value")
|
||||
tool = ExecTool(allowed_env_keys=["MY_CUSTOM_VAR"])
|
||||
result = await tool.execute(command="printenv MY_SECRET_VAR")
|
||||
assert "secret-value" not in result
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allowed_env_keys_missing_var_ignored(monkeypatch):
|
||||
"""If an allowed key is not set in the parent process, it should be silently skipped."""
|
||||
monkeypatch.delenv("NONEXISTENT_VAR_12345", raising=False)
|
||||
tool = ExecTool(allowed_env_keys=["NONEXISTENT_VAR_12345"])
|
||||
result = await tool.execute(command="printenv NONEXISTENT_VAR_12345")
|
||||
assert "Exit code: 1" in result
|
||||
|
||||
@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo"]
|
||||
@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
||||
@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo"]
|
||||
@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == []
|
||||
@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
||||
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == []
|
||||
@ -376,6 +356,46 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
||||
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_one_failure_does_not_block_others(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
sessions = {"good": _make_fake_session(["demo"])}
|
||||
|
||||
class _SelectiveClientSession:
|
||||
def __init__(self, read: object, _write: object) -> None:
|
||||
self._session = sessions[read]
|
||||
|
||||
async def __aenter__(self) -> object:
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
@asynccontextmanager
|
||||
async def _selective_stdio_client(params: object):
|
||||
if params.command == "bad":
|
||||
raise RuntimeError("boom")
|
||||
yield params.command, object()
|
||||
|
||||
monkeypatch.setattr(sys.modules["mcp"], "ClientSession", _SelectiveClientSession)
|
||||
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _selective_stdio_client)
|
||||
|
||||
registry = ToolRegistry()
|
||||
stacks = await connect_mcp_servers(
|
||||
{
|
||||
"good": MCPServerConfig(command="good"),
|
||||
"bad": MCPServerConfig(command="bad"),
|
||||
},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_good_demo"]
|
||||
assert set(stacks) == {"good"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCPResourceWrapper tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -389,9 +409,7 @@ def _make_resource_def(
|
||||
return SimpleNamespace(name=name, uri=uri, description=description)
|
||||
|
||||
|
||||
def _make_resource_wrapper(
|
||||
session: object, *, timeout: float = 0.1
|
||||
) -> MCPResourceWrapper:
|
||||
def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper:
|
||||
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
|
||||
|
||||
|
||||
@ -434,9 +452,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(contents=[])
|
||||
|
||||
wrapper = _make_resource_wrapper(
|
||||
SimpleNamespace(read_resource=read_resource), timeout=0.01
|
||||
)
|
||||
wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01)
|
||||
result = await wrapper.execute()
|
||||
assert result == "(MCP resource read timed out after 0.01s)"
|
||||
|
||||
@ -464,20 +480,14 @@ def _make_prompt_def(
|
||||
return SimpleNamespace(name=name, description=description, arguments=arguments)
|
||||
|
||||
|
||||
def _make_prompt_wrapper(
|
||||
session: object, *, timeout: float = 0.1
|
||||
) -> MCPPromptWrapper:
|
||||
return MCPPromptWrapper(
|
||||
session, "srv", _make_prompt_def(), prompt_timeout=timeout
|
||||
)
|
||||
def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper:
|
||||
return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout)
|
||||
|
||||
|
||||
def test_prompt_wrapper_properties() -> None:
|
||||
arg1 = SimpleNamespace(name="topic", required=True)
|
||||
arg2 = SimpleNamespace(name="style", required=False)
|
||||
wrapper = MCPPromptWrapper(
|
||||
None, "myserver", _make_prompt_def(arguments=[arg1, arg2])
|
||||
)
|
||||
wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2]))
|
||||
assert wrapper.name == "mcp_myserver_prompt_myprompt"
|
||||
assert "[MCP Prompt]" in wrapper.description
|
||||
assert "A test prompt" in wrapper.description
|
||||
@ -528,9 +538,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(messages=[])
|
||||
|
||||
wrapper = _make_prompt_wrapper(
|
||||
SimpleNamespace(get_prompt=get_prompt), timeout=0.01
|
||||
)
|
||||
wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01)
|
||||
result = await wrapper.execute()
|
||||
assert result == "(MCP prompt call timed out after 0.01s)"
|
||||
|
||||
@ -616,15 +624,11 @@ async def test_connect_registers_resources_and_prompts(
|
||||
prompt_names=["prompt_c"],
|
||||
)
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert "mcp_test_tool_a" in registry.tool_names
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Test message tool suppress logic for final replies."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic:
|
||||
assert result is not None
|
||||
assert "Hello" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="First answer", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
pending_queue = asyncio.Queue()
|
||||
await pending_queue.put(
|
||||
InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up")
|
||||
)
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start")
|
||||
result = await loop._process_message(msg, pending_queue=pending_queue)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert sent[0].content == "Tool reply"
|
||||
assert result is None
|
||||
|
||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||
@ -107,7 +144,7 @@ class TestMessageToolSuppressLogic:
|
||||
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
progress.append((content, tool_hint))
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
|
||||
147
tests/tools/test_notebook_tool.py
Normal file
147
tests/tools/test_notebook_tool.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""Tests for NotebookEditTool — Jupyter .ipynb editing."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
|
||||
|
||||
def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict:
|
||||
"""Build a minimal valid .ipynb structure."""
|
||||
return {
|
||||
"nbformat": nbformat,
|
||||
"nbformat_minor": nbformat_minor,
|
||||
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
|
||||
"cells": cells or [],
|
||||
}
|
||||
|
||||
|
||||
def _code_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _md_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "markdown", "source": source, "metadata": {}}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _write_nb(tmp_path, name: str, nb: dict) -> str:
|
||||
p = tmp_path / name
|
||||
p.write_text(json.dumps(nb), encoding="utf-8")
|
||||
return str(p)
|
||||
|
||||
|
||||
class TestNotebookEdit:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return NotebookEditTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_cell_content(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="print('world')")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][0]["source"] == "print('world')"
|
||||
assert saved["cells"][1]["source"] == "x = 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_cell_after_target(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 3
|
||||
assert saved["cells"][0]["source"] == "cell 0"
|
||||
assert saved["cells"][1]["source"] == "inserted"
|
||||
assert saved["cells"][2]["source"] == "cell 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_cell(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=1, edit_mode="delete")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 2
|
||||
assert saved["cells"][0]["source"] == "A"
|
||||
assert saved["cells"][1]["source"] == "C"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_notebook_from_scratch(self, tool, tmp_path):
|
||||
path = str(tmp_path / "new.ipynb")
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown")
|
||||
assert "Successfully" in result or "created" in result.lower()
|
||||
saved = json.loads((tmp_path / "new.ipynb").read_text())
|
||||
assert saved["nbformat"] == 4
|
||||
assert len(saved["cells"]) == 1
|
||||
assert saved["cells"][0]["cell_type"] == "markdown"
|
||||
assert saved["cells"][0]["source"] == "# Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_index_error(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("only cell")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=5, new_source="x")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_ipynb_rejected(self, tool, tmp_path):
|
||||
f = tmp_path / "script.py"
|
||||
f.write_text("pass")
|
||||
result = await tool.execute(path=str(f), cell_index=0, new_source="x")
|
||||
assert "Error" in result
|
||||
assert ".ipynb" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_metadata_and_outputs(self, tool, tmp_path):
|
||||
cell = _code_cell("old")
|
||||
cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}]
|
||||
cell["execution_count"] = 42
|
||||
nb = _make_notebook([cell])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="new")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["metadata"]["kernelspec"]["language"] == "python"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nbformat_45_generates_cell_id(self, tool, tmp_path):
|
||||
nb = _make_notebook([], nbformat_minor=5)
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert "id" in saved["cells"][0]
|
||||
assert len(saved["cells"][0]["id"]) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_with_cell_type_markdown(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][1]["cell_type"] == "markdown"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_edit_mode_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae")
|
||||
assert "Error" in result
|
||||
assert "edit_mode" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_type_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw")
|
||||
assert "Error" in result
|
||||
assert "cell_type" in result
|
||||
180
tests/tools/test_read_enhancements.py
Normal file
180
tests/tools/test_read_enhancements.py
Normal file
@ -0,0 +1,180 @@
|
||||
"""Tests for ReadFileTool enhancements: description fix, read dedup, PDF support, device blacklist."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools import file_state
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_file_state():
|
||||
file_state.clear()
|
||||
yield
|
||||
file_state.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Description fix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadDescriptionFix:
|
||||
|
||||
def test_description_mentions_image_support(self):
|
||||
tool = ReadFileTool()
|
||||
assert "image" in tool.description.lower()
|
||||
|
||||
def test_description_no_longer_says_cannot_read_images(self):
|
||||
tool = ReadFileTool()
|
||||
assert "cannot read binary files or images" not in tool.description.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read deduplication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadDedup:
|
||||
"""Same file + same offset/limit + unchanged mtime -> short stub."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def write_tool(self, tmp_path):
|
||||
return WriteFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_read_returns_unchanged_stub(self, tool, tmp_path):
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("\n".join(f"line {i}" for i in range(100)), encoding="utf-8")
|
||||
first = await tool.execute(path=str(f))
|
||||
assert "line 0" in first
|
||||
second = await tool.execute(path=str(f))
|
||||
assert "unchanged" in second.lower()
|
||||
# Stub should not contain file content
|
||||
assert "line 0" not in second
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_after_external_modification_returns_full(self, tool, tmp_path):
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("original", encoding="utf-8")
|
||||
await tool.execute(path=str(f))
|
||||
# Modify the file externally
|
||||
f.write_text("modified content", encoding="utf-8")
|
||||
second = await tool.execute(path=str(f))
|
||||
assert "modified content" in second
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_offset_returns_full(self, tool, tmp_path):
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
|
||||
await tool.execute(path=str(f), offset=1, limit=5)
|
||||
second = await tool.execute(path=str(f), offset=6, limit=5)
|
||||
# Different offset → full read, not stub
|
||||
assert "line 6" in second
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_read_after_write_returns_full_content(self, tool, write_tool, tmp_path):
|
||||
f = tmp_path / "fresh.txt"
|
||||
result = await write_tool.execute(path=str(f), content="hello")
|
||||
assert "Successfully" in result
|
||||
read_result = await tool.execute(path=str(f))
|
||||
assert "hello" in read_result
|
||||
assert "unchanged" not in read_result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedup_does_not_apply_to_images(self, tool, tmp_path):
|
||||
f = tmp_path / "img.png"
|
||||
f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data")
|
||||
first = await tool.execute(path=str(f))
|
||||
assert isinstance(first, list)
|
||||
second = await tool.execute(path=str(f))
|
||||
# Images should always return full content blocks, not a stub
|
||||
assert isinstance(second, list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PDF support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadPdf:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_returns_text_content(self, tool, tmp_path):
|
||||
fitz = pytest.importorskip("fitz")
|
||||
pdf_path = tmp_path / "test.pdf"
|
||||
doc = fitz.open()
|
||||
page = doc.new_page()
|
||||
page.insert_text((72, 72), "Hello PDF World")
|
||||
doc.save(str(pdf_path))
|
||||
doc.close()
|
||||
|
||||
result = await tool.execute(path=str(pdf_path))
|
||||
assert "Hello PDF World" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_pages_parameter(self, tool, tmp_path):
|
||||
fitz = pytest.importorskip("fitz")
|
||||
pdf_path = tmp_path / "multi.pdf"
|
||||
doc = fitz.open()
|
||||
for i in range(5):
|
||||
page = doc.new_page()
|
||||
page.insert_text((72, 72), f"Page {i + 1} content")
|
||||
doc.save(str(pdf_path))
|
||||
doc.close()
|
||||
|
||||
result = await tool.execute(path=str(pdf_path), pages="2-3")
|
||||
assert "Page 2 content" in result
|
||||
assert "Page 3 content" in result
|
||||
assert "Page 1 content" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_file_not_found_error(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope.pdf"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device path blacklist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadDeviceBlacklist:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self):
|
||||
return ReadFileTool()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_random_blocked(self, tool):
|
||||
result = await tool.execute(path="/dev/random")
|
||||
assert "Error" in result
|
||||
assert "blocked" in result.lower() or "device" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_urandom_blocked(self, tool):
|
||||
result = await tool.execute(path="/dev/urandom")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_zero_blocked(self, tool):
|
||||
result = await tool.execute(path="/dev/zero")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proc_fd_blocked(self, tool):
|
||||
result = await tool.execute(path="/proc/self/fd/0")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_symlink_to_dev_zero_blocked(self, tmp_path):
|
||||
tool = ReadFileTool(workspace=tmp_path)
|
||||
link = tmp_path / "zero-link"
|
||||
link.symlink_to("/dev/zero")
|
||||
result = await tool.execute(path=str(link))
|
||||
assert "Error" in result
|
||||
assert "blocked" in result.lower() or "device" in result.lower()
|
||||
@ -323,3 +323,27 @@ async def test_subagent_registers_grep_and_glob(tmp_path: Path) -> None:
|
||||
|
||||
assert "grep" in captured["tool_names"]
|
||||
assert "glob" in captured["tool_names"]
|
||||
|
||||
|
||||
def test_subagent_prompt_respects_disabled_skills(tmp_path: Path) -> None:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
skills_dir = tmp_path / "skills"
|
||||
(skills_dir / "alpha").mkdir(parents=True)
|
||||
(skills_dir / "alpha" / "SKILL.md").write_text("# Alpha\n\nhidden\n", encoding="utf-8")
|
||||
(skills_dir / "beta").mkdir(parents=True)
|
||||
(skills_dir / "beta" / "SKILL.md").write_text("# Beta\n\nshown\n", encoding="utf-8")
|
||||
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=4096,
|
||||
disabled_skills=["alpha"],
|
||||
)
|
||||
|
||||
prompt = mgr._build_subagent_prompt()
|
||||
|
||||
assert "alpha" not in prompt
|
||||
assert "beta" in prompt
|
||||
|
||||
@ -120,6 +120,27 @@ async def test_jina_search(monkeypatch):
|
||||
assert "https://jina.ai" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kagi_search(monkeypatch):
|
||||
async def mock_get(self, url, **kw):
|
||||
assert "kagi.com/api/v0/search" in url
|
||||
assert kw["headers"]["Authorization"] == "Bot kagi-key"
|
||||
assert kw["params"] == {"q": "test", "limit": 2}
|
||||
return _response(json={
|
||||
"data": [
|
||||
{"t": 0, "title": "Kagi Result", "url": "https://kagi.com", "snippet": "Premium search"},
|
||||
{"t": 1, "list": ["ignored related search"]},
|
||||
]
|
||||
})
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
|
||||
tool = _tool(provider="kagi", api_key="kagi-key")
|
||||
result = await tool.execute(query="test", count=2)
|
||||
assert "Kagi Result" in result
|
||||
assert "https://kagi.com" in result
|
||||
assert "ignored related search" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_provider():
|
||||
tool = _tool(provider="unknown")
|
||||
@ -189,6 +210,23 @@ async def test_jina_422_falls_back_to_duckduckgo(monkeypatch):
|
||||
assert "DuckDuckGo fallback" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kagi_fallback_to_duckduckgo_when_no_key(monkeypatch):
|
||||
class MockDDGS:
|
||||
def __init__(self, **kw):
|
||||
pass
|
||||
|
||||
def text(self, query, max_results=5):
|
||||
return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}]
|
||||
|
||||
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||
monkeypatch.delenv("KAGI_API_KEY", raising=False)
|
||||
|
||||
tool = _tool(provider="kagi", api_key="")
|
||||
result = await tool.execute(query="test")
|
||||
assert "Fallback" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jina_search_uses_path_encoded_query(monkeypatch):
|
||||
calls = {}
|
||||
|
||||
65
tests/utils/test_strip_think.py
Normal file
65
tests/utils/test_strip_think.py
Normal file
@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
|
||||
class TestStripThinkTag:
|
||||
"""Test <thought>...</thought> block stripping (Gemma 4 and similar models)."""
|
||||
|
||||
def test_closed_tag(self):
|
||||
assert strip_think("Hello <thought>reasoning</thought> World") == "Hello World"
|
||||
|
||||
def test_unclosed_trailing_tag(self):
|
||||
assert strip_think("<thought>ongoing...") == ""
|
||||
|
||||
def test_multiline_tag(self):
|
||||
assert strip_think("<thought>\nline1\nline2\n</thought>End") == "End"
|
||||
|
||||
def test_tag_with_nested_angle_brackets(self):
|
||||
text = "<thought>a < 3 and b > 2</thought>result"
|
||||
assert strip_think(text) == "result"
|
||||
|
||||
def test_multiple_tag_blocks(self):
|
||||
text = "A<thought>x</thought>B<thought>y</thought>C"
|
||||
assert strip_think(text) == "ABC"
|
||||
|
||||
def test_tag_only_whitespace_inside(self):
|
||||
assert strip_think("before<thought> </thought>after") == "beforeafter"
|
||||
|
||||
def test_self_closing_tag_not_matched(self):
|
||||
assert strip_think("<thought/>some text") == "<thought/>some text"
|
||||
|
||||
def test_normal_text_unchanged(self):
|
||||
assert strip_think("Just normal text") == "Just normal text"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert strip_think("") == ""
|
||||
|
||||
|
||||
class TestStripThinkFalsePositive:
|
||||
"""Ensure mid-content <think>/<thought> tags are NOT stripped (#3004)."""
|
||||
|
||||
def test_backtick_think_tag_preserved(self):
|
||||
text = "*Think Stripping:* A new utility to strip `<think>` tags from output."
|
||||
assert strip_think(text) == text
|
||||
|
||||
def test_prose_think_tag_preserved(self):
|
||||
text = "The model emits <think> at the start of its response."
|
||||
assert strip_think(text) == text
|
||||
|
||||
def test_code_block_think_tag_preserved(self):
|
||||
text = "Example:\n```\ntext = re.sub(r\"<think>[\\s\\S]*\", \"\", text)\n```\nDone."
|
||||
assert strip_think(text) == text
|
||||
|
||||
def test_backtick_thought_tag_preserved(self):
|
||||
text = "Gemma 4 uses `<thought>` blocks for reasoning."
|
||||
assert strip_think(text) == text
|
||||
|
||||
def test_prefix_unclosed_think_still_stripped(self):
|
||||
assert strip_think("<think>reasoning without closing") == ""
|
||||
|
||||
def test_prefix_unclosed_think_with_whitespace(self):
|
||||
assert strip_think(" <think>reasoning...") == ""
|
||||
|
||||
def test_prefix_unclosed_thought_still_stripped(self):
|
||||
assert strip_think("<thought>reasoning without closing") == ""
|
||||
Loading…
x
Reference in New Issue
Block a user