mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-19 17:24:02 +00:00
Merge origin/main into feat/api-file-upload
Keep the API file upload branch current with main, enforce the documented JSON base64 per-file limit, and avoid leaking document extraction error strings into user prompts. Made-with: Cursor
This commit is contained in:
commit
2502fc616b
151
README.md
151
README.md
@ -394,7 +394,8 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
|||||||
"enabled": true,
|
"enabled": true,
|
||||||
"token": "YOUR_BOT_TOKEN",
|
"token": "YOUR_BOT_TOKEN",
|
||||||
"allowFrom": ["YOUR_USER_ID"],
|
"allowFrom": ["YOUR_USER_ID"],
|
||||||
"groupPolicy": "mention"
|
"groupPolicy": "mention",
|
||||||
|
"streaming": true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -405,6 +406,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
|||||||
> - `"open"` — Respond to all messages
|
> - `"open"` — Respond to all messages
|
||||||
> DMs always respond when the sender is in `allowFrom`.
|
> DMs always respond when the sender is in `allowFrom`.
|
||||||
> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session.
|
> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session.
|
||||||
|
> `streaming` defaults to `true`. Disable it only if you explicitly want non-streaming replies.
|
||||||
|
|
||||||
**5. Invite the bot**
|
**5. Invite the bot**
|
||||||
- OAuth2 → URL Generator
|
- OAuth2 → URL Generator
|
||||||
@ -558,7 +560,11 @@ Uses **WebSocket** long connection — no public IP required.
|
|||||||
"verificationToken": "",
|
"verificationToken": "",
|
||||||
"allowFrom": ["ou_YOUR_OPEN_ID"],
|
"allowFrom": ["ou_YOUR_OPEN_ID"],
|
||||||
"groupPolicy": "mention",
|
"groupPolicy": "mention",
|
||||||
"streaming": true
|
"reactEmoji": "OnIt",
|
||||||
|
"doneEmoji": "DONE",
|
||||||
|
"toolHintPrefix": "🔧",
|
||||||
|
"streaming": true,
|
||||||
|
"domain": "feishu"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -568,6 +574,10 @@ Uses **WebSocket** long connection — no public IP required.
|
|||||||
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
|
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
|
||||||
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
|
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
|
||||||
> `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond.
|
> `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond.
|
||||||
|
> `reactEmoji`: Emoji for "processing" status (default: `OnIt`). See [available emojis](https://open.larkoffice.com/document/server-docs/im-v1/message-reaction/emojis-introduce).
|
||||||
|
> `doneEmoji`: Optional emoji for "completed" status (e.g., `DONE`, `OK`, `HEART`). When set, bot adds this reaction after removing `reactEmoji`.
|
||||||
|
> `toolHintPrefix`: Prefix for inline tool hints in streaming cards (default: `🔧`).
|
||||||
|
> `domain`: `"feishu"` (default) for China (open.feishu.cn), `"lark"` for international Lark (open.larksuite.com).
|
||||||
|
|
||||||
**3. Run**
|
**3. Run**
|
||||||
|
|
||||||
@ -1043,6 +1053,30 @@ Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, To
|
|||||||
```
|
```
|
||||||
|
|
||||||
> For local servers that don't require a key, set `apiKey` to any non-empty string (e.g. `"no-key"`).
|
> For local servers that don't require a key, set `apiKey` to any non-empty string (e.g. `"no-key"`).
|
||||||
|
>
|
||||||
|
> `custom` is the right choice for providers that expose an OpenAI-compatible **chat completions** API. It does **not** force third-party endpoints onto the OpenAI/Azure **Responses API**.
|
||||||
|
>
|
||||||
|
> If your proxy or gateway is specifically Responses-API-compatible, use the `azure_openai` provider shape instead and point `apiBase` at that endpoint:
|
||||||
|
>
|
||||||
|
> ```json
|
||||||
|
> {
|
||||||
|
> "providers": {
|
||||||
|
> "azure_openai": {
|
||||||
|
> "apiKey": "your-api-key",
|
||||||
|
> "apiBase": "https://api.your-provider.com",
|
||||||
|
> "defaultModel": "your-model-name"
|
||||||
|
> }
|
||||||
|
> },
|
||||||
|
> "agents": {
|
||||||
|
> "defaults": {
|
||||||
|
> "provider": "azure_openai",
|
||||||
|
> "model": "your-model-name"
|
||||||
|
> }
|
||||||
|
> }
|
||||||
|
> }
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> In short: **chat-completions-compatible endpoint → `custom`**; **Responses-compatible endpoint → `azure_openai`**.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@ -1304,6 +1338,7 @@ If you need to allow trusted private ranges such as Tailscale / CGNAT addresses,
|
|||||||
| `brave` | `apiKey` | `BRAVE_API_KEY` | No |
|
| `brave` | `apiKey` | `BRAVE_API_KEY` | No |
|
||||||
| `tavily` | `apiKey` | `TAVILY_API_KEY` | No |
|
| `tavily` | `apiKey` | `TAVILY_API_KEY` | No |
|
||||||
| `jina` | `apiKey` | `JINA_API_KEY` | Free tier (10M tokens) |
|
| `jina` | `apiKey` | `JINA_API_KEY` | Free tier (10M tokens) |
|
||||||
|
| `kagi` | `apiKey` | `KAGI_API_KEY` | No |
|
||||||
| `searxng` | `baseUrl` | `SEARXNG_BASE_URL` | Yes (self-hosted) |
|
| `searxng` | `baseUrl` | `SEARXNG_BASE_URL` | Yes (self-hosted) |
|
||||||
| `duckduckgo` (default) | — | — | Yes |
|
| `duckduckgo` (default) | — | — | Yes |
|
||||||
|
|
||||||
@ -1360,6 +1395,20 @@ If you need to allow trusted private ranges such as Tailscale / CGNAT addresses,
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Kagi:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"tools": {
|
||||||
|
"web": {
|
||||||
|
"search": {
|
||||||
|
"provider": "kagi",
|
||||||
|
"apiKey": "your-kagi-api-key"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
**SearXNG** (self-hosted, no API key needed):
|
**SearXNG** (self-hosted, no API key needed):
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
@ -1495,6 +1544,35 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
|
|||||||
**Docker security**: The official Docker image runs as a non-root user (`nanobot`, UID 1000) with bubblewrap pre-installed. When using `docker-compose.yml`, the container drops all Linux capabilities except `SYS_ADMIN` (required for bwrap's namespace isolation).
|
**Docker security**: The official Docker image runs as a non-root user (`nanobot`, UID 1000) with bubblewrap pre-installed. When using `docker-compose.yml`, the container drops all Linux capabilities except `SYS_ADMIN` (required for bwrap's namespace isolation).
|
||||||
|
|
||||||
|
|
||||||
|
### Auto Compact
|
||||||
|
|
||||||
|
When a user is idle for longer than a configured threshold, nanobot **proactively** compresses the older part of the session context into a summary while keeping a recent legal suffix of live messages. This reduces token cost and first-token latency when the user returns — instead of re-processing a long stale context with an expired KV cache, the model receives a compact summary, the most recent live context, and fresh input.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"idleCompactAfterMinutes": 15
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| Option | Default | Description |
|
||||||
|
|--------|---------|-------------|
|
||||||
|
| `agents.defaults.idleCompactAfterMinutes` | `0` (disabled) | Minutes of idle time before auto-compaction starts. Set to `0` to disable. Recommended: `15` — close to a typical LLM KV cache expiry window, so stale sessions get compacted before the user returns. |
|
||||||
|
|
||||||
|
`sessionTtlMinutes` remains accepted as a legacy alias for backward compatibility, but `idleCompactAfterMinutes` is the preferred config key going forward.
|
||||||
|
|
||||||
|
How it works:
|
||||||
|
1. **Idle detection**: On each idle tick (~1 s), checks all sessions for expiration.
|
||||||
|
2. **Background compaction**: Idle sessions summarize the older live prefix via LLM and keep the most recent legal suffix (currently 8 messages).
|
||||||
|
3. **Summary injection**: When the user returns, the summary is injected as runtime context (one-shot, not persisted) alongside the retained recent suffix.
|
||||||
|
4. **Restart-safe resume**: The summary is also mirrored into session metadata so it can still be recovered after a process restart.
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> Think of auto compact as "summarize older context, keep the freshest live turns." It is not a hard session reset.
|
||||||
|
|
||||||
### Timezone
|
### Timezone
|
||||||
|
|
||||||
Time is context. Context should be precise.
|
Time is context. Context should be precise.
|
||||||
@ -1517,6 +1595,52 @@ Common examples: `UTC`, `America/New_York`, `America/Los_Angeles`, `Europe/Londo
|
|||||||
|
|
||||||
> Need another timezone? Browse the full [IANA Time Zone Database](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).
|
> Need another timezone? Browse the full [IANA Time Zone Database](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).
|
||||||
|
|
||||||
|
### Unified Session
|
||||||
|
|
||||||
|
By default, each channel × chat ID combination gets its own session. If you use nanobot across multiple channels (e.g. Telegram + Discord + CLI) and want them to share the same conversation, enable `unifiedSession`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"unifiedSession": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
When enabled, all incoming messages — regardless of which channel they arrive on — are routed into a single shared session. Switching from Telegram to Discord (or any other channel) continues the same conversation seamlessly.
|
||||||
|
|
||||||
|
| Behavior | `false` (default) | `true` |
|
||||||
|
|----------|-------------------|--------|
|
||||||
|
| Session key | `channel:chat_id` | `unified:default` |
|
||||||
|
| Cross-channel continuity | No | Yes |
|
||||||
|
| `/new` clears | Current channel session | Shared session |
|
||||||
|
| `/stop` finds tasks | By channel session | By shared session |
|
||||||
|
| Existing `session_key_override` (e.g. Telegram thread) | Respected | Still respected — not overwritten |
|
||||||
|
|
||||||
|
> This is designed for single-user, multi-device setups. It is **off by default** — existing users see zero behavior change.
|
||||||
|
|
||||||
|
### Disabled Skills
|
||||||
|
|
||||||
|
nanobot ships with built-in skills, and your workspace can also define custom skills under `skills/`. If you want to hide specific skills from the agent, set `agents.defaults.disabledSkills` to a list of skill directory names:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"disabledSkills": ["github", "weather"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Disabled skills are excluded from the main agent's skill summary, from always-on skill injection, and from subagent skill summaries. This is useful when some bundled skills are unnecessary for your deployment or should not be exposed to end users.
|
||||||
|
|
||||||
|
| Option | Default | Description |
|
||||||
|
|--------|---------|-------------|
|
||||||
|
| `agents.defaults.disabledSkills` | `[]` | List of skill directory names to exclude from loading. Applies to both built-in skills and workspace skills. |
|
||||||
|
|
||||||
## 🧩 Multiple Instances
|
## 🧩 Multiple Instances
|
||||||
|
|
||||||
Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint. Optionally pass `--workspace` during `onboard` when you want to initialize or update the saved workspace for a specific instance.
|
Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint. Optionally pass `--workspace` during `onboard` when you want to initialize or update the saved workspace for a specific instance.
|
||||||
@ -1603,6 +1727,7 @@ Example config:
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"gateway": {
|
"gateway": {
|
||||||
|
"host": "127.0.0.1",
|
||||||
"port": 18790
|
"port": 18790
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1615,6 +1740,14 @@ nanobot gateway --config ~/.nanobot-telegram/config.json
|
|||||||
nanobot gateway --config ~/.nanobot-discord/config.json
|
nanobot gateway --config ~/.nanobot-discord/config.json
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Each gateway instance also exposes a lightweight HTTP health endpoint on
|
||||||
|
`gateway.host:gateway.port`. By default, the gateway binds to `127.0.0.1`,
|
||||||
|
so the endpoint stays local unless you explicitly set `gateway.host` to a
|
||||||
|
public or LAN-facing address.
|
||||||
|
|
||||||
|
- `GET /health` returns `{"status":"ok"}`
|
||||||
|
- Other paths return `404`
|
||||||
|
|
||||||
Override workspace for one-off runs when needed:
|
Override workspace for one-off runs when needed:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -1642,6 +1775,7 @@ time.
|
|||||||
|
|
||||||
- `memory/history.jsonl` stores append-only summarized history
|
- `memory/history.jsonl` stores append-only summarized history
|
||||||
- `SOUL.md`, `USER.md`, and `memory/MEMORY.md` store long-term knowledge managed by Dream
|
- `SOUL.md`, `USER.md`, and `memory/MEMORY.md` store long-term knowledge managed by Dream
|
||||||
|
- `Dream` can also promote repeated workflows into reusable workspace skills under `skills/`
|
||||||
- `Dream` runs on a schedule and can also be triggered manually
|
- `Dream` runs on a schedule and can also be triggered manually
|
||||||
- memory changes can be inspected and restored with built-in commands
|
- memory changes can be inspected and restored with built-in commands
|
||||||
|
|
||||||
@ -1758,6 +1892,19 @@ By default, the API binds to `127.0.0.1:8900`. You can change this in `config.js
|
|||||||
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
|
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
|
||||||
- No streaming: `stream=true` is not supported
|
- No streaming: `stream=true` is not supported
|
||||||
- **File uploads**: supports images, PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) via JSON base64 or `multipart/form-data` (max 10MB per file)
|
- **File uploads**: supports images, PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) via JSON base64 or `multipart/form-data` (max 10MB per file)
|
||||||
|
- API requests run in the synthetic `api` channel, so the `message` tool does **not** automatically deliver to Telegram/Discord/etc. To proactively send to another chat, call `message` with an explicit `channel` and `chat_id` for an enabled channel.
|
||||||
|
|
||||||
|
Example tool call for cross-channel delivery from an API session:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"content": "Build finished successfully.",
|
||||||
|
"channel": "telegram",
|
||||||
|
"chat_id": "123456789"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
If `channel` points to a channel that is not enabled in your config, nanobot will queue the outbound event but no platform delivery will occur.
|
||||||
|
|
||||||
### Endpoints
|
### Endpoints
|
||||||
|
|
||||||
|
|||||||
@ -290,7 +290,6 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] |
|
|||||||
|------|---------|
|
|------|---------|
|
||||||
| `_stream_delta: True` | A content chunk (delta contains the new text) |
|
| `_stream_delta: True` | A content chunk (delta contains the new text) |
|
||||||
| `_stream_end: True` | Streaming finished (delta is empty) |
|
| `_stream_end: True` | Streaming finished (delta is empty) |
|
||||||
| `_resuming: True` | More streaming rounds coming (e.g. tool call then another response) |
|
|
||||||
|
|
||||||
### Example: Webhook with Streaming
|
### Example: Webhook with Streaming
|
||||||
|
|
||||||
|
|||||||
331
docs/WEBSOCKET.md
Normal file
331
docs/WEBSOCKET.md
Normal file
@ -0,0 +1,331 @@
|
|||||||
|
# WebSocket Server Channel
|
||||||
|
|
||||||
|
Nanobot can act as a WebSocket server, allowing external clients (web apps, CLIs, scripts) to interact with the agent in real time via persistent connections.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Bidirectional real-time communication over WebSocket
|
||||||
|
- Streaming support — receive agent responses token by token
|
||||||
|
- Token-based authentication (static tokens and short-lived issued tokens)
|
||||||
|
- Per-connection sessions — each connection gets a unique `chat_id`
|
||||||
|
- TLS/SSL support (WSS) with enforced TLSv1.2 minimum
|
||||||
|
- Client allow-list via `allowFrom`
|
||||||
|
- Auto-cleanup of dead connections
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### 1. Configure
|
||||||
|
|
||||||
|
Add to `config.json` under `channels.websocket`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"websocket": {
|
||||||
|
"enabled": true,
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": 8765,
|
||||||
|
"path": "/",
|
||||||
|
"websocketRequiresToken": false,
|
||||||
|
"allowFrom": ["*"],
|
||||||
|
"streaming": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Start nanobot
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nanobot gateway
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see:
|
||||||
|
|
||||||
|
```
|
||||||
|
WebSocket server listening on ws://127.0.0.1:8765/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Connect a client
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Using websocat
|
||||||
|
websocat ws://127.0.0.1:8765/?client_id=alice
|
||||||
|
|
||||||
|
# Using Python
|
||||||
|
import asyncio, json, websockets
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
async with websockets.connect("ws://127.0.0.1:8765/?client_id=alice") as ws:
|
||||||
|
ready = json.loads(await ws.recv())
|
||||||
|
print(ready) # {"event": "ready", "chat_id": "...", "client_id": "alice"}
|
||||||
|
await ws.send(json.dumps({"content": "Hello nanobot!"}))
|
||||||
|
reply = json.loads(await ws.recv())
|
||||||
|
print(reply["text"])
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
```
|
||||||
|
|
||||||
|
## Connection URL
|
||||||
|
|
||||||
|
```
|
||||||
|
ws://{host}:{port}{path}?client_id={id}&token={token}
|
||||||
|
```
|
||||||
|
|
||||||
|
| Parameter | Required | Description |
|
||||||
|
|-----------|----------|-------------|
|
||||||
|
| `client_id` | No | Identifier for `allowFrom` authorization. Auto-generated as `anon-xxxxxxxxxxxx` if omitted. Truncated to 128 chars. |
|
||||||
|
| `token` | Conditional | Authentication token. Required when `websocketRequiresToken` is `true` or `token` (static secret) is configured. |
|
||||||
|
|
||||||
|
## Wire Protocol
|
||||||
|
|
||||||
|
All frames are JSON text. Each message has an `event` field.
|
||||||
|
|
||||||
|
### Server → Client
|
||||||
|
|
||||||
|
**`ready`** — sent immediately after connection is established:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"event": "ready",
|
||||||
|
"chat_id": "uuid-v4",
|
||||||
|
"client_id": "alice"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**`message`** — full agent response:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"event": "message",
|
||||||
|
"text": "Hello! How can I help?",
|
||||||
|
"media": ["/tmp/image.png"],
|
||||||
|
"reply_to": "msg-id"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
`media` and `reply_to` are only present when applicable.
|
||||||
|
|
||||||
|
**`delta`** — streaming text chunk (only when `streaming: true`):
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"event": "delta",
|
||||||
|
"text": "Hello",
|
||||||
|
"stream_id": "s1"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**`stream_end`** — signals the end of a streaming segment:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"event": "stream_end",
|
||||||
|
"stream_id": "s1"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Client → Server
|
||||||
|
|
||||||
|
Send plain text:
|
||||||
|
|
||||||
|
```json
|
||||||
|
"Hello nanobot!"
|
||||||
|
```
|
||||||
|
|
||||||
|
Or send a JSON object with a recognized text field:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"content": "Hello nanobot!"}
|
||||||
|
```
|
||||||
|
|
||||||
|
Recognized fields: `content`, `text`, `message` (checked in that order). Invalid JSON is treated as plain text.
|
||||||
|
|
||||||
|
## Configuration Reference
|
||||||
|
|
||||||
|
All fields go under `channels.websocket` in `config.json`.
|
||||||
|
|
||||||
|
### Connection
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `enabled` | bool | `false` | Enable the WebSocket server. |
|
||||||
|
| `host` | string | `"127.0.0.1"` | Bind address. Use `"0.0.0.0"` to accept external connections. |
|
||||||
|
| `port` | int | `8765` | Listen port. |
|
||||||
|
| `path` | string | `"/"` | WebSocket upgrade path. Trailing slashes are normalized (root `/` is preserved). |
|
||||||
|
| `maxMessageBytes` | int | `1048576` | Maximum inbound message size in bytes (1 KB – 16 MB). |
|
||||||
|
|
||||||
|
### Authentication
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `token` | string | `""` | Static shared secret. When set, clients must provide `?token=<value>` matching this secret (timing-safe comparison). Issued tokens are also accepted as a fallback. |
|
||||||
|
| `websocketRequiresToken` | bool | `true` | When `true` and no static `token` is configured, clients must still present a valid issued token. Set to `false` to allow unauthenticated connections (only safe for local/trusted networks). |
|
||||||
|
| `tokenIssuePath` | string | `""` | HTTP path for issuing short-lived tokens. Must differ from `path`. See [Token Issuance](#token-issuance). |
|
||||||
|
| `tokenIssueSecret` | string | `""` | Secret required to obtain tokens via the issue endpoint. If empty, any client can obtain tokens (logged as a warning). |
|
||||||
|
| `tokenTtlS` | int | `300` | Time-to-live for issued tokens in seconds (30 – 86,400). |
|
||||||
|
|
||||||
|
### Access Control
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `allowFrom` | list of string | `["*"]` | Allowed `client_id` values. `"*"` allows all; `[]` denies all. |
|
||||||
|
|
||||||
|
### Streaming
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `streaming` | bool | `true` | Enable streaming mode. The agent sends `delta` + `stream_end` frames instead of a single `message`. |
|
||||||
|
|
||||||
|
### Keep-alive
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `pingIntervalS` | float | `20.0` | WebSocket ping interval in seconds (5 – 300). |
|
||||||
|
| `pingTimeoutS` | float | `20.0` | Time to wait for a pong before closing the connection (5 – 300). |
|
||||||
|
|
||||||
|
### TLS/SSL
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `sslCertfile` | string | `""` | Path to the TLS certificate file (PEM). Both `sslCertfile` and `sslKeyfile` must be set to enable WSS. |
|
||||||
|
| `sslKeyfile` | string | `""` | Path to the TLS private key file (PEM). Minimum TLS version is enforced as TLSv1.2. |
|
||||||
|
|
||||||
|
## Token Issuance
|
||||||
|
|
||||||
|
For production deployments where `websocketRequiresToken: true`, use short-lived tokens instead of embedding static secrets in clients.
|
||||||
|
|
||||||
|
### How it works
|
||||||
|
|
||||||
|
1. Client sends `GET {tokenIssuePath}` with `Authorization: Bearer {tokenIssueSecret}` (or `X-Nanobot-Auth` header).
|
||||||
|
2. Server responds with a one-time-use token:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"token": "nbwt_aBcDeFg...", "expires_in": 300}
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Client opens WebSocket with `?token=nbwt_aBcDeFg...&client_id=...`.
|
||||||
|
4. The token is consumed (single use) and cannot be reused.
|
||||||
|
|
||||||
|
### Example setup
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"websocket": {
|
||||||
|
"enabled": true,
|
||||||
|
"port": 8765,
|
||||||
|
"path": "/ws",
|
||||||
|
"tokenIssuePath": "/auth/token",
|
||||||
|
"tokenIssueSecret": "your-secret-here",
|
||||||
|
"tokenTtlS": 300,
|
||||||
|
"websocketRequiresToken": true,
|
||||||
|
"allowFrom": ["*"],
|
||||||
|
"streaming": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Client flow:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Obtain a token
|
||||||
|
curl -H "Authorization: Bearer your-secret-here" http://127.0.0.1:8765/auth/token
|
||||||
|
|
||||||
|
# 2. Connect using the token
|
||||||
|
websocat "ws://127.0.0.1:8765/ws?client_id=alice&token=nbwt_aBcDeFg..."
|
||||||
|
```
|
||||||
|
|
||||||
|
### Limits
|
||||||
|
|
||||||
|
- Issued tokens are single-use — each token can only complete one handshake.
|
||||||
|
- Outstanding tokens are capped at 10,000. Requests beyond this return HTTP 429.
|
||||||
|
- Expired tokens are purged lazily on each issue or validation request.
|
||||||
|
|
||||||
|
## Security Notes
|
||||||
|
|
||||||
|
- **Timing-safe comparison**: Static token validation uses `hmac.compare_digest` to prevent timing attacks.
|
||||||
|
- **Defense in depth**: `allowFrom` is checked at both the HTTP handshake level and the message level.
|
||||||
|
- **Token isolation**: Each WebSocket connection gets a unique `chat_id`. Clients cannot access other sessions.
|
||||||
|
- **TLS enforcement**: When SSL is enabled, TLSv1.2 is the minimum allowed version.
|
||||||
|
- **Default-secure**: `websocketRequiresToken` defaults to `true`. Explicitly set it to `false` only on trusted networks.
|
||||||
|
|
||||||
|
## Media Files
|
||||||
|
|
||||||
|
Outbound `message` events may include a `media` field containing local filesystem paths. Remote clients cannot access these files directly — they need either:
|
||||||
|
|
||||||
|
- A shared filesystem mount, or
|
||||||
|
- An HTTP file server serving the nanobot media directory
|
||||||
|
|
||||||
|
## Common Patterns
|
||||||
|
|
||||||
|
### Trusted local network (no auth)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"websocket": {
|
||||||
|
"enabled": true,
|
||||||
|
"host": "0.0.0.0",
|
||||||
|
"port": 8765,
|
||||||
|
"websocketRequiresToken": false,
|
||||||
|
"allowFrom": ["*"],
|
||||||
|
"streaming": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Static token (simple auth)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"websocket": {
|
||||||
|
"enabled": true,
|
||||||
|
"token": "my-shared-secret",
|
||||||
|
"allowFrom": ["alice", "bob"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Clients connect with `?token=my-shared-secret&client_id=alice`.
|
||||||
|
|
||||||
|
### Public endpoint with issued tokens
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"websocket": {
|
||||||
|
"enabled": true,
|
||||||
|
"host": "0.0.0.0",
|
||||||
|
"port": 8765,
|
||||||
|
"path": "/ws",
|
||||||
|
"tokenIssuePath": "/auth/token",
|
||||||
|
"tokenIssueSecret": "production-secret",
|
||||||
|
"websocketRequiresToken": true,
|
||||||
|
"sslCertfile": "/etc/ssl/certs/server.pem",
|
||||||
|
"sslKeyfile": "/etc/ssl/private/server-key.pem",
|
||||||
|
"allowFrom": ["*"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Custom path
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"websocket": {
|
||||||
|
"enabled": true,
|
||||||
|
"path": "/chat/ws",
|
||||||
|
"allowFrom": ["*"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Clients connect to `ws://127.0.0.1:8765/chat/ws?client_id=...`. Trailing slashes are normalized, so `/chat/ws/` works the same.
|
||||||
@ -2,7 +2,29 @@
|
|||||||
nanobot - A lightweight AI agent framework
|
nanobot - A lightweight AI agent framework
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.1.5"
|
from importlib.metadata import PackageNotFoundError, version as _pkg_version
|
||||||
|
from pathlib import Path
|
||||||
|
import tomllib
|
||||||
|
|
||||||
|
|
||||||
|
def _read_pyproject_version() -> str | None:
|
||||||
|
"""Read the source-tree version when package metadata is unavailable."""
|
||||||
|
pyproject = Path(__file__).resolve().parent.parent / "pyproject.toml"
|
||||||
|
if not pyproject.exists():
|
||||||
|
return None
|
||||||
|
data = tomllib.loads(pyproject.read_text(encoding="utf-8"))
|
||||||
|
return data.get("project", {}).get("version")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_version() -> str:
|
||||||
|
try:
|
||||||
|
return _pkg_version("nanobot-ai")
|
||||||
|
except PackageNotFoundError:
|
||||||
|
# Source checkouts often import nanobot without installed dist-info.
|
||||||
|
return _read_pyproject_version() or "0.1.5"
|
||||||
|
|
||||||
|
|
||||||
|
__version__ = _resolve_version()
|
||||||
__logo__ = "🐈"
|
__logo__ = "🐈"
|
||||||
|
|
||||||
from nanobot.nanobot import Nanobot, RunResult
|
from nanobot.nanobot import Nanobot, RunResult
|
||||||
|
|||||||
123
nanobot/agent/autocompact.py
Normal file
123
nanobot/agent/autocompact.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
"""Auto compact: proactive compression of idle sessions to reduce token cost and latency."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Collection
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanobot.agent.memory import Consolidator
|
||||||
|
|
||||||
|
|
||||||
|
class AutoCompact:
|
||||||
|
_RECENT_SUFFIX_MESSAGES = 8
|
||||||
|
|
||||||
|
def __init__(self, sessions: SessionManager, consolidator: Consolidator,
|
||||||
|
session_ttl_minutes: int = 0):
|
||||||
|
self.sessions = sessions
|
||||||
|
self.consolidator = consolidator
|
||||||
|
self._ttl = session_ttl_minutes
|
||||||
|
self._archiving: set[str] = set()
|
||||||
|
self._summaries: dict[str, tuple[str, datetime]] = {}
|
||||||
|
|
||||||
|
def _is_expired(self, ts: datetime | str | None,
|
||||||
|
now: datetime | None = None) -> bool:
|
||||||
|
if self._ttl <= 0 or not ts:
|
||||||
|
return False
|
||||||
|
if isinstance(ts, str):
|
||||||
|
ts = datetime.fromisoformat(ts)
|
||||||
|
return ((now or datetime.now()) - ts).total_seconds() >= self._ttl * 60
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_summary(text: str, last_active: datetime) -> str:
|
||||||
|
idle_min = int((datetime.now() - last_active).total_seconds() / 60)
|
||||||
|
return f"Inactive for {idle_min} minutes.\nPrevious conversation summary: {text}"
|
||||||
|
|
||||||
|
def _split_unconsolidated(
|
||||||
|
self, session: Session,
|
||||||
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||||
|
"""Split live session tail into archiveable prefix and retained recent suffix."""
|
||||||
|
tail = list(session.messages[session.last_consolidated:])
|
||||||
|
if not tail:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
probe = Session(
|
||||||
|
key=session.key,
|
||||||
|
messages=tail.copy(),
|
||||||
|
created_at=session.created_at,
|
||||||
|
updated_at=session.updated_at,
|
||||||
|
metadata={},
|
||||||
|
last_consolidated=0,
|
||||||
|
)
|
||||||
|
probe.retain_recent_legal_suffix(self._RECENT_SUFFIX_MESSAGES)
|
||||||
|
kept = probe.messages
|
||||||
|
cut = len(tail) - len(kept)
|
||||||
|
return tail[:cut], kept
|
||||||
|
|
||||||
|
def check_expired(self, schedule_background: Callable[[Coroutine], None],
|
||||||
|
active_session_keys: Collection[str] = ()) -> None:
|
||||||
|
"""Schedule archival for idle sessions, skipping those with in-flight agent tasks."""
|
||||||
|
now = datetime.now()
|
||||||
|
for info in self.sessions.list_sessions():
|
||||||
|
key = info.get("key", "")
|
||||||
|
if not key or key in self._archiving:
|
||||||
|
continue
|
||||||
|
if key in active_session_keys:
|
||||||
|
continue
|
||||||
|
if self._is_expired(info.get("updated_at"), now):
|
||||||
|
self._archiving.add(key)
|
||||||
|
schedule_background(self._archive(key))
|
||||||
|
|
||||||
|
async def _archive(self, key: str) -> None:
|
||||||
|
try:
|
||||||
|
self.sessions.invalidate(key)
|
||||||
|
session = self.sessions.get_or_create(key)
|
||||||
|
archive_msgs, kept_msgs = self._split_unconsolidated(session)
|
||||||
|
if not archive_msgs and not kept_msgs:
|
||||||
|
session.updated_at = datetime.now()
|
||||||
|
self.sessions.save(session)
|
||||||
|
return
|
||||||
|
|
||||||
|
last_active = session.updated_at
|
||||||
|
summary = ""
|
||||||
|
if archive_msgs:
|
||||||
|
summary = await self.consolidator.archive(archive_msgs) or ""
|
||||||
|
if summary and summary != "(nothing)":
|
||||||
|
self._summaries[key] = (summary, last_active)
|
||||||
|
session.metadata["_last_summary"] = {"text": summary, "last_active": last_active.isoformat()}
|
||||||
|
session.messages = kept_msgs
|
||||||
|
session.last_consolidated = 0
|
||||||
|
session.updated_at = datetime.now()
|
||||||
|
self.sessions.save(session)
|
||||||
|
if archive_msgs:
|
||||||
|
logger.info(
|
||||||
|
"Auto-compact: archived {} (archived={}, kept={}, summary={})",
|
||||||
|
key,
|
||||||
|
len(archive_msgs),
|
||||||
|
len(kept_msgs),
|
||||||
|
bool(summary),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Auto-compact: failed for {}", key)
|
||||||
|
finally:
|
||||||
|
self._archiving.discard(key)
|
||||||
|
|
||||||
|
def prepare_session(self, session: Session, key: str) -> tuple[Session, str | None]:
|
||||||
|
if key in self._archiving or self._is_expired(session.updated_at):
|
||||||
|
logger.info("Auto-compact: reloading session {} (archiving={})", key, key in self._archiving)
|
||||||
|
session = self.sessions.get_or_create(key)
|
||||||
|
# Hot path: summary from in-memory dict (process hasn't restarted).
|
||||||
|
# Also clean metadata copy so stale _last_summary never leaks to disk.
|
||||||
|
entry = self._summaries.pop(key, None)
|
||||||
|
if entry:
|
||||||
|
session.metadata.pop("_last_summary", None)
|
||||||
|
return session, self._format_summary(entry[0], entry[1])
|
||||||
|
if "_last_summary" in session.metadata:
|
||||||
|
meta = session.metadata.pop("_last_summary")
|
||||||
|
self.sessions.save(session)
|
||||||
|
return session, self._format_summary(meta["text"], datetime.fromisoformat(meta["last_active"]))
|
||||||
|
return session, None
|
||||||
@ -6,12 +6,10 @@ import platform
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from nanobot.utils.helpers import current_time_str
|
|
||||||
|
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryStore
|
||||||
from nanobot.utils.prompt_templates import render_template
|
|
||||||
from nanobot.agent.skills import SkillsLoader
|
from nanobot.agent.skills import SkillsLoader
|
||||||
from nanobot.utils.helpers import build_assistant_message, detect_image_mime
|
from nanobot.utils.helpers import build_assistant_message, current_time_str, detect_image_mime
|
||||||
|
from nanobot.utils.prompt_templates import render_template
|
||||||
|
|
||||||
|
|
||||||
class ContextBuilder:
|
class ContextBuilder:
|
||||||
@ -20,12 +18,13 @@ class ContextBuilder:
|
|||||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
|
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
|
||||||
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||||
_MAX_RECENT_HISTORY = 50
|
_MAX_RECENT_HISTORY = 50
|
||||||
|
_RUNTIME_CONTEXT_END = "[/Runtime Context]"
|
||||||
|
|
||||||
def __init__(self, workspace: Path, timezone: str | None = None):
|
def __init__(self, workspace: Path, timezone: str | None = None, disabled_skills: list[str] | None = None):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.timezone = timezone
|
self.timezone = timezone
|
||||||
self.memory = MemoryStore(workspace)
|
self.memory = MemoryStore(workspace)
|
||||||
self.skills = SkillsLoader(workspace)
|
self.skills = SkillsLoader(workspace, disabled_skills=set(disabled_skills) if disabled_skills else None)
|
||||||
|
|
||||||
def build_system_prompt(
|
def build_system_prompt(
|
||||||
self,
|
self,
|
||||||
@ -79,12 +78,15 @@ class ContextBuilder:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_runtime_context(
|
def _build_runtime_context(
|
||||||
channel: str | None, chat_id: str | None, timezone: str | None = None,
|
channel: str | None, chat_id: str | None, timezone: str | None = None,
|
||||||
|
session_summary: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build untrusted runtime metadata block for injection before the user message."""
|
"""Build untrusted runtime metadata block for injection before the user message."""
|
||||||
lines = [f"Current Time: {current_time_str(timezone)}"]
|
lines = [f"Current Time: {current_time_str(timezone)}"]
|
||||||
if channel and chat_id:
|
if channel and chat_id:
|
||||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
if session_summary:
|
||||||
|
lines += ["", "[Resumed Session]", session_summary]
|
||||||
|
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + "\n" + ContextBuilder._RUNTIME_CONTEXT_END
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||||
@ -121,9 +123,10 @@ class ContextBuilder:
|
|||||||
channel: str | None = None,
|
channel: str | None = None,
|
||||||
chat_id: str | None = None,
|
chat_id: str | None = None,
|
||||||
current_role: str = "user",
|
current_role: str = "user",
|
||||||
|
session_summary: str | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Build the complete message list for an LLM call."""
|
"""Build the complete message list for an LLM call."""
|
||||||
runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone)
|
runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone, session_summary=session_summary)
|
||||||
user_content = self._build_user_content(current_message, media)
|
user_content = self._build_user_content(current_message, media)
|
||||||
|
|
||||||
# Merge runtime context and user content into a single user message
|
# Merge runtime context and user content into a single user message
|
||||||
@ -176,7 +179,7 @@ class ContextBuilder:
|
|||||||
# Try document text extraction
|
# Try document text extraction
|
||||||
from nanobot.utils.document import extract_text
|
from nanobot.utils.document import extract_text
|
||||||
extracted = extract_text(p)
|
extracted = extract_text(p)
|
||||||
if extracted and not extracted.startswith("Error"):
|
if extracted and not extracted.startswith("[error:"):
|
||||||
doc_texts.append(f"[File: {p.name}]\n{extracted}")
|
doc_texts.append(f"[File: {p.name}]\n{extracted}")
|
||||||
|
|
||||||
# Build final content
|
# Build final content
|
||||||
|
|||||||
@ -29,6 +29,9 @@ class AgentHookContext:
|
|||||||
class AgentHook:
|
class AgentHook:
|
||||||
"""Minimal lifecycle surface for shared runner customization."""
|
"""Minimal lifecycle surface for shared runner customization."""
|
||||||
|
|
||||||
|
def __init__(self, reraise: bool = False) -> None:
|
||||||
|
self._reraise = reraise
|
||||||
|
|
||||||
def wants_streaming(self) -> bool:
|
def wants_streaming(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -62,6 +65,7 @@ class CompositeHook(AgentHook):
|
|||||||
__slots__ = ("_hooks",)
|
__slots__ = ("_hooks",)
|
||||||
|
|
||||||
def __init__(self, hooks: list[AgentHook]) -> None:
|
def __init__(self, hooks: list[AgentHook]) -> None:
|
||||||
|
super().__init__()
|
||||||
self._hooks = list(hooks)
|
self._hooks = list(hooks)
|
||||||
|
|
||||||
def wants_streaming(self) -> bool:
|
def wants_streaming(self) -> bool:
|
||||||
@ -69,6 +73,10 @@ class CompositeHook(AgentHook):
|
|||||||
|
|
||||||
async def _for_each_hook_safe(self, method_name: str, *args: Any, **kwargs: Any) -> None:
|
async def _for_each_hook_safe(self, method_name: str, *args: Any, **kwargs: Any) -> None:
|
||||||
for h in self._hooks:
|
for h in self._hooks:
|
||||||
|
if getattr(h, "_reraise", False):
|
||||||
|
await getattr(h, method_name)(*args, **kwargs)
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await getattr(h, method_name)(*args, **kwargs)
|
await getattr(h, method_name)(*args, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
@ -12,27 +13,30 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.agent.autocompact import AutoCompact
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||||
from nanobot.agent.memory import Consolidator, Dream
|
from nanobot.agent.memory import Consolidator, Dream
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||||
|
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
|
||||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
from nanobot.agent.tools.message import MessageTool
|
from nanobot.agent.tools.message import MessageTool
|
||||||
|
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.agent.tools.search import GlobTool, GrepTool
|
from nanobot.agent.tools.search import GlobTool, GrepTool
|
||||||
from nanobot.agent.tools.shell import ExecTool
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
from nanobot.agent.tools.spawn import SpawnTool
|
from nanobot.agent.tools.spawn import SpawnTool
|
||||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||||
from nanobot.config.schema import AgentDefaults
|
from nanobot.config.schema import AgentDefaults
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.session.manager import Session, SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
from nanobot.utils.helpers import image_placeholder_text, truncate_text
|
from nanobot.utils.helpers import image_placeholder_text
|
||||||
|
from nanobot.utils.helpers import truncate_text as truncate_text_fn
|
||||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -40,6 +44,9 @@ if TYPE_CHECKING:
|
|||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
|
|
||||||
|
UNIFIED_SESSION_KEY = "unified:default"
|
||||||
|
|
||||||
|
|
||||||
class _LoopHook(AgentHook):
|
class _LoopHook(AgentHook):
|
||||||
"""Core hook for the main loop."""
|
"""Core hook for the main loop."""
|
||||||
|
|
||||||
@ -54,6 +61,7 @@ class _LoopHook(AgentHook):
|
|||||||
chat_id: str = "direct",
|
chat_id: str = "direct",
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
super().__init__(reraise=True)
|
||||||
self._loop = agent_loop
|
self._loop = agent_loop
|
||||||
self._on_progress = on_progress
|
self._on_progress = on_progress
|
||||||
self._on_stream = on_stream
|
self._on_stream = on_stream
|
||||||
@ -72,7 +80,7 @@ class _LoopHook(AgentHook):
|
|||||||
prev_clean = strip_think(self._stream_buf)
|
prev_clean = strip_think(self._stream_buf)
|
||||||
self._stream_buf += delta
|
self._stream_buf += delta
|
||||||
new_clean = strip_think(self._stream_buf)
|
new_clean = strip_think(self._stream_buf)
|
||||||
incremental = new_clean[len(prev_clean):]
|
incremental = new_clean[len(prev_clean) :]
|
||||||
if incremental and self._on_stream:
|
if incremental and self._on_stream:
|
||||||
await self._on_stream(incremental)
|
await self._on_stream(incremental)
|
||||||
|
|
||||||
@ -109,43 +117,6 @@ class _LoopHook(AgentHook):
|
|||||||
return self._loop._strip_think(content)
|
return self._loop._strip_think(content)
|
||||||
|
|
||||||
|
|
||||||
class _LoopHookChain(AgentHook):
|
|
||||||
"""Run the core hook before extra hooks."""
|
|
||||||
|
|
||||||
__slots__ = ("_primary", "_extras")
|
|
||||||
|
|
||||||
def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None:
|
|
||||||
self._primary = primary
|
|
||||||
self._extras = CompositeHook(extra_hooks)
|
|
||||||
|
|
||||||
def wants_streaming(self) -> bool:
|
|
||||||
return self._primary.wants_streaming() or self._extras.wants_streaming()
|
|
||||||
|
|
||||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
|
||||||
await self._primary.before_iteration(context)
|
|
||||||
await self._extras.before_iteration(context)
|
|
||||||
|
|
||||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
|
||||||
await self._primary.on_stream(context, delta)
|
|
||||||
await self._extras.on_stream(context, delta)
|
|
||||||
|
|
||||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
|
||||||
await self._primary.on_stream_end(context, resuming=resuming)
|
|
||||||
await self._extras.on_stream_end(context, resuming=resuming)
|
|
||||||
|
|
||||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
|
||||||
await self._primary.before_execute_tools(context)
|
|
||||||
await self._extras.before_execute_tools(context)
|
|
||||||
|
|
||||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
|
||||||
await self._primary.after_iteration(context)
|
|
||||||
await self._extras.after_iteration(context)
|
|
||||||
|
|
||||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
|
||||||
content = self._primary.finalize_content(context, content)
|
|
||||||
return self._extras.finalize_content(context, content)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentLoop:
|
class AgentLoop:
|
||||||
"""
|
"""
|
||||||
The agent loop is the core processing engine.
|
The agent loop is the core processing engine.
|
||||||
@ -159,6 +130,7 @@ class AgentLoop:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||||
|
_PENDING_USER_TURN_KEY = "pending_user_turn"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -179,7 +151,10 @@ class AgentLoop:
|
|||||||
mcp_servers: dict | None = None,
|
mcp_servers: dict | None = None,
|
||||||
channels_config: ChannelsConfig | None = None,
|
channels_config: ChannelsConfig | None = None,
|
||||||
timezone: str | None = None,
|
timezone: str | None = None,
|
||||||
|
session_ttl_minutes: int = 0,
|
||||||
hooks: list[AgentHook] | None = None,
|
hooks: list[AgentHook] | None = None,
|
||||||
|
unified_session: bool = False,
|
||||||
|
disabled_skills: list[str] | None = None,
|
||||||
):
|
):
|
||||||
from nanobot.config.schema import ExecToolConfig, WebToolsConfig
|
from nanobot.config.schema import ExecToolConfig, WebToolsConfig
|
||||||
|
|
||||||
@ -212,7 +187,7 @@ class AgentLoop:
|
|||||||
self._last_usage: dict[str, int] = {}
|
self._last_usage: dict[str, int] = {}
|
||||||
self._extra_hooks: list[AgentHook] = hooks or []
|
self._extra_hooks: list[AgentHook] = hooks or []
|
||||||
|
|
||||||
self.context = ContextBuilder(workspace, timezone=timezone)
|
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
|
||||||
self.sessions = session_manager or SessionManager(workspace)
|
self.sessions = session_manager or SessionManager(workspace)
|
||||||
self.tools = ToolRegistry()
|
self.tools = ToolRegistry()
|
||||||
self.runner = AgentRunner(provider)
|
self.runner = AgentRunner(provider)
|
||||||
@ -225,16 +200,21 @@ class AgentLoop:
|
|||||||
max_tool_result_chars=self.max_tool_result_chars,
|
max_tool_result_chars=self.max_tool_result_chars,
|
||||||
exec_config=self.exec_config,
|
exec_config=self.exec_config,
|
||||||
restrict_to_workspace=restrict_to_workspace,
|
restrict_to_workspace=restrict_to_workspace,
|
||||||
|
disabled_skills=disabled_skills,
|
||||||
)
|
)
|
||||||
|
self._unified_session = unified_session
|
||||||
self._running = False
|
self._running = False
|
||||||
self._mcp_servers = mcp_servers or {}
|
self._mcp_servers = mcp_servers or {}
|
||||||
self._mcp_stack: AsyncExitStack | None = None
|
self._mcp_stacks: dict[str, AsyncExitStack] = {}
|
||||||
self._mcp_connected = False
|
self._mcp_connected = False
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
self._background_tasks: list[asyncio.Task] = []
|
self._background_tasks: list[asyncio.Task] = []
|
||||||
self._session_locks: dict[str, asyncio.Lock] = {}
|
self._session_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
# Per-session pending queues for mid-turn message injection.
|
||||||
|
# When a session has an active task, new messages for that session
|
||||||
|
# are routed here instead of creating a new task.
|
||||||
|
self._pending_queues: dict[str, asyncio.Queue] = {}
|
||||||
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
|
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
|
||||||
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
|
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
|
||||||
self._concurrency_gate: asyncio.Semaphore | None = (
|
self._concurrency_gate: asyncio.Semaphore | None = (
|
||||||
@ -250,6 +230,11 @@ class AgentLoop:
|
|||||||
get_tool_definitions=self.tools.get_definitions,
|
get_tool_definitions=self.tools.get_definitions,
|
||||||
max_completion_tokens=provider.generation.max_tokens,
|
max_completion_tokens=provider.generation.max_tokens,
|
||||||
)
|
)
|
||||||
|
self.auto_compact = AutoCompact(
|
||||||
|
sessions=self.sessions,
|
||||||
|
consolidator=self.consolidator,
|
||||||
|
session_ttl_minutes=session_ttl_minutes,
|
||||||
|
)
|
||||||
self.dream = Dream(
|
self.dream = Dream(
|
||||||
store=self.context.memory,
|
store=self.context.memory,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
@ -261,23 +246,35 @@ class AgentLoop:
|
|||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
"""Register the default set of tools."""
|
"""Register the default set of tools."""
|
||||||
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
allowed_dir = (
|
||||||
|
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||||
|
)
|
||||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||||
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
self.tools.register(
|
||||||
|
ReadFileTool(
|
||||||
|
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
||||||
|
)
|
||||||
|
)
|
||||||
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
||||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
for cls in (GlobTool, GrepTool):
|
for cls in (GlobTool, GrepTool):
|
||||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
|
self.tools.register(NotebookEditTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||||
if self.exec_config.enable:
|
if self.exec_config.enable:
|
||||||
self.tools.register(ExecTool(
|
self.tools.register(
|
||||||
working_dir=str(self.workspace),
|
ExecTool(
|
||||||
timeout=self.exec_config.timeout,
|
working_dir=str(self.workspace),
|
||||||
restrict_to_workspace=self.restrict_to_workspace,
|
timeout=self.exec_config.timeout,
|
||||||
sandbox=self.exec_config.sandbox,
|
restrict_to_workspace=self.restrict_to_workspace,
|
||||||
path_append=self.exec_config.path_append,
|
sandbox=self.exec_config.sandbox,
|
||||||
))
|
path_append=self.exec_config.path_append,
|
||||||
|
allowed_env_keys=self.exec_config.allowed_env_keys,
|
||||||
|
)
|
||||||
|
)
|
||||||
if self.web_config.enable:
|
if self.web_config.enable:
|
||||||
self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy))
|
self.tools.register(
|
||||||
|
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
|
||||||
|
)
|
||||||
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
||||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||||
self.tools.register(SpawnTool(manager=self.subagents))
|
self.tools.register(SpawnTool(manager=self.subagents))
|
||||||
@ -292,19 +289,19 @@ class AgentLoop:
|
|||||||
return
|
return
|
||||||
self._mcp_connecting = True
|
self._mcp_connecting = True
|
||||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._mcp_stack = AsyncExitStack()
|
self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools)
|
||||||
await self._mcp_stack.__aenter__()
|
if self._mcp_stacks:
|
||||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
self._mcp_connected = True
|
||||||
self._mcp_connected = True
|
else:
|
||||||
|
logger.warning("No MCP servers connected successfully (will retry next message)")
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.warning("MCP connection cancelled (will retry next message)")
|
||||||
|
self._mcp_stacks.clear()
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||||
if self._mcp_stack:
|
self._mcp_stacks.clear()
|
||||||
try:
|
|
||||||
await self._mcp_stack.aclose()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._mcp_stack = None
|
|
||||||
finally:
|
finally:
|
||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
|
|
||||||
@ -321,6 +318,7 @@ class AgentLoop:
|
|||||||
if not text:
|
if not text:
|
||||||
return None
|
return None
|
||||||
from nanobot.utils.helpers import strip_think
|
from nanobot.utils.helpers import strip_think
|
||||||
|
|
||||||
return strip_think(text) or None
|
return strip_think(text) or None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -330,6 +328,12 @@ class AgentLoop:
|
|||||||
|
|
||||||
return format_tool_hints(tool_calls)
|
return format_tool_hints(tool_calls)
|
||||||
|
|
||||||
|
def _effective_session_key(self, msg: InboundMessage) -> str:
|
||||||
|
"""Return the session key used for task routing and mid-turn injections."""
|
||||||
|
if self._unified_session and not msg.session_key_override:
|
||||||
|
return UNIFIED_SESSION_KEY
|
||||||
|
return msg.session_key
|
||||||
|
|
||||||
async def _run_agent_loop(
|
async def _run_agent_loop(
|
||||||
self,
|
self,
|
||||||
initial_messages: list[dict],
|
initial_messages: list[dict],
|
||||||
@ -341,13 +345,16 @@ class AgentLoop:
|
|||||||
channel: str = "cli",
|
channel: str = "cli",
|
||||||
chat_id: str = "direct",
|
chat_id: str = "direct",
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
) -> tuple[str | None, list[str], list[dict]]:
|
pending_queue: asyncio.Queue | None = None,
|
||||||
|
) -> tuple[str | None, list[str], list[dict], str, bool]:
|
||||||
"""Run the agent iteration loop.
|
"""Run the agent iteration loop.
|
||||||
|
|
||||||
*on_stream*: called with each content delta during streaming.
|
*on_stream*: called with each content delta during streaming.
|
||||||
*on_stream_end(resuming)*: called when a streaming session finishes.
|
*on_stream_end(resuming)*: called when a streaming session finishes.
|
||||||
``resuming=True`` means tool calls follow (spinner should restart);
|
``resuming=True`` means tool calls follow (spinner should restart);
|
||||||
``resuming=False`` means this is the final response.
|
``resuming=False`` means this is the final response.
|
||||||
|
|
||||||
|
Returns (final_content, tools_used, messages, stop_reason, had_injections).
|
||||||
"""
|
"""
|
||||||
loop_hook = _LoopHook(
|
loop_hook = _LoopHook(
|
||||||
self,
|
self,
|
||||||
@ -359,9 +366,7 @@ class AgentLoop:
|
|||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
)
|
)
|
||||||
hook: AgentHook = (
|
hook: AgentHook = (
|
||||||
_LoopHookChain(loop_hook, self._extra_hooks)
|
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
|
||||||
if self._extra_hooks
|
|
||||||
else loop_hook
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _checkpoint(payload: dict[str, Any]) -> None:
|
async def _checkpoint(payload: dict[str, Any]) -> None:
|
||||||
@ -369,6 +374,32 @@ class AgentLoop:
|
|||||||
return
|
return
|
||||||
self._set_runtime_checkpoint(session, payload)
|
self._set_runtime_checkpoint(session, payload)
|
||||||
|
|
||||||
|
async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
|
||||||
|
"""Non-blocking drain of follow-up messages from the pending queue."""
|
||||||
|
if pending_queue is None:
|
||||||
|
return []
|
||||||
|
items: list[dict[str, Any]] = []
|
||||||
|
while len(items) < limit:
|
||||||
|
try:
|
||||||
|
pending_msg = pending_queue.get_nowait()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
user_content = self.context._build_user_content(
|
||||||
|
pending_msg.content,
|
||||||
|
pending_msg.media if pending_msg.media else None,
|
||||||
|
)
|
||||||
|
runtime_ctx = self.context._build_runtime_context(
|
||||||
|
pending_msg.channel,
|
||||||
|
pending_msg.chat_id,
|
||||||
|
self.context.timezone,
|
||||||
|
)
|
||||||
|
if isinstance(user_content, str):
|
||||||
|
merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}"
|
||||||
|
else:
|
||||||
|
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||||
|
items.append({"role": "user", "content": merged})
|
||||||
|
return items
|
||||||
|
|
||||||
result = await self.runner.run(AgentRunSpec(
|
result = await self.runner.run(AgentRunSpec(
|
||||||
initial_messages=initial_messages,
|
initial_messages=initial_messages,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
@ -385,13 +416,14 @@ class AgentLoop:
|
|||||||
provider_retry_mode=self.provider_retry_mode,
|
provider_retry_mode=self.provider_retry_mode,
|
||||||
progress_callback=on_progress,
|
progress_callback=on_progress,
|
||||||
checkpoint_callback=_checkpoint,
|
checkpoint_callback=_checkpoint,
|
||||||
|
injection_callback=_drain_pending,
|
||||||
))
|
))
|
||||||
self._last_usage = result.usage
|
self._last_usage = result.usage
|
||||||
if result.stop_reason == "max_iterations":
|
if result.stop_reason == "max_iterations":
|
||||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||||
elif result.stop_reason == "error":
|
elif result.stop_reason == "error":
|
||||||
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
|
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
|
||||||
return result.final_content, result.tools_used, result.messages
|
return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||||
@ -403,6 +435,10 @@ class AgentLoop:
|
|||||||
try:
|
try:
|
||||||
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
self.auto_compact.check_expired(
|
||||||
|
self._schedule_background,
|
||||||
|
active_session_keys=self._pending_queues.keys(),
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# Preserve real task cancellation so shutdown can complete cleanly.
|
# Preserve real task cancellation so shutdown can complete cleanly.
|
||||||
@ -421,79 +457,140 @@ class AgentLoop:
|
|||||||
if result:
|
if result:
|
||||||
await self.bus.publish_outbound(result)
|
await self.bus.publish_outbound(result)
|
||||||
continue
|
continue
|
||||||
|
effective_key = self._effective_session_key(msg)
|
||||||
|
# If this session already has an active pending queue (i.e. a task
|
||||||
|
# is processing this session), route the message there for mid-turn
|
||||||
|
# injection instead of creating a competing task.
|
||||||
|
if effective_key in self._pending_queues:
|
||||||
|
pending_msg = msg
|
||||||
|
if effective_key != msg.session_key:
|
||||||
|
pending_msg = dataclasses.replace(
|
||||||
|
msg,
|
||||||
|
session_key_override=effective_key,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self._pending_queues[effective_key].put_nowait(pending_msg)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
logger.warning(
|
||||||
|
"Pending queue full for session {}, falling back to queued task",
|
||||||
|
effective_key,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Routed follow-up message to pending queue for session {}",
|
||||||
|
effective_key,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
# Compute the effective session key before dispatching
|
||||||
|
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||||
task = asyncio.create_task(self._dispatch(msg))
|
task = asyncio.create_task(self._dispatch(msg))
|
||||||
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||||
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
task.add_done_callback(
|
||||||
|
lambda t, k=effective_key: self._active_tasks.get(k, [])
|
||||||
|
and self._active_tasks[k].remove(t)
|
||||||
|
if t in self._active_tasks.get(k, [])
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||||
"""Process a message: per-session serial, cross-session concurrent."""
|
"""Process a message: per-session serial, cross-session concurrent."""
|
||||||
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
|
session_key = self._effective_session_key(msg)
|
||||||
|
if session_key != msg.session_key:
|
||||||
|
msg = dataclasses.replace(msg, session_key_override=session_key)
|
||||||
|
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
||||||
gate = self._concurrency_gate or nullcontext()
|
gate = self._concurrency_gate or nullcontext()
|
||||||
async with lock, gate:
|
|
||||||
try:
|
|
||||||
on_stream = on_stream_end = None
|
|
||||||
if msg.metadata.get("_wants_stream"):
|
|
||||||
# Split one answer into distinct stream segments.
|
|
||||||
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
|
|
||||||
stream_segment = 0
|
|
||||||
|
|
||||||
def _current_stream_id() -> str:
|
# Register a pending queue so follow-up messages for this session are
|
||||||
return f"{stream_base_id}:{stream_segment}"
|
# routed here (mid-turn injection) instead of spawning a new task.
|
||||||
|
pending = asyncio.Queue(maxsize=20)
|
||||||
|
self._pending_queues[session_key] = pending
|
||||||
|
|
||||||
async def on_stream(delta: str) -> None:
|
try:
|
||||||
meta = dict(msg.metadata or {})
|
async with lock, gate:
|
||||||
meta["_stream_delta"] = True
|
try:
|
||||||
meta["_stream_id"] = _current_stream_id()
|
on_stream = on_stream_end = None
|
||||||
|
if msg.metadata.get("_wants_stream"):
|
||||||
|
# Split one answer into distinct stream segments.
|
||||||
|
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
|
||||||
|
stream_segment = 0
|
||||||
|
|
||||||
|
def _current_stream_id() -> str:
|
||||||
|
return f"{stream_base_id}:{stream_segment}"
|
||||||
|
|
||||||
|
async def on_stream(delta: str) -> None:
|
||||||
|
meta = dict(msg.metadata or {})
|
||||||
|
meta["_stream_delta"] = True
|
||||||
|
meta["_stream_id"] = _current_stream_id()
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content=delta,
|
||||||
|
metadata=meta,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||||
|
nonlocal stream_segment
|
||||||
|
meta = dict(msg.metadata or {})
|
||||||
|
meta["_stream_end"] = True
|
||||||
|
meta["_resuming"] = resuming
|
||||||
|
meta["_stream_id"] = _current_stream_id()
|
||||||
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
|
content="",
|
||||||
|
metadata=meta,
|
||||||
|
))
|
||||||
|
stream_segment += 1
|
||||||
|
|
||||||
|
response = await self._process_message(
|
||||||
|
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||||
|
pending_queue=pending,
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
await self.bus.publish_outbound(response)
|
||||||
|
elif msg.channel == "cli":
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content=delta,
|
content="", metadata=msg.metadata or {},
|
||||||
metadata=meta,
|
|
||||||
))
|
))
|
||||||
|
except asyncio.CancelledError:
|
||||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
logger.info("Task cancelled for session {}", session_key)
|
||||||
nonlocal stream_segment
|
raise
|
||||||
meta = dict(msg.metadata or {})
|
except Exception:
|
||||||
meta["_stream_end"] = True
|
logger.exception("Error processing message for session {}", session_key)
|
||||||
meta["_resuming"] = resuming
|
|
||||||
meta["_stream_id"] = _current_stream_id()
|
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
|
||||||
content="",
|
|
||||||
metadata=meta,
|
|
||||||
))
|
|
||||||
stream_segment += 1
|
|
||||||
|
|
||||||
response = await self._process_message(
|
|
||||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
|
||||||
)
|
|
||||||
if response is not None:
|
|
||||||
await self.bus.publish_outbound(response)
|
|
||||||
elif msg.channel == "cli":
|
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
content="", metadata=msg.metadata or {},
|
content="Sorry, I encountered an error.",
|
||||||
))
|
))
|
||||||
except asyncio.CancelledError:
|
finally:
|
||||||
logger.info("Task cancelled for session {}", msg.session_key)
|
# Drain any messages still in the pending queue and re-publish
|
||||||
raise
|
# them to the bus so they are processed as fresh inbound messages
|
||||||
except Exception:
|
# rather than silently lost.
|
||||||
logger.exception("Error processing message for session {}", msg.session_key)
|
queue = self._pending_queues.pop(session_key, None)
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
if queue is not None:
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
leftover = 0
|
||||||
content="Sorry, I encountered an error.",
|
while True:
|
||||||
))
|
try:
|
||||||
|
item = queue.get_nowait()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
await self.bus.publish_inbound(item)
|
||||||
|
leftover += 1
|
||||||
|
if leftover:
|
||||||
|
logger.info(
|
||||||
|
"Re-published {} leftover message(s) to bus for session {}",
|
||||||
|
leftover, session_key,
|
||||||
|
)
|
||||||
|
|
||||||
async def close_mcp(self) -> None:
|
async def close_mcp(self) -> None:
|
||||||
"""Drain pending background archives, then close MCP connections."""
|
"""Drain pending background archives, then close MCP connections."""
|
||||||
if self._background_tasks:
|
if self._background_tasks:
|
||||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||||
self._background_tasks.clear()
|
self._background_tasks.clear()
|
||||||
if self._mcp_stack:
|
for name, stack in self._mcp_stacks.items():
|
||||||
try:
|
try:
|
||||||
await self._mcp_stack.aclose()
|
await stack.aclose()
|
||||||
except (RuntimeError, BaseExceptionGroup):
|
except (RuntimeError, BaseExceptionGroup):
|
||||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
logger.debug("MCP server '{}' cleanup error (can be ignored)", name)
|
||||||
self._mcp_stack = None
|
self._mcp_stacks.clear()
|
||||||
|
|
||||||
def _schedule_background(self, coro) -> None:
|
def _schedule_background(self, coro) -> None:
|
||||||
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
||||||
@ -513,27 +610,36 @@ class AgentLoop:
|
|||||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||||
|
pending_queue: asyncio.Queue | None = None,
|
||||||
) -> OutboundMessage | None:
|
) -> OutboundMessage | None:
|
||||||
"""Process a single inbound message and return the response."""
|
"""Process a single inbound message and return the response."""
|
||||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||||
if msg.channel == "system":
|
if msg.channel == "system":
|
||||||
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
channel, chat_id = (
|
||||||
else ("cli", msg.chat_id))
|
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
|
||||||
|
)
|
||||||
logger.info("Processing system message from {}", msg.sender_id)
|
logger.info("Processing system message from {}", msg.sender_id)
|
||||||
key = f"{channel}:{chat_id}"
|
key = f"{channel}:{chat_id}"
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
if self._restore_runtime_checkpoint(session):
|
if self._restore_runtime_checkpoint(session):
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
if self._restore_pending_user_turn(session):
|
||||||
|
self.sessions.save(session)
|
||||||
|
|
||||||
|
session, pending = self.auto_compact.prepare_session(session, key)
|
||||||
|
|
||||||
await self.consolidator.maybe_consolidate_by_tokens(session)
|
await self.consolidator.maybe_consolidate_by_tokens(session)
|
||||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
||||||
|
|
||||||
messages = self.context.build_messages(
|
messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||||
|
session_summary=pending,
|
||||||
current_role=current_role,
|
current_role=current_role,
|
||||||
)
|
)
|
||||||
final_content, _, all_msgs = await self._run_agent_loop(
|
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
|
||||||
messages, session=session, channel=channel, chat_id=chat_id,
|
messages, session=session, channel=channel, chat_id=chat_id,
|
||||||
message_id=msg.metadata.get("message_id"),
|
message_id=msg.metadata.get("message_id"),
|
||||||
)
|
)
|
||||||
@ -541,8 +647,11 @@ class AgentLoop:
|
|||||||
self._clear_runtime_checkpoint(session)
|
self._clear_runtime_checkpoint(session)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
return OutboundMessage(
|
||||||
content=final_content or "Background task completed.")
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
|
content=final_content or "Background task completed.",
|
||||||
|
)
|
||||||
|
|
||||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
@ -551,6 +660,10 @@ class AgentLoop:
|
|||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
if self._restore_runtime_checkpoint(session):
|
if self._restore_runtime_checkpoint(session):
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
if self._restore_pending_user_turn(session):
|
||||||
|
self.sessions.save(session)
|
||||||
|
|
||||||
|
session, pending = self.auto_compact.prepare_session(session, key)
|
||||||
|
|
||||||
# Slash commands
|
# Slash commands
|
||||||
raw = msg.content.strip()
|
raw = msg.content.strip()
|
||||||
@ -566,50 +679,85 @@ class AgentLoop:
|
|||||||
message_tool.start_turn()
|
message_tool.start_turn()
|
||||||
|
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
|
|
||||||
initial_messages = self.context.build_messages(
|
initial_messages = self.context.build_messages(
|
||||||
history=history,
|
history=history,
|
||||||
current_message=msg.content,
|
current_message=msg.content,
|
||||||
|
session_summary=pending,
|
||||||
media=msg.media if msg.media else None,
|
media=msg.media if msg.media else None,
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||||
meta = dict(msg.metadata or {})
|
meta = dict(msg.metadata or {})
|
||||||
meta["_progress"] = True
|
meta["_progress"] = True
|
||||||
meta["_tool_hint"] = tool_hint
|
meta["_tool_hint"] = tool_hint
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
OutboundMessage(
|
||||||
))
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content=content,
|
||||||
|
metadata=meta,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
final_content, _, all_msgs = await self._run_agent_loop(
|
# Persist the triggering user message immediately, before running the
|
||||||
|
# agent loop. If the process is killed mid-turn (OOM, SIGKILL, self-
|
||||||
|
# restart, etc.), the existing runtime_checkpoint preserves the
|
||||||
|
# in-flight assistant/tool state but NOT the user message itself, so
|
||||||
|
# the user's prompt is silently lost on recovery. Saving it up front
|
||||||
|
# makes recovery possible from the session log alone.
|
||||||
|
user_persisted_early = False
|
||||||
|
if isinstance(msg.content, str) and msg.content.strip():
|
||||||
|
session.add_message("user", msg.content)
|
||||||
|
self._mark_pending_user_turn(session)
|
||||||
|
self.sessions.save(session)
|
||||||
|
user_persisted_early = True
|
||||||
|
|
||||||
|
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
|
||||||
initial_messages,
|
initial_messages,
|
||||||
on_progress=on_progress or _bus_progress,
|
on_progress=on_progress or _bus_progress,
|
||||||
on_stream=on_stream,
|
on_stream=on_stream,
|
||||||
on_stream_end=on_stream_end,
|
on_stream_end=on_stream_end,
|
||||||
session=session,
|
session=session,
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
message_id=msg.metadata.get("message_id"),
|
message_id=msg.metadata.get("message_id"),
|
||||||
|
pending_queue=pending_queue,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_content is None or not final_content.strip():
|
if final_content is None or not final_content.strip():
|
||||||
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
# Skip the already-persisted user message when saving the turn
|
||||||
|
save_skip = 1 + len(history) + (1 if user_persisted_early else 0)
|
||||||
|
self._save_turn(session, all_msgs, save_skip)
|
||||||
|
self._clear_pending_user_turn(session)
|
||||||
self._clear_runtime_checkpoint(session)
|
self._clear_runtime_checkpoint(session)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||||
|
|
||||||
|
# When follow-up messages were injected mid-turn, a later natural
|
||||||
|
# language reply may address those follow-ups and should not be
|
||||||
|
# suppressed just because MessageTool was used earlier in the turn.
|
||||||
|
# However, if the turn falls back to the empty-final-response
|
||||||
|
# placeholder, suppress it when the real user-visible output already
|
||||||
|
# came from MessageTool.
|
||||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||||
return None
|
if not had_injections or stop_reason == "empty_final_response":
|
||||||
|
return None
|
||||||
|
|
||||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||||
|
|
||||||
meta = dict(msg.metadata or {})
|
meta = dict(msg.metadata or {})
|
||||||
if on_stream is not None:
|
if on_stream is not None and stop_reason != "error":
|
||||||
meta["_streamed"] = True
|
meta["_streamed"] = True
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content=final_content,
|
||||||
metadata=meta,
|
metadata=meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -617,7 +765,7 @@ class AgentLoop:
|
|||||||
self,
|
self,
|
||||||
content: list[dict[str, Any]],
|
content: list[dict[str, Any]],
|
||||||
*,
|
*,
|
||||||
truncate_text: bool = False,
|
should_truncate_text: bool = False,
|
||||||
drop_runtime: bool = False,
|
drop_runtime: bool = False,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Strip volatile multimodal payloads before writing session history."""
|
"""Strip volatile multimodal payloads before writing session history."""
|
||||||
@ -635,18 +783,17 @@ class AgentLoop:
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
if block.get("type") == "image_url" and block.get("image_url", {}).get(
|
||||||
block.get("type") == "image_url"
|
"url", ""
|
||||||
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
).startswith("data:image/"):
|
||||||
):
|
|
||||||
path = (block.get("_meta") or {}).get("path", "")
|
path = (block.get("_meta") or {}).get("path", "")
|
||||||
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||||
text = block["text"]
|
text = block["text"]
|
||||||
if truncate_text and len(text) > self.max_tool_result_chars:
|
if should_truncate_text and len(text) > self.max_tool_result_chars:
|
||||||
text = truncate_text(text, self.max_tool_result_chars)
|
text = truncate_text_fn(text, self.max_tool_result_chars)
|
||||||
filtered.append({**block, "text": text})
|
filtered.append({**block, "text": text})
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -657,6 +804,7 @@ class AgentLoop:
|
|||||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||||
"""Save new-turn messages into session, truncating large tool results."""
|
"""Save new-turn messages into session, truncating large tool results."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
for m in messages[skip:]:
|
for m in messages[skip:]:
|
||||||
entry = dict(m)
|
entry = dict(m)
|
||||||
role, content = entry.get("role"), entry.get("content")
|
role, content = entry.get("role"), entry.get("content")
|
||||||
@ -664,20 +812,31 @@ class AgentLoop:
|
|||||||
continue # skip empty assistant messages — they poison session context
|
continue # skip empty assistant messages — they poison session context
|
||||||
if role == "tool":
|
if role == "tool":
|
||||||
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
|
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
|
||||||
entry["content"] = truncate_text(content, self.max_tool_result_chars)
|
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
|
||||||
elif isinstance(content, list):
|
elif isinstance(content, list):
|
||||||
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
|
filtered = self._sanitize_persisted_blocks(content, should_truncate_text=True)
|
||||||
if not filtered:
|
if not filtered:
|
||||||
continue
|
continue
|
||||||
entry["content"] = filtered
|
entry["content"] = filtered
|
||||||
elif role == "user":
|
elif role == "user":
|
||||||
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||||
# Strip the runtime-context prefix, keep only the user text.
|
# Strip the entire runtime-context block (including any session summary).
|
||||||
parts = content.split("\n\n", 1)
|
# The block is bounded by _RUNTIME_CONTEXT_TAG and _RUNTIME_CONTEXT_END.
|
||||||
if len(parts) > 1 and parts[1].strip():
|
end_marker = ContextBuilder._RUNTIME_CONTEXT_END
|
||||||
entry["content"] = parts[1]
|
end_pos = content.find(end_marker)
|
||||||
|
if end_pos >= 0:
|
||||||
|
after = content[end_pos + len(end_marker):].lstrip("\n")
|
||||||
|
if after:
|
||||||
|
entry["content"] = after
|
||||||
|
else:
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
continue
|
# Fallback: no end marker found, strip the tag prefix
|
||||||
|
after_tag = content[len(ContextBuilder._RUNTIME_CONTEXT_TAG):].lstrip("\n")
|
||||||
|
if after_tag.strip():
|
||||||
|
entry["content"] = after_tag
|
||||||
|
else:
|
||||||
|
continue
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
filtered = self._sanitize_persisted_blocks(content, drop_runtime=True)
|
filtered = self._sanitize_persisted_blocks(content, drop_runtime=True)
|
||||||
if not filtered:
|
if not filtered:
|
||||||
@ -692,6 +851,12 @@ class AgentLoop:
|
|||||||
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
|
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
|
def _mark_pending_user_turn(self, session: Session) -> None:
|
||||||
|
session.metadata[self._PENDING_USER_TURN_KEY] = True
|
||||||
|
|
||||||
|
def _clear_pending_user_turn(self, session: Session) -> None:
|
||||||
|
session.metadata.pop(self._PENDING_USER_TURN_KEY, None)
|
||||||
|
|
||||||
def _clear_runtime_checkpoint(self, session: Session) -> None:
|
def _clear_runtime_checkpoint(self, session: Session) -> None:
|
||||||
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
|
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
|
||||||
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
|
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
|
||||||
@ -735,13 +900,15 @@ class AgentLoop:
|
|||||||
continue
|
continue
|
||||||
tool_id = tool_call.get("id")
|
tool_id = tool_call.get("id")
|
||||||
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||||
restored_messages.append({
|
restored_messages.append(
|
||||||
"role": "tool",
|
{
|
||||||
"tool_call_id": tool_id,
|
"role": "tool",
|
||||||
"name": name,
|
"tool_call_id": tool_id,
|
||||||
"content": "Error: Task interrupted before this tool finished.",
|
"name": name,
|
||||||
"timestamp": datetime.now().isoformat(),
|
"content": "Error: Task interrupted before this tool finished.",
|
||||||
})
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
overlap = 0
|
overlap = 0
|
||||||
max_overlap = min(len(session.messages), len(restored_messages))
|
max_overlap = min(len(session.messages), len(restored_messages))
|
||||||
@ -756,9 +923,30 @@ class AgentLoop:
|
|||||||
break
|
break
|
||||||
session.messages.extend(restored_messages[overlap:])
|
session.messages.extend(restored_messages[overlap:])
|
||||||
|
|
||||||
|
self._clear_pending_user_turn(session)
|
||||||
self._clear_runtime_checkpoint(session)
|
self._clear_runtime_checkpoint(session)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _restore_pending_user_turn(self, session: Session) -> bool:
|
||||||
|
"""Close a turn that only persisted the user message before crashing."""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
if not session.metadata.get(self._PENDING_USER_TURN_KEY):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if session.messages and session.messages[-1].get("role") == "user":
|
||||||
|
session.messages.append(
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Error: Task interrupted before a response was generated.",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
session.updated_at = datetime.now()
|
||||||
|
|
||||||
|
self._clear_pending_user_turn(session)
|
||||||
|
return True
|
||||||
|
|
||||||
async def process_direct(
|
async def process_direct(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
@ -777,6 +965,9 @@ class AgentLoop:
|
|||||||
content=content, media=media or [],
|
content=content, media=media or [],
|
||||||
)
|
)
|
||||||
return await self._process_message(
|
return await self._process_message(
|
||||||
msg, session_key=session_key, on_progress=on_progress,
|
msg,
|
||||||
on_stream=on_stream, on_stream_end=on_stream_end,
|
session_key=session_key,
|
||||||
|
on_progress=on_progress,
|
||||||
|
on_stream=on_stream,
|
||||||
|
on_stream_end=on_stream_end,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -290,7 +290,7 @@ class MemoryStore:
|
|||||||
if not lines:
|
if not lines:
|
||||||
return None
|
return None
|
||||||
return json.loads(lines[-1])
|
return json.loads(lines[-1])
|
||||||
except (FileNotFoundError, json.JSONDecodeError):
|
except (FileNotFoundError, json.JSONDecodeError, UnicodeDecodeError):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _write_entries(self, entries: list[dict[str, Any]]) -> None:
|
def _write_entries(self, entries: list[dict[str, Any]]) -> None:
|
||||||
@ -347,6 +347,7 @@ class Consolidator:
|
|||||||
"""Lightweight consolidation: summarizes evicted messages into history.jsonl."""
|
"""Lightweight consolidation: summarizes evicted messages into history.jsonl."""
|
||||||
|
|
||||||
_MAX_CONSOLIDATION_ROUNDS = 5
|
_MAX_CONSOLIDATION_ROUNDS = 5
|
||||||
|
_MAX_CHUNK_MESSAGES = 60 # hard cap per consolidation round
|
||||||
|
|
||||||
_SAFETY_BUFFER = 1024 # extra headroom for tokenizer estimation drift
|
_SAFETY_BUFFER = 1024 # extra headroom for tokenizer estimation drift
|
||||||
|
|
||||||
@ -399,6 +400,22 @@ class Consolidator:
|
|||||||
|
|
||||||
return last_boundary
|
return last_boundary
|
||||||
|
|
||||||
|
def _cap_consolidation_boundary(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
end_idx: int,
|
||||||
|
) -> int | None:
|
||||||
|
"""Clamp the chunk size without breaking the user-turn boundary."""
|
||||||
|
start = session.last_consolidated
|
||||||
|
if end_idx - start <= self._MAX_CHUNK_MESSAGES:
|
||||||
|
return end_idx
|
||||||
|
|
||||||
|
capped_end = start + self._MAX_CHUNK_MESSAGES
|
||||||
|
for idx in range(capped_end, start, -1):
|
||||||
|
if session.messages[idx].get("role") == "user":
|
||||||
|
return idx
|
||||||
|
return None
|
||||||
|
|
||||||
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
|
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
|
||||||
"""Estimate current prompt size for the normal session history view."""
|
"""Estimate current prompt size for the normal session history view."""
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
@ -416,13 +433,13 @@ class Consolidator:
|
|||||||
self._get_tool_definitions(),
|
self._get_tool_definitions(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def archive(self, messages: list[dict]) -> bool:
|
async def archive(self, messages: list[dict]) -> str | None:
|
||||||
"""Summarize messages via LLM and append to history.jsonl.
|
"""Summarize messages via LLM and append to history.jsonl.
|
||||||
|
|
||||||
Returns True on success (or degraded success), False if nothing to do.
|
Returns the summary text on success, None if nothing to archive.
|
||||||
"""
|
"""
|
||||||
if not messages:
|
if not messages:
|
||||||
return False
|
return None
|
||||||
try:
|
try:
|
||||||
formatted = MemoryStore._format_messages(messages)
|
formatted = MemoryStore._format_messages(messages)
|
||||||
response = await self.provider.chat_with_retry(
|
response = await self.provider.chat_with_retry(
|
||||||
@ -442,11 +459,11 @@ class Consolidator:
|
|||||||
)
|
)
|
||||||
summary = response.content or "[no summary]"
|
summary = response.content or "[no summary]"
|
||||||
self.store.append_history(summary)
|
self.store.append_history(summary)
|
||||||
return True
|
return summary
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Consolidation LLM call failed, raw-dumping to history")
|
logger.warning("Consolidation LLM call failed, raw-dumping to history")
|
||||||
self.store.raw_archive(messages)
|
self.store.raw_archive(messages)
|
||||||
return True
|
return None
|
||||||
|
|
||||||
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||||
"""Loop: archive old messages until prompt fits within safe budget.
|
"""Loop: archive old messages until prompt fits within safe budget.
|
||||||
@ -461,16 +478,22 @@ class Consolidator:
|
|||||||
async with lock:
|
async with lock:
|
||||||
budget = self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER
|
budget = self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER
|
||||||
target = budget // 2
|
target = budget // 2
|
||||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
try:
|
||||||
|
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Token estimation failed for {}", session.key)
|
||||||
|
estimated, source = 0, "error"
|
||||||
if estimated <= 0:
|
if estimated <= 0:
|
||||||
return
|
return
|
||||||
if estimated < budget:
|
if estimated < budget:
|
||||||
|
unconsolidated_count = len(session.messages) - session.last_consolidated
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Token consolidation idle {}: {}/{} via {}",
|
"Token consolidation idle {}: {}/{} via {}, msgs={}",
|
||||||
session.key,
|
session.key,
|
||||||
estimated,
|
estimated,
|
||||||
self.context_window_tokens,
|
self.context_window_tokens,
|
||||||
source,
|
source,
|
||||||
|
unconsolidated_count,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -488,6 +511,15 @@ class Consolidator:
|
|||||||
return
|
return
|
||||||
|
|
||||||
end_idx = boundary[0]
|
end_idx = boundary[0]
|
||||||
|
end_idx = self._cap_consolidation_boundary(session, end_idx)
|
||||||
|
if end_idx is None:
|
||||||
|
logger.debug(
|
||||||
|
"Token consolidation: no capped boundary for {} (round {})",
|
||||||
|
session.key,
|
||||||
|
round_num,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
chunk = session.messages[session.last_consolidated:end_idx]
|
chunk = session.messages[session.last_consolidated:end_idx]
|
||||||
if not chunk:
|
if not chunk:
|
||||||
return
|
return
|
||||||
@ -506,7 +538,11 @@ class Consolidator:
|
|||||||
session.last_consolidated = end_idx
|
session.last_consolidated = end_idx
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|
||||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
try:
|
||||||
|
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Token estimation failed for {}", session.key)
|
||||||
|
estimated, source = 0, "error"
|
||||||
if estimated <= 0:
|
if estimated <= 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -546,18 +582,60 @@ class Dream:
|
|||||||
|
|
||||||
def _build_tools(self) -> ToolRegistry:
|
def _build_tools(self) -> ToolRegistry:
|
||||||
"""Build a minimal tool registry for the Dream agent."""
|
"""Build a minimal tool registry for the Dream agent."""
|
||||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool
|
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||||
|
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
|
||||||
|
|
||||||
tools = ToolRegistry()
|
tools = ToolRegistry()
|
||||||
workspace = self.store.workspace
|
workspace = self.store.workspace
|
||||||
tools.register(ReadFileTool(workspace=workspace, allowed_dir=workspace))
|
# Allow reading builtin skills for reference during skill creation
|
||||||
|
extra_read = [BUILTIN_SKILLS_DIR] if BUILTIN_SKILLS_DIR.exists() else None
|
||||||
|
tools.register(ReadFileTool(
|
||||||
|
workspace=workspace,
|
||||||
|
allowed_dir=workspace,
|
||||||
|
extra_allowed_dirs=extra_read,
|
||||||
|
))
|
||||||
tools.register(EditFileTool(workspace=workspace, allowed_dir=workspace))
|
tools.register(EditFileTool(workspace=workspace, allowed_dir=workspace))
|
||||||
|
# write_file resolves relative paths from workspace root, but can only
|
||||||
|
# write under skills/ so the prompt can safely use skills/<name>/SKILL.md.
|
||||||
|
skills_dir = workspace / "skills"
|
||||||
|
skills_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
tools.register(WriteFileTool(workspace=workspace, allowed_dir=skills_dir))
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
|
# -- skill listing --------------------------------------------------------
|
||||||
|
|
||||||
|
def _list_existing_skills(self) -> list[str]:
|
||||||
|
"""List existing skills as 'name — description' for dedup context."""
|
||||||
|
import re as _re
|
||||||
|
|
||||||
|
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||||
|
|
||||||
|
_DESC_RE = _re.compile(r"^description:\s*(.+)$", _re.MULTILINE | _re.IGNORECASE)
|
||||||
|
entries: dict[str, str] = {}
|
||||||
|
for base in (self.store.workspace / "skills", BUILTIN_SKILLS_DIR):
|
||||||
|
if not base.exists():
|
||||||
|
continue
|
||||||
|
for d in base.iterdir():
|
||||||
|
if not d.is_dir():
|
||||||
|
continue
|
||||||
|
skill_md = d / "SKILL.md"
|
||||||
|
if not skill_md.exists():
|
||||||
|
continue
|
||||||
|
# Prefer workspace skills over builtin (same name)
|
||||||
|
if d.name in entries and base == BUILTIN_SKILLS_DIR:
|
||||||
|
continue
|
||||||
|
content = skill_md.read_text(encoding="utf-8")[:500]
|
||||||
|
m = _DESC_RE.search(content)
|
||||||
|
desc = m.group(1).strip() if m else "(no description)"
|
||||||
|
entries[d.name] = desc
|
||||||
|
return [f"{name} — {desc}" for name, desc in sorted(entries.items())]
|
||||||
|
|
||||||
# -- main entry ----------------------------------------------------------
|
# -- main entry ----------------------------------------------------------
|
||||||
|
|
||||||
async def run(self) -> bool:
|
async def run(self) -> bool:
|
||||||
"""Process unprocessed history entries. Returns True if work was done."""
|
"""Process unprocessed history entries. Returns True if work was done."""
|
||||||
|
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||||
|
|
||||||
last_cursor = self.store.get_last_dream_cursor()
|
last_cursor = self.store.get_last_dream_cursor()
|
||||||
entries = self.store.read_unprocessed_history(since_cursor=last_cursor)
|
entries = self.store.read_unprocessed_history(since_cursor=last_cursor)
|
||||||
if not entries:
|
if not entries:
|
||||||
@ -579,6 +657,7 @@ class Dream:
|
|||||||
current_memory = self.store.read_memory() or "(empty)"
|
current_memory = self.store.read_memory() or "(empty)"
|
||||||
current_soul = self.store.read_soul() or "(empty)"
|
current_soul = self.store.read_soul() or "(empty)"
|
||||||
current_user = self.store.read_user() or "(empty)"
|
current_user = self.store.read_user() or "(empty)"
|
||||||
|
|
||||||
file_context = (
|
file_context = (
|
||||||
f"## Current Date\n{current_date}\n\n"
|
f"## Current Date\n{current_date}\n\n"
|
||||||
f"## Current MEMORY.md ({len(current_memory)} chars)\n{current_memory}\n\n"
|
f"## Current MEMORY.md ({len(current_memory)} chars)\n{current_memory}\n\n"
|
||||||
@ -586,7 +665,7 @@ class Dream:
|
|||||||
f"## Current USER.md ({len(current_user)} chars)\n{current_user}"
|
f"## Current USER.md ({len(current_user)} chars)\n{current_user}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Phase 1: Analyze
|
# Phase 1: Analyze (no skills list — dedup is Phase 2's job)
|
||||||
phase1_prompt = (
|
phase1_prompt = (
|
||||||
f"## Conversation History\n{history_text}\n\n{file_context}"
|
f"## Conversation History\n{history_text}\n\n{file_context}"
|
||||||
)
|
)
|
||||||
@ -611,13 +690,25 @@ class Dream:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Phase 2: Delegate to AgentRunner with read_file / edit_file
|
# Phase 2: Delegate to AgentRunner with read_file / edit_file
|
||||||
phase2_prompt = f"## Analysis Result\n{analysis}\n\n{file_context}"
|
existing_skills = self._list_existing_skills()
|
||||||
|
skills_section = ""
|
||||||
|
if existing_skills:
|
||||||
|
skills_section = (
|
||||||
|
"\n\n## Existing Skills\n"
|
||||||
|
+ "\n".join(f"- {s}" for s in existing_skills)
|
||||||
|
)
|
||||||
|
phase2_prompt = f"## Analysis Result\n{analysis}\n\n{file_context}{skills_section}"
|
||||||
|
|
||||||
tools = self._tools
|
tools = self._tools
|
||||||
|
skill_creator_path = BUILTIN_SKILLS_DIR / "skill-creator" / "SKILL.md"
|
||||||
messages: list[dict[str, Any]] = [
|
messages: list[dict[str, Any]] = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": render_template("agent/dream_phase2.md", strip=True),
|
"content": render_template(
|
||||||
|
"agent/dream_phase2.md",
|
||||||
|
strip=True,
|
||||||
|
skill_creator_path=str(skill_creator_path),
|
||||||
|
),
|
||||||
},
|
},
|
||||||
{"role": "user", "content": phase2_prompt},
|
{"role": "user", "content": phase2_prompt},
|
||||||
]
|
]
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
import inspect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -31,8 +32,11 @@ from nanobot.utils.runtime import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||||
|
_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]"
|
||||||
_MAX_EMPTY_RETRIES = 2
|
_MAX_EMPTY_RETRIES = 2
|
||||||
_MAX_LENGTH_RECOVERIES = 3
|
_MAX_LENGTH_RECOVERIES = 3
|
||||||
|
_MAX_INJECTIONS_PER_TURN = 3
|
||||||
|
_MAX_INJECTION_CYCLES = 5
|
||||||
_SNIP_SAFETY_BUFFER = 1024
|
_SNIP_SAFETY_BUFFER = 1024
|
||||||
_MICROCOMPACT_KEEP_RECENT = 10
|
_MICROCOMPACT_KEEP_RECENT = 10
|
||||||
_MICROCOMPACT_MIN_CHARS = 500
|
_MICROCOMPACT_MIN_CHARS = 500
|
||||||
@ -41,6 +45,9 @@ _COMPACTABLE_TOOLS = frozenset({
|
|||||||
"web_search", "web_fetch", "list_dir",
|
"web_search", "web_fetch", "list_dir",
|
||||||
})
|
})
|
||||||
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
|
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class AgentRunSpec:
|
class AgentRunSpec:
|
||||||
"""Configuration for a single agent execution."""
|
"""Configuration for a single agent execution."""
|
||||||
@ -65,6 +72,7 @@ class AgentRunSpec:
|
|||||||
provider_retry_mode: str = "standard"
|
provider_retry_mode: str = "standard"
|
||||||
progress_callback: Any | None = None
|
progress_callback: Any | None = None
|
||||||
checkpoint_callback: Any | None = None
|
checkpoint_callback: Any | None = None
|
||||||
|
injection_callback: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@ -78,6 +86,7 @@ class AgentRunResult:
|
|||||||
stop_reason: str = "completed"
|
stop_reason: str = "completed"
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
tool_events: list[dict[str, str]] = field(default_factory=list)
|
tool_events: list[dict[str, str]] = field(default_factory=list)
|
||||||
|
had_injections: bool = False
|
||||||
|
|
||||||
|
|
||||||
class AgentRunner:
|
class AgentRunner:
|
||||||
@ -86,6 +95,134 @@ class AgentRunner:
|
|||||||
def __init__(self, provider: LLMProvider):
|
def __init__(self, provider: LLMProvider):
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||||
|
if isinstance(left, str) and isinstance(right, str):
|
||||||
|
return f"{left}\n\n{right}" if left else right
|
||||||
|
|
||||||
|
def _to_blocks(value: Any) -> list[dict[str, Any]]:
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [
|
||||||
|
item if isinstance(item, dict) else {"type": "text", "text": str(item)}
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
return [{"type": "text", "text": str(value)}]
|
||||||
|
|
||||||
|
return _to_blocks(left) + _to_blocks(right)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _append_injected_messages(
|
||||||
|
cls,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
injections: list[dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""Append injected user messages while preserving role alternation."""
|
||||||
|
for injection in injections:
|
||||||
|
if (
|
||||||
|
messages
|
||||||
|
and injection.get("role") == "user"
|
||||||
|
and messages[-1].get("role") == "user"
|
||||||
|
):
|
||||||
|
merged = dict(messages[-1])
|
||||||
|
merged["content"] = cls._merge_message_content(
|
||||||
|
merged.get("content"),
|
||||||
|
injection.get("content"),
|
||||||
|
)
|
||||||
|
messages[-1] = merged
|
||||||
|
continue
|
||||||
|
messages.append(injection)
|
||||||
|
|
||||||
|
async def _try_drain_injections(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
assistant_message: dict[str, Any] | None,
|
||||||
|
injection_cycles: int,
|
||||||
|
*,
|
||||||
|
phase: str = "after error",
|
||||||
|
iteration: int | None = None,
|
||||||
|
) -> tuple[bool, int]:
|
||||||
|
"""Drain pending injections. Returns (should_continue, updated_cycles).
|
||||||
|
|
||||||
|
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
|
||||||
|
append them to *messages* (and emit a checkpoint if *assistant_message*
|
||||||
|
and *iteration* are both provided) and return (True, cycles+1) so the
|
||||||
|
caller continues the iteration loop. Otherwise return (False, cycles).
|
||||||
|
"""
|
||||||
|
if injection_cycles >= _MAX_INJECTION_CYCLES:
|
||||||
|
return False, injection_cycles
|
||||||
|
injections = await self._drain_injections(spec)
|
||||||
|
if not injections:
|
||||||
|
return False, injection_cycles
|
||||||
|
injection_cycles += 1
|
||||||
|
if assistant_message is not None:
|
||||||
|
messages.append(assistant_message)
|
||||||
|
if iteration is not None:
|
||||||
|
await self._emit_checkpoint(
|
||||||
|
spec,
|
||||||
|
{
|
||||||
|
"phase": "final_response",
|
||||||
|
"iteration": iteration,
|
||||||
|
"model": spec.model,
|
||||||
|
"assistant_message": assistant_message,
|
||||||
|
"completed_tool_results": [],
|
||||||
|
"pending_tool_calls": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self._append_injected_messages(messages, injections)
|
||||||
|
logger.info(
|
||||||
|
"Injected {} follow-up message(s) {} ({}/{})",
|
||||||
|
len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES,
|
||||||
|
)
|
||||||
|
return True, injection_cycles
|
||||||
|
|
||||||
|
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
|
||||||
|
"""Drain pending user messages via the injection callback.
|
||||||
|
|
||||||
|
Returns normalized user messages (capped by
|
||||||
|
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
|
||||||
|
nothing to inject. Messages beyond the cap are logged so they
|
||||||
|
are not silently lost.
|
||||||
|
"""
|
||||||
|
if spec.injection_callback is None:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
signature = inspect.signature(spec.injection_callback)
|
||||||
|
accepts_limit = (
|
||||||
|
"limit" in signature.parameters
|
||||||
|
or any(
|
||||||
|
parameter.kind is inspect.Parameter.VAR_KEYWORD
|
||||||
|
for parameter in signature.parameters.values()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if accepts_limit:
|
||||||
|
items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN)
|
||||||
|
else:
|
||||||
|
items = await spec.injection_callback()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("injection_callback failed")
|
||||||
|
return []
|
||||||
|
if not items:
|
||||||
|
return []
|
||||||
|
injected_messages: list[dict[str, Any]] = []
|
||||||
|
for item in items:
|
||||||
|
if isinstance(item, dict) and item.get("role") == "user" and "content" in item:
|
||||||
|
injected_messages.append(item)
|
||||||
|
continue
|
||||||
|
text = getattr(item, "content", str(item))
|
||||||
|
if text.strip():
|
||||||
|
injected_messages.append({"role": "user", "content": text})
|
||||||
|
if len(injected_messages) > _MAX_INJECTIONS_PER_TURN:
|
||||||
|
dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN
|
||||||
|
logger.warning(
|
||||||
|
"Injection callback returned {} messages, capping to {} ({} dropped)",
|
||||||
|
len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped,
|
||||||
|
)
|
||||||
|
injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN]
|
||||||
|
return injected_messages
|
||||||
|
|
||||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||||
hook = spec.hook or AgentHook()
|
hook = spec.hook or AgentHook()
|
||||||
messages = list(spec.initial_messages)
|
messages = list(spec.initial_messages)
|
||||||
@ -98,21 +235,35 @@ class AgentRunner:
|
|||||||
external_lookup_counts: dict[str, int] = {}
|
external_lookup_counts: dict[str, int] = {}
|
||||||
empty_content_retries = 0
|
empty_content_retries = 0
|
||||||
length_recovery_count = 0
|
length_recovery_count = 0
|
||||||
|
had_injections = False
|
||||||
|
injection_cycles = 0
|
||||||
|
|
||||||
for iteration in range(spec.max_iterations):
|
for iteration in range(spec.max_iterations):
|
||||||
try:
|
try:
|
||||||
messages = self._backfill_missing_tool_results(messages)
|
# Keep the persisted conversation untouched. Context governance
|
||||||
messages = self._microcompact(messages)
|
# may repair or compact historical messages for the model, but
|
||||||
messages = self._apply_tool_result_budget(spec, messages)
|
# those synthetic edits must not shift the append boundary used
|
||||||
messages_for_model = self._snip_history(spec, messages)
|
# later when the caller saves only the new turn.
|
||||||
|
messages_for_model = self._drop_orphan_tool_results(messages)
|
||||||
|
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||||
|
messages_for_model = self._microcompact(messages_for_model)
|
||||||
|
messages_for_model = self._apply_tool_result_budget(spec, messages_for_model)
|
||||||
|
messages_for_model = self._snip_history(spec, messages_for_model)
|
||||||
|
# Snipping may have created new orphans; clean them up.
|
||||||
|
messages_for_model = self._drop_orphan_tool_results(messages_for_model)
|
||||||
|
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Context governance failed on turn {} for {}: {}; using raw messages",
|
"Context governance failed on turn {} for {}: {}; applying minimal repair",
|
||||||
iteration,
|
iteration,
|
||||||
spec.session_key or "default",
|
spec.session_key or "default",
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
messages_for_model = messages
|
try:
|
||||||
|
messages_for_model = self._drop_orphan_tool_results(messages)
|
||||||
|
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||||
|
except Exception:
|
||||||
|
messages_for_model = messages
|
||||||
context = AgentHookContext(iteration=iteration, messages=messages)
|
context = AgentHookContext(iteration=iteration, messages=messages)
|
||||||
await hook.before_iteration(context)
|
await hook.before_iteration(context)
|
||||||
response = await self._request_model(spec, messages_for_model, hook, context)
|
response = await self._request_model(spec, messages_for_model, hook, context)
|
||||||
@ -156,16 +307,6 @@ class AgentRunner:
|
|||||||
tool_events.extend(new_events)
|
tool_events.extend(new_events)
|
||||||
context.tool_results = list(results)
|
context.tool_results = list(results)
|
||||||
context.tool_events = list(new_events)
|
context.tool_events = list(new_events)
|
||||||
if fatal_error is not None:
|
|
||||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
|
||||||
final_content = error
|
|
||||||
stop_reason = "tool_error"
|
|
||||||
self._append_final_message(messages, final_content)
|
|
||||||
context.final_content = final_content
|
|
||||||
context.error = error
|
|
||||||
context.stop_reason = stop_reason
|
|
||||||
await hook.after_iteration(context)
|
|
||||||
break
|
|
||||||
completed_tool_results: list[dict[str, Any]] = []
|
completed_tool_results: list[dict[str, Any]] = []
|
||||||
for tool_call, result in zip(response.tool_calls, results):
|
for tool_call, result in zip(response.tool_calls, results):
|
||||||
tool_message = {
|
tool_message = {
|
||||||
@ -181,6 +322,23 @@ class AgentRunner:
|
|||||||
}
|
}
|
||||||
messages.append(tool_message)
|
messages.append(tool_message)
|
||||||
completed_tool_results.append(tool_message)
|
completed_tool_results.append(tool_message)
|
||||||
|
if fatal_error is not None:
|
||||||
|
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||||
|
final_content = error
|
||||||
|
stop_reason = "tool_error"
|
||||||
|
self._append_final_message(messages, final_content)
|
||||||
|
context.final_content = final_content
|
||||||
|
context.error = error
|
||||||
|
context.stop_reason = stop_reason
|
||||||
|
await hook.after_iteration(context)
|
||||||
|
should_continue, injection_cycles = await self._try_drain_injections(
|
||||||
|
spec, messages, None, injection_cycles,
|
||||||
|
phase="after tool error",
|
||||||
|
)
|
||||||
|
if should_continue:
|
||||||
|
had_injections = True
|
||||||
|
continue
|
||||||
|
break
|
||||||
await self._emit_checkpoint(
|
await self._emit_checkpoint(
|
||||||
spec,
|
spec,
|
||||||
{
|
{
|
||||||
@ -194,6 +352,13 @@ class AgentRunner:
|
|||||||
)
|
)
|
||||||
empty_content_retries = 0
|
empty_content_retries = 0
|
||||||
length_recovery_count = 0
|
length_recovery_count = 0
|
||||||
|
# Checkpoint 1: drain injections after tools, before next LLM call
|
||||||
|
_drained, injection_cycles = await self._try_drain_injections(
|
||||||
|
spec, messages, None, injection_cycles,
|
||||||
|
phase="after tool execution",
|
||||||
|
)
|
||||||
|
if _drained:
|
||||||
|
had_injections = True
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -250,18 +415,48 @@ class AgentRunner:
|
|||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
assistant_message: dict[str, Any] | None = None
|
||||||
|
if response.finish_reason != "error" and not is_blank_text(clean):
|
||||||
|
assistant_message = build_assistant_message(
|
||||||
|
clean,
|
||||||
|
reasoning_content=response.reasoning_content,
|
||||||
|
thinking_blocks=response.thinking_blocks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for mid-turn injections BEFORE signaling stream end.
|
||||||
|
# If injections are found we keep the stream alive (resuming=True)
|
||||||
|
# so streaming channels don't prematurely finalize the card.
|
||||||
|
should_continue, injection_cycles = await self._try_drain_injections(
|
||||||
|
spec, messages, assistant_message, injection_cycles,
|
||||||
|
phase="after final response",
|
||||||
|
iteration=iteration,
|
||||||
|
)
|
||||||
|
if should_continue:
|
||||||
|
had_injections = True
|
||||||
|
|
||||||
if hook.wants_streaming():
|
if hook.wants_streaming():
|
||||||
await hook.on_stream_end(context, resuming=False)
|
await hook.on_stream_end(context, resuming=should_continue)
|
||||||
|
|
||||||
|
if should_continue:
|
||||||
|
await hook.after_iteration(context)
|
||||||
|
continue
|
||||||
|
|
||||||
if response.finish_reason == "error":
|
if response.finish_reason == "error":
|
||||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||||
stop_reason = "error"
|
stop_reason = "error"
|
||||||
error = final_content
|
error = final_content
|
||||||
self._append_final_message(messages, final_content)
|
self._append_model_error_placeholder(messages)
|
||||||
context.final_content = final_content
|
context.final_content = final_content
|
||||||
context.error = error
|
context.error = error
|
||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
|
should_continue, injection_cycles = await self._try_drain_injections(
|
||||||
|
spec, messages, None, injection_cycles,
|
||||||
|
phase="after LLM error",
|
||||||
|
)
|
||||||
|
if should_continue:
|
||||||
|
had_injections = True
|
||||||
|
continue
|
||||||
break
|
break
|
||||||
if is_blank_text(clean):
|
if is_blank_text(clean):
|
||||||
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
@ -272,9 +467,16 @@ class AgentRunner:
|
|||||||
context.error = error
|
context.error = error
|
||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
|
should_continue, injection_cycles = await self._try_drain_injections(
|
||||||
|
spec, messages, None, injection_cycles,
|
||||||
|
phase="after empty response",
|
||||||
|
)
|
||||||
|
if should_continue:
|
||||||
|
had_injections = True
|
||||||
|
continue
|
||||||
break
|
break
|
||||||
|
|
||||||
messages.append(build_assistant_message(
|
messages.append(assistant_message or build_assistant_message(
|
||||||
clean,
|
clean,
|
||||||
reasoning_content=response.reasoning_content,
|
reasoning_content=response.reasoning_content,
|
||||||
thinking_blocks=response.thinking_blocks,
|
thinking_blocks=response.thinking_blocks,
|
||||||
@ -308,6 +510,17 @@ class AgentRunner:
|
|||||||
max_iterations=spec.max_iterations,
|
max_iterations=spec.max_iterations,
|
||||||
)
|
)
|
||||||
self._append_final_message(messages, final_content)
|
self._append_final_message(messages, final_content)
|
||||||
|
# Drain any remaining injections so they are appended to the
|
||||||
|
# conversation history instead of being re-published as
|
||||||
|
# independent inbound messages by _dispatch's finally block.
|
||||||
|
# We ignore should_continue here because the for-loop has already
|
||||||
|
# exhausted all iterations.
|
||||||
|
drained_after_max_iterations, injection_cycles = await self._try_drain_injections(
|
||||||
|
spec, messages, None, injection_cycles,
|
||||||
|
phase="after max_iterations",
|
||||||
|
)
|
||||||
|
if drained_after_max_iterations:
|
||||||
|
had_injections = True
|
||||||
|
|
||||||
return AgentRunResult(
|
return AgentRunResult(
|
||||||
final_content=final_content,
|
final_content=final_content,
|
||||||
@ -317,6 +530,7 @@ class AgentRunner:
|
|||||||
stop_reason=stop_reason,
|
stop_reason=stop_reason,
|
||||||
error=error,
|
error=error,
|
||||||
tool_events=tool_events,
|
tool_events=tool_events,
|
||||||
|
had_injections=had_injections,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_request_kwargs(
|
def _build_request_kwargs(
|
||||||
@ -521,6 +735,12 @@ class AgentRunner:
|
|||||||
return
|
return
|
||||||
messages.append(build_assistant_message(content))
|
messages.append(build_assistant_message(content))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_model_error_placeholder(messages: list[dict[str, Any]]) -> None:
|
||||||
|
if messages and messages[-1].get("role") == "assistant" and not messages[-1].get("tool_calls"):
|
||||||
|
return
|
||||||
|
messages.append(build_assistant_message(_PERSISTED_MODEL_ERROR_PLACEHOLDER))
|
||||||
|
|
||||||
def _normalize_tool_result(
|
def _normalize_tool_result(
|
||||||
self,
|
self,
|
||||||
spec: AgentRunSpec,
|
spec: AgentRunSpec,
|
||||||
@ -549,6 +769,32 @@ class AgentRunner:
|
|||||||
return truncate_text(content, spec.max_tool_result_chars)
|
return truncate_text(content, spec.max_tool_result_chars)
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _drop_orphan_tool_results(
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Drop tool results that have no matching assistant tool_call earlier in the history."""
|
||||||
|
declared: set[str] = set()
|
||||||
|
updated: list[dict[str, Any]] | None = None
|
||||||
|
for idx, msg in enumerate(messages):
|
||||||
|
role = msg.get("role")
|
||||||
|
if role == "assistant":
|
||||||
|
for tc in msg.get("tool_calls") or []:
|
||||||
|
if isinstance(tc, dict) and tc.get("id"):
|
||||||
|
declared.add(str(tc["id"]))
|
||||||
|
if role == "tool":
|
||||||
|
tid = msg.get("tool_call_id")
|
||||||
|
if tid and str(tid) not in declared:
|
||||||
|
if updated is None:
|
||||||
|
updated = [dict(m) for m in messages[:idx]]
|
||||||
|
continue
|
||||||
|
if updated is not None:
|
||||||
|
updated.append(dict(msg))
|
||||||
|
|
||||||
|
if updated is None:
|
||||||
|
return messages
|
||||||
|
return updated
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _backfill_missing_tool_results(
|
def _backfill_missing_tool_results(
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
|
|||||||
@ -28,10 +28,11 @@ class SkillsLoader:
|
|||||||
specific tools or perform certain tasks.
|
specific tools or perform certain tasks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None):
|
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None, disabled_skills: set[str] | None = None):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.workspace_skills = workspace / "skills"
|
self.workspace_skills = workspace / "skills"
|
||||||
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
|
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
|
||||||
|
self.disabled_skills = disabled_skills or set()
|
||||||
|
|
||||||
def _skill_entries_from_dir(self, base: Path, source: str, *, skip_names: set[str] | None = None) -> list[dict[str, str]]:
|
def _skill_entries_from_dir(self, base: Path, source: str, *, skip_names: set[str] | None = None) -> list[dict[str, str]]:
|
||||||
if not base.exists():
|
if not base.exists():
|
||||||
@ -66,6 +67,9 @@ class SkillsLoader:
|
|||||||
self._skill_entries_from_dir(self.builtin_skills, "builtin", skip_names=workspace_names)
|
self._skill_entries_from_dir(self.builtin_skills, "builtin", skip_names=workspace_names)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.disabled_skills:
|
||||||
|
skills = [s for s in skills if s["name"] not in self.disabled_skills]
|
||||||
|
|
||||||
if filter_unavailable:
|
if filter_unavailable:
|
||||||
return [skill for skill in skills if self._check_requirements(self._get_skill_meta(skill["name"]))]
|
return [skill for skill in skills if self._check_requirements(self._get_skill_meta(skill["name"]))]
|
||||||
return skills
|
return skills
|
||||||
|
|||||||
@ -27,6 +27,7 @@ class _SubagentHook(AgentHook):
|
|||||||
"""Logging-only hook for subagent execution."""
|
"""Logging-only hook for subagent execution."""
|
||||||
|
|
||||||
def __init__(self, task_id: str) -> None:
|
def __init__(self, task_id: str) -> None:
|
||||||
|
super().__init__()
|
||||||
self._task_id = task_id
|
self._task_id = task_id
|
||||||
|
|
||||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||||
@ -51,6 +52,7 @@ class SubagentManager:
|
|||||||
web_config: "WebToolsConfig | None" = None,
|
web_config: "WebToolsConfig | None" = None,
|
||||||
exec_config: "ExecToolConfig | None" = None,
|
exec_config: "ExecToolConfig | None" = None,
|
||||||
restrict_to_workspace: bool = False,
|
restrict_to_workspace: bool = False,
|
||||||
|
disabled_skills: list[str] | None = None,
|
||||||
):
|
):
|
||||||
from nanobot.config.schema import ExecToolConfig
|
from nanobot.config.schema import ExecToolConfig
|
||||||
|
|
||||||
@ -62,6 +64,7 @@ class SubagentManager:
|
|||||||
self.max_tool_result_chars = max_tool_result_chars
|
self.max_tool_result_chars = max_tool_result_chars
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
|
self.disabled_skills = set(disabled_skills or [])
|
||||||
self.runner = AgentRunner(provider)
|
self.runner = AgentRunner(provider)
|
||||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||||
@ -235,7 +238,10 @@ class SubagentManager:
|
|||||||
from nanobot.agent.skills import SkillsLoader
|
from nanobot.agent.skills import SkillsLoader
|
||||||
|
|
||||||
time_ctx = ContextBuilder._build_runtime_context(None, None)
|
time_ctx = ContextBuilder._build_runtime_context(None, None)
|
||||||
skills_summary = SkillsLoader(self.workspace).build_skills_summary()
|
skills_summary = SkillsLoader(
|
||||||
|
self.workspace,
|
||||||
|
disabled_skills=self.disabled_skills,
|
||||||
|
).build_skills_summary()
|
||||||
return render_template(
|
return render_template(
|
||||||
"agent/subagent_system.md",
|
"agent/subagent_system.md",
|
||||||
time_ctx=time_ctx,
|
time_ctx=time_ctx,
|
||||||
|
|||||||
105
nanobot/agent/tools/file_state.py
Normal file
105
nanobot/agent/tools/file_state.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
"""Track file-read state for read-before-edit warnings and read deduplication."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ReadState:
|
||||||
|
mtime: float
|
||||||
|
offset: int
|
||||||
|
limit: int | None
|
||||||
|
content_hash: str | None
|
||||||
|
can_dedup: bool
|
||||||
|
|
||||||
|
|
||||||
|
_state: dict[str, ReadState] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_file(p: str) -> str | None:
|
||||||
|
try:
|
||||||
|
return hashlib.sha256(Path(p).read_bytes()).hexdigest()
|
||||||
|
except OSError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def record_read(path: str | Path, offset: int = 1, limit: int | None = None) -> None:
|
||||||
|
"""Record that a file was read (called after successful read)."""
|
||||||
|
p = str(Path(path).resolve())
|
||||||
|
try:
|
||||||
|
mtime = os.path.getmtime(p)
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
_state[p] = ReadState(
|
||||||
|
mtime=mtime,
|
||||||
|
offset=offset,
|
||||||
|
limit=limit,
|
||||||
|
content_hash=_hash_file(p),
|
||||||
|
can_dedup=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def record_write(path: str | Path) -> None:
|
||||||
|
"""Record that a file was written (updates mtime in state)."""
|
||||||
|
p = str(Path(path).resolve())
|
||||||
|
try:
|
||||||
|
mtime = os.path.getmtime(p)
|
||||||
|
except OSError:
|
||||||
|
_state.pop(p, None)
|
||||||
|
return
|
||||||
|
_state[p] = ReadState(
|
||||||
|
mtime=mtime,
|
||||||
|
offset=1,
|
||||||
|
limit=None,
|
||||||
|
content_hash=_hash_file(p),
|
||||||
|
can_dedup=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_read(path: str | Path) -> str | None:
|
||||||
|
"""Check if a file has been read and is fresh.
|
||||||
|
|
||||||
|
Returns None if OK, or a warning string.
|
||||||
|
When mtime changed but file content is identical (e.g. touch, editor save),
|
||||||
|
the check passes to avoid false-positive staleness warnings.
|
||||||
|
"""
|
||||||
|
p = str(Path(path).resolve())
|
||||||
|
entry = _state.get(p)
|
||||||
|
if entry is None:
|
||||||
|
return "Warning: file has not been read yet. Read it first to verify content before editing."
|
||||||
|
try:
|
||||||
|
current_mtime = os.path.getmtime(p)
|
||||||
|
except OSError:
|
||||||
|
return None
|
||||||
|
if current_mtime != entry.mtime:
|
||||||
|
if entry.content_hash and _hash_file(p) == entry.content_hash:
|
||||||
|
entry.mtime = current_mtime
|
||||||
|
return None
|
||||||
|
return "Warning: file has been modified since last read. Re-read to verify content before editing."
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_unchanged(path: str | Path, offset: int = 1, limit: int | None = None) -> bool:
|
||||||
|
"""Return True if file was previously read with same params and mtime is unchanged."""
|
||||||
|
p = str(Path(path).resolve())
|
||||||
|
entry = _state.get(p)
|
||||||
|
if entry is None:
|
||||||
|
return False
|
||||||
|
if not entry.can_dedup:
|
||||||
|
return False
|
||||||
|
if entry.offset != offset or entry.limit != limit:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
current_mtime = os.path.getmtime(p)
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
return current_mtime == entry.mtime
|
||||||
|
|
||||||
|
|
||||||
|
def clear() -> None:
|
||||||
|
"""Clear all tracked state (useful for testing)."""
|
||||||
|
_state.clear()
|
||||||
@ -2,11 +2,13 @@
|
|||||||
|
|
||||||
import difflib
|
import difflib
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||||
|
from nanobot.agent.tools import file_state
|
||||||
from nanobot.utils.helpers import build_image_content_blocks, detect_image_mime
|
from nanobot.utils.helpers import build_image_content_blocks, detect_image_mime
|
||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
|
|
||||||
@ -60,6 +62,36 @@ class _FsTool(Tool):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_BLOCKED_DEVICE_PATHS = frozenset({
|
||||||
|
"/dev/zero", "/dev/random", "/dev/urandom", "/dev/full",
|
||||||
|
"/dev/stdin", "/dev/stdout", "/dev/stderr",
|
||||||
|
"/dev/tty", "/dev/console",
|
||||||
|
"/dev/fd/0", "/dev/fd/1", "/dev/fd/2",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _is_blocked_device(path: str | Path) -> bool:
|
||||||
|
"""Check if path is a blocked device that could hang or produce infinite output."""
|
||||||
|
import re
|
||||||
|
raw = str(path)
|
||||||
|
if raw in _BLOCKED_DEVICE_PATHS:
|
||||||
|
return True
|
||||||
|
if re.match(r"/proc/\d+/fd/[012]$", raw) or re.match(r"/proc/self/fd/[012]$", raw):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_page_range(pages: str, total: int) -> tuple[int, int]:
|
||||||
|
"""Parse a page range like '2-5' into 0-based (start, end) inclusive."""
|
||||||
|
parts = pages.strip().split("-")
|
||||||
|
if len(parts) == 1:
|
||||||
|
p = int(parts[0])
|
||||||
|
return max(0, p - 1), min(p - 1, total - 1)
|
||||||
|
start = int(parts[0])
|
||||||
|
end = int(parts[1])
|
||||||
|
return max(0, start - 1), min(end - 1, total - 1)
|
||||||
|
|
||||||
|
|
||||||
@tool_parameters(
|
@tool_parameters(
|
||||||
tool_parameters_schema(
|
tool_parameters_schema(
|
||||||
path=StringSchema("The file path to read"),
|
path=StringSchema("The file path to read"),
|
||||||
@ -73,6 +105,7 @@ class _FsTool(Tool):
|
|||||||
description="Maximum number of lines to read (default 2000)",
|
description="Maximum number of lines to read (default 2000)",
|
||||||
minimum=1,
|
minimum=1,
|
||||||
),
|
),
|
||||||
|
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
|
||||||
required=["path"],
|
required=["path"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -81,6 +114,7 @@ class ReadFileTool(_FsTool):
|
|||||||
|
|
||||||
_MAX_CHARS = 128_000
|
_MAX_CHARS = 128_000
|
||||||
_DEFAULT_LIMIT = 2000
|
_DEFAULT_LIMIT = 2000
|
||||||
|
_MAX_PDF_PAGES = 20
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@ -89,9 +123,10 @@ class ReadFileTool(_FsTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Read a text file. Output format: LINE_NUM|CONTENT. "
|
"Read a file (text or image). Text output format: LINE_NUM|CONTENT. "
|
||||||
|
"Images return visual content for analysis. "
|
||||||
"Use offset and limit for large files. "
|
"Use offset and limit for large files. "
|
||||||
"Cannot read binary files or images. "
|
"Cannot read non-image binary files. "
|
||||||
"Reads exceeding ~128K chars are truncated."
|
"Reads exceeding ~128K chars are truncated."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -99,16 +134,27 @@ class ReadFileTool(_FsTool):
|
|||||||
def read_only(self) -> bool:
|
def read_only(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any:
|
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
if not path:
|
if not path:
|
||||||
return "Error reading file: Unknown path"
|
return "Error reading file: Unknown path"
|
||||||
|
|
||||||
|
# Device path blacklist
|
||||||
|
if _is_blocked_device(path):
|
||||||
|
return f"Error: Reading {path} is blocked (device path that could hang or produce infinite output)."
|
||||||
|
|
||||||
fp = self._resolve(path)
|
fp = self._resolve(path)
|
||||||
|
if _is_blocked_device(fp):
|
||||||
|
return f"Error: Reading {fp} is blocked (device path that could hang or produce infinite output)."
|
||||||
if not fp.exists():
|
if not fp.exists():
|
||||||
return f"Error: File not found: {path}"
|
return f"Error: File not found: {path}"
|
||||||
if not fp.is_file():
|
if not fp.is_file():
|
||||||
return f"Error: Not a file: {path}"
|
return f"Error: Not a file: {path}"
|
||||||
|
|
||||||
|
# PDF support
|
||||||
|
if fp.suffix.lower() == ".pdf":
|
||||||
|
return self._read_pdf(fp, pages)
|
||||||
|
|
||||||
raw = fp.read_bytes()
|
raw = fp.read_bytes()
|
||||||
if not raw:
|
if not raw:
|
||||||
return f"(Empty file: {path})"
|
return f"(Empty file: {path})"
|
||||||
@ -117,6 +163,10 @@ class ReadFileTool(_FsTool):
|
|||||||
if mime and mime.startswith("image/"):
|
if mime and mime.startswith("image/"):
|
||||||
return build_image_content_blocks(raw, mime, str(fp), f"(Image file: {path})")
|
return build_image_content_blocks(raw, mime, str(fp), f"(Image file: {path})")
|
||||||
|
|
||||||
|
# Read dedup: same path + offset + limit + unchanged mtime → stub
|
||||||
|
if file_state.is_unchanged(fp, offset=offset, limit=limit):
|
||||||
|
return f"[File unchanged since last read: {path}]"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
text_content = raw.decode("utf-8")
|
text_content = raw.decode("utf-8")
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
@ -149,12 +199,59 @@ class ReadFileTool(_FsTool):
|
|||||||
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
|
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
|
||||||
else:
|
else:
|
||||||
result += f"\n\n(End of file — {total} lines total)"
|
result += f"\n\n(End of file — {total} lines total)"
|
||||||
|
file_state.record_read(fp, offset=offset, limit=limit)
|
||||||
return result
|
return result
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error reading file: {e}"
|
return f"Error reading file: {e}"
|
||||||
|
|
||||||
|
def _read_pdf(self, fp: Path, pages: str | None) -> str:
|
||||||
|
try:
|
||||||
|
import fitz # pymupdf
|
||||||
|
except ImportError:
|
||||||
|
return "Error: PDF reading requires pymupdf. Install with: pip install pymupdf"
|
||||||
|
|
||||||
|
try:
|
||||||
|
doc = fitz.open(str(fp))
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error reading PDF: {e}"
|
||||||
|
|
||||||
|
total_pages = len(doc)
|
||||||
|
if pages:
|
||||||
|
try:
|
||||||
|
start, end = _parse_page_range(pages, total_pages)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
doc.close()
|
||||||
|
return f"Error: Invalid page range '{pages}'. Use format like '1-5'."
|
||||||
|
if start > end or start >= total_pages:
|
||||||
|
doc.close()
|
||||||
|
return f"Error: Page range '{pages}' is out of bounds (document has {total_pages} pages)."
|
||||||
|
else:
|
||||||
|
start = 0
|
||||||
|
end = min(total_pages - 1, self._MAX_PDF_PAGES - 1)
|
||||||
|
|
||||||
|
if end - start + 1 > self._MAX_PDF_PAGES:
|
||||||
|
end = start + self._MAX_PDF_PAGES - 1
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
for i in range(start, end + 1):
|
||||||
|
page = doc[i]
|
||||||
|
text = page.get_text().strip()
|
||||||
|
if text:
|
||||||
|
parts.append(f"--- Page {i + 1} ---\n{text}")
|
||||||
|
doc.close()
|
||||||
|
|
||||||
|
if not parts:
|
||||||
|
return f"(PDF has no extractable text: {fp})"
|
||||||
|
|
||||||
|
result = "\n\n".join(parts)
|
||||||
|
if end < total_pages - 1:
|
||||||
|
result += f"\n\n(Showing pages {start + 1}-{end + 1} of {total_pages}. Use pages='{end + 2}-{min(end + 1 + self._MAX_PDF_PAGES, total_pages)}' to continue.)"
|
||||||
|
if len(result) > self._MAX_CHARS:
|
||||||
|
result = result[:self._MAX_CHARS] + "\n\n(PDF text truncated at ~128K chars)"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# write_file
|
# write_file
|
||||||
@ -192,6 +289,7 @@ class WriteFileTool(_FsTool):
|
|||||||
fp = self._resolve(path)
|
fp = self._resolve(path)
|
||||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||||
fp.write_text(content, encoding="utf-8")
|
fp.write_text(content, encoding="utf-8")
|
||||||
|
file_state.record_write(fp)
|
||||||
return f"Successfully wrote {len(content)} characters to {fp}"
|
return f"Successfully wrote {len(content)} characters to {fp}"
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
@ -203,30 +301,269 @@ class WriteFileTool(_FsTool):
|
|||||||
# edit_file
|
# edit_file
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_QUOTE_TABLE = str.maketrans({
|
||||||
|
"\u2018": "'", "\u2019": "'", # curly single → straight
|
||||||
|
"\u201c": '"', "\u201d": '"', # curly double → straight
|
||||||
|
"'": "'", '"': '"', # identity (kept for completeness)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_quotes(s: str) -> str:
|
||||||
|
return s.translate(_QUOTE_TABLE)
|
||||||
|
|
||||||
|
|
||||||
|
def _curly_double_quotes(text: str) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
opening = True
|
||||||
|
for ch in text:
|
||||||
|
if ch == '"':
|
||||||
|
parts.append("\u201c" if opening else "\u201d")
|
||||||
|
opening = not opening
|
||||||
|
else:
|
||||||
|
parts.append(ch)
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _curly_single_quotes(text: str) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
opening = True
|
||||||
|
for i, ch in enumerate(text):
|
||||||
|
if ch != "'":
|
||||||
|
parts.append(ch)
|
||||||
|
continue
|
||||||
|
prev_ch = text[i - 1] if i > 0 else ""
|
||||||
|
next_ch = text[i + 1] if i + 1 < len(text) else ""
|
||||||
|
if prev_ch.isalnum() and next_ch.isalnum():
|
||||||
|
parts.append("\u2019")
|
||||||
|
continue
|
||||||
|
parts.append("\u2018" if opening else "\u2019")
|
||||||
|
opening = not opening
|
||||||
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _preserve_quote_style(old_text: str, actual_text: str, new_text: str) -> str:
|
||||||
|
"""Preserve curly quote style when a quote-normalized fallback matched."""
|
||||||
|
if _normalize_quotes(old_text.strip()) != _normalize_quotes(actual_text.strip()) or old_text == actual_text:
|
||||||
|
return new_text
|
||||||
|
|
||||||
|
styled = new_text
|
||||||
|
if any(ch in actual_text for ch in ("\u201c", "\u201d")) and '"' in styled:
|
||||||
|
styled = _curly_double_quotes(styled)
|
||||||
|
if any(ch in actual_text for ch in ("\u2018", "\u2019")) and "'" in styled:
|
||||||
|
styled = _curly_single_quotes(styled)
|
||||||
|
return styled
|
||||||
|
|
||||||
|
|
||||||
|
def _leading_ws(line: str) -> str:
|
||||||
|
return line[: len(line) - len(line.lstrip(" \t"))]
|
||||||
|
|
||||||
|
|
||||||
|
def _reindent_like_match(old_text: str, actual_text: str, new_text: str) -> str:
|
||||||
|
"""Preserve the outer indentation from the actual matched block."""
|
||||||
|
old_lines = old_text.split("\n")
|
||||||
|
actual_lines = actual_text.split("\n")
|
||||||
|
if len(old_lines) != len(actual_lines):
|
||||||
|
return new_text
|
||||||
|
|
||||||
|
comparable = [
|
||||||
|
(old_line, actual_line)
|
||||||
|
for old_line, actual_line in zip(old_lines, actual_lines)
|
||||||
|
if old_line.strip() and actual_line.strip()
|
||||||
|
]
|
||||||
|
if not comparable or any(
|
||||||
|
_normalize_quotes(old_line.strip()) != _normalize_quotes(actual_line.strip())
|
||||||
|
for old_line, actual_line in comparable
|
||||||
|
):
|
||||||
|
return new_text
|
||||||
|
|
||||||
|
old_ws = _leading_ws(comparable[0][0])
|
||||||
|
actual_ws = _leading_ws(comparable[0][1])
|
||||||
|
if actual_ws == old_ws:
|
||||||
|
return new_text
|
||||||
|
|
||||||
|
if old_ws:
|
||||||
|
if not actual_ws.startswith(old_ws):
|
||||||
|
return new_text
|
||||||
|
delta = actual_ws[len(old_ws):]
|
||||||
|
else:
|
||||||
|
delta = actual_ws
|
||||||
|
|
||||||
|
if not delta:
|
||||||
|
return new_text
|
||||||
|
|
||||||
|
return "\n".join((delta + line) if line else line for line in new_text.split("\n"))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _MatchSpan:
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
text: str
|
||||||
|
line: int
|
||||||
|
|
||||||
|
|
||||||
|
def _find_exact_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||||
|
matches: list[_MatchSpan] = []
|
||||||
|
start = 0
|
||||||
|
while True:
|
||||||
|
idx = content.find(old_text, start)
|
||||||
|
if idx == -1:
|
||||||
|
break
|
||||||
|
matches.append(
|
||||||
|
_MatchSpan(
|
||||||
|
start=idx,
|
||||||
|
end=idx + len(old_text),
|
||||||
|
text=content[idx : idx + len(old_text)],
|
||||||
|
line=content.count("\n", 0, idx) + 1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
start = idx + max(1, len(old_text))
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def _find_trim_matches(content: str, old_text: str, *, normalize_quotes: bool = False) -> list[_MatchSpan]:
|
||||||
|
old_lines = old_text.splitlines()
|
||||||
|
if not old_lines:
|
||||||
|
return []
|
||||||
|
|
||||||
|
content_lines = content.splitlines()
|
||||||
|
content_lines_keepends = content.splitlines(keepends=True)
|
||||||
|
if len(content_lines) < len(old_lines):
|
||||||
|
return []
|
||||||
|
|
||||||
|
offsets: list[int] = []
|
||||||
|
pos = 0
|
||||||
|
for line in content_lines_keepends:
|
||||||
|
offsets.append(pos)
|
||||||
|
pos += len(line)
|
||||||
|
offsets.append(pos)
|
||||||
|
|
||||||
|
if normalize_quotes:
|
||||||
|
stripped_old = [_normalize_quotes(line.strip()) for line in old_lines]
|
||||||
|
else:
|
||||||
|
stripped_old = [line.strip() for line in old_lines]
|
||||||
|
|
||||||
|
matches: list[_MatchSpan] = []
|
||||||
|
window_size = len(stripped_old)
|
||||||
|
for i in range(len(content_lines) - window_size + 1):
|
||||||
|
window = content_lines[i : i + window_size]
|
||||||
|
if normalize_quotes:
|
||||||
|
comparable = [_normalize_quotes(line.strip()) for line in window]
|
||||||
|
else:
|
||||||
|
comparable = [line.strip() for line in window]
|
||||||
|
if comparable != stripped_old:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start = offsets[i]
|
||||||
|
end = offsets[i + window_size]
|
||||||
|
if content_lines_keepends[i + window_size - 1].endswith("\n"):
|
||||||
|
end -= 1
|
||||||
|
matches.append(
|
||||||
|
_MatchSpan(
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
text=content[start:end],
|
||||||
|
line=i + 1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def _find_quote_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||||
|
norm_content = _normalize_quotes(content)
|
||||||
|
norm_old = _normalize_quotes(old_text)
|
||||||
|
matches: list[_MatchSpan] = []
|
||||||
|
start = 0
|
||||||
|
while True:
|
||||||
|
idx = norm_content.find(norm_old, start)
|
||||||
|
if idx == -1:
|
||||||
|
break
|
||||||
|
matches.append(
|
||||||
|
_MatchSpan(
|
||||||
|
start=idx,
|
||||||
|
end=idx + len(old_text),
|
||||||
|
text=content[idx : idx + len(old_text)],
|
||||||
|
line=content.count("\n", 0, idx) + 1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
start = idx + max(1, len(norm_old))
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def _find_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||||
|
"""Locate all matches using progressively looser strategies."""
|
||||||
|
for matcher in (
|
||||||
|
lambda: _find_exact_matches(content, old_text),
|
||||||
|
lambda: _find_trim_matches(content, old_text),
|
||||||
|
lambda: _find_trim_matches(content, old_text, normalize_quotes=True),
|
||||||
|
lambda: _find_quote_matches(content, old_text),
|
||||||
|
):
|
||||||
|
matches = matcher()
|
||||||
|
if matches:
|
||||||
|
return matches
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _find_match_line_numbers(content: str, old_text: str) -> list[int]:
|
||||||
|
"""Return 1-based starting line numbers for the current matching strategies."""
|
||||||
|
return [match.line for match in _find_matches(content, old_text)]
|
||||||
|
|
||||||
|
|
||||||
|
def _collapse_internal_whitespace(text: str) -> str:
|
||||||
|
return "\n".join(" ".join(line.split()) for line in text.splitlines())
|
||||||
|
|
||||||
|
|
||||||
|
def _diagnose_near_match(old_text: str, actual_text: str) -> list[str]:
|
||||||
|
"""Return actionable hints describing why text was close but not exact."""
|
||||||
|
hints: list[str] = []
|
||||||
|
|
||||||
|
if old_text.lower() == actual_text.lower() and old_text != actual_text:
|
||||||
|
hints.append("letter case differs")
|
||||||
|
if _collapse_internal_whitespace(old_text) == _collapse_internal_whitespace(actual_text) and old_text != actual_text:
|
||||||
|
hints.append("whitespace differs")
|
||||||
|
if old_text.rstrip("\n") == actual_text.rstrip("\n") and old_text != actual_text:
|
||||||
|
hints.append("trailing newline differs")
|
||||||
|
if _normalize_quotes(old_text) == _normalize_quotes(actual_text) and old_text != actual_text:
|
||||||
|
hints.append("quote style differs")
|
||||||
|
|
||||||
|
return hints
|
||||||
|
|
||||||
|
|
||||||
|
def _best_window(old_text: str, content: str) -> tuple[float, int, list[str], list[str]]:
|
||||||
|
"""Find the closest line-window match and return ratio/start/snippet/hints."""
|
||||||
|
lines = content.splitlines(keepends=True)
|
||||||
|
old_lines = old_text.splitlines(keepends=True)
|
||||||
|
window = max(1, len(old_lines))
|
||||||
|
|
||||||
|
best_ratio, best_start = -1.0, 0
|
||||||
|
best_window_lines: list[str] = []
|
||||||
|
|
||||||
|
for i in range(max(1, len(lines) - window + 1)):
|
||||||
|
current = lines[i : i + window]
|
||||||
|
ratio = difflib.SequenceMatcher(None, old_lines, current).ratio()
|
||||||
|
if ratio > best_ratio:
|
||||||
|
best_ratio, best_start = ratio, i
|
||||||
|
best_window_lines = current
|
||||||
|
|
||||||
|
actual_text = "".join(best_window_lines).replace("\r\n", "\n").rstrip("\n")
|
||||||
|
hints = _diagnose_near_match(old_text.replace("\r\n", "\n").rstrip("\n"), actual_text)
|
||||||
|
return best_ratio, best_start, best_window_lines, hints
|
||||||
|
|
||||||
|
|
||||||
def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||||
"""Locate old_text in content: exact first, then line-trimmed sliding window.
|
"""Locate old_text in content with a multi-level fallback chain:
|
||||||
|
|
||||||
|
1. Exact substring match
|
||||||
|
2. Line-trimmed sliding window (handles indentation differences)
|
||||||
|
3. Smart quote normalization (curly ↔ straight quotes)
|
||||||
|
|
||||||
Both inputs should use LF line endings (caller normalises CRLF).
|
Both inputs should use LF line endings (caller normalises CRLF).
|
||||||
Returns (matched_fragment, count) or (None, 0).
|
Returns (matched_fragment, count) or (None, 0).
|
||||||
"""
|
"""
|
||||||
if old_text in content:
|
matches = _find_matches(content, old_text)
|
||||||
return old_text, content.count(old_text)
|
if not matches:
|
||||||
|
|
||||||
old_lines = old_text.splitlines()
|
|
||||||
if not old_lines:
|
|
||||||
return None, 0
|
return None, 0
|
||||||
stripped_old = [l.strip() for l in old_lines]
|
return matches[0].text, len(matches)
|
||||||
content_lines = content.splitlines()
|
|
||||||
|
|
||||||
candidates = []
|
|
||||||
for i in range(len(content_lines) - len(stripped_old) + 1):
|
|
||||||
window = content_lines[i : i + len(stripped_old)]
|
|
||||||
if [l.strip() for l in window] == stripped_old:
|
|
||||||
candidates.append("\n".join(window))
|
|
||||||
|
|
||||||
if candidates:
|
|
||||||
return candidates[0], len(candidates)
|
|
||||||
return None, 0
|
|
||||||
|
|
||||||
|
|
||||||
@tool_parameters(
|
@tool_parameters(
|
||||||
@ -241,6 +578,9 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
|||||||
class EditFileTool(_FsTool):
|
class EditFileTool(_FsTool):
|
||||||
"""Edit a file by replacing text with fallback matching."""
|
"""Edit a file by replacing text with fallback matching."""
|
||||||
|
|
||||||
|
_MAX_EDIT_FILE_SIZE = 1024 * 1024 * 1024 # 1 GiB
|
||||||
|
_MARKDOWN_EXTS = frozenset({".md", ".mdx", ".markdown"})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "edit_file"
|
return "edit_file"
|
||||||
@ -249,11 +589,16 @@ class EditFileTool(_FsTool):
|
|||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Edit a file by replacing old_text with new_text. "
|
"Edit a file by replacing old_text with new_text. "
|
||||||
"Tolerates minor whitespace/indentation differences. "
|
"Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. "
|
||||||
"If old_text matches multiple times, you must provide more context "
|
"If old_text matches multiple times, you must provide more context "
|
||||||
"or set replace_all=true. Shows a diff of the closest match on failure."
|
"or set replace_all=true. Shows a diff of the closest match on failure."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_trailing_ws(text: str) -> str:
|
||||||
|
"""Strip trailing whitespace from each line."""
|
||||||
|
return "\n".join(line.rstrip() for line in text.split("\n"))
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self, path: str | None = None, old_text: str | None = None,
|
self, path: str | None = None, old_text: str | None = None,
|
||||||
new_text: str | None = None,
|
new_text: str | None = None,
|
||||||
@ -267,55 +612,133 @@ class EditFileTool(_FsTool):
|
|||||||
if new_text is None:
|
if new_text is None:
|
||||||
raise ValueError("Unknown new_text")
|
raise ValueError("Unknown new_text")
|
||||||
|
|
||||||
|
# .ipynb detection
|
||||||
|
if path.endswith(".ipynb"):
|
||||||
|
return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file."
|
||||||
|
|
||||||
fp = self._resolve(path)
|
fp = self._resolve(path)
|
||||||
|
|
||||||
|
# Create-file semantics: old_text='' + file doesn't exist → create
|
||||||
if not fp.exists():
|
if not fp.exists():
|
||||||
return f"Error: File not found: {path}"
|
if old_text == "":
|
||||||
|
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fp.write_text(new_text, encoding="utf-8")
|
||||||
|
file_state.record_write(fp)
|
||||||
|
return f"Successfully created {fp}"
|
||||||
|
return self._file_not_found_msg(path, fp)
|
||||||
|
|
||||||
|
# File size protection
|
||||||
|
try:
|
||||||
|
fsize = fp.stat().st_size
|
||||||
|
except OSError:
|
||||||
|
fsize = 0
|
||||||
|
if fsize > self._MAX_EDIT_FILE_SIZE:
|
||||||
|
return f"Error: File too large to edit ({fsize / (1024**3):.1f} GiB). Maximum is 1 GiB."
|
||||||
|
|
||||||
|
# Create-file: old_text='' but file exists and not empty → reject
|
||||||
|
if old_text == "":
|
||||||
|
raw = fp.read_bytes()
|
||||||
|
content = raw.decode("utf-8")
|
||||||
|
if content.strip():
|
||||||
|
return f"Error: Cannot create file — {path} already exists and is not empty."
|
||||||
|
fp.write_text(new_text, encoding="utf-8")
|
||||||
|
file_state.record_write(fp)
|
||||||
|
return f"Successfully edited {fp}"
|
||||||
|
|
||||||
|
# Read-before-edit check
|
||||||
|
warning = file_state.check_read(fp)
|
||||||
|
|
||||||
raw = fp.read_bytes()
|
raw = fp.read_bytes()
|
||||||
uses_crlf = b"\r\n" in raw
|
uses_crlf = b"\r\n" in raw
|
||||||
content = raw.decode("utf-8").replace("\r\n", "\n")
|
content = raw.decode("utf-8").replace("\r\n", "\n")
|
||||||
match, count = _find_match(content, old_text.replace("\r\n", "\n"))
|
norm_old = old_text.replace("\r\n", "\n")
|
||||||
|
matches = _find_matches(content, norm_old)
|
||||||
|
|
||||||
if match is None:
|
if not matches:
|
||||||
return self._not_found_msg(old_text, content, path)
|
return self._not_found_msg(old_text, content, path)
|
||||||
|
count = len(matches)
|
||||||
if count > 1 and not replace_all:
|
if count > 1 and not replace_all:
|
||||||
|
line_numbers = [match.line for match in matches]
|
||||||
|
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
|
||||||
|
if len(line_numbers) > 3:
|
||||||
|
preview += ", ..."
|
||||||
|
location_hint = f" at {preview}" if preview else ""
|
||||||
return (
|
return (
|
||||||
f"Warning: old_text appears {count} times. "
|
f"Warning: old_text appears {count} times{location_hint}. "
|
||||||
"Provide more context to make it unique, or set replace_all=true."
|
"Provide more context to make it unique, or set replace_all=true."
|
||||||
)
|
)
|
||||||
|
|
||||||
norm_new = new_text.replace("\r\n", "\n")
|
norm_new = new_text.replace("\r\n", "\n")
|
||||||
new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1)
|
|
||||||
|
# Trailing whitespace stripping (skip markdown to preserve double-space line breaks)
|
||||||
|
if fp.suffix.lower() not in self._MARKDOWN_EXTS:
|
||||||
|
norm_new = self._strip_trailing_ws(norm_new)
|
||||||
|
|
||||||
|
selected = matches if replace_all else matches[:1]
|
||||||
|
new_content = content
|
||||||
|
for match in reversed(selected):
|
||||||
|
replacement = _preserve_quote_style(norm_old, match.text, norm_new)
|
||||||
|
replacement = _reindent_like_match(norm_old, match.text, replacement)
|
||||||
|
|
||||||
|
# Delete-line cleanup: when deleting text (new_text=''), consume trailing
|
||||||
|
# newline to avoid leaving a blank line
|
||||||
|
end = match.end
|
||||||
|
if replacement == "" and not match.text.endswith("\n") and content[end:end + 1] == "\n":
|
||||||
|
end += 1
|
||||||
|
|
||||||
|
new_content = new_content[: match.start] + replacement + new_content[end:]
|
||||||
if uses_crlf:
|
if uses_crlf:
|
||||||
new_content = new_content.replace("\n", "\r\n")
|
new_content = new_content.replace("\n", "\r\n")
|
||||||
|
|
||||||
fp.write_bytes(new_content.encode("utf-8"))
|
fp.write_bytes(new_content.encode("utf-8"))
|
||||||
return f"Successfully edited {fp}"
|
file_state.record_write(fp)
|
||||||
|
msg = f"Successfully edited {fp}"
|
||||||
|
if warning:
|
||||||
|
msg = f"{warning}\n{msg}"
|
||||||
|
return msg
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error editing file: {e}"
|
return f"Error editing file: {e}"
|
||||||
|
|
||||||
|
def _file_not_found_msg(self, path: str, fp: Path) -> str:
|
||||||
|
"""Build an error message with 'Did you mean ...?' suggestions."""
|
||||||
|
parent = fp.parent
|
||||||
|
suggestions: list[str] = []
|
||||||
|
if parent.is_dir():
|
||||||
|
siblings = [f.name for f in parent.iterdir() if f.is_file()]
|
||||||
|
close = difflib.get_close_matches(fp.name, siblings, n=3, cutoff=0.6)
|
||||||
|
suggestions = [str(parent / c) for c in close]
|
||||||
|
parts = [f"Error: File not found: {path}"]
|
||||||
|
if suggestions:
|
||||||
|
parts.append("Did you mean: " + ", ".join(suggestions) + "?")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _not_found_msg(old_text: str, content: str, path: str) -> str:
|
def _not_found_msg(old_text: str, content: str, path: str) -> str:
|
||||||
lines = content.splitlines(keepends=True)
|
best_ratio, best_start, best_window_lines, hints = _best_window(old_text, content)
|
||||||
old_lines = old_text.splitlines(keepends=True)
|
|
||||||
window = len(old_lines)
|
|
||||||
|
|
||||||
best_ratio, best_start = 0.0, 0
|
|
||||||
for i in range(max(1, len(lines) - window + 1)):
|
|
||||||
ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio()
|
|
||||||
if ratio > best_ratio:
|
|
||||||
best_ratio, best_start = ratio, i
|
|
||||||
|
|
||||||
if best_ratio > 0.5:
|
if best_ratio > 0.5:
|
||||||
diff = "\n".join(difflib.unified_diff(
|
diff = "\n".join(difflib.unified_diff(
|
||||||
old_lines, lines[best_start : best_start + window],
|
old_text.splitlines(keepends=True),
|
||||||
|
best_window_lines,
|
||||||
fromfile="old_text (provided)",
|
fromfile="old_text (provided)",
|
||||||
tofile=f"{path} (actual, line {best_start + 1})",
|
tofile=f"{path} (actual, line {best_start + 1})",
|
||||||
lineterm="",
|
lineterm="",
|
||||||
))
|
))
|
||||||
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
hint_text = ""
|
||||||
|
if hints:
|
||||||
|
hint_text = "\nPossible cause: " + ", ".join(hints) + "."
|
||||||
|
return (
|
||||||
|
f"Error: old_text not found in {path}."
|
||||||
|
f"{hint_text}\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hints:
|
||||||
|
return (
|
||||||
|
f"Error: old_text not found in {path}. "
|
||||||
|
f"Possible cause: {', '.join(hints)}. "
|
||||||
|
"Copy the exact text from read_file and try again."
|
||||||
|
)
|
||||||
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
|
|||||||
|
|
||||||
if "properties" in normalized and isinstance(normalized["properties"], dict):
|
if "properties" in normalized and isinstance(normalized["properties"], dict):
|
||||||
normalized["properties"] = {
|
normalized["properties"] = {
|
||||||
name: _normalize_schema_for_openai(prop)
|
name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop
|
||||||
if isinstance(prop, dict)
|
|
||||||
else prop
|
|
||||||
for name, prop in normalized["properties"].items()
|
for name, prop in normalized["properties"].items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,9 +136,7 @@ class MCPToolWrapper(Tool):
|
|||||||
class MCPResourceWrapper(Tool):
|
class MCPResourceWrapper(Tool):
|
||||||
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
|
||||||
self, session, server_name: str, resource_def, resource_timeout: int = 30
|
|
||||||
):
|
|
||||||
self._session = session
|
self._session = session
|
||||||
self._uri = resource_def.uri
|
self._uri = resource_def.uri
|
||||||
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
|
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
|
||||||
@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool):
|
|||||||
class MCPPromptWrapper(Tool):
|
class MCPPromptWrapper(Tool):
|
||||||
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
|
||||||
self, session, server_name: str, prompt_def, prompt_timeout: int = 30
|
|
||||||
):
|
|
||||||
self._session = session
|
self._session = session
|
||||||
self._prompt_name = prompt_def.name
|
self._prompt_name = prompt_def.name
|
||||||
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
|
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
|
||||||
@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool):
|
|||||||
timeout=self._prompt_timeout,
|
timeout=self._prompt_timeout,
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.warning(
|
logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout)
|
||||||
"MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout
|
|
||||||
)
|
|
||||||
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
|
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
task = asyncio.current_task()
|
task = asyncio.current_task()
|
||||||
@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool):
|
|||||||
except McpError as exc:
|
except McpError as exc:
|
||||||
logger.error(
|
logger.error(
|
||||||
"MCP prompt '{}' failed: code={} message={}",
|
"MCP prompt '{}' failed: code={} message={}",
|
||||||
self._name, exc.error.code, exc.error.message,
|
self._name,
|
||||||
|
exc.error.code,
|
||||||
|
exc.error.message,
|
||||||
)
|
)
|
||||||
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"MCP prompt '{}' failed: {}: {}",
|
"MCP prompt '{}' failed: {}: {}",
|
||||||
self._name, type(exc).__name__, exc,
|
self._name,
|
||||||
|
type(exc).__name__,
|
||||||
|
exc,
|
||||||
)
|
)
|
||||||
return f"(MCP prompt call failed: {type(exc).__name__})"
|
return f"(MCP prompt call failed: {type(exc).__name__})"
|
||||||
|
|
||||||
@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool):
|
|||||||
|
|
||||||
|
|
||||||
async def connect_mcp_servers(
|
async def connect_mcp_servers(
|
||||||
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
|
mcp_servers: dict, registry: ToolRegistry
|
||||||
) -> None:
|
) -> dict[str, AsyncExitStack]:
|
||||||
"""Connect to configured MCP servers and register their tools, resources, and prompts."""
|
"""Connect to configured MCP servers and register their tools, resources, prompts.
|
||||||
|
|
||||||
|
Returns a dict mapping server name -> its dedicated AsyncExitStack.
|
||||||
|
Each server gets its own stack and runs in its own task to prevent
|
||||||
|
cancel scope conflicts when multiple MCP servers are configured.
|
||||||
|
"""
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
from mcp.client.sse import sse_client
|
from mcp.client.sse import sse_client
|
||||||
from mcp.client.stdio import stdio_client
|
from mcp.client.stdio import stdio_client
|
||||||
from mcp.client.streamable_http import streamable_http_client
|
from mcp.client.streamable_http import streamable_http_client
|
||||||
|
|
||||||
for name, cfg in mcp_servers.items():
|
async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]:
|
||||||
|
server_stack = AsyncExitStack()
|
||||||
|
await server_stack.__aenter__()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
transport_type = cfg.type
|
transport_type = cfg.type
|
||||||
if not transport_type:
|
if not transport_type:
|
||||||
if cfg.command:
|
if cfg.command:
|
||||||
transport_type = "stdio"
|
transport_type = "stdio"
|
||||||
elif cfg.url:
|
elif cfg.url:
|
||||||
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
|
|
||||||
transport_type = (
|
transport_type = (
|
||||||
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||||
continue
|
await server_stack.aclose()
|
||||||
|
return name, None
|
||||||
|
|
||||||
if transport_type == "stdio":
|
if transport_type == "stdio":
|
||||||
params = StdioServerParameters(
|
params = StdioServerParameters(
|
||||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||||
)
|
)
|
||||||
read, write = await stack.enter_async_context(stdio_client(params))
|
read, write = await server_stack.enter_async_context(stdio_client(params))
|
||||||
elif transport_type == "sse":
|
elif transport_type == "sse":
|
||||||
|
|
||||||
def httpx_client_factory(
|
def httpx_client_factory(
|
||||||
headers: dict[str, str] | None = None,
|
headers: dict[str, str] | None = None,
|
||||||
timeout: httpx.Timeout | None = None,
|
timeout: httpx.Timeout | None = None,
|
||||||
@ -353,27 +358,26 @@ async def connect_mcp_servers(
|
|||||||
auth=auth,
|
auth=auth,
|
||||||
)
|
)
|
||||||
|
|
||||||
read, write = await stack.enter_async_context(
|
read, write = await server_stack.enter_async_context(
|
||||||
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
||||||
)
|
)
|
||||||
elif transport_type == "streamableHttp":
|
elif transport_type == "streamableHttp":
|
||||||
# Always provide an explicit httpx client so MCP HTTP transport does not
|
http_client = await server_stack.enter_async_context(
|
||||||
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
|
||||||
http_client = await stack.enter_async_context(
|
|
||||||
httpx.AsyncClient(
|
httpx.AsyncClient(
|
||||||
headers=cfg.headers or None,
|
headers=cfg.headers or None,
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
timeout=None,
|
timeout=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
read, write, _ = await stack.enter_async_context(
|
read, write, _ = await server_stack.enter_async_context(
|
||||||
streamable_http_client(cfg.url, http_client=http_client)
|
streamable_http_client(cfg.url, http_client=http_client)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
||||||
continue
|
await server_stack.aclose()
|
||||||
|
return name, None
|
||||||
|
|
||||||
session = await stack.enter_async_context(ClientSession(read, write))
|
session = await server_stack.enter_async_context(ClientSession(read, write))
|
||||||
await session.initialize()
|
await session.initialize()
|
||||||
|
|
||||||
tools = await session.list_tools()
|
tools = await session.list_tools()
|
||||||
@ -418,7 +422,6 @@ async def connect_mcp_servers(
|
|||||||
", ".join(available_wrapped_names) or "(none)",
|
", ".join(available_wrapped_names) or "(none)",
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Register resources ---
|
|
||||||
try:
|
try:
|
||||||
resources_result = await session.list_resources()
|
resources_result = await session.list_resources()
|
||||||
for resource in resources_result.resources:
|
for resource in resources_result.resources:
|
||||||
@ -433,7 +436,6 @@ async def connect_mcp_servers(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
|
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
|
||||||
|
|
||||||
# --- Register prompts ---
|
|
||||||
try:
|
try:
|
||||||
prompts_result = await session.list_prompts()
|
prompts_result = await session.list_prompts()
|
||||||
for prompt in prompts_result.prompts:
|
for prompt in prompts_result.prompts:
|
||||||
@ -442,14 +444,54 @@ async def connect_mcp_servers(
|
|||||||
)
|
)
|
||||||
registry.register(wrapper)
|
registry.register(wrapper)
|
||||||
registered_count += 1
|
registered_count += 1
|
||||||
logger.debug(
|
logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name)
|
||||||
"MCP: registered prompt '{}' from server '{}'", wrapper.name, name
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
|
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"MCP server '{}': connected, {} capabilities registered", name, registered_count
|
"MCP server '{}': connected, {} capabilities registered", name, registered_count
|
||||||
)
|
)
|
||||||
|
return name, server_stack
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
hint = ""
|
||||||
|
text = str(e).lower()
|
||||||
|
if any(
|
||||||
|
marker in text
|
||||||
|
for marker in (
|
||||||
|
"parse error",
|
||||||
|
"invalid json",
|
||||||
|
"unexpected token",
|
||||||
|
"jsonrpc",
|
||||||
|
"content-length",
|
||||||
|
)
|
||||||
|
):
|
||||||
|
hint = (
|
||||||
|
" Hint: this looks like stdio protocol pollution. Make sure the MCP server writes "
|
||||||
|
"only JSON-RPC to stdout and sends logs/debug output to stderr instead."
|
||||||
|
)
|
||||||
|
logger.error("MCP server '{}': failed to connect: {}{}", name, e, hint)
|
||||||
|
try:
|
||||||
|
await server_stack.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return name, None
|
||||||
|
|
||||||
|
server_stacks: dict[str, AsyncExitStack] = {}
|
||||||
|
|
||||||
|
tasks: list[asyncio.Task] = []
|
||||||
|
for name, cfg in mcp_servers.items():
|
||||||
|
task = asyncio.create_task(connect_single_server(name, cfg))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
name = list(mcp_servers.keys())[i]
|
||||||
|
if isinstance(result, BaseException):
|
||||||
|
if not isinstance(result, asyncio.CancelledError):
|
||||||
|
logger.error("MCP server '{}' connection task failed: {}", name, result)
|
||||||
|
elif result is not None and result[1] is not None:
|
||||||
|
server_stacks[result[0]] = result[1]
|
||||||
|
|
||||||
|
return server_stacks
|
||||||
|
|||||||
161
nanobot/agent/tools/notebook.py
Normal file
161
nanobot/agent/tools/notebook.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
"""NotebookEditTool — edit Jupyter .ipynb notebooks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from nanobot.agent.tools.base import tool_parameters
|
||||||
|
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
||||||
|
from nanobot.agent.tools.filesystem import _FsTool
|
||||||
|
|
||||||
|
|
||||||
|
def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict:
|
||||||
|
cell: dict[str, Any] = {
|
||||||
|
"cell_type": cell_type,
|
||||||
|
"source": source,
|
||||||
|
"metadata": {},
|
||||||
|
}
|
||||||
|
if cell_type == "code":
|
||||||
|
cell["outputs"] = []
|
||||||
|
cell["execution_count"] = None
|
||||||
|
if generate_id:
|
||||||
|
cell["id"] = uuid.uuid4().hex[:8]
|
||||||
|
return cell
|
||||||
|
|
||||||
|
|
||||||
|
def _make_empty_notebook() -> dict:
|
||||||
|
return {
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5,
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
||||||
|
"language_info": {"name": "python"},
|
||||||
|
},
|
||||||
|
"cells": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@tool_parameters(
|
||||||
|
tool_parameters_schema(
|
||||||
|
path=StringSchema("Path to the .ipynb notebook file"),
|
||||||
|
cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0),
|
||||||
|
new_source=StringSchema("New source content for the cell"),
|
||||||
|
cell_type=StringSchema(
|
||||||
|
"Cell type: 'code' or 'markdown' (default: code)",
|
||||||
|
enum=["code", "markdown"],
|
||||||
|
),
|
||||||
|
edit_mode=StringSchema(
|
||||||
|
"Mode: 'replace' (default), 'insert' (after target), or 'delete'",
|
||||||
|
enum=["replace", "insert", "delete"],
|
||||||
|
),
|
||||||
|
required=["path", "cell_index"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
class NotebookEditTool(_FsTool):
|
||||||
|
"""Edit Jupyter notebook cells: replace, insert, or delete."""
|
||||||
|
|
||||||
|
_VALID_CELL_TYPES = frozenset({"code", "markdown"})
|
||||||
|
_VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "notebook_edit"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Edit a Jupyter notebook (.ipynb) cell. "
|
||||||
|
"Modes: replace (default) replaces cell content, "
|
||||||
|
"insert adds a new cell after the target index, "
|
||||||
|
"delete removes the cell at the index. "
|
||||||
|
"cell_index is 0-based."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
path: str | None = None,
|
||||||
|
cell_index: int = 0,
|
||||||
|
new_source: str = "",
|
||||||
|
cell_type: str = "code",
|
||||||
|
edit_mode: str = "replace",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
if not path:
|
||||||
|
return "Error: path is required"
|
||||||
|
|
||||||
|
if not path.endswith(".ipynb"):
|
||||||
|
return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files."
|
||||||
|
|
||||||
|
if edit_mode not in self._VALID_EDIT_MODES:
|
||||||
|
return (
|
||||||
|
f"Error: Invalid edit_mode '{edit_mode}'. "
|
||||||
|
"Use one of: replace, insert, delete."
|
||||||
|
)
|
||||||
|
|
||||||
|
if cell_type not in self._VALID_CELL_TYPES:
|
||||||
|
return (
|
||||||
|
f"Error: Invalid cell_type '{cell_type}'. "
|
||||||
|
"Use one of: code, markdown."
|
||||||
|
)
|
||||||
|
|
||||||
|
fp = self._resolve(path)
|
||||||
|
|
||||||
|
# Create new notebook if file doesn't exist and mode is insert
|
||||||
|
if not fp.exists():
|
||||||
|
if edit_mode != "insert":
|
||||||
|
return f"Error: File not found: {path}"
|
||||||
|
nb = _make_empty_notebook()
|
||||||
|
cell = _new_cell(new_source, cell_type, generate_id=True)
|
||||||
|
nb["cells"].append(cell)
|
||||||
|
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||||
|
return f"Successfully created {fp} with 1 cell"
|
||||||
|
|
||||||
|
try:
|
||||||
|
nb = json.loads(fp.read_text(encoding="utf-8"))
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||||
|
return f"Error: Failed to parse notebook: {e}"
|
||||||
|
|
||||||
|
cells = nb.get("cells", [])
|
||||||
|
nbformat_minor = nb.get("nbformat_minor", 0)
|
||||||
|
generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5
|
||||||
|
|
||||||
|
if edit_mode == "delete":
|
||||||
|
if cell_index < 0 or cell_index >= len(cells):
|
||||||
|
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||||
|
cells.pop(cell_index)
|
||||||
|
nb["cells"] = cells
|
||||||
|
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||||
|
return f"Successfully deleted cell {cell_index} from {fp}"
|
||||||
|
|
||||||
|
if edit_mode == "insert":
|
||||||
|
insert_at = min(cell_index + 1, len(cells))
|
||||||
|
cell = _new_cell(new_source, cell_type, generate_id=generate_id)
|
||||||
|
cells.insert(insert_at, cell)
|
||||||
|
nb["cells"] = cells
|
||||||
|
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||||
|
return f"Successfully inserted cell at index {insert_at} in {fp}"
|
||||||
|
|
||||||
|
# Default: replace
|
||||||
|
if cell_index < 0 or cell_index >= len(cells):
|
||||||
|
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||||
|
cells[cell_index]["source"] = new_source
|
||||||
|
if cell_type and cells[cell_index].get("cell_type") != cell_type:
|
||||||
|
cells[cell_index]["cell_type"] = cell_type
|
||||||
|
if cell_type == "code":
|
||||||
|
cells[cell_index].setdefault("outputs", [])
|
||||||
|
cells[cell_index].setdefault("execution_count", None)
|
||||||
|
elif "outputs" in cells[cell_index]:
|
||||||
|
del cells[cell_index]["outputs"]
|
||||||
|
cells[cell_index].pop("execution_count", None)
|
||||||
|
nb["cells"] = cells
|
||||||
|
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||||
|
return f"Successfully edited cell {cell_index} in {fp}"
|
||||||
|
|
||||||
|
except PermissionError as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error editing notebook: {e}"
|
||||||
@ -68,6 +68,13 @@ class ToolRegistry:
|
|||||||
params: dict[str, Any],
|
params: dict[str, Any],
|
||||||
) -> tuple[Tool | None, dict[str, Any], str | None]:
|
) -> tuple[Tool | None, dict[str, Any], str | None]:
|
||||||
"""Resolve, cast, and validate one tool call."""
|
"""Resolve, cast, and validate one tool call."""
|
||||||
|
# Guard against invalid parameter types (e.g., list instead of dict)
|
||||||
|
if not isinstance(params, dict) and name in ('write_file', 'read_file'):
|
||||||
|
return None, params, (
|
||||||
|
f"Error: Tool '{name}' parameters must be a JSON object, got {type(params).__name__}. "
|
||||||
|
"Use named parameters: tool_name(param1=\"value1\", param2=\"value2\")"
|
||||||
|
)
|
||||||
|
|
||||||
tool = self._tools.get(name)
|
tool = self._tools.get(name)
|
||||||
if not tool:
|
if not tool:
|
||||||
return None, params, (
|
return None, params, (
|
||||||
|
|||||||
@ -46,6 +46,7 @@ class ExecTool(Tool):
|
|||||||
restrict_to_workspace: bool = False,
|
restrict_to_workspace: bool = False,
|
||||||
sandbox: str = "",
|
sandbox: str = "",
|
||||||
path_append: str = "",
|
path_append: str = "",
|
||||||
|
allowed_env_keys: list[str] | None = None,
|
||||||
):
|
):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.working_dir = working_dir
|
self.working_dir = working_dir
|
||||||
@ -60,10 +61,19 @@ class ExecTool(Tool):
|
|||||||
r">\s*/dev/sd", # write to disk
|
r">\s*/dev/sd", # write to disk
|
||||||
r"\b(shutdown|reboot|poweroff)\b", # system power
|
r"\b(shutdown|reboot|poweroff)\b", # system power
|
||||||
r":\(\)\s*\{.*\};\s*:", # fork bomb
|
r":\(\)\s*\{.*\};\s*:", # fork bomb
|
||||||
|
# Block writes to nanobot internal state files (#2989).
|
||||||
|
# history.jsonl / .dream_cursor are managed by append_history();
|
||||||
|
# direct writes corrupt the cursor format and crash /dream.
|
||||||
|
r">>?\s*\S*(?:history\.jsonl|\.dream_cursor)", # > / >> redirect
|
||||||
|
r"\btee\b[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # tee / tee -a
|
||||||
|
r"\b(?:cp|mv)\b(?:\s+[^\s|;&<>]+)+\s+\S*(?:history\.jsonl|\.dream_cursor)", # cp/mv target
|
||||||
|
r"\bdd\b[^|;&<>]*\bof=\S*(?:history\.jsonl|\.dream_cursor)", # dd of=
|
||||||
|
r"\bsed\s+-i[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # sed -i
|
||||||
]
|
]
|
||||||
self.allow_patterns = allow_patterns or []
|
self.allow_patterns = allow_patterns or []
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
self.path_append = path_append
|
self.path_append = path_append
|
||||||
|
self.allowed_env_keys = allowed_env_keys or []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@ -91,6 +101,21 @@ class ExecTool(Tool):
|
|||||||
timeout: int | None = None, **kwargs: Any,
|
timeout: int | None = None, **kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
cwd = working_dir or self.working_dir or os.getcwd()
|
cwd = working_dir or self.working_dir or os.getcwd()
|
||||||
|
|
||||||
|
# Prevent an LLM-supplied working_dir from escaping the configured
|
||||||
|
# workspace when restrict_to_workspace is enabled (#2826). Without
|
||||||
|
# this, a caller can pass working_dir="/etc" and then all absolute
|
||||||
|
# paths under /etc would pass the _guard_command check that anchors
|
||||||
|
# on cwd.
|
||||||
|
if self.restrict_to_workspace and self.working_dir:
|
||||||
|
try:
|
||||||
|
requested = Path(cwd).expanduser().resolve()
|
||||||
|
workspace_root = Path(self.working_dir).expanduser().resolve()
|
||||||
|
except Exception:
|
||||||
|
return "Error: working_dir could not be resolved"
|
||||||
|
if requested != workspace_root and workspace_root not in requested.parents:
|
||||||
|
return "Error: working_dir is outside the configured workspace"
|
||||||
|
|
||||||
guard_error = self._guard_command(command, cwd)
|
guard_error = self._guard_command(command, cwd)
|
||||||
if guard_error:
|
if guard_error:
|
||||||
return guard_error
|
return guard_error
|
||||||
@ -208,7 +233,7 @@ class ExecTool(Tool):
|
|||||||
"""
|
"""
|
||||||
if _IS_WINDOWS:
|
if _IS_WINDOWS:
|
||||||
sr = os.environ.get("SYSTEMROOT", r"C:\Windows")
|
sr = os.environ.get("SYSTEMROOT", r"C:\Windows")
|
||||||
return {
|
env = {
|
||||||
"SYSTEMROOT": sr,
|
"SYSTEMROOT": sr,
|
||||||
"COMSPEC": os.environ.get("COMSPEC", f"{sr}\\system32\\cmd.exe"),
|
"COMSPEC": os.environ.get("COMSPEC", f"{sr}\\system32\\cmd.exe"),
|
||||||
"USERPROFILE": os.environ.get("USERPROFILE", ""),
|
"USERPROFILE": os.environ.get("USERPROFILE", ""),
|
||||||
@ -218,13 +243,29 @@ class ExecTool(Tool):
|
|||||||
"TMP": os.environ.get("TMP", f"{sr}\\Temp"),
|
"TMP": os.environ.get("TMP", f"{sr}\\Temp"),
|
||||||
"PATHEXT": os.environ.get("PATHEXT", ".COM;.EXE;.BAT;.CMD"),
|
"PATHEXT": os.environ.get("PATHEXT", ".COM;.EXE;.BAT;.CMD"),
|
||||||
"PATH": os.environ.get("PATH", f"{sr}\\system32;{sr}"),
|
"PATH": os.environ.get("PATH", f"{sr}\\system32;{sr}"),
|
||||||
|
"APPDATA": os.environ.get("APPDATA", ""),
|
||||||
|
"LOCALAPPDATA": os.environ.get("LOCALAPPDATA", ""),
|
||||||
|
"ProgramData": os.environ.get("ProgramData", ""),
|
||||||
|
"ProgramFiles": os.environ.get("ProgramFiles", ""),
|
||||||
|
"ProgramFiles(x86)": os.environ.get("ProgramFiles(x86)", ""),
|
||||||
|
"ProgramW6432": os.environ.get("ProgramW6432", ""),
|
||||||
}
|
}
|
||||||
|
for key in self.allowed_env_keys:
|
||||||
|
val = os.environ.get(key)
|
||||||
|
if val is not None:
|
||||||
|
env[key] = val
|
||||||
|
return env
|
||||||
home = os.environ.get("HOME", "/tmp")
|
home = os.environ.get("HOME", "/tmp")
|
||||||
return {
|
env = {
|
||||||
"HOME": home,
|
"HOME": home,
|
||||||
"LANG": os.environ.get("LANG", "C.UTF-8"),
|
"LANG": os.environ.get("LANG", "C.UTF-8"),
|
||||||
"TERM": os.environ.get("TERM", "dumb"),
|
"TERM": os.environ.get("TERM", "dumb"),
|
||||||
}
|
}
|
||||||
|
for key in self.allowed_env_keys:
|
||||||
|
val = os.environ.get(key)
|
||||||
|
if val is not None:
|
||||||
|
env[key] = val
|
||||||
|
return env
|
||||||
|
|
||||||
def _guard_command(self, command: str, cwd: str) -> str | None:
|
def _guard_command(self, command: str, cwd: str) -> str | None:
|
||||||
"""Best-effort safety guard for potentially destructive commands."""
|
"""Best-effort safety guard for potentially destructive commands."""
|
||||||
|
|||||||
@ -96,10 +96,37 @@ class WebSearchTool(Tool):
|
|||||||
self.config = config if config is not None else WebSearchConfig()
|
self.config = config if config is not None else WebSearchConfig()
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
|
|
||||||
|
def _effective_provider(self) -> str:
|
||||||
|
"""Resolve the backend that execute() will actually use."""
|
||||||
|
provider = self.config.provider.strip().lower() or "brave"
|
||||||
|
if provider == "duckduckgo":
|
||||||
|
return "duckduckgo"
|
||||||
|
if provider == "brave":
|
||||||
|
api_key = self.config.api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||||
|
return "brave" if api_key else "duckduckgo"
|
||||||
|
if provider == "tavily":
|
||||||
|
api_key = self.config.api_key or os.environ.get("TAVILY_API_KEY", "")
|
||||||
|
return "tavily" if api_key else "duckduckgo"
|
||||||
|
if provider == "searxng":
|
||||||
|
base_url = (self.config.base_url or os.environ.get("SEARXNG_BASE_URL", "")).strip()
|
||||||
|
return "searxng" if base_url else "duckduckgo"
|
||||||
|
if provider == "jina":
|
||||||
|
api_key = self.config.api_key or os.environ.get("JINA_API_KEY", "")
|
||||||
|
return "jina" if api_key else "duckduckgo"
|
||||||
|
if provider == "kagi":
|
||||||
|
api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "")
|
||||||
|
return "kagi" if api_key else "duckduckgo"
|
||||||
|
return provider
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def read_only(self) -> bool:
|
def read_only(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exclusive(self) -> bool:
|
||||||
|
"""DuckDuckGo searches are serialized because ddgs is not concurrency-safe."""
|
||||||
|
return self._effective_provider() == "duckduckgo"
|
||||||
|
|
||||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||||
provider = self.config.provider.strip().lower() or "brave"
|
provider = self.config.provider.strip().lower() or "brave"
|
||||||
n = min(max(count or self.config.max_results, 1), 10)
|
n = min(max(count or self.config.max_results, 1), 10)
|
||||||
@ -114,6 +141,8 @@ class WebSearchTool(Tool):
|
|||||||
return await self._search_jina(query, n)
|
return await self._search_jina(query, n)
|
||||||
elif provider == "brave":
|
elif provider == "brave":
|
||||||
return await self._search_brave(query, n)
|
return await self._search_brave(query, n)
|
||||||
|
elif provider == "kagi":
|
||||||
|
return await self._search_kagi(query, n)
|
||||||
else:
|
else:
|
||||||
return f"Error: unknown search provider '{provider}'"
|
return f"Error: unknown search provider '{provider}'"
|
||||||
|
|
||||||
@ -204,6 +233,29 @@ class WebSearchTool(Tool):
|
|||||||
logger.warning("Jina search failed ({}), falling back to DuckDuckGo", e)
|
logger.warning("Jina search failed ({}), falling back to DuckDuckGo", e)
|
||||||
return await self._search_duckduckgo(query, n)
|
return await self._search_duckduckgo(query, n)
|
||||||
|
|
||||||
|
async def _search_kagi(self, query: str, n: int) -> str:
|
||||||
|
api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "")
|
||||||
|
if not api_key:
|
||||||
|
logger.warning("KAGI_API_KEY not set, falling back to DuckDuckGo")
|
||||||
|
return await self._search_duckduckgo(query, n)
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(proxy=self.proxy) as client:
|
||||||
|
r = await client.get(
|
||||||
|
"https://kagi.com/api/v0/search",
|
||||||
|
params={"q": query, "limit": n},
|
||||||
|
headers={"Authorization": f"Bot {api_key}"},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
# t=0 items are search results; other values are related searches, etc.
|
||||||
|
items = [
|
||||||
|
{"title": d.get("title", ""), "url": d.get("url", ""), "content": d.get("snippet", "")}
|
||||||
|
for d in r.json().get("data", []) if d.get("t") == 0
|
||||||
|
]
|
||||||
|
return _format_results(query, items, n)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
|
||||||
async def _search_duckduckgo(self, query: str, n: int) -> str:
|
async def _search_duckduckgo(self, query: str, n: int) -> str:
|
||||||
try:
|
try:
|
||||||
# Note: duckduckgo_search is synchronous and does its own requests
|
# Note: duckduckgo_search is synchronous and does its own requests
|
||||||
|
|||||||
@ -84,6 +84,10 @@ def _save_base64_data_url(data_url: str, media_dir: Path) -> str | None:
|
|||||||
raw = base64.b64decode(b64_payload)
|
raw = base64.b64decode(b64_payload)
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
if len(raw) > MAX_FILE_SIZE:
|
||||||
|
raise _FileSizeExceeded(
|
||||||
|
f"File exceeds {MAX_FILE_SIZE // (1024 * 1024)}MB limit"
|
||||||
|
)
|
||||||
ext = mimetypes.guess_extension(mime_type) or ".bin"
|
ext = mimetypes.guess_extension(mime_type) or ".bin"
|
||||||
filename = f"{uuid.uuid4().hex[:12]}{ext}"
|
filename = f"{uuid.uuid4().hex[:12]}{ext}"
|
||||||
dest = media_dir / safe_filename(filename)
|
dest = media_dir / safe_filename(filename)
|
||||||
|
|||||||
@ -5,6 +5,8 @@ import json
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import zipfile
|
||||||
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import unquote, urlparse
|
from urllib.parse import unquote, urlparse
|
||||||
@ -171,6 +173,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||||
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
||||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
||||||
|
_ZIP_BEFORE_UPLOAD_EXTS = {".htm", ".html"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_config(cls) -> dict[str, Any]:
|
def default_config(cls) -> dict[str, Any]:
|
||||||
@ -287,6 +290,31 @@ class DingTalkChannel(BaseChannel):
|
|||||||
name = os.path.basename(urlparse(media_ref).path)
|
name = os.path.basename(urlparse(media_ref).path)
|
||||||
return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
|
return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _zip_bytes(filename: str, data: bytes) -> tuple[bytes, str, str]:
|
||||||
|
stem = Path(filename).stem or "attachment"
|
||||||
|
safe_name = filename or "attachment.bin"
|
||||||
|
zip_name = f"{stem}.zip"
|
||||||
|
buffer = BytesIO()
|
||||||
|
with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
|
||||||
|
archive.writestr(safe_name, data)
|
||||||
|
return buffer.getvalue(), zip_name, "application/zip"
|
||||||
|
|
||||||
|
def _normalize_upload_payload(
|
||||||
|
self,
|
||||||
|
filename: str,
|
||||||
|
data: bytes,
|
||||||
|
content_type: str | None,
|
||||||
|
) -> tuple[bytes, str, str | None]:
|
||||||
|
ext = Path(filename).suffix.lower()
|
||||||
|
if ext in self._ZIP_BEFORE_UPLOAD_EXTS or content_type == "text/html":
|
||||||
|
logger.info(
|
||||||
|
"DingTalk does not accept raw HTML attachments, zipping {} before upload",
|
||||||
|
filename,
|
||||||
|
)
|
||||||
|
return self._zip_bytes(filename, data)
|
||||||
|
return data, filename, content_type
|
||||||
|
|
||||||
async def _read_media_bytes(
|
async def _read_media_bytes(
|
||||||
self,
|
self,
|
||||||
media_ref: str,
|
media_ref: str,
|
||||||
@ -309,6 +337,9 @@ class DingTalkChannel(BaseChannel):
|
|||||||
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
|
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
|
||||||
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||||
return resp.content, filename, content_type or None
|
return resp.content, filename, content_type or None
|
||||||
|
except httpx.TransportError as e:
|
||||||
|
logger.error("DingTalk media download network error ref={} err={}", media_ref, e)
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
||||||
return None, None, None
|
return None, None, None
|
||||||
@ -360,6 +391,9 @@ class DingTalkChannel(BaseChannel):
|
|||||||
logger.error("DingTalk media upload missing media_id body={}", text[:500])
|
logger.error("DingTalk media upload missing media_id body={}", text[:500])
|
||||||
return None
|
return None
|
||||||
return str(media_id)
|
return str(media_id)
|
||||||
|
except httpx.TransportError as e:
|
||||||
|
logger.error("DingTalk media upload network error type={} err={}", media_type, e)
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("DingTalk media upload error type={} err={}", media_type, e)
|
logger.error("DingTalk media upload error type={} err={}", media_type, e)
|
||||||
return None
|
return None
|
||||||
@ -409,6 +443,9 @@ class DingTalkChannel(BaseChannel):
|
|||||||
return False
|
return False
|
||||||
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
|
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
|
||||||
return True
|
return True
|
||||||
|
except httpx.TransportError as e:
|
||||||
|
logger.error("DingTalk network error sending message msgKey={} err={}", msg_key, e)
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
|
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
|
||||||
return False
|
return False
|
||||||
@ -444,6 +481,7 @@ class DingTalkChannel(BaseChannel):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
filename = filename or self._guess_filename(media_ref, upload_type)
|
filename = filename or self._guess_filename(media_ref, upload_type)
|
||||||
|
data, filename, content_type = self._normalize_upload_payload(filename, data, content_type)
|
||||||
file_type = Path(filename).suffix.lower().lstrip(".")
|
file_type = Path(filename).suffix.lower().lstrip(".")
|
||||||
if not file_type:
|
if not file_type:
|
||||||
guessed = mimetypes.guess_extension(content_type or "")
|
guessed = mimetypes.guess_extension(content_type or "")
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
@ -20,6 +22,7 @@ from nanobot.utils.helpers import safe_filename, split_message
|
|||||||
|
|
||||||
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
|
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
import aiohttp
|
||||||
import discord
|
import discord
|
||||||
from discord import app_commands
|
from discord import app_commands
|
||||||
from discord.abc import Messageable
|
from discord.abc import Messageable
|
||||||
@ -34,6 +37,16 @@ MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
|||||||
TYPING_INTERVAL_S = 8
|
TYPING_INTERVAL_S = 8
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _StreamBuf:
|
||||||
|
"""Per-chat streaming accumulator for progressive Discord message edits."""
|
||||||
|
|
||||||
|
text: str = ""
|
||||||
|
message: Any | None = None
|
||||||
|
last_edit: float = 0.0
|
||||||
|
stream_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class DiscordConfig(Base):
|
class DiscordConfig(Base):
|
||||||
"""Discord channel configuration."""
|
"""Discord channel configuration."""
|
||||||
|
|
||||||
@ -45,6 +58,10 @@ class DiscordConfig(Base):
|
|||||||
read_receipt_emoji: str = "👀"
|
read_receipt_emoji: str = "👀"
|
||||||
working_emoji: str = "🔧"
|
working_emoji: str = "🔧"
|
||||||
working_emoji_delay: float = 2.0
|
working_emoji_delay: float = 2.0
|
||||||
|
streaming: bool = True
|
||||||
|
proxy: str | None = None
|
||||||
|
proxy_username: str | None = None
|
||||||
|
proxy_password: str | None = None
|
||||||
|
|
||||||
|
|
||||||
if DISCORD_AVAILABLE:
|
if DISCORD_AVAILABLE:
|
||||||
@ -52,8 +69,15 @@ if DISCORD_AVAILABLE:
|
|||||||
class DiscordBotClient(discord.Client):
|
class DiscordBotClient(discord.Client):
|
||||||
"""discord.py client that forwards events to the channel."""
|
"""discord.py client that forwards events to the channel."""
|
||||||
|
|
||||||
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
|
def __init__(
|
||||||
super().__init__(intents=intents)
|
self,
|
||||||
|
channel: DiscordChannel,
|
||||||
|
*,
|
||||||
|
intents: discord.Intents,
|
||||||
|
proxy: str | None = None,
|
||||||
|
proxy_auth: aiohttp.BasicAuth | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(intents=intents, proxy=proxy, proxy_auth=proxy_auth)
|
||||||
self._channel = channel
|
self._channel = channel
|
||||||
self.tree = app_commands.CommandTree(self)
|
self.tree = app_commands.CommandTree(self)
|
||||||
self._register_app_commands()
|
self._register_app_commands()
|
||||||
@ -117,6 +141,7 @@ if DISCORD_AVAILABLE:
|
|||||||
)
|
)
|
||||||
|
|
||||||
for name, description, command_text in commands:
|
for name, description, command_text in commands:
|
||||||
|
|
||||||
@self.tree.command(name=name, description=description)
|
@self.tree.command(name=name, description=description)
|
||||||
async def command_handler(
|
async def command_handler(
|
||||||
interaction: discord.Interaction,
|
interaction: discord.Interaction,
|
||||||
@ -173,7 +198,9 @@ if DISCORD_AVAILABLE:
|
|||||||
else:
|
else:
|
||||||
failed_media.append(Path(media_path).name)
|
failed_media.append(Path(media_path).name)
|
||||||
|
|
||||||
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
|
for index, chunk in enumerate(
|
||||||
|
self._build_chunks(msg.content or "", failed_media, sent_media)
|
||||||
|
):
|
||||||
kwargs: dict[str, Any] = {"content": chunk}
|
kwargs: dict[str, Any] = {"content": chunk}
|
||||||
if index == 0 and reference is not None and not sent_media:
|
if index == 0 and reference is not None and not sent_media:
|
||||||
kwargs["reference"] = reference
|
kwargs["reference"] = reference
|
||||||
@ -242,6 +269,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
|
|
||||||
name = "discord"
|
name = "discord"
|
||||||
display_name = "Discord"
|
display_name = "Discord"
|
||||||
|
_STREAM_EDIT_INTERVAL = 0.8
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_config(cls) -> dict[str, Any]:
|
def default_config(cls) -> dict[str, Any]:
|
||||||
@ -263,6 +291,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
self._bot_user_id: str | None = None
|
self._bot_user_id: str | None = None
|
||||||
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
|
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
|
||||||
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
|
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
|
self._stream_bufs: dict[str, _StreamBuf] = {}
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Discord client."""
|
"""Start the Discord client."""
|
||||||
@ -277,7 +306,29 @@ class DiscordChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
intents = discord.Intents.none()
|
intents = discord.Intents.none()
|
||||||
intents.value = self.config.intents
|
intents.value = self.config.intents
|
||||||
self._client = DiscordBotClient(self, intents=intents)
|
|
||||||
|
proxy_auth = None
|
||||||
|
has_user = bool(self.config.proxy_username)
|
||||||
|
has_pass = bool(self.config.proxy_password)
|
||||||
|
if has_user and has_pass:
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
proxy_auth = aiohttp.BasicAuth(
|
||||||
|
login=self.config.proxy_username,
|
||||||
|
password=self.config.proxy_password,
|
||||||
|
)
|
||||||
|
elif has_user != has_pass:
|
||||||
|
logger.warning(
|
||||||
|
"Discord proxy auth incomplete: both proxy_username and "
|
||||||
|
"proxy_password must be set; ignoring partial credentials",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._client = DiscordBotClient(
|
||||||
|
self,
|
||||||
|
intents=intents,
|
||||||
|
proxy=self.config.proxy,
|
||||||
|
proxy_auth=proxy_auth,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to initialize Discord client: {}", e)
|
logger.error("Failed to initialize Discord client: {}", e)
|
||||||
self._client = None
|
self._client = None
|
||||||
@ -315,11 +366,71 @@ class DiscordChannel(BaseChannel):
|
|||||||
await client.send_outbound(msg)
|
await client.send_outbound(msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending Discord message: {}", e)
|
logger.error("Error sending Discord message: {}", e)
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
if not is_progress:
|
if not is_progress:
|
||||||
await self._stop_typing(msg.chat_id)
|
await self._stop_typing(msg.chat_id)
|
||||||
await self._clear_reactions(msg.chat_id)
|
await self._clear_reactions(msg.chat_id)
|
||||||
|
|
||||||
|
async def send_delta(
|
||||||
|
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Progressive Discord delivery: send once, then edit until the stream ends."""
|
||||||
|
client = self._client
|
||||||
|
if client is None or not client.is_ready():
|
||||||
|
logger.warning("Discord client not ready; dropping stream delta")
|
||||||
|
return
|
||||||
|
|
||||||
|
meta = metadata or {}
|
||||||
|
stream_id = meta.get("_stream_id")
|
||||||
|
|
||||||
|
if meta.get("_stream_end"):
|
||||||
|
buf = self._stream_bufs.get(chat_id)
|
||||||
|
if not buf or buf.message is None or not buf.text:
|
||||||
|
return
|
||||||
|
if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id:
|
||||||
|
return
|
||||||
|
await self._finalize_stream(chat_id, buf)
|
||||||
|
return
|
||||||
|
|
||||||
|
buf = self._stream_bufs.get(chat_id)
|
||||||
|
if buf is None or (
|
||||||
|
stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id
|
||||||
|
):
|
||||||
|
buf = _StreamBuf(stream_id=stream_id)
|
||||||
|
self._stream_bufs[chat_id] = buf
|
||||||
|
elif buf.stream_id is None:
|
||||||
|
buf.stream_id = stream_id
|
||||||
|
|
||||||
|
buf.text += delta
|
||||||
|
if not buf.text.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
target = await self._resolve_channel(chat_id)
|
||||||
|
if target is None:
|
||||||
|
logger.warning("Discord stream target {} unavailable", chat_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
now = time.monotonic()
|
||||||
|
if buf.message is None:
|
||||||
|
try:
|
||||||
|
buf.message = await target.send(content=buf.text)
|
||||||
|
buf.last_edit = now
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Discord stream initial send failed: {}", e)
|
||||||
|
raise
|
||||||
|
return
|
||||||
|
|
||||||
|
if (now - buf.last_edit) < self._STREAM_EDIT_INTERVAL:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await buf.message.edit(content=DiscordBotClient._build_chunks(buf.text, [], False)[0])
|
||||||
|
buf.last_edit = now
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Discord stream edit failed: {}", e)
|
||||||
|
raise
|
||||||
|
|
||||||
async def _handle_discord_message(self, message: discord.Message) -> None:
|
async def _handle_discord_message(self, message: discord.Message) -> None:
|
||||||
"""Handle incoming Discord messages from discord.py."""
|
"""Handle incoming Discord messages from discord.py."""
|
||||||
if message.author.bot:
|
if message.author.bot:
|
||||||
@ -373,6 +484,47 @@ class DiscordChannel(BaseChannel):
|
|||||||
"""Backward-compatible alias for legacy tests/callers."""
|
"""Backward-compatible alias for legacy tests/callers."""
|
||||||
await self._handle_discord_message(message)
|
await self._handle_discord_message(message)
|
||||||
|
|
||||||
|
async def _resolve_channel(self, chat_id: str) -> Any | None:
|
||||||
|
"""Resolve a Discord channel from cache first, then network fetch."""
|
||||||
|
client = self._client
|
||||||
|
if client is None or not client.is_ready():
|
||||||
|
return None
|
||||||
|
channel_id = int(chat_id)
|
||||||
|
channel = client.get_channel(channel_id)
|
||||||
|
if channel is not None:
|
||||||
|
return channel
|
||||||
|
try:
|
||||||
|
return await client.fetch_channel(channel_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Discord channel {} unavailable: {}", chat_id, e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _finalize_stream(self, chat_id: str, buf: _StreamBuf) -> None:
|
||||||
|
"""Commit the final streamed content and flush overflow chunks."""
|
||||||
|
chunks = DiscordBotClient._build_chunks(buf.text, [], False)
|
||||||
|
if not chunks:
|
||||||
|
self._stream_bufs.pop(chat_id, None)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await buf.message.edit(content=chunks[0])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Discord final stream edit failed: {}", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
target = getattr(buf.message, "channel", None) or await self._resolve_channel(chat_id)
|
||||||
|
if target is None:
|
||||||
|
logger.warning("Discord stream follow-up target {} unavailable", chat_id)
|
||||||
|
self._stream_bufs.pop(chat_id, None)
|
||||||
|
return
|
||||||
|
|
||||||
|
for extra_chunk in chunks[1:]:
|
||||||
|
await target.send(content=extra_chunk)
|
||||||
|
|
||||||
|
self._stream_bufs.pop(chat_id, None)
|
||||||
|
await self._stop_typing(chat_id)
|
||||||
|
await self._clear_reactions(chat_id)
|
||||||
|
|
||||||
def _should_accept_inbound(
|
def _should_accept_inbound(
|
||||||
self,
|
self,
|
||||||
message: discord.Message,
|
message: discord.Message,
|
||||||
@ -423,7 +575,11 @@ class DiscordChannel(BaseChannel):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
||||||
"""Build metadata for inbound Discord messages."""
|
"""Build metadata for inbound Discord messages."""
|
||||||
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
|
reply_to = (
|
||||||
|
str(message.reference.message_id)
|
||||||
|
if message.reference and message.reference.message_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"message_id": str(message.id),
|
"message_id": str(message.id),
|
||||||
"guild_id": str(message.guild.id) if message.guild else None,
|
"guild_id": str(message.guild.id) if message.guild else None,
|
||||||
@ -438,7 +594,9 @@ class DiscordChannel(BaseChannel):
|
|||||||
if self.config.group_policy == "mention":
|
if self.config.group_policy == "mention":
|
||||||
bot_user_id = self._bot_user_id
|
bot_user_id = self._bot_user_id
|
||||||
if bot_user_id is None:
|
if bot_user_id is None:
|
||||||
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
|
logger.debug(
|
||||||
|
"Discord message in {} ignored (bot identity unavailable)", message.channel.id
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if any(str(user.id) == bot_user_id for user in message.mentions):
|
if any(str(user.id) == bot_user_id for user in message.mentions):
|
||||||
@ -480,7 +638,6 @@ class DiscordChannel(BaseChannel):
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def _clear_reactions(self, chat_id: str) -> None:
|
async def _clear_reactions(self, chat_id: str) -> None:
|
||||||
"""Remove all pending reactions after bot replies."""
|
"""Remove all pending reactions after bot replies."""
|
||||||
# Cancel delayed working emoji if it hasn't fired yet
|
# Cancel delayed working emoji if it hasn't fired yet
|
||||||
@ -507,6 +664,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
async def _reset_runtime_state(self, close_client: bool) -> None:
|
async def _reset_runtime_state(self, close_client: bool) -> None:
|
||||||
"""Reset client and typing state."""
|
"""Reset client and typing state."""
|
||||||
await self._cancel_all_typing()
|
await self._cancel_all_typing()
|
||||||
|
self._stream_bufs.clear()
|
||||||
if close_client and self._client is not None and not self._client.is_closed():
|
if close_client and self._client is not None and not self._client.is_closed():
|
||||||
try:
|
try:
|
||||||
await self._client.close()
|
await self._client.close()
|
||||||
|
|||||||
@ -22,6 +22,8 @@ from nanobot.channels.base import BaseChannel
|
|||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
|
|
||||||
|
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
|
||||||
|
|
||||||
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||||
|
|
||||||
# Message type display mapping
|
# Message type display mapping
|
||||||
@ -250,9 +252,12 @@ class FeishuConfig(Base):
|
|||||||
verification_token: str = ""
|
verification_token: str = ""
|
||||||
allow_from: list[str] = Field(default_factory=list)
|
allow_from: list[str] = Field(default_factory=list)
|
||||||
react_emoji: str = "THUMBSUP"
|
react_emoji: str = "THUMBSUP"
|
||||||
|
done_emoji: str | None = None # Emoji to show when task is completed (e.g., "DONE", "OK")
|
||||||
|
tool_hint_prefix: str = "\U0001f527" # Prefix for inline tool hints (default: 🔧)
|
||||||
group_policy: Literal["open", "mention"] = "mention"
|
group_policy: Literal["open", "mention"] = "mention"
|
||||||
reply_to_message: bool = False # If True, bot replies quote the user's original message
|
reply_to_message: bool = False # If True, bot replies quote the user's original message
|
||||||
streaming: bool = True
|
streaming: bool = True
|
||||||
|
domain: Literal["feishu", "lark"] = "feishu" # Set to "lark" for international Lark
|
||||||
|
|
||||||
|
|
||||||
_STREAM_ELEMENT_ID = "streaming_md"
|
_STREAM_ELEMENT_ID = "streaming_md"
|
||||||
@ -326,10 +331,12 @@ class FeishuChannel(BaseChannel):
|
|||||||
self._loop = asyncio.get_running_loop()
|
self._loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
# Create Lark client for sending messages
|
# Create Lark client for sending messages
|
||||||
|
domain = LARK_DOMAIN if self.config.domain == "lark" else FEISHU_DOMAIN
|
||||||
self._client = (
|
self._client = (
|
||||||
lark.Client.builder()
|
lark.Client.builder()
|
||||||
.app_id(self.config.app_id)
|
.app_id(self.config.app_id)
|
||||||
.app_secret(self.config.app_secret)
|
.app_secret(self.config.app_secret)
|
||||||
|
.domain(domain)
|
||||||
.log_level(lark.LogLevel.INFO)
|
.log_level(lark.LogLevel.INFO)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
@ -357,6 +364,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
self._ws_client = lark.ws.Client(
|
self._ws_client = lark.ws.Client(
|
||||||
self.config.app_id,
|
self.config.app_id,
|
||||||
self.config.app_secret,
|
self.config.app_secret,
|
||||||
|
domain=domain,
|
||||||
event_handler=event_handler,
|
event_handler=event_handler,
|
||||||
log_level=lark.LogLevel.INFO,
|
log_level=lark.LogLevel.INFO,
|
||||||
)
|
)
|
||||||
@ -1012,14 +1020,29 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
elif msg_type in ("audio", "file", "media"):
|
elif msg_type in ("audio", "file", "media"):
|
||||||
file_key = content_json.get("file_key")
|
file_key = content_json.get("file_key")
|
||||||
if file_key and message_id:
|
if not file_key:
|
||||||
data, filename = await loop.run_in_executor(
|
logger.warning("Feishu {} message missing file_key: {}", msg_type, content_json)
|
||||||
None, self._download_file_sync, message_id, file_key, msg_type
|
return None, f"[{msg_type}: missing file_key]"
|
||||||
)
|
if not message_id:
|
||||||
if not filename:
|
logger.warning("Feishu {} message missing message_id", msg_type)
|
||||||
filename = file_key[:16]
|
return None, f"[{msg_type}: missing message_id]"
|
||||||
if msg_type == "audio" and not filename.endswith(".opus"):
|
|
||||||
filename = f"{filename}.opus"
|
data, filename = await loop.run_in_executor(
|
||||||
|
None, self._download_file_sync, message_id, file_key, msg_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if not data:
|
||||||
|
logger.warning("Feishu {} download failed: file_key={}", msg_type, file_key)
|
||||||
|
return None, f"[{msg_type}: download failed]"
|
||||||
|
|
||||||
|
if not filename:
|
||||||
|
filename = file_key[:16]
|
||||||
|
|
||||||
|
# Feishu voice messages are opus in OGG container.
|
||||||
|
# Use .ogg extension for better Whisper compatibility.
|
||||||
|
if msg_type == "audio":
|
||||||
|
if not any(filename.endswith(ext) for ext in (".opus", ".ogg", ".oga")):
|
||||||
|
filename = f"{filename}.ogg"
|
||||||
|
|
||||||
if data and filename:
|
if data and filename:
|
||||||
file_path = media_dir / filename
|
file_path = media_dir / filename
|
||||||
@ -1263,7 +1286,14 @@ class FeishuChannel(BaseChannel):
|
|||||||
async def send_delta(
|
async def send_delta(
|
||||||
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Progressive streaming via CardKit: create card on first delta, stream-update on subsequent."""
|
"""Progressive streaming via CardKit: create card on first delta, stream-update on subsequent.
|
||||||
|
|
||||||
|
Supported metadata keys:
|
||||||
|
_stream_end: Finalize the streaming card.
|
||||||
|
_tool_hint: Delta is a formatted tool hint (for display only).
|
||||||
|
message_id: Original message id (used with _stream_end for reaction cleanup).
|
||||||
|
reaction_id: Reaction id to remove on stream end.
|
||||||
|
"""
|
||||||
if not self._client:
|
if not self._client:
|
||||||
return
|
return
|
||||||
meta = metadata or {}
|
meta = metadata or {}
|
||||||
@ -1274,38 +1304,48 @@ class FeishuChannel(BaseChannel):
|
|||||||
if meta.get("_stream_end"):
|
if meta.get("_stream_end"):
|
||||||
if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")):
|
if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")):
|
||||||
await self._remove_reaction(message_id, reaction_id)
|
await self._remove_reaction(message_id, reaction_id)
|
||||||
|
# Add completion emoji if configured
|
||||||
|
if self.config.done_emoji and message_id:
|
||||||
|
await self._add_reaction(message_id, self.config.done_emoji)
|
||||||
|
|
||||||
buf = self._stream_bufs.pop(chat_id, None)
|
buf = self._stream_bufs.pop(chat_id, None)
|
||||||
if not buf or not buf.text:
|
if not buf or not buf.text:
|
||||||
return
|
return
|
||||||
|
# Try to finalize via streaming card; if that fails (e.g.
|
||||||
|
# streaming mode was closed by Feishu due to timeout), fall
|
||||||
|
# back to sending a regular interactive card.
|
||||||
if buf.card_id:
|
if buf.card_id:
|
||||||
buf.sequence += 1
|
buf.sequence += 1
|
||||||
await loop.run_in_executor(
|
ok = await loop.run_in_executor(
|
||||||
None,
|
None,
|
||||||
self._stream_update_text_sync,
|
self._stream_update_text_sync,
|
||||||
buf.card_id,
|
buf.card_id,
|
||||||
buf.text,
|
buf.text,
|
||||||
buf.sequence,
|
buf.sequence,
|
||||||
)
|
)
|
||||||
# Required so the chat list preview exits the streaming placeholder (Feishu streaming card docs).
|
if ok:
|
||||||
buf.sequence += 1
|
buf.sequence += 1
|
||||||
await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
self._close_streaming_mode_sync,
|
|
||||||
buf.card_id,
|
|
||||||
buf.sequence,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for chunk in self._split_elements_by_table_limit(
|
|
||||||
self._build_card_elements(buf.text)
|
|
||||||
):
|
|
||||||
card = json.dumps(
|
|
||||||
{"config": {"wide_screen_mode": True}, "elements": chunk},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
None, self._send_message_sync, rid_type, chat_id, "interactive", card
|
None,
|
||||||
|
self._close_streaming_mode_sync,
|
||||||
|
buf.card_id,
|
||||||
|
buf.sequence,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
logger.warning(
|
||||||
|
"Streaming card {} final update failed, falling back to regular card",
|
||||||
|
buf.card_id,
|
||||||
|
)
|
||||||
|
for chunk in self._split_elements_by_table_limit(
|
||||||
|
self._build_card_elements(buf.text)
|
||||||
|
):
|
||||||
|
card = json.dumps(
|
||||||
|
{"config": {"wide_screen_mode": True}, "elements": chunk},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, self._send_message_sync, rid_type, chat_id, "interactive", card
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# --- accumulate delta ---
|
# --- accumulate delta ---
|
||||||
@ -1346,13 +1386,33 @@ class FeishuChannel(BaseChannel):
|
|||||||
receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id"
|
receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id"
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
# Handle tool hint messages as code blocks in interactive cards.
|
# Handle tool hint messages. When a streaming card is active for
|
||||||
# These are progress-only messages and should bypass normal reply routing.
|
# this chat, inline the hint into the card instead of sending a
|
||||||
|
# separate message so the user experience stays cohesive.
|
||||||
if msg.metadata.get("_tool_hint"):
|
if msg.metadata.get("_tool_hint"):
|
||||||
if msg.content and msg.content.strip():
|
hint = (msg.content or "").strip()
|
||||||
await self._send_tool_hint_card(
|
if not hint:
|
||||||
receive_id_type, msg.chat_id, msg.content.strip()
|
return
|
||||||
|
buf = self._stream_bufs.get(msg.chat_id)
|
||||||
|
if buf and buf.card_id:
|
||||||
|
# Delegate to send_delta so tool hints get the same
|
||||||
|
# throttling (and card creation) as regular text deltas.
|
||||||
|
await self.send_delta(
|
||||||
|
msg.chat_id,
|
||||||
|
"\n\n" + self._format_tool_hint_delta(hint) + "\n\n",
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
# No active streaming card — send as a regular
|
||||||
|
# interactive card with the same 🔧 prefix style.
|
||||||
|
card = json.dumps(
|
||||||
|
{"config": {"wide_screen_mode": True}, "elements": [
|
||||||
|
{"tag": "markdown", "content": self._format_tool_hint_delta(hint)},
|
||||||
|
]},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Determine whether the first message should quote the user's message.
|
# Determine whether the first message should quote the user's message.
|
||||||
@ -1648,33 +1708,9 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
return "\n".join(part for part in parts if part)
|
return "\n".join(part for part in parts if part)
|
||||||
|
|
||||||
async def _send_tool_hint_card(
|
def _format_tool_hint_delta(self, tool_hint: str) -> str:
|
||||||
self, receive_id_type: str, receive_id: str, tool_hint: str
|
"""Format a tool hint string with the 🔧 prefix for each line."""
|
||||||
) -> None:
|
lines = self.__class__._format_tool_hint_lines(tool_hint).split("\n")
|
||||||
"""Send tool hint as an interactive card with formatted code block.
|
return "\n".join(
|
||||||
|
f"{self.config.tool_hint_prefix} {ln}" for ln in lines if ln.strip()
|
||||||
Args:
|
|
||||||
receive_id_type: "chat_id" or "open_id"
|
|
||||||
receive_id: The target chat or user ID
|
|
||||||
tool_hint: Formatted tool hint string (e.g., 'web_search("q"), read_file("path")')
|
|
||||||
"""
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
|
|
||||||
# Put each top-level tool call on its own line without altering commas inside arguments.
|
|
||||||
formatted_code = self._format_tool_hint_lines(tool_hint)
|
|
||||||
|
|
||||||
card = {
|
|
||||||
"config": {"wide_screen_mode": True},
|
|
||||||
"elements": [
|
|
||||||
{"tag": "markdown", "content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```"}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
await loop.run_in_executor(
|
|
||||||
None,
|
|
||||||
self._send_message_sync,
|
|
||||||
receive_id_type,
|
|
||||||
receive_id,
|
|
||||||
"interactive",
|
|
||||||
json.dumps(card, ensure_ascii=False),
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -242,43 +242,49 @@ class QQChannel(BaseChannel):
|
|||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send attachments first, then text."""
|
"""Send attachments first, then text."""
|
||||||
if not self._client:
|
try:
|
||||||
logger.warning("QQ client not initialized")
|
if not self._client:
|
||||||
return
|
logger.warning("QQ client not initialized")
|
||||||
|
return
|
||||||
|
|
||||||
msg_id = msg.metadata.get("message_id")
|
msg_id = msg.metadata.get("message_id")
|
||||||
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||||
is_group = chat_type == "group"
|
is_group = chat_type == "group"
|
||||||
|
|
||||||
# 1) Send media
|
# 1) Send media
|
||||||
for media_ref in msg.media or []:
|
for media_ref in msg.media or []:
|
||||||
ok = await self._send_media(
|
ok = await self._send_media(
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
media_ref=media_ref,
|
media_ref=media_ref,
|
||||||
msg_id=msg_id,
|
msg_id=msg_id,
|
||||||
is_group=is_group,
|
is_group=is_group,
|
||||||
)
|
|
||||||
if not ok:
|
|
||||||
filename = (
|
|
||||||
os.path.basename(urlparse(media_ref).path)
|
|
||||||
or os.path.basename(media_ref)
|
|
||||||
or "file"
|
|
||||||
)
|
)
|
||||||
|
if not ok:
|
||||||
|
filename = (
|
||||||
|
os.path.basename(urlparse(media_ref).path)
|
||||||
|
or os.path.basename(media_ref)
|
||||||
|
or "file"
|
||||||
|
)
|
||||||
|
await self._send_text_only(
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
is_group=is_group,
|
||||||
|
msg_id=msg_id,
|
||||||
|
content=f"[Attachment send failed: {filename}]",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Send text
|
||||||
|
if msg.content and msg.content.strip():
|
||||||
await self._send_text_only(
|
await self._send_text_only(
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
is_group=is_group,
|
is_group=is_group,
|
||||||
msg_id=msg_id,
|
msg_id=msg_id,
|
||||||
content=f"[Attachment send failed: {filename}]",
|
content=msg.content.strip(),
|
||||||
)
|
)
|
||||||
|
except (aiohttp.ClientError, OSError):
|
||||||
# 2) Send text
|
# Network / transport errors — propagate so ChannelManager can retry
|
||||||
if msg.content and msg.content.strip():
|
raise
|
||||||
await self._send_text_only(
|
except Exception:
|
||||||
chat_id=msg.chat_id,
|
logger.exception("Error sending QQ message to chat_id={}", msg.chat_id)
|
||||||
is_group=is_group,
|
|
||||||
msg_id=msg_id,
|
|
||||||
content=msg.content.strip(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _send_text_only(
|
async def _send_text_only(
|
||||||
self,
|
self,
|
||||||
@ -359,7 +365,12 @@ class QQChannel(BaseChannel):
|
|||||||
|
|
||||||
logger.info("QQ media sent: {}", filename)
|
logger.info("QQ media sent: {}", filename)
|
||||||
return True
|
return True
|
||||||
|
except (aiohttp.ClientError, OSError) as e:
|
||||||
|
# Network / transport errors — propagate for retry by caller
|
||||||
|
logger.warning("QQ send media network error filename={} err={}", filename, e)
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# API-level or other non-network errors — return False so send() can fallback
|
||||||
logger.error("QQ send media failed filename={} err={}", filename, e)
|
logger.error("QQ send media failed filename={} err={}", filename, e)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -438,15 +449,26 @@ class QQChannel(BaseChannel):
|
|||||||
endpoint = "/v2/users/{openid}/files"
|
endpoint = "/v2/users/{openid}/files"
|
||||||
id_key = "openid"
|
id_key = "openid"
|
||||||
|
|
||||||
payload = {
|
payload: dict[str, Any] = {
|
||||||
id_key: chat_id,
|
id_key: chat_id,
|
||||||
"file_type": file_type,
|
"file_type": file_type,
|
||||||
"file_data": file_data,
|
"file_data": file_data,
|
||||||
"file_name": file_name,
|
|
||||||
"srv_send_msg": srv_send_msg,
|
"srv_send_msg": srv_send_msg,
|
||||||
}
|
}
|
||||||
|
# Only pass file_name for non-image types (file_type=4).
|
||||||
|
# Passing file_name for images causes QQ client to render them as
|
||||||
|
# file attachments instead of inline images.
|
||||||
|
if file_type != QQ_FILE_TYPE_IMAGE and file_name:
|
||||||
|
payload["file_name"] = file_name
|
||||||
|
|
||||||
route = Route("POST", endpoint, **{id_key: chat_id})
|
route = Route("POST", endpoint, **{id_key: chat_id})
|
||||||
return await self._client.api._http.request(route, json=payload)
|
result = await self._client.api._http.request(route, json=payload)
|
||||||
|
|
||||||
|
# Extract only the file_info field to avoid extra fields (file_uuid, ttl, etc.)
|
||||||
|
# that may confuse QQ client when sending the media object.
|
||||||
|
if isinstance(result, dict) and "file_info" in result:
|
||||||
|
return {"file_info": result["file_info"]}
|
||||||
|
return result
|
||||||
|
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
# Inbound (receive)
|
# Inbound (receive)
|
||||||
@ -454,58 +476,68 @@ class QQChannel(BaseChannel):
|
|||||||
|
|
||||||
async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None:
|
async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None:
|
||||||
"""Parse inbound message, download attachments, and publish to the bus."""
|
"""Parse inbound message, download attachments, and publish to the bus."""
|
||||||
if data.id in self._processed_ids:
|
try:
|
||||||
return
|
if data.id in self._processed_ids:
|
||||||
self._processed_ids.append(data.id)
|
return
|
||||||
|
self._processed_ids.append(data.id)
|
||||||
|
|
||||||
if is_group:
|
if is_group:
|
||||||
chat_id = data.group_openid
|
chat_id = data.group_openid
|
||||||
user_id = data.author.member_openid
|
user_id = data.author.member_openid
|
||||||
self._chat_type_cache[chat_id] = "group"
|
self._chat_type_cache[chat_id] = "group"
|
||||||
else:
|
else:
|
||||||
chat_id = str(
|
chat_id = str(
|
||||||
getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown")
|
getattr(data.author, "id", None)
|
||||||
)
|
or getattr(data.author, "user_openid", "unknown")
|
||||||
user_id = chat_id
|
|
||||||
self._chat_type_cache[chat_id] = "c2c"
|
|
||||||
|
|
||||||
content = (data.content or "").strip()
|
|
||||||
|
|
||||||
# the data used by tests don't contain attachments property
|
|
||||||
# so we use getattr with a default of [] to avoid AttributeError in tests
|
|
||||||
attachments = getattr(data, "attachments", None) or []
|
|
||||||
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
|
|
||||||
|
|
||||||
# Compose content that always contains actionable saved paths
|
|
||||||
if recv_lines:
|
|
||||||
tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]"
|
|
||||||
file_block = "Received files:\n" + "\n".join(recv_lines)
|
|
||||||
content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
|
|
||||||
|
|
||||||
if not content and not media_paths:
|
|
||||||
return
|
|
||||||
|
|
||||||
if self.config.ack_message:
|
|
||||||
try:
|
|
||||||
await self._send_text_only(
|
|
||||||
chat_id=chat_id,
|
|
||||||
is_group=is_group,
|
|
||||||
msg_id=data.id,
|
|
||||||
content=self.config.ack_message,
|
|
||||||
)
|
)
|
||||||
except Exception:
|
user_id = chat_id
|
||||||
logger.debug("QQ ack message failed for chat_id={}", chat_id)
|
self._chat_type_cache[chat_id] = "c2c"
|
||||||
|
|
||||||
await self._handle_message(
|
content = (data.content or "").strip()
|
||||||
sender_id=user_id,
|
|
||||||
chat_id=chat_id,
|
# the data used by tests don't contain attachments property
|
||||||
content=content,
|
# so we use getattr with a default of [] to avoid AttributeError in tests
|
||||||
media=media_paths if media_paths else None,
|
attachments = getattr(data, "attachments", None) or []
|
||||||
metadata={
|
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
|
||||||
"message_id": data.id,
|
|
||||||
"attachments": att_meta,
|
# Compose content that always contains actionable saved paths
|
||||||
},
|
if recv_lines:
|
||||||
)
|
tag = (
|
||||||
|
"[Image]"
|
||||||
|
if any(_is_image_name(Path(p).name) for p in media_paths)
|
||||||
|
else "[File]"
|
||||||
|
)
|
||||||
|
file_block = "Received files:\n" + "\n".join(recv_lines)
|
||||||
|
content = (
|
||||||
|
f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not content and not media_paths:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.config.ack_message:
|
||||||
|
try:
|
||||||
|
await self._send_text_only(
|
||||||
|
chat_id=chat_id,
|
||||||
|
is_group=is_group,
|
||||||
|
msg_id=data.id,
|
||||||
|
content=self.config.ack_message,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("QQ ack message failed for chat_id={}", chat_id)
|
||||||
|
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=user_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
content=content,
|
||||||
|
media=media_paths if media_paths else None,
|
||||||
|
metadata={
|
||||||
|
"message_id": data.id,
|
||||||
|
"attachments": att_meta,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error handling QQ inbound message id={}", getattr(data, "id", "?"))
|
||||||
|
|
||||||
async def _handle_attachments(
|
async def _handle_attachments(
|
||||||
self,
|
self,
|
||||||
@ -520,7 +552,9 @@ class QQChannel(BaseChannel):
|
|||||||
return media_paths, recv_lines, att_meta
|
return media_paths, recv_lines, att_meta
|
||||||
|
|
||||||
for att in attachments:
|
for att in attachments:
|
||||||
url, filename, ctype = att.url, att.filename, att.content_type
|
url = getattr(att, "url", None) or ""
|
||||||
|
filename = getattr(att, "filename", None) or ""
|
||||||
|
ctype = getattr(att, "content_type", None) or ""
|
||||||
|
|
||||||
logger.info("Downloading file from QQ: {}", filename or url)
|
logger.info("Downloading file from QQ: {}", filename or url)
|
||||||
local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename)
|
local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename)
|
||||||
@ -555,6 +589,10 @@ class QQChannel(BaseChannel):
|
|||||||
Enforces a max download size and writes to a .part temp file
|
Enforces a max download size and writes to a .part temp file
|
||||||
that is atomically renamed on success.
|
that is atomically renamed on success.
|
||||||
"""
|
"""
|
||||||
|
# Handle protocol-relative URLs (e.g. "//multimedia.nt.qq.com/...")
|
||||||
|
if url.startswith("//"):
|
||||||
|
url = f"https:{url}"
|
||||||
|
|
||||||
if not self._http:
|
if not self._http:
|
||||||
self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
|
self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import re
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||||
@ -13,8 +14,6 @@ from slackify_markdown import slackify_markdown
|
|||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
|
|
||||||
@ -50,6 +49,9 @@ class SlackChannel(BaseChannel):
|
|||||||
|
|
||||||
name = "slack"
|
name = "slack"
|
||||||
display_name = "Slack"
|
display_name = "Slack"
|
||||||
|
_SLACK_ID_RE = re.compile(r"^[CDGUW][A-Z0-9]{2,}$")
|
||||||
|
_SLACK_CHANNEL_REF_RE = re.compile(r"^<#([A-Z0-9]+)(?:\|[^>]+)?>$")
|
||||||
|
_SLACK_USER_REF_RE = re.compile(r"^<@([A-Z0-9]+)(?:\|[^>]+)?>$")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_config(cls) -> dict[str, Any]:
|
def default_config(cls) -> dict[str, Any]:
|
||||||
@ -63,6 +65,7 @@ class SlackChannel(BaseChannel):
|
|||||||
self._web_client: AsyncWebClient | None = None
|
self._web_client: AsyncWebClient | None = None
|
||||||
self._socket_client: SocketModeClient | None = None
|
self._socket_client: SocketModeClient | None = None
|
||||||
self._bot_user_id: str | None = None
|
self._bot_user_id: str | None = None
|
||||||
|
self._target_cache: dict[str, str] = {}
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Slack Socket Mode client."""
|
"""Start the Slack Socket Mode client."""
|
||||||
@ -113,17 +116,23 @@ class SlackChannel(BaseChannel):
|
|||||||
logger.warning("Slack client not running")
|
logger.warning("Slack client not running")
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
|
target_chat_id = await self._resolve_target_chat_id(msg.chat_id)
|
||||||
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
|
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
|
||||||
thread_ts = slack_meta.get("thread_ts")
|
thread_ts = slack_meta.get("thread_ts")
|
||||||
channel_type = slack_meta.get("channel_type")
|
channel_type = slack_meta.get("channel_type")
|
||||||
|
origin_chat_id = str((slack_meta.get("event", {}) or {}).get("channel") or msg.chat_id)
|
||||||
# Slack DMs don't use threads; channel/group replies may keep thread_ts.
|
# Slack DMs don't use threads; channel/group replies may keep thread_ts.
|
||||||
thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None
|
thread_ts_param = (
|
||||||
|
thread_ts
|
||||||
|
if thread_ts and channel_type != "im" and target_chat_id == origin_chat_id
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# Slack rejects empty text payloads. Keep media-only messages media-only,
|
# Slack rejects empty text payloads. Keep media-only messages media-only,
|
||||||
# but send a single blank message when the bot has no text or files to send.
|
# but send a single blank message when the bot has no text or files to send.
|
||||||
if msg.content or not (msg.media or []):
|
if msg.content or not (msg.media or []):
|
||||||
await self._web_client.chat_postMessage(
|
await self._web_client.chat_postMessage(
|
||||||
channel=msg.chat_id,
|
channel=target_chat_id,
|
||||||
text=self._to_mrkdwn(msg.content) if msg.content else " ",
|
text=self._to_mrkdwn(msg.content) if msg.content else " ",
|
||||||
thread_ts=thread_ts_param,
|
thread_ts=thread_ts_param,
|
||||||
)
|
)
|
||||||
@ -131,7 +140,7 @@ class SlackChannel(BaseChannel):
|
|||||||
for media_path in msg.media or []:
|
for media_path in msg.media or []:
|
||||||
try:
|
try:
|
||||||
await self._web_client.files_upload_v2(
|
await self._web_client.files_upload_v2(
|
||||||
channel=msg.chat_id,
|
channel=target_chat_id,
|
||||||
file=media_path,
|
file=media_path,
|
||||||
thread_ts=thread_ts_param,
|
thread_ts=thread_ts_param,
|
||||||
)
|
)
|
||||||
@ -141,12 +150,123 @@ class SlackChannel(BaseChannel):
|
|||||||
# Update reaction emoji when the final (non-progress) response is sent
|
# Update reaction emoji when the final (non-progress) response is sent
|
||||||
if not (msg.metadata or {}).get("_progress"):
|
if not (msg.metadata or {}).get("_progress"):
|
||||||
event = slack_meta.get("event", {})
|
event = slack_meta.get("event", {})
|
||||||
await self._update_react_emoji(msg.chat_id, event.get("ts"))
|
await self._update_react_emoji(origin_chat_id, event.get("ts"))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending Slack message: {}", e)
|
logger.error("Error sending Slack message: {}", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def _resolve_target_chat_id(self, target: str) -> str:
|
||||||
|
"""Resolve human-friendly Slack targets to concrete IDs when needed."""
|
||||||
|
if not self._web_client:
|
||||||
|
return target
|
||||||
|
|
||||||
|
target = target.strip()
|
||||||
|
if not target:
|
||||||
|
return target
|
||||||
|
|
||||||
|
if match := self._SLACK_CHANNEL_REF_RE.fullmatch(target):
|
||||||
|
return match.group(1)
|
||||||
|
if match := self._SLACK_USER_REF_RE.fullmatch(target):
|
||||||
|
return await self._open_dm_for_user(match.group(1))
|
||||||
|
if self._SLACK_ID_RE.fullmatch(target):
|
||||||
|
if target.startswith(("U", "W")):
|
||||||
|
return await self._open_dm_for_user(target)
|
||||||
|
return target
|
||||||
|
|
||||||
|
if target.startswith("#"):
|
||||||
|
return await self._resolve_channel_name(target[1:])
|
||||||
|
if target.startswith("@"):
|
||||||
|
return await self._resolve_user_handle(target[1:])
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._resolve_channel_name(target)
|
||||||
|
except ValueError:
|
||||||
|
return await self._resolve_user_handle(target)
|
||||||
|
|
||||||
|
async def _resolve_channel_name(self, name: str) -> str:
|
||||||
|
normalized = self._normalize_target_name(name)
|
||||||
|
if not normalized:
|
||||||
|
raise ValueError("Slack target channel name is empty")
|
||||||
|
|
||||||
|
cache_key = f"channel:{normalized}"
|
||||||
|
if cache_key in self._target_cache:
|
||||||
|
return self._target_cache[cache_key]
|
||||||
|
|
||||||
|
cursor: str | None = None
|
||||||
|
while True:
|
||||||
|
response = await self._web_client.conversations_list(
|
||||||
|
types="public_channel,private_channel",
|
||||||
|
exclude_archived=True,
|
||||||
|
limit=200,
|
||||||
|
cursor=cursor,
|
||||||
|
)
|
||||||
|
for channel in response.get("channels", []):
|
||||||
|
if self._normalize_target_name(str(channel.get("name") or "")) == normalized:
|
||||||
|
channel_id = str(channel.get("id") or "")
|
||||||
|
if channel_id:
|
||||||
|
self._target_cache[cache_key] = channel_id
|
||||||
|
return channel_id
|
||||||
|
cursor = ((response.get("response_metadata") or {}).get("next_cursor") or "").strip()
|
||||||
|
if not cursor:
|
||||||
|
break
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Slack channel '{name}' was not found. Use a joined channel name like "
|
||||||
|
f"'#general' or a concrete channel ID."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _resolve_user_handle(self, handle: str) -> str:
|
||||||
|
normalized = self._normalize_target_name(handle)
|
||||||
|
if not normalized:
|
||||||
|
raise ValueError("Slack target user handle is empty")
|
||||||
|
|
||||||
|
cache_key = f"user:{normalized}"
|
||||||
|
if cache_key in self._target_cache:
|
||||||
|
return self._target_cache[cache_key]
|
||||||
|
|
||||||
|
cursor: str | None = None
|
||||||
|
while True:
|
||||||
|
response = await self._web_client.users_list(limit=200, cursor=cursor)
|
||||||
|
for member in response.get("members", []):
|
||||||
|
if self._member_matches_handle(member, normalized):
|
||||||
|
user_id = str(member.get("id") or "")
|
||||||
|
if not user_id:
|
||||||
|
continue
|
||||||
|
dm_id = await self._open_dm_for_user(user_id)
|
||||||
|
self._target_cache[cache_key] = dm_id
|
||||||
|
return dm_id
|
||||||
|
cursor = ((response.get("response_metadata") or {}).get("next_cursor") or "").strip()
|
||||||
|
if not cursor:
|
||||||
|
break
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Slack user '{handle}' was not found. Use '@name' or a concrete DM/channel ID."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _open_dm_for_user(self, user_id: str) -> str:
|
||||||
|
response = await self._web_client.conversations_open(users=user_id)
|
||||||
|
channel_id = str(((response.get("channel") or {}).get("id")) or "")
|
||||||
|
if not channel_id:
|
||||||
|
raise ValueError(f"Slack DM target for user '{user_id}' could not be opened.")
|
||||||
|
return channel_id
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_target_name(value: str) -> str:
|
||||||
|
return value.strip().lstrip("#@").lower()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _member_matches_handle(cls, member: dict[str, Any], normalized: str) -> bool:
|
||||||
|
profile = member.get("profile") or {}
|
||||||
|
candidates = {
|
||||||
|
str(member.get("name") or ""),
|
||||||
|
str(profile.get("display_name") or ""),
|
||||||
|
str(profile.get("display_name_normalized") or ""),
|
||||||
|
str(profile.get("real_name") or ""),
|
||||||
|
str(profile.get("real_name_normalized") or ""),
|
||||||
|
}
|
||||||
|
return normalized in {cls._normalize_target_name(candidate) for candidate in candidates if candidate}
|
||||||
|
|
||||||
async def _on_socket_request(
|
async def _on_socket_request(
|
||||||
self,
|
self,
|
||||||
client: SocketModeClient,
|
client: SocketModeClient,
|
||||||
|
|||||||
@ -166,6 +166,7 @@ def _markdown_to_telegram_html(text: str) -> str:
|
|||||||
|
|
||||||
_SEND_MAX_RETRIES = 3
|
_SEND_MAX_RETRIES = 3
|
||||||
_SEND_RETRY_BASE_DELAY = 0.5 # seconds, doubled each retry
|
_SEND_RETRY_BASE_DELAY = 0.5 # seconds, doubled each retry
|
||||||
|
_STREAM_EDIT_INTERVAL_DEFAULT = 0.6 # min seconds between edit_message_text calls
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -190,6 +191,7 @@ class TelegramConfig(Base):
|
|||||||
connection_pool_size: int = 32
|
connection_pool_size: int = 32
|
||||||
pool_timeout: float = 5.0
|
pool_timeout: float = 5.0
|
||||||
streaming: bool = True
|
streaming: bool = True
|
||||||
|
stream_edit_interval: float = Field(default=_STREAM_EDIT_INTERVAL_DEFAULT, ge=0.1)
|
||||||
|
|
||||||
|
|
||||||
class TelegramChannel(BaseChannel):
|
class TelegramChannel(BaseChannel):
|
||||||
@ -219,8 +221,6 @@ class TelegramChannel(BaseChannel):
|
|||||||
def default_config(cls) -> dict[str, Any]:
|
def default_config(cls) -> dict[str, Any]:
|
||||||
return TelegramConfig().model_dump(by_alias=True)
|
return TelegramConfig().model_dump(by_alias=True)
|
||||||
|
|
||||||
_STREAM_EDIT_INTERVAL = 0.6 # min seconds between edit_message_text calls
|
|
||||||
|
|
||||||
def __init__(self, config: Any, bus: MessageBus):
|
def __init__(self, config: Any, bus: MessageBus):
|
||||||
if isinstance(config, dict):
|
if isinstance(config, dict):
|
||||||
config = TelegramConfig.model_validate(config)
|
config = TelegramConfig.model_validate(config)
|
||||||
@ -520,7 +520,10 @@ class TelegramChannel(BaseChannel):
|
|||||||
reply_parameters=reply_params,
|
reply_parameters=reply_params,
|
||||||
**(thread_kwargs or {}),
|
**(thread_kwargs or {}),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except BadRequest as e:
|
||||||
|
# Only fall back to plain text on actual HTML parse/format errors.
|
||||||
|
# Network errors (TimedOut, NetworkError) should propagate immediately
|
||||||
|
# to avoid doubling connection demand during pool exhaustion.
|
||||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||||
try:
|
try:
|
||||||
await self._call_with_retry(
|
await self._call_with_retry(
|
||||||
@ -567,7 +570,10 @@ class TelegramChannel(BaseChannel):
|
|||||||
chat_id=int_chat_id, message_id=buf.message_id,
|
chat_id=int_chat_id, message_id=buf.message_id,
|
||||||
text=html, parse_mode="HTML",
|
text=html, parse_mode="HTML",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except BadRequest as e:
|
||||||
|
# Only fall back to plain text on actual HTML parse/format errors.
|
||||||
|
# Network errors (TimedOut, NetworkError) should propagate immediately
|
||||||
|
# to avoid doubling connection demand during pool exhaustion.
|
||||||
if self._is_not_modified_error(e):
|
if self._is_not_modified_error(e):
|
||||||
logger.debug("Final stream edit already applied for {}", chat_id)
|
logger.debug("Final stream edit already applied for {}", chat_id)
|
||||||
self._stream_bufs.pop(chat_id, None)
|
self._stream_bufs.pop(chat_id, None)
|
||||||
@ -619,7 +625,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Stream initial send failed: {}", e)
|
logger.warning("Stream initial send failed: {}", e)
|
||||||
raise # Let ChannelManager handle retry
|
raise # Let ChannelManager handle retry
|
||||||
elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
|
elif (now - buf.last_edit) >= self.config.stream_edit_interval:
|
||||||
try:
|
try:
|
||||||
await self._call_with_retry(
|
await self._call_with_retry(
|
||||||
self._app.bot.edit_message_text,
|
self._app.bot.edit_message_text,
|
||||||
|
|||||||
457
nanobot/channels/websocket.py
Normal file
457
nanobot/channels/websocket.py
Normal file
@ -0,0 +1,457 @@
|
|||||||
|
"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import email.utils
|
||||||
|
import hmac
|
||||||
|
import http
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import ssl
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Self
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import Field, field_validator, model_validator
|
||||||
|
from websockets.asyncio.server import ServerConnection, serve
|
||||||
|
from websockets.datastructures import Headers
|
||||||
|
from websockets.exceptions import ConnectionClosed
|
||||||
|
from websockets.http11 import Request as WsRequest, Response
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.base import BaseChannel
|
||||||
|
from nanobot.config.schema import Base
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_trailing_slash(path: str) -> str:
|
||||||
|
if len(path) > 1 and path.endswith("/"):
|
||||||
|
return path.rstrip("/")
|
||||||
|
return path or "/"
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_config_path(path: str) -> str:
|
||||||
|
return _strip_trailing_slash(path)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketConfig(Base):
|
||||||
|
"""WebSocket server channel configuration.
|
||||||
|
|
||||||
|
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
|
||||||
|
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
|
||||||
|
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
|
||||||
|
from ``token_issue_path`` are also accepted.
|
||||||
|
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
|
||||||
|
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
|
||||||
|
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
|
||||||
|
nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call
|
||||||
|
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
|
||||||
|
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
|
||||||
|
``X-Nanobot-Auth: <secret>``.
|
||||||
|
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
|
||||||
|
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
|
||||||
|
- ``media`` field in outbound messages contains local filesystem paths; remote clients need a
|
||||||
|
shared filesystem or an HTTP file server to access these files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
host: str = "127.0.0.1"
|
||||||
|
port: int = 8765
|
||||||
|
path: str = "/"
|
||||||
|
token: str = ""
|
||||||
|
token_issue_path: str = ""
|
||||||
|
token_issue_secret: str = ""
|
||||||
|
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
|
||||||
|
websocket_requires_token: bool = True
|
||||||
|
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
||||||
|
streaming: bool = True
|
||||||
|
max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216)
|
||||||
|
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||||||
|
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||||||
|
ssl_certfile: str = ""
|
||||||
|
ssl_keyfile: str = ""
|
||||||
|
|
||||||
|
@field_validator("path")
|
||||||
|
@classmethod
|
||||||
|
def path_must_start_with_slash(cls, value: str) -> str:
|
||||||
|
if not value.startswith("/"):
|
||||||
|
raise ValueError('path must start with "/"')
|
||||||
|
return _normalize_config_path(value)
|
||||||
|
|
||||||
|
@field_validator("token_issue_path")
|
||||||
|
@classmethod
|
||||||
|
def token_issue_path_format(cls, value: str) -> str:
|
||||||
|
value = value.strip()
|
||||||
|
if not value:
|
||||||
|
return ""
|
||||||
|
if not value.startswith("/"):
|
||||||
|
raise ValueError('token_issue_path must start with "/"')
|
||||||
|
return _normalize_config_path(value)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def token_issue_path_differs_from_ws_path(self) -> Self:
|
||||||
|
if not self.token_issue_path:
|
||||||
|
return self
|
||||||
|
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
|
||||||
|
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||||
|
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||||
|
headers = Headers(
|
||||||
|
[
|
||||||
|
("Date", email.utils.formatdate(usegmt=True)),
|
||||||
|
("Connection", "close"),
|
||||||
|
("Content-Length", str(len(body))),
|
||||||
|
("Content-Type", "application/json; charset=utf-8"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
reason = http.HTTPStatus(status).phrase
|
||||||
|
return Response(status, reason, headers, body)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||||||
|
"""Parse normalized path and query parameters in one pass."""
|
||||||
|
parsed = urlparse("ws://x" + path_with_query)
|
||||||
|
path = _strip_trailing_slash(parsed.path or "/")
|
||||||
|
return path, parse_qs(parsed.query)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_http_path(path_with_query: str) -> str:
|
||||||
|
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
|
||||||
|
return _parse_request_path(path_with_query)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||||||
|
return _parse_request_path(path_with_query)[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
|
||||||
|
"""Return the first value for *key*, or None."""
|
||||||
|
values = query.get(key)
|
||||||
|
return values[0] if values else None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_inbound_payload(raw: str) -> str | None:
|
||||||
|
"""Parse a client frame into text; return None for empty or unrecognized content."""
|
||||||
|
text = raw.strip()
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
if text.startswith("{"):
|
||||||
|
try:
|
||||||
|
data = json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return text
|
||||||
|
if isinstance(data, dict):
|
||||||
|
for key in ("content", "text", "message"):
|
||||||
|
value = data.get(key)
|
||||||
|
if isinstance(value, str) and value.strip():
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||||||
|
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
|
||||||
|
if not configured_secret:
|
||||||
|
return True
|
||||||
|
authorization = headers.get("Authorization") or headers.get("authorization")
|
||||||
|
if authorization and authorization.lower().startswith("bearer "):
|
||||||
|
supplied = authorization[7:].strip()
|
||||||
|
return hmac.compare_digest(supplied, configured_secret)
|
||||||
|
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||||||
|
if not header_token:
|
||||||
|
return False
|
||||||
|
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketChannel(BaseChannel):
|
||||||
|
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
|
||||||
|
|
||||||
|
name = "websocket"
|
||||||
|
display_name = "WebSocket"
|
||||||
|
|
||||||
|
def __init__(self, config: Any, bus: MessageBus):
|
||||||
|
if isinstance(config, dict):
|
||||||
|
config = WebSocketConfig.model_validate(config)
|
||||||
|
super().__init__(config, bus)
|
||||||
|
self.config: WebSocketConfig = config
|
||||||
|
self._connections: dict[str, Any] = {}
|
||||||
|
self._issued_tokens: dict[str, float] = {}
|
||||||
|
self._stop_event: asyncio.Event | None = None
|
||||||
|
self._server_task: asyncio.Task[None] | None = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default_config(cls) -> dict[str, Any]:
|
||||||
|
return WebSocketConfig().model_dump(by_alias=True)
|
||||||
|
|
||||||
|
def _expected_path(self) -> str:
|
||||||
|
return _normalize_config_path(self.config.path)
|
||||||
|
|
||||||
|
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
||||||
|
cert = self.config.ssl_certfile.strip()
|
||||||
|
key = self.config.ssl_keyfile.strip()
|
||||||
|
if not cert and not key:
|
||||||
|
return None
|
||||||
|
if not cert or not key:
|
||||||
|
raise ValueError(
|
||||||
|
"websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
||||||
|
)
|
||||||
|
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||||
|
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||||
|
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
_MAX_ISSUED_TOKENS = 10_000
|
||||||
|
|
||||||
|
def _purge_expired_issued_tokens(self) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
for token_key, expiry in list(self._issued_tokens.items()):
|
||||||
|
if now > expiry:
|
||||||
|
self._issued_tokens.pop(token_key, None)
|
||||||
|
|
||||||
|
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||||||
|
"""Validate and consume one issued token (single use per connection attempt).
|
||||||
|
|
||||||
|
Uses single-step pop to minimize the window between lookup and removal;
|
||||||
|
safe under asyncio's single-threaded cooperative model.
|
||||||
|
"""
|
||||||
|
if not token_value:
|
||||||
|
return False
|
||||||
|
self._purge_expired_issued_tokens()
|
||||||
|
expiry = self._issued_tokens.pop(token_value, None)
|
||||||
|
if expiry is None:
|
||||||
|
return False
|
||||||
|
if time.monotonic() > expiry:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
|
||||||
|
secret = self.config.token_issue_secret.strip()
|
||||||
|
if secret:
|
||||||
|
if not _issue_route_secret_matches(request.headers, secret):
|
||||||
|
return connection.respond(401, "Unauthorized")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"websocket: token_issue_path is set but token_issue_secret is empty; "
|
||||||
|
"any client can obtain connection tokens — set token_issue_secret for production."
|
||||||
|
)
|
||||||
|
self._purge_expired_issued_tokens()
|
||||||
|
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
||||||
|
logger.error(
|
||||||
|
"websocket: too many outstanding issued tokens ({}), rejecting issuance",
|
||||||
|
len(self._issued_tokens),
|
||||||
|
)
|
||||||
|
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
||||||
|
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||||
|
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||||||
|
|
||||||
|
return _http_json_response(
|
||||||
|
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
|
||||||
|
supplied = _query_first(query, "token")
|
||||||
|
static_token = self.config.token.strip()
|
||||||
|
|
||||||
|
if static_token:
|
||||||
|
if supplied and hmac.compare_digest(supplied, static_token):
|
||||||
|
return None
|
||||||
|
if supplied and self._take_issued_token_if_valid(supplied):
|
||||||
|
return None
|
||||||
|
return connection.respond(401, "Unauthorized")
|
||||||
|
|
||||||
|
if self.config.websocket_requires_token:
|
||||||
|
if supplied and self._take_issued_token_if_valid(supplied):
|
||||||
|
return None
|
||||||
|
return connection.respond(401, "Unauthorized")
|
||||||
|
|
||||||
|
if supplied:
|
||||||
|
self._take_issued_token_if_valid(supplied)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self._running = True
|
||||||
|
self._stop_event = asyncio.Event()
|
||||||
|
|
||||||
|
ssl_context = self._build_ssl_context()
|
||||||
|
scheme = "wss" if ssl_context else "ws"
|
||||||
|
|
||||||
|
async def process_request(
|
||||||
|
connection: ServerConnection,
|
||||||
|
request: WsRequest,
|
||||||
|
) -> Any:
|
||||||
|
got, _ = _parse_request_path(request.path)
|
||||||
|
if self.config.token_issue_path:
|
||||||
|
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
||||||
|
if got == issue_expected:
|
||||||
|
return self._handle_token_issue_http(connection, request)
|
||||||
|
|
||||||
|
expected_ws = self._expected_path()
|
||||||
|
if got != expected_ws:
|
||||||
|
return connection.respond(404, "Not Found")
|
||||||
|
# Early reject before WebSocket upgrade to avoid unnecessary overhead;
|
||||||
|
# _handle_message() performs a second check as defense-in-depth.
|
||||||
|
query = _parse_query(request.path)
|
||||||
|
client_id = _query_first(query, "client_id") or ""
|
||||||
|
if len(client_id) > 128:
|
||||||
|
client_id = client_id[:128]
|
||||||
|
if not self.is_allowed(client_id):
|
||||||
|
return connection.respond(403, "Forbidden")
|
||||||
|
return self._authorize_websocket_handshake(connection, query)
|
||||||
|
|
||||||
|
async def handler(connection: ServerConnection) -> None:
|
||||||
|
await self._connection_loop(connection)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"WebSocket server listening on {}://{}:{}{}",
|
||||||
|
scheme,
|
||||||
|
self.config.host,
|
||||||
|
self.config.port,
|
||||||
|
self.config.path,
|
||||||
|
)
|
||||||
|
if self.config.token_issue_path:
|
||||||
|
logger.info(
|
||||||
|
"WebSocket token issue route: {}://{}:{}{}",
|
||||||
|
scheme,
|
||||||
|
self.config.host,
|
||||||
|
self.config.port,
|
||||||
|
_normalize_config_path(self.config.token_issue_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def runner() -> None:
|
||||||
|
async with serve(
|
||||||
|
handler,
|
||||||
|
self.config.host,
|
||||||
|
self.config.port,
|
||||||
|
process_request=process_request,
|
||||||
|
max_size=self.config.max_message_bytes,
|
||||||
|
ping_interval=self.config.ping_interval_s,
|
||||||
|
ping_timeout=self.config.ping_timeout_s,
|
||||||
|
ssl=ssl_context,
|
||||||
|
):
|
||||||
|
assert self._stop_event is not None
|
||||||
|
await self._stop_event.wait()
|
||||||
|
|
||||||
|
self._server_task = asyncio.create_task(runner())
|
||||||
|
await self._server_task
|
||||||
|
|
||||||
|
async def _connection_loop(self, connection: Any) -> None:
|
||||||
|
request = connection.request
|
||||||
|
path_part = request.path if request else "/"
|
||||||
|
_, query = _parse_request_path(path_part)
|
||||||
|
client_id_raw = _query_first(query, "client_id")
|
||||||
|
client_id = client_id_raw.strip() if client_id_raw else ""
|
||||||
|
if not client_id:
|
||||||
|
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
||||||
|
elif len(client_id) > 128:
|
||||||
|
logger.warning("websocket: client_id too long ({} chars), truncating", len(client_id))
|
||||||
|
client_id = client_id[:128]
|
||||||
|
|
||||||
|
chat_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
await connection.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"event": "ready",
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"client_id": client_id,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Register only after ready is successfully sent to avoid out-of-order sends
|
||||||
|
self._connections[chat_id] = connection
|
||||||
|
|
||||||
|
async for raw in connection:
|
||||||
|
if isinstance(raw, bytes):
|
||||||
|
try:
|
||||||
|
raw = raw.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logger.warning("websocket: ignoring non-utf8 binary frame")
|
||||||
|
continue
|
||||||
|
content = _parse_inbound_payload(raw)
|
||||||
|
if content is None:
|
||||||
|
continue
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=client_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
content=content,
|
||||||
|
metadata={"remote": getattr(connection, "remote_address", None)},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("websocket connection ended: {}", e)
|
||||||
|
finally:
|
||||||
|
self._connections.pop(chat_id, None)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
if not self._running:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
if self._stop_event:
|
||||||
|
self._stop_event.set()
|
||||||
|
if self._server_task:
|
||||||
|
try:
|
||||||
|
await self._server_task
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("websocket: server task error during shutdown: {}", e)
|
||||||
|
self._server_task = None
|
||||||
|
self._connections.clear()
|
||||||
|
self._issued_tokens.clear()
|
||||||
|
|
||||||
|
async def _safe_send(self, chat_id: str, raw: str, *, label: str = "") -> None:
|
||||||
|
"""Send a raw frame, cleaning up dead connections on ConnectionClosed."""
|
||||||
|
connection = self._connections.get(chat_id)
|
||||||
|
if connection is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await connection.send(raw)
|
||||||
|
except ConnectionClosed:
|
||||||
|
self._connections.pop(chat_id, None)
|
||||||
|
logger.warning("websocket{}connection gone for chat_id={}", label, chat_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("websocket{}send failed: {}", label, e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
|
connection = self._connections.get(msg.chat_id)
|
||||||
|
if connection is None:
|
||||||
|
logger.warning("websocket: no active connection for chat_id={}", msg.chat_id)
|
||||||
|
return
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"event": "message",
|
||||||
|
"text": msg.content,
|
||||||
|
}
|
||||||
|
if msg.media:
|
||||||
|
payload["media"] = msg.media
|
||||||
|
if msg.reply_to:
|
||||||
|
payload["reply_to"] = msg.reply_to
|
||||||
|
raw = json.dumps(payload, ensure_ascii=False)
|
||||||
|
await self._safe_send(msg.chat_id, raw, label=" ")
|
||||||
|
|
||||||
|
async def send_delta(
|
||||||
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
delta: str,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
if self._connections.get(chat_id) is None:
|
||||||
|
return
|
||||||
|
meta = metadata or {}
|
||||||
|
if meta.get("_stream_end"):
|
||||||
|
body: dict[str, Any] = {"event": "stream_end"}
|
||||||
|
else:
|
||||||
|
body = {
|
||||||
|
"event": "delta",
|
||||||
|
"text": delta,
|
||||||
|
}
|
||||||
|
if meta.get("_stream_id") is not None:
|
||||||
|
body["stream_id"] = meta["_stream_id"]
|
||||||
|
raw = json.dumps(body, ensure_ascii=False)
|
||||||
|
await self._safe_send(chat_id, raw, label=" stream ")
|
||||||
@ -1,9 +1,13 @@
|
|||||||
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
|
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -17,6 +21,37 @@ from pydantic import Field
|
|||||||
|
|
||||||
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||||
|
|
||||||
|
# Upload safety limits (matching QQ channel defaults)
|
||||||
|
WECOM_UPLOAD_MAX_BYTES = 1024 * 1024 * 200 # 200MB
|
||||||
|
|
||||||
|
# Replace unsafe characters with "_", keep Chinese and common safe punctuation.
|
||||||
|
_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]()【】\u4e00-\u9fff]+", re.UNICODE)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_filename(name: str) -> str:
|
||||||
|
"""Sanitize filename to avoid traversal and problematic chars."""
|
||||||
|
name = (name or "").strip()
|
||||||
|
name = Path(name).name
|
||||||
|
name = _SAFE_NAME_RE.sub("_", name).strip("._ ")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
|
||||||
|
_VIDEO_EXTS = {".mp4", ".avi", ".mov"}
|
||||||
|
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg"}
|
||||||
|
|
||||||
|
|
||||||
|
def _guess_wecom_media_type(filename: str) -> str:
|
||||||
|
"""Classify file extension as WeCom media_type string."""
|
||||||
|
ext = Path(filename).suffix.lower()
|
||||||
|
if ext in _IMAGE_EXTS:
|
||||||
|
return "image"
|
||||||
|
if ext in _VIDEO_EXTS:
|
||||||
|
return "video"
|
||||||
|
if ext in _AUDIO_EXTS:
|
||||||
|
return "voice"
|
||||||
|
return "file"
|
||||||
|
|
||||||
class WecomConfig(Base):
|
class WecomConfig(Base):
|
||||||
"""WeCom (Enterprise WeChat) AI Bot channel configuration."""
|
"""WeCom (Enterprise WeChat) AI Bot channel configuration."""
|
||||||
|
|
||||||
@ -217,6 +252,7 @@ class WecomChannel(BaseChannel):
|
|||||||
chat_id = body.get("chatid", sender_id)
|
chat_id = body.get("chatid", sender_id)
|
||||||
|
|
||||||
content_parts = []
|
content_parts = []
|
||||||
|
media_paths: list[str] = []
|
||||||
|
|
||||||
if msg_type == "text":
|
if msg_type == "text":
|
||||||
text = body.get("text", {}).get("content", "")
|
text = body.get("text", {}).get("content", "")
|
||||||
@ -232,7 +268,8 @@ class WecomChannel(BaseChannel):
|
|||||||
file_path = await self._download_and_save_media(file_url, aes_key, "image")
|
file_path = await self._download_and_save_media(file_url, aes_key, "image")
|
||||||
if file_path:
|
if file_path:
|
||||||
filename = os.path.basename(file_path)
|
filename = os.path.basename(file_path)
|
||||||
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
|
content_parts.append(f"[image: {filename}]")
|
||||||
|
media_paths.append(file_path)
|
||||||
else:
|
else:
|
||||||
content_parts.append("[image: download failed]")
|
content_parts.append("[image: download failed]")
|
||||||
else:
|
else:
|
||||||
@ -256,7 +293,8 @@ class WecomChannel(BaseChannel):
|
|||||||
if file_url and aes_key:
|
if file_url and aes_key:
|
||||||
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
|
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
|
||||||
if file_path:
|
if file_path:
|
||||||
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
|
content_parts.append(f"[file: {file_name}]")
|
||||||
|
media_paths.append(file_path)
|
||||||
else:
|
else:
|
||||||
content_parts.append(f"[file: {file_name}: download failed]")
|
content_parts.append(f"[file: {file_name}: download failed]")
|
||||||
else:
|
else:
|
||||||
@ -286,12 +324,11 @@ class WecomChannel(BaseChannel):
|
|||||||
self._chat_frames[chat_id] = frame
|
self._chat_frames[chat_id] = frame
|
||||||
|
|
||||||
# Forward to message bus
|
# Forward to message bus
|
||||||
# Note: media paths are included in content for broader model compatibility
|
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
content=content,
|
content=content,
|
||||||
media=None,
|
media=media_paths or None,
|
||||||
metadata={
|
metadata={
|
||||||
"message_id": msg_id,
|
"message_id": msg_id,
|
||||||
"msg_type": msg_type,
|
"msg_type": msg_type,
|
||||||
@ -322,13 +359,21 @@ class WecomChannel(BaseChannel):
|
|||||||
logger.warning("Failed to download media from WeCom")
|
logger.warning("Failed to download media from WeCom")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if len(data) > WECOM_UPLOAD_MAX_BYTES:
|
||||||
|
logger.warning(
|
||||||
|
"WeCom inbound media too large: {} bytes (max {})",
|
||||||
|
len(data),
|
||||||
|
WECOM_UPLOAD_MAX_BYTES,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
media_dir = get_media_dir("wecom")
|
media_dir = get_media_dir("wecom")
|
||||||
if not filename:
|
if not filename:
|
||||||
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
|
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
|
||||||
filename = os.path.basename(filename)
|
filename = _sanitize_filename(filename)
|
||||||
|
|
||||||
file_path = media_dir / filename
|
file_path = media_dir / filename
|
||||||
file_path.write_bytes(data)
|
await asyncio.to_thread(file_path.write_bytes, data)
|
||||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||||
return str(file_path)
|
return str(file_path)
|
||||||
|
|
||||||
@ -336,6 +381,100 @@ class WecomChannel(BaseChannel):
|
|||||||
logger.error("Error downloading media: {}", e)
|
logger.error("Error downloading media: {}", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _upload_media_ws(
|
||||||
|
self, client: Any, file_path: str,
|
||||||
|
) -> "tuple[str, str] | tuple[None, None]":
|
||||||
|
"""Upload a local file to WeCom via WebSocket 3-step protocol (base64).
|
||||||
|
|
||||||
|
Uses the WeCom WebSocket upload commands directly via
|
||||||
|
``client._ws_manager.send_reply()``:
|
||||||
|
|
||||||
|
``aibot_upload_media_init`` → upload_id
|
||||||
|
``aibot_upload_media_chunk`` × N (≤512 KB raw per chunk, base64)
|
||||||
|
``aibot_upload_media_finish`` → media_id
|
||||||
|
|
||||||
|
Returns (media_id, media_type) on success, (None, None) on failure.
|
||||||
|
"""
|
||||||
|
from wecom_aibot_sdk.utils import generate_req_id as _gen_req_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
fname = os.path.basename(file_path)
|
||||||
|
media_type = _guess_wecom_media_type(fname)
|
||||||
|
|
||||||
|
# Read file size and data in a thread to avoid blocking the event loop
|
||||||
|
def _read_file():
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
if file_size > WECOM_UPLOAD_MAX_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"File too large: {file_size} bytes (max {WECOM_UPLOAD_MAX_BYTES})"
|
||||||
|
)
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
return file_size, f.read()
|
||||||
|
|
||||||
|
file_size, data = await asyncio.to_thread(_read_file)
|
||||||
|
# MD5 is used for file integrity only, not cryptographic security
|
||||||
|
md5_hash = hashlib.md5(data).hexdigest()
|
||||||
|
|
||||||
|
CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64)
|
||||||
|
mv = memoryview(data)
|
||||||
|
chunk_list = [bytes(mv[i : i + CHUNK_SIZE]) for i in range(0, file_size, CHUNK_SIZE)]
|
||||||
|
n_chunks = len(chunk_list)
|
||||||
|
del mv, data
|
||||||
|
|
||||||
|
# Step 1: init
|
||||||
|
req_id = _gen_req_id("upload_init")
|
||||||
|
resp = await client._ws_manager.send_reply(req_id, {
|
||||||
|
"type": media_type,
|
||||||
|
"filename": fname,
|
||||||
|
"total_size": file_size,
|
||||||
|
"total_chunks": n_chunks,
|
||||||
|
"md5": md5_hash,
|
||||||
|
}, "aibot_upload_media_init")
|
||||||
|
if resp.errcode != 0:
|
||||||
|
logger.warning("WeCom upload init failed ({}): {}", resp.errcode, resp.errmsg)
|
||||||
|
return None, None
|
||||||
|
upload_id = resp.body.get("upload_id") if resp.body else None
|
||||||
|
if not upload_id:
|
||||||
|
logger.warning("WeCom upload init: no upload_id in response")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Step 2: send chunks
|
||||||
|
for i, chunk in enumerate(chunk_list):
|
||||||
|
req_id = _gen_req_id("upload_chunk")
|
||||||
|
resp = await client._ws_manager.send_reply(req_id, {
|
||||||
|
"upload_id": upload_id,
|
||||||
|
"chunk_index": i,
|
||||||
|
"base64_data": base64.b64encode(chunk).decode(),
|
||||||
|
}, "aibot_upload_media_chunk")
|
||||||
|
if resp.errcode != 0:
|
||||||
|
logger.warning("WeCom upload chunk {} failed ({}): {}", i, resp.errcode, resp.errmsg)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Step 3: finish
|
||||||
|
req_id = _gen_req_id("upload_finish")
|
||||||
|
resp = await client._ws_manager.send_reply(req_id, {
|
||||||
|
"upload_id": upload_id,
|
||||||
|
}, "aibot_upload_media_finish")
|
||||||
|
if resp.errcode != 0:
|
||||||
|
logger.warning("WeCom upload finish failed ({}): {}", resp.errcode, resp.errmsg)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
media_id = resp.body.get("media_id") if resp.body else None
|
||||||
|
if not media_id:
|
||||||
|
logger.warning("WeCom upload finish: no media_id in response body={}", resp.body)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
suffix = "..." if len(media_id) > 16 else ""
|
||||||
|
logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + suffix)
|
||||||
|
return media_id, media_type
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning("WeCom upload skipped for {}: {}", file_path, e)
|
||||||
|
return None, None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("WeCom _upload_media_ws error for {}: {}", file_path, e)
|
||||||
|
return None, None
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send a message through WeCom."""
|
"""Send a message through WeCom."""
|
||||||
if not self._client:
|
if not self._client:
|
||||||
@ -343,29 +482,59 @@ class WecomChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = msg.content.strip()
|
content = (msg.content or "").strip()
|
||||||
if not content:
|
is_progress = bool(msg.metadata.get("_progress"))
|
||||||
return
|
|
||||||
|
|
||||||
# Get the stored frame for this chat
|
# Get the stored frame for this chat
|
||||||
frame = self._chat_frames.get(msg.chat_id)
|
frame = self._chat_frames.get(msg.chat_id)
|
||||||
if not frame:
|
|
||||||
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
|
# Send media files via WebSocket upload
|
||||||
|
for file_path in msg.media or []:
|
||||||
|
if not os.path.isfile(file_path):
|
||||||
|
logger.warning("WeCom media file not found: {}", file_path)
|
||||||
|
continue
|
||||||
|
media_id, media_type = await self._upload_media_ws(self._client, file_path)
|
||||||
|
if media_id:
|
||||||
|
if frame:
|
||||||
|
await self._client.reply(frame, {
|
||||||
|
"msgtype": media_type,
|
||||||
|
media_type: {"media_id": media_id},
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
await self._client.send_message(msg.chat_id, {
|
||||||
|
"msgtype": media_type,
|
||||||
|
media_type: {"media_id": media_id},
|
||||||
|
})
|
||||||
|
logger.debug("WeCom sent {} → {}", media_type, msg.chat_id)
|
||||||
|
else:
|
||||||
|
content += f"\n[file upload failed: {os.path.basename(file_path)}]"
|
||||||
|
|
||||||
|
if not content:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Use streaming reply for better UX
|
if frame:
|
||||||
stream_id = self._generate_req_id("stream")
|
# Both progress and final messages must use reply_stream (cmd="aibot_respond_msg").
|
||||||
|
# The plain reply() uses cmd="reply" which does not support "text" msgtype
|
||||||
|
# and causes errcode=40008 from WeCom API.
|
||||||
|
stream_id = self._generate_req_id("stream")
|
||||||
|
await self._client.reply_stream(
|
||||||
|
frame,
|
||||||
|
stream_id,
|
||||||
|
content,
|
||||||
|
finish=not is_progress,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"WeCom {} sent to {}",
|
||||||
|
"progress" if is_progress else "message",
|
||||||
|
msg.chat_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No frame (e.g. cron push): proactive send only supports markdown
|
||||||
|
await self._client.send_message(msg.chat_id, {
|
||||||
|
"msgtype": "markdown",
|
||||||
|
"markdown": {"content": content},
|
||||||
|
})
|
||||||
|
logger.info("WeCom proactive send to {}", msg.chat_id)
|
||||||
|
|
||||||
# Send as streaming message with finish=True
|
except Exception:
|
||||||
await self._client.reply_stream(
|
logger.exception("Error sending WeCom message to chat_id={}", msg.chat_id)
|
||||||
frame,
|
|
||||||
stream_id,
|
|
||||||
content,
|
|
||||||
finish=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug("WeCom message sent to {}", msg.chat_id)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error sending WeCom message: {}", e)
|
|
||||||
raise
|
|
||||||
|
|||||||
@ -985,7 +985,43 @@ class WeixinChannel(BaseChannel):
|
|||||||
for media_path in (msg.media or []):
|
for media_path in (msg.media or []):
|
||||||
try:
|
try:
|
||||||
await self._send_media_file(msg.chat_id, media_path, ctx_token)
|
await self._send_media_file(msg.chat_id, media_path, ctx_token)
|
||||||
|
except (httpx.TimeoutException, httpx.TransportError) as net_err:
|
||||||
|
# Network/transport errors: do NOT fall back to text —
|
||||||
|
# the text send would also likely fail, and the outer
|
||||||
|
# except will re-raise so ChannelManager retries properly.
|
||||||
|
logger.error(
|
||||||
|
"Network error sending WeChat media {}: {}",
|
||||||
|
media_path,
|
||||||
|
net_err,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
except httpx.HTTPStatusError as http_err:
|
||||||
|
status_code = (
|
||||||
|
http_err.response.status_code
|
||||||
|
if http_err.response is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
if status_code >= 500:
|
||||||
|
# Server-side / retryable HTTP error — same as network.
|
||||||
|
logger.error(
|
||||||
|
"Server error ({} {}) sending WeChat media {}: {}",
|
||||||
|
status_code,
|
||||||
|
http_err.response.reason_phrase
|
||||||
|
if http_err.response is not None
|
||||||
|
else "",
|
||||||
|
media_path,
|
||||||
|
http_err,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
# 4xx client errors are NOT retryable — fall back to text.
|
||||||
|
filename = Path(media_path).name
|
||||||
|
logger.error("Failed to send WeChat media {}: {}", media_path, http_err)
|
||||||
|
await self._send_text(
|
||||||
|
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Non-network errors (format, file-not-found, etc.):
|
||||||
|
# notify the user via text fallback.
|
||||||
filename = Path(media_path).name
|
filename = Path(media_path).name
|
||||||
logger.error("Failed to send WeChat media {}: {}", media_path, e)
|
logger.error("Failed to send WeChat media {}: {}", media_path, e)
|
||||||
# Notify user about failure via text
|
# Notify user about failure via text
|
||||||
|
|||||||
@ -590,6 +590,9 @@ def serve(
|
|||||||
mcp_servers=runtime_config.tools.mcp_servers,
|
mcp_servers=runtime_config.tools.mcp_servers,
|
||||||
channels_config=runtime_config.channels,
|
channels_config=runtime_config.channels,
|
||||||
timezone=runtime_config.agents.defaults.timezone,
|
timezone=runtime_config.agents.defaults.timezone,
|
||||||
|
unified_session=runtime_config.agents.defaults.unified_session,
|
||||||
|
disabled_skills=runtime_config.agents.defaults.disabled_skills,
|
||||||
|
session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_name = runtime_config.agents.defaults.model
|
model_name = runtime_config.agents.defaults.model
|
||||||
@ -681,6 +684,9 @@ def gateway(
|
|||||||
mcp_servers=config.tools.mcp_servers,
|
mcp_servers=config.tools.mcp_servers,
|
||||||
channels_config=config.channels,
|
channels_config=config.channels,
|
||||||
timezone=config.agents.defaults.timezone,
|
timezone=config.agents.defaults.timezone,
|
||||||
|
unified_session=config.agents.defaults.unified_session,
|
||||||
|
disabled_skills=config.agents.defaults.disabled_skills,
|
||||||
|
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set cron callback (needs agent)
|
# Set cron callback (needs agent)
|
||||||
@ -815,6 +821,48 @@ def gateway(
|
|||||||
|
|
||||||
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
|
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
|
||||||
|
|
||||||
|
async def _health_server(host: str, health_port: int):
|
||||||
|
"""Lightweight HTTP health endpoint on the gateway port."""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
async def handle(reader, writer):
|
||||||
|
try:
|
||||||
|
data = await asyncio.wait_for(reader.read(4096), timeout=5)
|
||||||
|
except (asyncio.TimeoutError, ConnectionError):
|
||||||
|
writer.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
request_line = data.split(b"\r\n", 1)[0].decode("utf-8", errors="replace")
|
||||||
|
method, path = "", ""
|
||||||
|
parts = request_line.split(" ")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
method, path = parts[0], parts[1]
|
||||||
|
|
||||||
|
if method == "GET" and path == "/health":
|
||||||
|
body = _json.dumps({"status": "ok"})
|
||||||
|
resp = (
|
||||||
|
f"HTTP/1.0 200 OK\r\n"
|
||||||
|
f"Content-Type: application/json\r\n"
|
||||||
|
f"Content-Length: {len(body)}\r\n"
|
||||||
|
f"\r\n{body}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
body = "Not Found"
|
||||||
|
resp = (
|
||||||
|
f"HTTP/1.0 404 Not Found\r\n"
|
||||||
|
f"Content-Type: text/plain\r\n"
|
||||||
|
f"Content-Length: {len(body)}\r\n"
|
||||||
|
f"\r\n{body}"
|
||||||
|
)
|
||||||
|
|
||||||
|
writer.write(resp.encode())
|
||||||
|
await writer.drain()
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
server = await asyncio.start_server(handle, host, health_port)
|
||||||
|
console.print(f"[green]✓[/green] Health endpoint: http://{host}:{health_port}/health")
|
||||||
|
async with server:
|
||||||
|
await server.serve_forever()
|
||||||
# Register Dream system job (always-on, idempotent on restart)
|
# Register Dream system job (always-on, idempotent on restart)
|
||||||
dream_cfg = config.agents.defaults.dream
|
dream_cfg = config.agents.defaults.dream
|
||||||
if dream_cfg.model_override:
|
if dream_cfg.model_override:
|
||||||
@ -837,6 +885,7 @@ def gateway(
|
|||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
agent.run(),
|
agent.run(),
|
||||||
channels.start_all(),
|
channels.start_all(),
|
||||||
|
_health_server(config.gateway.host, port),
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("\nShutting down...")
|
console.print("\nShutting down...")
|
||||||
@ -912,6 +961,9 @@ def agent(
|
|||||||
mcp_servers=config.tools.mcp_servers,
|
mcp_servers=config.tools.mcp_servers,
|
||||||
channels_config=config.channels,
|
channels_config=config.channels,
|
||||||
timezone=config.agents.defaults.timezone,
|
timezone=config.agents.defaults.timezone,
|
||||||
|
unified_session=config.agents.defaults.unified_session,
|
||||||
|
disabled_skills=config.agents.defaults.disabled_skills,
|
||||||
|
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||||
)
|
)
|
||||||
restart_notice = consume_restart_notice_from_env()
|
restart_notice = consume_restart_notice_from_env()
|
||||||
if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):
|
if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):
|
||||||
@ -1116,7 +1168,7 @@ def channels_status(
|
|||||||
|
|
||||||
table = Table(title="Channel Status")
|
table = Table(title="Channel Status")
|
||||||
table.add_column("Channel", style="cyan")
|
table.add_column("Channel", style="cyan")
|
||||||
table.add_column("Enabled", style="green")
|
table.add_column("Enabled")
|
||||||
|
|
||||||
for name, cls in sorted(discover_all().items()):
|
for name, cls in sorted(discover_all().items()):
|
||||||
section = getattr(config.channels, name, None)
|
section = getattr(config.channels, name, None)
|
||||||
@ -1251,7 +1303,7 @@ def plugins_list():
|
|||||||
table = Table(title="Channel Plugins")
|
table = Table(title="Channel Plugins")
|
||||||
table.add_column("Name", style="cyan")
|
table.add_column("Name", style="cyan")
|
||||||
table.add_column("Source", style="magenta")
|
table.add_column("Source", style="magenta")
|
||||||
table.add_column("Enabled", style="green")
|
table.add_column("Enabled")
|
||||||
|
|
||||||
for name in sorted(all_channels):
|
for name in sorted(all_channels):
|
||||||
cls = all_channels[name]
|
cls = all_channels[name]
|
||||||
|
|||||||
@ -76,6 +76,14 @@ class AgentDefaults(Base):
|
|||||||
provider_retry_mode: Literal["standard", "persistent"] = "standard"
|
provider_retry_mode: Literal["standard", "persistent"] = "standard"
|
||||||
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
||||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||||
|
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
||||||
|
disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"])
|
||||||
|
session_ttl_minutes: int = Field(
|
||||||
|
default=0,
|
||||||
|
ge=0,
|
||||||
|
validation_alias=AliasChoices("idleCompactAfterMinutes", "sessionTtlMinutes"),
|
||||||
|
serialization_alias="idleCompactAfterMinutes",
|
||||||
|
) # Auto-compact idle threshold in minutes (0 = disabled)
|
||||||
dream: DreamConfig = Field(default_factory=DreamConfig)
|
dream: DreamConfig = Field(default_factory=DreamConfig)
|
||||||
|
|
||||||
|
|
||||||
@ -144,7 +152,7 @@ class ApiConfig(Base):
|
|||||||
class GatewayConfig(Base):
|
class GatewayConfig(Base):
|
||||||
"""Gateway/server configuration."""
|
"""Gateway/server configuration."""
|
||||||
|
|
||||||
host: str = "0.0.0.0"
|
host: str = "127.0.0.1" # Safer default: local-only bind.
|
||||||
port: int = 18790
|
port: int = 18790
|
||||||
heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig)
|
heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig)
|
||||||
|
|
||||||
@ -152,7 +160,7 @@ class GatewayConfig(Base):
|
|||||||
class WebSearchConfig(Base):
|
class WebSearchConfig(Base):
|
||||||
"""Web search tool configuration."""
|
"""Web search tool configuration."""
|
||||||
|
|
||||||
provider: str = "duckduckgo" # brave, tavily, duckduckgo, searxng, jina
|
provider: str = "duckduckgo" # brave, tavily, duckduckgo, searxng, jina, kagi
|
||||||
api_key: str = ""
|
api_key: str = ""
|
||||||
base_url: str = "" # SearXNG base URL
|
base_url: str = "" # SearXNG base URL
|
||||||
max_results: int = 5
|
max_results: int = 5
|
||||||
@ -176,6 +184,7 @@ class ExecToolConfig(Base):
|
|||||||
timeout: int = 60
|
timeout: int = 60
|
||||||
path_append: str = ""
|
path_append: str = ""
|
||||||
sandbox: str = "" # sandbox backend: "" (none) or "bwrap"
|
sandbox: str = "" # sandbox backend: "" (none) or "bwrap"
|
||||||
|
allowed_env_keys: list[str] = Field(default_factory=list) # Env var names to pass through to subprocess (e.g. ["GOPATH", "JAVA_HOME"])
|
||||||
|
|
||||||
class MCPServerConfig(Base):
|
class MCPServerConfig(Base):
|
||||||
"""MCP server connection configuration (stdio or HTTP)."""
|
"""MCP server connection configuration (stdio or HTTP)."""
|
||||||
|
|||||||
@ -4,10 +4,12 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from dataclasses import asdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Coroutine, Literal
|
from typing import Any, Callable, Coroutine, Literal
|
||||||
|
|
||||||
|
from filelock import FileLock
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronRunRecord, CronSchedule, CronStore
|
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronRunRecord, CronSchedule, CronStore
|
||||||
@ -69,28 +71,26 @@ class CronService:
|
|||||||
self,
|
self,
|
||||||
store_path: Path,
|
store_path: Path,
|
||||||
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None,
|
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None,
|
||||||
|
max_sleep_ms: int = 300_000, # 5 minutes
|
||||||
):
|
):
|
||||||
self.store_path = store_path
|
self.store_path = store_path
|
||||||
|
self._action_path = store_path.parent / "action.jsonl"
|
||||||
|
self._lock = FileLock(str(self._action_path.parent) + ".lock")
|
||||||
self.on_job = on_job
|
self.on_job = on_job
|
||||||
self._store: CronStore | None = None
|
self._store: CronStore | None = None
|
||||||
self._last_mtime: float = 0.0
|
|
||||||
self._timer_task: asyncio.Task | None = None
|
self._timer_task: asyncio.Task | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._timer_active = False
|
||||||
|
self.max_sleep_ms = max_sleep_ms
|
||||||
|
|
||||||
def _load_store(self) -> CronStore:
|
def _load_jobs(self) -> tuple[list[CronJob], int]:
|
||||||
"""Load jobs from disk. Reloads automatically if file was modified externally."""
|
jobs = []
|
||||||
if self._store and self.store_path.exists():
|
version = 1
|
||||||
mtime = self.store_path.stat().st_mtime
|
|
||||||
if mtime != self._last_mtime:
|
|
||||||
logger.info("Cron: jobs.json modified externally, reloading")
|
|
||||||
self._store = None
|
|
||||||
if self._store:
|
|
||||||
return self._store
|
|
||||||
|
|
||||||
if self.store_path.exists():
|
if self.store_path.exists():
|
||||||
try:
|
try:
|
||||||
data = json.loads(self.store_path.read_text(encoding="utf-8"))
|
data = json.loads(self.store_path.read_text(encoding="utf-8"))
|
||||||
jobs = []
|
jobs = []
|
||||||
|
version = data.get("version", 1)
|
||||||
for j in data.get("jobs", []):
|
for j in data.get("jobs", []):
|
||||||
jobs.append(CronJob(
|
jobs.append(CronJob(
|
||||||
id=j["id"],
|
id=j["id"],
|
||||||
@ -129,13 +129,57 @@ class CronService:
|
|||||||
updated_at_ms=j.get("updatedAtMs", 0),
|
updated_at_ms=j.get("updatedAtMs", 0),
|
||||||
delete_after_run=j.get("deleteAfterRun", False),
|
delete_after_run=j.get("deleteAfterRun", False),
|
||||||
))
|
))
|
||||||
self._store = CronStore(jobs=jobs)
|
|
||||||
self._last_mtime = self.store_path.stat().st_mtime
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to load cron store: {}", e)
|
logger.warning("Failed to load cron store: {}", e)
|
||||||
self._store = CronStore()
|
return jobs, version
|
||||||
else:
|
|
||||||
self._store = CronStore()
|
def _merge_action(self):
|
||||||
|
if not self._action_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
jobs_map = {j.id: j for j in self._store.jobs}
|
||||||
|
def _update(params: dict):
|
||||||
|
j = CronJob.from_dict(params)
|
||||||
|
jobs_map[j.id] = j
|
||||||
|
|
||||||
|
def _del(params: dict):
|
||||||
|
if job_id := params.get("job_id"):
|
||||||
|
jobs_map.pop(job_id)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
with open(self._action_path, "r", encoding="utf-8") as f:
|
||||||
|
changed = False
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
line = line.strip()
|
||||||
|
action = json.loads(line)
|
||||||
|
if "action" not in action:
|
||||||
|
continue
|
||||||
|
if action["action"] == "del":
|
||||||
|
_del(action.get("params", {}))
|
||||||
|
else:
|
||||||
|
_update(action.get("params", {}))
|
||||||
|
changed = True
|
||||||
|
except Exception as exp:
|
||||||
|
logger.debug(f"load action line error: {exp}")
|
||||||
|
continue
|
||||||
|
self._store.jobs = list(jobs_map.values())
|
||||||
|
if self._running and changed:
|
||||||
|
self._action_path.write_text("", encoding="utf-8")
|
||||||
|
self._save_store()
|
||||||
|
return
|
||||||
|
|
||||||
|
def _load_store(self) -> CronStore:
|
||||||
|
"""Load jobs from disk. Reloads automatically if file was modified externally.
|
||||||
|
- Reload every time because it needs to merge operations on the jobs object from other instances.
|
||||||
|
- During _on_timer execution, return the existing store to prevent concurrent
|
||||||
|
_load_store calls (e.g. from list_jobs polling) from replacing it mid-execution.
|
||||||
|
"""
|
||||||
|
if self._timer_active and self._store:
|
||||||
|
return self._store
|
||||||
|
jobs, version = self._load_jobs()
|
||||||
|
self._store = CronStore(version=version, jobs=jobs)
|
||||||
|
self._merge_action()
|
||||||
|
|
||||||
return self._store
|
return self._store
|
||||||
|
|
||||||
@ -230,11 +274,14 @@ class CronService:
|
|||||||
if self._timer_task:
|
if self._timer_task:
|
||||||
self._timer_task.cancel()
|
self._timer_task.cancel()
|
||||||
|
|
||||||
next_wake = self._get_next_wake_ms()
|
if not self._running:
|
||||||
if not next_wake or not self._running:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
delay_ms = max(0, next_wake - _now_ms())
|
next_wake = self._get_next_wake_ms()
|
||||||
|
if next_wake is None:
|
||||||
|
delay_ms = self.max_sleep_ms
|
||||||
|
else:
|
||||||
|
delay_ms = min(self.max_sleep_ms, max(0, next_wake - _now_ms()))
|
||||||
delay_s = delay_ms / 1000
|
delay_s = delay_ms / 1000
|
||||||
|
|
||||||
async def tick():
|
async def tick():
|
||||||
@ -248,18 +295,23 @@ class CronService:
|
|||||||
"""Handle timer tick - run due jobs."""
|
"""Handle timer tick - run due jobs."""
|
||||||
self._load_store()
|
self._load_store()
|
||||||
if not self._store:
|
if not self._store:
|
||||||
|
self._arm_timer()
|
||||||
return
|
return
|
||||||
|
|
||||||
now = _now_ms()
|
self._timer_active = True
|
||||||
due_jobs = [
|
try:
|
||||||
j for j in self._store.jobs
|
now = _now_ms()
|
||||||
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
|
due_jobs = [
|
||||||
]
|
j for j in self._store.jobs
|
||||||
|
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
|
||||||
|
]
|
||||||
|
|
||||||
for job in due_jobs:
|
for job in due_jobs:
|
||||||
await self._execute_job(job)
|
await self._execute_job(job)
|
||||||
|
|
||||||
self._save_store()
|
self._save_store()
|
||||||
|
finally:
|
||||||
|
self._timer_active = False
|
||||||
self._arm_timer()
|
self._arm_timer()
|
||||||
|
|
||||||
async def _execute_job(self, job: CronJob) -> None:
|
async def _execute_job(self, job: CronJob) -> None:
|
||||||
@ -303,6 +355,13 @@ class CronService:
|
|||||||
# Compute next run
|
# Compute next run
|
||||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||||
|
|
||||||
|
def _append_action(self, action: Literal["add", "del", "update"], params: dict):
|
||||||
|
self.store_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with self._lock:
|
||||||
|
with open(self._action_path, "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps({"action": action, "params": params}, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
|
||||||
# ========== Public API ==========
|
# ========== Public API ==========
|
||||||
|
|
||||||
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
|
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
|
||||||
@ -322,7 +381,6 @@ class CronService:
|
|||||||
delete_after_run: bool = False,
|
delete_after_run: bool = False,
|
||||||
) -> CronJob:
|
) -> CronJob:
|
||||||
"""Add a new job."""
|
"""Add a new job."""
|
||||||
store = self._load_store()
|
|
||||||
_validate_schedule_for_add(schedule)
|
_validate_schedule_for_add(schedule)
|
||||||
now = _now_ms()
|
now = _now_ms()
|
||||||
|
|
||||||
@ -343,10 +401,13 @@ class CronService:
|
|||||||
updated_at_ms=now,
|
updated_at_ms=now,
|
||||||
delete_after_run=delete_after_run,
|
delete_after_run=delete_after_run,
|
||||||
)
|
)
|
||||||
|
if self._running:
|
||||||
store.jobs.append(job)
|
store = self._load_store()
|
||||||
self._save_store()
|
store.jobs.append(job)
|
||||||
self._arm_timer()
|
self._save_store()
|
||||||
|
self._arm_timer()
|
||||||
|
else:
|
||||||
|
self._append_action("add", asdict(job))
|
||||||
|
|
||||||
logger.info("Cron: added job '{}' ({})", name, job.id)
|
logger.info("Cron: added job '{}' ({})", name, job.id)
|
||||||
return job
|
return job
|
||||||
@ -380,8 +441,11 @@ class CronService:
|
|||||||
removed = len(store.jobs) < before
|
removed = len(store.jobs) < before
|
||||||
|
|
||||||
if removed:
|
if removed:
|
||||||
self._save_store()
|
if self._running:
|
||||||
self._arm_timer()
|
self._save_store()
|
||||||
|
self._arm_timer()
|
||||||
|
else:
|
||||||
|
self._append_action("del", {"job_id": job_id})
|
||||||
logger.info("Cron: removed job {}", job_id)
|
logger.info("Cron: removed job {}", job_id)
|
||||||
return "removed"
|
return "removed"
|
||||||
|
|
||||||
@ -398,23 +462,85 @@ class CronService:
|
|||||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||||
else:
|
else:
|
||||||
job.state.next_run_at_ms = None
|
job.state.next_run_at_ms = None
|
||||||
self._save_store()
|
if self._running:
|
||||||
self._arm_timer()
|
self._save_store()
|
||||||
|
self._arm_timer()
|
||||||
|
else:
|
||||||
|
self._append_action("update", asdict(job))
|
||||||
return job
|
return job
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def run_job(self, job_id: str, force: bool = False) -> bool:
|
def update_job(
|
||||||
"""Manually run a job."""
|
self,
|
||||||
|
job_id: str,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
schedule: CronSchedule | None = None,
|
||||||
|
message: str | None = None,
|
||||||
|
deliver: bool | None = None,
|
||||||
|
channel: str | None = ...,
|
||||||
|
to: str | None = ...,
|
||||||
|
delete_after_run: bool | None = None,
|
||||||
|
) -> CronJob | Literal["not_found", "protected"]:
|
||||||
|
"""Update mutable fields of an existing job. System jobs cannot be updated.
|
||||||
|
|
||||||
|
For ``channel`` and ``to``, pass an explicit value (including ``None``)
|
||||||
|
to update; omit (sentinel ``...``) to leave unchanged.
|
||||||
|
"""
|
||||||
store = self._load_store()
|
store = self._load_store()
|
||||||
for job in store.jobs:
|
job = next((j for j in store.jobs if j.id == job_id), None)
|
||||||
if job.id == job_id:
|
if job is None:
|
||||||
if not force and not job.enabled:
|
return "not_found"
|
||||||
return False
|
if job.payload.kind == "system_event":
|
||||||
await self._execute_job(job)
|
return "protected"
|
||||||
self._save_store()
|
|
||||||
|
if schedule is not None:
|
||||||
|
_validate_schedule_for_add(schedule)
|
||||||
|
job.schedule = schedule
|
||||||
|
if name is not None:
|
||||||
|
job.name = name
|
||||||
|
if message is not None:
|
||||||
|
job.payload.message = message
|
||||||
|
if deliver is not None:
|
||||||
|
job.payload.deliver = deliver
|
||||||
|
if channel is not ...:
|
||||||
|
job.payload.channel = channel
|
||||||
|
if to is not ...:
|
||||||
|
job.payload.to = to
|
||||||
|
if delete_after_run is not None:
|
||||||
|
job.delete_after_run = delete_after_run
|
||||||
|
|
||||||
|
job.updated_at_ms = _now_ms()
|
||||||
|
if job.enabled:
|
||||||
|
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||||
|
|
||||||
|
if self._running:
|
||||||
|
self._save_store()
|
||||||
|
self._arm_timer()
|
||||||
|
else:
|
||||||
|
self._append_action("update", asdict(job))
|
||||||
|
|
||||||
|
logger.info("Cron: updated job '{}' ({})", job.name, job.id)
|
||||||
|
return job
|
||||||
|
|
||||||
|
async def run_job(self, job_id: str, force: bool = False) -> bool:
|
||||||
|
"""Manually run a job without disturbing the service's running state."""
|
||||||
|
was_running = self._running
|
||||||
|
self._running = True
|
||||||
|
try:
|
||||||
|
store = self._load_store()
|
||||||
|
for job in store.jobs:
|
||||||
|
if job.id == job_id:
|
||||||
|
if not force and not job.enabled:
|
||||||
|
return False
|
||||||
|
await self._execute_job(job)
|
||||||
|
self._save_store()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
self._running = was_running
|
||||||
|
if was_running:
|
||||||
self._arm_timer()
|
self._arm_timer()
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_job(self, job_id: str) -> CronJob | None:
|
def get_job(self, job_id: str) -> CronJob | None:
|
||||||
"""Get a job by ID."""
|
"""Get a job by ID."""
|
||||||
|
|||||||
@ -61,6 +61,18 @@ class CronJob:
|
|||||||
updated_at_ms: int = 0
|
updated_at_ms: int = 0
|
||||||
delete_after_run: bool = False
|
delete_after_run: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, kwargs: dict):
|
||||||
|
state_kwargs = dict(kwargs.get("state", {}))
|
||||||
|
state_kwargs["run_history"] = [
|
||||||
|
record if isinstance(record, CronRunRecord) else CronRunRecord(**record)
|
||||||
|
for record in state_kwargs.get("run_history", [])
|
||||||
|
]
|
||||||
|
kwargs["schedule"] = CronSchedule(**kwargs.get("schedule", {"kind": "every"}))
|
||||||
|
kwargs["payload"] = CronPayload(**kwargs.get("payload", {}))
|
||||||
|
kwargs["state"] = CronJobState(**state_kwargs)
|
||||||
|
return cls(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CronStore:
|
class CronStore:
|
||||||
|
|||||||
@ -81,6 +81,9 @@ class Nanobot:
|
|||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
mcp_servers=config.tools.mcp_servers,
|
mcp_servers=config.tools.mcp_servers,
|
||||||
timezone=defaults.timezone,
|
timezone=defaults.timezone,
|
||||||
|
unified_session=defaults.unified_session,
|
||||||
|
disabled_skills=defaults.disabled_skills,
|
||||||
|
session_ttl_minutes=defaults.session_ttl_minutes,
|
||||||
)
|
)
|
||||||
return cls(loop)
|
return cls(loop)
|
||||||
|
|
||||||
|
|||||||
@ -353,6 +353,64 @@ class LLMProvider(ABC):
|
|||||||
# Unknown 429 defaults to WAIT+retry.
|
# Unknown 429 defaults to WAIT+retry.
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _enforce_role_alternation(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
"""Merge consecutive same-role messages and drop trailing assistant messages.
|
||||||
|
|
||||||
|
Some providers (OpenAI-compat, Azure, vLLM, Ollama, etc.) reject requests
|
||||||
|
where the last message is 'assistant' (prefill not supported) or two
|
||||||
|
consecutive non-system messages share the same role.
|
||||||
|
"""
|
||||||
|
if not messages:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
merged: list[dict[str, Any]] = []
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role")
|
||||||
|
if (
|
||||||
|
merged
|
||||||
|
and role != "system"
|
||||||
|
and role not in ("tool",)
|
||||||
|
and merged[-1].get("role") == role
|
||||||
|
and role in ("user", "assistant")
|
||||||
|
):
|
||||||
|
prev = merged[-1]
|
||||||
|
if role == "assistant":
|
||||||
|
prev_has_tools = bool(prev.get("tool_calls"))
|
||||||
|
curr_has_tools = bool(msg.get("tool_calls"))
|
||||||
|
if curr_has_tools:
|
||||||
|
merged[-1] = dict(msg)
|
||||||
|
continue
|
||||||
|
if prev_has_tools:
|
||||||
|
continue
|
||||||
|
prev_content = prev.get("content") or ""
|
||||||
|
curr_content = msg.get("content") or ""
|
||||||
|
if isinstance(prev_content, str) and isinstance(curr_content, str):
|
||||||
|
prev["content"] = (prev_content + "\n\n" + curr_content).strip()
|
||||||
|
else:
|
||||||
|
merged[-1] = dict(msg)
|
||||||
|
else:
|
||||||
|
merged.append(dict(msg))
|
||||||
|
|
||||||
|
last_popped = None
|
||||||
|
while merged and merged[-1].get("role") == "assistant":
|
||||||
|
last_popped = merged.pop()
|
||||||
|
|
||||||
|
# If removing trailing assistant messages left only system messages,
|
||||||
|
# the request would be invalid for most providers (e.g. Zhipu/GLM
|
||||||
|
# error 1214). Recover by converting the last popped assistant
|
||||||
|
# message to a user message so the LLM can still see the content.
|
||||||
|
if (
|
||||||
|
merged
|
||||||
|
and last_popped is not None
|
||||||
|
and not any(m.get("role") in ("user", "tool") for m in merged)
|
||||||
|
):
|
||||||
|
recovered = dict(last_popped)
|
||||||
|
recovered["role"] = "user"
|
||||||
|
merged.append(recovered)
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
||||||
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
|
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
|
||||||
@ -375,6 +433,26 @@ class LLMProvider(ABC):
|
|||||||
result.append(msg)
|
result.append(msg)
|
||||||
return result if found else None
|
return result if found else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_image_content_inplace(messages: list[dict[str, Any]]) -> bool:
|
||||||
|
"""Replace image_url blocks with text placeholder *in-place*.
|
||||||
|
|
||||||
|
Mutates the content lists of the original message dicts so that
|
||||||
|
callers holding references to those dicts also see the stripped
|
||||||
|
version.
|
||||||
|
"""
|
||||||
|
found = False
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
for i, b in enumerate(content):
|
||||||
|
if isinstance(b, dict) and b.get("type") == "image_url":
|
||||||
|
path = (b.get("_meta") or {}).get("path", "")
|
||||||
|
placeholder = image_placeholder_text(path, empty="[image omitted]")
|
||||||
|
content[i] = {"type": "text", "text": placeholder}
|
||||||
|
found = True
|
||||||
|
return found
|
||||||
|
|
||||||
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
|
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
|
||||||
"""Call chat() and convert unexpected exceptions to error responses."""
|
"""Call chat() and convert unexpected exceptions to error responses."""
|
||||||
try:
|
try:
|
||||||
@ -626,7 +704,12 @@ class LLMProvider(ABC):
|
|||||||
)
|
)
|
||||||
retry_kw = dict(kw)
|
retry_kw = dict(kw)
|
||||||
retry_kw["messages"] = stripped
|
retry_kw["messages"] = stripped
|
||||||
return await call(**retry_kw)
|
result = await call(**retry_kw)
|
||||||
|
# Permanently strip images from the original messages so
|
||||||
|
# subsequent iterations do not repeat the error-retry cycle.
|
||||||
|
if result.finish_reason != "error":
|
||||||
|
self._strip_image_content_inplace(original_messages)
|
||||||
|
return result
|
||||||
return response
|
return response
|
||||||
|
|
||||||
if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT:
|
if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT:
|
||||||
|
|||||||
@ -26,6 +26,12 @@ else:
|
|||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
from nanobot.providers.openai_responses import (
|
||||||
|
consume_sdk_stream,
|
||||||
|
convert_messages,
|
||||||
|
convert_tools,
|
||||||
|
parse_response_output,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanobot.providers.registry import ProviderSpec
|
from nanobot.providers.registry import ProviderSpec
|
||||||
@ -113,6 +119,14 @@ def _uses_openrouter_attribution(spec: "ProviderSpec | None", api_base: str | No
|
|||||||
return bool(api_base and "openrouter" in api_base.lower())
|
return bool(api_base and "openrouter" in api_base.lower())
|
||||||
|
|
||||||
|
|
||||||
|
def _is_direct_openai_base(api_base: str | None) -> bool:
|
||||||
|
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
|
||||||
|
if not api_base:
|
||||||
|
return True
|
||||||
|
normalized = api_base.strip().lower().rstrip("/")
|
||||||
|
return "api.openai.com" in normalized and "openrouter" not in normalized
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompatProvider(LLMProvider):
|
class OpenAICompatProvider(LLMProvider):
|
||||||
"""Unified provider for all OpenAI-compatible APIs.
|
"""Unified provider for all OpenAI-compatible APIs.
|
||||||
|
|
||||||
@ -137,6 +151,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
self._setup_env(api_key, api_base)
|
self._setup_env(api_key, api_base)
|
||||||
|
|
||||||
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
||||||
|
self._effective_base = effective_base
|
||||||
default_headers = {"x-session-affinity": uuid.uuid4().hex}
|
default_headers = {"x-session-affinity": uuid.uuid4().hex}
|
||||||
if _uses_openrouter_attribution(spec, effective_base):
|
if _uses_openrouter_attribution(spec, effective_base):
|
||||||
default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
|
default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
|
||||||
@ -228,9 +243,13 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||||
normalized.append(tc_clean)
|
normalized.append(tc_clean)
|
||||||
clean["tool_calls"] = normalized
|
clean["tool_calls"] = normalized
|
||||||
|
if clean.get("role") == "assistant":
|
||||||
|
# Some OpenAI-compatible gateways reject assistant messages
|
||||||
|
# that mix non-empty content with tool_calls.
|
||||||
|
clean["content"] = None
|
||||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||||
return sanitized
|
return self._enforce_role_alternation(sanitized)
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Build kwargs
|
# Build kwargs
|
||||||
@ -321,6 +340,88 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
|
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
def _should_use_responses_api(
|
||||||
|
self,
|
||||||
|
model: str | None,
|
||||||
|
reasoning_effort: str | None,
|
||||||
|
) -> bool:
|
||||||
|
"""Use Responses API only for direct OpenAI requests that benefit from it."""
|
||||||
|
if self._spec and self._spec.name != "openai":
|
||||||
|
return False
|
||||||
|
if not _is_direct_openai_base(self._effective_base):
|
||||||
|
return False
|
||||||
|
|
||||||
|
model_name = (model or self.default_model).lower()
|
||||||
|
if reasoning_effort and reasoning_effort.lower() != "none":
|
||||||
|
return True
|
||||||
|
return any(token in model_name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _should_fallback_from_responses_error(e: Exception) -> bool:
|
||||||
|
"""Fallback only for likely Responses API compatibility errors."""
|
||||||
|
response = getattr(e, "response", None)
|
||||||
|
status_code = getattr(e, "status_code", None)
|
||||||
|
if status_code is None and response is not None:
|
||||||
|
status_code = getattr(response, "status_code", None)
|
||||||
|
if status_code not in {400, 404, 422}:
|
||||||
|
return False
|
||||||
|
|
||||||
|
body = (
|
||||||
|
getattr(e, "body", None)
|
||||||
|
or getattr(e, "doc", None)
|
||||||
|
or getattr(response, "text", None)
|
||||||
|
)
|
||||||
|
body_text = str(body).lower() if body is not None else ""
|
||||||
|
compatibility_markers = (
|
||||||
|
"responses",
|
||||||
|
"response api",
|
||||||
|
"max_output_tokens",
|
||||||
|
"instructions",
|
||||||
|
"previous_response",
|
||||||
|
"unsupported",
|
||||||
|
"not supported",
|
||||||
|
"unknown parameter",
|
||||||
|
"unrecognized request argument",
|
||||||
|
)
|
||||||
|
return any(marker in body_text for marker in compatibility_markers)
|
||||||
|
|
||||||
|
def _build_responses_body(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
model: str | None,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
reasoning_effort: str | None,
|
||||||
|
tool_choice: str | dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Build a Responses API body for direct OpenAI requests."""
|
||||||
|
model_name = model or self.default_model
|
||||||
|
sanitized_messages = self._sanitize_messages(self._sanitize_empty_content(messages))
|
||||||
|
instructions, input_items = convert_messages(sanitized_messages)
|
||||||
|
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"model": model_name,
|
||||||
|
"instructions": instructions or None,
|
||||||
|
"input": input_items,
|
||||||
|
"max_output_tokens": max(1, max_tokens),
|
||||||
|
"store": False,
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._supports_temperature(model_name, reasoning_effort):
|
||||||
|
body["temperature"] = temperature
|
||||||
|
|
||||||
|
if reasoning_effort and reasoning_effort.lower() != "none":
|
||||||
|
body["reasoning"] = {"effort": reasoning_effort}
|
||||||
|
body["include"] = ["reasoning.encrypted_content"]
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
body["tools"] = convert_tools(tools)
|
||||||
|
body["tool_choice"] = tool_choice or "auto"
|
||||||
|
|
||||||
|
return body
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Response parsing
|
# Response parsing
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -698,7 +799,12 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _handle_error(e: Exception) -> LLMResponse:
|
def _handle_error(
|
||||||
|
e: Exception,
|
||||||
|
*,
|
||||||
|
spec: ProviderSpec | None = None,
|
||||||
|
api_base: str | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
body = (
|
body = (
|
||||||
getattr(e, "doc", None)
|
getattr(e, "doc", None)
|
||||||
or getattr(e, "body", None)
|
or getattr(e, "body", None)
|
||||||
@ -706,6 +812,15 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
)
|
)
|
||||||
body_text = body if isinstance(body, str) else str(body) if body is not None else ""
|
body_text = body if isinstance(body, str) else str(body) if body is not None else ""
|
||||||
msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}"
|
msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}"
|
||||||
|
|
||||||
|
text = f"{body_text} {e}".lower()
|
||||||
|
if spec and spec.is_local and ("502" in text or "connection" in text or "refused" in text):
|
||||||
|
msg += (
|
||||||
|
"\nHint: this is a local model endpoint. Check that the local server is reachable at "
|
||||||
|
f"{api_base or spec.default_api_base}, and if you are using a proxy/tunnel, make sure it "
|
||||||
|
"can reach your local Ollama/vLLM service instead of routing localhost through the remote host."
|
||||||
|
)
|
||||||
|
|
||||||
response = getattr(e, "response", None)
|
response = getattr(e, "response", None)
|
||||||
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
|
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
|
||||||
if retry_after is None:
|
if retry_after is None:
|
||||||
@ -731,14 +846,25 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
reasoning_effort: str | None = None,
|
reasoning_effort: str | None = None,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
kwargs = self._build_kwargs(
|
|
||||||
messages, tools, model, max_tokens, temperature,
|
|
||||||
reasoning_effort, tool_choice,
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
|
if self._should_use_responses_api(model, reasoning_effort):
|
||||||
|
try:
|
||||||
|
body = self._build_responses_body(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
|
return parse_response_output(await self._client.responses.create(**body))
|
||||||
|
except Exception as responses_error:
|
||||||
|
if not self._should_fallback_from_responses_error(responses_error):
|
||||||
|
raise
|
||||||
|
|
||||||
|
kwargs = self._build_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self._handle_error(e)
|
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
|
||||||
|
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
self,
|
self,
|
||||||
@ -751,14 +877,49 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
kwargs = self._build_kwargs(
|
|
||||||
messages, tools, model, max_tokens, temperature,
|
|
||||||
reasoning_effort, tool_choice,
|
|
||||||
)
|
|
||||||
kwargs["stream"] = True
|
|
||||||
kwargs["stream_options"] = {"include_usage": True}
|
|
||||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||||
try:
|
try:
|
||||||
|
if self._should_use_responses_api(model, reasoning_effort):
|
||||||
|
try:
|
||||||
|
body = self._build_responses_body(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
|
body["stream"] = True
|
||||||
|
stream = await self._client.responses.create(**body)
|
||||||
|
|
||||||
|
async def _timed_stream():
|
||||||
|
stream_iter = stream.__aiter__()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield await asyncio.wait_for(
|
||||||
|
stream_iter.__anext__(),
|
||||||
|
timeout=idle_timeout_s,
|
||||||
|
)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
content, tool_calls, finish_reason, usage, reasoning_content = await consume_sdk_stream(
|
||||||
|
_timed_stream(),
|
||||||
|
on_content_delta,
|
||||||
|
)
|
||||||
|
return LLMResponse(
|
||||||
|
content=content or None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
reasoning_content=reasoning_content,
|
||||||
|
)
|
||||||
|
except Exception as responses_error:
|
||||||
|
if not self._should_fallback_from_responses_error(responses_error):
|
||||||
|
raise
|
||||||
|
|
||||||
|
kwargs = self._build_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature,
|
||||||
|
reasoning_effort, tool_choice,
|
||||||
|
)
|
||||||
|
kwargs["stream"] = True
|
||||||
|
kwargs["stream_options"] = {"include_usage": True}
|
||||||
stream = await self._client.chat.completions.create(**kwargs)
|
stream = await self._client.chat.completions.create(**kwargs)
|
||||||
chunks: list[Any] = []
|
chunks: list[Any] = []
|
||||||
stream_iter = stream.__aiter__()
|
stream_iter = stream.__aiter__()
|
||||||
@ -786,7 +947,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
error_kind="timeout",
|
error_kind="timeout",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self._handle_error(e)
|
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
return self.default_model
|
return self.default_model
|
||||||
|
|||||||
@ -155,6 +155,7 @@ class SessionManager:
|
|||||||
messages = []
|
messages = []
|
||||||
metadata = {}
|
metadata = {}
|
||||||
created_at = None
|
created_at = None
|
||||||
|
updated_at = None
|
||||||
last_consolidated = 0
|
last_consolidated = 0
|
||||||
|
|
||||||
with open(path, encoding="utf-8") as f:
|
with open(path, encoding="utf-8") as f:
|
||||||
@ -168,6 +169,7 @@ class SessionManager:
|
|||||||
if data.get("_type") == "metadata":
|
if data.get("_type") == "metadata":
|
||||||
metadata = data.get("metadata", {})
|
metadata = data.get("metadata", {})
|
||||||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||||||
|
updated_at = datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None
|
||||||
last_consolidated = data.get("last_consolidated", 0)
|
last_consolidated = data.get("last_consolidated", 0)
|
||||||
else:
|
else:
|
||||||
messages.append(data)
|
messages.append(data)
|
||||||
@ -176,6 +178,7 @@ class SessionManager:
|
|||||||
key=key,
|
key=key,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
created_at=created_at or datetime.now(),
|
created_at=created_at or datetime.now(),
|
||||||
|
updated_at=updated_at or datetime.now(),
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
last_consolidated=last_consolidated
|
last_consolidated=last_consolidated
|
||||||
)
|
)
|
||||||
|
|||||||
@ -3,6 +3,7 @@ Compare conversation history against current memory files. Also scan memory file
|
|||||||
Output one line per finding:
|
Output one line per finding:
|
||||||
[FILE] atomic fact (not already in memory)
|
[FILE] atomic fact (not already in memory)
|
||||||
[FILE-REMOVE] reason for removal
|
[FILE-REMOVE] reason for removal
|
||||||
|
[SKILL] kebab-case-name: one-line description of the reusable pattern
|
||||||
|
|
||||||
Files: USER (identity, preferences), SOUL (bot behavior, tone), MEMORY (knowledge, project context)
|
Files: USER (identity, preferences), SOUL (bot behavior, tone), MEMORY (knowledge, project context)
|
||||||
|
|
||||||
@ -18,6 +19,12 @@ Staleness — flag for [FILE-REMOVE]:
|
|||||||
- Detailed incident info after 14 days — reduce to one-line summary
|
- Detailed incident info after 14 days — reduce to one-line summary
|
||||||
- Superseded: approaches replaced by newer solutions, deprecated dependencies
|
- Superseded: approaches replaced by newer solutions, deprecated dependencies
|
||||||
|
|
||||||
|
Skill discovery — flag [SKILL] when ALL of these are true:
|
||||||
|
- A specific, repeatable workflow appeared 2+ times in the conversation history
|
||||||
|
- It involves clear steps (not vague preferences like "likes concise answers")
|
||||||
|
- It is substantial enough to warrant its own instruction set (not trivial like "read a file")
|
||||||
|
- Do not worry about duplicates — the next phase will check against existing skills
|
||||||
|
|
||||||
Do not add: current weather, transient status, temporary errors, conversational filler.
|
Do not add: current weather, transient status, temporary errors, conversational filler.
|
||||||
|
|
||||||
[SKIP] if nothing needs updating.
|
[SKIP] if nothing needs updating.
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
Update memory files based on the analysis below.
|
Update memory files based on the analysis below.
|
||||||
- [FILE] entries: add the described content to the appropriate file
|
- [FILE] entries: add the described content to the appropriate file
|
||||||
- [FILE-REMOVE] entries: delete the corresponding content from memory files
|
- [FILE-REMOVE] entries: delete the corresponding content from memory files
|
||||||
|
- [SKILL] entries: create a new skill under skills/<name>/SKILL.md using write_file
|
||||||
|
|
||||||
## File paths (relative to workspace root)
|
## File paths (relative to workspace root)
|
||||||
- SOUL.md
|
- SOUL.md
|
||||||
- USER.md
|
- USER.md
|
||||||
- memory/MEMORY.md
|
- memory/MEMORY.md
|
||||||
|
- skills/<name>/SKILL.md (for [SKILL] entries only)
|
||||||
|
|
||||||
Do NOT guess paths.
|
Do NOT guess paths.
|
||||||
|
|
||||||
@ -17,6 +19,17 @@ Do NOT guess paths.
|
|||||||
- Surgical edits only — never rewrite entire files
|
- Surgical edits only — never rewrite entire files
|
||||||
- If nothing to update, stop without calling tools
|
- If nothing to update, stop without calling tools
|
||||||
|
|
||||||
|
## Skill creation rules (for [SKILL] entries)
|
||||||
|
- Use write_file to create skills/<name>/SKILL.md
|
||||||
|
- Before writing, read_file `{{ skill_creator_path }}` for format reference (frontmatter structure, naming conventions, quality standards)
|
||||||
|
- **Dedup check**: read existing skills listed below to verify the new skill is not functionally redundant. Skip creation if an existing skill already covers the same workflow.
|
||||||
|
- Include YAML frontmatter with name and description fields
|
||||||
|
- Keep SKILL.md under 2000 words — concise and actionable
|
||||||
|
- Include: when to use, steps, output format, at least one example
|
||||||
|
- Do NOT overwrite existing skills — skip if the skill directory already exists
|
||||||
|
- Reference specific tools the agent has access to (read_file, write_file, exec, web_search, etc.)
|
||||||
|
- Skills are instruction sets, not code — do not include implementation code
|
||||||
|
|
||||||
## Quality
|
## Quality
|
||||||
- Every line must carry standalone value
|
- Every line must carry standalone value
|
||||||
- Concise bullets under clear headers
|
- Concise bullets under clear headers
|
||||||
|
|||||||
@ -15,9 +15,12 @@ from loguru import logger
|
|||||||
|
|
||||||
|
|
||||||
def strip_think(text: str) -> str:
|
def strip_think(text: str) -> str:
|
||||||
"""Remove <think>…</think> blocks and any unclosed trailing <think> tag."""
|
"""Remove thinking blocks and any unclosed trailing tag."""
|
||||||
text = re.sub(r"<think>[\s\S]*?</think>", "", text)
|
text = re.sub(r"<think>[\s\S]*?</think>", "", text)
|
||||||
text = re.sub(r"<think>[\s\S]*$", "", text)
|
text = re.sub(r"^\s*<think>[\s\S]*$", "", text)
|
||||||
|
# Gemma 4 and similar models use <thought>...</thought> blocks
|
||||||
|
text = re.sub(r"<thought>[\s\S]*?</thought>", "", text)
|
||||||
|
text = re.sub(r"^\s*<thought>[\s\S]*$", "", text)
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
@ -272,7 +275,7 @@ def build_assistant_message(
|
|||||||
thinking_blocks: list[dict] | None = None,
|
thinking_blocks: list[dict] | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Build a provider-safe assistant message with optional reasoning fields."""
|
"""Build a provider-safe assistant message with optional reasoning fields."""
|
||||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
msg: dict[str, Any] = {"role": "assistant", "content": content or ""}
|
||||||
if tool_calls:
|
if tool_calls:
|
||||||
msg["tool_calls"] = tool_calls
|
msg["tool_calls"] = tool_calls
|
||||||
if reasoning_content is not None or thinking_blocks:
|
if reasoning_content is not None or thinking_blocks:
|
||||||
@ -417,7 +420,7 @@ def build_status_content(
|
|||||||
ctx_total = max(context_window_tokens, 0)
|
ctx_total = max(context_window_tokens, 0)
|
||||||
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
|
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
|
||||||
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
|
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
|
||||||
ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a"
|
ctx_total_str = f"{ctx_total // 1000}k" if ctx_total > 0 else "n/a"
|
||||||
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
|
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
|
||||||
if cached and last_in:
|
if cached and last_in:
|
||||||
token_line += f" ({cached * 100 // last_in}% cached)"
|
token_line += f" ({cached * 100 // last_in}% cached)"
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
from nanobot.utils.path import abbreviate_path
|
from nanobot.utils.path import abbreviate_path
|
||||||
|
|
||||||
# Registry: tool_name -> (key_args, template, is_path, is_command)
|
# Registry: tool_name -> (key_args, template, is_path, is_command)
|
||||||
@ -17,27 +19,39 @@ _TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = {
|
|||||||
"list_dir": (["path"], "ls {}", True, False),
|
"list_dir": (["path"], "ls {}", True, False),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Matches file paths embedded in shell commands, including quoted paths with spaces.
|
||||||
|
_PATH_IN_CMD_RE = re.compile(
|
||||||
|
r'"(?P<double>(?:[A-Za-z]:[/\\]|~/|/)[^"]+)"'
|
||||||
|
r"|'(?P<single>(?:[A-Za-z]:[/\\]|~/|/)[^']+)'"
|
||||||
|
r"|(?P<bare>(?:[A-Za-z]:[/\\]|~/|(?<=\s)/)[^\s;&|<>\"']+)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def format_tool_hints(tool_calls: list) -> str:
|
def format_tool_hints(tool_calls: list) -> str:
|
||||||
"""Format tool calls as concise hints with smart abbreviation."""
|
"""Format tool calls as concise hints with smart abbreviation."""
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
hints = []
|
formatted = []
|
||||||
for name, count, example_tc in _group_consecutive(tool_calls):
|
for tc in tool_calls:
|
||||||
fmt = _TOOL_FORMATS.get(name)
|
fmt = _TOOL_FORMATS.get(tc.name)
|
||||||
if fmt:
|
if fmt:
|
||||||
hint = _fmt_known(example_tc, fmt)
|
formatted.append(_fmt_known(tc, fmt))
|
||||||
elif name.startswith("mcp_"):
|
elif tc.name.startswith("mcp_"):
|
||||||
hint = _fmt_mcp(example_tc)
|
formatted.append(_fmt_mcp(tc))
|
||||||
else:
|
else:
|
||||||
hint = _fmt_fallback(example_tc)
|
formatted.append(_fmt_fallback(tc))
|
||||||
|
|
||||||
if count > 1:
|
hints = []
|
||||||
hint = f"{hint} \u00d7 {count}"
|
for hint in formatted:
|
||||||
hints.append(hint)
|
if hints and hints[-1][0] == hint:
|
||||||
|
hints[-1] = (hint, hints[-1][1] + 1)
|
||||||
|
else:
|
||||||
|
hints.append((hint, 1))
|
||||||
|
|
||||||
return ", ".join(hints)
|
return ", ".join(
|
||||||
|
f"{h} \u00d7 {c}" if c > 1 else h for h, c in hints
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_args(tc) -> dict:
|
def _get_args(tc) -> dict:
|
||||||
@ -51,17 +65,6 @@ def _get_args(tc) -> dict:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def _group_consecutive(calls: list) -> list[tuple[str, int, object]]:
|
|
||||||
"""Group consecutive calls to the same tool: [(name, count, first), ...]."""
|
|
||||||
groups: list[tuple[str, int, object]] = []
|
|
||||||
for tc in calls:
|
|
||||||
if groups and groups[-1][0] == tc.name:
|
|
||||||
groups[-1] = (groups[-1][0], groups[-1][1] + 1, groups[-1][2])
|
|
||||||
else:
|
|
||||||
groups.append((tc.name, 1, tc))
|
|
||||||
return groups
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_arg(tc, key_args: list[str]) -> str | None:
|
def _extract_arg(tc, key_args: list[str]) -> str | None:
|
||||||
"""Extract the first available value from preferred key names."""
|
"""Extract the first available value from preferred key names."""
|
||||||
args = _get_args(tc)
|
args = _get_args(tc)
|
||||||
@ -85,10 +88,25 @@ def _fmt_known(tc, fmt: tuple) -> str:
|
|||||||
if fmt[2]: # is_path
|
if fmt[2]: # is_path
|
||||||
val = abbreviate_path(val)
|
val = abbreviate_path(val)
|
||||||
elif fmt[3]: # is_command
|
elif fmt[3]: # is_command
|
||||||
val = val[:40] + "\u2026" if len(val) > 40 else val
|
val = _abbreviate_command(val)
|
||||||
return fmt[1].format(val)
|
return fmt[1].format(val)
|
||||||
|
|
||||||
|
|
||||||
|
def _abbreviate_command(cmd: str, max_len: int = 40) -> str:
|
||||||
|
"""Abbreviate paths in a command string, then truncate."""
|
||||||
|
def _replace_path(match: re.Match[str]) -> str:
|
||||||
|
if match.group("double") is not None:
|
||||||
|
return f'"{abbreviate_path(match.group("double"), max_len=25)}"'
|
||||||
|
if match.group("single") is not None:
|
||||||
|
return f"'{abbreviate_path(match.group('single'), max_len=25)}'"
|
||||||
|
return abbreviate_path(match.group("bare"), max_len=25)
|
||||||
|
|
||||||
|
abbreviated = _PATH_IN_CMD_RE.sub(_replace_path, cmd)
|
||||||
|
if len(abbreviated) <= max_len:
|
||||||
|
return abbreviated
|
||||||
|
return abbreviated[:max_len - 1] + "\u2026"
|
||||||
|
|
||||||
|
|
||||||
def _fmt_mcp(tc) -> str:
|
def _fmt_mcp(tc) -> str:
|
||||||
"""Format MCP tool as server::tool."""
|
"""Format MCP tool as server::tool."""
|
||||||
name = tc.name
|
name = tc.name
|
||||||
|
|||||||
@ -54,6 +54,7 @@ dependencies = [
|
|||||||
"python-docx>=1.1.0,<2.0.0",
|
"python-docx>=1.1.0,<2.0.0",
|
||||||
"openpyxl>=3.1.0,<4.0.0",
|
"openpyxl>=3.1.0,<4.0.0",
|
||||||
"python-pptx>=1.0.0,<2.0.0",
|
"python-pptx>=1.0.0,<2.0.0",
|
||||||
|
"filelock>=3.25.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@ -79,12 +80,16 @@ discord = [
|
|||||||
langsmith = [
|
langsmith = [
|
||||||
"langsmith>=0.1.0",
|
"langsmith>=0.1.0",
|
||||||
]
|
]
|
||||||
|
pdf = [
|
||||||
|
"pymupdf>=1.25.0",
|
||||||
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest>=9.0.0,<10.0.0",
|
"pytest>=9.0.0,<10.0.0",
|
||||||
"pytest-asyncio>=1.3.0,<2.0.0",
|
"pytest-asyncio>=1.3.0,<2.0.0",
|
||||||
"aiohttp>=3.9.0,<4.0.0",
|
"aiohttp>=3.9.0,<4.0.0",
|
||||||
"pytest-cov>=6.0.0,<7.0.0",
|
"pytest-cov>=6.0.0,<7.0.0",
|
||||||
"ruff>=0.1.0",
|
"ruff>=0.1.0",
|
||||||
|
"pymupdf>=1.25.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
1010
tests/agent/test_auto_compact.py
Normal file
1010
tests/agent/test_auto_compact.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -46,7 +46,7 @@ class TestConsolidatorSummarize:
|
|||||||
{"role": "assistant", "content": "Done, fixed the race condition."},
|
{"role": "assistant", "content": "Done, fixed the race condition."},
|
||||||
]
|
]
|
||||||
result = await consolidator.archive(messages)
|
result = await consolidator.archive(messages)
|
||||||
assert result is True
|
assert result == "User fixed a bug in the auth module."
|
||||||
entries = store.read_unprocessed_history(since_cursor=0)
|
entries = store.read_unprocessed_history(since_cursor=0)
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
|
|
||||||
@ -55,14 +55,14 @@ class TestConsolidatorSummarize:
|
|||||||
mock_provider.chat_with_retry.side_effect = Exception("API error")
|
mock_provider.chat_with_retry.side_effect = Exception("API error")
|
||||||
messages = [{"role": "user", "content": "hello"}]
|
messages = [{"role": "user", "content": "hello"}]
|
||||||
result = await consolidator.archive(messages)
|
result = await consolidator.archive(messages)
|
||||||
assert result is True # always succeeds
|
assert result is None # no summary on raw dump fallback
|
||||||
entries = store.read_unprocessed_history(since_cursor=0)
|
entries = store.read_unprocessed_history(since_cursor=0)
|
||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
assert "[RAW]" in entries[0]["content"]
|
assert "[RAW]" in entries[0]["content"]
|
||||||
|
|
||||||
async def test_summarize_skips_empty_messages(self, consolidator):
|
async def test_summarize_skips_empty_messages(self, consolidator):
|
||||||
result = await consolidator.archive([])
|
result = await consolidator.archive([])
|
||||||
assert result is False
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
class TestConsolidatorTokenBudget:
|
class TestConsolidatorTokenBudget:
|
||||||
@ -76,3 +76,52 @@ class TestConsolidatorTokenBudget:
|
|||||||
consolidator.archive = AsyncMock(return_value=True)
|
consolidator.archive = AsyncMock(return_value=True)
|
||||||
await consolidator.maybe_consolidate_by_tokens(session)
|
await consolidator.maybe_consolidate_by_tokens(session)
|
||||||
consolidator.archive.assert_not_called()
|
consolidator.archive.assert_not_called()
|
||||||
|
|
||||||
|
async def test_chunk_cap_preserves_user_turn_boundary(self, consolidator):
|
||||||
|
"""Chunk cap should rewind to the last user boundary within the cap."""
|
||||||
|
consolidator._SAFETY_BUFFER = 0
|
||||||
|
session = MagicMock()
|
||||||
|
session.last_consolidated = 0
|
||||||
|
session.key = "test:key"
|
||||||
|
session.messages = [
|
||||||
|
{
|
||||||
|
"role": "user" if i in {0, 50, 61} else "assistant",
|
||||||
|
"content": f"m{i}",
|
||||||
|
}
|
||||||
|
for i in range(70)
|
||||||
|
]
|
||||||
|
consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||||
|
side_effect=[(1200, "tiktoken"), (400, "tiktoken")]
|
||||||
|
)
|
||||||
|
consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999))
|
||||||
|
consolidator.archive = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
await consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
|
archived_chunk = consolidator.archive.await_args.args[0]
|
||||||
|
assert len(archived_chunk) == 50
|
||||||
|
assert archived_chunk[0]["content"] == "m0"
|
||||||
|
assert archived_chunk[-1]["content"] == "m49"
|
||||||
|
assert session.last_consolidated == 50
|
||||||
|
|
||||||
|
async def test_chunk_cap_skips_when_no_user_boundary_within_cap(self, consolidator):
|
||||||
|
"""If the cap would cut mid-turn, consolidation should skip that round."""
|
||||||
|
consolidator._SAFETY_BUFFER = 0
|
||||||
|
session = MagicMock()
|
||||||
|
session.last_consolidated = 0
|
||||||
|
session.key = "test:key"
|
||||||
|
session.messages = [
|
||||||
|
{
|
||||||
|
"role": "user" if i in {0, 61} else "assistant",
|
||||||
|
"content": f"m{i}",
|
||||||
|
}
|
||||||
|
for i in range(70)
|
||||||
|
]
|
||||||
|
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(1200, "tiktoken"))
|
||||||
|
consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999))
|
||||||
|
consolidator.archive = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
await consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
|
consolidator.archive.assert_not_awaited()
|
||||||
|
assert session.last_consolidated == 0
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
|
|
||||||
from nanobot.agent.memory import Dream, MemoryStore
|
from nanobot.agent.memory import Dream, MemoryStore
|
||||||
from nanobot.agent.runner import AgentRunResult
|
from nanobot.agent.runner import AgentRunResult
|
||||||
|
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -95,3 +96,30 @@ class TestDreamRun:
|
|||||||
entries = store.read_unprocessed_history(since_cursor=0)
|
entries = store.read_unprocessed_history(since_cursor=0)
|
||||||
assert all(e["cursor"] > 0 for e in entries)
|
assert all(e["cursor"] > 0 for e in entries)
|
||||||
|
|
||||||
|
async def test_skill_phase_uses_builtin_skill_creator_path(self, dream, mock_provider, mock_runner, store):
|
||||||
|
"""Dream should point skill creation guidance at the builtin skill-creator template."""
|
||||||
|
store.append_history("Repeated workflow one")
|
||||||
|
store.append_history("Repeated workflow two")
|
||||||
|
mock_provider.chat_with_retry.return_value = MagicMock(content="[SKILL] test-skill: test description")
|
||||||
|
mock_runner.run = AsyncMock(return_value=_make_run_result())
|
||||||
|
|
||||||
|
await dream.run()
|
||||||
|
|
||||||
|
spec = mock_runner.run.call_args[0][0]
|
||||||
|
system_prompt = spec.initial_messages[0]["content"]
|
||||||
|
expected = str(BUILTIN_SKILLS_DIR / "skill-creator" / "SKILL.md")
|
||||||
|
assert expected in system_prompt
|
||||||
|
|
||||||
|
async def test_skill_write_tool_accepts_workspace_relative_skill_path(self, dream, store):
|
||||||
|
"""Dream skill creation should allow skills/<name>/SKILL.md relative to workspace root."""
|
||||||
|
write_tool = dream._tools.get("write_file")
|
||||||
|
assert write_tool is not None
|
||||||
|
|
||||||
|
result = await write_tool.execute(
|
||||||
|
path="skills/test-skill/SKILL.md",
|
||||||
|
content="---\nname: test-skill\ndescription: Test\n---\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Successfully wrote" in result
|
||||||
|
assert (store.workspace / "skills" / "test-skill" / "SKILL.md").exists()
|
||||||
|
|
||||||
|
|||||||
@ -184,17 +184,22 @@ def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None:
|
|||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
provider = OpenAICompatProvider()
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
messages = [{
|
messages = [
|
||||||
"role": "assistant",
|
{"role": "user", "content": "hi"},
|
||||||
"content": None,
|
{
|
||||||
"tool_calls": [{
|
"role": "assistant",
|
||||||
"id": "call_1",
|
"content": None,
|
||||||
"type": "function",
|
"tool_calls": [{
|
||||||
"function": {"name": "fn", "arguments": "{}"},
|
"id": "call_1",
|
||||||
"extra_content": GEMINI_EXTRA,
|
"type": "function",
|
||||||
}],
|
"function": {"name": "fn", "arguments": "{}"},
|
||||||
}]
|
"extra_content": GEMINI_EXTRA,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
{"role": "tool", "content": "ok", "tool_call_id": "call_1"},
|
||||||
|
{"role": "user", "content": "thanks"},
|
||||||
|
]
|
||||||
|
|
||||||
sanitized = provider._sanitize_messages(messages)
|
sanitized = provider._sanitize_messages(messages)
|
||||||
|
|
||||||
assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA
|
assert sanitized[1]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA
|
||||||
|
|||||||
@ -232,6 +232,35 @@ async def test_composite_empty_hooks_no_ops():
|
|||||||
assert hook.finalize_content(ctx, "test") == "test"
|
assert hook.finalize_content(ctx, "test") == "test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_composite_supports_legacy_hook_init_without_super():
|
||||||
|
calls: list[str] = []
|
||||||
|
|
||||||
|
class LegacyHook(AgentHook):
|
||||||
|
def __init__(self, label: str) -> None:
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||||
|
calls.append(self.label)
|
||||||
|
|
||||||
|
hook = CompositeHook([LegacyHook("legacy")])
|
||||||
|
await hook.before_iteration(_ctx())
|
||||||
|
assert calls == ["legacy"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_composite_can_wrap_another_composite():
|
||||||
|
calls: list[str] = []
|
||||||
|
|
||||||
|
class Inner(AgentHook):
|
||||||
|
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||||
|
calls.append("inner")
|
||||||
|
|
||||||
|
hook = CompositeHook([CompositeHook([Inner()])])
|
||||||
|
await hook.before_iteration(_ctx())
|
||||||
|
assert calls == ["inner"]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Integration: AgentLoop with extra hooks
|
# Integration: AgentLoop with extra hooks
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -278,7 +307,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
|||||||
)
|
)
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
content, tools_used, messages = await loop._run_agent_loop(
|
content, tools_used, messages, _, _ = await loop._run_agent_loop(
|
||||||
[{"role": "user", "content": "hi"}]
|
[{"role": "user", "content": "hi"}]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -302,7 +331,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path):
|
|||||||
)
|
)
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
content, _, _ = await loop._run_agent_loop(
|
content, _, _, _, _ = await loop._run_agent_loop(
|
||||||
[{"role": "user", "content": "hi"}]
|
[{"role": "user", "content": "hi"}]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -344,7 +373,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path):
|
|||||||
loop.tools.execute = AsyncMock(return_value="ok")
|
loop.tools.execute = AsyncMock(return_value="ok")
|
||||||
loop.max_iterations = 2
|
loop.max_iterations = 2
|
||||||
|
|
||||||
content, tools_used, _ = await loop._run_agent_loop([])
|
content, tools_used, _, _, _ = await loop._run_agent_loop([])
|
||||||
assert content == (
|
assert content == (
|
||||||
"I reached the maximum number of tool call iterations (2) "
|
"I reached the maximum number of tool call iterations (2) "
|
||||||
"without completing the task. You can try breaking the task into smaller steps."
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
|
|||||||
@ -1,5 +1,13 @@
|
|||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.session.manager import Session
|
from nanobot.session.manager import Session
|
||||||
|
|
||||||
|
|
||||||
@ -11,6 +19,12 @@ def _mk_loop() -> AgentLoop:
|
|||||||
return loop
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
def _make_full_loop(tmp_path: Path) -> AgentLoop:
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
||||||
|
|
||||||
|
|
||||||
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
||||||
loop = _mk_loop()
|
loop = _mk_loop()
|
||||||
session = Session(key="test:runtime-only")
|
session = Session(key="test:runtime-only")
|
||||||
@ -200,3 +214,206 @@ def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
|
|||||||
assert session.messages[0]["role"] == "assistant"
|
assert session.messages[0]["role"] == "assistant"
|
||||||
assert session.messages[1]["tool_call_id"] == "call_done"
|
assert session.messages[1]["tool_call_id"] == "call_done"
|
||||||
assert session.messages[2]["tool_call_id"] == "call_pending"
|
assert session.messages[2]["tool_call_id"] == "call_pending"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_message_persists_user_message_before_turn_completes(tmp_path: Path) -> None:
|
||||||
|
loop = _make_full_loop(tmp_path)
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
loop._run_agent_loop = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c1", content="persist me")
|
||||||
|
with pytest.raises(RuntimeError, match="boom"):
|
||||||
|
await loop._process_message(msg)
|
||||||
|
|
||||||
|
loop.sessions.invalidate("feishu:c1")
|
||||||
|
persisted = loop.sessions.get_or_create("feishu:c1")
|
||||||
|
assert [m["role"] for m in persisted.messages] == ["user"]
|
||||||
|
assert persisted.messages[0]["content"] == "persist me"
|
||||||
|
assert persisted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True
|
||||||
|
assert persisted.updated_at >= persisted.created_at
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_message_does_not_duplicate_early_persisted_user_message(tmp_path: Path) -> None:
|
||||||
|
loop = _make_full_loop(tmp_path)
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
loop._run_agent_loop = AsyncMock(return_value=(
|
||||||
|
"done",
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "done"},
|
||||||
|
],
|
||||||
|
"stop",
|
||||||
|
False,
|
||||||
|
)) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
result = await loop._process_message(
|
||||||
|
InboundMessage(channel="feishu", sender_id="u1", chat_id="c2", content="hello")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.content == "done"
|
||||||
|
session = loop.sessions.get_or_create("feishu:c2")
|
||||||
|
assert [
|
||||||
|
{k: v for k, v in m.items() if k in {"role", "content"}}
|
||||||
|
for m in session.messages
|
||||||
|
] == [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "done"},
|
||||||
|
]
|
||||||
|
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(tmp_path: Path) -> None:
|
||||||
|
loop = _make_full_loop(tmp_path)
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
loop.provider.chat_with_retry = AsyncMock(return_value=MagicMock()) # unused because _run_agent_loop is stubbed
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("feishu:c3")
|
||||||
|
session.add_message("user", "old question")
|
||||||
|
session.metadata[AgentLoop._PENDING_USER_TURN_KEY] = True
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
loop._run_agent_loop = AsyncMock(return_value=(
|
||||||
|
"new answer",
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "old question"},
|
||||||
|
{"role": "assistant", "content": "Error: Task interrupted before a response was generated."},
|
||||||
|
{"role": "user", "content": "new question"},
|
||||||
|
{"role": "assistant", "content": "new answer"},
|
||||||
|
],
|
||||||
|
"stop",
|
||||||
|
False,
|
||||||
|
)) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
result = await loop._process_message(
|
||||||
|
InboundMessage(channel="feishu", sender_id="u1", chat_id="c3", content="new question")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.content == "new answer"
|
||||||
|
session = loop.sessions.get_or_create("feishu:c3")
|
||||||
|
assert [
|
||||||
|
{k: v for k, v in m.items() if k in {"role", "content"}}
|
||||||
|
for m in session.messages
|
||||||
|
] == [
|
||||||
|
{"role": "user", "content": "old question"},
|
||||||
|
{"role": "assistant", "content": "Error: Task interrupted before a response was generated."},
|
||||||
|
{"role": "user", "content": "new question"},
|
||||||
|
{"role": "assistant", "content": "new answer"},
|
||||||
|
]
|
||||||
|
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_preserves_runtime_checkpoint_for_next_turn(tmp_path: Path) -> None:
|
||||||
|
from nanobot.command.builtin import cmd_stop
|
||||||
|
from nanobot.command.router import CommandContext
|
||||||
|
|
||||||
|
loop = _make_full_loop(tmp_path)
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
checkpoint_saved = asyncio.Event()
|
||||||
|
|
||||||
|
async def interrupted_run_agent_loop(_initial_messages, *, session=None, **_kwargs):
|
||||||
|
assert session is not None
|
||||||
|
loop._set_runtime_checkpoint(
|
||||||
|
session,
|
||||||
|
{
|
||||||
|
"assistant_message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "working",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_done",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"completed_tool_results": [
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_done",
|
||||||
|
"name": "read_file",
|
||||||
|
"content": "ok",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"pending_tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
checkpoint_saved.set()
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
loop._run_agent_loop = interrupted_run_agent_loop # type: ignore[method-assign]
|
||||||
|
|
||||||
|
first_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="keep progress")
|
||||||
|
task = asyncio.create_task(loop._process_message(first_msg))
|
||||||
|
loop._active_tasks[first_msg.session_key] = [task]
|
||||||
|
await asyncio.wait_for(checkpoint_saved.wait(), timeout=1.0)
|
||||||
|
|
||||||
|
stop_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="/stop")
|
||||||
|
stop_ctx = CommandContext(msg=stop_msg, session=None, key=stop_msg.session_key, raw="/stop", loop=loop)
|
||||||
|
stop_result = await cmd_stop(stop_ctx)
|
||||||
|
|
||||||
|
assert "Stopped 1 task" in stop_result.content
|
||||||
|
assert task.done()
|
||||||
|
|
||||||
|
loop.sessions.invalidate("feishu:c4")
|
||||||
|
interrupted = loop.sessions.get_or_create("feishu:c4")
|
||||||
|
assert interrupted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True
|
||||||
|
assert interrupted.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is not None
|
||||||
|
|
||||||
|
async def resumed_run_agent_loop(initial_messages, **_kwargs):
|
||||||
|
return (
|
||||||
|
"next answer",
|
||||||
|
None,
|
||||||
|
[*initial_messages, {"role": "assistant", "content": "next answer"}],
|
||||||
|
"stop",
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
loop._run_agent_loop = resumed_run_agent_loop # type: ignore[method-assign]
|
||||||
|
result = await loop._process_message(
|
||||||
|
InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="continue here")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.content == "next answer"
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("feishu:c4")
|
||||||
|
assert [
|
||||||
|
{k: v for k, v in m.items() if k in {"role", "content", "tool_call_id", "name"}}
|
||||||
|
for m in session.messages
|
||||||
|
] == [
|
||||||
|
{"role": "user", "content": "keep progress"},
|
||||||
|
{"role": "assistant", "content": "working"},
|
||||||
|
{"role": "tool", "tool_call_id": "call_done", "name": "read_file", "content": "ok"},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_pending",
|
||||||
|
"name": "exec",
|
||||||
|
"content": "Error: Task interrupted before this tool finished.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "continue here"},
|
||||||
|
{"role": "assistant", "content": "next answer"},
|
||||||
|
]
|
||||||
|
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||||
|
assert AgentLoop._RUNTIME_CHECKPOINT_KEY not in session.metadata
|
||||||
|
|||||||
44
tests/agent/test_mcp_connection.py
Normal file
44
tests/agent/test_mcp_connection.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
"""Tests for MCP connection lifecycle in AgentLoop."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(tmp_path, *, mcp_servers: dict | None = None) -> AgentLoop:
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.generation.max_tokens = 4096
|
||||||
|
return AgentLoop(
|
||||||
|
bus=bus,
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="test-model",
|
||||||
|
mcp_servers=mcp_servers or {"test": object()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_mcp_retries_when_no_servers_connect(tmp_path, monkeypatch: pytest.MonkeyPatch):
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
attempts = 0
|
||||||
|
|
||||||
|
async def _fake_connect(_servers, _registry):
|
||||||
|
nonlocal attempts
|
||||||
|
attempts += 1
|
||||||
|
return {}
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
||||||
|
|
||||||
|
await loop._connect_mcp()
|
||||||
|
await loop._connect_mcp()
|
||||||
|
|
||||||
|
assert attempts == 2
|
||||||
|
assert loop._mcp_connected is False
|
||||||
|
assert loop._mcp_stacks == {}
|
||||||
File diff suppressed because it is too large
Load Diff
@ -250,3 +250,63 @@ def test_list_skills_openclaw_metadata_parsed_for_requirements(
|
|||||||
assert entries == [
|
assert entries == [
|
||||||
{"name": "openclaw_skill", "path": str(skill_path), "source": "workspace"},
|
{"name": "openclaw_skill", "path": str(skill_path), "source": "workspace"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_disabled_skills_excluded_from_list(tmp_path: Path) -> None:
|
||||||
|
workspace = tmp_path / "ws"
|
||||||
|
ws_skills = workspace / "skills"
|
||||||
|
ws_skills.mkdir(parents=True)
|
||||||
|
_write_skill(ws_skills, "alpha", body="# Alpha")
|
||||||
|
beta_path = _write_skill(ws_skills, "beta", body="# Beta")
|
||||||
|
builtin = tmp_path / "builtin"
|
||||||
|
builtin.mkdir()
|
||||||
|
|
||||||
|
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"})
|
||||||
|
entries = loader.list_skills(filter_unavailable=False)
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0]["name"] == "beta"
|
||||||
|
assert entries[0]["path"] == str(beta_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disabled_skills_empty_set_no_effect(tmp_path: Path) -> None:
|
||||||
|
workspace = tmp_path / "ws"
|
||||||
|
ws_skills = workspace / "skills"
|
||||||
|
ws_skills.mkdir(parents=True)
|
||||||
|
_write_skill(ws_skills, "alpha", body="# Alpha")
|
||||||
|
_write_skill(ws_skills, "beta", body="# Beta")
|
||||||
|
builtin = tmp_path / "builtin"
|
||||||
|
builtin.mkdir()
|
||||||
|
|
||||||
|
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills=set())
|
||||||
|
entries = loader.list_skills(filter_unavailable=False)
|
||||||
|
assert len(entries) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_disabled_skills_excluded_from_build_skills_summary(tmp_path: Path) -> None:
|
||||||
|
workspace = tmp_path / "ws"
|
||||||
|
ws_skills = workspace / "skills"
|
||||||
|
ws_skills.mkdir(parents=True)
|
||||||
|
_write_skill(ws_skills, "alpha", body="# Alpha")
|
||||||
|
_write_skill(ws_skills, "beta", body="# Beta")
|
||||||
|
builtin = tmp_path / "builtin"
|
||||||
|
builtin.mkdir()
|
||||||
|
|
||||||
|
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"})
|
||||||
|
summary = loader.build_skills_summary()
|
||||||
|
assert "alpha" not in summary
|
||||||
|
assert "beta" in summary
|
||||||
|
|
||||||
|
|
||||||
|
def test_disabled_skills_excluded_from_get_always_skills(tmp_path: Path) -> None:
|
||||||
|
workspace = tmp_path / "ws"
|
||||||
|
ws_skills = workspace / "skills"
|
||||||
|
ws_skills.mkdir(parents=True)
|
||||||
|
_write_skill(ws_skills, "alpha", metadata_json={"always": True}, body="# Alpha")
|
||||||
|
_write_skill(ws_skills, "beta", metadata_json={"always": True}, body="# Beta")
|
||||||
|
builtin = tmp_path / "builtin"
|
||||||
|
builtin.mkdir()
|
||||||
|
|
||||||
|
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"})
|
||||||
|
always = loader.get_always_skills()
|
||||||
|
assert "alpha" not in always
|
||||||
|
assert "beta" in always
|
||||||
|
|||||||
@ -52,6 +52,53 @@ class TestToolHintKnownTools:
|
|||||||
assert result.startswith("$ ")
|
assert result.startswith("$ ")
|
||||||
assert len(result) <= 50 # reasonable limit
|
assert len(result) <= 50 # reasonable limit
|
||||||
|
|
||||||
|
def test_exec_abbreviates_paths_in_command(self):
|
||||||
|
"""Windows paths in exec commands should be folded, not blindly truncated."""
|
||||||
|
cmd = "cd D:\\Documents\\GitHub\\nanobot\\.worktree\\tomain\\nanobot && git diff origin/main...pr-2706 --name-only 2>&1"
|
||||||
|
result = _hint([_tc("exec", {"command": cmd})])
|
||||||
|
assert "\u2026/" in result # path should be folded with …/
|
||||||
|
assert "worktree" not in result # middle segments should be collapsed
|
||||||
|
|
||||||
|
def test_exec_abbreviates_linux_paths(self):
|
||||||
|
"""Unix absolute paths in exec commands should be folded."""
|
||||||
|
cmd = "cd /home/user/projects/nanobot/.worktree/tomain && make build"
|
||||||
|
result = _hint([_tc("exec", {"command": cmd})])
|
||||||
|
assert "\u2026/" in result
|
||||||
|
assert "projects" not in result
|
||||||
|
|
||||||
|
def test_exec_abbreviates_home_paths(self):
|
||||||
|
"""~/ paths in exec commands should be folded."""
|
||||||
|
cmd = "cd ~/projects/nanobot/workspace && pytest tests/"
|
||||||
|
result = _hint([_tc("exec", {"command": cmd})])
|
||||||
|
assert "\u2026/" in result
|
||||||
|
|
||||||
|
def test_exec_abbreviates_quoted_linux_paths_with_spaces(self):
|
||||||
|
"""Quoted Unix paths with spaces should still be folded."""
|
||||||
|
cmd = 'cd "/home/user/My Documents/project" && pytest tests/'
|
||||||
|
result = _hint([_tc("exec", {"command": cmd})])
|
||||||
|
assert "\u2026/" in result
|
||||||
|
assert '"/home/user/My Documents/project"' not in result
|
||||||
|
assert '"' in result
|
||||||
|
|
||||||
|
def test_exec_abbreviates_quoted_windows_paths_with_spaces(self):
|
||||||
|
"""Quoted Windows paths with spaces should still be folded."""
|
||||||
|
cmd = 'cd "C:/Program Files/Git/project" && git status'
|
||||||
|
result = _hint([_tc("exec", {"command": cmd})])
|
||||||
|
assert "\u2026/" in result
|
||||||
|
assert '"C:/Program Files/Git/project"' not in result
|
||||||
|
assert '"' in result
|
||||||
|
|
||||||
|
def test_exec_short_command_unchanged(self):
|
||||||
|
result = _hint([_tc("exec", {"command": "npm install typescript"})])
|
||||||
|
assert result == "$ npm install typescript"
|
||||||
|
|
||||||
|
def test_exec_chained_commands_truncated_not_mid_path(self):
|
||||||
|
"""Long chained commands should truncate preserving abbreviated paths."""
|
||||||
|
cmd = "cd D:\\Documents\\GitHub\\project && npm run build && npm test"
|
||||||
|
result = _hint([_tc("exec", {"command": cmd})])
|
||||||
|
assert "\u2026/" in result # path folded
|
||||||
|
assert "npm" in result # chained command still visible
|
||||||
|
|
||||||
def test_web_search(self):
|
def test_web_search(self):
|
||||||
result = _hint([_tc("web_search", {"query": "Claude 4 vs GPT-4"})])
|
result = _hint([_tc("web_search", {"query": "Claude 4 vs GPT-4"})])
|
||||||
assert result == 'search "Claude 4 vs GPT-4"'
|
assert result == 'search "Claude 4 vs GPT-4"'
|
||||||
@ -105,22 +152,30 @@ class TestToolHintFolding:
|
|||||||
result = _hint(calls)
|
result = _hint(calls)
|
||||||
assert "\u00d7" not in result
|
assert "\u00d7" not in result
|
||||||
|
|
||||||
def test_two_consecutive_same_folded(self):
|
def test_two_consecutive_different_args_not_folded(self):
|
||||||
calls = [
|
calls = [
|
||||||
_tc("grep", {"pattern": "*.py"}),
|
_tc("grep", {"pattern": "*.py"}),
|
||||||
_tc("grep", {"pattern": "*.ts"}),
|
_tc("grep", {"pattern": "*.ts"}),
|
||||||
]
|
]
|
||||||
result = _hint(calls)
|
result = _hint(calls)
|
||||||
|
assert "\u00d7" not in result
|
||||||
|
|
||||||
|
def test_two_consecutive_same_args_folded(self):
|
||||||
|
calls = [
|
||||||
|
_tc("grep", {"pattern": "TODO"}),
|
||||||
|
_tc("grep", {"pattern": "TODO"}),
|
||||||
|
]
|
||||||
|
result = _hint(calls)
|
||||||
assert "\u00d7 2" in result
|
assert "\u00d7 2" in result
|
||||||
|
|
||||||
def test_three_consecutive_same_folded(self):
|
def test_three_consecutive_different_args_not_folded(self):
|
||||||
calls = [
|
calls = [
|
||||||
_tc("read_file", {"path": "a.py"}),
|
_tc("read_file", {"path": "a.py"}),
|
||||||
_tc("read_file", {"path": "b.py"}),
|
_tc("read_file", {"path": "b.py"}),
|
||||||
_tc("read_file", {"path": "c.py"}),
|
_tc("read_file", {"path": "c.py"}),
|
||||||
]
|
]
|
||||||
result = _hint(calls)
|
result = _hint(calls)
|
||||||
assert "\u00d7 3" in result
|
assert "\u00d7" not in result
|
||||||
|
|
||||||
def test_different_tools_not_folded(self):
|
def test_different_tools_not_folded(self):
|
||||||
calls = [
|
calls = [
|
||||||
@ -187,7 +242,7 @@ class TestToolHintMixedFolding:
|
|||||||
"""G4: Mixed folding groups with interleaved same-tool segments."""
|
"""G4: Mixed folding groups with interleaved same-tool segments."""
|
||||||
|
|
||||||
def test_read_read_grep_grep_read(self):
|
def test_read_read_grep_grep_read(self):
|
||||||
"""read×2, grep×2, read — should produce two separate groups."""
|
"""All different args — each hint listed separately."""
|
||||||
calls = [
|
calls = [
|
||||||
_tc("read_file", {"path": "a.py"}),
|
_tc("read_file", {"path": "a.py"}),
|
||||||
_tc("read_file", {"path": "b.py"}),
|
_tc("read_file", {"path": "b.py"}),
|
||||||
@ -196,7 +251,6 @@ class TestToolHintMixedFolding:
|
|||||||
_tc("read_file", {"path": "c.py"}),
|
_tc("read_file", {"path": "c.py"}),
|
||||||
]
|
]
|
||||||
result = _hint(calls)
|
result = _hint(calls)
|
||||||
assert "\u00d7 2" in result
|
assert "\u00d7" not in result
|
||||||
# Should have 3 groups: read×2, grep×2, read
|
|
||||||
parts = result.split(", ")
|
parts = result.split(", ")
|
||||||
assert len(parts) == 3
|
assert len(parts) == 5
|
||||||
|
|||||||
502
tests/agent/test_unified_session.py
Normal file
502
tests/agent/test_unified_session.py
Normal file
@ -0,0 +1,502 @@
|
|||||||
|
"""Tests for unified_session feature.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- AgentLoop._dispatch() rewrites session_key to "unified:default" when enabled
|
||||||
|
- Existing session_key_override is respected (not overwritten)
|
||||||
|
- Feature is off by default (no behavior change for existing users)
|
||||||
|
- Config schema serialises unified_session as camelCase "unifiedSession"
|
||||||
|
- onboard-generated config.json contains "unifiedSession" key
|
||||||
|
- /new command correctly clears the shared session in unified mode
|
||||||
|
- /new is NOT a priority command (goes through _dispatch, key rewrite applies)
|
||||||
|
- Context window consolidation is unaffected by unified_session
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.command.builtin import cmd_new, register_builtin_commands
|
||||||
|
from nanobot.command.router import CommandContext, CommandRouter
|
||||||
|
from nanobot.config.schema import AgentDefaults, Config
|
||||||
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_loop(tmp_path: Path, unified_session: bool = False) -> AgentLoop:
|
||||||
|
"""Create a minimal AgentLoop for dispatch-level tests."""
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
|
||||||
|
with patch("nanobot.agent.loop.SessionManager"), \
|
||||||
|
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr, \
|
||||||
|
patch("nanobot.agent.loop.Dream"):
|
||||||
|
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=bus,
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
unified_session=unified_session,
|
||||||
|
)
|
||||||
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
def _make_msg(channel: str = "telegram", chat_id: str = "111",
|
||||||
|
session_key_override: str | None = None) -> InboundMessage:
|
||||||
|
return InboundMessage(
|
||||||
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
|
sender_id="user1",
|
||||||
|
content="hello",
|
||||||
|
session_key_override=session_key_override,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TestUnifiedSessionDispatch — core behaviour
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestUnifiedSessionDispatch:
|
||||||
|
"""AgentLoop._dispatch() session key rewriting logic."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unified_session_rewrites_key_to_unified_default(self, tmp_path: Path):
|
||||||
|
"""When unified_session=True, all messages use 'unified:default' as session key."""
|
||||||
|
loop = _make_loop(tmp_path, unified_session=True)
|
||||||
|
|
||||||
|
captured: list[str] = []
|
||||||
|
|
||||||
|
async def fake_process(msg, **kwargs):
|
||||||
|
captured.append(msg.session_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
loop._process_message = fake_process # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = _make_msg(channel="telegram", chat_id="111")
|
||||||
|
await loop._dispatch(msg)
|
||||||
|
|
||||||
|
assert captured == ["unified:default"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unified_session_different_channels_share_same_key(self, tmp_path: Path):
|
||||||
|
"""Messages from different channels all resolve to the same session key."""
|
||||||
|
loop = _make_loop(tmp_path, unified_session=True)
|
||||||
|
|
||||||
|
captured: list[str] = []
|
||||||
|
|
||||||
|
async def fake_process(msg, **kwargs):
|
||||||
|
captured.append(msg.session_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
loop._process_message = fake_process # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await loop._dispatch(_make_msg(channel="telegram", chat_id="111"))
|
||||||
|
await loop._dispatch(_make_msg(channel="discord", chat_id="222"))
|
||||||
|
await loop._dispatch(_make_msg(channel="cli", chat_id="direct"))
|
||||||
|
|
||||||
|
assert captured == ["unified:default", "unified:default", "unified:default"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unified_session_disabled_preserves_original_key(self, tmp_path: Path):
|
||||||
|
"""When unified_session=False (default), session key is channel:chat_id as usual."""
|
||||||
|
loop = _make_loop(tmp_path, unified_session=False)
|
||||||
|
|
||||||
|
captured: list[str] = []
|
||||||
|
|
||||||
|
async def fake_process(msg, **kwargs):
|
||||||
|
captured.append(msg.session_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
loop._process_message = fake_process # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = _make_msg(channel="telegram", chat_id="999")
|
||||||
|
await loop._dispatch(msg)
|
||||||
|
|
||||||
|
assert captured == ["telegram:999"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unified_session_respects_existing_override(self, tmp_path: Path):
|
||||||
|
"""If session_key_override is already set (e.g. Telegram thread), it is NOT overwritten."""
|
||||||
|
loop = _make_loop(tmp_path, unified_session=True)
|
||||||
|
|
||||||
|
captured: list[str] = []
|
||||||
|
|
||||||
|
async def fake_process(msg, **kwargs):
|
||||||
|
captured.append(msg.session_key)
|
||||||
|
return None
|
||||||
|
|
||||||
|
loop._process_message = fake_process # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = _make_msg(channel="telegram", chat_id="111", session_key_override="telegram:thread:42")
|
||||||
|
await loop._dispatch(msg)
|
||||||
|
|
||||||
|
assert captured == ["telegram:thread:42"]
|
||||||
|
|
||||||
|
def test_unified_session_default_is_false(self, tmp_path: Path):
|
||||||
|
"""unified_session defaults to False — no behavior change for existing users."""
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
assert loop._unified_session is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TestUnifiedSessionConfig — schema & serialisation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestUnifiedSessionConfig:
|
||||||
|
"""Config schema and onboard serialisation for unified_session."""
|
||||||
|
|
||||||
|
def test_agent_defaults_unified_session_default_is_false(self):
|
||||||
|
"""AgentDefaults.unified_session defaults to False."""
|
||||||
|
defaults = AgentDefaults()
|
||||||
|
assert defaults.unified_session is False
|
||||||
|
|
||||||
|
def test_agent_defaults_unified_session_can_be_enabled(self):
|
||||||
|
"""AgentDefaults.unified_session can be set to True."""
|
||||||
|
defaults = AgentDefaults(unified_session=True)
|
||||||
|
assert defaults.unified_session is True
|
||||||
|
|
||||||
|
def test_config_serialises_unified_session_as_camel_case(self):
|
||||||
|
"""model_dump(by_alias=True) outputs 'unifiedSession' (camelCase) for JSON."""
|
||||||
|
config = Config()
|
||||||
|
data = config.model_dump(mode="json", by_alias=True)
|
||||||
|
agents_defaults = data["agents"]["defaults"]
|
||||||
|
assert "unifiedSession" in agents_defaults
|
||||||
|
assert agents_defaults["unifiedSession"] is False
|
||||||
|
|
||||||
|
def test_config_parses_unified_session_from_camel_case(self):
|
||||||
|
"""Config can be loaded from JSON with camelCase 'unifiedSession'."""
|
||||||
|
raw = {"agents": {"defaults": {"unifiedSession": True}}}
|
||||||
|
config = Config.model_validate(raw)
|
||||||
|
assert config.agents.defaults.unified_session is True
|
||||||
|
|
||||||
|
def test_config_parses_unified_session_from_snake_case(self):
|
||||||
|
"""Config also accepts snake_case 'unified_session' (populate_by_name=True)."""
|
||||||
|
raw = {"agents": {"defaults": {"unified_session": True}}}
|
||||||
|
config = Config.model_validate(raw)
|
||||||
|
assert config.agents.defaults.unified_session is True
|
||||||
|
|
||||||
|
def test_onboard_generated_config_contains_unified_session(self, tmp_path: Path):
|
||||||
|
"""save_config() writes 'unifiedSession' into config.json (simulates nanobot onboard)."""
|
||||||
|
from nanobot.config.loader import save_config
|
||||||
|
|
||||||
|
config = Config()
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
save_config(config, config_path)
|
||||||
|
|
||||||
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
agents_defaults = data["agents"]["defaults"]
|
||||||
|
assert "unifiedSession" in agents_defaults, (
|
||||||
|
"onboard-generated config.json must contain 'unifiedSession' key"
|
||||||
|
)
|
||||||
|
assert agents_defaults["unifiedSession"] is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TestCmdNewUnifiedSession — /new command behaviour in unified mode
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCmdNewUnifiedSession:
|
||||||
|
"""/new command routing and session-clear behaviour in unified mode."""
|
||||||
|
|
||||||
|
def test_new_is_not_a_priority_command(self):
|
||||||
|
"""/new must NOT be in the priority table — it must go through _dispatch()
|
||||||
|
so the unified session key rewrite applies before cmd_new runs."""
|
||||||
|
router = CommandRouter()
|
||||||
|
register_builtin_commands(router)
|
||||||
|
assert router.is_priority("/new") is False
|
||||||
|
|
||||||
|
def test_new_is_an_exact_command(self):
|
||||||
|
"""/new must be registered as an exact command."""
|
||||||
|
router = CommandRouter()
|
||||||
|
register_builtin_commands(router)
|
||||||
|
assert "/new" in router._exact
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cmd_new_clears_unified_session(self, tmp_path: Path):
|
||||||
|
"""cmd_new called with key='unified:default' clears the shared session."""
|
||||||
|
sessions = SessionManager(tmp_path)
|
||||||
|
|
||||||
|
# Pre-populate the shared session with some messages
|
||||||
|
shared = sessions.get_or_create("unified:default")
|
||||||
|
shared.add_message("user", "hello from telegram")
|
||||||
|
shared.add_message("assistant", "hi there")
|
||||||
|
sessions.save(shared)
|
||||||
|
assert len(sessions.get_or_create("unified:default").messages) == 2
|
||||||
|
|
||||||
|
# _schedule_background is a *sync* method that schedules a coroutine via
|
||||||
|
# asyncio.create_task(). Mirror that exactly so the coroutine is consumed
|
||||||
|
# and no RuntimeWarning is emitted.
|
||||||
|
loop = SimpleNamespace(
|
||||||
|
sessions=sessions,
|
||||||
|
consolidator=SimpleNamespace(archive=AsyncMock(return_value=True)),
|
||||||
|
)
|
||||||
|
loop._schedule_background = lambda coro: asyncio.ensure_future(coro)
|
||||||
|
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="telegram", sender_id="user1", chat_id="111", content="/new",
|
||||||
|
session_key_override="unified:default", # as _dispatch() would set it
|
||||||
|
)
|
||||||
|
ctx = CommandContext(msg=msg, session=None, key="unified:default", raw="/new", loop=loop)
|
||||||
|
|
||||||
|
result = await cmd_new(ctx)
|
||||||
|
|
||||||
|
assert "New session started" in result.content
|
||||||
|
# Invalidate cache and reload from disk to confirm persistence
|
||||||
|
sessions.invalidate("unified:default")
|
||||||
|
reloaded = sessions.get_or_create("unified:default")
|
||||||
|
assert reloaded.messages == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cmd_new_in_unified_mode_does_not_affect_other_sessions(self, tmp_path: Path):
|
||||||
|
"""Clearing unified:default must not touch other sessions on disk."""
|
||||||
|
sessions = SessionManager(tmp_path)
|
||||||
|
|
||||||
|
other = sessions.get_or_create("discord:999")
|
||||||
|
other.add_message("user", "discord message")
|
||||||
|
sessions.save(other)
|
||||||
|
|
||||||
|
shared = sessions.get_or_create("unified:default")
|
||||||
|
shared.add_message("user", "shared message")
|
||||||
|
sessions.save(shared)
|
||||||
|
|
||||||
|
loop = SimpleNamespace(
|
||||||
|
sessions=sessions,
|
||||||
|
consolidator=SimpleNamespace(archive=AsyncMock(return_value=True)),
|
||||||
|
)
|
||||||
|
loop._schedule_background = lambda coro: asyncio.ensure_future(coro)
|
||||||
|
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="telegram", sender_id="user1", chat_id="111", content="/new",
|
||||||
|
session_key_override="unified:default",
|
||||||
|
)
|
||||||
|
ctx = CommandContext(msg=msg, session=None, key="unified:default", raw="/new", loop=loop)
|
||||||
|
await cmd_new(ctx)
|
||||||
|
|
||||||
|
sessions.invalidate("unified:default")
|
||||||
|
sessions.invalidate("discord:999")
|
||||||
|
assert sessions.get_or_create("unified:default").messages == []
|
||||||
|
assert len(sessions.get_or_create("discord:999").messages) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TestConsolidationUnaffectedByUnifiedSession — consolidation is key-agnostic
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestConsolidationUnaffectedByUnifiedSession:
|
||||||
|
"""maybe_consolidate_by_tokens() behaviour is identical regardless of session key."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_skips_empty_session_for_unified_key(self):
|
||||||
|
"""Empty unified:default session → consolidation exits immediately, archive not called."""
|
||||||
|
from nanobot.agent.memory import Consolidator, MemoryStore
|
||||||
|
|
||||||
|
store = MagicMock(spec=MemoryStore)
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.chat_with_retry = AsyncMock(return_value=MagicMock(content="summary"))
|
||||||
|
# Use spec= so MagicMock doesn't auto-generate AsyncMock for non-async methods,
|
||||||
|
# which would leave unawaited coroutines and trigger RuntimeWarning.
|
||||||
|
sessions = MagicMock(spec=SessionManager)
|
||||||
|
|
||||||
|
consolidator = Consolidator(
|
||||||
|
store=store,
|
||||||
|
provider=mock_provider,
|
||||||
|
model="test-model",
|
||||||
|
sessions=sessions,
|
||||||
|
context_window_tokens=1000,
|
||||||
|
build_messages=MagicMock(return_value=[]),
|
||||||
|
get_tool_definitions=MagicMock(return_value=[]),
|
||||||
|
max_completion_tokens=100,
|
||||||
|
)
|
||||||
|
consolidator.archive = AsyncMock()
|
||||||
|
|
||||||
|
session = Session(key="unified:default")
|
||||||
|
session.messages = []
|
||||||
|
|
||||||
|
await consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
|
consolidator.archive.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_behaviour_identical_for_any_key(self):
|
||||||
|
"""archive call count is the same for 'telegram:123' and 'unified:default'
|
||||||
|
under identical token conditions."""
|
||||||
|
from nanobot.agent.memory import Consolidator, MemoryStore
|
||||||
|
|
||||||
|
archive_calls: dict[str, int] = {}
|
||||||
|
|
||||||
|
for key in ("telegram:123", "unified:default"):
|
||||||
|
store = MagicMock(spec=MemoryStore)
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.chat_with_retry = AsyncMock(return_value=MagicMock(content="summary"))
|
||||||
|
sessions = MagicMock(spec=SessionManager)
|
||||||
|
|
||||||
|
consolidator = Consolidator(
|
||||||
|
store=store,
|
||||||
|
provider=mock_provider,
|
||||||
|
model="test-model",
|
||||||
|
sessions=sessions,
|
||||||
|
context_window_tokens=1000,
|
||||||
|
build_messages=MagicMock(return_value=[]),
|
||||||
|
get_tool_definitions=MagicMock(return_value=[]),
|
||||||
|
max_completion_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
session = Session(key=key)
|
||||||
|
session.messages = [] # empty → exits immediately for both keys
|
||||||
|
|
||||||
|
consolidator.archive = AsyncMock()
|
||||||
|
await consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
archive_calls[key] = consolidator.archive.call_count
|
||||||
|
|
||||||
|
assert archive_calls["telegram:123"] == archive_calls["unified:default"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consolidation_triggers_when_over_budget_unified_key(self):
|
||||||
|
"""When tokens exceed budget, consolidation attempts to find a boundary —
|
||||||
|
behaviour is identical to any other session key."""
|
||||||
|
from nanobot.agent.memory import Consolidator, MemoryStore
|
||||||
|
|
||||||
|
store = MagicMock(spec=MemoryStore)
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
sessions = MagicMock(spec=SessionManager)
|
||||||
|
|
||||||
|
consolidator = Consolidator(
|
||||||
|
store=store,
|
||||||
|
provider=mock_provider,
|
||||||
|
model="test-model",
|
||||||
|
sessions=sessions,
|
||||||
|
context_window_tokens=1000,
|
||||||
|
build_messages=MagicMock(return_value=[]),
|
||||||
|
get_tool_definitions=MagicMock(return_value=[]),
|
||||||
|
max_completion_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
session = Session(key="unified:default")
|
||||||
|
session.messages = [{"role": "user", "content": "msg"}]
|
||||||
|
|
||||||
|
# Simulate over-budget: estimated > budget
|
||||||
|
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(950, "tiktoken"))
|
||||||
|
# No valid boundary found → returns gracefully without archiving
|
||||||
|
consolidator.pick_consolidation_boundary = MagicMock(return_value=None)
|
||||||
|
consolidator.archive = AsyncMock()
|
||||||
|
|
||||||
|
await consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
|
# estimate was called (consolidation was attempted)
|
||||||
|
consolidator.estimate_session_prompt_tokens.assert_called_once_with(session)
|
||||||
|
# but archive was not called (no valid boundary)
|
||||||
|
consolidator.archive.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TestStopCommandWithUnifiedSession — /stop command integration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestStopCommandWithUnifiedSession:
|
||||||
|
"""Verify /stop command works correctly with unified session enabled."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_active_tasks_use_effective_key_in_unified_mode(self, tmp_path: Path):
|
||||||
|
"""When unified_session=True, tasks are stored under UNIFIED_SESSION_KEY."""
|
||||||
|
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path, unified_session=True)
|
||||||
|
|
||||||
|
# Create a message from telegram channel
|
||||||
|
msg = _make_msg(channel="telegram", chat_id="123456")
|
||||||
|
|
||||||
|
# Mock _dispatch to complete immediately
|
||||||
|
async def fake_dispatch(m):
|
||||||
|
pass
|
||||||
|
|
||||||
|
loop._dispatch = fake_dispatch # type: ignore[method-assign]
|
||||||
|
|
||||||
|
# Simulate the task creation flow (from _run loop)
|
||||||
|
effective_key = UNIFIED_SESSION_KEY if loop._unified_session and not msg.session_key_override else msg.session_key
|
||||||
|
task = asyncio.create_task(loop._dispatch(msg))
|
||||||
|
loop._active_tasks.setdefault(effective_key, []).append(task)
|
||||||
|
|
||||||
|
# Wait for task to complete
|
||||||
|
await task
|
||||||
|
|
||||||
|
# Verify the task is stored under UNIFIED_SESSION_KEY, not the original channel:chat_id
|
||||||
|
assert UNIFIED_SESSION_KEY in loop._active_tasks
|
||||||
|
assert "telegram:123456" not in loop._active_tasks
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_command_finds_task_in_unified_mode(self, tmp_path: Path):
|
||||||
|
"""cmd_stop can cancel tasks when unified_session=True."""
|
||||||
|
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||||
|
from nanobot.command.builtin import cmd_stop
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path, unified_session=True)
|
||||||
|
|
||||||
|
# Create a long-running task stored under UNIFIED_SESSION_KEY
|
||||||
|
async def long_running():
|
||||||
|
await asyncio.sleep(10) # Will be cancelled
|
||||||
|
|
||||||
|
task = asyncio.create_task(long_running())
|
||||||
|
loop._active_tasks[UNIFIED_SESSION_KEY] = [task]
|
||||||
|
|
||||||
|
# Create a message that would have session_key=UNIFIED_SESSION_KEY after dispatch
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="telegram",
|
||||||
|
chat_id="123456",
|
||||||
|
sender_id="user1",
|
||||||
|
content="/stop",
|
||||||
|
session_key_override=UNIFIED_SESSION_KEY, # Simulate post-dispatch state
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
|
||||||
|
|
||||||
|
# Execute /stop
|
||||||
|
result = await cmd_stop(ctx)
|
||||||
|
|
||||||
|
# Verify task was cancelled
|
||||||
|
assert task.cancelled() or task.done()
|
||||||
|
assert "Stopped 1 task" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path):
|
||||||
|
"""In unified mode, /stop from one channel cancels tasks from another channel."""
|
||||||
|
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||||
|
from nanobot.command.builtin import cmd_stop
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path, unified_session=True)
|
||||||
|
|
||||||
|
# Create tasks from different channels, all stored under UNIFIED_SESSION_KEY
|
||||||
|
async def long_running():
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
task1 = asyncio.create_task(long_running())
|
||||||
|
task2 = asyncio.create_task(long_running())
|
||||||
|
loop._active_tasks[UNIFIED_SESSION_KEY] = [task1, task2]
|
||||||
|
|
||||||
|
# /stop from discord should cancel tasks started from telegram
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="discord",
|
||||||
|
chat_id="789012",
|
||||||
|
sender_id="user2",
|
||||||
|
content="/stop",
|
||||||
|
session_key_override=UNIFIED_SESSION_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
|
||||||
|
|
||||||
|
result = await cmd_stop(ctx)
|
||||||
|
|
||||||
|
# Both tasks should be cancelled
|
||||||
|
assert "Stopped 2 task" in result.content
|
||||||
@ -1,6 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import zipfile
|
||||||
|
from io import BytesIO
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
# Check optional dingtalk dependencies before running tests
|
# Check optional dingtalk dependencies before running tests
|
||||||
@ -50,6 +54,21 @@ class _FakeHttp:
|
|||||||
return self._next_response()
|
return self._next_response()
|
||||||
|
|
||||||
|
|
||||||
|
class _NetworkErrorHttp:
|
||||||
|
"""HTTP client stub that raises httpx.TransportError on every request."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls: list[dict] = []
|
||||||
|
|
||||||
|
async def post(self, url: str, json=None, headers=None, **kwargs):
|
||||||
|
self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
|
||||||
|
raise httpx.ConnectError("Connection refused")
|
||||||
|
|
||||||
|
async def get(self, url: str, **kwargs):
|
||||||
|
self.calls.append({"method": "GET", "url": url})
|
||||||
|
raise httpx.ConnectError("Connection refused")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
|
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
|
||||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
|
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
|
||||||
@ -221,3 +240,216 @@ async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None:
|
|||||||
assert "messageFiles/download" in channel._http.calls[0]["url"]
|
assert "messageFiles/download" in channel._http.calls[0]["url"]
|
||||||
assert channel._http.calls[0]["json"]["downloadCode"] == "code123"
|
assert channel._http.calls[0]["json"]["downloadCode"] == "code123"
|
||||||
assert channel._http.calls[1]["method"] == "GET"
|
assert channel._http.calls[1]["method"] == "GET"
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_upload_payload_zips_html_attachment() -> None:
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
|
||||||
|
data, filename, content_type = channel._normalize_upload_payload(
|
||||||
|
"report.html",
|
||||||
|
b"<html><body>Hello</body></html>",
|
||||||
|
"text/html",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert filename == "report.zip"
|
||||||
|
assert content_type == "application/zip"
|
||||||
|
|
||||||
|
archive = zipfile.ZipFile(BytesIO(data))
|
||||||
|
assert archive.namelist() == ["report.html"]
|
||||||
|
assert archive.read("report.html") == b"<html><body>Hello</body></html>"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_ref_zips_html_before_upload(tmp_path, monkeypatch) -> None:
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
|
||||||
|
html_path = tmp_path / "report.html"
|
||||||
|
html_path.write_text("<html><body>Hello</body></html>", encoding="utf-8")
|
||||||
|
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
async def fake_upload_media(*, token, data, media_type, filename, content_type):
|
||||||
|
captured.update(
|
||||||
|
{
|
||||||
|
"token": token,
|
||||||
|
"data": data,
|
||||||
|
"media_type": media_type,
|
||||||
|
"filename": filename,
|
||||||
|
"content_type": content_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return "media-123"
|
||||||
|
|
||||||
|
async def fake_send_batch_message(token, chat_id, msg_key, msg_param):
|
||||||
|
captured.update(
|
||||||
|
{
|
||||||
|
"sent_token": token,
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"msg_key": msg_key,
|
||||||
|
"msg_param": msg_param,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
monkeypatch.setattr(channel, "_upload_media", fake_upload_media)
|
||||||
|
monkeypatch.setattr(channel, "_send_batch_message", fake_send_batch_message)
|
||||||
|
|
||||||
|
ok = await channel._send_media_ref("token-123", "user-1", str(html_path))
|
||||||
|
|
||||||
|
assert ok is True
|
||||||
|
assert captured["media_type"] == "file"
|
||||||
|
assert captured["filename"] == "report.zip"
|
||||||
|
assert captured["content_type"] == "application/zip"
|
||||||
|
assert captured["msg_key"] == "sampleFile"
|
||||||
|
assert captured["msg_param"] == {
|
||||||
|
"mediaId": "media-123",
|
||||||
|
"fileName": "report.zip",
|
||||||
|
"fileType": "zip",
|
||||||
|
}
|
||||||
|
|
||||||
|
archive = zipfile.ZipFile(BytesIO(captured["data"]))
|
||||||
|
assert archive.namelist() == ["report.html"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Exception handling tests ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_batch_message_propagates_transport_error() -> None:
|
||||||
|
"""Network/transport errors must re-raise so callers can retry."""
|
||||||
|
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||||
|
channel = DingTalkChannel(config, MessageBus())
|
||||||
|
channel._http = _NetworkErrorHttp()
|
||||||
|
|
||||||
|
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||||
|
await channel._send_batch_message(
|
||||||
|
"token",
|
||||||
|
"user123",
|
||||||
|
"sampleMarkdown",
|
||||||
|
{"text": "hello", "title": "Nanobot Reply"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# The POST was attempted exactly once
|
||||||
|
assert len(channel._http.calls) == 1
|
||||||
|
assert channel._http.calls[0]["method"] == "POST"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_batch_message_returns_false_on_api_error() -> None:
|
||||||
|
"""DingTalk API-level errors (non-200 status, errcode != 0) should return False."""
|
||||||
|
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||||
|
channel = DingTalkChannel(config, MessageBus())
|
||||||
|
|
||||||
|
# Non-200 status code → API error → return False
|
||||||
|
channel._http = _FakeHttp(responses=[_FakeResponse(400, {"errcode": 400})])
|
||||||
|
result = await channel._send_batch_message(
|
||||||
|
"token", "user123", "sampleMarkdown", {"text": "hello"}
|
||||||
|
)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
# 200 with non-zero errcode → API error → return False
|
||||||
|
channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 100})])
|
||||||
|
result = await channel._send_batch_message(
|
||||||
|
"token", "user123", "sampleMarkdown", {"text": "hello"}
|
||||||
|
)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
# 200 with errcode=0 → success → return True
|
||||||
|
channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 0})])
|
||||||
|
result = await channel._send_batch_message(
|
||||||
|
"token", "user123", "sampleMarkdown", {"text": "hello"}
|
||||||
|
)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_ref_short_circuits_on_transport_error() -> None:
|
||||||
|
"""When the first send fails with a transport error, _send_media_ref must
|
||||||
|
re-raise immediately instead of trying download+upload+fallback."""
|
||||||
|
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||||
|
channel = DingTalkChannel(config, MessageBus())
|
||||||
|
channel._http = _NetworkErrorHttp()
|
||||||
|
|
||||||
|
# An image URL triggers the sampleImageMsg path first
|
||||||
|
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||||
|
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
|
||||||
|
|
||||||
|
# Only one POST should have been attempted — no download/upload/fallback
|
||||||
|
assert len(channel._http.calls) == 1
|
||||||
|
assert channel._http.calls[0]["method"] == "POST"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_ref_short_circuits_on_download_transport_error() -> None:
|
||||||
|
"""When the image URL send returns an API error (False) but the download
|
||||||
|
for the fallback hits a transport error, it must re-raise rather than
|
||||||
|
silently returning False."""
|
||||||
|
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||||
|
channel = DingTalkChannel(config, MessageBus())
|
||||||
|
|
||||||
|
# First POST (sampleImageMsg) returns API error → False, then GET (download) raises transport error
|
||||||
|
class _MixedHttp:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls: list[dict] = []
|
||||||
|
|
||||||
|
async def post(self, url, json=None, headers=None, **kwargs):
|
||||||
|
self.calls.append({"method": "POST", "url": url})
|
||||||
|
# API-level failure: 200 with errcode != 0
|
||||||
|
return _FakeResponse(200, {"errcode": 100})
|
||||||
|
|
||||||
|
async def get(self, url, **kwargs):
|
||||||
|
self.calls.append({"method": "GET", "url": url})
|
||||||
|
raise httpx.ConnectError("Connection refused")
|
||||||
|
|
||||||
|
channel._http = _MixedHttp()
|
||||||
|
|
||||||
|
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||||
|
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
|
||||||
|
|
||||||
|
# Should have attempted POST (image URL) and GET (download), but NOT upload
|
||||||
|
assert len(channel._http.calls) == 2
|
||||||
|
assert channel._http.calls[0]["method"] == "POST"
|
||||||
|
assert channel._http.calls[1]["method"] == "GET"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_ref_short_circuits_on_upload_transport_error() -> None:
|
||||||
|
"""When download succeeds but upload hits a transport error, must re-raise."""
|
||||||
|
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||||
|
channel = DingTalkChannel(config, MessageBus())
|
||||||
|
|
||||||
|
image_bytes = b"\xff\xd8\xff\xe0" + b"\x00" * 100 # minimal JPEG-ish data
|
||||||
|
|
||||||
|
class _UploadFailsHttp:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.calls: list[dict] = []
|
||||||
|
|
||||||
|
async def post(self, url, json=None, headers=None, files=None, **kwargs):
|
||||||
|
self.calls.append({"method": "POST", "url": url})
|
||||||
|
# If it's the upload endpoint, raise transport error
|
||||||
|
if "media/upload" in url:
|
||||||
|
raise httpx.ConnectError("Connection refused")
|
||||||
|
# Otherwise (sampleImageMsg), return API error to trigger fallback
|
||||||
|
return _FakeResponse(200, {"errcode": 100})
|
||||||
|
|
||||||
|
async def get(self, url, **kwargs):
|
||||||
|
self.calls.append({"method": "GET", "url": url})
|
||||||
|
resp = _FakeResponse(200)
|
||||||
|
resp.content = image_bytes
|
||||||
|
resp.headers = {"content-type": "image/jpeg"}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
channel._http = _UploadFailsHttp()
|
||||||
|
|
||||||
|
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||||
|
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
|
||||||
|
|
||||||
|
# POST (image URL), GET (download), POST (upload) attempted — no further sends
|
||||||
|
methods = [c["method"] for c in channel._http.calls]
|
||||||
|
assert methods == ["POST", "GET", "POST"]
|
||||||
|
|||||||
@ -5,11 +5,17 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
discord = pytest.importorskip("discord")
|
discord = pytest.importorskip("discord")
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig
|
from nanobot.channels.discord import (
|
||||||
|
MAX_MESSAGE_LEN,
|
||||||
|
DiscordBotClient,
|
||||||
|
DiscordChannel,
|
||||||
|
DiscordConfig,
|
||||||
|
)
|
||||||
from nanobot.command.builtin import build_help_text
|
from nanobot.command.builtin import build_help_text
|
||||||
|
|
||||||
|
|
||||||
@ -18,9 +24,11 @@ class _FakeDiscordClient:
|
|||||||
instances: list["_FakeDiscordClient"] = []
|
instances: list["_FakeDiscordClient"] = []
|
||||||
start_error: Exception | None = None
|
start_error: Exception | None = None
|
||||||
|
|
||||||
def __init__(self, owner, *, intents) -> None:
|
def __init__(self, owner, *, intents, proxy=None, proxy_auth=None) -> None:
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.intents = intents
|
self.intents = intents
|
||||||
|
self.proxy = proxy
|
||||||
|
self.proxy_auth = proxy_auth
|
||||||
self.closed = False
|
self.closed = False
|
||||||
self.ready = True
|
self.ready = True
|
||||||
self.channels: dict[int, object] = {}
|
self.channels: dict[int, object] = {}
|
||||||
@ -53,7 +61,9 @@ class _FakeDiscordClient:
|
|||||||
|
|
||||||
class _FakeAttachment:
|
class _FakeAttachment:
|
||||||
# Attachment double that can simulate successful or failing save() calls.
|
# Attachment double that can simulate successful or failing save() calls.
|
||||||
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
|
def __init__(
|
||||||
|
self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False
|
||||||
|
) -> None:
|
||||||
self.id = attachment_id
|
self.id = attachment_id
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -71,11 +81,25 @@ class _FakePartialMessage:
|
|||||||
self.id = message_id
|
self.id = message_id
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSentMessage:
|
||||||
|
# Sent-message double supporting edit() for streaming tests.
|
||||||
|
def __init__(self, channel, content: str) -> None:
|
||||||
|
self.channel = channel
|
||||||
|
self.content = content
|
||||||
|
self.edits: list[dict] = []
|
||||||
|
|
||||||
|
async def edit(self, **kwargs) -> None:
|
||||||
|
self.edits.append(dict(kwargs))
|
||||||
|
if "content" in kwargs:
|
||||||
|
self.content = kwargs["content"]
|
||||||
|
|
||||||
|
|
||||||
class _FakeChannel:
|
class _FakeChannel:
|
||||||
# Channel double that records outbound payloads and typing activity.
|
# Channel double that records outbound payloads and typing activity.
|
||||||
def __init__(self, channel_id: int = 123) -> None:
|
def __init__(self, channel_id: int = 123) -> None:
|
||||||
self.id = channel_id
|
self.id = channel_id
|
||||||
self.sent_payloads: list[dict] = []
|
self.sent_payloads: list[dict] = []
|
||||||
|
self.sent_messages: list[_FakeSentMessage] = []
|
||||||
self.trigger_typing_calls = 0
|
self.trigger_typing_calls = 0
|
||||||
self.typing_enter_hook = None
|
self.typing_enter_hook = None
|
||||||
|
|
||||||
@ -85,6 +109,9 @@ class _FakeChannel:
|
|||||||
payload["file_name"] = payload["file"].filename
|
payload["file_name"] = payload["file"].filename
|
||||||
del payload["file"]
|
del payload["file"]
|
||||||
self.sent_payloads.append(payload)
|
self.sent_payloads.append(payload)
|
||||||
|
message = _FakeSentMessage(self, payload.get("content", ""))
|
||||||
|
self.sent_messages.append(message)
|
||||||
|
return message
|
||||||
|
|
||||||
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
|
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
|
||||||
return _FakePartialMessage(message_id)
|
return _FakePartialMessage(message_id)
|
||||||
@ -194,7 +221,7 @@ async def test_start_handles_client_construction_failure(monkeypatch) -> None:
|
|||||||
MessageBus(),
|
MessageBus(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _boom(owner, *, intents):
|
def _boom(owner, *, intents, proxy=None, proxy_auth=None):
|
||||||
raise RuntimeError("bad client")
|
raise RuntimeError("bad client")
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
|
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
|
||||||
@ -427,6 +454,60 @@ async def test_send_fetches_channel_when_not_cached() -> None:
|
|||||||
assert target.sent_payloads == [{"content": "hello"}]
|
assert target.sent_payloads == [{"content": "hello"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_supports_streaming_enabled_by_default() -> None:
|
||||||
|
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||||
|
|
||||||
|
assert channel.supports_streaming is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_delta_streams_by_editing_message(monkeypatch) -> None:
|
||||||
|
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeDiscordClient(owner, intents=None)
|
||||||
|
owner._client = client
|
||||||
|
owner._running = True
|
||||||
|
target = _FakeChannel(channel_id=123)
|
||||||
|
client.channels[123] = target
|
||||||
|
|
||||||
|
times = iter([1.0, 3.0, 5.0])
|
||||||
|
monkeypatch.setattr("nanobot.channels.discord.time.monotonic", lambda: next(times, 5.0))
|
||||||
|
|
||||||
|
await owner.send_delta("123", "hel", {"_stream_delta": True, "_stream_id": "s1"})
|
||||||
|
await owner.send_delta("123", "lo", {"_stream_delta": True, "_stream_id": "s1"})
|
||||||
|
await owner.send_delta("123", "", {"_stream_end": True, "_stream_id": "s1"})
|
||||||
|
|
||||||
|
assert target.sent_payloads[0] == {"content": "hel"}
|
||||||
|
assert target.sent_messages[0].edits == [{"content": "hello"}, {"content": "hello"}]
|
||||||
|
assert owner._stream_bufs == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_delta_stream_end_splits_oversized_reply(monkeypatch) -> None:
|
||||||
|
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeDiscordClient(owner, intents=None)
|
||||||
|
owner._client = client
|
||||||
|
owner._running = True
|
||||||
|
target = _FakeChannel(channel_id=123)
|
||||||
|
client.channels[123] = target
|
||||||
|
|
||||||
|
prefix = "a" * (MAX_MESSAGE_LEN - 100)
|
||||||
|
suffix = "b" * 150
|
||||||
|
full_text = prefix + suffix
|
||||||
|
chunks = DiscordBotClient._build_chunks(full_text, [], False)
|
||||||
|
assert len(chunks) == 2
|
||||||
|
|
||||||
|
times = iter([1.0, 3.0])
|
||||||
|
monkeypatch.setattr("nanobot.channels.discord.time.monotonic", lambda: next(times, 3.0))
|
||||||
|
|
||||||
|
await owner.send_delta("123", prefix, {"_stream_delta": True, "_stream_id": "s1"})
|
||||||
|
await owner.send_delta("123", suffix, {"_stream_delta": True, "_stream_id": "s1"})
|
||||||
|
await owner.send_delta("123", "", {"_stream_end": True, "_stream_id": "s1"})
|
||||||
|
|
||||||
|
assert target.sent_payloads == [{"content": prefix}, {"content": chunks[1]}]
|
||||||
|
assert target.sent_messages[0].edits == [{"content": chunks[0]}, {"content": chunks[0]}]
|
||||||
|
assert owner._stream_bufs == {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
||||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
|
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
|
||||||
@ -443,9 +524,7 @@ async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
|||||||
assert new_cmd is not None
|
assert new_cmd is not None
|
||||||
await new_cmd.callback(interaction)
|
await new_cmd.callback(interaction)
|
||||||
|
|
||||||
assert interaction.response.messages == [
|
assert interaction.response.messages == [{"content": "Processing /new...", "ephemeral": True}]
|
||||||
{"content": "Processing /new...", "ephemeral": True}
|
|
||||||
]
|
|
||||||
assert len(handled) == 1
|
assert len(handled) == 1
|
||||||
assert handled[0]["content"] == "/new"
|
assert handled[0]["content"] == "/new"
|
||||||
assert handled[0]["sender_id"] == "123"
|
assert handled[0]["sender_id"] == "123"
|
||||||
@ -519,9 +598,7 @@ async def test_slash_help_returns_ephemeral_help_text() -> None:
|
|||||||
assert help_cmd is not None
|
assert help_cmd is not None
|
||||||
await help_cmd.callback(interaction)
|
await help_cmd.callback(interaction)
|
||||||
|
|
||||||
assert interaction.response.messages == [
|
assert interaction.response.messages == [{"content": build_help_text(), "ephemeral": True}]
|
||||||
{"content": build_help_text(), "ephemeral": True}
|
|
||||||
]
|
|
||||||
assert handled == []
|
assert handled == []
|
||||||
|
|
||||||
|
|
||||||
@ -656,11 +733,13 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
|||||||
def typing(self):
|
def typing(self):
|
||||||
async def _waiter():
|
async def _waiter():
|
||||||
await release.wait()
|
await release.wait()
|
||||||
|
|
||||||
# Hold the loop so task remains active until explicitly stopped.
|
# Hold the loop so task remains active until explicitly stopped.
|
||||||
class _Ctx(_TypingCtx):
|
class _Ctx(_TypingCtx):
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
await super().__aenter__()
|
await super().__aenter__()
|
||||||
await _waiter()
|
await _waiter()
|
||||||
|
|
||||||
return _Ctx()
|
return _Ctx()
|
||||||
|
|
||||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||||
@ -674,3 +753,214 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
|||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
assert channel._typing_tasks == {}
|
assert channel._typing_tasks == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_accepts_proxy_fields() -> None:
|
||||||
|
config = DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
token="token",
|
||||||
|
allow_from=["*"],
|
||||||
|
proxy="http://127.0.0.1:7890",
|
||||||
|
proxy_username="user",
|
||||||
|
proxy_password="pass",
|
||||||
|
)
|
||||||
|
assert config.proxy == "http://127.0.0.1:7890"
|
||||||
|
assert config.proxy_username == "user"
|
||||||
|
assert config.proxy_password == "pass"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_proxy_defaults_to_none() -> None:
|
||||||
|
config = DiscordConfig(enabled=True, token="token", allow_from=["*"])
|
||||||
|
assert config.proxy is None
|
||||||
|
assert config.proxy_username is None
|
||||||
|
assert config.proxy_password is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_passes_proxy_to_client(monkeypatch) -> None:
|
||||||
|
_FakeDiscordClient.instances.clear()
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
token="token",
|
||||||
|
allow_from=["*"],
|
||||||
|
proxy="http://127.0.0.1:7890",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||||
|
|
||||||
|
await channel.start()
|
||||||
|
|
||||||
|
assert channel.is_running is False
|
||||||
|
assert len(_FakeDiscordClient.instances) == 1
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_passes_proxy_auth_when_credentials_provided(monkeypatch) -> None:
|
||||||
|
aiohttp = pytest.importorskip("aiohttp")
|
||||||
|
_FakeDiscordClient.instances.clear()
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
token="token",
|
||||||
|
allow_from=["*"],
|
||||||
|
proxy="http://127.0.0.1:7890",
|
||||||
|
proxy_username="user",
|
||||||
|
proxy_password="pass",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||||
|
|
||||||
|
await channel.start()
|
||||||
|
|
||||||
|
assert channel.is_running is False
|
||||||
|
assert len(_FakeDiscordClient.instances) == 1
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy_auth is not None
|
||||||
|
assert isinstance(_FakeDiscordClient.instances[0].proxy_auth, aiohttp.BasicAuth)
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy_auth.login == "user"
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy_auth.password == "pass"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_no_proxy_auth_when_only_username(monkeypatch) -> None:
|
||||||
|
_FakeDiscordClient.instances.clear()
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
token="token",
|
||||||
|
allow_from=["*"],
|
||||||
|
proxy="http://127.0.0.1:7890",
|
||||||
|
proxy_username="user",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||||
|
|
||||||
|
await channel.start()
|
||||||
|
|
||||||
|
assert channel.is_running is False
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_no_proxy_auth_when_only_password(monkeypatch) -> None:
|
||||||
|
_FakeDiscordClient.instances.clear()
|
||||||
|
channel = DiscordChannel(
|
||||||
|
DiscordConfig(
|
||||||
|
enabled=True,
|
||||||
|
token="token",
|
||||||
|
allow_from=["*"],
|
||||||
|
proxy="http://127.0.0.1:7890",
|
||||||
|
proxy_password="pass",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||||
|
|
||||||
|
await channel.start()
|
||||||
|
|
||||||
|
assert channel.is_running is False
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||||
|
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests for the send() exception propagation fix
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_re_raises_network_error() -> None:
|
||||||
|
"""Network errors during send must propagate so ChannelManager can retry."""
|
||||||
|
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeDiscordClient(channel, intents=None)
|
||||||
|
channel._client = client
|
||||||
|
channel._running = True
|
||||||
|
|
||||||
|
async def _failing_send_outbound(msg: OutboundMessage) -> None:
|
||||||
|
raise ConnectionError("network unreachable")
|
||||||
|
|
||||||
|
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
|
||||||
|
|
||||||
|
with pytest.raises(ConnectionError, match="network unreachable"):
|
||||||
|
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_re_raises_generic_exception() -> None:
|
||||||
|
"""Any exception from send_outbound must propagate, not be swallowed."""
|
||||||
|
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeDiscordClient(channel, intents=None)
|
||||||
|
channel._client = client
|
||||||
|
channel._running = True
|
||||||
|
|
||||||
|
async def _failing_send_outbound(msg: OutboundMessage) -> None:
|
||||||
|
raise RuntimeError("discord API failure")
|
||||||
|
|
||||||
|
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="discord API failure"):
|
||||||
|
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_still_stops_typing_on_error() -> None:
|
||||||
|
"""Typing cleanup must still run in the finally block even when send raises."""
|
||||||
|
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeDiscordClient(channel, intents=None)
|
||||||
|
channel._client = client
|
||||||
|
channel._running = True
|
||||||
|
|
||||||
|
# Start a typing task so we can verify it gets cleaned up
|
||||||
|
start = asyncio.Event()
|
||||||
|
release = asyncio.Event()
|
||||||
|
|
||||||
|
async def slow_typing() -> None:
|
||||||
|
start.set()
|
||||||
|
await release.wait()
|
||||||
|
|
||||||
|
typing_channel = _FakeChannel(channel_id=123)
|
||||||
|
typing_channel.typing_enter_hook = slow_typing
|
||||||
|
await channel._start_typing(typing_channel)
|
||||||
|
await asyncio.wait_for(start.wait(), timeout=1.0)
|
||||||
|
|
||||||
|
async def _failing_send_outbound(msg: OutboundMessage) -> None:
|
||||||
|
raise ConnectionError("timeout")
|
||||||
|
|
||||||
|
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
|
||||||
|
|
||||||
|
with pytest.raises(ConnectionError, match="timeout"):
|
||||||
|
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||||
|
|
||||||
|
release.set()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# Typing should have been cleaned up by the finally block
|
||||||
|
assert channel._typing_tasks == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_succeeds_normally() -> None:
|
||||||
|
"""Successful sends should work without raising."""
|
||||||
|
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeDiscordClient(channel, intents=None)
|
||||||
|
channel._client = client
|
||||||
|
channel._running = True
|
||||||
|
|
||||||
|
sent_messages: list[OutboundMessage] = []
|
||||||
|
|
||||||
|
async def _capture_send_outbound(msg: OutboundMessage) -> None:
|
||||||
|
sent_messages.append(msg)
|
||||||
|
|
||||||
|
client.send_outbound = _capture_send_outbound # type: ignore[method-assign]
|
||||||
|
|
||||||
|
msg = OutboundMessage(channel="discord", chat_id="123", content="hello world")
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
assert len(sent_messages) == 1
|
||||||
|
assert sent_messages[0].content == "hello world"
|
||||||
|
assert sent_messages[0].chat_id == "123"
|
||||||
|
|||||||
48
tests/channels/test_feishu_domain.py
Normal file
48
tests/channels/test_feishu_domain.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
"""Tests for Feishu/Lark domain configuration."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _make_channel(domain: str = "feishu") -> FeishuChannel:
|
||||||
|
config = FeishuConfig(
|
||||||
|
enabled=True,
|
||||||
|
app_id="cli_test",
|
||||||
|
app_secret="secret",
|
||||||
|
allow_from=["*"],
|
||||||
|
domain=domain,
|
||||||
|
)
|
||||||
|
ch = FeishuChannel(config, MessageBus())
|
||||||
|
ch._client = MagicMock()
|
||||||
|
ch._loop = None
|
||||||
|
return ch
|
||||||
|
|
||||||
|
|
||||||
|
class TestFeishuConfigDomain:
|
||||||
|
def test_domain_default_is_feishu(self):
|
||||||
|
config = FeishuConfig()
|
||||||
|
assert config.domain == "feishu"
|
||||||
|
|
||||||
|
def test_domain_accepts_lark(self):
|
||||||
|
config = FeishuConfig(domain="lark")
|
||||||
|
assert config.domain == "lark"
|
||||||
|
|
||||||
|
def test_domain_accepts_feishu(self):
|
||||||
|
config = FeishuConfig(domain="feishu")
|
||||||
|
assert config.domain == "feishu"
|
||||||
|
|
||||||
|
def test_default_config_includes_domain(self):
|
||||||
|
default_cfg = FeishuChannel.default_config()
|
||||||
|
assert "domain" in default_cfg
|
||||||
|
assert default_cfg["domain"] == "feishu"
|
||||||
|
|
||||||
|
def test_channel_persists_domain_from_config(self):
|
||||||
|
ch = _make_channel(domain="lark")
|
||||||
|
assert ch.config.domain == "lark"
|
||||||
|
|
||||||
|
def test_channel_persists_feishu_domain_from_config(self):
|
||||||
|
ch = _make_channel(domain="feishu")
|
||||||
|
assert ch.config.domain == "feishu"
|
||||||
@ -5,6 +5,7 @@ from unittest.mock import MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig, _FeishuStreamBuf
|
from nanobot.channels.feishu import FeishuChannel, FeishuConfig, _FeishuStreamBuf
|
||||||
|
|
||||||
@ -203,6 +204,24 @@ class TestSendDelta:
|
|||||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||||
ch._client.im.v1.message.create.assert_called_once()
|
ch._client.im.v1.message.create.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_end_fallback_when_final_update_fails(self):
|
||||||
|
"""If streaming mode was closed (e.g. Feishu timeout), fall back to a regular card."""
|
||||||
|
ch = _make_channel()
|
||||||
|
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||||
|
text="Lost content", card_id="card_1", sequence=3, last_edit=0.0,
|
||||||
|
)
|
||||||
|
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response(success=False)
|
||||||
|
ch._client.im.v1.message.create.return_value = _mock_send_response("om_fb")
|
||||||
|
|
||||||
|
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||||
|
|
||||||
|
assert "oc_chat1" not in ch._stream_bufs
|
||||||
|
# Should NOT attempt to close streaming mode since update failed
|
||||||
|
ch._client.cardkit.v1.card.settings.assert_not_called()
|
||||||
|
# Should fall back to sending a regular interactive card
|
||||||
|
ch._client.im.v1.message.create.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_end_without_buf_is_noop(self):
|
async def test_stream_end_without_buf_is_noop(self):
|
||||||
ch = _make_channel()
|
ch = _make_channel()
|
||||||
@ -239,6 +258,130 @@ class TestSendDelta:
|
|||||||
assert buf.sequence == 7
|
assert buf.sequence == 7
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolHintInlineStreaming:
|
||||||
|
"""Tool hint messages should be inlined into active streaming cards."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_hint_inlined_when_stream_active(self):
|
||||||
|
"""With an active streaming buffer, tool hint appends to the card."""
|
||||||
|
ch = _make_channel()
|
||||||
|
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||||
|
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
|
||||||
|
)
|
||||||
|
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||||
|
|
||||||
|
msg = OutboundMessage(
|
||||||
|
channel="feishu", chat_id="oc_chat1",
|
||||||
|
content='web_fetch("https://example.com")',
|
||||||
|
metadata={"_tool_hint": True},
|
||||||
|
)
|
||||||
|
await ch.send(msg)
|
||||||
|
|
||||||
|
buf = ch._stream_bufs["oc_chat1"]
|
||||||
|
assert '🔧 web_fetch("https://example.com")' in buf.text
|
||||||
|
assert buf.sequence == 3
|
||||||
|
ch._client.cardkit.v1.card_element.content.assert_called_once()
|
||||||
|
ch._client.im.v1.message.create.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_hint_preserved_on_next_delta(self):
|
||||||
|
"""When new delta arrives, the tool hint is kept as permanent content and delta appends after it."""
|
||||||
|
ch = _make_channel()
|
||||||
|
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||||
|
text="Partial answer\n\n🔧 web_fetch(\"url\")\n\n",
|
||||||
|
card_id="card_1", sequence=3, last_edit=0.0,
|
||||||
|
)
|
||||||
|
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||||
|
|
||||||
|
await ch.send_delta("oc_chat1", " continued")
|
||||||
|
|
||||||
|
buf = ch._stream_bufs["oc_chat1"]
|
||||||
|
assert "Partial answer" in buf.text
|
||||||
|
assert "🔧 web_fetch" in buf.text
|
||||||
|
assert buf.text.endswith(" continued")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_hint_fallback_when_no_stream(self):
|
||||||
|
"""Without an active buffer, tool hint falls back to a standalone card."""
|
||||||
|
ch = _make_channel()
|
||||||
|
ch._client.im.v1.message.create.return_value = _mock_send_response("om_hint")
|
||||||
|
|
||||||
|
msg = OutboundMessage(
|
||||||
|
channel="feishu", chat_id="oc_chat1",
|
||||||
|
content='read_file("path")',
|
||||||
|
metadata={"_tool_hint": True},
|
||||||
|
)
|
||||||
|
await ch.send(msg)
|
||||||
|
|
||||||
|
assert "oc_chat1" not in ch._stream_bufs
|
||||||
|
ch._client.im.v1.message.create.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_consecutive_tool_hints_append(self):
|
||||||
|
"""When multiple tool hints arrive consecutively, each appends to the card."""
|
||||||
|
ch = _make_channel()
|
||||||
|
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||||
|
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
|
||||||
|
)
|
||||||
|
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||||
|
|
||||||
|
msg1 = OutboundMessage(
|
||||||
|
channel="feishu", chat_id="oc_chat1",
|
||||||
|
content='$ cd /project', metadata={"_tool_hint": True},
|
||||||
|
)
|
||||||
|
await ch.send(msg1)
|
||||||
|
|
||||||
|
msg2 = OutboundMessage(
|
||||||
|
channel="feishu", chat_id="oc_chat1",
|
||||||
|
content='$ git status', metadata={"_tool_hint": True},
|
||||||
|
)
|
||||||
|
await ch.send(msg2)
|
||||||
|
|
||||||
|
buf = ch._stream_bufs["oc_chat1"]
|
||||||
|
assert "$ cd /project" in buf.text
|
||||||
|
assert "$ git status" in buf.text
|
||||||
|
assert buf.text.startswith("Partial answer")
|
||||||
|
assert "🔧 $ cd /project" in buf.text
|
||||||
|
assert "🔧 $ git status" in buf.text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_hint_preserved_on_final_stream_end(self):
|
||||||
|
"""When final _stream_end closes the card, tool hint is kept in the final text."""
|
||||||
|
ch = _make_channel()
|
||||||
|
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||||
|
text="Final content\n\n🔧 web_fetch(\"url\")\n\n",
|
||||||
|
card_id="card_1", sequence=3, last_edit=0.0,
|
||||||
|
)
|
||||||
|
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||||
|
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response()
|
||||||
|
|
||||||
|
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||||
|
|
||||||
|
assert "oc_chat1" not in ch._stream_bufs
|
||||||
|
update_call = ch._client.cardkit.v1.card_element.content.call_args[0][0]
|
||||||
|
assert "🔧" in update_call.body.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_tool_hint_is_noop(self):
|
||||||
|
"""Empty or whitespace-only tool hint content is silently ignored."""
|
||||||
|
ch = _make_channel()
|
||||||
|
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||||
|
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
for content in ("", " ", "\t\n"):
|
||||||
|
msg = OutboundMessage(
|
||||||
|
channel="feishu", chat_id="oc_chat1",
|
||||||
|
content=content, metadata={"_tool_hint": True},
|
||||||
|
)
|
||||||
|
await ch.send(msg)
|
||||||
|
|
||||||
|
buf = ch._stream_bufs["oc_chat1"]
|
||||||
|
assert buf.text == "Partial answer"
|
||||||
|
assert buf.sequence == 2
|
||||||
|
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
class TestSendMessageReturnsId:
|
class TestSendMessageReturnsId:
|
||||||
def test_returns_message_id_on_success(self):
|
def test_returns_message_id_on_success(self):
|
||||||
ch = _make_channel()
|
ch = _make_channel()
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""Tests for FeishuChannel tool hint code block formatting."""
|
"""Tests for FeishuChannel tool hint formatting."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -28,15 +29,24 @@ def mock_feishu_channel():
|
|||||||
config.app_secret = "test_app_secret"
|
config.app_secret = "test_app_secret"
|
||||||
config.encrypt_key = None
|
config.encrypt_key = None
|
||||||
config.verification_token = None
|
config.verification_token = None
|
||||||
|
config.tool_hint_prefix = "\U0001f527" # 🔧
|
||||||
bus = MagicMock()
|
bus = MagicMock()
|
||||||
channel = FeishuChannel(config, bus)
|
channel = FeishuChannel(config, bus)
|
||||||
channel._client = MagicMock() # Simulate initialized client
|
channel._client = MagicMock()
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tool_hint_card(mock_send):
|
||||||
|
"""Extract the interactive card from _send_message_sync calls."""
|
||||||
|
call_args = mock_send.call_args[0]
|
||||||
|
_, _, msg_type, content = call_args
|
||||||
|
assert msg_type == "interactive"
|
||||||
|
return json.loads(content)
|
||||||
|
|
||||||
|
|
||||||
@mark.asyncio
|
@mark.asyncio
|
||||||
async def test_tool_hint_sends_code_message(mock_feishu_channel):
|
async def test_tool_hint_sends_interactive_card(mock_feishu_channel):
|
||||||
"""Tool hint messages should be sent as interactive cards with code blocks."""
|
"""Tool hint without active buffer sends an interactive card with 🔧 style."""
|
||||||
msg = OutboundMessage(
|
msg = OutboundMessage(
|
||||||
channel="feishu",
|
channel="feishu",
|
||||||
chat_id="oc_123456",
|
chat_id="oc_123456",
|
||||||
@ -47,23 +57,12 @@ async def test_tool_hint_sends_code_message(mock_feishu_channel):
|
|||||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||||
await mock_feishu_channel.send(msg)
|
await mock_feishu_channel.send(msg)
|
||||||
|
|
||||||
# Verify interactive message with card was sent
|
|
||||||
assert mock_send.call_count == 1
|
assert mock_send.call_count == 1
|
||||||
call_args = mock_send.call_args[0]
|
card = _get_tool_hint_card(mock_send)
|
||||||
receive_id_type, receive_id, msg_type, content = call_args
|
|
||||||
|
|
||||||
assert receive_id_type == "chat_id"
|
|
||||||
assert receive_id == "oc_123456"
|
|
||||||
assert msg_type == "interactive"
|
|
||||||
|
|
||||||
# Parse content to verify card structure
|
|
||||||
card = json.loads(content)
|
|
||||||
assert card["config"]["wide_screen_mode"] is True
|
assert card["config"]["wide_screen_mode"] is True
|
||||||
assert len(card["elements"]) == 1
|
md = card["elements"][0]["content"]
|
||||||
assert card["elements"][0]["tag"] == "markdown"
|
assert "\U0001f527" in md
|
||||||
# Check that code block is properly formatted with language hint
|
assert "web_search" in md
|
||||||
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"test query\")\n```"
|
|
||||||
assert card["elements"][0]["content"] == expected_md
|
|
||||||
|
|
||||||
|
|
||||||
@mark.asyncio
|
@mark.asyncio
|
||||||
@ -78,8 +77,6 @@ async def test_tool_hint_empty_content_does_not_send(mock_feishu_channel):
|
|||||||
|
|
||||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||||
await mock_feishu_channel.send(msg)
|
await mock_feishu_channel.send(msg)
|
||||||
|
|
||||||
# Should not send any message
|
|
||||||
mock_send.assert_not_called()
|
mock_send.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
@ -96,7 +93,6 @@ async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel):
|
|||||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||||
await mock_feishu_channel.send(msg)
|
await mock_feishu_channel.send(msg)
|
||||||
|
|
||||||
# Should send as text message (detected format)
|
|
||||||
assert mock_send.call_count == 1
|
assert mock_send.call_count == 1
|
||||||
call_args = mock_send.call_args[0]
|
call_args = mock_send.call_args[0]
|
||||||
_, _, msg_type, content = call_args
|
_, _, msg_type, content = call_args
|
||||||
@ -106,7 +102,7 @@ async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel):
|
|||||||
|
|
||||||
@mark.asyncio
|
@mark.asyncio
|
||||||
async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
|
async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
|
||||||
"""Multiple tool calls should be displayed each on its own line in a code block."""
|
"""Multiple tool calls should each get the 🔧 prefix."""
|
||||||
msg = OutboundMessage(
|
msg = OutboundMessage(
|
||||||
channel="feishu",
|
channel="feishu",
|
||||||
chat_id="oc_123456",
|
chat_id="oc_123456",
|
||||||
@ -117,13 +113,11 @@ async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
|
|||||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||||
await mock_feishu_channel.send(msg)
|
await mock_feishu_channel.send(msg)
|
||||||
|
|
||||||
call_args = mock_send.call_args[0]
|
card = _get_tool_hint_card(mock_send)
|
||||||
msg_type = call_args[2]
|
md = card["elements"][0]["content"]
|
||||||
content = json.loads(call_args[3])
|
assert "web_search" in md
|
||||||
assert msg_type == "interactive"
|
assert "read_file" in md
|
||||||
# Each tool call should be on its own line
|
assert "\U0001f527" in md
|
||||||
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"query\"),\nread_file(\"/path/to/file\")\n```"
|
|
||||||
assert content["elements"][0]["content"] == expected_md
|
|
||||||
|
|
||||||
|
|
||||||
@mark.asyncio
|
@mark.asyncio
|
||||||
@ -139,8 +133,8 @@ async def test_tool_hint_new_format_basic(mock_feishu_channel):
|
|||||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||||
await mock_feishu_channel.send(msg)
|
await mock_feishu_channel.send(msg)
|
||||||
|
|
||||||
content = json.loads(mock_send.call_args[0][3])
|
card = _get_tool_hint_card(mock_send)
|
||||||
md = content["elements"][0]["content"]
|
md = card["elements"][0]["content"]
|
||||||
assert "read src/main.py" in md
|
assert "read src/main.py" in md
|
||||||
assert 'grep "TODO"' in md
|
assert 'grep "TODO"' in md
|
||||||
|
|
||||||
@ -158,16 +152,15 @@ async def test_tool_hint_new_format_with_comma_in_quotes(mock_feishu_channel):
|
|||||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||||
await mock_feishu_channel.send(msg)
|
await mock_feishu_channel.send(msg)
|
||||||
|
|
||||||
content = json.loads(mock_send.call_args[0][3])
|
card = _get_tool_hint_card(mock_send)
|
||||||
md = content["elements"][0]["content"]
|
md = card["elements"][0]["content"]
|
||||||
# The comma inside quotes should NOT cause a line break
|
|
||||||
assert 'grep "hello, world"' in md
|
assert 'grep "hello, world"' in md
|
||||||
assert "$ echo test" in md
|
assert "$ echo test" in md
|
||||||
|
|
||||||
|
|
||||||
@mark.asyncio
|
@mark.asyncio
|
||||||
async def test_tool_hint_new_format_with_folding(mock_feishu_channel):
|
async def test_tool_hint_new_format_with_folding(mock_feishu_channel):
|
||||||
"""Folded calls (× N) should display on separate lines."""
|
"""Folded calls (× N) should display correctly."""
|
||||||
msg = OutboundMessage(
|
msg = OutboundMessage(
|
||||||
channel="feishu",
|
channel="feishu",
|
||||||
chat_id="oc_123456",
|
chat_id="oc_123456",
|
||||||
@ -178,8 +171,8 @@ async def test_tool_hint_new_format_with_folding(mock_feishu_channel):
|
|||||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||||
await mock_feishu_channel.send(msg)
|
await mock_feishu_channel.send(msg)
|
||||||
|
|
||||||
content = json.loads(mock_send.call_args[0][3])
|
card = _get_tool_hint_card(mock_send)
|
||||||
md = content["elements"][0]["content"]
|
md = card["elements"][0]["content"]
|
||||||
assert "\u00d7 3" in md
|
assert "\u00d7 3" in md
|
||||||
assert 'grep "pattern"' in md
|
assert 'grep "pattern"' in md
|
||||||
|
|
||||||
@ -197,9 +190,12 @@ async def test_tool_hint_new_format_mcp(mock_feishu_channel):
|
|||||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||||
await mock_feishu_channel.send(msg)
|
await mock_feishu_channel.send(msg)
|
||||||
|
|
||||||
content = json.loads(mock_send.call_args[0][3])
|
card = _get_tool_hint_card(mock_send)
|
||||||
md = content["elements"][0]["content"]
|
md = card["elements"][0]["content"]
|
||||||
assert "4_5v::analyze_image" in md
|
assert "4_5v::analyze_image" in md
|
||||||
|
|
||||||
|
|
||||||
|
@mark.asyncio
|
||||||
async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
|
async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
|
||||||
"""Commas inside a single tool argument must not be split onto a new line."""
|
"""Commas inside a single tool argument must not be split onto a new line."""
|
||||||
msg = OutboundMessage(
|
msg = OutboundMessage(
|
||||||
@ -212,10 +208,7 @@ async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
|
|||||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||||
await mock_feishu_channel.send(msg)
|
await mock_feishu_channel.send(msg)
|
||||||
|
|
||||||
content = json.loads(mock_send.call_args[0][3])
|
card = _get_tool_hint_card(mock_send)
|
||||||
expected_md = (
|
md = card["elements"][0]["content"]
|
||||||
"**Tool Calls**\n\n```text\n"
|
assert 'web_search("foo, bar")' in md
|
||||||
"web_search(\"foo, bar\"),\n"
|
assert 'read_file("/path/to/file")' in md
|
||||||
"read_file(\"/path/to/file\")\n```"
|
|
||||||
)
|
|
||||||
assert content["elements"][0]["content"] == expected_md
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -14,6 +15,8 @@ except ImportError:
|
|||||||
if not QQ_AVAILABLE:
|
if not QQ_AVAILABLE:
|
||||||
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.qq import QQChannel, QQConfig
|
from nanobot.channels.qq import QQChannel, QQConfig
|
||||||
@ -170,3 +173,221 @@ async def test_read_media_bytes_missing_file() -> None:
|
|||||||
data, filename = await channel._read_media_bytes("/nonexistent/path/image.png")
|
data, filename = await channel._read_media_bytes("/nonexistent/path/image.png")
|
||||||
assert data is None
|
assert data is None
|
||||||
assert filename is None
|
assert filename is None
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------------------------------------------
|
||||||
|
# Tests for _send_media exception handling
|
||||||
|
# -------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_channel_with_local_file(suffix: str = ".png", content: bytes = b"\x89PNG\r\n"):
|
||||||
|
"""Create a QQChannel with a fake client and a temp file for media."""
|
||||||
|
channel = QQChannel(
|
||||||
|
QQConfig(app_id="app", secret="secret", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
channel._chat_type_cache["user1"] = "c2c"
|
||||||
|
|
||||||
|
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||||
|
tmp.write(content)
|
||||||
|
tmp.close()
|
||||||
|
return channel, tmp.name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_network_error_propagates() -> None:
|
||||||
|
"""aiohttp.ClientError (network/transport) should re-raise, not return False."""
|
||||||
|
channel, tmp_path = _make_channel_with_local_file()
|
||||||
|
|
||||||
|
# Make the base64 upload raise a network error
|
||||||
|
channel._client.api._http = SimpleNamespace()
|
||||||
|
channel._client.api._http.request = AsyncMock(
|
||||||
|
side_effect=aiohttp.ServerDisconnectedError("connection lost"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(aiohttp.ServerDisconnectedError):
|
||||||
|
await channel._send_media(
|
||||||
|
chat_id="user1",
|
||||||
|
media_ref=tmp_path,
|
||||||
|
msg_id="msg1",
|
||||||
|
is_group=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_client_connector_error_propagates() -> None:
|
||||||
|
"""aiohttp.ClientConnectorError (DNS/connection refused) should re-raise."""
|
||||||
|
channel, tmp_path = _make_channel_with_local_file()
|
||||||
|
|
||||||
|
from aiohttp.client_reqrep import ConnectionKey
|
||||||
|
conn_key = ConnectionKey("api.qq.com", 443, True, None, None, None, None)
|
||||||
|
connector_error = aiohttp.ClientConnectorError(
|
||||||
|
connection_key=conn_key,
|
||||||
|
os_error=OSError("Connection refused"),
|
||||||
|
)
|
||||||
|
|
||||||
|
channel._client.api._http = SimpleNamespace()
|
||||||
|
channel._client.api._http.request = AsyncMock(
|
||||||
|
side_effect=connector_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(aiohttp.ClientConnectorError):
|
||||||
|
await channel._send_media(
|
||||||
|
chat_id="user1",
|
||||||
|
media_ref=tmp_path,
|
||||||
|
msg_id="msg1",
|
||||||
|
is_group=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_oserror_propagates() -> None:
|
||||||
|
"""OSError (low-level I/O) should re-raise for retry."""
|
||||||
|
channel, tmp_path = _make_channel_with_local_file()
|
||||||
|
|
||||||
|
channel._client.api._http = SimpleNamespace()
|
||||||
|
channel._client.api._http.request = AsyncMock(
|
||||||
|
side_effect=OSError("Network is unreachable"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(OSError):
|
||||||
|
await channel._send_media(
|
||||||
|
chat_id="user1",
|
||||||
|
media_ref=tmp_path,
|
||||||
|
msg_id="msg1",
|
||||||
|
is_group=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_api_error_returns_false() -> None:
|
||||||
|
"""API-level errors (botpy RuntimeError subclasses) should return False, not raise."""
|
||||||
|
channel, tmp_path = _make_channel_with_local_file()
|
||||||
|
|
||||||
|
# Simulate a botpy API error (e.g. ServerError is a RuntimeError subclass)
|
||||||
|
from botpy.errors import ServerError
|
||||||
|
|
||||||
|
channel._client.api._http = SimpleNamespace()
|
||||||
|
channel._client.api._http.request = AsyncMock(
|
||||||
|
side_effect=ServerError("internal server error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await channel._send_media(
|
||||||
|
chat_id="user1",
|
||||||
|
media_ref=tmp_path,
|
||||||
|
msg_id="msg1",
|
||||||
|
is_group=False,
|
||||||
|
)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_generic_runtime_error_returns_false() -> None:
|
||||||
|
"""Generic RuntimeError (not network) should return False."""
|
||||||
|
channel, tmp_path = _make_channel_with_local_file()
|
||||||
|
|
||||||
|
channel._client.api._http = SimpleNamespace()
|
||||||
|
channel._client.api._http.request = AsyncMock(
|
||||||
|
side_effect=RuntimeError("some API error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await channel._send_media(
|
||||||
|
chat_id="user1",
|
||||||
|
media_ref=tmp_path,
|
||||||
|
msg_id="msg1",
|
||||||
|
is_group=False,
|
||||||
|
)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_value_error_returns_false() -> None:
|
||||||
|
"""ValueError (bad API response data) should return False."""
|
||||||
|
channel, tmp_path = _make_channel_with_local_file()
|
||||||
|
|
||||||
|
channel._client.api._http = SimpleNamespace()
|
||||||
|
channel._client.api._http.request = AsyncMock(
|
||||||
|
side_effect=ValueError("bad response data"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await channel._send_media(
|
||||||
|
chat_id="user1",
|
||||||
|
media_ref=tmp_path,
|
||||||
|
msg_id="msg1",
|
||||||
|
is_group=False,
|
||||||
|
)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_timeout_error_propagates() -> None:
|
||||||
|
"""asyncio.TimeoutError inherits from Exception but not ClientError/OSError.
|
||||||
|
However, aiohttp.ServerTimeoutError IS a ClientError subclass, so that propagates.
|
||||||
|
For a plain TimeoutError (which is also OSError in Python 3.11+), it should propagate."""
|
||||||
|
channel, tmp_path = _make_channel_with_local_file()
|
||||||
|
|
||||||
|
channel._client.api._http = SimpleNamespace()
|
||||||
|
channel._client.api._http.request = AsyncMock(
|
||||||
|
side_effect=aiohttp.ServerTimeoutError("request timed out"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(aiohttp.ServerTimeoutError):
|
||||||
|
await channel._send_media(
|
||||||
|
chat_id="user1",
|
||||||
|
media_ref=tmp_path,
|
||||||
|
msg_id="msg1",
|
||||||
|
is_group=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_fallback_text_on_api_error() -> None:
|
||||||
|
"""When _send_media returns False (API error), send() should emit fallback text."""
|
||||||
|
channel, tmp_path = _make_channel_with_local_file()
|
||||||
|
|
||||||
|
from botpy.errors import ServerError
|
||||||
|
|
||||||
|
channel._client.api._http = SimpleNamespace()
|
||||||
|
channel._client.api._http.request = AsyncMock(
|
||||||
|
side_effect=ServerError("internal server error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="qq",
|
||||||
|
chat_id="user1",
|
||||||
|
content="",
|
||||||
|
media=[tmp_path],
|
||||||
|
metadata={"message_id": "msg1"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have sent a fallback text message
|
||||||
|
assert len(channel._client.api.c2c_calls) == 1
|
||||||
|
fallback_content = channel._client.api.c2c_calls[0]["content"]
|
||||||
|
assert "Attachment send failed" in fallback_content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_propagates_network_error_no_fallback() -> None:
|
||||||
|
"""When _send_media raises a network error, send() should NOT silently fallback."""
|
||||||
|
channel, tmp_path = _make_channel_with_local_file()
|
||||||
|
|
||||||
|
channel._client.api._http = SimpleNamespace()
|
||||||
|
channel._client.api._http.request = AsyncMock(
|
||||||
|
side_effect=aiohttp.ServerDisconnectedError("connection lost"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(aiohttp.ServerDisconnectedError):
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="qq",
|
||||||
|
chat_id="user1",
|
||||||
|
content="hello",
|
||||||
|
media=[tmp_path],
|
||||||
|
metadata={"message_id": "msg1"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# No fallback text should have been sent
|
||||||
|
assert len(channel._client.api.c2c_calls) == 0
|
||||||
|
|||||||
304
tests/channels/test_qq_media.py
Normal file
304
tests/channels/test_qq_media.py
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
"""Tests for QQ channel media support: helpers, send, inbound, and upload."""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
from nanobot.channels import qq
|
||||||
|
|
||||||
|
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
|
||||||
|
except ImportError:
|
||||||
|
QQ_AVAILABLE = False
|
||||||
|
|
||||||
|
if not QQ_AVAILABLE:
|
||||||
|
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.qq import (
|
||||||
|
QQ_FILE_TYPE_FILE,
|
||||||
|
QQ_FILE_TYPE_IMAGE,
|
||||||
|
QQChannel,
|
||||||
|
QQConfig,
|
||||||
|
_guess_send_file_type,
|
||||||
|
_is_image_name,
|
||||||
|
_sanitize_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeApi:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.c2c_calls: list[dict] = []
|
||||||
|
self.group_calls: list[dict] = []
|
||||||
|
|
||||||
|
async def post_c2c_message(self, **kwargs) -> None:
|
||||||
|
self.c2c_calls.append(kwargs)
|
||||||
|
|
||||||
|
async def post_group_message(self, **kwargs) -> None:
|
||||||
|
self.group_calls.append(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeHttp:
|
||||||
|
"""Fake _http for _post_base64file tests."""
|
||||||
|
|
||||||
|
def __init__(self, return_value: dict | None = None) -> None:
|
||||||
|
self.return_value = return_value or {}
|
||||||
|
self.calls: list[tuple] = []
|
||||||
|
|
||||||
|
async def request(self, route, **kwargs):
|
||||||
|
self.calls.append((route, kwargs))
|
||||||
|
return self.return_value
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeClient:
|
||||||
|
def __init__(self, http_return: dict | None = None) -> None:
|
||||||
|
self.api = _FakeApi()
|
||||||
|
self.api._http = _FakeHttp(http_return)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helper function tests (pure, no async) ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_filename_strips_path_traversal() -> None:
|
||||||
|
assert _sanitize_filename("../../etc/passwd") == "passwd"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_filename_keeps_chinese_chars() -> None:
|
||||||
|
assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_filename_strips_unsafe_chars() -> None:
|
||||||
|
result = _sanitize_filename('file<>:"|?*.txt')
|
||||||
|
# All unsafe chars replaced with "_", but * is replaced too
|
||||||
|
assert result.startswith("file")
|
||||||
|
assert result.endswith(".txt")
|
||||||
|
assert "<" not in result
|
||||||
|
assert ">" not in result
|
||||||
|
assert '"' not in result
|
||||||
|
assert "|" not in result
|
||||||
|
assert "?" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_filename_empty_input() -> None:
|
||||||
|
assert _sanitize_filename("") == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_image_name_with_known_extensions() -> None:
|
||||||
|
for ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".ico", ".svg"):
|
||||||
|
assert _is_image_name(f"photo{ext}") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_image_name_with_unknown_extension() -> None:
|
||||||
|
for ext in (".pdf", ".txt", ".mp3", ".mp4"):
|
||||||
|
assert _is_image_name(f"doc{ext}") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_guess_send_file_type_image() -> None:
|
||||||
|
assert _guess_send_file_type("photo.png") == QQ_FILE_TYPE_IMAGE
|
||||||
|
assert _guess_send_file_type("pic.jpg") == QQ_FILE_TYPE_IMAGE
|
||||||
|
|
||||||
|
|
||||||
|
def test_guess_send_file_type_file() -> None:
|
||||||
|
assert _guess_send_file_type("doc.pdf") == QQ_FILE_TYPE_FILE
|
||||||
|
|
||||||
|
|
||||||
|
def test_guess_send_file_type_by_mime() -> None:
|
||||||
|
# A filename with no known extension but whose mime type is image/*
|
||||||
|
assert _guess_send_file_type("photo.xyz_image_test") == QQ_FILE_TYPE_FILE
|
||||||
|
|
||||||
|
|
||||||
|
# ── send() exception handling ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_exception_caught_not_raised() -> None:
|
||||||
|
"""Exceptions inside send() must not propagate."""
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
|
||||||
|
with patch.object(channel, "_send_text_only", new_callable=AsyncMock, side_effect=RuntimeError("boom")):
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(channel="qq", chat_id="user1", content="hello")
|
||||||
|
)
|
||||||
|
# No exception raised — test passes if we get here.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_then_text() -> None:
|
||||||
|
"""Media is sent before text when both are present."""
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||||
|
f.write(b"\x89PNG\r\n")
|
||||||
|
tmp = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.object(channel, "_post_base64file", new_callable=AsyncMock, return_value={"file_info": "1"}) as mock_upload:
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="qq",
|
||||||
|
chat_id="user1",
|
||||||
|
content="text after image",
|
||||||
|
media=[tmp],
|
||||||
|
metadata={"message_id": "m1"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert mock_upload.called
|
||||||
|
|
||||||
|
# Text should have been sent via c2c (default chat type)
|
||||||
|
text_calls = [c for c in channel._client.api.c2c_calls if c.get("msg_type") == 0]
|
||||||
|
assert len(text_calls) >= 1
|
||||||
|
assert text_calls[-1]["content"] == "text after image"
|
||||||
|
finally:
|
||||||
|
import os
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_failure_falls_back_to_text() -> None:
|
||||||
|
"""When _send_media returns False, a failure notice is appended."""
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
|
||||||
|
with patch.object(channel, "_send_media", new_callable=AsyncMock, return_value=False):
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="qq",
|
||||||
|
chat_id="user1",
|
||||||
|
content="hello",
|
||||||
|
media=["https://example.com/bad.png"],
|
||||||
|
metadata={"message_id": "m1"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have the failure text among the c2c calls
|
||||||
|
failure_calls = [c for c in channel._client.api.c2c_calls if "Attachment send failed" in c.get("content", "")]
|
||||||
|
assert len(failure_calls) == 1
|
||||||
|
assert "bad.png" in failure_calls[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── _on_message() exception handling ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_exception_caught_not_raised() -> None:
|
||||||
|
"""Missing required attributes should not crash _on_message."""
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
|
||||||
|
# Construct a message-like object that lacks 'author' — triggers AttributeError
|
||||||
|
bad_data = SimpleNamespace(id="x1", content="hi")
|
||||||
|
# Should not raise
|
||||||
|
await channel._on_message(bad_data, is_group=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_with_attachments() -> None:
|
||||||
|
"""Messages with attachments produce media_paths and formatted content."""
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||||
|
f.write(b"\x89PNG\r\n")
|
||||||
|
saved_path = f.name
|
||||||
|
|
||||||
|
att = SimpleNamespace(url="", filename="screenshot.png", content_type="image/png")
|
||||||
|
|
||||||
|
# Patch _download_to_media_dir_chunked to return the temp file path
|
||||||
|
async def fake_download(url, filename_hint=""):
|
||||||
|
return saved_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.object(channel, "_download_to_media_dir_chunked", side_effect=fake_download):
|
||||||
|
data = SimpleNamespace(
|
||||||
|
id="att1",
|
||||||
|
content="look at this",
|
||||||
|
author=SimpleNamespace(user_openid="u1"),
|
||||||
|
attachments=[att],
|
||||||
|
)
|
||||||
|
await channel._on_message(data, is_group=False)
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert "look at this" in msg.content
|
||||||
|
assert "screenshot.png" in msg.content
|
||||||
|
assert "Received files:" in msg.content
|
||||||
|
assert len(msg.media) == 1
|
||||||
|
assert msg.media[0] == saved_path
|
||||||
|
finally:
|
||||||
|
import os
|
||||||
|
os.unlink(saved_path)
|
||||||
|
|
||||||
|
|
||||||
|
# ── _post_base64file() ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_post_base64file_omits_file_name_for_images() -> None:
|
||||||
|
"""file_type=1 (image) → payload must not contain file_name."""
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||||
|
channel._client = _FakeClient(http_return={"file_info": "img_abc"})
|
||||||
|
|
||||||
|
await channel._post_base64file(
|
||||||
|
chat_id="user1",
|
||||||
|
is_group=False,
|
||||||
|
file_type=QQ_FILE_TYPE_IMAGE,
|
||||||
|
file_data="ZmFrZQ==",
|
||||||
|
file_name="photo.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
http = channel._client.api._http
|
||||||
|
assert len(http.calls) == 1
|
||||||
|
payload = http.calls[0][1]["json"]
|
||||||
|
assert "file_name" not in payload
|
||||||
|
assert payload["file_type"] == QQ_FILE_TYPE_IMAGE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_post_base64file_includes_file_name_for_files() -> None:
|
||||||
|
"""file_type=4 (file) → payload must contain file_name."""
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||||
|
channel._client = _FakeClient(http_return={"file_info": "file_abc"})
|
||||||
|
|
||||||
|
await channel._post_base64file(
|
||||||
|
chat_id="user1",
|
||||||
|
is_group=False,
|
||||||
|
file_type=QQ_FILE_TYPE_FILE,
|
||||||
|
file_data="ZmFrZQ==",
|
||||||
|
file_name="report.pdf",
|
||||||
|
)
|
||||||
|
|
||||||
|
http = channel._client.api._http
|
||||||
|
assert len(http.calls) == 1
|
||||||
|
payload = http.calls[0][1]["json"]
|
||||||
|
assert payload["file_name"] == "report.pdf"
|
||||||
|
assert payload["file_type"] == QQ_FILE_TYPE_FILE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_post_base64file_filters_response_to_file_info() -> None:
|
||||||
|
"""Response with file_info + extra fields must be filtered to only file_info."""
|
||||||
|
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||||
|
channel._client = _FakeClient(http_return={
|
||||||
|
"file_info": "fi_123",
|
||||||
|
"file_uuid": "uuid_xxx",
|
||||||
|
"ttl": 3600,
|
||||||
|
})
|
||||||
|
|
||||||
|
result = await channel._post_base64file(
|
||||||
|
chat_id="user1",
|
||||||
|
is_group=False,
|
||||||
|
file_type=QQ_FILE_TYPE_FILE,
|
||||||
|
file_data="ZmFrZQ==",
|
||||||
|
file_name="doc.pdf",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"file_info": "fi_123"}
|
||||||
|
assert "file_uuid" not in result
|
||||||
|
assert "ttl" not in result
|
||||||
@ -10,8 +10,7 @@ except ImportError:
|
|||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.slack import SlackChannel
|
from nanobot.channels.slack import SlackChannel, SlackConfig
|
||||||
from nanobot.channels.slack import SlackConfig
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeAsyncWebClient:
|
class _FakeAsyncWebClient:
|
||||||
@ -20,6 +19,12 @@ class _FakeAsyncWebClient:
|
|||||||
self.file_upload_calls: list[dict[str, object | None]] = []
|
self.file_upload_calls: list[dict[str, object | None]] = []
|
||||||
self.reactions_add_calls: list[dict[str, object | None]] = []
|
self.reactions_add_calls: list[dict[str, object | None]] = []
|
||||||
self.reactions_remove_calls: list[dict[str, object | None]] = []
|
self.reactions_remove_calls: list[dict[str, object | None]] = []
|
||||||
|
self.conversations_list_calls: list[dict[str, object | None]] = []
|
||||||
|
self.users_list_calls: list[dict[str, object | None]] = []
|
||||||
|
self.conversations_open_calls: list[dict[str, object | None]] = []
|
||||||
|
self._conversations_pages: list[dict[str, object]] = []
|
||||||
|
self._users_pages: list[dict[str, object]] = []
|
||||||
|
self._open_dm_response: dict[str, object] = {"channel": {"id": "D_OPENED"}}
|
||||||
|
|
||||||
async def chat_postMessage(
|
async def chat_postMessage(
|
||||||
self,
|
self,
|
||||||
@ -81,6 +86,22 @@ class _FakeAsyncWebClient:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def conversations_list(self, **kwargs):
|
||||||
|
self.conversations_list_calls.append(kwargs)
|
||||||
|
if self._conversations_pages:
|
||||||
|
return self._conversations_pages.pop(0)
|
||||||
|
return {"channels": [], "response_metadata": {"next_cursor": ""}}
|
||||||
|
|
||||||
|
async def users_list(self, **kwargs):
|
||||||
|
self.users_list_calls.append(kwargs)
|
||||||
|
if self._users_pages:
|
||||||
|
return self._users_pages.pop(0)
|
||||||
|
return {"members": [], "response_metadata": {"next_cursor": ""}}
|
||||||
|
|
||||||
|
async def conversations_open(self, **kwargs):
|
||||||
|
self.conversations_open_calls.append(kwargs)
|
||||||
|
return self._open_dm_response
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_uses_thread_for_channel_messages() -> None:
|
async def test_send_uses_thread_for_channel_messages() -> None:
|
||||||
@ -151,3 +172,147 @@ async def test_send_updates_reaction_when_final_response_sent() -> None:
|
|||||||
assert fake_web.reactions_add_calls == [
|
assert fake_web.reactions_add_calls == [
|
||||||
{"channel": "C123", "name": "white_check_mark", "timestamp": "1700000000.000100"}
|
{"channel": "C123", "name": "white_check_mark", "timestamp": "1700000000.000100"}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_resolves_channel_name_to_channel_id() -> None:
|
||||||
|
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||||
|
fake_web = _FakeAsyncWebClient()
|
||||||
|
fake_web._conversations_pages = [
|
||||||
|
{
|
||||||
|
"channels": [{"id": "C999", "name": "channel_x"}],
|
||||||
|
"response_metadata": {"next_cursor": ""},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
channel._web_client = fake_web
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="slack",
|
||||||
|
chat_id="#channel_x",
|
||||||
|
content="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fake_web.chat_post_calls == [
|
||||||
|
{"channel": "C999", "text": "hello\n", "thread_ts": None}
|
||||||
|
]
|
||||||
|
assert len(fake_web.conversations_list_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_resolves_user_handle_to_dm_channel() -> None:
|
||||||
|
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||||
|
fake_web = _FakeAsyncWebClient()
|
||||||
|
fake_web._users_pages = [
|
||||||
|
{
|
||||||
|
"members": [
|
||||||
|
{
|
||||||
|
"id": "U234",
|
||||||
|
"name": "alice",
|
||||||
|
"profile": {"display_name": "Alice"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"response_metadata": {"next_cursor": ""},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
fake_web._open_dm_response = {"channel": {"id": "D234"}}
|
||||||
|
channel._web_client = fake_web
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="slack",
|
||||||
|
chat_id="@alice",
|
||||||
|
content="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fake_web.conversations_open_calls == [{"users": "U234"}]
|
||||||
|
assert fake_web.chat_post_calls == [
|
||||||
|
{"channel": "D234", "text": "hello\n", "thread_ts": None}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_updates_reaction_on_origin_channel_for_cross_channel_send() -> None:
|
||||||
|
channel = SlackChannel(SlackConfig(enabled=True, react_emoji="eyes"), MessageBus())
|
||||||
|
fake_web = _FakeAsyncWebClient()
|
||||||
|
fake_web._conversations_pages = [
|
||||||
|
{
|
||||||
|
"channels": [{"id": "C999", "name": "channel_x"}],
|
||||||
|
"response_metadata": {"next_cursor": ""},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
channel._web_client = fake_web
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="slack",
|
||||||
|
chat_id="channel_x",
|
||||||
|
content="done",
|
||||||
|
metadata={
|
||||||
|
"slack": {
|
||||||
|
"event": {"ts": "1700000000.000100", "channel": "D_ORIGIN"},
|
||||||
|
"channel_type": "im",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fake_web.chat_post_calls == [
|
||||||
|
{"channel": "C999", "text": "done\n", "thread_ts": None}
|
||||||
|
]
|
||||||
|
assert fake_web.reactions_remove_calls == [
|
||||||
|
{"channel": "D_ORIGIN", "name": "eyes", "timestamp": "1700000000.000100"}
|
||||||
|
]
|
||||||
|
assert fake_web.reactions_add_calls == [
|
||||||
|
{"channel": "D_ORIGIN", "name": "white_check_mark", "timestamp": "1700000000.000100"}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_does_not_reuse_origin_thread_ts_for_cross_channel_send() -> None:
|
||||||
|
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||||
|
fake_web = _FakeAsyncWebClient()
|
||||||
|
fake_web._conversations_pages = [
|
||||||
|
{
|
||||||
|
"channels": [{"id": "C999", "name": "channel_x"}],
|
||||||
|
"response_metadata": {"next_cursor": ""},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
channel._web_client = fake_web
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="slack",
|
||||||
|
chat_id="channel_x",
|
||||||
|
content="done",
|
||||||
|
metadata={
|
||||||
|
"slack": {
|
||||||
|
"event": {"ts": "1700000000.000100", "channel": "C_ORIGIN"},
|
||||||
|
"thread_ts": "1700000000.000200",
|
||||||
|
"channel_type": "channel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fake_web.chat_post_calls == [
|
||||||
|
{"channel": "C999", "text": "done\n", "thread_ts": None}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_raises_when_named_target_cannot_be_resolved() -> None:
|
||||||
|
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||||
|
fake_web = _FakeAsyncWebClient()
|
||||||
|
channel._web_client = fake_web
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="was not found"):
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="slack",
|
||||||
|
chat_id="#missing-channel",
|
||||||
|
content="hello",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@ -387,6 +387,84 @@ async def test_send_delta_stream_end_treats_not_modified_as_success() -> None:
|
|||||||
assert "123" not in channel._stream_bufs
|
assert "123" not in channel._stream_bufs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_delta_stream_end_does_not_fallback_on_network_timeout() -> None:
|
||||||
|
"""TimedOut during HTML edit should propagate, never fall back to plain text."""
|
||||||
|
from telegram.error import TimedOut
|
||||||
|
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
# _call_with_retry retries TimedOut up to 3 times, so the mock will be called
|
||||||
|
# multiple times – but all calls must be with parse_mode="HTML" (no plain fallback).
|
||||||
|
channel._app.bot.edit_message_text = AsyncMock(side_effect=TimedOut("network timeout"))
|
||||||
|
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0)
|
||||||
|
|
||||||
|
with pytest.raises(TimedOut, match="network timeout"):
|
||||||
|
await channel.send_delta("123", "", {"_stream_end": True})
|
||||||
|
|
||||||
|
# Every call to edit_message_text must have used parse_mode="HTML" —
|
||||||
|
# no plain-text fallback call should have been made.
|
||||||
|
for call in channel._app.bot.edit_message_text.call_args_list:
|
||||||
|
assert call.kwargs.get("parse_mode") == "HTML"
|
||||||
|
# Buffer should still be present (not cleaned up on error)
|
||||||
|
assert "123" in channel._stream_bufs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_delta_stream_end_does_not_fallback_on_network_error() -> None:
|
||||||
|
"""NetworkError during HTML edit should propagate, never fall back to plain text."""
|
||||||
|
from telegram.error import NetworkError
|
||||||
|
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
channel._app.bot.edit_message_text = AsyncMock(side_effect=NetworkError("connection reset"))
|
||||||
|
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0)
|
||||||
|
|
||||||
|
with pytest.raises(NetworkError, match="connection reset"):
|
||||||
|
await channel.send_delta("123", "", {"_stream_end": True})
|
||||||
|
|
||||||
|
# Every call to edit_message_text must have used parse_mode="HTML" —
|
||||||
|
# no plain-text fallback call should have been made.
|
||||||
|
for call in channel._app.bot.edit_message_text.call_args_list:
|
||||||
|
assert call.kwargs.get("parse_mode") == "HTML"
|
||||||
|
# Buffer should still be present (not cleaned up on error)
|
||||||
|
assert "123" in channel._stream_bufs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_delta_stream_end_falls_back_on_bad_request() -> None:
|
||||||
|
"""BadRequest (HTML parse error) should still trigger plain-text fallback."""
|
||||||
|
from telegram.error import BadRequest
|
||||||
|
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
# First call (HTML) raises BadRequest, second call (plain) succeeds
|
||||||
|
channel._app.bot.edit_message_text = AsyncMock(
|
||||||
|
side_effect=[BadRequest("Can't parse entities"), None]
|
||||||
|
)
|
||||||
|
channel._stream_bufs["123"] = _StreamBuf(text="hello <bad>", message_id=7, last_edit=0.0)
|
||||||
|
|
||||||
|
await channel.send_delta("123", "", {"_stream_end": True})
|
||||||
|
|
||||||
|
# edit_message_text should have been called twice: once for HTML, once for plain fallback
|
||||||
|
assert channel._app.bot.edit_message_text.call_count == 2
|
||||||
|
# Second call should not use parse_mode="HTML"
|
||||||
|
second_call_kwargs = channel._app.bot.edit_message_text.call_args_list[1].kwargs
|
||||||
|
assert "parse_mode" not in second_call_kwargs or second_call_kwargs.get("parse_mode") is None
|
||||||
|
# Buffer should be cleaned up on success
|
||||||
|
assert "123" not in channel._stream_bufs
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_delta_stream_end_splits_oversized_reply() -> None:
|
async def test_send_delta_stream_end_splits_oversized_reply() -> None:
|
||||||
"""Final streamed reply exceeding Telegram limit is split into chunks."""
|
"""Final streamed reply exceeding Telegram limit is split into chunks."""
|
||||||
@ -1159,3 +1237,159 @@ async def test_on_message_location_with_text() -> None:
|
|||||||
assert len(handled) == 1
|
assert len(handled) == 1
|
||||||
assert "meet me here" in handled[0]["content"]
|
assert "meet me here" in handled[0]["content"]
|
||||||
assert "[location: 51.5074, -0.1278]" in handled[0]["content"]
|
assert "[location: 51.5074, -0.1278]" in handled[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests for retry amplification fix (issue #3050)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_text_does_not_fallback_on_network_timeout() -> None:
|
||||||
|
"""TimedOut should propagate immediately, NOT trigger plain-text fallback.
|
||||||
|
|
||||||
|
Before the fix, _send_text caught ALL exceptions (including TimedOut)
|
||||||
|
and retried as plain text, doubling connection demand during pool
|
||||||
|
exhaustion — see issue #3050.
|
||||||
|
"""
|
||||||
|
from telegram.error import TimedOut
|
||||||
|
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def always_timeout(**kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
raise TimedOut()
|
||||||
|
|
||||||
|
channel._app.bot.send_message = always_timeout
|
||||||
|
|
||||||
|
import nanobot.channels.telegram as tg_mod
|
||||||
|
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||||
|
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||||
|
try:
|
||||||
|
with pytest.raises(TimedOut):
|
||||||
|
await channel._send_text(123, "hello", None, {})
|
||||||
|
finally:
|
||||||
|
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||||
|
|
||||||
|
# With the fix: only _call_with_retry's 3 HTML attempts (no plain fallback).
|
||||||
|
# Before the fix: 3 HTML + 3 plain = 6 attempts.
|
||||||
|
assert call_count == 3, (
|
||||||
|
f"Expected 3 calls (HTML retries only), got {call_count} "
|
||||||
|
"(plain-text fallback should not trigger on TimedOut)"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_text_does_not_fallback_on_network_error() -> None:
|
||||||
|
"""NetworkError should propagate immediately, NOT trigger plain-text fallback."""
|
||||||
|
from telegram.error import NetworkError
|
||||||
|
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def always_network_error(**kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
raise NetworkError("Connection reset")
|
||||||
|
|
||||||
|
channel._app.bot.send_message = always_network_error
|
||||||
|
|
||||||
|
import nanobot.channels.telegram as tg_mod
|
||||||
|
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||||
|
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||||
|
try:
|
||||||
|
with pytest.raises(NetworkError):
|
||||||
|
await channel._send_text(123, "hello", None, {})
|
||||||
|
finally:
|
||||||
|
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||||
|
|
||||||
|
# _call_with_retry does NOT retry NetworkError (only TimedOut/RetryAfter),
|
||||||
|
# so it raises after 1 attempt. The fix prevents plain-text fallback.
|
||||||
|
# Before the fix: 1 HTML + 1 plain = 2. After the fix: 1 HTML only.
|
||||||
|
assert call_count == 1, (
|
||||||
|
f"Expected 1 call (HTML only, no plain fallback), got {call_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_text_falls_back_on_bad_request() -> None:
|
||||||
|
"""BadRequest (HTML parse error) should still trigger plain-text fallback."""
|
||||||
|
from telegram.error import BadRequest
|
||||||
|
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
original_send = channel._app.bot.send_message
|
||||||
|
html_call_count = 0
|
||||||
|
|
||||||
|
async def html_fails(**kwargs):
|
||||||
|
nonlocal html_call_count
|
||||||
|
if kwargs.get("parse_mode") == "HTML":
|
||||||
|
html_call_count += 1
|
||||||
|
raise BadRequest("Can't parse entities")
|
||||||
|
return await original_send(**kwargs)
|
||||||
|
|
||||||
|
channel._app.bot.send_message = html_fails
|
||||||
|
|
||||||
|
import nanobot.channels.telegram as tg_mod
|
||||||
|
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||||
|
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||||
|
try:
|
||||||
|
await channel._send_text(123, "hello **world**", None, {})
|
||||||
|
finally:
|
||||||
|
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||||
|
|
||||||
|
# HTML attempt failed with BadRequest → fallback to plain text succeeds.
|
||||||
|
assert html_call_count == 1, f"Expected 1 HTML attempt, got {html_call_count}"
|
||||||
|
assert len(channel._app.bot.sent_messages) == 1
|
||||||
|
# Plain text send should NOT have parse_mode
|
||||||
|
assert channel._app.bot.sent_messages[0].get("parse_mode") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_text_bad_request_plain_fallback_exhausted() -> None:
|
||||||
|
"""When both HTML and plain-text fallback fail with BadRequest, the error propagates."""
|
||||||
|
from telegram.error import BadRequest
|
||||||
|
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def always_bad_request(**kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
raise BadRequest("Bad request")
|
||||||
|
|
||||||
|
channel._app.bot.send_message = always_bad_request
|
||||||
|
|
||||||
|
import nanobot.channels.telegram as tg_mod
|
||||||
|
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||||
|
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||||
|
try:
|
||||||
|
with pytest.raises(BadRequest):
|
||||||
|
await channel._send_text(123, "hello", None, {})
|
||||||
|
finally:
|
||||||
|
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||||
|
|
||||||
|
# _call_with_retry does NOT retry BadRequest (only TimedOut/RetryAfter),
|
||||||
|
# so HTML fails after 1 attempt → fallback to plain also fails after 1 attempt.
|
||||||
|
# Before the fix: 2 total. After the fix: still 2 (BadRequest SHOULD fallback).
|
||||||
|
assert call_count == 2, f"Expected 2 calls (1 HTML + 1 plain), got {call_count}"
|
||||||
|
|||||||
598
tests/channels/test_websocket_channel.py
Normal file
598
tests/channels/test_websocket_channel.py
Normal file
@ -0,0 +1,598 @@
|
|||||||
|
"""Unit and lightweight integration tests for the WebSocket channel."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
import websockets
|
||||||
|
from websockets.exceptions import ConnectionClosed
|
||||||
|
from websockets.frames import Close
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.channels.websocket import (
|
||||||
|
WebSocketChannel,
|
||||||
|
WebSocketConfig,
|
||||||
|
_issue_route_secret_matches,
|
||||||
|
_normalize_config_path,
|
||||||
|
_normalize_http_path,
|
||||||
|
_parse_inbound_payload,
|
||||||
|
_parse_query,
|
||||||
|
_parse_request_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
|
||||||
|
|
||||||
|
_PORT = 29876
|
||||||
|
|
||||||
|
|
||||||
|
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
|
||||||
|
cfg: dict[str, Any] = {
|
||||||
|
"enabled": True,
|
||||||
|
"allowFrom": ["*"],
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": _PORT,
|
||||||
|
"path": "/ws",
|
||||||
|
"websocketRequiresToken": False,
|
||||||
|
}
|
||||||
|
cfg.update(kw)
|
||||||
|
return WebSocketChannel(cfg, bus)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def bus() -> MagicMock:
|
||||||
|
b = MagicMock()
|
||||||
|
b.publish_inbound = AsyncMock()
|
||||||
|
return b
|
||||||
|
|
||||||
|
|
||||||
|
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
|
||||||
|
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
|
||||||
|
return await asyncio.to_thread(
|
||||||
|
functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_http_path_strips_trailing_slash_except_root() -> None:
|
||||||
|
assert _normalize_http_path("/chat/") == "/chat"
|
||||||
|
assert _normalize_http_path("/chat?x=1") == "/chat"
|
||||||
|
assert _normalize_http_path("/") == "/"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_request_path_matches_normalize_and_query() -> None:
|
||||||
|
path, query = _parse_request_path("/ws/?token=secret&client_id=u1")
|
||||||
|
assert path == _normalize_http_path("/ws/?token=secret&client_id=u1")
|
||||||
|
assert query == _parse_query("/ws/?token=secret&client_id=u1")
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_config_path_matches_request() -> None:
|
||||||
|
assert _normalize_config_path("/ws/") == "/ws"
|
||||||
|
assert _normalize_config_path("/") == "/"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_query_extracts_token_and_client_id() -> None:
|
||||||
|
query = _parse_query("/?token=secret&client_id=u1")
|
||||||
|
assert query.get("token") == ["secret"]
|
||||||
|
assert query.get("client_id") == ["u1"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("raw", "expected"),
|
||||||
|
[
|
||||||
|
("plain", "plain"),
|
||||||
|
('{"content": "hi"}', "hi"),
|
||||||
|
('{"text": "there"}', "there"),
|
||||||
|
('{"message": "x"}', "x"),
|
||||||
|
(" ", None),
|
||||||
|
("{}", None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_inbound_payload(raw: str, expected: str | None) -> None:
|
||||||
|
assert _parse_inbound_payload(raw) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
|
||||||
|
assert _parse_inbound_payload("{not json") == "{not json"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("raw", "expected"),
|
||||||
|
[
|
||||||
|
('{"content": ""}', None), # empty string content
|
||||||
|
('{"content": 123}', None), # non-string content
|
||||||
|
('{"content": " "}', None), # whitespace-only content
|
||||||
|
('["hello"]', '["hello"]'), # JSON array: not a dict, treated as plain text
|
||||||
|
('{"unknown_key": "val"}', None), # unrecognized key
|
||||||
|
('{"content": null}', None), # null content
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_inbound_payload_edge_cases(raw: str, expected: str | None) -> None:
|
||||||
|
assert _parse_inbound_payload(raw) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_web_socket_config_path_must_start_with_slash() -> None:
|
||||||
|
with pytest.raises(ValueError, match='path must start with "/"'):
|
||||||
|
WebSocketConfig(path="bad")
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssl_context_requires_both_cert_and_key_files() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel(
|
||||||
|
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
|
||||||
|
channel._build_ssl_context()
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_config_includes_safe_bind_and_streaming() -> None:
|
||||||
|
defaults = WebSocketChannel.default_config()
|
||||||
|
assert defaults["enabled"] is False
|
||||||
|
assert defaults["host"] == "127.0.0.1"
|
||||||
|
assert defaults["streaming"] is True
|
||||||
|
assert defaults["allowFrom"] == ["*"]
|
||||||
|
assert defaults.get("tokenIssuePath", "") == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_issue_path_must_differ_from_websocket_path() -> None:
|
||||||
|
with pytest.raises(ValueError, match="token_issue_path must differ"):
|
||||||
|
WebSocketConfig(path="/ws", token_issue_path="/ws")
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue_route_secret_matches_bearer_and_header() -> None:
|
||||||
|
from websockets.datastructures import Headers
|
||||||
|
|
||||||
|
secret = "my-secret"
|
||||||
|
bearer_headers = Headers([("Authorization", "Bearer my-secret")])
|
||||||
|
assert _issue_route_secret_matches(bearer_headers, secret) is True
|
||||||
|
x_headers = Headers([("X-Nanobot-Auth", "my-secret")])
|
||||||
|
assert _issue_route_secret_matches(x_headers, secret) is True
|
||||||
|
wrong = Headers([("Authorization", "Bearer other")])
|
||||||
|
assert _issue_route_secret_matches(wrong, secret) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue_route_secret_matches_empty_secret() -> None:
|
||||||
|
from websockets.datastructures import Headers
|
||||||
|
|
||||||
|
# Empty secret always returns True regardless of headers
|
||||||
|
assert _issue_route_secret_matches(Headers([]), "") is True
|
||||||
|
assert _issue_route_secret_matches(Headers([("Authorization", "Bearer anything")]), "") is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
channel._connections["chat-1"] = mock_ws
|
||||||
|
|
||||||
|
msg = OutboundMessage(
|
||||||
|
channel="websocket",
|
||||||
|
chat_id="chat-1",
|
||||||
|
content="hello",
|
||||||
|
reply_to="m1",
|
||||||
|
media=["/tmp/a.png"],
|
||||||
|
)
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
mock_ws.send.assert_awaited_once()
|
||||||
|
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||||
|
assert payload["event"] == "message"
|
||||||
|
assert payload["text"] == "hello"
|
||||||
|
assert payload["reply_to"] == "m1"
|
||||||
|
assert payload["media"] == ["/tmp/a.png"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_missing_connection_is_noop_without_error() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||||
|
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_removes_connection_on_connection_closed() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
||||||
|
channel._connections["chat-1"] = mock_ws
|
||||||
|
|
||||||
|
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
assert "chat-1" not in channel._connections
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_delta_removes_connection_on_connection_closed() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
||||||
|
channel._connections["chat-1"] = mock_ws
|
||||||
|
|
||||||
|
await channel.send_delta("chat-1", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
|
||||||
|
|
||||||
|
assert "chat-1" not in channel._connections
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_delta_emits_delta_and_stream_end() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
channel._connections["chat-1"] = mock_ws
|
||||||
|
|
||||||
|
await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"})
|
||||||
|
await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"})
|
||||||
|
|
||||||
|
assert mock_ws.send.await_count == 2
|
||||||
|
first = json.loads(mock_ws.send.call_args_list[0][0][0])
|
||||||
|
second = json.loads(mock_ws.send.call_args_list[1][0][0])
|
||||||
|
assert first["event"] == "delta"
|
||||||
|
assert first["text"] == "part"
|
||||||
|
assert first["stream_id"] == "sid"
|
||||||
|
assert second["event"] == "stream_end"
|
||||||
|
assert second["stream_id"] == "sid"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_non_connection_closed_exception_is_raised() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||||
|
mock_ws = AsyncMock()
|
||||||
|
mock_ws.send.side_effect = RuntimeError("unexpected")
|
||||||
|
channel._connections["chat-1"] = mock_ws
|
||||||
|
|
||||||
|
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
|
||||||
|
with pytest.raises(RuntimeError, match="unexpected"):
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_delta_missing_connection_is_noop() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||||
|
# No exception, no error — just a no-op
|
||||||
|
await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_is_idempotent() -> None:
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||||
|
# stop() before start() should not raise
|
||||||
|
await channel.stop()
|
||||||
|
await channel.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound(bus: MagicMock) -> None:
|
||||||
|
port = 29876
|
||||||
|
channel = _ch(bus, port=port)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client:
|
||||||
|
ready_raw = await client.recv()
|
||||||
|
ready = json.loads(ready_raw)
|
||||||
|
assert ready["event"] == "ready"
|
||||||
|
assert ready["client_id"] == "tester"
|
||||||
|
chat_id = ready["chat_id"]
|
||||||
|
|
||||||
|
await client.send(json.dumps({"content": "ping from client"}))
|
||||||
|
await asyncio.sleep(0.08)
|
||||||
|
|
||||||
|
bus.publish_inbound.assert_awaited()
|
||||||
|
inbound = bus.publish_inbound.call_args[0][0]
|
||||||
|
assert inbound.channel == "websocket"
|
||||||
|
assert inbound.sender_id == "tester"
|
||||||
|
assert inbound.chat_id == chat_id
|
||||||
|
assert inbound.content == "ping from client"
|
||||||
|
|
||||||
|
await client.send("plain text frame")
|
||||||
|
await asyncio.sleep(0.08)
|
||||||
|
assert bus.publish_inbound.await_count >= 2
|
||||||
|
second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1]
|
||||||
|
assert second.content == "plain text frame"
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_rejects_handshake_when_mismatch(bus: MagicMock) -> None:
|
||||||
|
port = 29877
|
||||||
|
channel = _ch(bus, port=port, path="/", token="secret")
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"):
|
||||||
|
pass
|
||||||
|
assert excinfo.value.response.status_code == 401
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wrong_path_returns_404(bus: MagicMock) -> None:
|
||||||
|
port = 29878
|
||||||
|
channel = _ch(bus, port=port)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/other"):
|
||||||
|
pass
|
||||||
|
assert excinfo.value.response.status_code == 404
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_discovers_websocket_channel() -> None:
|
||||||
|
from nanobot.channels.registry import load_channel_class
|
||||||
|
|
||||||
|
cls = load_channel_class("websocket")
|
||||||
|
assert cls.name == "websocket"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock) -> None:
|
||||||
|
port = 29879
|
||||||
|
channel = _ch(
|
||||||
|
bus, port=port,
|
||||||
|
tokenIssuePath="/auth/token",
|
||||||
|
tokenIssueSecret="route-secret",
|
||||||
|
websocketRequiresToken=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
deny = await _http_get(f"http://127.0.0.1:{port}/auth/token")
|
||||||
|
assert deny.status_code == 401
|
||||||
|
|
||||||
|
issue = await _http_get(
|
||||||
|
f"http://127.0.0.1:{port}/auth/token",
|
||||||
|
headers={"Authorization": "Bearer route-secret"},
|
||||||
|
)
|
||||||
|
assert issue.status_code == 200
|
||||||
|
token = issue.json()["token"]
|
||||||
|
assert token.startswith("nbwt_")
|
||||||
|
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"):
|
||||||
|
pass
|
||||||
|
assert missing_token.value.response.status_code == 401
|
||||||
|
|
||||||
|
uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller"
|
||||||
|
async with websockets.connect(uri) as client:
|
||||||
|
ready = json.loads(await client.recv())
|
||||||
|
assert ready["event"] == "ready"
|
||||||
|
assert ready["client_id"] == "caller"
|
||||||
|
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as reuse:
|
||||||
|
async with websockets.connect(uri):
|
||||||
|
pass
|
||||||
|
assert reuse.value.response.status_code == 401
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
|
||||||
|
port = 29880
|
||||||
|
channel = _ch(bus, port=port, streaming=True)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=stream-tester") as client:
|
||||||
|
ready_raw = await client.recv()
|
||||||
|
ready = json.loads(ready_raw)
|
||||||
|
chat_id = ready["chat_id"]
|
||||||
|
|
||||||
|
# Server pushes deltas directly
|
||||||
|
await channel.send_delta(
|
||||||
|
chat_id, "Hello ", {"_stream_delta": True, "_stream_id": "s1"}
|
||||||
|
)
|
||||||
|
await channel.send_delta(
|
||||||
|
chat_id, "world", {"_stream_delta": True, "_stream_id": "s1"}
|
||||||
|
)
|
||||||
|
await channel.send_delta(
|
||||||
|
chat_id, "", {"_stream_end": True, "_stream_id": "s1"}
|
||||||
|
)
|
||||||
|
|
||||||
|
delta1 = json.loads(await client.recv())
|
||||||
|
assert delta1["event"] == "delta"
|
||||||
|
assert delta1["text"] == "Hello "
|
||||||
|
assert delta1["stream_id"] == "s1"
|
||||||
|
|
||||||
|
delta2 = json.loads(await client.recv())
|
||||||
|
assert delta2["event"] == "delta"
|
||||||
|
assert delta2["text"] == "world"
|
||||||
|
assert delta2["stream_id"] == "s1"
|
||||||
|
|
||||||
|
end = json.loads(await client.recv())
|
||||||
|
assert end["event"] == "stream_end"
|
||||||
|
assert end["stream_id"] == "s1"
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None:
|
||||||
|
port = 29881
|
||||||
|
channel = _ch(bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="s")
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fill issued tokens to capacity
|
||||||
|
channel._issued_tokens = {
|
||||||
|
f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._MAX_ISSUED_TOKENS)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = await _http_get(
|
||||||
|
f"http://127.0.0.1:{port}/auth/token",
|
||||||
|
headers={"Authorization": "Bearer s"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 429
|
||||||
|
data = resp.json()
|
||||||
|
assert "error" in data
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allow_from_rejects_unauthorized_client_id(bus: MagicMock) -> None:
|
||||||
|
port = 29882
|
||||||
|
channel = _ch(bus, port=port, allowFrom=["alice", "bob"])
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=eve"):
|
||||||
|
pass
|
||||||
|
assert exc_info.value.response.status_code == 403
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_id_truncation(bus: MagicMock) -> None:
|
||||||
|
port = 29883
|
||||||
|
channel = _ch(bus, port=port)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
long_id = "x" * 200
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id={long_id}") as client:
|
||||||
|
ready = json.loads(await client.recv())
|
||||||
|
assert ready["client_id"] == "x" * 128
|
||||||
|
assert len(ready["client_id"]) == 128
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_utf8_binary_frame_ignored(bus: MagicMock) -> None:
|
||||||
|
port = 29884
|
||||||
|
channel = _ch(bus, port=port)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=bin-test") as client:
|
||||||
|
await client.recv() # consume ready
|
||||||
|
# Send non-UTF-8 bytes
|
||||||
|
await client.send(b"\xff\xfe\xfd")
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
# publish_inbound should NOT have been called
|
||||||
|
bus.publish_inbound.assert_not_awaited()
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_static_token_accepts_issued_token_as_fallback(bus: MagicMock) -> None:
|
||||||
|
port = 29885
|
||||||
|
channel = _ch(
|
||||||
|
bus, port=port,
|
||||||
|
token="static-secret",
|
||||||
|
tokenIssuePath="/auth/token",
|
||||||
|
tokenIssueSecret="route-secret",
|
||||||
|
)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get an issued token
|
||||||
|
resp = await _http_get(
|
||||||
|
f"http://127.0.0.1:{port}/auth/token",
|
||||||
|
headers={"Authorization": "Bearer route-secret"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
issued_token = resp.json()["token"]
|
||||||
|
|
||||||
|
# Connect using issued token (not the static one)
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?token={issued_token}&client_id=caller") as client:
|
||||||
|
ready = json.loads(await client.recv())
|
||||||
|
assert ready["event"] == "ready"
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allow_from_empty_list_denies_all(bus: MagicMock) -> None:
|
||||||
|
port = 29886
|
||||||
|
channel = _ch(bus, port=port, allowFrom=[])
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=anyone"):
|
||||||
|
pass
|
||||||
|
assert exc_info.value.response.status_code == 403
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_websocket_requires_token_without_issue_path(bus: MagicMock) -> None:
|
||||||
|
"""When websocket_requires_token is True but no token or issue path configured, all connections are rejected."""
|
||||||
|
port = 29887
|
||||||
|
channel = _ch(bus, port=port, websocketRequiresToken=True)
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# No token at all → 401
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u"):
|
||||||
|
pass
|
||||||
|
assert exc_info.value.response.status_code == 401
|
||||||
|
|
||||||
|
# Wrong token → 401
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||||
|
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u&token=wrong"):
|
||||||
|
pass
|
||||||
|
assert exc_info.value.response.status_code == 401
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
478
tests/channels/test_websocket_integration.py
Normal file
478
tests/channels/test_websocket_integration.py
Normal file
@ -0,0 +1,478 @@
|
|||||||
|
"""Integration tests for the WebSocket channel using WsTestClient.
|
||||||
|
|
||||||
|
Complements the unit/lightweight tests in test_websocket_channel.py by covering
|
||||||
|
multi-client scenarios, edge cases, and realistic usage patterns.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import websockets
|
||||||
|
|
||||||
|
from nanobot.channels.websocket import WebSocketChannel
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from ws_test_client import WsTestClient, issue_token, issue_token_ok
|
||||||
|
|
||||||
|
|
||||||
|
def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
|
||||||
|
cfg: dict[str, Any] = {
|
||||||
|
"enabled": True,
|
||||||
|
"allowFrom": ["*"],
|
||||||
|
"host": "127.0.0.1",
|
||||||
|
"port": port,
|
||||||
|
"path": "/",
|
||||||
|
"websocketRequiresToken": False,
|
||||||
|
}
|
||||||
|
cfg.update(kw)
|
||||||
|
return WebSocketChannel(cfg, bus)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def bus() -> MagicMock:
|
||||||
|
b = MagicMock()
|
||||||
|
b.publish_inbound = AsyncMock()
|
||||||
|
return b
|
||||||
|
|
||||||
|
|
||||||
|
# -- Connection basics ----------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ready_event_fields(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29901)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29901/", client_id="c1") as c:
|
||||||
|
r = await c.recv_ready()
|
||||||
|
assert r.event == "ready"
|
||||||
|
assert len(r.chat_id) == 36
|
||||||
|
assert r.client_id == "c1"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29902)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29902/", client_id="") as c:
|
||||||
|
r = await c.recv_ready()
|
||||||
|
assert r.client_id.startswith("anon-")
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_each_connection_unique_chat_id(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29903)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29903/", client_id="a") as c1:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2:
|
||||||
|
assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
# -- Inbound messages (client -> server) ----------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plain_text(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29904)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29904/", client_id="p") as c:
|
||||||
|
await c.recv_ready()
|
||||||
|
await c.send_text("hello world")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
inbound = bus.publish_inbound.call_args[0][0]
|
||||||
|
assert inbound.content == "hello world"
|
||||||
|
assert inbound.sender_id == "p"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_json_content_field(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29905)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29905/", client_id="j") as c:
|
||||||
|
await c.recv_ready()
|
||||||
|
await c.send_json({"content": "structured"})
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert bus.publish_inbound.call_args[0][0].content == "structured"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_json_text_and_message_fields(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29906)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29906/", client_id="x") as c:
|
||||||
|
await c.recv_ready()
|
||||||
|
await c.send_json({"text": "via text"})
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert bus.publish_inbound.call_args[0][0].content == "via text"
|
||||||
|
await c.send_json({"message": "via message"})
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert bus.publish_inbound.call_args[0][0].content == "via message"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_payload_ignored(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29907)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29907/", client_id="e") as c:
|
||||||
|
await c.recv_ready()
|
||||||
|
await c.send_text(" ")
|
||||||
|
await c.send_json({})
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
bus.publish_inbound.assert_not_awaited()
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_messages_preserve_order(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29908)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29908/", client_id="o") as c:
|
||||||
|
await c.recv_ready()
|
||||||
|
for i in range(5):
|
||||||
|
await c.send_text(f"msg-{i}")
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
contents = [call[0][0].content for call in bus.publish_inbound.call_args_list]
|
||||||
|
assert contents == [f"msg-{i}" for i in range(5)]
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
# -- Outbound messages (server -> client) ---------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_send_message(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29909)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29909/", client_id="r") as c:
|
||||||
|
ready = await c.recv_ready()
|
||||||
|
await ch.send(OutboundMessage(
|
||||||
|
channel="websocket", chat_id=ready.chat_id, content="reply",
|
||||||
|
))
|
||||||
|
msg = await c.recv_message()
|
||||||
|
assert msg.text == "reply"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_send_with_media_and_reply(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29910)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29910/", client_id="m") as c:
|
||||||
|
ready = await c.recv_ready()
|
||||||
|
await ch.send(OutboundMessage(
|
||||||
|
channel="websocket", chat_id=ready.chat_id, content="img",
|
||||||
|
media=["/tmp/a.png"], reply_to="m1",
|
||||||
|
))
|
||||||
|
msg = await c.recv_message()
|
||||||
|
assert msg.text == "img"
|
||||||
|
assert msg.media == ["/tmp/a.png"]
|
||||||
|
assert msg.reply_to == "m1"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
# -- Streaming ------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_deltas_and_end(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29911, streaming=True)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29911/", client_id="s") as c:
|
||||||
|
cid = (await c.recv_ready()).chat_id
|
||||||
|
for part in ("Hello", " ", "world", "!"):
|
||||||
|
await ch.send_delta(cid, part, {"_stream_delta": True, "_stream_id": "s1"})
|
||||||
|
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "s1"})
|
||||||
|
|
||||||
|
msgs = await c.collect_stream()
|
||||||
|
deltas = [m for m in msgs if m.event == "delta"]
|
||||||
|
assert "".join(d.text for d in deltas) == "Hello world!"
|
||||||
|
ends = [m for m in msgs if m.event == "stream_end"]
|
||||||
|
assert len(ends) == 1
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_interleaved_streams(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29912, streaming=True)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29912/", client_id="i") as c:
|
||||||
|
cid = (await c.recv_ready()).chat_id
|
||||||
|
await ch.send_delta(cid, "A1", {"_stream_delta": True, "_stream_id": "sa"})
|
||||||
|
await ch.send_delta(cid, "B1", {"_stream_delta": True, "_stream_id": "sb"})
|
||||||
|
await ch.send_delta(cid, "A2", {"_stream_delta": True, "_stream_id": "sa"})
|
||||||
|
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sa"})
|
||||||
|
await ch.send_delta(cid, "B2", {"_stream_delta": True, "_stream_id": "sb"})
|
||||||
|
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sb"})
|
||||||
|
|
||||||
|
msgs = await c.recv_n(6)
|
||||||
|
sa = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sa")
|
||||||
|
sb = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sb")
|
||||||
|
assert sa == "A1A2"
|
||||||
|
assert sb == "B1B2"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
# -- Multi-client ---------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_independent_sessions(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29913)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u1") as c1:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u2") as c2:
|
||||||
|
r1, r2 = await c1.recv_ready(), await c2.recv_ready()
|
||||||
|
await ch.send(OutboundMessage(
|
||||||
|
channel="websocket", chat_id=r1.chat_id, content="for-u1",
|
||||||
|
))
|
||||||
|
assert (await c1.recv_message()).text == "for-u1"
|
||||||
|
await ch.send(OutboundMessage(
|
||||||
|
channel="websocket", chat_id=r2.chat_id, content="for-u2",
|
||||||
|
))
|
||||||
|
assert (await c2.recv_message()).text == "for-u2"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disconnected_client_cleanup(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29914)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29914/", client_id="tmp") as c:
|
||||||
|
chat_id = (await c.recv_ready()).chat_id
|
||||||
|
# disconnected
|
||||||
|
await ch.send(OutboundMessage(
|
||||||
|
channel="websocket", chat_id=chat_id, content="orphan",
|
||||||
|
))
|
||||||
|
assert chat_id not in ch._connections
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
# -- Authentication -------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_static_token_accepted(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29915, token="secret")
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c:
|
||||||
|
assert (await c.recv_ready()).client_id == "a"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_static_token_rejected(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29916, token="correct")
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29916/", client_id="b", token="wrong"):
|
||||||
|
pass
|
||||||
|
assert exc.value.response.status_code == 401
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_issue_full_flow(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29917, path="/ws",
|
||||||
|
tokenIssuePath="/auth/token", tokenIssueSecret="s",
|
||||||
|
websocketRequiresToken=True)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
# no secret -> 401
|
||||||
|
_, status = await issue_token(port=29917, issue_path="/auth/token")
|
||||||
|
assert status == 401
|
||||||
|
|
||||||
|
# with secret -> token
|
||||||
|
token = await issue_token_ok(port=29917, issue_path="/auth/token", secret="s")
|
||||||
|
|
||||||
|
# no token -> 401
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="x"):
|
||||||
|
pass
|
||||||
|
assert exc.value.response.status_code == 401
|
||||||
|
|
||||||
|
# valid token -> ok
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="ok", token=token) as c:
|
||||||
|
assert (await c.recv_ready()).client_id == "ok"
|
||||||
|
|
||||||
|
# reuse -> 401
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="r", token=token):
|
||||||
|
pass
|
||||||
|
assert exc.value.response.status_code == 401
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
# -- Path routing ---------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_custom_path(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29918, path="/my-chat")
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c:
|
||||||
|
assert (await c.recv_ready()).event == "ready"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wrong_path_404(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29919, path="/ws")
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29919/wrong", client_id="x"):
|
||||||
|
pass
|
||||||
|
assert exc.value.response.status_code == 404
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trailing_slash_normalized(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29920, path="/ws")
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c:
|
||||||
|
assert (await c.recv_ready()).event == "ready"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
# -- Edge cases -----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_large_message(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29921)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29921/", client_id="big") as c:
|
||||||
|
await c.recv_ready()
|
||||||
|
big = "x" * 100_000
|
||||||
|
await c.send_text(big)
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
assert bus.publish_inbound.call_args[0][0].content == big
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unicode_roundtrip(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29922)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29922/", client_id="u") as c:
|
||||||
|
ready = await c.recv_ready()
|
||||||
|
text = "你好世界 🌍 日本語テスト"
|
||||||
|
await c.send_text(text)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert bus.publish_inbound.call_args[0][0].content == text
|
||||||
|
await ch.send(OutboundMessage(
|
||||||
|
channel="websocket", chat_id=ready.chat_id, content=text,
|
||||||
|
))
|
||||||
|
assert (await c.recv_message()).text == text
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rapid_fire(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29923)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29923/", client_id="r") as c:
|
||||||
|
ready = await c.recv_ready()
|
||||||
|
for i in range(50):
|
||||||
|
await c.send_text(f"in-{i}")
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
assert bus.publish_inbound.await_count == 50
|
||||||
|
for i in range(50):
|
||||||
|
await ch.send(OutboundMessage(
|
||||||
|
channel="websocket", chat_id=ready.chat_id, content=f"out-{i}",
|
||||||
|
))
|
||||||
|
received = [(await c.recv_message()).text for _ in range(50)]
|
||||||
|
assert received == [f"out-{i}" for i in range(50)]
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_json_as_plain_text(bus: MagicMock) -> None:
|
||||||
|
ch = _ch(bus, 29924)
|
||||||
|
t = asyncio.create_task(ch.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
try:
|
||||||
|
async with WsTestClient("ws://127.0.0.1:29924/", client_id="j") as c:
|
||||||
|
await c.recv_ready()
|
||||||
|
await c.send_text("{broken json")
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
assert bus.publish_inbound.call_args[0][0].content == "{broken json"
|
||||||
|
finally:
|
||||||
|
await ch.stop(); await t
|
||||||
584
tests/channels/test_wecom_channel.py
Normal file
584
tests/channels/test_wecom_channel.py
Normal file
@ -0,0 +1,584 @@
|
|||||||
|
"""Tests for WeCom channel: helpers, download, upload, send, and message processing."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
try:
|
||||||
|
import importlib.util
|
||||||
|
|
||||||
|
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||||
|
except ImportError:
|
||||||
|
WECOM_AVAILABLE = False
|
||||||
|
|
||||||
|
if not WECOM_AVAILABLE:
|
||||||
|
pytest.skip("WeCom dependencies not installed (wecom_aibot_sdk)", allow_module_level=True)
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.wecom import (
|
||||||
|
WecomChannel,
|
||||||
|
WecomConfig,
|
||||||
|
_guess_wecom_media_type,
|
||||||
|
_sanitize_filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to import the real response class; fall back to a stub if unavailable.
|
||||||
|
try:
|
||||||
|
from wecom_aibot_sdk.utils import WsResponse
|
||||||
|
|
||||||
|
_RealWsResponse = WsResponse
|
||||||
|
except ImportError:
|
||||||
|
_RealWsResponse = None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResponse:
|
||||||
|
"""Minimal stand-in for wecom_aibot_sdk WsResponse."""
|
||||||
|
|
||||||
|
def __init__(self, errcode: int = 0, body: dict | None = None, errmsg: str = "ok"):
|
||||||
|
self.errcode = errcode
|
||||||
|
self.errmsg = errmsg
|
||||||
|
self.body = body or {}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeWsManager:
|
||||||
|
"""Tracks send_reply calls and returns configurable responses."""
|
||||||
|
|
||||||
|
def __init__(self, responses: list[_FakeResponse] | None = None):
|
||||||
|
self.responses = responses or []
|
||||||
|
self.calls: list[tuple[str, dict, str]] = []
|
||||||
|
self._idx = 0
|
||||||
|
|
||||||
|
async def send_reply(self, req_id: str, data: dict, cmd: str) -> _FakeResponse:
|
||||||
|
self.calls.append((req_id, data, cmd))
|
||||||
|
if self._idx < len(self.responses):
|
||||||
|
resp = self.responses[self._idx]
|
||||||
|
self._idx += 1
|
||||||
|
return resp
|
||||||
|
return _FakeResponse()
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeFrame:
|
||||||
|
"""Minimal frame object with a body dict."""
|
||||||
|
|
||||||
|
def __init__(self, body: dict | None = None):
|
||||||
|
self.body = body or {}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeWeComClient:
|
||||||
|
"""Fake WeCom client with mock methods."""
|
||||||
|
|
||||||
|
def __init__(self, ws_responses: list[_FakeResponse] | None = None):
|
||||||
|
self._ws_manager = _FakeWsManager(ws_responses)
|
||||||
|
self.download_file = AsyncMock(return_value=(None, None))
|
||||||
|
self.reply = AsyncMock()
|
||||||
|
self.reply_stream = AsyncMock()
|
||||||
|
self.send_message = AsyncMock()
|
||||||
|
self.reply_welcome = AsyncMock()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helper function tests (pure, no async) ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_filename_strips_path_traversal() -> None:
|
||||||
|
assert _sanitize_filename("../../etc/passwd") == "passwd"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_filename_keeps_chinese_chars() -> None:
|
||||||
|
assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_filename_empty_input() -> None:
|
||||||
|
assert _sanitize_filename("") == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_guess_wecom_media_type_image() -> None:
|
||||||
|
for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"):
|
||||||
|
assert _guess_wecom_media_type(f"photo{ext}") == "image"
|
||||||
|
|
||||||
|
|
||||||
|
def test_guess_wecom_media_type_video() -> None:
|
||||||
|
for ext in (".mp4", ".avi", ".mov"):
|
||||||
|
assert _guess_wecom_media_type(f"video{ext}") == "video"
|
||||||
|
|
||||||
|
|
||||||
|
def test_guess_wecom_media_type_voice() -> None:
|
||||||
|
for ext in (".amr", ".mp3", ".wav", ".ogg"):
|
||||||
|
assert _guess_wecom_media_type(f"audio{ext}") == "voice"
|
||||||
|
|
||||||
|
|
||||||
|
def test_guess_wecom_media_type_file_fallback() -> None:
|
||||||
|
for ext in (".pdf", ".doc", ".xlsx", ".zip"):
|
||||||
|
assert _guess_wecom_media_type(f"doc{ext}") == "file"
|
||||||
|
|
||||||
|
|
||||||
|
def test_guess_wecom_media_type_case_insensitive() -> None:
|
||||||
|
assert _guess_wecom_media_type("photo.PNG") == "image"
|
||||||
|
assert _guess_wecom_media_type("photo.Jpg") == "image"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _download_and_save_media() ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_and_save_success() -> None:
|
||||||
|
"""Successful download writes file and returns sanitized path."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
fake_data = b"\x89PNG\r\nfake image"
|
||||||
|
client.download_file.return_value = (fake_data, "raw_photo.png")
|
||||||
|
|
||||||
|
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||||
|
path = await channel._download_and_save_media("https://example.com/img.png", "aes_key", "image", "photo.png")
|
||||||
|
|
||||||
|
assert path is not None
|
||||||
|
assert os.path.isfile(path)
|
||||||
|
assert os.path.basename(path) == "photo.png"
|
||||||
|
# Cleanup
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_and_save_oversized_rejected() -> None:
|
||||||
|
"""Data exceeding 200MB is rejected → returns None."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
big_data = b"\x00" * (200 * 1024 * 1024 + 1) # 200MB + 1 byte
|
||||||
|
client.download_file.return_value = (big_data, "big.bin")
|
||||||
|
|
||||||
|
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||||
|
result = await channel._download_and_save_media("https://example.com/big.bin", "key", "file", "big.bin")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_and_save_failure() -> None:
|
||||||
|
"""SDK returns None data → returns None."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
client.download_file.return_value = (None, None)
|
||||||
|
|
||||||
|
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||||
|
result = await channel._download_and_save_media("https://example.com/fail.png", "key", "image")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── _upload_media_ws() ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_media_ws_success() -> None:
|
||||||
|
"""Happy path: init → chunk → finish → returns (media_id, media_type)."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||||
|
f.write(b"\x89PNG\r\n")
|
||||||
|
tmp = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
responses = [
|
||||||
|
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||||
|
_FakeResponse(errcode=0, body={}),
|
||||||
|
_FakeResponse(errcode=0, body={"media_id": "media_abc"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
client = _FakeWeComClient(responses)
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||||
|
media_id, media_type = await channel._upload_media_ws(client, tmp)
|
||||||
|
|
||||||
|
assert media_id == "media_abc"
|
||||||
|
assert media_type == "image"
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_media_ws_oversized_file() -> None:
|
||||||
|
"""File >200MB triggers ValueError → returns (None, None)."""
|
||||||
|
# Instead of creating a real 200MB+ file, mock os.path.getsize and open
|
||||||
|
with patch("os.path.getsize", return_value=200 * 1024 * 1024 + 1), \
|
||||||
|
patch("builtins.open", MagicMock()):
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
result = await channel._upload_media_ws(client, "/fake/large.bin")
|
||||||
|
assert result == (None, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_media_ws_init_failure() -> None:
|
||||||
|
"""Init step returns errcode != 0 → returns (None, None)."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
||||||
|
f.write(b"hello")
|
||||||
|
tmp = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
responses = [
|
||||||
|
_FakeResponse(errcode=50001, errmsg="invalid"),
|
||||||
|
]
|
||||||
|
|
||||||
|
client = _FakeWeComClient(responses)
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||||
|
result = await channel._upload_media_ws(client, tmp)
|
||||||
|
|
||||||
|
assert result == (None, None)
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_media_ws_chunk_failure() -> None:
|
||||||
|
"""Chunk step returns errcode != 0 → returns (None, None)."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||||
|
f.write(b"\x89PNG\r\n")
|
||||||
|
tmp = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
responses = [
|
||||||
|
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||||
|
_FakeResponse(errcode=50002, errmsg="chunk fail"),
|
||||||
|
]
|
||||||
|
|
||||||
|
client = _FakeWeComClient(responses)
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||||
|
result = await channel._upload_media_ws(client, tmp)
|
||||||
|
|
||||||
|
assert result == (None, None)
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_media_ws_finish_no_media_id() -> None:
|
||||||
|
"""Finish step returns empty media_id → returns (None, None)."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||||
|
f.write(b"\x89PNG\r\n")
|
||||||
|
tmp = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
responses = [
|
||||||
|
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||||
|
_FakeResponse(errcode=0, body={}),
|
||||||
|
_FakeResponse(errcode=0, body={}), # no media_id
|
||||||
|
]
|
||||||
|
|
||||||
|
client = _FakeWeComClient(responses)
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||||
|
result = await channel._upload_media_ws(client, tmp)
|
||||||
|
|
||||||
|
assert result == (None, None)
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
# ── send() ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_text_with_frame() -> None:
|
||||||
|
"""When frame is stored, send uses reply_stream for final text."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
channel._generate_req_id = lambda x: f"req_{x}"
|
||||||
|
channel._chat_frames["chat1"] = _FakeFrame()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(channel="wecom", chat_id="chat1", content="hello")
|
||||||
|
)
|
||||||
|
|
||||||
|
client.reply_stream.assert_called_once()
|
||||||
|
call_args = client.reply_stream.call_args
|
||||||
|
assert call_args[0][2] == "hello" # content arg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_progress_with_frame() -> None:
|
||||||
|
"""When metadata has _progress, send uses reply_stream with finish=False."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
channel._generate_req_id = lambda x: f"req_{x}"
|
||||||
|
channel._chat_frames["chat1"] = _FakeFrame()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(channel="wecom", chat_id="chat1", content="thinking...", metadata={"_progress": True})
|
||||||
|
)
|
||||||
|
|
||||||
|
client.reply_stream.assert_called_once()
|
||||||
|
call_args = client.reply_stream.call_args
|
||||||
|
assert call_args[0][2] == "thinking..." # content arg
|
||||||
|
assert call_args[1]["finish"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_proactive_without_frame() -> None:
|
||||||
|
"""Without stored frame, send uses send_message with markdown."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(channel="wecom", chat_id="chat1", content="proactive msg")
|
||||||
|
)
|
||||||
|
|
||||||
|
client.send_message.assert_called_once()
|
||||||
|
call_args = client.send_message.call_args
|
||||||
|
assert call_args[0][0] == "chat1"
|
||||||
|
assert call_args[0][1]["msgtype"] == "markdown"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_then_text() -> None:
|
||||||
|
"""Media files are uploaded and sent before text content."""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||||
|
f.write(b"\x89PNG\r\n")
|
||||||
|
tmp = f.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
responses = [
|
||||||
|
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||||
|
_FakeResponse(errcode=0, body={}),
|
||||||
|
_FakeResponse(errcode=0, body={"media_id": "media_123"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeWeComClient(responses)
|
||||||
|
channel._client = client
|
||||||
|
channel._generate_req_id = lambda x: f"req_{x}"
|
||||||
|
channel._chat_frames["chat1"] = _FakeFrame()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(channel="wecom", chat_id="chat1", content="see image", media=[tmp])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Media should have been sent via reply
|
||||||
|
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") == "image"]
|
||||||
|
assert len(media_calls) == 1
|
||||||
|
assert media_calls[0][0][1]["image"]["media_id"] == "media_123"
|
||||||
|
|
||||||
|
# Text should have been sent via reply_stream
|
||||||
|
client.reply_stream.assert_called_once()
|
||||||
|
finally:
|
||||||
|
os.unlink(tmp)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_file_not_found() -> None:
|
||||||
|
"""Non-existent media path is skipped with a warning."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
channel._generate_req_id = lambda x: f"req_{x}"
|
||||||
|
channel._chat_frames["chat1"] = _FakeFrame()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(channel="wecom", chat_id="chat1", content="hello", media=["/nonexistent/file.png"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# reply_stream should still be called for the text part
|
||||||
|
client.reply_stream.assert_called_once()
|
||||||
|
# No media reply should happen
|
||||||
|
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") in ("image", "file", "video")]
|
||||||
|
assert len(media_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_exception_caught_not_raised() -> None:
|
||||||
|
"""Exceptions inside send() must not propagate."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
channel._generate_req_id = lambda x: f"req_{x}"
|
||||||
|
channel._chat_frames["chat1"] = _FakeFrame()
|
||||||
|
|
||||||
|
# Make reply_stream raise
|
||||||
|
client.reply_stream.side_effect = RuntimeError("boom")
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(channel="wecom", chat_id="chat1", content="fail test")
|
||||||
|
)
|
||||||
|
# No exception — test passes if we reach here.
|
||||||
|
|
||||||
|
|
||||||
|
# ── _process_message() ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_text_message() -> None:
|
||||||
|
"""Text message is routed to bus with correct fields."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
frame = _FakeFrame(body={
|
||||||
|
"msgid": "msg_text_1",
|
||||||
|
"chatid": "chat1",
|
||||||
|
"chattype": "single",
|
||||||
|
"from": {"userid": "user1"},
|
||||||
|
"text": {"content": "hello wecom"},
|
||||||
|
})
|
||||||
|
|
||||||
|
await channel._process_message(frame, "text")
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert msg.sender_id == "user1"
|
||||||
|
assert msg.chat_id == "chat1"
|
||||||
|
assert msg.content == "hello wecom"
|
||||||
|
assert msg.metadata["msg_type"] == "text"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_image_message() -> None:
|
||||||
|
"""Image message: download success → media_paths non-empty."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||||
|
f.write(b"\x89PNG\r\n")
|
||||||
|
saved = f.name
|
||||||
|
|
||||||
|
client.download_file.return_value = (b"\x89PNG\r\n", "photo.png")
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
|
||||||
|
frame = _FakeFrame(body={
|
||||||
|
"msgid": "msg_img_1",
|
||||||
|
"chatid": "chat1",
|
||||||
|
"from": {"userid": "user1"},
|
||||||
|
"image": {"url": "https://example.com/img.png", "aeskey": "key123"},
|
||||||
|
})
|
||||||
|
await channel._process_message(frame, "image")
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert len(msg.media) == 1
|
||||||
|
assert msg.media[0].endswith("photo.png")
|
||||||
|
assert "[image:" in msg.content
|
||||||
|
finally:
|
||||||
|
if os.path.exists(saved):
|
||||||
|
pass # may have been overwritten; clean up if exists
|
||||||
|
# Clean up any photo.png in tempdir
|
||||||
|
p = os.path.join(os.path.dirname(saved), "photo.png")
|
||||||
|
if os.path.exists(p):
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_file_message() -> None:
|
||||||
|
"""File message: download success → media_paths non-empty (critical fix verification)."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||||
|
f.write(b"%PDF-1.4 fake")
|
||||||
|
saved = f.name
|
||||||
|
|
||||||
|
client.download_file.return_value = (b"%PDF-1.4 fake", "report.pdf")
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
|
||||||
|
frame = _FakeFrame(body={
|
||||||
|
"msgid": "msg_file_1",
|
||||||
|
"chatid": "chat1",
|
||||||
|
"from": {"userid": "user1"},
|
||||||
|
"file": {"url": "https://example.com/report.pdf", "aeskey": "key456", "name": "report.pdf"},
|
||||||
|
})
|
||||||
|
await channel._process_message(frame, "file")
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert len(msg.media) == 1
|
||||||
|
assert msg.media[0].endswith("report.pdf")
|
||||||
|
assert "[file: report.pdf]" in msg.content
|
||||||
|
finally:
|
||||||
|
p = os.path.join(os.path.dirname(saved), "report.pdf")
|
||||||
|
if os.path.exists(p):
|
||||||
|
os.unlink(p)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_voice_message() -> None:
|
||||||
|
"""Voice message: transcribed text is included in content."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
frame = _FakeFrame(body={
|
||||||
|
"msgid": "msg_voice_1",
|
||||||
|
"chatid": "chat1",
|
||||||
|
"from": {"userid": "user1"},
|
||||||
|
"voice": {"content": "transcribed text here"},
|
||||||
|
})
|
||||||
|
|
||||||
|
await channel._process_message(frame, "voice")
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert "transcribed text here" in msg.content
|
||||||
|
assert "[voice]" in msg.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_message_deduplication() -> None:
|
||||||
|
"""Same msg_id is not processed twice."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
frame = _FakeFrame(body={
|
||||||
|
"msgid": "msg_dup_1",
|
||||||
|
"chatid": "chat1",
|
||||||
|
"from": {"userid": "user1"},
|
||||||
|
"text": {"content": "once"},
|
||||||
|
})
|
||||||
|
|
||||||
|
await channel._process_message(frame, "text")
|
||||||
|
await channel._process_message(frame, "text")
|
||||||
|
|
||||||
|
msg = await channel.bus.consume_inbound()
|
||||||
|
assert msg.content == "once"
|
||||||
|
|
||||||
|
# Second message should not appear on the bus
|
||||||
|
assert channel.bus.inbound.empty()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_message_empty_content_skipped() -> None:
|
||||||
|
"""Message with empty content produces no bus message."""
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
|
||||||
|
frame = _FakeFrame(body={
|
||||||
|
"msgid": "msg_empty_1",
|
||||||
|
"chatid": "chat1",
|
||||||
|
"from": {"userid": "user1"},
|
||||||
|
"text": {"content": ""},
|
||||||
|
})
|
||||||
|
|
||||||
|
await channel._process_message(frame, "text")
|
||||||
|
|
||||||
|
assert channel.bus.inbound.empty()
|
||||||
@ -1003,3 +1003,185 @@ async def test_download_media_item_non_image_requires_aes_key_even_with_full_url
|
|||||||
|
|
||||||
assert saved_path is None
|
assert saved_path is None
|
||||||
channel._client.get.assert_not_awaited()
|
channel._client.get.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests for media-send error classification (network vs non-network errors)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_outbound_msg(chat_id: str = "wx-user", content: str = "", media: list | None = None):
|
||||||
|
"""Build a minimal OutboundMessage-like object for send() tests."""
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
|
||||||
|
return OutboundMessage(
|
||||||
|
channel="weixin",
|
||||||
|
chat_id=chat_id,
|
||||||
|
content=content,
|
||||||
|
media=media or [],
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_timeout_error_propagates_without_text_fallback() -> None:
|
||||||
|
"""httpx.TimeoutException during media send must re-raise immediately,
|
||||||
|
NOT fall back to _send_text (which would also fail during network issues)."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._send_media_file = AsyncMock(side_effect=httpx.TimeoutException("timed out"))
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||||
|
|
||||||
|
with pytest.raises(httpx.TimeoutException, match="timed out"):
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
# _send_text must NOT have been called as a fallback
|
||||||
|
channel._send_text.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_transport_error_propagates_without_text_fallback() -> None:
|
||||||
|
"""httpx.TransportError during media send must re-raise immediately."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._send_media_file = AsyncMock(
|
||||||
|
side_effect=httpx.TransportError("connection reset")
|
||||||
|
)
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||||
|
|
||||||
|
with pytest.raises(httpx.TransportError, match="connection reset"):
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
channel._send_text.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_5xx_http_status_error_propagates_without_text_fallback() -> None:
|
||||||
|
"""httpx.HTTPStatusError with a 5xx status must re-raise immediately."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
|
||||||
|
fake_response = httpx.Response(
|
||||||
|
status_code=503,
|
||||||
|
request=httpx.Request("POST", "https://example.test/upload"),
|
||||||
|
)
|
||||||
|
channel._send_media_file = AsyncMock(
|
||||||
|
side_effect=httpx.HTTPStatusError(
|
||||||
|
"Service Unavailable", request=fake_response.request, response=fake_response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||||
|
|
||||||
|
with pytest.raises(httpx.HTTPStatusError, match="Service Unavailable"):
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
channel._send_text.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_4xx_http_status_error_falls_back_to_text() -> None:
|
||||||
|
"""httpx.HTTPStatusError with a 4xx status should fall back to text, not re-raise."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
|
||||||
|
fake_response = httpx.Response(
|
||||||
|
status_code=400,
|
||||||
|
request=httpx.Request("POST", "https://example.test/upload"),
|
||||||
|
)
|
||||||
|
channel._send_media_file = AsyncMock(
|
||||||
|
side_effect=httpx.HTTPStatusError(
|
||||||
|
"Bad Request", request=fake_response.request, response=fake_response
|
||||||
|
)
|
||||||
|
)
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||||
|
|
||||||
|
# Should NOT raise — 4xx is a client error, non-retryable
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
# _send_text should have been called with the fallback message
|
||||||
|
channel._send_text.assert_awaited_once_with(
|
||||||
|
"wx-user", "[Failed to send: photo.jpg]", "ctx-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_file_not_found_falls_back_to_text() -> None:
|
||||||
|
"""FileNotFoundError (a non-network error) should fall back to text."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._send_media_file = AsyncMock(
|
||||||
|
side_effect=FileNotFoundError("Media file not found: /tmp/missing.jpg")
|
||||||
|
)
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/missing.jpg"])
|
||||||
|
|
||||||
|
# Should NOT raise
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
channel._send_text.assert_awaited_once_with(
|
||||||
|
"wx-user", "[Failed to send: missing.jpg]", "ctx-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_value_error_falls_back_to_text() -> None:
|
||||||
|
"""ValueError (e.g. unsupported format) should fall back to text."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._send_media_file = AsyncMock(
|
||||||
|
side_effect=ValueError("Unsupported media format")
|
||||||
|
)
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/file.xyz"])
|
||||||
|
|
||||||
|
# Should NOT raise
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
channel._send_text.assert_awaited_once_with(
|
||||||
|
"wx-user", "[Failed to send: file.xyz]", "ctx-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_network_error_does_not_double_api_calls() -> None:
|
||||||
|
"""During network issues, media send should make exactly 1 API call attempt,
|
||||||
|
not 2 (media + text fallback). Verify total call count."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._send_media_file = AsyncMock(
|
||||||
|
side_effect=httpx.ConnectError("connection refused")
|
||||||
|
)
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
msg = _make_outbound_msg(chat_id="wx-user", content="hello", media=["/tmp/img.png"])
|
||||||
|
|
||||||
|
with pytest.raises(httpx.ConnectError):
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
# _send_media_file called once, _send_text never called
|
||||||
|
channel._send_media_file.assert_awaited_once()
|
||||||
|
channel._send_text.assert_not_awaited()
|
||||||
|
|||||||
227
tests/channels/ws_test_client.py
Normal file
227
tests/channels/ws_test_client.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
"""Lightweight WebSocket test client for integration testing the nanobot WebSocket channel.
|
||||||
|
|
||||||
|
Provides an async ``WsTestClient`` class and token-issuance helpers that
|
||||||
|
integration tests can import and use directly::
|
||||||
|
|
||||||
|
from ws_test_client import WsTestClient
|
||||||
|
|
||||||
|
async with WsTestClient("ws://127.0.0.1:8765/", client_id="t") as c:
|
||||||
|
ready = await c.recv_ready()
|
||||||
|
await c.send_text("hello")
|
||||||
|
msg = await c.recv_message()
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import websockets
|
||||||
|
from websockets.asyncio.client import ClientConnection
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WsMessage:
|
||||||
|
"""A parsed message received from the WebSocket server."""
|
||||||
|
|
||||||
|
event: str
|
||||||
|
raw: dict[str, Any] = field(repr=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str | None:
|
||||||
|
return self.raw.get("text")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_id(self) -> str | None:
|
||||||
|
return self.raw.get("chat_id")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client_id(self) -> str | None:
|
||||||
|
return self.raw.get("client_id")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def media(self) -> list[str] | None:
|
||||||
|
return self.raw.get("media")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reply_to(self) -> str | None:
|
||||||
|
return self.raw.get("reply_to")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stream_id(self) -> str | None:
|
||||||
|
return self.raw.get("stream_id")
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, WsMessage):
|
||||||
|
return NotImplemented
|
||||||
|
return self.event == other.event and self.raw == other.raw
|
||||||
|
|
||||||
|
|
||||||
|
class WsTestClient:
|
||||||
|
"""Async WebSocket test client with helper methods for common operations.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
async with WsTestClient("ws://127.0.0.1:8765/", client_id="tester") as client:
|
||||||
|
ready = await client.recv_ready()
|
||||||
|
await client.send_text("hello")
|
||||||
|
msg = await client.recv_message(timeout=5.0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
uri: str,
|
||||||
|
*,
|
||||||
|
client_id: str = "test-client",
|
||||||
|
token: str = "",
|
||||||
|
extra_headers: dict[str, str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
params: list[str] = []
|
||||||
|
if client_id:
|
||||||
|
params.append(f"client_id={client_id}")
|
||||||
|
if token:
|
||||||
|
params.append(f"token={token}")
|
||||||
|
sep = "&" if "?" in uri else "?"
|
||||||
|
self._uri = uri + sep + "&".join(params) if params else uri
|
||||||
|
self._extra_headers = extra_headers
|
||||||
|
self._ws: ClientConnection | None = None
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
self._ws = await websockets.connect(
|
||||||
|
self._uri,
|
||||||
|
additional_headers=self._extra_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self._ws:
|
||||||
|
await self._ws.close()
|
||||||
|
self._ws = None
|
||||||
|
|
||||||
|
async def __aenter__(self) -> WsTestClient:
|
||||||
|
await self.connect()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *args: Any) -> None:
|
||||||
|
await self.close()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ws(self) -> ClientConnection:
|
||||||
|
assert self._ws is not None, "Client is not connected"
|
||||||
|
return self._ws
|
||||||
|
|
||||||
|
# -- Receiving --------------------------------------------------------
|
||||||
|
|
||||||
|
async def recv_raw(self, timeout: float = 10.0) -> dict[str, Any]:
|
||||||
|
"""Receive and parse one raw JSON message with timeout."""
|
||||||
|
raw = await asyncio.wait_for(self.ws.recv(), timeout=timeout)
|
||||||
|
return json.loads(raw)
|
||||||
|
|
||||||
|
async def recv(self, timeout: float = 10.0) -> WsMessage:
|
||||||
|
"""Receive one message, returning a WsMessage wrapper."""
|
||||||
|
data = await self.recv_raw(timeout)
|
||||||
|
return WsMessage(event=data.get("event", ""), raw=data)
|
||||||
|
|
||||||
|
async def recv_ready(self, timeout: float = 5.0) -> WsMessage:
|
||||||
|
"""Receive and validate the 'ready' event."""
|
||||||
|
msg = await self.recv(timeout)
|
||||||
|
assert msg.event == "ready", f"Expected 'ready' event, got '{msg.event}'"
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def recv_message(self, timeout: float = 10.0) -> WsMessage:
|
||||||
|
"""Receive and validate a 'message' event."""
|
||||||
|
msg = await self.recv(timeout)
|
||||||
|
assert msg.event == "message", f"Expected 'message' event, got '{msg.event}'"
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def recv_delta(self, timeout: float = 10.0) -> WsMessage:
|
||||||
|
"""Receive and validate a 'delta' event."""
|
||||||
|
msg = await self.recv(timeout)
|
||||||
|
assert msg.event == "delta", f"Expected 'delta' event, got '{msg.event}'"
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def recv_stream_end(self, timeout: float = 10.0) -> WsMessage:
|
||||||
|
"""Receive and validate a 'stream_end' event."""
|
||||||
|
msg = await self.recv(timeout)
|
||||||
|
assert msg.event == "stream_end", f"Expected 'stream_end' event, got '{msg.event}'"
|
||||||
|
return msg
|
||||||
|
|
||||||
|
async def collect_stream(self, timeout: float = 10.0) -> list[WsMessage]:
|
||||||
|
"""Collect all deltas and the final stream_end into a list."""
|
||||||
|
messages: list[WsMessage] = []
|
||||||
|
while True:
|
||||||
|
msg = await self.recv(timeout)
|
||||||
|
messages.append(msg)
|
||||||
|
if msg.event == "stream_end":
|
||||||
|
break
|
||||||
|
return messages
|
||||||
|
|
||||||
|
async def recv_n(self, n: int, timeout: float = 10.0) -> list[WsMessage]:
|
||||||
|
"""Receive exactly *n* messages."""
|
||||||
|
return [await self.recv(timeout) for _ in range(n)]
|
||||||
|
|
||||||
|
# -- Sending ----------------------------------------------------------
|
||||||
|
|
||||||
|
async def send_text(self, text: str) -> None:
|
||||||
|
"""Send a plain text frame."""
|
||||||
|
await self.ws.send(text)
|
||||||
|
|
||||||
|
async def send_json(self, data: dict[str, Any]) -> None:
|
||||||
|
"""Send a JSON frame."""
|
||||||
|
await self.ws.send(json.dumps(data, ensure_ascii=False))
|
||||||
|
|
||||||
|
async def send_content(self, content: str) -> None:
|
||||||
|
"""Send content in the preferred JSON format ``{"content": ...}``."""
|
||||||
|
await self.send_json({"content": content})
|
||||||
|
|
||||||
|
# -- Connection introspection -----------------------------------------
|
||||||
|
|
||||||
|
@property
|
||||||
|
def closed(self) -> bool:
|
||||||
|
return self._ws is None or self._ws.closed
|
||||||
|
|
||||||
|
|
||||||
|
# -- Token issuance helpers -----------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def issue_token(
|
||||||
|
host: str = "127.0.0.1",
|
||||||
|
port: int = 8765,
|
||||||
|
issue_path: str = "/auth/token",
|
||||||
|
secret: str = "",
|
||||||
|
) -> tuple[dict[str, Any] | None, int]:
|
||||||
|
"""Request a short-lived token from the token-issue HTTP endpoint.
|
||||||
|
|
||||||
|
Returns ``(parsed_json_or_None, status_code)``.
|
||||||
|
"""
|
||||||
|
url = f"http://{host}:{port}{issue_path}"
|
||||||
|
headers: dict[str, str] = {}
|
||||||
|
if secret:
|
||||||
|
headers["Authorization"] = f"Bearer {secret}"
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
resp = await loop.run_in_executor(
|
||||||
|
None, lambda: httpx.get(url, headers=headers, timeout=5.0)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
data = resp.json()
|
||||||
|
except Exception:
|
||||||
|
data = None
|
||||||
|
return data, resp.status_code
|
||||||
|
|
||||||
|
|
||||||
|
async def issue_token_ok(
|
||||||
|
host: str = "127.0.0.1",
|
||||||
|
port: int = 8765,
|
||||||
|
issue_path: str = "/auth/token",
|
||||||
|
secret: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Request a token, asserting success, and return the token string."""
|
||||||
|
(data, status) = await issue_token(host, port, issue_path, secret)
|
||||||
|
assert status == 200, f"Token issue failed with status {status}"
|
||||||
|
assert data is not None
|
||||||
|
token = data["token"]
|
||||||
|
assert token.startswith("nbwt_"), f"Unexpected token format: {token}"
|
||||||
|
return token
|
||||||
@ -1126,6 +1126,153 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
|
|||||||
assert "port 18792" in result.stdout
|
assert "port 18792" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_health_endpoint_binds_and_serves_expected_responses(
|
||||||
|
monkeypatch, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
config_file = _write_instance_config(tmp_path)
|
||||||
|
config = Config()
|
||||||
|
config.gateway.port = 18791
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
class _FakeDream:
|
||||||
|
model = None
|
||||||
|
max_batch_size = 0
|
||||||
|
max_iterations = 0
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class _FakeAgentLoop:
|
||||||
|
def __init__(self, **_kwargs) -> None:
|
||||||
|
self.model = "test-model"
|
||||||
|
self.dream = _FakeDream()
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
async def close_mcp(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class _FakeChannelManager:
|
||||||
|
def __init__(self, _config, _bus) -> None:
|
||||||
|
self.enabled_channels = ["telegram", "discord"]
|
||||||
|
|
||||||
|
async def start_all(self) -> None:
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
async def stop_all(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class _FakeCronService:
|
||||||
|
def __init__(self, _store_path: Path) -> None:
|
||||||
|
self.on_job = None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def status(self) -> dict[str, int]:
|
||||||
|
return {"jobs": 0}
|
||||||
|
|
||||||
|
def register_system_job(self, _job) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class _FakeHeartbeatService:
|
||||||
|
def __init__(self, **_kwargs) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class _FakeServer:
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def serve_forever(self) -> None:
|
||||||
|
raise _StopGatewayError("stop")
|
||||||
|
|
||||||
|
async def _fake_start_server(handler, host: str, port: int):
|
||||||
|
captured["handler"] = handler
|
||||||
|
captured["host"] = host
|
||||||
|
captured["port"] = port
|
||||||
|
return _FakeServer()
|
||||||
|
|
||||||
|
class _FakeReader:
|
||||||
|
def __init__(self, payload: bytes) -> None:
|
||||||
|
self.payload = payload
|
||||||
|
|
||||||
|
async def read(self, _size: int) -> bytes:
|
||||||
|
return self.payload
|
||||||
|
|
||||||
|
class _FakeWriter:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.output = b""
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
def write(self, data: bytes) -> None:
|
||||||
|
self.output += data
|
||||||
|
|
||||||
|
async def drain(self) -> None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
_patch_cli_command_runtime(
|
||||||
|
monkeypatch,
|
||||||
|
config,
|
||||||
|
message_bus=lambda: object(),
|
||||||
|
session_manager=lambda _workspace: object(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||||
|
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager)
|
||||||
|
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCronService)
|
||||||
|
monkeypatch.setattr("nanobot.heartbeat.service.HeartbeatService", _FakeHeartbeatService)
|
||||||
|
monkeypatch.setattr("asyncio.start_server", _fake_start_server)
|
||||||
|
|
||||||
|
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert captured["host"] == "127.0.0.1"
|
||||||
|
assert captured["port"] == 18791
|
||||||
|
assert "Health endpoint: http://127.0.0.1:18791/health" in result.stdout
|
||||||
|
|
||||||
|
def _call_handler(path: str) -> tuple[str, _FakeWriter]:
|
||||||
|
request = f"GET {path} HTTP/1.1\r\nHost: localhost\r\n\r\n".encode()
|
||||||
|
writer = _FakeWriter()
|
||||||
|
handler = captured["handler"]
|
||||||
|
assert callable(handler)
|
||||||
|
asyncio.run(handler(_FakeReader(request), writer))
|
||||||
|
return writer.output.decode(), writer
|
||||||
|
|
||||||
|
root_response, root_writer = _call_handler("/")
|
||||||
|
assert root_writer.closed is True
|
||||||
|
assert "HTTP/1.0 404 Not Found" in root_response
|
||||||
|
assert root_response.endswith("\r\n\r\nNot Found")
|
||||||
|
|
||||||
|
health_response, health_writer = _call_handler("/health")
|
||||||
|
assert health_writer.closed is True
|
||||||
|
assert "HTTP/1.0 200 OK" in health_response
|
||||||
|
health_body = json.loads(health_response.split("\r\n\r\n", 1)[1])
|
||||||
|
assert health_body == {"status": "ok"}
|
||||||
|
|
||||||
|
missing_response, missing_writer = _call_handler("/missing")
|
||||||
|
assert missing_writer.closed is True
|
||||||
|
assert "HTTP/1.0 404 Not Found" in missing_response
|
||||||
|
assert missing_response.endswith("\r\n\r\nNot Found")
|
||||||
|
|
||||||
|
|
||||||
def test_serve_uses_api_config_defaults_and_workspace_override(
|
def test_serve_uses_api_config_defaults_and_workspace_override(
|
||||||
monkeypatch, tmp_path: Path
|
monkeypatch, tmp_path: Path
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@ -148,7 +148,7 @@ class TestRestartCommand:
|
|||||||
assert response is not None
|
assert response is not None
|
||||||
assert "Model: test-model" in response.content
|
assert "Model: test-model" in response.content
|
||||||
assert "Tokens: 0 in / 0 out" in response.content
|
assert "Tokens: 0 in / 0 out" in response.content
|
||||||
assert "Context: 20k/64k (31%)" in response.content
|
assert "Context: 20k/65k (31%)" in response.content
|
||||||
assert "Session: 3 messages" in response.content
|
assert "Session: 3 messages" in response.content
|
||||||
assert "Uptime: 2m 5s" in response.content
|
assert "Uptime: 2m 5s" in response.content
|
||||||
assert response.metadata == {"render_as": "text"}
|
assert response.metadata == {"render_as": "text"}
|
||||||
@ -186,7 +186,7 @@ class TestRestartCommand:
|
|||||||
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert "Tokens: 1200 in / 34 out" in response.content
|
assert "Tokens: 1200 in / 34 out" in response.content
|
||||||
assert "Context: 1k/64k (1%)" in response.content
|
assert "Context: 1k/65k (1%)" in response.content
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_direct_preserves_render_metadata(self):
|
async def test_process_direct_preserves_render_metadata(self):
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -114,6 +115,41 @@ async def test_run_history_persisted_to_disk(tmp_path) -> None:
|
|||||||
assert loaded.state.run_history[0].status == "ok"
|
assert loaded.state.run_history[0].status == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_job_disabled_does_not_flip_running_state(tmp_path) -> None:
|
||||||
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
|
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||||
|
job = service.add_job(
|
||||||
|
name="disabled",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
service.enable_job(job.id, enabled=False)
|
||||||
|
|
||||||
|
result = await service.run_job(job.id)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert service._running is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_job_preserves_running_service_state(tmp_path) -> None:
|
||||||
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
|
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||||
|
service._running = True
|
||||||
|
job = service.add_job(
|
||||||
|
name="manual",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await service.run_job(job.id, force=True)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert service._running is True
|
||||||
|
service.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_running_service_honors_external_disable(tmp_path) -> None:
|
async def test_running_service_honors_external_disable(tmp_path) -> None:
|
||||||
store_path = tmp_path / "cron" / "jobs.json"
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
@ -158,24 +194,49 @@ def test_remove_job_refuses_system_jobs(tmp_path) -> None:
|
|||||||
assert service.get_job("dream") is not None
|
assert service.get_job("dream") is not None
|
||||||
|
|
||||||
|
|
||||||
def test_reload_jobs(tmp_path):
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_server_not_jobs(tmp_path):
|
||||||
store_path = tmp_path / "cron" / "jobs.json"
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
called = []
|
||||||
service.add_job(
|
async def on_job(job):
|
||||||
name="hist",
|
called.append(job.name)
|
||||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
|
||||||
message="hello",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(service.list_jobs()) == 1
|
service = CronService(store_path, on_job=on_job, max_sleep_ms=1000)
|
||||||
|
await service.start()
|
||||||
|
assert len(service.list_jobs()) == 0
|
||||||
|
|
||||||
service2 = CronService(tmp_path / "cron" / "jobs.json")
|
service2 = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
service2.add_job(
|
service2.add_job(
|
||||||
name="hist2",
|
name="hist",
|
||||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
schedule=CronSchedule(kind="every", every_ms=500),
|
||||||
message="hello2",
|
message="hello",
|
||||||
)
|
)
|
||||||
assert len(service.list_jobs()) == 2
|
assert len(service.list_jobs()) == 1
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
assert len(called) != 0
|
||||||
|
service.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subsecond_job_not_delayed_to_one_second(tmp_path):
|
||||||
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
|
called = []
|
||||||
|
|
||||||
|
async def on_job(job):
|
||||||
|
called.append(job.name)
|
||||||
|
|
||||||
|
service = CronService(store_path, on_job=on_job, max_sleep_ms=5000)
|
||||||
|
service.add_job(
|
||||||
|
name="fast",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=100),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
await service.start()
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(0.35)
|
||||||
|
assert called
|
||||||
|
finally:
|
||||||
|
service.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -204,7 +265,302 @@ async def test_running_service_picks_up_external_add(tmp_path):
|
|||||||
message="ping",
|
message="ping",
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.sleep(0.6)
|
await asyncio.sleep(2)
|
||||||
assert "external" in called
|
assert "external" in called
|
||||||
finally:
|
finally:
|
||||||
service.stop()
|
service.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_job_during_jobs_exec(tmp_path):
|
||||||
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
|
run_once = True
|
||||||
|
|
||||||
|
async def on_job(job):
|
||||||
|
nonlocal run_once
|
||||||
|
if run_once:
|
||||||
|
service2 = CronService(store_path, on_job=lambda x: asyncio.sleep(0))
|
||||||
|
service2.add_job(
|
||||||
|
name="test",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=150),
|
||||||
|
message="tick",
|
||||||
|
)
|
||||||
|
run_once = False
|
||||||
|
|
||||||
|
service = CronService(store_path, on_job=on_job)
|
||||||
|
service.add_job(
|
||||||
|
name="heartbeat",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=150),
|
||||||
|
message="tick",
|
||||||
|
)
|
||||||
|
assert len(service.list_jobs()) == 1
|
||||||
|
await service.start()
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(3)
|
||||||
|
jobs = service.list_jobs()
|
||||||
|
assert len(jobs) == 2
|
||||||
|
assert "test" in [j.name for j in jobs]
|
||||||
|
finally:
|
||||||
|
service.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_external_update_preserves_run_history_records(tmp_path):
|
||||||
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
|
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||||
|
job = service.add_job(
|
||||||
|
name="history",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
await service.run_job(job.id, force=True)
|
||||||
|
|
||||||
|
external = CronService(store_path)
|
||||||
|
updated = external.enable_job(job.id, enabled=False)
|
||||||
|
assert updated is not None
|
||||||
|
|
||||||
|
fresh = CronService(store_path)
|
||||||
|
loaded = fresh.get_job(job.id)
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.state.run_history
|
||||||
|
assert loaded.state.run_history[0].status == "ok"
|
||||||
|
|
||||||
|
fresh._running = True
|
||||||
|
fresh._save_store()
|
||||||
|
|
||||||
|
|
||||||
|
# ── timer race regression tests ──
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timer_execution_is_not_rolled_back_by_list_jobs_reload(tmp_path):
|
||||||
|
"""list_jobs() during _on_timer should not replace the active store and re-run the same due job."""
|
||||||
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
|
calls: list[str] = []
|
||||||
|
|
||||||
|
async def on_job(job):
|
||||||
|
calls.append(job.id)
|
||||||
|
# Simulate frontend polling list_jobs while the timer callback is mid-execution.
|
||||||
|
service.list_jobs(include_disabled=True)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
service = CronService(store_path, on_job=on_job)
|
||||||
|
service._running = True
|
||||||
|
service._load_store()
|
||||||
|
service._arm_timer = lambda: None
|
||||||
|
|
||||||
|
job = service.add_job(
|
||||||
|
name="race",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
job.state.next_run_at_ms = max(1, int(time.time() * 1000) - 1_000)
|
||||||
|
service._save_store()
|
||||||
|
|
||||||
|
await service._on_timer()
|
||||||
|
await service._on_timer()
|
||||||
|
|
||||||
|
assert calls == [job.id]
|
||||||
|
loaded = service.get_job(job.id)
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.state.last_run_at_ms is not None
|
||||||
|
assert loaded.state.next_run_at_ms is not None
|
||||||
|
assert loaded.state.next_run_at_ms > loaded.state.last_run_at_ms
|
||||||
|
|
||||||
|
|
||||||
|
# ── update_job tests ──
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_job_changes_name(tmp_path) -> None:
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
job = service.add_job(
|
||||||
|
name="old name",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
result = service.update_job(job.id, name="new name")
|
||||||
|
assert isinstance(result, CronJob)
|
||||||
|
assert result.name == "new name"
|
||||||
|
assert result.payload.message == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_job_changes_schedule(tmp_path) -> None:
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
job = service.add_job(
|
||||||
|
name="sched",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
old_next = job.state.next_run_at_ms
|
||||||
|
|
||||||
|
new_sched = CronSchedule(kind="every", every_ms=120_000)
|
||||||
|
result = service.update_job(job.id, schedule=new_sched)
|
||||||
|
assert isinstance(result, CronJob)
|
||||||
|
assert result.schedule.every_ms == 120_000
|
||||||
|
assert result.state.next_run_at_ms != old_next
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_job_changes_message(tmp_path) -> None:
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
job = service.add_job(
|
||||||
|
name="msg",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="old message",
|
||||||
|
)
|
||||||
|
result = service.update_job(job.id, message="new message")
|
||||||
|
assert isinstance(result, CronJob)
|
||||||
|
assert result.payload.message == "new message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_job_changes_cron_expression(tmp_path) -> None:
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
job = service.add_job(
|
||||||
|
name="cron-job",
|
||||||
|
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
result = service.update_job(
|
||||||
|
job.id,
|
||||||
|
schedule=CronSchedule(kind="cron", expr="0 18 * * *", tz="UTC"),
|
||||||
|
)
|
||||||
|
assert isinstance(result, CronJob)
|
||||||
|
assert result.schedule.expr == "0 18 * * *"
|
||||||
|
assert result.state.next_run_at_ms is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_job_not_found(tmp_path) -> None:
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
result = service.update_job("nonexistent", name="x")
|
||||||
|
assert result == "not_found"
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_job_rejects_system_job(tmp_path) -> None:
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
service.register_system_job(CronJob(
|
||||||
|
id="dream",
|
||||||
|
name="dream",
|
||||||
|
schedule=CronSchedule(kind="cron", expr="0 */2 * * *", tz="UTC"),
|
||||||
|
payload=CronPayload(kind="system_event"),
|
||||||
|
))
|
||||||
|
result = service.update_job("dream", name="hacked")
|
||||||
|
assert result == "protected"
|
||||||
|
assert service.get_job("dream").name == "dream"
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_job_validates_schedule(tmp_path) -> None:
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
job = service.add_job(
|
||||||
|
name="validate",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="unknown timezone"):
|
||||||
|
service.update_job(
|
||||||
|
job.id,
|
||||||
|
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="Bad/Zone"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_job_preserves_run_history(tmp_path) -> None:
|
||||||
|
import asyncio
|
||||||
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
|
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||||
|
job = service.add_job(
|
||||||
|
name="hist",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
await service.run_job(job.id)
|
||||||
|
|
||||||
|
result = service.update_job(job.id, name="renamed")
|
||||||
|
assert isinstance(result, CronJob)
|
||||||
|
assert len(result.state.run_history) == 1
|
||||||
|
assert result.state.run_history[0].status == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_job_offline_writes_action(tmp_path) -> None:
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
job = service.add_job(
|
||||||
|
name="offline",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
service.update_job(job.id, name="updated-offline")
|
||||||
|
|
||||||
|
action_path = tmp_path / "cron" / "action.jsonl"
|
||||||
|
assert action_path.exists()
|
||||||
|
lines = [l for l in action_path.read_text().strip().split("\n") if l]
|
||||||
|
last = json.loads(lines[-1])
|
||||||
|
assert last["action"] == "update"
|
||||||
|
assert last["params"]["name"] == "updated-offline"
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_job_sentinel_channel_and_to(tmp_path) -> None:
|
||||||
|
"""Passing None clears channel/to; omitting leaves them unchanged."""
|
||||||
|
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||||
|
job = service.add_job(
|
||||||
|
name="sentinel",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||||
|
message="hello",
|
||||||
|
channel="telegram",
|
||||||
|
to="user123",
|
||||||
|
)
|
||||||
|
assert job.payload.channel == "telegram"
|
||||||
|
assert job.payload.to == "user123"
|
||||||
|
|
||||||
|
result = service.update_job(job.id, name="renamed")
|
||||||
|
assert isinstance(result, CronJob)
|
||||||
|
assert result.payload.channel == "telegram"
|
||||||
|
assert result.payload.to == "user123"
|
||||||
|
|
||||||
|
result = service.update_job(job.id, channel=None, to=None)
|
||||||
|
assert isinstance(result, CronJob)
|
||||||
|
assert result.payload.channel is None
|
||||||
|
assert result.payload.to is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_jobs_during_on_job_does_not_cause_stale_reload(tmp_path) -> None:
|
||||||
|
"""Regression: if the bot calls list_jobs (which reloads from disk) during
|
||||||
|
on_job execution, the in-memory next_run_at_ms update must not be lost.
|
||||||
|
Previously this caused an infinite re-trigger loop."""
|
||||||
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
|
execution_count = 0
|
||||||
|
|
||||||
|
async def on_job_that_lists(job):
|
||||||
|
nonlocal execution_count
|
||||||
|
execution_count += 1
|
||||||
|
# Simulate the bot calling cron(action=list) mid-execution
|
||||||
|
service.list_jobs()
|
||||||
|
|
||||||
|
service = CronService(store_path, on_job=on_job_that_lists, max_sleep_ms=100)
|
||||||
|
await service.start()
|
||||||
|
|
||||||
|
# Add two jobs scheduled in the past so they're immediately due
|
||||||
|
now_ms = int(time.time() * 1000)
|
||||||
|
for name in ("job-a", "job-b"):
|
||||||
|
service.add_job(
|
||||||
|
name=name,
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=3_600_000),
|
||||||
|
message="test",
|
||||||
|
)
|
||||||
|
# Force next_run to the past so _on_timer picks them up
|
||||||
|
for job in service._store.jobs:
|
||||||
|
job.state.next_run_at_ms = now_ms - 1000
|
||||||
|
service._save_store()
|
||||||
|
service._arm_timer()
|
||||||
|
|
||||||
|
# Let the timer fire once
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
service.stop()
|
||||||
|
|
||||||
|
# Each job should have run exactly once, not looped
|
||||||
|
assert execution_count == 2
|
||||||
|
|
||||||
|
# Verify next_run_at_ms was persisted correctly (in the future)
|
||||||
|
raw = json.loads(store_path.read_text())
|
||||||
|
for j in raw["jobs"]:
|
||||||
|
next_run = j["state"]["nextRunAtMs"]
|
||||||
|
assert next_run is not None
|
||||||
|
assert next_run > now_ms, f"Job '{j['name']}' next_run should be in the future"
|
||||||
|
|||||||
@ -2,9 +2,12 @@
|
|||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronSchedule
|
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronSchedule
|
||||||
|
from tests.test_openai_api import pytest_plugins
|
||||||
|
|
||||||
|
|
||||||
def _make_tool(tmp_path) -> CronTool:
|
def _make_tool(tmp_path) -> CronTool:
|
||||||
@ -215,8 +218,10 @@ def test_list_at_job_shows_iso_timestamp(tmp_path) -> None:
|
|||||||
assert "Asia/Shanghai" in result
|
assert "Asia/Shanghai" in result
|
||||||
|
|
||||||
|
|
||||||
def test_list_shows_last_run_state(tmp_path) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_shows_last_run_state(tmp_path) -> None:
|
||||||
tool = _make_tool(tmp_path)
|
tool = _make_tool(tmp_path)
|
||||||
|
tool._cron._running = True
|
||||||
job = tool._cron.add_job(
|
job = tool._cron.add_job(
|
||||||
name="Stateful job",
|
name="Stateful job",
|
||||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||||
@ -232,9 +237,10 @@ def test_list_shows_last_run_state(tmp_path) -> None:
|
|||||||
assert "ok" in result
|
assert "ok" in result
|
||||||
assert "(UTC)" in result
|
assert "(UTC)" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
def test_list_shows_error_message(tmp_path) -> None:
|
async def test_list_shows_error_message(tmp_path) -> None:
|
||||||
tool = _make_tool(tmp_path)
|
tool = _make_tool(tmp_path)
|
||||||
|
tool._cron._running = True
|
||||||
job = tool._cron.add_job(
|
job = tool._cron.add_job(
|
||||||
name="Failed job",
|
name="Failed job",
|
||||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from types import SimpleNamespace
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
|
|
||||||
def test_custom_provider_parse_handles_empty_choices() -> None:
|
def test_custom_provider_parse_handles_empty_choices() -> None:
|
||||||
@ -53,3 +54,20 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None:
|
|||||||
|
|
||||||
assert result.finish_reason == "stop"
|
assert result.finish_reason == "stop"
|
||||||
assert result.content == "hello world"
|
assert result.content == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_provider_502_error_includes_reachability_hint() -> None:
|
||||||
|
spec = find_by_name("ollama")
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider(api_base="http://localhost:11434/v1", spec=spec)
|
||||||
|
|
||||||
|
result = provider._handle_error(
|
||||||
|
Exception("Error code: 502"),
|
||||||
|
spec=spec,
|
||||||
|
api_base="http://localhost:11434/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
assert "local model endpoint" in result.content
|
||||||
|
assert "http://localhost:11434/v1" in result.content
|
||||||
|
assert "proxy/tunnel" in result.content
|
||||||
|
|||||||
197
tests/providers/test_enforce_role_alternation.py
Normal file
197
tests/providers/test_enforce_role_alternation.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
"""Tests for LLMProvider._enforce_role_alternation."""
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMProvider
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnforceRoleAlternation:
|
||||||
|
"""Verify trailing-assistant removal and consecutive same-role merging."""
|
||||||
|
|
||||||
|
def test_empty_messages(self):
|
||||||
|
assert LLMProvider._enforce_role_alternation([]) == []
|
||||||
|
|
||||||
|
def test_no_change_needed(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Hello!"},
|
||||||
|
{"role": "user", "content": "Bye"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert len(result) == 4
|
||||||
|
assert result[-1]["role"] == "user"
|
||||||
|
|
||||||
|
def test_trailing_assistant_removed(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Hello!"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["role"] == "user"
|
||||||
|
|
||||||
|
def test_multiple_trailing_assistants_removed(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "A"},
|
||||||
|
{"role": "assistant", "content": "B"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["role"] == "user"
|
||||||
|
|
||||||
|
def test_consecutive_user_messages_merged(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "user", "content": "How are you?"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert "Hello" in result[0]["content"]
|
||||||
|
assert "How are you?" in result[0]["content"]
|
||||||
|
|
||||||
|
def test_consecutive_assistant_messages_merged(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Hello!"},
|
||||||
|
{"role": "assistant", "content": "How can I help?"},
|
||||||
|
{"role": "user", "content": "Thanks"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert len(result) == 3
|
||||||
|
assert "Hello!" in result[1]["content"]
|
||||||
|
assert "How can I help?" in result[1]["content"]
|
||||||
|
|
||||||
|
def test_system_messages_not_merged(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "System A"},
|
||||||
|
{"role": "system", "content": "System B"},
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["content"] == "System A"
|
||||||
|
assert result[1]["content"] == "System B"
|
||||||
|
|
||||||
|
def test_tool_messages_not_merged(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||||
|
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||||
|
{"role": "tool", "content": "result2", "tool_call_id": "2"},
|
||||||
|
{"role": "user", "content": "Next"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
tool_msgs = [m for m in result if m["role"] == "tool"]
|
||||||
|
assert len(tool_msgs) == 2
|
||||||
|
|
||||||
|
def test_consecutive_assistant_keeps_later_tool_call_message(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Previous reply"},
|
||||||
|
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||||
|
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||||
|
{"role": "user", "content": "Next"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert result[1]["role"] == "assistant"
|
||||||
|
assert result[1]["tool_calls"] == [{"id": "1"}]
|
||||||
|
assert result[1]["content"] is None
|
||||||
|
assert result[2]["role"] == "tool"
|
||||||
|
|
||||||
|
def test_consecutive_assistant_does_not_overwrite_existing_tool_call_message(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||||
|
{"role": "assistant", "content": "Later plain assistant"},
|
||||||
|
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||||
|
{"role": "user", "content": "Next"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert result[1]["role"] == "assistant"
|
||||||
|
assert result[1]["tool_calls"] == [{"id": "1"}]
|
||||||
|
assert result[1]["content"] is None
|
||||||
|
assert result[2]["role"] == "tool"
|
||||||
|
|
||||||
|
def test_non_string_content_uses_latest(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": "A"}]},
|
||||||
|
{"role": "user", "content": "B"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["content"] == "B"
|
||||||
|
|
||||||
|
def test_original_messages_not_mutated(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "user", "content": "World"},
|
||||||
|
]
|
||||||
|
original_first = dict(msgs[0])
|
||||||
|
LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert msgs[0] == original_first
|
||||||
|
assert len(msgs) == 2
|
||||||
|
|
||||||
|
def test_trailing_assistant_recovered_as_user_when_only_system_remains(self):
|
||||||
|
"""Subagent result injected as assistant message must not be silently dropped.
|
||||||
|
|
||||||
|
When build_messages(current_role="assistant") produces [system, assistant],
|
||||||
|
_enforce_role_alternation would drop the assistant, leaving only [system].
|
||||||
|
Most providers (e.g. Zhipu/GLM error 1214) reject such requests.
|
||||||
|
The trailing assistant should be recovered as a user message instead.
|
||||||
|
"""
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "assistant", "content": "Subagent completed successfully."},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["role"] == "system"
|
||||||
|
assert result[1]["role"] == "user"
|
||||||
|
assert "Subagent completed successfully." in result[1]["content"]
|
||||||
|
|
||||||
|
def test_trailing_assistant_not_recovered_when_user_message_present(self):
|
||||||
|
"""Recovery should NOT happen when a user message already exists."""
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Hello!"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[-1]["role"] == "user"
|
||||||
|
|
||||||
|
def test_trailing_assistant_recovered_with_tool_result_preceding(self):
|
||||||
|
"""When only [system, tool, assistant] remains, recovery is not needed
|
||||||
|
because tool messages are valid non-system content."""
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "tool", "content": "result", "tool_call_id": "1"},
|
||||||
|
{"role": "assistant", "content": "Done."},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[-1]["role"] == "tool"
|
||||||
|
|
||||||
|
def test_only_assistant_messages(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "assistant", "content": "A"},
|
||||||
|
{"role": "assistant", "content": "B"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_realistic_conversation(self):
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "4"},
|
||||||
|
{"role": "user", "content": "And 3+3?"},
|
||||||
|
{"role": "user", "content": "(please be quick)"},
|
||||||
|
{"role": "assistant", "content": "6"},
|
||||||
|
]
|
||||||
|
result = LLMProvider._enforce_role_alternation(msgs)
|
||||||
|
assert len(result) == 4
|
||||||
|
assert result[2]["role"] == "assistant"
|
||||||
|
assert result[3]["role"] == "user"
|
||||||
|
assert "And 3+3?" in result[3]["content"]
|
||||||
|
assert "(please be quick)" in result[3]["content"]
|
||||||
@ -10,7 +10,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -54,6 +54,57 @@ def _fake_tool_call_response() -> SimpleNamespace:
|
|||||||
return SimpleNamespace(choices=[choice], usage=usage)
|
return SimpleNamespace(choices=[choice], usage=usage)
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_responses_response(content: str = "ok") -> MagicMock:
|
||||||
|
"""Build a minimal Responses API response object."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.model_dump.return_value = {
|
||||||
|
"output": [{
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "output_text", "text": content}],
|
||||||
|
}],
|
||||||
|
"status": "completed",
|
||||||
|
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_responses_stream(text: str = "ok"):
|
||||||
|
async def _stream():
|
||||||
|
yield SimpleNamespace(type="response.output_text.delta", delta=text)
|
||||||
|
yield SimpleNamespace(
|
||||||
|
type="response.completed",
|
||||||
|
response=SimpleNamespace(
|
||||||
|
status="completed",
|
||||||
|
usage=SimpleNamespace(input_tokens=10, output_tokens=5, total_tokens=15),
|
||||||
|
output=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return _stream()
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_chat_stream(text: str = "ok"):
|
||||||
|
async def _stream():
|
||||||
|
yield SimpleNamespace(
|
||||||
|
choices=[SimpleNamespace(finish_reason=None, delta=SimpleNamespace(content=text, reasoning_content=None, tool_calls=None))],
|
||||||
|
usage=None,
|
||||||
|
)
|
||||||
|
yield SimpleNamespace(
|
||||||
|
choices=[SimpleNamespace(finish_reason="stop", delta=SimpleNamespace(content=None, reasoning_content=None, tool_calls=None))],
|
||||||
|
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||||
|
)
|
||||||
|
|
||||||
|
return _stream()
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResponsesError(Exception):
|
||||||
|
def __init__(self, status_code: int, text: str):
|
||||||
|
super().__init__(text)
|
||||||
|
self.status_code = status_code
|
||||||
|
self.response = SimpleNamespace(status_code=status_code, text=text, headers={})
|
||||||
|
|
||||||
|
|
||||||
class _StalledStream:
|
class _StalledStream:
|
||||||
def __aiter__(self):
|
def __aiter__(self):
|
||||||
return self
|
return self
|
||||||
@ -226,6 +277,224 @@ def test_openai_model_passthrough() -> None:
|
|||||||
assert provider.get_default_model() == "gpt-4o"
|
assert provider.get_default_model() == "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_openai_gpt5_uses_responses_api() -> None:
|
||||||
|
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||||
|
mock_responses = AsyncMock(return_value=_fake_responses_response("from responses"))
|
||||||
|
spec = find_by_name("openai")
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_chat
|
||||||
|
client_instance.responses.create = mock_responses
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-test-key",
|
||||||
|
default_model="gpt-5-chat",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
result = await provider.chat(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gpt-5-chat",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.content == "from responses"
|
||||||
|
mock_responses.assert_awaited_once()
|
||||||
|
mock_chat.assert_not_awaited()
|
||||||
|
call_kwargs = mock_responses.call_args.kwargs
|
||||||
|
assert call_kwargs["model"] == "gpt-5-chat"
|
||||||
|
assert call_kwargs["max_output_tokens"] == 4096
|
||||||
|
assert "input" in call_kwargs
|
||||||
|
assert "messages" not in call_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_openai_reasoning_prefers_responses_api() -> None:
|
||||||
|
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||||
|
mock_responses = AsyncMock(return_value=_fake_responses_response("reasoned"))
|
||||||
|
spec = find_by_name("openai")
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_chat
|
||||||
|
client_instance.responses.create = mock_responses
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-test-key",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
await provider.chat(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gpt-4o",
|
||||||
|
reasoning_effort="medium",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_responses.assert_awaited_once()
|
||||||
|
mock_chat.assert_not_awaited()
|
||||||
|
call_kwargs = mock_responses.call_args.kwargs
|
||||||
|
assert call_kwargs["reasoning"] == {"effort": "medium"}
|
||||||
|
assert call_kwargs["include"] == ["reasoning.encrypted_content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_openai_gpt4o_stays_on_chat_completions() -> None:
|
||||||
|
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||||
|
mock_responses = AsyncMock(return_value=_fake_responses_response())
|
||||||
|
spec = find_by_name("openai")
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_chat
|
||||||
|
client_instance.responses.create = mock_responses
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-test-key",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
await provider.chat(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_chat.assert_awaited_once()
|
||||||
|
mock_responses.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openrouter_gpt5_stays_on_chat_completions() -> None:
|
||||||
|
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||||
|
mock_responses = AsyncMock(return_value=_fake_responses_response())
|
||||||
|
spec = find_by_name("openrouter")
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_chat
|
||||||
|
client_instance.responses.create = mock_responses
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-or-test-key",
|
||||||
|
api_base="https://openrouter.ai/api/v1",
|
||||||
|
default_model="openai/gpt-5",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
await provider.chat(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="openai/gpt-5",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_chat.assert_awaited_once()
|
||||||
|
mock_responses.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_openai_streaming_gpt5_uses_responses_api() -> None:
|
||||||
|
mock_chat = AsyncMock(return_value=_StalledStream())
|
||||||
|
mock_responses = AsyncMock(return_value=_fake_responses_stream("hi"))
|
||||||
|
spec = find_by_name("openai")
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_chat
|
||||||
|
client_instance.responses.create = mock_responses
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-test-key",
|
||||||
|
default_model="gpt-5-chat",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
result = await provider.chat_stream(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gpt-5-chat",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.content == "hi"
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
mock_responses.assert_awaited_once()
|
||||||
|
mock_chat.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_openai_responses_404_falls_back_to_chat_completions() -> None:
|
||||||
|
mock_chat = AsyncMock(return_value=_fake_chat_response("from chat"))
|
||||||
|
mock_responses = AsyncMock(side_effect=_FakeResponsesError(404, "Responses endpoint not supported"))
|
||||||
|
spec = find_by_name("openai")
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_chat
|
||||||
|
client_instance.responses.create = mock_responses
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-test-key",
|
||||||
|
default_model="gpt-5-chat",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
result = await provider.chat(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gpt-5-chat",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.content == "from chat"
|
||||||
|
mock_responses.assert_awaited_once()
|
||||||
|
mock_chat.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_openai_stream_responses_unsupported_param_falls_back() -> None:
|
||||||
|
mock_chat = AsyncMock(return_value=_fake_chat_stream("fallback stream"))
|
||||||
|
mock_responses = AsyncMock(
|
||||||
|
side_effect=_FakeResponsesError(400, "Unknown parameter: max_output_tokens for Responses API")
|
||||||
|
)
|
||||||
|
spec = find_by_name("openai")
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_chat
|
||||||
|
client_instance.responses.create = mock_responses
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-test-key",
|
||||||
|
default_model="gpt-5-chat",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
result = await provider.chat_stream(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gpt-5-chat",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.content == "fallback stream"
|
||||||
|
mock_responses.assert_awaited_once()
|
||||||
|
mock_chat.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_openai_responses_rate_limit_does_not_fallback() -> None:
|
||||||
|
mock_chat = AsyncMock(return_value=_fake_chat_response("from chat"))
|
||||||
|
mock_responses = AsyncMock(side_effect=_FakeResponsesError(429, "rate limit"))
|
||||||
|
spec = find_by_name("openai")
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_chat
|
||||||
|
client_instance.responses.create = mock_responses
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-test-key",
|
||||||
|
default_model="gpt-5-chat",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
result = await provider.chat(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gpt-5-chat",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
mock_responses.assert_awaited_once()
|
||||||
|
mock_chat.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
def test_openai_compat_supports_temperature_matches_reasoning_model_rules() -> None:
|
def test_openai_compat_supports_temperature_matches_reasoning_model_rules() -> None:
|
||||||
assert OpenAICompatProvider._supports_temperature("gpt-4o") is True
|
assert OpenAICompatProvider._supports_temperature("gpt-4o") is True
|
||||||
assert OpenAICompatProvider._supports_temperature("gpt-5-chat") is False
|
assert OpenAICompatProvider._supports_temperature("gpt-5-chat") is False
|
||||||
@ -263,6 +532,7 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
|||||||
provider = OpenAICompatProvider()
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
sanitized = provider._sanitize_messages([
|
sanitized = provider._sanitize_messages([
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "done",
|
"content": "done",
|
||||||
@ -276,12 +546,42 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
|||||||
"extra_content": {"google": {"thought_signature": "sig"}},
|
"extra_content": {"google": {"thought_signature": "sig"}},
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
},
|
||||||
|
{"role": "user", "content": "thanks"},
|
||||||
])
|
])
|
||||||
|
|
||||||
assert sanitized[0]["reasoning_content"] == "hidden"
|
assert sanitized[1]["content"] is None
|
||||||
assert sanitized[0]["extra_content"] == {"debug": True}
|
assert sanitized[1]["reasoning_content"] == "hidden"
|
||||||
assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
assert sanitized[1]["extra_content"] == {"debug": True}
|
||||||
|
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
sanitized = provider._sanitize_messages([
|
||||||
|
{"role": "user", "content": "不错"},
|
||||||
|
{"role": "assistant", "content": "对,破 4 万指日可待"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "<think>我再查一下</think>",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_function_akxp3wqzn7ph_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_function_akxp3wqzn7ph_1", "name": "exec", "content": "ok"},
|
||||||
|
{"role": "user", "content": "多少star了呢"},
|
||||||
|
])
|
||||||
|
|
||||||
|
assert sanitized[1]["role"] == "assistant"
|
||||||
|
assert sanitized[1]["content"] is None
|
||||||
|
assert sanitized[1]["tool_calls"][0]["id"] == "3ec83c30d"
|
||||||
|
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -152,7 +153,7 @@ async def test_non_transient_error_with_images_retries_without_images() -> None:
|
|||||||
LLMResponse(content="ok, no image"),
|
LLMResponse(content="ok, no image"),
|
||||||
])
|
])
|
||||||
|
|
||||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG))
|
||||||
|
|
||||||
assert response.content == "ok, no image"
|
assert response.content == "ok, no image"
|
||||||
assert provider.calls == 2
|
assert provider.calls == 2
|
||||||
@ -164,6 +165,24 @@ async def test_non_transient_error_with_images_retries_without_images() -> None:
|
|||||||
assert any("[image: /media/test.png]" in (b.get("text") or "") for b in content)
|
assert any("[image: /media/test.png]" in (b.get("text") or "") for b in content)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_image_retry_mutates_original_messages_in_place() -> None:
|
||||||
|
"""Successful no-image retry should update the caller's message history."""
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
LLMResponse(content="model does not support images", finish_reason="error"),
|
||||||
|
LLMResponse(content="ok, no image"),
|
||||||
|
])
|
||||||
|
messages = copy.deepcopy(_IMAGE_MSG)
|
||||||
|
|
||||||
|
response = await provider.chat_with_retry(messages=messages)
|
||||||
|
|
||||||
|
assert response.content == "ok, no image"
|
||||||
|
content = messages[0]["content"]
|
||||||
|
assert isinstance(content, list)
|
||||||
|
assert all(block.get("type") != "image_url" for block in content)
|
||||||
|
assert any("[image: /media/test.png]" in (block.get("text") or "") for block in content)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_non_transient_error_without_images_no_retry() -> None:
|
async def test_non_transient_error_without_images_no_retry() -> None:
|
||||||
"""Non-transient errors without image content are returned immediately."""
|
"""Non-transient errors without image content are returned immediately."""
|
||||||
@ -187,7 +206,7 @@ async def test_image_fallback_returns_error_on_second_failure() -> None:
|
|||||||
LLMResponse(content="still failing", finish_reason="error"),
|
LLMResponse(content="still failing", finish_reason="error"),
|
||||||
])
|
])
|
||||||
|
|
||||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG))
|
||||||
|
|
||||||
assert provider.calls == 2
|
assert provider.calls == 2
|
||||||
assert response.content == "still failing"
|
assert response.content == "still failing"
|
||||||
@ -202,7 +221,7 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
|
|||||||
LLMResponse(content="ok"),
|
LLMResponse(content="ok"),
|
||||||
])
|
])
|
||||||
|
|
||||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG_NO_META)
|
response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG_NO_META))
|
||||||
|
|
||||||
assert response.content == "ok"
|
assert response.content == "ok"
|
||||||
assert provider.calls == 2
|
assert provider.calls == 2
|
||||||
|
|||||||
@ -10,8 +10,7 @@ import pytest
|
|||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from nanobot.api.server import (
|
from nanobot.api.server import (
|
||||||
API_CHAT_ID,
|
_FileSizeExceeded,
|
||||||
API_SESSION_KEY,
|
|
||||||
_parse_json_content,
|
_parse_json_content,
|
||||||
_save_base64_data_url,
|
_save_base64_data_url,
|
||||||
create_app,
|
create_app,
|
||||||
@ -91,6 +90,15 @@ def test_save_base64_data_url_handles_unknown_mime(tmp_path) -> None:
|
|||||||
assert result.endswith(".bin")
|
assert result.endswith(".bin")
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_base64_data_url_rejects_oversized_payload(tmp_path) -> None:
|
||||||
|
"""Base64 uploads should respect the same per-file limit as multipart."""
|
||||||
|
large_payload = base64.b64encode(b"x" * (11 * 1024 * 1024)).decode()
|
||||||
|
data_url = f"data:image/png;base64,{large_payload}"
|
||||||
|
|
||||||
|
with pytest.raises(_FileSizeExceeded, match="10MB limit"):
|
||||||
|
_save_base64_data_url(data_url, tmp_path)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_json_content_extracts_text_and_media(tmp_path) -> None:
|
def test_parse_json_content_extracts_text_and_media(tmp_path) -> None:
|
||||||
"""Parse JSON with text + base64 image saves image and returns paths."""
|
"""Parse JSON with text + base64 image saves image and returns paths."""
|
||||||
b64_data = base64.b64encode(b"img").decode()
|
b64_data = base64.b64encode(b"img").decode()
|
||||||
@ -144,6 +152,31 @@ def test_parse_json_content_validates_user_role() -> None:
|
|||||||
_parse_json_content(body)
|
_parse_json_content(body)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_json_content_rejects_oversized_base64_file(tmp_path) -> None:
|
||||||
|
"""Oversized JSON data URLs should fail before writing to disk."""
|
||||||
|
large_payload = base64.b64encode(b"x" * (11 * 1024 * 1024)).decode()
|
||||||
|
body = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "describe"},
|
||||||
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{large_payload}"}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
import os
|
||||||
|
original_cwd = os.getcwd()
|
||||||
|
os.chdir(tmp_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(_FileSizeExceeded, match="10MB limit"):
|
||||||
|
_parse_json_content(body)
|
||||||
|
finally:
|
||||||
|
os.chdir(original_cwd)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Multipart upload tests
|
# Multipart upload tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
@ -64,3 +63,19 @@ def test_build_user_content_mixed_image_and_document(tmp_path: Path) -> None:
|
|||||||
assert any(b["type"] == "image_url" for b in result)
|
assert any(b["type"] == "image_url" for b in result)
|
||||||
text_parts = [b.get("text", "") for b in result if b.get("type") == "text"]
|
text_parts = [b.get("text", "") for b in result if b.get("type") == "text"]
|
||||||
assert any("Report text here" in t for t in text_parts)
|
assert any("Report text here" in t for t in text_parts)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_user_content_skips_document_extraction_errors(tmp_path: Path, monkeypatch) -> None:
|
||||||
|
"""Document extraction errors should not be embedded into the user prompt."""
|
||||||
|
docx_path = tmp_path / "broken.docx"
|
||||||
|
docx_path.write_text("not a real docx", encoding="utf-8")
|
||||||
|
|
||||||
|
builder = _make_builder(tmp_path)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.utils.document.extract_text",
|
||||||
|
lambda _path: "[error: failed to extract DOCX: boom]",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = builder._build_user_content("summarize this", [str(docx_path)])
|
||||||
|
assert result == "summarize this"
|
||||||
|
|||||||
@ -1,10 +1,7 @@
|
|||||||
"""Tests for document text extraction utilities."""
|
"""Tests for document text extraction utilities."""
|
||||||
|
|
||||||
import io
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from nanobot.utils.document import (
|
from nanobot.utils.document import (
|
||||||
SUPPORTED_EXTENSIONS,
|
SUPPORTED_EXTENSIONS,
|
||||||
_is_text_extension,
|
_is_text_extension,
|
||||||
|
|||||||
41
tests/test_package_version.py
Normal file
41
tests/test_package_version.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import textwrap
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import tomllib
|
||||||
|
|
||||||
|
|
||||||
|
def test_source_checkout_import_uses_pyproject_version_without_metadata() -> None:
|
||||||
|
repo_root = Path(__file__).resolve().parents[1]
|
||||||
|
expected = tomllib.loads((repo_root / "pyproject.toml").read_text(encoding="utf-8"))["project"][
|
||||||
|
"version"
|
||||||
|
]
|
||||||
|
script = textwrap.dedent(
|
||||||
|
f"""
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
sys.path.insert(0, {str(repo_root)!r})
|
||||||
|
fake = types.ModuleType("nanobot.nanobot")
|
||||||
|
fake.Nanobot = object
|
||||||
|
fake.RunResult = object
|
||||||
|
sys.modules["nanobot.nanobot"] = fake
|
||||||
|
|
||||||
|
import nanobot
|
||||||
|
|
||||||
|
print(nanobot.__version__)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
proc = subprocess.run(
|
||||||
|
[sys.executable, "-S", "-c", script],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert proc.returncode == 0, proc.stderr
|
||||||
|
assert proc.stdout.strip() == expected
|
||||||
31
tests/test_truncate_text_shadowing.py
Normal file
31
tests/test_truncate_text_shadowing.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import inspect
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_persisted_blocks_truncate_text_shadowing_regression() -> None:
|
||||||
|
"""Regression: avoid bool param shadowing imported truncate_text.
|
||||||
|
|
||||||
|
Buggy behavior (historical):
|
||||||
|
- loop.py imports `truncate_text` from helpers
|
||||||
|
- `_sanitize_persisted_blocks(..., truncate_text: bool=...)` uses same name
|
||||||
|
- when called with `truncate_text=True`, function body executes `truncate_text(text, ...)`
|
||||||
|
which resolves to bool and raises `TypeError: 'bool' object is not callable`.
|
||||||
|
|
||||||
|
This test asserts the fixed API exists and truncation works without raising.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
|
||||||
|
sig = inspect.signature(AgentLoop._sanitize_persisted_blocks)
|
||||||
|
assert "should_truncate_text" in sig.parameters
|
||||||
|
assert "truncate_text" not in sig.parameters
|
||||||
|
|
||||||
|
dummy = SimpleNamespace(max_tool_result_chars=5)
|
||||||
|
content = [{"type": "text", "text": "0123456789"}]
|
||||||
|
|
||||||
|
out = AgentLoop._sanitize_persisted_blocks(dummy, content, should_truncate_text=True)
|
||||||
|
assert isinstance(out, list)
|
||||||
|
assert out and out[0]["type"] == "text"
|
||||||
|
assert isinstance(out[0]["text"], str)
|
||||||
|
assert out[0]["text"] != content[0]["text"]
|
||||||
|
|
||||||
423
tests/tools/test_edit_advanced.py
Normal file
423
tests/tools/test_edit_advanced.py
Normal file
@ -0,0 +1,423 @@
|
|||||||
|
"""Tests for advanced EditFileTool enhancements inspired by claude-code:
|
||||||
|
- Delete-line newline cleanup
|
||||||
|
- Smart quote normalization (curly ↔ straight)
|
||||||
|
- Quote style preservation in replacements
|
||||||
|
- Indentation preservation when fallback match is trimmed
|
||||||
|
- Trailing whitespace stripping for new_text
|
||||||
|
- File size protection
|
||||||
|
- Stale detection with content-equality fallback
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, _find_match
|
||||||
|
from nanobot.agent.tools import file_state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clear_file_state():
|
||||||
|
file_state.clear()
|
||||||
|
yield
|
||||||
|
file_state.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Delete-line newline cleanup
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteLineCleanup:
|
||||||
|
"""When new_text='' and deleting a line, trailing newline should be consumed."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_line_consumes_trailing_newline(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "a.py"
|
||||||
|
f.write_text("line1\nline2\nline3\n", encoding="utf-8")
|
||||||
|
result = await tool.execute(path=str(f), old_text="line2", new_text="")
|
||||||
|
assert "Successfully" in result
|
||||||
|
content = f.read_text()
|
||||||
|
# Should not leave a blank line where line2 was
|
||||||
|
assert content == "line1\nline3\n"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_line_with_explicit_newline_in_old_text(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "a.py"
|
||||||
|
f.write_text("line1\nline2\nline3\n", encoding="utf-8")
|
||||||
|
result = await tool.execute(path=str(f), old_text="line2\n", new_text="")
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert f.read_text() == "line1\nline3\n"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_preserves_content_when_not_trailing_newline(self, tool, tmp_path):
|
||||||
|
"""Deleting a word mid-line should not consume extra characters."""
|
||||||
|
f = tmp_path / "a.py"
|
||||||
|
f.write_text("hello world here\n", encoding="utf-8")
|
||||||
|
result = await tool.execute(path=str(f), old_text="world ", new_text="")
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert f.read_text() == "hello here\n"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Smart quote normalization
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSmartQuoteNormalization:
|
||||||
|
"""_find_match should handle curly ↔ straight quote fallback."""
|
||||||
|
|
||||||
|
def test_curly_double_quotes_match_straight(self):
|
||||||
|
content = 'She said \u201chello\u201d to him'
|
||||||
|
old_text = 'She said "hello" to him'
|
||||||
|
match, count = _find_match(content, old_text)
|
||||||
|
assert match is not None
|
||||||
|
assert count == 1
|
||||||
|
# Returned match should be the ORIGINAL content with curly quotes
|
||||||
|
assert "\u201c" in match
|
||||||
|
|
||||||
|
def test_curly_single_quotes_match_straight(self):
|
||||||
|
content = "it\u2019s a test"
|
||||||
|
old_text = "it's a test"
|
||||||
|
match, count = _find_match(content, old_text)
|
||||||
|
assert match is not None
|
||||||
|
assert count == 1
|
||||||
|
assert "\u2019" in match
|
||||||
|
|
||||||
|
def test_straight_matches_curly_in_old_text(self):
|
||||||
|
content = 'x = "hello"'
|
||||||
|
old_text = 'x = \u201chello\u201d'
|
||||||
|
match, count = _find_match(content, old_text)
|
||||||
|
assert match is not None
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
def test_exact_match_still_preferred_over_quote_normalization(self):
|
||||||
|
content = 'x = "hello"'
|
||||||
|
old_text = 'x = "hello"'
|
||||||
|
match, count = _find_match(content, old_text)
|
||||||
|
assert match == old_text
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestQuoteStylePreservation:
|
||||||
|
"""When quote-normalized matching occurs, replacement should preserve actual quote style."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replacement_preserves_curly_double_quotes(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "quotes.txt"
|
||||||
|
f.write_text('message = “hello”\n', encoding="utf-8")
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(f),
|
||||||
|
old_text='message = "hello"',
|
||||||
|
new_text='message = "goodbye"',
|
||||||
|
)
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert f.read_text(encoding="utf-8") == 'message = “goodbye”\n'
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replacement_preserves_curly_apostrophe(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "apostrophe.txt"
|
||||||
|
f.write_text("it’s fine\n", encoding="utf-8")
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(f),
|
||||||
|
old_text="it's fine",
|
||||||
|
new_text="it's better",
|
||||||
|
)
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert f.read_text(encoding="utf-8") == "it’s better\n"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Indentation preservation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestIndentationPreservation:
|
||||||
|
"""Replacement should keep outer indentation when trim fallback matched."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trim_fallback_preserves_outer_indentation(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "indent.py"
|
||||||
|
f.write_text(
|
||||||
|
"if True:\n"
|
||||||
|
" def foo():\n"
|
||||||
|
" pass\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(f),
|
||||||
|
old_text="def foo():\n pass",
|
||||||
|
new_text="def bar():\n return 1",
|
||||||
|
)
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert f.read_text(encoding="utf-8") == (
|
||||||
|
"if True:\n"
|
||||||
|
" def bar():\n"
|
||||||
|
" return 1\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Failure diagnostics
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEditDiagnostics:
|
||||||
|
"""Failure paths should offer actionable hints."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ambiguous_match_reports_candidate_lines(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "dup.py"
|
||||||
|
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
|
||||||
|
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
|
||||||
|
assert "appears 2 times" in result.lower()
|
||||||
|
assert "line 1" in result.lower()
|
||||||
|
assert "line 3" in result.lower()
|
||||||
|
assert "replace_all=true" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_found_reports_whitespace_hint(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "space.py"
|
||||||
|
f.write_text("value = 1\n", encoding="utf-8")
|
||||||
|
result = await tool.execute(path=str(f), old_text="value = 1", new_text="value = 2")
|
||||||
|
assert "Error" in result
|
||||||
|
assert "whitespace" in result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_not_found_reports_case_hint(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "case.py"
|
||||||
|
f.write_text("HelloWorld\n", encoding="utf-8")
|
||||||
|
result = await tool.execute(path=str(f), old_text="helloworld", new_text="goodbye")
|
||||||
|
assert "Error" in result
|
||||||
|
assert "letter case differs" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Advanced fallback replacement behavior
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdvancedReplaceAll:
|
||||||
|
"""replace_all should work correctly for fallback-based matches too."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "indent_multi.py"
|
||||||
|
f.write_text(
|
||||||
|
"if a:\n"
|
||||||
|
" def foo():\n"
|
||||||
|
" pass\n"
|
||||||
|
"if b:\n"
|
||||||
|
" def foo():\n"
|
||||||
|
" pass\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(f),
|
||||||
|
old_text="def foo():\n pass",
|
||||||
|
new_text="def bar():\n return 1",
|
||||||
|
replace_all=True,
|
||||||
|
)
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert f.read_text(encoding="utf-8") == (
|
||||||
|
"if a:\n"
|
||||||
|
" def bar():\n"
|
||||||
|
" return 1\n"
|
||||||
|
"if b:\n"
|
||||||
|
" def bar():\n"
|
||||||
|
" return 1\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "quote_indent.py"
|
||||||
|
f.write_text(" message = “hello”\n", encoding="utf-8")
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(f),
|
||||||
|
old_text='message = "hello"',
|
||||||
|
new_text='message = "goodbye"',
|
||||||
|
)
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert f.read_text(encoding="utf-8") == " message = “goodbye”\n"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Advanced fallback replacement behavior
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdvancedReplaceAll:
|
||||||
|
"""replace_all should work correctly for fallback-based matches too."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "indent_multi.py"
|
||||||
|
f.write_text(
|
||||||
|
"if a:\n"
|
||||||
|
" def foo():\n"
|
||||||
|
" pass\n"
|
||||||
|
"if b:\n"
|
||||||
|
" def foo():\n"
|
||||||
|
" pass\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(f),
|
||||||
|
old_text="def foo():\n pass",
|
||||||
|
new_text="def bar():\n return 1",
|
||||||
|
replace_all=True,
|
||||||
|
)
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert f.read_text(encoding="utf-8") == (
|
||||||
|
"if a:\n"
|
||||||
|
" def bar():\n"
|
||||||
|
" return 1\n"
|
||||||
|
"if b:\n"
|
||||||
|
" def bar():\n"
|
||||||
|
" return 1\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "quote_indent.py"
|
||||||
|
f.write_text(" message = “hello”\n", encoding="utf-8")
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(f),
|
||||||
|
old_text='message = "hello"',
|
||||||
|
new_text='message = "goodbye"',
|
||||||
|
)
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert f.read_text(encoding="utf-8") == " message = “goodbye”\n"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Trailing whitespace stripping on new_text
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrailingWhitespaceStrip:
|
||||||
|
"""new_text trailing whitespace should be stripped (except .md files)."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_strips_trailing_whitespace_from_new_text(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "a.py"
|
||||||
|
f.write_text("x = 1\n", encoding="utf-8")
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(f), old_text="x = 1", new_text="x = 2 \ny = 3 ",
|
||||||
|
)
|
||||||
|
assert "Successfully" in result
|
||||||
|
content = f.read_text()
|
||||||
|
assert "x = 2\ny = 3\n" == content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserves_trailing_whitespace_in_markdown(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "doc.md"
|
||||||
|
f.write_text("# Title\n", encoding="utf-8")
|
||||||
|
# Markdown uses trailing double-space for line breaks
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(f), old_text="# Title", new_text="# Title \nSubtitle ",
|
||||||
|
)
|
||||||
|
assert "Successfully" in result
|
||||||
|
content = f.read_text()
|
||||||
|
# Trailing spaces should be preserved for markdown
|
||||||
|
assert "Title " in content
|
||||||
|
assert "Subtitle " in content
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# File size protection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileSizeProtection:
|
||||||
|
"""Editing extremely large files should be rejected."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_file_over_size_limit(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "huge.txt"
|
||||||
|
f.write_text("x", encoding="utf-8")
|
||||||
|
# Monkey-patch the file size check by creating a stat mock
|
||||||
|
original_stat = f.stat
|
||||||
|
|
||||||
|
class FakeStat:
|
||||||
|
def __init__(self, real_stat):
|
||||||
|
self._real = real_stat
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(self._real, name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def st_size(self):
|
||||||
|
return 2 * 1024 * 1024 * 1024 # 2 GiB
|
||||||
|
|
||||||
|
import unittest.mock
|
||||||
|
with unittest.mock.patch.object(type(f), 'stat', return_value=FakeStat(f.stat())):
|
||||||
|
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
||||||
|
assert "Error" in result
|
||||||
|
assert "too large" in result.lower() or "size" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Stale detection with content-equality fallback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestStaleDetectionContentFallback:
|
||||||
|
"""When mtime changed but file content is unchanged, edit should proceed without warning."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def read_tool(self, tmp_path):
|
||||||
|
return ReadFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def edit_tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mtime_bump_same_content_no_warning(self, read_tool, edit_tool, tmp_path):
|
||||||
|
f = tmp_path / "a.py"
|
||||||
|
f.write_text("hello world", encoding="utf-8")
|
||||||
|
await read_tool.execute(path=str(f))
|
||||||
|
|
||||||
|
# Touch the file to bump mtime without changing content
|
||||||
|
time.sleep(0.05)
|
||||||
|
original_content = f.read_text()
|
||||||
|
f.write_text(original_content, encoding="utf-8")
|
||||||
|
|
||||||
|
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||||
|
assert "Successfully" in result
|
||||||
|
# Should NOT warn about modification since content is the same
|
||||||
|
assert "modified" not in result.lower()
|
||||||
152
tests/tools/test_edit_enhancements.py
Normal file
152
tests/tools/test_edit_enhancements.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
|
||||||
|
.ipynb detection, and create-file semantics."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
|
||||||
|
from nanobot.agent.tools import file_state
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clear_file_state():
|
||||||
|
"""Reset global read-state between tests."""
|
||||||
|
file_state.clear()
|
||||||
|
yield
|
||||||
|
file_state.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Read-before-edit tracking
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEditReadTracking:
|
||||||
|
"""edit_file should warn when file hasn't been read first."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def read_tool(self, tmp_path):
|
||||||
|
return ReadFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def edit_tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_edit_warns_if_file_not_read_first(self, edit_tool, tmp_path):
|
||||||
|
f = tmp_path / "a.py"
|
||||||
|
f.write_text("hello world", encoding="utf-8")
|
||||||
|
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||||
|
# Should still succeed but include a warning
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert "not been read" in result.lower() or "warning" in result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_edit_succeeds_cleanly_after_read(self, read_tool, edit_tool, tmp_path):
|
||||||
|
f = tmp_path / "a.py"
|
||||||
|
f.write_text("hello world", encoding="utf-8")
|
||||||
|
await read_tool.execute(path=str(f))
|
||||||
|
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||||
|
assert "Successfully" in result
|
||||||
|
# No warning when file was read first
|
||||||
|
assert "not been read" not in result.lower()
|
||||||
|
assert f.read_text() == "hello earth"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_edit_warns_if_file_modified_since_read(self, read_tool, edit_tool, tmp_path):
|
||||||
|
f = tmp_path / "a.py"
|
||||||
|
f.write_text("hello world", encoding="utf-8")
|
||||||
|
await read_tool.execute(path=str(f))
|
||||||
|
# External modification
|
||||||
|
f.write_text("hello universe", encoding="utf-8")
|
||||||
|
result = await edit_tool.execute(path=str(f), old_text="universe", new_text="earth")
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert "modified" in result.lower() or "warning" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Create-file semantics
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEditCreateFile:
|
||||||
|
"""edit_file with old_text='' creates new file if not exists."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_new_file_with_empty_old_text(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "subdir" / "new.py"
|
||||||
|
result = await tool.execute(path=str(f), old_text="", new_text="print('hi')")
|
||||||
|
assert "created" in result.lower() or "Successfully" in result
|
||||||
|
assert f.exists()
|
||||||
|
assert f.read_text() == "print('hi')"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_fails_if_file_already_exists_and_not_empty(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "existing.py"
|
||||||
|
f.write_text("existing content", encoding="utf-8")
|
||||||
|
result = await tool.execute(path=str(f), old_text="", new_text="new content")
|
||||||
|
assert "Error" in result or "already exists" in result.lower()
|
||||||
|
# File should be unchanged
|
||||||
|
assert f.read_text() == "existing content"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_succeeds_if_file_exists_but_empty(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "empty.py"
|
||||||
|
f.write_text("", encoding="utf-8")
|
||||||
|
result = await tool.execute(path=str(f), old_text="", new_text="print('hi')")
|
||||||
|
assert "Successfully" in result
|
||||||
|
assert f.read_text() == "print('hi')"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# .ipynb detection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEditIpynbDetection:
|
||||||
|
"""edit_file should refuse .ipynb and suggest notebook_edit."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "analysis.ipynb"
|
||||||
|
f.write_text('{"cells": []}', encoding="utf-8")
|
||||||
|
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
||||||
|
assert "notebook" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Path suggestion on not-found
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEditPathSuggestion:
|
||||||
|
"""edit_file should suggest similar paths on not-found."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_suggests_similar_filename(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "config.py"
|
||||||
|
f.write_text("x = 1", encoding="utf-8")
|
||||||
|
# Typo: conifg.py
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(tmp_path / "conifg.py"), old_text="x = 1", new_text="x = 2",
|
||||||
|
)
|
||||||
|
assert "Error" in result
|
||||||
|
assert "config.py" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shows_cwd_in_error(self, tool, tmp_path):
|
||||||
|
result = await tool.execute(
|
||||||
|
path=str(tmp_path / "nonexistent.py"), old_text="a", new_text="b",
|
||||||
|
)
|
||||||
|
assert "Error" in result
|
||||||
@ -43,3 +43,34 @@ async def test_exec_path_append_preserves_system_path():
|
|||||||
tool = ExecTool(path_append="/opt/custom/bin")
|
tool = ExecTool(path_append="/opt/custom/bin")
|
||||||
result = await tool.execute(command="ls /")
|
result = await tool.execute(command="ls /")
|
||||||
assert "Exit code: 0" in result
|
assert "Exit code: 0" in result
|
||||||
|
|
||||||
|
|
||||||
|
@_UNIX_ONLY
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_allowed_env_keys_passthrough(monkeypatch):
|
||||||
|
"""Env vars listed in allowed_env_keys should be visible to commands."""
|
||||||
|
monkeypatch.setenv("MY_CUSTOM_VAR", "hello-from-config")
|
||||||
|
tool = ExecTool(allowed_env_keys=["MY_CUSTOM_VAR"])
|
||||||
|
result = await tool.execute(command="printenv MY_CUSTOM_VAR")
|
||||||
|
assert "hello-from-config" in result
|
||||||
|
|
||||||
|
|
||||||
|
@_UNIX_ONLY
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_allowed_env_keys_does_not_leak_others(monkeypatch):
|
||||||
|
"""Env vars NOT in allowed_env_keys should still be blocked."""
|
||||||
|
monkeypatch.setenv("MY_CUSTOM_VAR", "hello-from-config")
|
||||||
|
monkeypatch.setenv("MY_SECRET_VAR", "secret-value")
|
||||||
|
tool = ExecTool(allowed_env_keys=["MY_CUSTOM_VAR"])
|
||||||
|
result = await tool.execute(command="printenv MY_SECRET_VAR")
|
||||||
|
assert "secret-value" not in result
|
||||||
|
|
||||||
|
|
||||||
|
@_UNIX_ONLY
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_allowed_env_keys_missing_var_ignored(monkeypatch):
|
||||||
|
"""If an allowed key is not set in the parent process, it should be silently skipped."""
|
||||||
|
monkeypatch.delenv("NONEXISTENT_VAR_12345", raising=False)
|
||||||
|
tool = ExecTool(allowed_env_keys=["NONEXISTENT_VAR_12345"])
|
||||||
|
result = await tool.execute(command="printenv NONEXISTENT_VAR_12345")
|
||||||
|
assert "Exit code: 1" in result
|
||||||
|
|||||||
@ -5,12 +5,18 @@ strategy, and sandbox behaviour per platform — without actually running
|
|||||||
platform-specific binaries (all subprocess calls are mocked).
|
platform-specific binaries (all subprocess calls are mocked).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.tools.shell import ExecTool
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
|
|
||||||
|
_WINDOWS_ENV_KEYS = {
|
||||||
|
"APPDATA", "LOCALAPPDATA", "ProgramData",
|
||||||
|
"ProgramFiles", "ProgramFiles(x86)", "ProgramW6432",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _build_env
|
# _build_env
|
||||||
@ -21,7 +27,10 @@ class TestBuildEnvUnix:
|
|||||||
def test_expected_keys(self):
|
def test_expected_keys(self):
|
||||||
with patch("nanobot.agent.tools.shell._IS_WINDOWS", False):
|
with patch("nanobot.agent.tools.shell._IS_WINDOWS", False):
|
||||||
env = ExecTool()._build_env()
|
env = ExecTool()._build_env()
|
||||||
assert set(env) == {"HOME", "LANG", "TERM"}
|
expected = {"HOME", "LANG", "TERM"}
|
||||||
|
assert expected <= set(env)
|
||||||
|
if sys.platform != "win32":
|
||||||
|
assert set(env) == expected
|
||||||
|
|
||||||
def test_home_from_environ(self, monkeypatch):
|
def test_home_from_environ(self, monkeypatch):
|
||||||
monkeypatch.setenv("HOME", "/Users/dev")
|
monkeypatch.setenv("HOME", "/Users/dev")
|
||||||
@ -45,6 +54,7 @@ class TestBuildEnvWindows:
|
|||||||
_EXPECTED_KEYS = {
|
_EXPECTED_KEYS = {
|
||||||
"SYSTEMROOT", "COMSPEC", "USERPROFILE", "HOMEDRIVE",
|
"SYSTEMROOT", "COMSPEC", "USERPROFILE", "HOMEDRIVE",
|
||||||
"HOMEPATH", "TEMP", "TMP", "PATHEXT", "PATH",
|
"HOMEPATH", "TEMP", "TMP", "PATHEXT", "PATH",
|
||||||
|
*_WINDOWS_ENV_KEYS,
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_expected_keys(self):
|
def test_expected_keys(self):
|
||||||
|
|||||||
@ -67,3 +67,118 @@ async def test_exec_blocks_chained_internal_url():
|
|||||||
command="echo start && curl http://169.254.169.254/latest/meta-data/ && echo done"
|
command="echo start && curl http://169.254.169.254/latest/meta-data/ && echo done"
|
||||||
)
|
)
|
||||||
assert "Error" in result
|
assert "Error" in result
|
||||||
|
|
||||||
|
|
||||||
|
# --- #2989: block writes to nanobot internal state files -----------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"command",
|
||||||
|
[
|
||||||
|
"cat foo >> history.jsonl",
|
||||||
|
"echo '{}' > history.jsonl",
|
||||||
|
"echo '{}' > memory/history.jsonl",
|
||||||
|
"echo '{}' > ./workspace/memory/history.jsonl",
|
||||||
|
"tee -a history.jsonl < foo",
|
||||||
|
"tee history.jsonl",
|
||||||
|
"cp /tmp/fake.jsonl history.jsonl",
|
||||||
|
"mv backup.jsonl memory/history.jsonl",
|
||||||
|
"dd if=/dev/zero of=memory/history.jsonl",
|
||||||
|
"sed -i 's/old/new/' history.jsonl",
|
||||||
|
"echo x > .dream_cursor",
|
||||||
|
"cp /tmp/x memory/.dream_cursor",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_exec_blocks_writes_to_history_jsonl(command):
|
||||||
|
"""Direct writes to history.jsonl / .dream_cursor must be blocked (#2989)."""
|
||||||
|
tool = ExecTool()
|
||||||
|
result = tool._guard_command(command, "/tmp")
|
||||||
|
assert result is not None
|
||||||
|
assert "dangerous pattern" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"command",
|
||||||
|
[
|
||||||
|
"cat history.jsonl",
|
||||||
|
"wc -l history.jsonl",
|
||||||
|
"tail -n 5 history.jsonl",
|
||||||
|
"grep foo history.jsonl",
|
||||||
|
"cp history.jsonl /tmp/history.backup",
|
||||||
|
"ls memory/",
|
||||||
|
"echo history.jsonl",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_exec_allows_reads_of_history_jsonl(command):
|
||||||
|
"""Read-only access to history.jsonl must still be allowed."""
|
||||||
|
tool = ExecTool()
|
||||||
|
result = tool._guard_command(command, "/tmp")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# --- #2826: working_dir must not escape the configured workspace ---------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_blocks_working_dir_outside_workspace(tmp_path):
|
||||||
|
"""An LLM-supplied working_dir outside the workspace must be rejected."""
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
workspace.mkdir()
|
||||||
|
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True)
|
||||||
|
result = await tool.execute(command="rm calendar.ics", working_dir="/etc")
|
||||||
|
assert "outside the configured workspace" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_blocks_absolute_rm_via_hijacked_working_dir(tmp_path):
|
||||||
|
"""Regression for #2826: `rm /abs/path` via working_dir hijack."""
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
workspace.mkdir()
|
||||||
|
victim_dir = tmp_path / "outside"
|
||||||
|
victim_dir.mkdir()
|
||||||
|
victim = victim_dir / "file.ics"
|
||||||
|
victim.write_text("data")
|
||||||
|
|
||||||
|
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True)
|
||||||
|
result = await tool.execute(
|
||||||
|
command=f"rm {victim}",
|
||||||
|
working_dir=str(victim_dir),
|
||||||
|
)
|
||||||
|
assert "outside the configured workspace" in result
|
||||||
|
assert victim.exists(), "victim file must not have been deleted"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_allows_working_dir_within_workspace(tmp_path):
|
||||||
|
"""A working_dir that is a subdirectory of the workspace is fine."""
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
subdir = workspace / "project"
|
||||||
|
subdir.mkdir(parents=True)
|
||||||
|
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True, timeout=5)
|
||||||
|
result = await tool.execute(command="echo ok", working_dir=str(subdir))
|
||||||
|
assert "ok" in result
|
||||||
|
assert "outside the configured workspace" not in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_allows_working_dir_equal_to_workspace(tmp_path):
|
||||||
|
"""Passing working_dir equal to the workspace root must be allowed."""
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
workspace.mkdir()
|
||||||
|
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True, timeout=5)
|
||||||
|
result = await tool.execute(command="echo ok", working_dir=str(workspace))
|
||||||
|
assert "ok" in result
|
||||||
|
assert "outside the configured workspace" not in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_ignores_workspace_check_when_not_restricted(tmp_path):
|
||||||
|
"""Without restrict_to_workspace, the LLM may still choose any working_dir."""
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
workspace.mkdir()
|
||||||
|
other = tmp_path / "other"
|
||||||
|
other.mkdir()
|
||||||
|
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=False, timeout=5)
|
||||||
|
result = await tool.execute(command="echo ok", working_dir=str(other))
|
||||||
|
assert "ok" in result
|
||||||
|
assert "outside the configured workspace" not in result
|
||||||
|
|||||||
@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == ["mcp_test_demo"]
|
assert registry.tool_names == ["mcp_test_demo"]
|
||||||
@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake")},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake")},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
||||||
@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == ["mcp_test_demo"]
|
assert registry.tool_names == ["mcp_test_demo"]
|
||||||
@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
|
|||||||
) -> None:
|
) -> None:
|
||||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == []
|
assert registry.tool_names == []
|
||||||
@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
|||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
||||||
|
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert registry.tool_names == []
|
assert registry.tool_names == []
|
||||||
@ -376,6 +356,73 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
|||||||
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
|
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_mcp_servers_logs_stdio_pollution_hint(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
messages: list[str] = []
|
||||||
|
|
||||||
|
def _error(message: str, *args: object) -> None:
|
||||||
|
messages.append(message.format(*args))
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _broken_stdio_client(_params: object):
|
||||||
|
raise RuntimeError("Parse error: Unexpected token 'INFO' before JSON-RPC headers")
|
||||||
|
yield # pragma: no cover
|
||||||
|
|
||||||
|
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _broken_stdio_client)
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.error", _error)
|
||||||
|
|
||||||
|
registry = ToolRegistry()
|
||||||
|
stacks = await connect_mcp_servers({"gh": MCPServerConfig(command="github-mcp")}, registry)
|
||||||
|
|
||||||
|
assert stacks == {}
|
||||||
|
assert messages
|
||||||
|
assert "stdio protocol pollution" in messages[-1]
|
||||||
|
assert "stdout" in messages[-1]
|
||||||
|
assert "stderr" in messages[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_mcp_servers_one_failure_does_not_block_others(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
sessions = {"good": _make_fake_session(["demo"])}
|
||||||
|
|
||||||
|
class _SelectiveClientSession:
|
||||||
|
def __init__(self, read: object, _write: object) -> None:
|
||||||
|
self._session = sessions[read]
|
||||||
|
|
||||||
|
async def __aenter__(self) -> object:
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _selective_stdio_client(params: object):
|
||||||
|
if params.command == "bad":
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
yield params.command, object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(sys.modules["mcp"], "ClientSession", _SelectiveClientSession)
|
||||||
|
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _selective_stdio_client)
|
||||||
|
|
||||||
|
registry = ToolRegistry()
|
||||||
|
stacks = await connect_mcp_servers(
|
||||||
|
{
|
||||||
|
"good": MCPServerConfig(command="good"),
|
||||||
|
"bad": MCPServerConfig(command="bad"),
|
||||||
|
},
|
||||||
|
registry,
|
||||||
|
)
|
||||||
|
for stack in stacks.values():
|
||||||
|
await stack.aclose()
|
||||||
|
|
||||||
|
assert registry.tool_names == ["mcp_good_demo"]
|
||||||
|
assert set(stacks) == {"good"}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# MCPResourceWrapper tests
|
# MCPResourceWrapper tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -389,9 +436,7 @@ def _make_resource_def(
|
|||||||
return SimpleNamespace(name=name, uri=uri, description=description)
|
return SimpleNamespace(name=name, uri=uri, description=description)
|
||||||
|
|
||||||
|
|
||||||
def _make_resource_wrapper(
|
def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper:
|
||||||
session: object, *, timeout: float = 0.1
|
|
||||||
) -> MCPResourceWrapper:
|
|
||||||
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
|
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
|
||||||
|
|
||||||
|
|
||||||
@ -434,9 +479,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None:
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
return SimpleNamespace(contents=[])
|
return SimpleNamespace(contents=[])
|
||||||
|
|
||||||
wrapper = _make_resource_wrapper(
|
wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01)
|
||||||
SimpleNamespace(read_resource=read_resource), timeout=0.01
|
|
||||||
)
|
|
||||||
result = await wrapper.execute()
|
result = await wrapper.execute()
|
||||||
assert result == "(MCP resource read timed out after 0.01s)"
|
assert result == "(MCP resource read timed out after 0.01s)"
|
||||||
|
|
||||||
@ -464,20 +507,14 @@ def _make_prompt_def(
|
|||||||
return SimpleNamespace(name=name, description=description, arguments=arguments)
|
return SimpleNamespace(name=name, description=description, arguments=arguments)
|
||||||
|
|
||||||
|
|
||||||
def _make_prompt_wrapper(
|
def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper:
|
||||||
session: object, *, timeout: float = 0.1
|
return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout)
|
||||||
) -> MCPPromptWrapper:
|
|
||||||
return MCPPromptWrapper(
|
|
||||||
session, "srv", _make_prompt_def(), prompt_timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_wrapper_properties() -> None:
|
def test_prompt_wrapper_properties() -> None:
|
||||||
arg1 = SimpleNamespace(name="topic", required=True)
|
arg1 = SimpleNamespace(name="topic", required=True)
|
||||||
arg2 = SimpleNamespace(name="style", required=False)
|
arg2 = SimpleNamespace(name="style", required=False)
|
||||||
wrapper = MCPPromptWrapper(
|
wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2]))
|
||||||
None, "myserver", _make_prompt_def(arguments=[arg1, arg2])
|
|
||||||
)
|
|
||||||
assert wrapper.name == "mcp_myserver_prompt_myprompt"
|
assert wrapper.name == "mcp_myserver_prompt_myprompt"
|
||||||
assert "[MCP Prompt]" in wrapper.description
|
assert "[MCP Prompt]" in wrapper.description
|
||||||
assert "A test prompt" in wrapper.description
|
assert "A test prompt" in wrapper.description
|
||||||
@ -528,9 +565,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None:
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
return SimpleNamespace(messages=[])
|
return SimpleNamespace(messages=[])
|
||||||
|
|
||||||
wrapper = _make_prompt_wrapper(
|
wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01)
|
||||||
SimpleNamespace(get_prompt=get_prompt), timeout=0.01
|
|
||||||
)
|
|
||||||
result = await wrapper.execute()
|
result = await wrapper.execute()
|
||||||
assert result == "(MCP prompt call timed out after 0.01s)"
|
assert result == "(MCP prompt call timed out after 0.01s)"
|
||||||
|
|
||||||
@ -616,15 +651,11 @@ async def test_connect_registers_resources_and_prompts(
|
|||||||
prompt_names=["prompt_c"],
|
prompt_names=["prompt_c"],
|
||||||
)
|
)
|
||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
stack = AsyncExitStack()
|
stacks = await connect_mcp_servers(
|
||||||
await stack.__aenter__()
|
{"test": MCPServerConfig(command="fake")},
|
||||||
try:
|
registry,
|
||||||
await connect_mcp_servers(
|
)
|
||||||
{"test": MCPServerConfig(command="fake")},
|
for stack in stacks.values():
|
||||||
registry,
|
|
||||||
stack,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
await stack.aclose()
|
await stack.aclose()
|
||||||
|
|
||||||
assert "mcp_test_tool_a" in registry.tool_names
|
assert "mcp_test_tool_a" in registry.tool_names
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
"""Test message tool suppress logic for final replies."""
|
"""Test message tool suppress logic for final replies."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic:
|
|||||||
assert result is not None
|
assert result is not None
|
||||||
assert "Hello" in result.content
|
assert "Hello" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="call1", name="message",
|
||||||
|
arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"},
|
||||||
|
)
|
||||||
|
calls = iter([
|
||||||
|
LLMResponse(content="First answer", tool_calls=[]),
|
||||||
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
|
LLMResponse(content="", tool_calls=[]),
|
||||||
|
LLMResponse(content="", tool_calls=[]),
|
||||||
|
LLMResponse(content="", tool_calls=[]),
|
||||||
|
])
|
||||||
|
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
sent: list[OutboundMessage] = []
|
||||||
|
mt = loop.tools.get("message")
|
||||||
|
if isinstance(mt, MessageTool):
|
||||||
|
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||||
|
|
||||||
|
pending_queue = asyncio.Queue()
|
||||||
|
await pending_queue.put(
|
||||||
|
InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up")
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start")
|
||||||
|
result = await loop._process_message(msg, pending_queue=pending_queue)
|
||||||
|
|
||||||
|
assert len(sent) == 1
|
||||||
|
assert sent[0].content == "Tool reply"
|
||||||
|
assert result is None
|
||||||
|
|
||||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||||
loop = _make_loop(tmp_path)
|
loop = _make_loop(tmp_path)
|
||||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||||
@ -107,7 +144,7 @@ class TestMessageToolSuppressLogic:
|
|||||||
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||||
progress.append((content, tool_hint))
|
progress.append((content, tool_hint))
|
||||||
|
|
||||||
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||||
|
|
||||||
assert final_content == "Done"
|
assert final_content == "Done"
|
||||||
assert progress == [
|
assert progress == [
|
||||||
|
|||||||
147
tests/tools/test_notebook_tool.py
Normal file
147
tests/tools/test_notebook_tool.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
"""Tests for NotebookEditTool — Jupyter .ipynb editing."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||||
|
|
||||||
|
|
||||||
|
def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict:
|
||||||
|
"""Build a minimal valid .ipynb structure."""
|
||||||
|
return {
|
||||||
|
"nbformat": nbformat,
|
||||||
|
"nbformat_minor": nbformat_minor,
|
||||||
|
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
|
||||||
|
"cells": cells or [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _code_cell(source: str, cell_id: str | None = None) -> dict:
|
||||||
|
cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None}
|
||||||
|
if cell_id:
|
||||||
|
cell["id"] = cell_id
|
||||||
|
return cell
|
||||||
|
|
||||||
|
|
||||||
|
def _md_cell(source: str, cell_id: str | None = None) -> dict:
|
||||||
|
cell = {"cell_type": "markdown", "source": source, "metadata": {}}
|
||||||
|
if cell_id:
|
||||||
|
cell["id"] = cell_id
|
||||||
|
return cell
|
||||||
|
|
||||||
|
|
||||||
|
def _write_nb(tmp_path, name: str, nb: dict) -> str:
|
||||||
|
p = tmp_path / name
|
||||||
|
p.write_text(json.dumps(nb), encoding="utf-8")
|
||||||
|
return str(p)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNotebookEdit:
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return NotebookEditTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replace_cell_content(self, tool, tmp_path):
|
||||||
|
nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")])
|
||||||
|
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||||
|
result = await tool.execute(path=path, cell_index=0, new_source="print('world')")
|
||||||
|
assert "Successfully" in result
|
||||||
|
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||||
|
assert saved["cells"][0]["source"] == "print('world')"
|
||||||
|
assert saved["cells"][1]["source"] == "x = 1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_insert_cell_after_target(self, tool, tmp_path):
|
||||||
|
nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")])
|
||||||
|
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||||
|
result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert")
|
||||||
|
assert "Successfully" in result
|
||||||
|
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||||
|
assert len(saved["cells"]) == 3
|
||||||
|
assert saved["cells"][0]["source"] == "cell 0"
|
||||||
|
assert saved["cells"][1]["source"] == "inserted"
|
||||||
|
assert saved["cells"][2]["source"] == "cell 1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_cell(self, tool, tmp_path):
|
||||||
|
nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")])
|
||||||
|
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||||
|
result = await tool.execute(path=path, cell_index=1, edit_mode="delete")
|
||||||
|
assert "Successfully" in result
|
||||||
|
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||||
|
assert len(saved["cells"]) == 2
|
||||||
|
assert saved["cells"][0]["source"] == "A"
|
||||||
|
assert saved["cells"][1]["source"] == "C"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_new_notebook_from_scratch(self, tool, tmp_path):
|
||||||
|
path = str(tmp_path / "new.ipynb")
|
||||||
|
result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown")
|
||||||
|
assert "Successfully" in result or "created" in result.lower()
|
||||||
|
saved = json.loads((tmp_path / "new.ipynb").read_text())
|
||||||
|
assert saved["nbformat"] == 4
|
||||||
|
assert len(saved["cells"]) == 1
|
||||||
|
assert saved["cells"][0]["cell_type"] == "markdown"
|
||||||
|
assert saved["cells"][0]["source"] == "# Hello"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_cell_index_error(self, tool, tmp_path):
|
||||||
|
nb = _make_notebook([_code_cell("only cell")])
|
||||||
|
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||||
|
result = await tool.execute(path=path, cell_index=5, new_source="x")
|
||||||
|
assert "Error" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_ipynb_rejected(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "script.py"
|
||||||
|
f.write_text("pass")
|
||||||
|
result = await tool.execute(path=str(f), cell_index=0, new_source="x")
|
||||||
|
assert "Error" in result
|
||||||
|
assert ".ipynb" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserves_metadata_and_outputs(self, tool, tmp_path):
|
||||||
|
cell = _code_cell("old")
|
||||||
|
cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}]
|
||||||
|
cell["execution_count"] = 42
|
||||||
|
nb = _make_notebook([cell])
|
||||||
|
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||||
|
await tool.execute(path=path, cell_index=0, new_source="new")
|
||||||
|
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||||
|
assert saved["metadata"]["kernelspec"]["language"] == "python"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_nbformat_45_generates_cell_id(self, tool, tmp_path):
|
||||||
|
nb = _make_notebook([], nbformat_minor=5)
|
||||||
|
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||||
|
await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert")
|
||||||
|
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||||
|
assert "id" in saved["cells"][0]
|
||||||
|
assert len(saved["cells"][0]["id"]) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_insert_with_cell_type_markdown(self, tool, tmp_path):
|
||||||
|
nb = _make_notebook([_code_cell("code")])
|
||||||
|
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||||
|
await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown")
|
||||||
|
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||||
|
assert saved["cells"][1]["cell_type"] == "markdown"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_edit_mode_rejected(self, tool, tmp_path):
|
||||||
|
nb = _make_notebook([_code_cell("code")])
|
||||||
|
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||||
|
result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae")
|
||||||
|
assert "Error" in result
|
||||||
|
assert "edit_mode" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_cell_type_rejected(self, tool, tmp_path):
|
||||||
|
nb = _make_notebook([_code_cell("code")])
|
||||||
|
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||||
|
result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw")
|
||||||
|
assert "Error" in result
|
||||||
|
assert "cell_type" in result
|
||||||
180
tests/tools/test_read_enhancements.py
Normal file
180
tests/tools/test_read_enhancements.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
"""Tests for ReadFileTool enhancements: description fix, read dedup, PDF support, device blacklist."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool
|
||||||
|
from nanobot.agent.tools import file_state
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clear_file_state():
|
||||||
|
file_state.clear()
|
||||||
|
yield
|
||||||
|
file_state.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Description fix
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReadDescriptionFix:
|
||||||
|
|
||||||
|
def test_description_mentions_image_support(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert "image" in tool.description.lower()
|
||||||
|
|
||||||
|
def test_description_no_longer_says_cannot_read_images(self):
|
||||||
|
tool = ReadFileTool()
|
||||||
|
assert "cannot read binary files or images" not in tool.description.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Read deduplication
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReadDedup:
|
||||||
|
"""Same file + same offset/limit + unchanged mtime -> short stub."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return ReadFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def write_tool(self, tmp_path):
|
||||||
|
return WriteFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_second_read_returns_unchanged_stub(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "data.txt"
|
||||||
|
f.write_text("\n".join(f"line {i}" for i in range(100)), encoding="utf-8")
|
||||||
|
first = await tool.execute(path=str(f))
|
||||||
|
assert "line 0" in first
|
||||||
|
second = await tool.execute(path=str(f))
|
||||||
|
assert "unchanged" in second.lower()
|
||||||
|
# Stub should not contain file content
|
||||||
|
assert "line 0" not in second
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_after_external_modification_returns_full(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "data.txt"
|
||||||
|
f.write_text("original", encoding="utf-8")
|
||||||
|
await tool.execute(path=str(f))
|
||||||
|
# Modify the file externally
|
||||||
|
f.write_text("modified content", encoding="utf-8")
|
||||||
|
second = await tool.execute(path=str(f))
|
||||||
|
assert "modified content" in second
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_different_offset_returns_full(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "data.txt"
|
||||||
|
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
|
||||||
|
await tool.execute(path=str(f), offset=1, limit=5)
|
||||||
|
second = await tool.execute(path=str(f), offset=6, limit=5)
|
||||||
|
# Different offset → full read, not stub
|
||||||
|
assert "line 6" in second
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_first_read_after_write_returns_full_content(self, tool, write_tool, tmp_path):
|
||||||
|
f = tmp_path / "fresh.txt"
|
||||||
|
result = await write_tool.execute(path=str(f), content="hello")
|
||||||
|
assert "Successfully" in result
|
||||||
|
read_result = await tool.execute(path=str(f))
|
||||||
|
assert "hello" in read_result
|
||||||
|
assert "unchanged" not in read_result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dedup_does_not_apply_to_images(self, tool, tmp_path):
|
||||||
|
f = tmp_path / "img.png"
|
||||||
|
f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data")
|
||||||
|
first = await tool.execute(path=str(f))
|
||||||
|
assert isinstance(first, list)
|
||||||
|
second = await tool.execute(path=str(f))
|
||||||
|
# Images should always return full content blocks, not a stub
|
||||||
|
assert isinstance(second, list)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PDF support
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReadPdf:
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self, tmp_path):
|
||||||
|
return ReadFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pdf_returns_text_content(self, tool, tmp_path):
|
||||||
|
fitz = pytest.importorskip("fitz")
|
||||||
|
pdf_path = tmp_path / "test.pdf"
|
||||||
|
doc = fitz.open()
|
||||||
|
page = doc.new_page()
|
||||||
|
page.insert_text((72, 72), "Hello PDF World")
|
||||||
|
doc.save(str(pdf_path))
|
||||||
|
doc.close()
|
||||||
|
|
||||||
|
result = await tool.execute(path=str(pdf_path))
|
||||||
|
assert "Hello PDF World" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pdf_pages_parameter(self, tool, tmp_path):
|
||||||
|
fitz = pytest.importorskip("fitz")
|
||||||
|
pdf_path = tmp_path / "multi.pdf"
|
||||||
|
doc = fitz.open()
|
||||||
|
for i in range(5):
|
||||||
|
page = doc.new_page()
|
||||||
|
page.insert_text((72, 72), f"Page {i + 1} content")
|
||||||
|
doc.save(str(pdf_path))
|
||||||
|
doc.close()
|
||||||
|
|
||||||
|
result = await tool.execute(path=str(pdf_path), pages="2-3")
|
||||||
|
assert "Page 2 content" in result
|
||||||
|
assert "Page 3 content" in result
|
||||||
|
assert "Page 1 content" not in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pdf_file_not_found_error(self, tool, tmp_path):
|
||||||
|
result = await tool.execute(path=str(tmp_path / "nope.pdf"))
|
||||||
|
assert "Error" in result
|
||||||
|
assert "not found" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Device path blacklist
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestReadDeviceBlacklist:
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tool(self):
|
||||||
|
return ReadFileTool()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dev_random_blocked(self, tool):
|
||||||
|
result = await tool.execute(path="/dev/random")
|
||||||
|
assert "Error" in result
|
||||||
|
assert "blocked" in result.lower() or "device" in result.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dev_urandom_blocked(self, tool):
|
||||||
|
result = await tool.execute(path="/dev/urandom")
|
||||||
|
assert "Error" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dev_zero_blocked(self, tool):
|
||||||
|
result = await tool.execute(path="/dev/zero")
|
||||||
|
assert "Error" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_proc_fd_blocked(self, tool):
|
||||||
|
result = await tool.execute(path="/proc/self/fd/0")
|
||||||
|
assert "Error" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_symlink_to_dev_zero_blocked(self, tmp_path):
|
||||||
|
tool = ReadFileTool(workspace=tmp_path)
|
||||||
|
link = tmp_path / "zero-link"
|
||||||
|
link.symlink_to("/dev/zero")
|
||||||
|
result = await tool.execute(path=str(link))
|
||||||
|
assert "Error" in result
|
||||||
|
assert "blocked" in result.lower() or "device" in result.lower()
|
||||||
@ -323,3 +323,27 @@ async def test_subagent_registers_grep_and_glob(tmp_path: Path) -> None:
|
|||||||
|
|
||||||
assert "grep" in captured["tool_names"]
|
assert "grep" in captured["tool_names"]
|
||||||
assert "glob" in captured["tool_names"]
|
assert "glob" in captured["tool_names"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_subagent_prompt_respects_disabled_skills(tmp_path: Path) -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
skills_dir = tmp_path / "skills"
|
||||||
|
(skills_dir / "alpha").mkdir(parents=True)
|
||||||
|
(skills_dir / "alpha" / "SKILL.md").write_text("# Alpha\n\nhidden\n", encoding="utf-8")
|
||||||
|
(skills_dir / "beta").mkdir(parents=True)
|
||||||
|
(skills_dir / "beta" / "SKILL.md").write_text("# Beta\n\nshown\n", encoding="utf-8")
|
||||||
|
|
||||||
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=4096,
|
||||||
|
disabled_skills=["alpha"],
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = mgr._build_subagent_prompt()
|
||||||
|
|
||||||
|
assert "alpha" not in prompt
|
||||||
|
assert "beta" in prompt
|
||||||
|
|||||||
@ -47,3 +47,27 @@ def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
|
|||||||
"mcp_fs_list",
|
"mcp_fs_list",
|
||||||
"mcp_git_status",
|
"mcp_git_status",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_call_read_file_rejects_non_object_params_with_actionable_hint() -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(_FakeTool("read_file"))
|
||||||
|
|
||||||
|
tool, params, error = registry.prepare_call("read_file", ["foo.txt"])
|
||||||
|
|
||||||
|
assert tool is None
|
||||||
|
assert params == ["foo.txt"]
|
||||||
|
assert error is not None
|
||||||
|
assert "must be a JSON object" in error
|
||||||
|
assert "Use named parameters" in error
|
||||||
|
|
||||||
|
|
||||||
|
def test_prepare_call_other_tools_keep_generic_object_validation() -> None:
|
||||||
|
registry = ToolRegistry()
|
||||||
|
registry.register(_FakeTool("grep"))
|
||||||
|
|
||||||
|
tool, params, error = registry.prepare_call("grep", ["TODO"])
|
||||||
|
|
||||||
|
assert tool is not None
|
||||||
|
assert params == ["TODO"]
|
||||||
|
assert error == "Error: Invalid parameters for tool 'grep': parameters must be an object, got list"
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
"""Tests for multi-provider web search."""
|
"""Tests for multi-provider web search."""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -20,6 +18,25 @@ def _response(status: int = 200, json: dict | None = None) -> httpx.Response:
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
def test_duckduckgo_search_is_exclusive():
|
||||||
|
tool = _tool(provider="duckduckgo")
|
||||||
|
assert tool.exclusive is True
|
||||||
|
assert tool.concurrency_safe is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_brave_with_api_key_remains_concurrency_safe():
|
||||||
|
tool = _tool(provider="brave", api_key="brave-key")
|
||||||
|
assert tool.exclusive is False
|
||||||
|
assert tool.concurrency_safe is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_brave_without_api_key_is_treated_as_duckduckgo_for_concurrency(monkeypatch):
|
||||||
|
monkeypatch.delenv("BRAVE_API_KEY", raising=False)
|
||||||
|
tool = _tool(provider="brave", api_key="")
|
||||||
|
assert tool.exclusive is True
|
||||||
|
assert tool.concurrency_safe is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_brave_search(monkeypatch):
|
async def test_brave_search(monkeypatch):
|
||||||
async def mock_get(self, url, **kw):
|
async def mock_get(self, url, **kw):
|
||||||
@ -79,7 +96,6 @@ async def test_duckduckgo_search(monkeypatch):
|
|||||||
import nanobot.agent.tools.web as web_mod
|
import nanobot.agent.tools.web as web_mod
|
||||||
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
|
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
|
||||||
|
|
||||||
from ddgs import DDGS
|
|
||||||
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||||
|
|
||||||
tool = _tool(provider="duckduckgo")
|
tool = _tool(provider="duckduckgo")
|
||||||
@ -120,6 +136,27 @@ async def test_jina_search(monkeypatch):
|
|||||||
assert "https://jina.ai" in result
|
assert "https://jina.ai" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_kagi_search(monkeypatch):
|
||||||
|
async def mock_get(self, url, **kw):
|
||||||
|
assert "kagi.com/api/v0/search" in url
|
||||||
|
assert kw["headers"]["Authorization"] == "Bot kagi-key"
|
||||||
|
assert kw["params"] == {"q": "test", "limit": 2}
|
||||||
|
return _response(json={
|
||||||
|
"data": [
|
||||||
|
{"t": 0, "title": "Kagi Result", "url": "https://kagi.com", "snippet": "Premium search"},
|
||||||
|
{"t": 1, "list": ["ignored related search"]},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
|
||||||
|
tool = _tool(provider="kagi", api_key="kagi-key")
|
||||||
|
result = await tool.execute(query="test", count=2)
|
||||||
|
assert "Kagi Result" in result
|
||||||
|
assert "https://kagi.com" in result
|
||||||
|
assert "ignored related search" not in result
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unknown_provider():
|
async def test_unknown_provider():
|
||||||
tool = _tool(provider="unknown")
|
tool = _tool(provider="unknown")
|
||||||
@ -189,6 +226,23 @@ async def test_jina_422_falls_back_to_duckduckgo(monkeypatch):
|
|||||||
assert "DuckDuckGo fallback" in result
|
assert "DuckDuckGo fallback" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_kagi_fallback_to_duckduckgo_when_no_key(monkeypatch):
|
||||||
|
class MockDDGS:
|
||||||
|
def __init__(self, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def text(self, query, max_results=5):
|
||||||
|
return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}]
|
||||||
|
|
||||||
|
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||||
|
monkeypatch.delenv("KAGI_API_KEY", raising=False)
|
||||||
|
|
||||||
|
tool = _tool(provider="kagi", api_key="")
|
||||||
|
result = await tool.execute(query="test")
|
||||||
|
assert "Fallback" in result
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_jina_search_uses_path_encoded_query(monkeypatch):
|
async def test_jina_search_uses_path_encoded_query(monkeypatch):
|
||||||
calls = {}
|
calls = {}
|
||||||
@ -227,5 +281,3 @@ async def test_duckduckgo_timeout_returns_error(monkeypatch):
|
|||||||
result = await tool.execute(query="test")
|
result = await tool.execute(query="test")
|
||||||
gate.set()
|
gate.set()
|
||||||
assert "Error" in result
|
assert "Error" in result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
65
tests/utils/test_strip_think.py
Normal file
65
tests/utils/test_strip_think.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.utils.helpers import strip_think
|
||||||
|
|
||||||
|
|
||||||
|
class TestStripThinkTag:
|
||||||
|
"""Test <thought>...</thought> block stripping (Gemma 4 and similar models)."""
|
||||||
|
|
||||||
|
def test_closed_tag(self):
|
||||||
|
assert strip_think("Hello <thought>reasoning</thought> World") == "Hello World"
|
||||||
|
|
||||||
|
def test_unclosed_trailing_tag(self):
|
||||||
|
assert strip_think("<thought>ongoing...") == ""
|
||||||
|
|
||||||
|
def test_multiline_tag(self):
|
||||||
|
assert strip_think("<thought>\nline1\nline2\n</thought>End") == "End"
|
||||||
|
|
||||||
|
def test_tag_with_nested_angle_brackets(self):
|
||||||
|
text = "<thought>a < 3 and b > 2</thought>result"
|
||||||
|
assert strip_think(text) == "result"
|
||||||
|
|
||||||
|
def test_multiple_tag_blocks(self):
|
||||||
|
text = "A<thought>x</thought>B<thought>y</thought>C"
|
||||||
|
assert strip_think(text) == "ABC"
|
||||||
|
|
||||||
|
def test_tag_only_whitespace_inside(self):
|
||||||
|
assert strip_think("before<thought> </thought>after") == "beforeafter"
|
||||||
|
|
||||||
|
def test_self_closing_tag_not_matched(self):
|
||||||
|
assert strip_think("<thought/>some text") == "<thought/>some text"
|
||||||
|
|
||||||
|
def test_normal_text_unchanged(self):
|
||||||
|
assert strip_think("Just normal text") == "Just normal text"
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
assert strip_think("") == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestStripThinkFalsePositive:
|
||||||
|
"""Ensure mid-content <think>/<thought> tags are NOT stripped (#3004)."""
|
||||||
|
|
||||||
|
def test_backtick_think_tag_preserved(self):
|
||||||
|
text = "*Think Stripping:* A new utility to strip `<think>` tags from output."
|
||||||
|
assert strip_think(text) == text
|
||||||
|
|
||||||
|
def test_prose_think_tag_preserved(self):
|
||||||
|
text = "The model emits <think> at the start of its response."
|
||||||
|
assert strip_think(text) == text
|
||||||
|
|
||||||
|
def test_code_block_think_tag_preserved(self):
|
||||||
|
text = "Example:\n```\ntext = re.sub(r\"<think>[\\s\\S]*\", \"\", text)\n```\nDone."
|
||||||
|
assert strip_think(text) == text
|
||||||
|
|
||||||
|
def test_backtick_thought_tag_preserved(self):
|
||||||
|
text = "Gemma 4 uses `<thought>` blocks for reasoning."
|
||||||
|
assert strip_think(text) == text
|
||||||
|
|
||||||
|
def test_prefix_unclosed_think_still_stripped(self):
|
||||||
|
assert strip_think("<think>reasoning without closing") == ""
|
||||||
|
|
||||||
|
def test_prefix_unclosed_think_with_whitespace(self):
|
||||||
|
assert strip_think(" <think>reasoning...") == ""
|
||||||
|
|
||||||
|
def test_prefix_unclosed_thought_still_stripped(self):
|
||||||
|
assert strip_think("<thought>reasoning without closing") == ""
|
||||||
Loading…
x
Reference in New Issue
Block a user