mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
Merge origin/main into feat/api-file-upload
Keep the API file upload branch current with main, enforce the documented JSON base64 per-file limit, and avoid leaking document extraction error strings into user prompts. Made-with: Cursor
This commit is contained in:
commit
2502fc616b
151
README.md
151
README.md
@ -394,7 +394,8 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"],
|
||||
"groupPolicy": "mention"
|
||||
"groupPolicy": "mention",
|
||||
"streaming": true
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -405,6 +406,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
||||
> - `"open"` — Respond to all messages
|
||||
> DMs always respond when the sender is in `allowFrom`.
|
||||
> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session.
|
||||
> `streaming` defaults to `true`. Disable it only if you explicitly want non-streaming replies.
|
||||
|
||||
**5. Invite the bot**
|
||||
- OAuth2 → URL Generator
|
||||
@ -558,7 +560,11 @@ Uses **WebSocket** long connection — no public IP required.
|
||||
"verificationToken": "",
|
||||
"allowFrom": ["ou_YOUR_OPEN_ID"],
|
||||
"groupPolicy": "mention",
|
||||
"streaming": true
|
||||
"reactEmoji": "OnIt",
|
||||
"doneEmoji": "DONE",
|
||||
"toolHintPrefix": "🔧",
|
||||
"streaming": true,
|
||||
"domain": "feishu"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -568,6 +574,10 @@ Uses **WebSocket** long connection — no public IP required.
|
||||
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
|
||||
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
|
||||
> `groupPolicy`: `"mention"` (default — respond only when @mentioned), `"open"` (respond to all group messages). Private chats always respond.
|
||||
> `reactEmoji`: Emoji for "processing" status (default: `OnIt`). See [available emojis](https://open.larkoffice.com/document/server-docs/im-v1/message-reaction/emojis-introduce).
|
||||
> `doneEmoji`: Optional emoji for "completed" status (e.g., `DONE`, `OK`, `HEART`). When set, bot adds this reaction after removing `reactEmoji`.
|
||||
> `toolHintPrefix`: Prefix for inline tool hints in streaming cards (default: `🔧`).
|
||||
> `domain`: `"feishu"` (default) for China (open.feishu.cn), `"lark"` for international Lark (open.larksuite.com).
|
||||
|
||||
**3. Run**
|
||||
|
||||
@ -1043,6 +1053,30 @@ Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, To
|
||||
```
|
||||
|
||||
> For local servers that don't require a key, set `apiKey` to any non-empty string (e.g. `"no-key"`).
|
||||
>
|
||||
> `custom` is the right choice for providers that expose an OpenAI-compatible **chat completions** API. It does **not** force third-party endpoints onto the OpenAI/Azure **Responses API**.
|
||||
>
|
||||
> If your proxy or gateway is specifically Responses-API-compatible, use the `azure_openai` provider shape instead and point `apiBase` at that endpoint:
|
||||
>
|
||||
> ```json
|
||||
> {
|
||||
> "providers": {
|
||||
> "azure_openai": {
|
||||
> "apiKey": "your-api-key",
|
||||
> "apiBase": "https://api.your-provider.com",
|
||||
> "defaultModel": "your-model-name"
|
||||
> }
|
||||
> },
|
||||
> "agents": {
|
||||
> "defaults": {
|
||||
> "provider": "azure_openai",
|
||||
> "model": "your-model-name"
|
||||
> }
|
||||
> }
|
||||
> }
|
||||
> ```
|
||||
>
|
||||
> In short: **chat-completions-compatible endpoint → `custom`**; **Responses-compatible endpoint → `azure_openai`**.
|
||||
|
||||
</details>
|
||||
|
||||
@ -1304,6 +1338,7 @@ If you need to allow trusted private ranges such as Tailscale / CGNAT addresses,
|
||||
| `brave` | `apiKey` | `BRAVE_API_KEY` | No |
|
||||
| `tavily` | `apiKey` | `TAVILY_API_KEY` | No |
|
||||
| `jina` | `apiKey` | `JINA_API_KEY` | Free tier (10M tokens) |
|
||||
| `kagi` | `apiKey` | `KAGI_API_KEY` | No |
|
||||
| `searxng` | `baseUrl` | `SEARXNG_BASE_URL` | Yes (self-hosted) |
|
||||
| `duckduckgo` (default) | — | — | Yes |
|
||||
|
||||
@ -1360,6 +1395,20 @@ If you need to allow trusted private ranges such as Tailscale / CGNAT addresses,
|
||||
}
|
||||
```
|
||||
|
||||
**Kagi:**
|
||||
```json
|
||||
{
|
||||
"tools": {
|
||||
"web": {
|
||||
"search": {
|
||||
"provider": "kagi",
|
||||
"apiKey": "your-kagi-api-key"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**SearXNG** (self-hosted, no API key needed):
|
||||
```json
|
||||
{
|
||||
@ -1495,6 +1544,35 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
|
||||
**Docker security**: The official Docker image runs as a non-root user (`nanobot`, UID 1000) with bubblewrap pre-installed. When using `docker-compose.yml`, the container drops all Linux capabilities except `SYS_ADMIN` (required for bwrap's namespace isolation).
|
||||
|
||||
|
||||
### Auto Compact
|
||||
|
||||
When a user is idle for longer than a configured threshold, nanobot **proactively** compresses the older part of the session context into a summary while keeping a recent legal suffix of live messages. This reduces token cost and first-token latency when the user returns — instead of re-processing a long stale context with an expired KV cache, the model receives a compact summary, the most recent live context, and fresh input.
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"idleCompactAfterMinutes": 15
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `agents.defaults.idleCompactAfterMinutes` | `0` (disabled) | Minutes of idle time before auto-compaction starts. Set to `0` to disable. Recommended: `15` — close to a typical LLM KV cache expiry window, so stale sessions get compacted before the user returns. |
|
||||
|
||||
`sessionTtlMinutes` remains accepted as a legacy alias for backward compatibility, but `idleCompactAfterMinutes` is the preferred config key going forward.
|
||||
|
||||
How it works:
|
||||
1. **Idle detection**: On each idle tick (~1 s), checks all sessions for expiration.
|
||||
2. **Background compaction**: Idle sessions summarize the older live prefix via LLM and keep the most recent legal suffix (currently 8 messages).
|
||||
3. **Summary injection**: When the user returns, the summary is injected as runtime context (one-shot, not persisted) alongside the retained recent suffix.
|
||||
4. **Restart-safe resume**: The summary is also mirrored into session metadata so it can still be recovered after a process restart.
|
||||
|
||||
> [!TIP]
|
||||
> Think of auto compact as "summarize older context, keep the freshest live turns." It is not a hard session reset.
|
||||
|
||||
### Timezone
|
||||
|
||||
Time is context. Context should be precise.
|
||||
@ -1517,6 +1595,52 @@ Common examples: `UTC`, `America/New_York`, `America/Los_Angeles`, `Europe/Londo
|
||||
|
||||
> Need another timezone? Browse the full [IANA Time Zone Database](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones).
|
||||
|
||||
### Unified Session
|
||||
|
||||
By default, each channel × chat ID combination gets its own session. If you use nanobot across multiple channels (e.g. Telegram + Discord + CLI) and want them to share the same conversation, enable `unifiedSession`:
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"unifiedSession": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
When enabled, all incoming messages — regardless of which channel they arrive on — are routed into a single shared session. Switching from Telegram to Discord (or any other channel) continues the same conversation seamlessly.
|
||||
|
||||
| Behavior | `false` (default) | `true` |
|
||||
|----------|-------------------|--------|
|
||||
| Session key | `channel:chat_id` | `unified:default` |
|
||||
| Cross-channel continuity | No | Yes |
|
||||
| `/new` clears | Current channel session | Shared session |
|
||||
| `/stop` finds tasks | By channel session | By shared session |
|
||||
| Existing `session_key_override` (e.g. Telegram thread) | Respected | Still respected — not overwritten |
|
||||
|
||||
> This is designed for single-user, multi-device setups. It is **off by default** — existing users see zero behavior change.
|
||||
|
||||
### Disabled Skills
|
||||
|
||||
nanobot ships with built-in skills, and your workspace can also define custom skills under `skills/`. If you want to hide specific skills from the agent, set `agents.defaults.disabledSkills` to a list of skill directory names:
|
||||
|
||||
```json
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"disabledSkills": ["github", "weather"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Disabled skills are excluded from the main agent's skill summary, from always-on skill injection, and from subagent skill summaries. This is useful when some bundled skills are unnecessary for your deployment or should not be exposed to end users.
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `agents.defaults.disabledSkills` | `[]` | List of skill directory names to exclude from loading. Applies to both built-in skills and workspace skills. |
|
||||
|
||||
## 🧩 Multiple Instances
|
||||
|
||||
Run multiple nanobot instances simultaneously with separate configs and runtime data. Use `--config` as the main entrypoint. Optionally pass `--workspace` during `onboard` when you want to initialize or update the saved workspace for a specific instance.
|
||||
@ -1603,6 +1727,7 @@ Example config:
|
||||
}
|
||||
},
|
||||
"gateway": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 18790
|
||||
}
|
||||
}
|
||||
@ -1615,6 +1740,14 @@ nanobot gateway --config ~/.nanobot-telegram/config.json
|
||||
nanobot gateway --config ~/.nanobot-discord/config.json
|
||||
```
|
||||
|
||||
Each gateway instance also exposes a lightweight HTTP health endpoint on
|
||||
`gateway.host:gateway.port`. By default, the gateway binds to `127.0.0.1`,
|
||||
so the endpoint stays local unless you explicitly set `gateway.host` to a
|
||||
public or LAN-facing address.
|
||||
|
||||
- `GET /health` returns `{"status":"ok"}`
|
||||
- Other paths return `404`
|
||||
|
||||
Override workspace for one-off runs when needed:
|
||||
|
||||
```bash
|
||||
@ -1642,6 +1775,7 @@ time.
|
||||
|
||||
- `memory/history.jsonl` stores append-only summarized history
|
||||
- `SOUL.md`, `USER.md`, and `memory/MEMORY.md` store long-term knowledge managed by Dream
|
||||
- `Dream` can also promote repeated workflows into reusable workspace skills under `skills/`
|
||||
- `Dream` runs on a schedule and can also be triggered manually
|
||||
- memory changes can be inspected and restored with built-in commands
|
||||
|
||||
@ -1758,6 +1892,19 @@ By default, the API binds to `127.0.0.1:8900`. You can change this in `config.js
|
||||
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
|
||||
- No streaming: `stream=true` is not supported
|
||||
- **File uploads**: supports images, PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) via JSON base64 or `multipart/form-data` (max 10MB per file)
|
||||
- API requests run in the synthetic `api` channel, so the `message` tool does **not** automatically deliver to Telegram/Discord/etc. To proactively send to another chat, call `message` with an explicit `channel` and `chat_id` for an enabled channel.
|
||||
|
||||
Example tool call for cross-channel delivery from an API session:
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "Build finished successfully.",
|
||||
"channel": "telegram",
|
||||
"chat_id": "123456789"
|
||||
}
|
||||
```
|
||||
|
||||
If `channel` points to a channel that is not enabled in your config, nanobot will queue the outbound event but no platform delivery will occur.
|
||||
|
||||
### Endpoints
|
||||
|
||||
|
||||
@ -290,7 +290,6 @@ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] |
|
||||
|------|---------|
|
||||
| `_stream_delta: True` | A content chunk (delta contains the new text) |
|
||||
| `_stream_end: True` | Streaming finished (delta is empty) |
|
||||
| `_resuming: True` | More streaming rounds coming (e.g. tool call then another response) |
|
||||
|
||||
### Example: Webhook with Streaming
|
||||
|
||||
|
||||
331
docs/WEBSOCKET.md
Normal file
331
docs/WEBSOCKET.md
Normal file
@ -0,0 +1,331 @@
|
||||
# WebSocket Server Channel
|
||||
|
||||
Nanobot can act as a WebSocket server, allowing external clients (web apps, CLIs, scripts) to interact with the agent in real time via persistent connections.
|
||||
|
||||
## Features
|
||||
|
||||
- Bidirectional real-time communication over WebSocket
|
||||
- Streaming support — receive agent responses token by token
|
||||
- Token-based authentication (static tokens and short-lived issued tokens)
|
||||
- Per-connection sessions — each connection gets a unique `chat_id`
|
||||
- TLS/SSL support (WSS) with enforced TLSv1.2 minimum
|
||||
- Client allow-list via `allowFrom`
|
||||
- Auto-cleanup of dead connections
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Configure
|
||||
|
||||
Add to `config.json` under `channels.websocket`:
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"host": "127.0.0.1",
|
||||
"port": 8765,
|
||||
"path": "/",
|
||||
"websocketRequiresToken": false,
|
||||
"allowFrom": ["*"],
|
||||
"streaming": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Start nanobot
|
||||
|
||||
```bash
|
||||
nanobot gateway
|
||||
```
|
||||
|
||||
You should see:
|
||||
|
||||
```
|
||||
WebSocket server listening on ws://127.0.0.1:8765/
|
||||
```
|
||||
|
||||
### 3. Connect a client
|
||||
|
||||
```bash
|
||||
# Using websocat
|
||||
websocat ws://127.0.0.1:8765/?client_id=alice
|
||||
|
||||
# Using Python
|
||||
import asyncio, json, websockets
|
||||
|
||||
async def main():
|
||||
async with websockets.connect("ws://127.0.0.1:8765/?client_id=alice") as ws:
|
||||
ready = json.loads(await ws.recv())
|
||||
print(ready) # {"event": "ready", "chat_id": "...", "client_id": "alice"}
|
||||
await ws.send(json.dumps({"content": "Hello nanobot!"}))
|
||||
reply = json.loads(await ws.recv())
|
||||
print(reply["text"])
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Connection URL
|
||||
|
||||
```
|
||||
ws://{host}:{port}{path}?client_id={id}&token={token}
|
||||
```
|
||||
|
||||
| Parameter | Required | Description |
|
||||
|-----------|----------|-------------|
|
||||
| `client_id` | No | Identifier for `allowFrom` authorization. Auto-generated as `anon-xxxxxxxxxxxx` if omitted. Truncated to 128 chars. |
|
||||
| `token` | Conditional | Authentication token. Required when `websocketRequiresToken` is `true` or `token` (static secret) is configured. |
|
||||
|
||||
## Wire Protocol
|
||||
|
||||
All frames are JSON text. Each message has an `event` field.
|
||||
|
||||
### Server → Client
|
||||
|
||||
**`ready`** — sent immediately after connection is established:
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "ready",
|
||||
"chat_id": "uuid-v4",
|
||||
"client_id": "alice"
|
||||
}
|
||||
```
|
||||
|
||||
**`message`** — full agent response:
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "message",
|
||||
"text": "Hello! How can I help?",
|
||||
"media": ["/tmp/image.png"],
|
||||
"reply_to": "msg-id"
|
||||
}
|
||||
```
|
||||
|
||||
`media` and `reply_to` are only present when applicable.
|
||||
|
||||
**`delta`** — streaming text chunk (only when `streaming: true`):
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "delta",
|
||||
"text": "Hello",
|
||||
"stream_id": "s1"
|
||||
}
|
||||
```
|
||||
|
||||
**`stream_end`** — signals the end of a streaming segment:
|
||||
|
||||
```json
|
||||
{
|
||||
"event": "stream_end",
|
||||
"stream_id": "s1"
|
||||
}
|
||||
```
|
||||
|
||||
### Client → Server
|
||||
|
||||
Send plain text:
|
||||
|
||||
```json
|
||||
"Hello nanobot!"
|
||||
```
|
||||
|
||||
Or send a JSON object with a recognized text field:
|
||||
|
||||
```json
|
||||
{"content": "Hello nanobot!"}
|
||||
```
|
||||
|
||||
Recognized fields: `content`, `text`, `message` (checked in that order). Invalid JSON is treated as plain text.
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
All fields go under `channels.websocket` in `config.json`.
|
||||
|
||||
### Connection
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `enabled` | bool | `false` | Enable the WebSocket server. |
|
||||
| `host` | string | `"127.0.0.1"` | Bind address. Use `"0.0.0.0"` to accept external connections. |
|
||||
| `port` | int | `8765` | Listen port. |
|
||||
| `path` | string | `"/"` | WebSocket upgrade path. Trailing slashes are normalized (root `/` is preserved). |
|
||||
| `maxMessageBytes` | int | `1048576` | Maximum inbound message size in bytes (1 KB – 16 MB). |
|
||||
|
||||
### Authentication
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `token` | string | `""` | Static shared secret. When set, clients must provide `?token=<value>` matching this secret (timing-safe comparison). Issued tokens are also accepted as a fallback. |
|
||||
| `websocketRequiresToken` | bool | `true` | When `true` and no static `token` is configured, clients must still present a valid issued token. Set to `false` to allow unauthenticated connections (only safe for local/trusted networks). |
|
||||
| `tokenIssuePath` | string | `""` | HTTP path for issuing short-lived tokens. Must differ from `path`. See [Token Issuance](#token-issuance). |
|
||||
| `tokenIssueSecret` | string | `""` | Secret required to obtain tokens via the issue endpoint. If empty, any client can obtain tokens (logged as a warning). |
|
||||
| `tokenTtlS` | int | `300` | Time-to-live for issued tokens in seconds (30 – 86,400). |
|
||||
|
||||
### Access Control
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `allowFrom` | list of string | `["*"]` | Allowed `client_id` values. `"*"` allows all; `[]` denies all. |
|
||||
|
||||
### Streaming
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `streaming` | bool | `true` | Enable streaming mode. The agent sends `delta` + `stream_end` frames instead of a single `message`. |
|
||||
|
||||
### Keep-alive
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `pingIntervalS` | float | `20.0` | WebSocket ping interval in seconds (5 – 300). |
|
||||
| `pingTimeoutS` | float | `20.0` | Time to wait for a pong before closing the connection (5 – 300). |
|
||||
|
||||
### TLS/SSL
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `sslCertfile` | string | `""` | Path to the TLS certificate file (PEM). Both `sslCertfile` and `sslKeyfile` must be set to enable WSS. |
|
||||
| `sslKeyfile` | string | `""` | Path to the TLS private key file (PEM). Minimum TLS version is enforced as TLSv1.2. |
|
||||
|
||||
## Token Issuance
|
||||
|
||||
For production deployments where `websocketRequiresToken: true`, use short-lived tokens instead of embedding static secrets in clients.
|
||||
|
||||
### How it works
|
||||
|
||||
1. Client sends `GET {tokenIssuePath}` with `Authorization: Bearer {tokenIssueSecret}` (or `X-Nanobot-Auth` header).
|
||||
2. Server responds with a one-time-use token:
|
||||
|
||||
```json
|
||||
{"token": "nbwt_aBcDeFg...", "expires_in": 300}
|
||||
```
|
||||
|
||||
3. Client opens WebSocket with `?token=nbwt_aBcDeFg...&client_id=...`.
|
||||
4. The token is consumed (single use) and cannot be reused.
|
||||
|
||||
### Example setup
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"port": 8765,
|
||||
"path": "/ws",
|
||||
"tokenIssuePath": "/auth/token",
|
||||
"tokenIssueSecret": "your-secret-here",
|
||||
"tokenTtlS": 300,
|
||||
"websocketRequiresToken": true,
|
||||
"allowFrom": ["*"],
|
||||
"streaming": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Client flow:
|
||||
|
||||
```bash
|
||||
# 1. Obtain a token
|
||||
curl -H "Authorization: Bearer your-secret-here" http://127.0.0.1:8765/auth/token
|
||||
|
||||
# 2. Connect using the token
|
||||
websocat "ws://127.0.0.1:8765/ws?client_id=alice&token=nbwt_aBcDeFg..."
|
||||
```
|
||||
|
||||
### Limits
|
||||
|
||||
- Issued tokens are single-use — each token can only complete one handshake.
|
||||
- Outstanding tokens are capped at 10,000. Requests beyond this return HTTP 429.
|
||||
- Expired tokens are purged lazily on each issue or validation request.
|
||||
|
||||
## Security Notes
|
||||
|
||||
- **Timing-safe comparison**: Static token validation uses `hmac.compare_digest` to prevent timing attacks.
|
||||
- **Defense in depth**: `allowFrom` is checked at both the HTTP handshake level and the message level.
|
||||
- **Token isolation**: Each WebSocket connection gets a unique `chat_id`. Clients cannot access other sessions.
|
||||
- **TLS enforcement**: When SSL is enabled, TLSv1.2 is the minimum allowed version.
|
||||
- **Default-secure**: `websocketRequiresToken` defaults to `true`. Explicitly set it to `false` only on trusted networks.
|
||||
|
||||
## Media Files
|
||||
|
||||
Outbound `message` events may include a `media` field containing local filesystem paths. Remote clients cannot access these files directly — they need either:
|
||||
|
||||
- A shared filesystem mount, or
|
||||
- An HTTP file server serving the nanobot media directory
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Trusted local network (no auth)
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"host": "0.0.0.0",
|
||||
"port": 8765,
|
||||
"websocketRequiresToken": false,
|
||||
"allowFrom": ["*"],
|
||||
"streaming": true
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Static token (simple auth)
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"token": "my-shared-secret",
|
||||
"allowFrom": ["alice", "bob"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Clients connect with `?token=my-shared-secret&client_id=alice`.
|
||||
|
||||
### Public endpoint with issued tokens
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"host": "0.0.0.0",
|
||||
"port": 8765,
|
||||
"path": "/ws",
|
||||
"tokenIssuePath": "/auth/token",
|
||||
"tokenIssueSecret": "production-secret",
|
||||
"websocketRequiresToken": true,
|
||||
"sslCertfile": "/etc/ssl/certs/server.pem",
|
||||
"sslKeyfile": "/etc/ssl/private/server-key.pem",
|
||||
"allowFrom": ["*"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Custom path
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"websocket": {
|
||||
"enabled": true,
|
||||
"path": "/chat/ws",
|
||||
"allowFrom": ["*"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Clients connect to `ws://127.0.0.1:8765/chat/ws?client_id=...`. Trailing slashes are normalized, so `/chat/ws/` works the same.
|
||||
@ -2,7 +2,29 @@
|
||||
nanobot - A lightweight AI agent framework
|
||||
"""
|
||||
|
||||
__version__ = "0.1.5"
|
||||
from importlib.metadata import PackageNotFoundError, version as _pkg_version
|
||||
from pathlib import Path
|
||||
import tomllib
|
||||
|
||||
|
||||
def _read_pyproject_version() -> str | None:
|
||||
"""Read the source-tree version when package metadata is unavailable."""
|
||||
pyproject = Path(__file__).resolve().parent.parent / "pyproject.toml"
|
||||
if not pyproject.exists():
|
||||
return None
|
||||
data = tomllib.loads(pyproject.read_text(encoding="utf-8"))
|
||||
return data.get("project", {}).get("version")
|
||||
|
||||
|
||||
def _resolve_version() -> str:
|
||||
try:
|
||||
return _pkg_version("nanobot-ai")
|
||||
except PackageNotFoundError:
|
||||
# Source checkouts often import nanobot without installed dist-info.
|
||||
return _read_pyproject_version() or "0.1.5"
|
||||
|
||||
|
||||
__version__ = _resolve_version()
|
||||
__logo__ = "🐈"
|
||||
|
||||
from nanobot.nanobot import Nanobot, RunResult
|
||||
|
||||
123
nanobot/agent/autocompact.py
Normal file
123
nanobot/agent/autocompact.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""Auto compact: proactive compression of idle sessions to reduce token cost and latency."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Collection
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
||||
|
||||
from loguru import logger
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.agent.memory import Consolidator
|
||||
|
||||
|
||||
class AutoCompact:
|
||||
_RECENT_SUFFIX_MESSAGES = 8
|
||||
|
||||
def __init__(self, sessions: SessionManager, consolidator: Consolidator,
|
||||
session_ttl_minutes: int = 0):
|
||||
self.sessions = sessions
|
||||
self.consolidator = consolidator
|
||||
self._ttl = session_ttl_minutes
|
||||
self._archiving: set[str] = set()
|
||||
self._summaries: dict[str, tuple[str, datetime]] = {}
|
||||
|
||||
def _is_expired(self, ts: datetime | str | None,
|
||||
now: datetime | None = None) -> bool:
|
||||
if self._ttl <= 0 or not ts:
|
||||
return False
|
||||
if isinstance(ts, str):
|
||||
ts = datetime.fromisoformat(ts)
|
||||
return ((now or datetime.now()) - ts).total_seconds() >= self._ttl * 60
|
||||
|
||||
@staticmethod
|
||||
def _format_summary(text: str, last_active: datetime) -> str:
|
||||
idle_min = int((datetime.now() - last_active).total_seconds() / 60)
|
||||
return f"Inactive for {idle_min} minutes.\nPrevious conversation summary: {text}"
|
||||
|
||||
def _split_unconsolidated(
|
||||
self, session: Session,
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Split live session tail into archiveable prefix and retained recent suffix."""
|
||||
tail = list(session.messages[session.last_consolidated:])
|
||||
if not tail:
|
||||
return [], []
|
||||
|
||||
probe = Session(
|
||||
key=session.key,
|
||||
messages=tail.copy(),
|
||||
created_at=session.created_at,
|
||||
updated_at=session.updated_at,
|
||||
metadata={},
|
||||
last_consolidated=0,
|
||||
)
|
||||
probe.retain_recent_legal_suffix(self._RECENT_SUFFIX_MESSAGES)
|
||||
kept = probe.messages
|
||||
cut = len(tail) - len(kept)
|
||||
return tail[:cut], kept
|
||||
|
||||
def check_expired(self, schedule_background: Callable[[Coroutine], None],
|
||||
active_session_keys: Collection[str] = ()) -> None:
|
||||
"""Schedule archival for idle sessions, skipping those with in-flight agent tasks."""
|
||||
now = datetime.now()
|
||||
for info in self.sessions.list_sessions():
|
||||
key = info.get("key", "")
|
||||
if not key or key in self._archiving:
|
||||
continue
|
||||
if key in active_session_keys:
|
||||
continue
|
||||
if self._is_expired(info.get("updated_at"), now):
|
||||
self._archiving.add(key)
|
||||
schedule_background(self._archive(key))
|
||||
|
||||
async def _archive(self, key: str) -> None:
|
||||
try:
|
||||
self.sessions.invalidate(key)
|
||||
session = self.sessions.get_or_create(key)
|
||||
archive_msgs, kept_msgs = self._split_unconsolidated(session)
|
||||
if not archive_msgs and not kept_msgs:
|
||||
session.updated_at = datetime.now()
|
||||
self.sessions.save(session)
|
||||
return
|
||||
|
||||
last_active = session.updated_at
|
||||
summary = ""
|
||||
if archive_msgs:
|
||||
summary = await self.consolidator.archive(archive_msgs) or ""
|
||||
if summary and summary != "(nothing)":
|
||||
self._summaries[key] = (summary, last_active)
|
||||
session.metadata["_last_summary"] = {"text": summary, "last_active": last_active.isoformat()}
|
||||
session.messages = kept_msgs
|
||||
session.last_consolidated = 0
|
||||
session.updated_at = datetime.now()
|
||||
self.sessions.save(session)
|
||||
if archive_msgs:
|
||||
logger.info(
|
||||
"Auto-compact: archived {} (archived={}, kept={}, summary={})",
|
||||
key,
|
||||
len(archive_msgs),
|
||||
len(kept_msgs),
|
||||
bool(summary),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Auto-compact: failed for {}", key)
|
||||
finally:
|
||||
self._archiving.discard(key)
|
||||
|
||||
def prepare_session(self, session: Session, key: str) -> tuple[Session, str | None]:
|
||||
if key in self._archiving or self._is_expired(session.updated_at):
|
||||
logger.info("Auto-compact: reloading session {} (archiving={})", key, key in self._archiving)
|
||||
session = self.sessions.get_or_create(key)
|
||||
# Hot path: summary from in-memory dict (process hasn't restarted).
|
||||
# Also clean metadata copy so stale _last_summary never leaks to disk.
|
||||
entry = self._summaries.pop(key, None)
|
||||
if entry:
|
||||
session.metadata.pop("_last_summary", None)
|
||||
return session, self._format_summary(entry[0], entry[1])
|
||||
if "_last_summary" in session.metadata:
|
||||
meta = session.metadata.pop("_last_summary")
|
||||
self.sessions.save(session)
|
||||
return session, self._format_summary(meta["text"], datetime.fromisoformat(meta["last_active"]))
|
||||
return session, None
|
||||
@ -6,12 +6,10 @@ import platform
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.utils.helpers import current_time_str
|
||||
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.utils.prompt_templates import render_template
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
from nanobot.utils.helpers import build_assistant_message, detect_image_mime
|
||||
from nanobot.utils.helpers import build_assistant_message, current_time_str, detect_image_mime
|
||||
from nanobot.utils.prompt_templates import render_template
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
@ -20,12 +18,13 @@ class ContextBuilder:
|
||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
|
||||
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||
_MAX_RECENT_HISTORY = 50
|
||||
_RUNTIME_CONTEXT_END = "[/Runtime Context]"
|
||||
|
||||
def __init__(self, workspace: Path, timezone: str | None = None):
|
||||
def __init__(self, workspace: Path, timezone: str | None = None, disabled_skills: list[str] | None = None):
|
||||
self.workspace = workspace
|
||||
self.timezone = timezone
|
||||
self.memory = MemoryStore(workspace)
|
||||
self.skills = SkillsLoader(workspace)
|
||||
self.skills = SkillsLoader(workspace, disabled_skills=set(disabled_skills) if disabled_skills else None)
|
||||
|
||||
def build_system_prompt(
|
||||
self,
|
||||
@ -79,12 +78,15 @@ class ContextBuilder:
|
||||
@staticmethod
|
||||
def _build_runtime_context(
|
||||
channel: str | None, chat_id: str | None, timezone: str | None = None,
|
||||
session_summary: str | None = None,
|
||||
) -> str:
|
||||
"""Build untrusted runtime metadata block for injection before the user message."""
|
||||
lines = [f"Current Time: {current_time_str(timezone)}"]
|
||||
if channel and chat_id:
|
||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
||||
if session_summary:
|
||||
lines += ["", "[Resumed Session]", session_summary]
|
||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + "\n" + ContextBuilder._RUNTIME_CONTEXT_END
|
||||
|
||||
@staticmethod
|
||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||
@ -121,9 +123,10 @@ class ContextBuilder:
|
||||
channel: str | None = None,
|
||||
chat_id: str | None = None,
|
||||
current_role: str = "user",
|
||||
session_summary: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build the complete message list for an LLM call."""
|
||||
runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone)
|
||||
runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone, session_summary=session_summary)
|
||||
user_content = self._build_user_content(current_message, media)
|
||||
|
||||
# Merge runtime context and user content into a single user message
|
||||
@ -176,7 +179,7 @@ class ContextBuilder:
|
||||
# Try document text extraction
|
||||
from nanobot.utils.document import extract_text
|
||||
extracted = extract_text(p)
|
||||
if extracted and not extracted.startswith("Error"):
|
||||
if extracted and not extracted.startswith("[error:"):
|
||||
doc_texts.append(f"[File: {p.name}]\n{extracted}")
|
||||
|
||||
# Build final content
|
||||
|
||||
@ -29,6 +29,9 @@ class AgentHookContext:
|
||||
class AgentHook:
|
||||
"""Minimal lifecycle surface for shared runner customization."""
|
||||
|
||||
def __init__(self, reraise: bool = False) -> None:
|
||||
self._reraise = reraise
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return False
|
||||
|
||||
@ -62,6 +65,7 @@ class CompositeHook(AgentHook):
|
||||
__slots__ = ("_hooks",)
|
||||
|
||||
def __init__(self, hooks: list[AgentHook]) -> None:
|
||||
super().__init__()
|
||||
self._hooks = list(hooks)
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
@ -69,6 +73,10 @@ class CompositeHook(AgentHook):
|
||||
|
||||
async def _for_each_hook_safe(self, method_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
for h in self._hooks:
|
||||
if getattr(h, "_reraise", False):
|
||||
await getattr(h, method_name)(*args, **kwargs)
|
||||
continue
|
||||
|
||||
try:
|
||||
await getattr(h, method_name)(*args, **kwargs)
|
||||
except Exception:
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@ -12,27 +13,30 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.autocompact import AutoCompact
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.memory import Consolidator, Dream
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.search import GlobTool, GrepTool
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.utils.helpers import image_placeholder_text, truncate_text
|
||||
from nanobot.utils.helpers import image_placeholder_text
|
||||
from nanobot.utils.helpers import truncate_text as truncate_text_fn
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -40,6 +44,9 @@ if TYPE_CHECKING:
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
|
||||
UNIFIED_SESSION_KEY = "unified:default"
|
||||
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
"""Core hook for the main loop."""
|
||||
|
||||
@ -54,6 +61,7 @@ class _LoopHook(AgentHook):
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(reraise=True)
|
||||
self._loop = agent_loop
|
||||
self._on_progress = on_progress
|
||||
self._on_stream = on_stream
|
||||
@ -72,7 +80,7 @@ class _LoopHook(AgentHook):
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
incremental = new_clean[len(prev_clean) :]
|
||||
if incremental and self._on_stream:
|
||||
await self._on_stream(incremental)
|
||||
|
||||
@ -109,43 +117,6 @@ class _LoopHook(AgentHook):
|
||||
return self._loop._strip_think(content)
|
||||
|
||||
|
||||
class _LoopHookChain(AgentHook):
|
||||
"""Run the core hook before extra hooks."""
|
||||
|
||||
__slots__ = ("_primary", "_extras")
|
||||
|
||||
def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None:
|
||||
self._primary = primary
|
||||
self._extras = CompositeHook(extra_hooks)
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return self._primary.wants_streaming() or self._extras.wants_streaming()
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
await self._primary.before_iteration(context)
|
||||
await self._extras.before_iteration(context)
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
await self._primary.on_stream(context, delta)
|
||||
await self._extras.on_stream(context, delta)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
await self._primary.on_stream_end(context, resuming=resuming)
|
||||
await self._extras.on_stream_end(context, resuming=resuming)
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
await self._primary.before_execute_tools(context)
|
||||
await self._extras.before_execute_tools(context)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
await self._primary.after_iteration(context)
|
||||
await self._extras.after_iteration(context)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
content = self._primary.finalize_content(context, content)
|
||||
return self._extras.finalize_content(context, content)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
@ -159,6 +130,7 @@ class AgentLoop:
|
||||
"""
|
||||
|
||||
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||
_PENDING_USER_TURN_KEY = "pending_user_turn"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -179,7 +151,10 @@ class AgentLoop:
|
||||
mcp_servers: dict | None = None,
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
timezone: str | None = None,
|
||||
session_ttl_minutes: int = 0,
|
||||
hooks: list[AgentHook] | None = None,
|
||||
unified_session: bool = False,
|
||||
disabled_skills: list[str] | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig, WebToolsConfig
|
||||
|
||||
@ -212,7 +187,7 @@ class AgentLoop:
|
||||
self._last_usage: dict[str, int] = {}
|
||||
self._extra_hooks: list[AgentHook] = hooks or []
|
||||
|
||||
self.context = ContextBuilder(workspace, timezone=timezone)
|
||||
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
self.tools = ToolRegistry()
|
||||
self.runner = AgentRunner(provider)
|
||||
@ -225,16 +200,21 @@ class AgentLoop:
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
exec_config=self.exec_config,
|
||||
restrict_to_workspace=restrict_to_workspace,
|
||||
disabled_skills=disabled_skills,
|
||||
)
|
||||
|
||||
self._unified_session = unified_session
|
||||
self._running = False
|
||||
self._mcp_servers = mcp_servers or {}
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
self._mcp_stacks: dict[str, AsyncExitStack] = {}
|
||||
self._mcp_connected = False
|
||||
self._mcp_connecting = False
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._background_tasks: list[asyncio.Task] = []
|
||||
self._session_locks: dict[str, asyncio.Lock] = {}
|
||||
# Per-session pending queues for mid-turn message injection.
|
||||
# When a session has an active task, new messages for that session
|
||||
# are routed here instead of creating a new task.
|
||||
self._pending_queues: dict[str, asyncio.Queue] = {}
|
||||
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
|
||||
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
|
||||
self._concurrency_gate: asyncio.Semaphore | None = (
|
||||
@ -250,6 +230,11 @@ class AgentLoop:
|
||||
get_tool_definitions=self.tools.get_definitions,
|
||||
max_completion_tokens=provider.generation.max_tokens,
|
||||
)
|
||||
self.auto_compact = AutoCompact(
|
||||
sessions=self.sessions,
|
||||
consolidator=self.consolidator,
|
||||
session_ttl_minutes=session_ttl_minutes,
|
||||
)
|
||||
self.dream = Dream(
|
||||
store=self.context.memory,
|
||||
provider=provider,
|
||||
@ -261,23 +246,35 @@ class AgentLoop:
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
allowed_dir = (
|
||||
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
)
|
||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
||||
self.tools.register(
|
||||
ReadFileTool(
|
||||
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
||||
)
|
||||
)
|
||||
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
for cls in (GlobTool, GrepTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
self.tools.register(NotebookEditTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
if self.exec_config.enable:
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
sandbox=self.exec_config.sandbox,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
self.tools.register(
|
||||
ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
sandbox=self.exec_config.sandbox,
|
||||
path_append=self.exec_config.path_append,
|
||||
allowed_env_keys=self.exec_config.allowed_env_keys,
|
||||
)
|
||||
)
|
||||
if self.web_config.enable:
|
||||
self.tools.register(WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy))
|
||||
self.tools.register(
|
||||
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
|
||||
)
|
||||
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||
self.tools.register(SpawnTool(manager=self.subagents))
|
||||
@ -292,19 +289,19 @@ class AgentLoop:
|
||||
return
|
||||
self._mcp_connecting = True
|
||||
from nanobot.agent.tools.mcp import connect_mcp_servers
|
||||
|
||||
try:
|
||||
self._mcp_stack = AsyncExitStack()
|
||||
await self._mcp_stack.__aenter__()
|
||||
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
|
||||
self._mcp_connected = True
|
||||
self._mcp_stacks = await connect_mcp_servers(self._mcp_servers, self.tools)
|
||||
if self._mcp_stacks:
|
||||
self._mcp_connected = True
|
||||
else:
|
||||
logger.warning("No MCP servers connected successfully (will retry next message)")
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("MCP connection cancelled (will retry next message)")
|
||||
self._mcp_stacks.clear()
|
||||
except BaseException as e:
|
||||
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
|
||||
if self._mcp_stack:
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
self._mcp_stack = None
|
||||
self._mcp_stacks.clear()
|
||||
finally:
|
||||
self._mcp_connecting = False
|
||||
|
||||
@ -321,6 +318,7 @@ class AgentLoop:
|
||||
if not text:
|
||||
return None
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
return strip_think(text) or None
|
||||
|
||||
@staticmethod
|
||||
@ -330,6 +328,12 @@ class AgentLoop:
|
||||
|
||||
return format_tool_hints(tool_calls)
|
||||
|
||||
def _effective_session_key(self, msg: InboundMessage) -> str:
|
||||
"""Return the session key used for task routing and mid-turn injections."""
|
||||
if self._unified_session and not msg.session_key_override:
|
||||
return UNIFIED_SESSION_KEY
|
||||
return msg.session_key
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
@ -341,13 +345,16 @@ class AgentLoop:
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict]]:
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict], str, bool]:
|
||||
"""Run the agent iteration loop.
|
||||
|
||||
*on_stream*: called with each content delta during streaming.
|
||||
*on_stream_end(resuming)*: called when a streaming session finishes.
|
||||
``resuming=True`` means tool calls follow (spinner should restart);
|
||||
``resuming=False`` means this is the final response.
|
||||
|
||||
Returns (final_content, tools_used, messages, stop_reason, had_injections).
|
||||
"""
|
||||
loop_hook = _LoopHook(
|
||||
self,
|
||||
@ -359,9 +366,7 @@ class AgentLoop:
|
||||
message_id=message_id,
|
||||
)
|
||||
hook: AgentHook = (
|
||||
_LoopHookChain(loop_hook, self._extra_hooks)
|
||||
if self._extra_hooks
|
||||
else loop_hook
|
||||
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
|
||||
)
|
||||
|
||||
async def _checkpoint(payload: dict[str, Any]) -> None:
|
||||
@ -369,6 +374,32 @@ class AgentLoop:
|
||||
return
|
||||
self._set_runtime_checkpoint(session, payload)
|
||||
|
||||
async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
|
||||
"""Non-blocking drain of follow-up messages from the pending queue."""
|
||||
if pending_queue is None:
|
||||
return []
|
||||
items: list[dict[str, Any]] = []
|
||||
while len(items) < limit:
|
||||
try:
|
||||
pending_msg = pending_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
user_content = self.context._build_user_content(
|
||||
pending_msg.content,
|
||||
pending_msg.media if pending_msg.media else None,
|
||||
)
|
||||
runtime_ctx = self.context._build_runtime_context(
|
||||
pending_msg.channel,
|
||||
pending_msg.chat_id,
|
||||
self.context.timezone,
|
||||
)
|
||||
if isinstance(user_content, str):
|
||||
merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}"
|
||||
else:
|
||||
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||
items.append({"role": "user", "content": merged})
|
||||
return items
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
@ -385,13 +416,14 @@ class AgentLoop:
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
checkpoint_callback=_checkpoint,
|
||||
injection_callback=_drain_pending,
|
||||
))
|
||||
self._last_usage = result.usage
|
||||
if result.stop_reason == "max_iterations":
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
elif result.stop_reason == "error":
|
||||
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
|
||||
return result.final_content, result.tools_used, result.messages
|
||||
return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||
@ -403,6 +435,10 @@ class AgentLoop:
|
||||
try:
|
||||
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.auto_compact.check_expired(
|
||||
self._schedule_background,
|
||||
active_session_keys=self._pending_queues.keys(),
|
||||
)
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
# Preserve real task cancellation so shutdown can complete cleanly.
|
||||
@ -421,79 +457,140 @@ class AgentLoop:
|
||||
if result:
|
||||
await self.bus.publish_outbound(result)
|
||||
continue
|
||||
effective_key = self._effective_session_key(msg)
|
||||
# If this session already has an active pending queue (i.e. a task
|
||||
# is processing this session), route the message there for mid-turn
|
||||
# injection instead of creating a competing task.
|
||||
if effective_key in self._pending_queues:
|
||||
pending_msg = msg
|
||||
if effective_key != msg.session_key:
|
||||
pending_msg = dataclasses.replace(
|
||||
msg,
|
||||
session_key_override=effective_key,
|
||||
)
|
||||
try:
|
||||
self._pending_queues[effective_key].put_nowait(pending_msg)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"Pending queue full for session {}, falling back to queued task",
|
||||
effective_key,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Routed follow-up message to pending queue for session {}",
|
||||
effective_key,
|
||||
)
|
||||
continue
|
||||
# Compute the effective session key before dispatching
|
||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
||||
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||
task.add_done_callback(
|
||||
lambda t, k=effective_key: self._active_tasks.get(k, [])
|
||||
and self._active_tasks[k].remove(t)
|
||||
if t in self._active_tasks.get(k, [])
|
||||
else None
|
||||
)
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message: per-session serial, cross-session concurrent."""
|
||||
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
|
||||
session_key = self._effective_session_key(msg)
|
||||
if session_key != msg.session_key:
|
||||
msg = dataclasses.replace(msg, session_key_override=session_key)
|
||||
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
||||
gate = self._concurrency_gate or nullcontext()
|
||||
async with lock, gate:
|
||||
try:
|
||||
on_stream = on_stream_end = None
|
||||
if msg.metadata.get("_wants_stream"):
|
||||
# Split one answer into distinct stream segments.
|
||||
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
|
||||
stream_segment = 0
|
||||
|
||||
def _current_stream_id() -> str:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
# Register a pending queue so follow-up messages for this session are
|
||||
# routed here (mid-turn injection) instead of spawning a new task.
|
||||
pending = asyncio.Queue(maxsize=20)
|
||||
self._pending_queues[session_key] = pending
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
try:
|
||||
async with lock, gate:
|
||||
try:
|
||||
on_stream = on_stream_end = None
|
||||
if msg.metadata.get("_wants_stream"):
|
||||
# Split one answer into distinct stream segments.
|
||||
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
|
||||
stream_segment = 0
|
||||
|
||||
def _current_stream_id() -> str:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
))
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
))
|
||||
stream_segment += 1
|
||||
|
||||
response = await self._process_message(
|
||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
pending_queue=pending,
|
||||
)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
content="", metadata=msg.metadata or {},
|
||||
))
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
))
|
||||
stream_segment += 1
|
||||
|
||||
response = await self._process_message(
|
||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="", metadata=msg.metadata or {},
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", msg.session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", msg.session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
finally:
|
||||
# Drain any messages still in the pending queue and re-publish
|
||||
# them to the bus so they are processed as fresh inbound messages
|
||||
# rather than silently lost.
|
||||
queue = self._pending_queues.pop(session_key, None)
|
||||
if queue is not None:
|
||||
leftover = 0
|
||||
while True:
|
||||
try:
|
||||
item = queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
await self.bus.publish_inbound(item)
|
||||
leftover += 1
|
||||
if leftover:
|
||||
logger.info(
|
||||
"Re-published {} leftover message(s) to bus for session {}",
|
||||
leftover, session_key,
|
||||
)
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Drain pending background archives, then close MCP connections."""
|
||||
if self._background_tasks:
|
||||
await asyncio.gather(*self._background_tasks, return_exceptions=True)
|
||||
self._background_tasks.clear()
|
||||
if self._mcp_stack:
|
||||
for name, stack in self._mcp_stacks.items():
|
||||
try:
|
||||
await self._mcp_stack.aclose()
|
||||
await stack.aclose()
|
||||
except (RuntimeError, BaseExceptionGroup):
|
||||
pass # MCP SDK cancel scope cleanup is noisy but harmless
|
||||
self._mcp_stack = None
|
||||
logger.debug("MCP server '{}' cleanup error (can be ignored)", name)
|
||||
self._mcp_stacks.clear()
|
||||
|
||||
def _schedule_background(self, coro) -> None:
|
||||
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
|
||||
@ -513,27 +610,36 @@ class AgentLoop:
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a single inbound message and return the response."""
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
if msg.channel == "system":
|
||||
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
|
||||
else ("cli", msg.chat_id))
|
||||
channel, chat_id = (
|
||||
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
|
||||
)
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
if self._restore_pending_user_turn(session):
|
||||
self.sessions.save(session)
|
||||
|
||||
session, pending = self.auto_compact.prepare_session(session, key)
|
||||
|
||||
await self.consolidator.maybe_consolidate_by_tokens(session)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=0)
|
||||
current_role = "assistant" if msg.sender_id == "subagent" else "user"
|
||||
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content, channel=channel, chat_id=chat_id,
|
||||
session_summary=pending,
|
||||
current_role=current_role,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
@ -541,8 +647,11 @@ class AgentLoop:
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
return OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=final_content or "Background task completed.",
|
||||
)
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
@ -551,6 +660,10 @@ class AgentLoop:
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
if self._restore_pending_user_turn(session):
|
||||
self.sessions.save(session)
|
||||
|
||||
session, pending = self.auto_compact.prepare_session(session, key)
|
||||
|
||||
# Slash commands
|
||||
raw = msg.content.strip()
|
||||
@ -566,50 +679,85 @@ class AgentLoop:
|
||||
message_tool.start_turn()
|
||||
|
||||
history = session.get_history(max_messages=0)
|
||||
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
session_summary=pending,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_progress"] = True
|
||||
meta["_tool_hint"] = tool_hint
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
|
||||
))
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=content,
|
||||
metadata=meta,
|
||||
)
|
||||
)
|
||||
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
# Persist the triggering user message immediately, before running the
|
||||
# agent loop. If the process is killed mid-turn (OOM, SIGKILL, self-
|
||||
# restart, etc.), the existing runtime_checkpoint preserves the
|
||||
# in-flight assistant/tool state but NOT the user message itself, so
|
||||
# the user's prompt is silently lost on recovery. Saving it up front
|
||||
# makes recovery possible from the session log alone.
|
||||
user_persisted_early = False
|
||||
if isinstance(msg.content, str) and msg.content.strip():
|
||||
session.add_message("user", msg.content)
|
||||
self._mark_pending_user_turn(session)
|
||||
self.sessions.save(session)
|
||||
user_persisted_early = True
|
||||
|
||||
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
|
||||
initial_messages,
|
||||
on_progress=on_progress or _bus_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
session=session,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
if final_content is None or not final_content.strip():
|
||||
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
# Skip the already-persisted user message when saving the turn
|
||||
save_skip = 1 + len(history) + (1 if user_persisted_early else 0)
|
||||
self._save_turn(session, all_msgs, save_skip)
|
||||
self._clear_pending_user_turn(session)
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
|
||||
# When follow-up messages were injected mid-turn, a later natural
|
||||
# language reply may address those follow-ups and should not be
|
||||
# suppressed just because MessageTool was used earlier in the turn.
|
||||
# However, if the turn falls back to the empty-final-response
|
||||
# placeholder, suppress it when the real user-visible output already
|
||||
# came from MessageTool.
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
if not had_injections or stop_reason == "empty_final_response":
|
||||
return None
|
||||
|
||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
meta = dict(msg.metadata or {})
|
||||
if on_stream is not None:
|
||||
if on_stream is not None and stop_reason != "error":
|
||||
meta["_streamed"] = True
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=final_content,
|
||||
metadata=meta,
|
||||
)
|
||||
|
||||
@ -617,7 +765,7 @@ class AgentLoop:
|
||||
self,
|
||||
content: list[dict[str, Any]],
|
||||
*,
|
||||
truncate_text: bool = False,
|
||||
should_truncate_text: bool = False,
|
||||
drop_runtime: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Strip volatile multimodal payloads before writing session history."""
|
||||
@ -635,18 +783,17 @@ class AgentLoop:
|
||||
):
|
||||
continue
|
||||
|
||||
if (
|
||||
block.get("type") == "image_url"
|
||||
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
||||
):
|
||||
if block.get("type") == "image_url" and block.get("image_url", {}).get(
|
||||
"url", ""
|
||||
).startswith("data:image/"):
|
||||
path = (block.get("_meta") or {}).get("path", "")
|
||||
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||
continue
|
||||
|
||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||
text = block["text"]
|
||||
if truncate_text and len(text) > self.max_tool_result_chars:
|
||||
text = truncate_text(text, self.max_tool_result_chars)
|
||||
if should_truncate_text and len(text) > self.max_tool_result_chars:
|
||||
text = truncate_text_fn(text, self.max_tool_result_chars)
|
||||
filtered.append({**block, "text": text})
|
||||
continue
|
||||
|
||||
@ -657,6 +804,7 @@ class AgentLoop:
|
||||
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
|
||||
"""Save new-turn messages into session, truncating large tool results."""
|
||||
from datetime import datetime
|
||||
|
||||
for m in messages[skip:]:
|
||||
entry = dict(m)
|
||||
role, content = entry.get("role"), entry.get("content")
|
||||
@ -664,20 +812,31 @@ class AgentLoop:
|
||||
continue # skip empty assistant messages — they poison session context
|
||||
if role == "tool":
|
||||
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
|
||||
entry["content"] = truncate_text(content, self.max_tool_result_chars)
|
||||
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
|
||||
elif isinstance(content, list):
|
||||
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
|
||||
filtered = self._sanitize_persisted_blocks(content, should_truncate_text=True)
|
||||
if not filtered:
|
||||
continue
|
||||
entry["content"] = filtered
|
||||
elif role == "user":
|
||||
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||
# Strip the runtime-context prefix, keep only the user text.
|
||||
parts = content.split("\n\n", 1)
|
||||
if len(parts) > 1 and parts[1].strip():
|
||||
entry["content"] = parts[1]
|
||||
# Strip the entire runtime-context block (including any session summary).
|
||||
# The block is bounded by _RUNTIME_CONTEXT_TAG and _RUNTIME_CONTEXT_END.
|
||||
end_marker = ContextBuilder._RUNTIME_CONTEXT_END
|
||||
end_pos = content.find(end_marker)
|
||||
if end_pos >= 0:
|
||||
after = content[end_pos + len(end_marker):].lstrip("\n")
|
||||
if after:
|
||||
entry["content"] = after
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
# Fallback: no end marker found, strip the tag prefix
|
||||
after_tag = content[len(ContextBuilder._RUNTIME_CONTEXT_TAG):].lstrip("\n")
|
||||
if after_tag.strip():
|
||||
entry["content"] = after_tag
|
||||
else:
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
filtered = self._sanitize_persisted_blocks(content, drop_runtime=True)
|
||||
if not filtered:
|
||||
@ -692,6 +851,12 @@ class AgentLoop:
|
||||
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
|
||||
self.sessions.save(session)
|
||||
|
||||
def _mark_pending_user_turn(self, session: Session) -> None:
|
||||
session.metadata[self._PENDING_USER_TURN_KEY] = True
|
||||
|
||||
def _clear_pending_user_turn(self, session: Session) -> None:
|
||||
session.metadata.pop(self._PENDING_USER_TURN_KEY, None)
|
||||
|
||||
def _clear_runtime_checkpoint(self, session: Session) -> None:
|
||||
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
|
||||
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
|
||||
@ -735,13 +900,15 @@ class AgentLoop:
|
||||
continue
|
||||
tool_id = tool_call.get("id")
|
||||
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||
restored_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
restored_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
overlap = 0
|
||||
max_overlap = min(len(session.messages), len(restored_messages))
|
||||
@ -756,9 +923,30 @@ class AgentLoop:
|
||||
break
|
||||
session.messages.extend(restored_messages[overlap:])
|
||||
|
||||
self._clear_pending_user_turn(session)
|
||||
self._clear_runtime_checkpoint(session)
|
||||
return True
|
||||
|
||||
def _restore_pending_user_turn(self, session: Session) -> bool:
|
||||
"""Close a turn that only persisted the user message before crashing."""
|
||||
from datetime import datetime
|
||||
|
||||
if not session.metadata.get(self._PENDING_USER_TURN_KEY):
|
||||
return False
|
||||
|
||||
if session.messages and session.messages[-1].get("role") == "user":
|
||||
session.messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Error: Task interrupted before a response was generated.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
self._clear_pending_user_turn(session)
|
||||
return True
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
@ -777,6 +965,9 @@ class AgentLoop:
|
||||
content=content, media=media or [],
|
||||
)
|
||||
return await self._process_message(
|
||||
msg, session_key=session_key, on_progress=on_progress,
|
||||
on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
msg,
|
||||
session_key=session_key,
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
@ -290,7 +290,7 @@ class MemoryStore:
|
||||
if not lines:
|
||||
return None
|
||||
return json.loads(lines[-1])
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
except (FileNotFoundError, json.JSONDecodeError, UnicodeDecodeError):
|
||||
return None
|
||||
|
||||
def _write_entries(self, entries: list[dict[str, Any]]) -> None:
|
||||
@ -347,6 +347,7 @@ class Consolidator:
|
||||
"""Lightweight consolidation: summarizes evicted messages into history.jsonl."""
|
||||
|
||||
_MAX_CONSOLIDATION_ROUNDS = 5
|
||||
_MAX_CHUNK_MESSAGES = 60 # hard cap per consolidation round
|
||||
|
||||
_SAFETY_BUFFER = 1024 # extra headroom for tokenizer estimation drift
|
||||
|
||||
@ -399,6 +400,22 @@ class Consolidator:
|
||||
|
||||
return last_boundary
|
||||
|
||||
def _cap_consolidation_boundary(
|
||||
self,
|
||||
session: Session,
|
||||
end_idx: int,
|
||||
) -> int | None:
|
||||
"""Clamp the chunk size without breaking the user-turn boundary."""
|
||||
start = session.last_consolidated
|
||||
if end_idx - start <= self._MAX_CHUNK_MESSAGES:
|
||||
return end_idx
|
||||
|
||||
capped_end = start + self._MAX_CHUNK_MESSAGES
|
||||
for idx in range(capped_end, start, -1):
|
||||
if session.messages[idx].get("role") == "user":
|
||||
return idx
|
||||
return None
|
||||
|
||||
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
|
||||
"""Estimate current prompt size for the normal session history view."""
|
||||
history = session.get_history(max_messages=0)
|
||||
@ -416,13 +433,13 @@ class Consolidator:
|
||||
self._get_tool_definitions(),
|
||||
)
|
||||
|
||||
async def archive(self, messages: list[dict]) -> bool:
|
||||
async def archive(self, messages: list[dict]) -> str | None:
|
||||
"""Summarize messages via LLM and append to history.jsonl.
|
||||
|
||||
Returns True on success (or degraded success), False if nothing to do.
|
||||
Returns the summary text on success, None if nothing to archive.
|
||||
"""
|
||||
if not messages:
|
||||
return False
|
||||
return None
|
||||
try:
|
||||
formatted = MemoryStore._format_messages(messages)
|
||||
response = await self.provider.chat_with_retry(
|
||||
@ -442,11 +459,11 @@ class Consolidator:
|
||||
)
|
||||
summary = response.content or "[no summary]"
|
||||
self.store.append_history(summary)
|
||||
return True
|
||||
return summary
|
||||
except Exception:
|
||||
logger.warning("Consolidation LLM call failed, raw-dumping to history")
|
||||
self.store.raw_archive(messages)
|
||||
return True
|
||||
return None
|
||||
|
||||
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||
"""Loop: archive old messages until prompt fits within safe budget.
|
||||
@ -461,16 +478,22 @@ class Consolidator:
|
||||
async with lock:
|
||||
budget = self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER
|
||||
target = budget // 2
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
try:
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
except Exception:
|
||||
logger.exception("Token estimation failed for {}", session.key)
|
||||
estimated, source = 0, "error"
|
||||
if estimated <= 0:
|
||||
return
|
||||
if estimated < budget:
|
||||
unconsolidated_count = len(session.messages) - session.last_consolidated
|
||||
logger.debug(
|
||||
"Token consolidation idle {}: {}/{} via {}",
|
||||
"Token consolidation idle {}: {}/{} via {}, msgs={}",
|
||||
session.key,
|
||||
estimated,
|
||||
self.context_window_tokens,
|
||||
source,
|
||||
unconsolidated_count,
|
||||
)
|
||||
return
|
||||
|
||||
@ -488,6 +511,15 @@ class Consolidator:
|
||||
return
|
||||
|
||||
end_idx = boundary[0]
|
||||
end_idx = self._cap_consolidation_boundary(session, end_idx)
|
||||
if end_idx is None:
|
||||
logger.debug(
|
||||
"Token consolidation: no capped boundary for {} (round {})",
|
||||
session.key,
|
||||
round_num,
|
||||
)
|
||||
return
|
||||
|
||||
chunk = session.messages[session.last_consolidated:end_idx]
|
||||
if not chunk:
|
||||
return
|
||||
@ -506,7 +538,11 @@ class Consolidator:
|
||||
session.last_consolidated = end_idx
|
||||
self.sessions.save(session)
|
||||
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
try:
|
||||
estimated, source = self.estimate_session_prompt_tokens(session)
|
||||
except Exception:
|
||||
logger.exception("Token estimation failed for {}", session.key)
|
||||
estimated, source = 0, "error"
|
||||
if estimated <= 0:
|
||||
return
|
||||
|
||||
@ -546,18 +582,60 @@ class Dream:
|
||||
|
||||
def _build_tools(self) -> ToolRegistry:
|
||||
"""Build a minimal tool registry for the Dream agent."""
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
|
||||
|
||||
tools = ToolRegistry()
|
||||
workspace = self.store.workspace
|
||||
tools.register(ReadFileTool(workspace=workspace, allowed_dir=workspace))
|
||||
# Allow reading builtin skills for reference during skill creation
|
||||
extra_read = [BUILTIN_SKILLS_DIR] if BUILTIN_SKILLS_DIR.exists() else None
|
||||
tools.register(ReadFileTool(
|
||||
workspace=workspace,
|
||||
allowed_dir=workspace,
|
||||
extra_allowed_dirs=extra_read,
|
||||
))
|
||||
tools.register(EditFileTool(workspace=workspace, allowed_dir=workspace))
|
||||
# write_file resolves relative paths from workspace root, but can only
|
||||
# write under skills/ so the prompt can safely use skills/<name>/SKILL.md.
|
||||
skills_dir = workspace / "skills"
|
||||
skills_dir.mkdir(parents=True, exist_ok=True)
|
||||
tools.register(WriteFileTool(workspace=workspace, allowed_dir=skills_dir))
|
||||
return tools
|
||||
|
||||
# -- skill listing --------------------------------------------------------
|
||||
|
||||
def _list_existing_skills(self) -> list[str]:
|
||||
"""List existing skills as 'name — description' for dedup context."""
|
||||
import re as _re
|
||||
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
|
||||
_DESC_RE = _re.compile(r"^description:\s*(.+)$", _re.MULTILINE | _re.IGNORECASE)
|
||||
entries: dict[str, str] = {}
|
||||
for base in (self.store.workspace / "skills", BUILTIN_SKILLS_DIR):
|
||||
if not base.exists():
|
||||
continue
|
||||
for d in base.iterdir():
|
||||
if not d.is_dir():
|
||||
continue
|
||||
skill_md = d / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
continue
|
||||
# Prefer workspace skills over builtin (same name)
|
||||
if d.name in entries and base == BUILTIN_SKILLS_DIR:
|
||||
continue
|
||||
content = skill_md.read_text(encoding="utf-8")[:500]
|
||||
m = _DESC_RE.search(content)
|
||||
desc = m.group(1).strip() if m else "(no description)"
|
||||
entries[d.name] = desc
|
||||
return [f"{name} — {desc}" for name, desc in sorted(entries.items())]
|
||||
|
||||
# -- main entry ----------------------------------------------------------
|
||||
|
||||
async def run(self) -> bool:
|
||||
"""Process unprocessed history entries. Returns True if work was done."""
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
|
||||
last_cursor = self.store.get_last_dream_cursor()
|
||||
entries = self.store.read_unprocessed_history(since_cursor=last_cursor)
|
||||
if not entries:
|
||||
@ -579,6 +657,7 @@ class Dream:
|
||||
current_memory = self.store.read_memory() or "(empty)"
|
||||
current_soul = self.store.read_soul() or "(empty)"
|
||||
current_user = self.store.read_user() or "(empty)"
|
||||
|
||||
file_context = (
|
||||
f"## Current Date\n{current_date}\n\n"
|
||||
f"## Current MEMORY.md ({len(current_memory)} chars)\n{current_memory}\n\n"
|
||||
@ -586,7 +665,7 @@ class Dream:
|
||||
f"## Current USER.md ({len(current_user)} chars)\n{current_user}"
|
||||
)
|
||||
|
||||
# Phase 1: Analyze
|
||||
# Phase 1: Analyze (no skills list — dedup is Phase 2's job)
|
||||
phase1_prompt = (
|
||||
f"## Conversation History\n{history_text}\n\n{file_context}"
|
||||
)
|
||||
@ -611,13 +690,25 @@ class Dream:
|
||||
return False
|
||||
|
||||
# Phase 2: Delegate to AgentRunner with read_file / edit_file
|
||||
phase2_prompt = f"## Analysis Result\n{analysis}\n\n{file_context}"
|
||||
existing_skills = self._list_existing_skills()
|
||||
skills_section = ""
|
||||
if existing_skills:
|
||||
skills_section = (
|
||||
"\n\n## Existing Skills\n"
|
||||
+ "\n".join(f"- {s}" for s in existing_skills)
|
||||
)
|
||||
phase2_prompt = f"## Analysis Result\n{analysis}\n\n{file_context}{skills_section}"
|
||||
|
||||
tools = self._tools
|
||||
skill_creator_path = BUILTIN_SKILLS_DIR / "skill-creator" / "SKILL.md"
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": render_template("agent/dream_phase2.md", strip=True),
|
||||
"content": render_template(
|
||||
"agent/dream_phase2.md",
|
||||
strip=True,
|
||||
skill_creator_path=str(skill_creator_path),
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": phase2_prompt},
|
||||
]
|
||||
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -31,8 +32,11 @@ from nanobot.utils.runtime import (
|
||||
)
|
||||
|
||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||
_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]"
|
||||
_MAX_EMPTY_RETRIES = 2
|
||||
_MAX_LENGTH_RECOVERIES = 3
|
||||
_MAX_INJECTIONS_PER_TURN = 3
|
||||
_MAX_INJECTION_CYCLES = 5
|
||||
_SNIP_SAFETY_BUFFER = 1024
|
||||
_MICROCOMPACT_KEEP_RECENT = 10
|
||||
_MICROCOMPACT_MIN_CHARS = 500
|
||||
@ -41,6 +45,9 @@ _COMPACTABLE_TOOLS = frozenset({
|
||||
"web_search", "web_fetch", "list_dir",
|
||||
})
|
||||
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
|
||||
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRunSpec:
|
||||
"""Configuration for a single agent execution."""
|
||||
@ -65,6 +72,7 @@ class AgentRunSpec:
|
||||
provider_retry_mode: str = "standard"
|
||||
progress_callback: Any | None = None
|
||||
checkpoint_callback: Any | None = None
|
||||
injection_callback: Any | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -78,6 +86,7 @@ class AgentRunResult:
|
||||
stop_reason: str = "completed"
|
||||
error: str | None = None
|
||||
tool_events: list[dict[str, str]] = field(default_factory=list)
|
||||
had_injections: bool = False
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
@ -86,6 +95,134 @@ class AgentRunner:
|
||||
def __init__(self, provider: LLMProvider):
|
||||
self.provider = provider
|
||||
|
||||
@staticmethod
|
||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||
if isinstance(left, str) and isinstance(right, str):
|
||||
return f"{left}\n\n{right}" if left else right
|
||||
|
||||
def _to_blocks(value: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(value, list):
|
||||
return [
|
||||
item if isinstance(item, dict) else {"type": "text", "text": str(item)}
|
||||
for item in value
|
||||
]
|
||||
if value is None:
|
||||
return []
|
||||
return [{"type": "text", "text": str(value)}]
|
||||
|
||||
return _to_blocks(left) + _to_blocks(right)
|
||||
|
||||
@classmethod
|
||||
def _append_injected_messages(
|
||||
cls,
|
||||
messages: list[dict[str, Any]],
|
||||
injections: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Append injected user messages while preserving role alternation."""
|
||||
for injection in injections:
|
||||
if (
|
||||
messages
|
||||
and injection.get("role") == "user"
|
||||
and messages[-1].get("role") == "user"
|
||||
):
|
||||
merged = dict(messages[-1])
|
||||
merged["content"] = cls._merge_message_content(
|
||||
merged.get("content"),
|
||||
injection.get("content"),
|
||||
)
|
||||
messages[-1] = merged
|
||||
continue
|
||||
messages.append(injection)
|
||||
|
||||
async def _try_drain_injections(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
messages: list[dict[str, Any]],
|
||||
assistant_message: dict[str, Any] | None,
|
||||
injection_cycles: int,
|
||||
*,
|
||||
phase: str = "after error",
|
||||
iteration: int | None = None,
|
||||
) -> tuple[bool, int]:
|
||||
"""Drain pending injections. Returns (should_continue, updated_cycles).
|
||||
|
||||
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
|
||||
append them to *messages* (and emit a checkpoint if *assistant_message*
|
||||
and *iteration* are both provided) and return (True, cycles+1) so the
|
||||
caller continues the iteration loop. Otherwise return (False, cycles).
|
||||
"""
|
||||
if injection_cycles >= _MAX_INJECTION_CYCLES:
|
||||
return False, injection_cycles
|
||||
injections = await self._drain_injections(spec)
|
||||
if not injections:
|
||||
return False, injection_cycles
|
||||
injection_cycles += 1
|
||||
if assistant_message is not None:
|
||||
messages.append(assistant_message)
|
||||
if iteration is not None:
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) {} ({}/{})",
|
||||
len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
return True, injection_cycles
|
||||
|
||||
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
|
||||
"""Drain pending user messages via the injection callback.
|
||||
|
||||
Returns normalized user messages (capped by
|
||||
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
|
||||
nothing to inject. Messages beyond the cap are logged so they
|
||||
are not silently lost.
|
||||
"""
|
||||
if spec.injection_callback is None:
|
||||
return []
|
||||
try:
|
||||
signature = inspect.signature(spec.injection_callback)
|
||||
accepts_limit = (
|
||||
"limit" in signature.parameters
|
||||
or any(
|
||||
parameter.kind is inspect.Parameter.VAR_KEYWORD
|
||||
for parameter in signature.parameters.values()
|
||||
)
|
||||
)
|
||||
if accepts_limit:
|
||||
items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN)
|
||||
else:
|
||||
items = await spec.injection_callback()
|
||||
except Exception:
|
||||
logger.exception("injection_callback failed")
|
||||
return []
|
||||
if not items:
|
||||
return []
|
||||
injected_messages: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict) and item.get("role") == "user" and "content" in item:
|
||||
injected_messages.append(item)
|
||||
continue
|
||||
text = getattr(item, "content", str(item))
|
||||
if text.strip():
|
||||
injected_messages.append({"role": "user", "content": text})
|
||||
if len(injected_messages) > _MAX_INJECTIONS_PER_TURN:
|
||||
dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN
|
||||
logger.warning(
|
||||
"Injection callback returned {} messages, capping to {} ({} dropped)",
|
||||
len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped,
|
||||
)
|
||||
injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN]
|
||||
return injected_messages
|
||||
|
||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||
hook = spec.hook or AgentHook()
|
||||
messages = list(spec.initial_messages)
|
||||
@ -98,21 +235,35 @@ class AgentRunner:
|
||||
external_lookup_counts: dict[str, int] = {}
|
||||
empty_content_retries = 0
|
||||
length_recovery_count = 0
|
||||
had_injections = False
|
||||
injection_cycles = 0
|
||||
|
||||
for iteration in range(spec.max_iterations):
|
||||
try:
|
||||
messages = self._backfill_missing_tool_results(messages)
|
||||
messages = self._microcompact(messages)
|
||||
messages = self._apply_tool_result_budget(spec, messages)
|
||||
messages_for_model = self._snip_history(spec, messages)
|
||||
# Keep the persisted conversation untouched. Context governance
|
||||
# may repair or compact historical messages for the model, but
|
||||
# those synthetic edits must not shift the append boundary used
|
||||
# later when the caller saves only the new turn.
|
||||
messages_for_model = self._drop_orphan_tool_results(messages)
|
||||
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||
messages_for_model = self._microcompact(messages_for_model)
|
||||
messages_for_model = self._apply_tool_result_budget(spec, messages_for_model)
|
||||
messages_for_model = self._snip_history(spec, messages_for_model)
|
||||
# Snipping may have created new orphans; clean them up.
|
||||
messages_for_model = self._drop_orphan_tool_results(messages_for_model)
|
||||
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Context governance failed on turn {} for {}: {}; using raw messages",
|
||||
"Context governance failed on turn {} for {}: {}; applying minimal repair",
|
||||
iteration,
|
||||
spec.session_key or "default",
|
||||
exc,
|
||||
)
|
||||
messages_for_model = messages
|
||||
try:
|
||||
messages_for_model = self._drop_orphan_tool_results(messages)
|
||||
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
|
||||
except Exception:
|
||||
messages_for_model = messages
|
||||
context = AgentHookContext(iteration=iteration, messages=messages)
|
||||
await hook.before_iteration(context)
|
||||
response = await self._request_model(spec, messages_for_model, hook, context)
|
||||
@ -156,16 +307,6 @@ class AgentRunner:
|
||||
tool_events.extend(new_events)
|
||||
context.tool_results = list(results)
|
||||
context.tool_events = list(new_events)
|
||||
if fatal_error is not None:
|
||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||
final_content = error
|
||||
stop_reason = "tool_error"
|
||||
self._append_final_message(messages, final_content)
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
completed_tool_results: list[dict[str, Any]] = []
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
tool_message = {
|
||||
@ -181,6 +322,23 @@ class AgentRunner:
|
||||
}
|
||||
messages.append(tool_message)
|
||||
completed_tool_results.append(tool_message)
|
||||
if fatal_error is not None:
|
||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||
final_content = error
|
||||
stop_reason = "tool_error"
|
||||
self._append_final_message(messages, final_content)
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after tool error",
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
continue
|
||||
break
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
@ -194,6 +352,13 @@ class AgentRunner:
|
||||
)
|
||||
empty_content_retries = 0
|
||||
length_recovery_count = 0
|
||||
# Checkpoint 1: drain injections after tools, before next LLM call
|
||||
_drained, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after tool execution",
|
||||
)
|
||||
if _drained:
|
||||
had_injections = True
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
@ -250,18 +415,48 @@ class AgentRunner:
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
assistant_message: dict[str, Any] | None = None
|
||||
if response.finish_reason != "error" and not is_blank_text(clean):
|
||||
assistant_message = build_assistant_message(
|
||||
clean,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
|
||||
# Check for mid-turn injections BEFORE signaling stream end.
|
||||
# If injections are found we keep the stream alive (resuming=True)
|
||||
# so streaming channels don't prematurely finalize the card.
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, assistant_message, injection_cycles,
|
||||
phase="after final response",
|
||||
iteration=iteration,
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
await hook.on_stream_end(context, resuming=should_continue)
|
||||
|
||||
if should_continue:
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
if response.finish_reason == "error":
|
||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||
stop_reason = "error"
|
||||
error = final_content
|
||||
self._append_final_message(messages, final_content)
|
||||
self._append_model_error_placeholder(messages)
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after LLM error",
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
continue
|
||||
break
|
||||
if is_blank_text(clean):
|
||||
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
@ -272,9 +467,16 @@ class AgentRunner:
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after empty response",
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
continue
|
||||
break
|
||||
|
||||
messages.append(build_assistant_message(
|
||||
messages.append(assistant_message or build_assistant_message(
|
||||
clean,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
@ -308,6 +510,17 @@ class AgentRunner:
|
||||
max_iterations=spec.max_iterations,
|
||||
)
|
||||
self._append_final_message(messages, final_content)
|
||||
# Drain any remaining injections so they are appended to the
|
||||
# conversation history instead of being re-published as
|
||||
# independent inbound messages by _dispatch's finally block.
|
||||
# We ignore should_continue here because the for-loop has already
|
||||
# exhausted all iterations.
|
||||
drained_after_max_iterations, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after max_iterations",
|
||||
)
|
||||
if drained_after_max_iterations:
|
||||
had_injections = True
|
||||
|
||||
return AgentRunResult(
|
||||
final_content=final_content,
|
||||
@ -317,6 +530,7 @@ class AgentRunner:
|
||||
stop_reason=stop_reason,
|
||||
error=error,
|
||||
tool_events=tool_events,
|
||||
had_injections=had_injections,
|
||||
)
|
||||
|
||||
def _build_request_kwargs(
|
||||
@ -521,6 +735,12 @@ class AgentRunner:
|
||||
return
|
||||
messages.append(build_assistant_message(content))
|
||||
|
||||
@staticmethod
|
||||
def _append_model_error_placeholder(messages: list[dict[str, Any]]) -> None:
|
||||
if messages and messages[-1].get("role") == "assistant" and not messages[-1].get("tool_calls"):
|
||||
return
|
||||
messages.append(build_assistant_message(_PERSISTED_MODEL_ERROR_PLACEHOLDER))
|
||||
|
||||
def _normalize_tool_result(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
@ -549,6 +769,32 @@ class AgentRunner:
|
||||
return truncate_text(content, spec.max_tool_result_chars)
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
def _drop_orphan_tool_results(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Drop tool results that have no matching assistant tool_call earlier in the history."""
|
||||
declared: set[str] = set()
|
||||
updated: list[dict[str, Any]] | None = None
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
if role == "tool":
|
||||
tid = msg.get("tool_call_id")
|
||||
if tid and str(tid) not in declared:
|
||||
if updated is None:
|
||||
updated = [dict(m) for m in messages[:idx]]
|
||||
continue
|
||||
if updated is not None:
|
||||
updated.append(dict(msg))
|
||||
|
||||
if updated is None:
|
||||
return messages
|
||||
return updated
|
||||
|
||||
@staticmethod
|
||||
def _backfill_missing_tool_results(
|
||||
messages: list[dict[str, Any]],
|
||||
|
||||
@ -28,10 +28,11 @@ class SkillsLoader:
|
||||
specific tools or perform certain tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None):
|
||||
def __init__(self, workspace: Path, builtin_skills_dir: Path | None = None, disabled_skills: set[str] | None = None):
|
||||
self.workspace = workspace
|
||||
self.workspace_skills = workspace / "skills"
|
||||
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
|
||||
self.disabled_skills = disabled_skills or set()
|
||||
|
||||
def _skill_entries_from_dir(self, base: Path, source: str, *, skip_names: set[str] | None = None) -> list[dict[str, str]]:
|
||||
if not base.exists():
|
||||
@ -66,6 +67,9 @@ class SkillsLoader:
|
||||
self._skill_entries_from_dir(self.builtin_skills, "builtin", skip_names=workspace_names)
|
||||
)
|
||||
|
||||
if self.disabled_skills:
|
||||
skills = [s for s in skills if s["name"] not in self.disabled_skills]
|
||||
|
||||
if filter_unavailable:
|
||||
return [skill for skill in skills if self._check_requirements(self._get_skill_meta(skill["name"]))]
|
||||
return skills
|
||||
|
||||
@ -27,6 +27,7 @@ class _SubagentHook(AgentHook):
|
||||
"""Logging-only hook for subagent execution."""
|
||||
|
||||
def __init__(self, task_id: str) -> None:
|
||||
super().__init__()
|
||||
self._task_id = task_id
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
@ -51,6 +52,7 @@ class SubagentManager:
|
||||
web_config: "WebToolsConfig | None" = None,
|
||||
exec_config: "ExecToolConfig | None" = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
disabled_skills: list[str] | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
|
||||
@ -62,6 +64,7 @@ class SubagentManager:
|
||||
self.max_tool_result_chars = max_tool_result_chars
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.disabled_skills = set(disabled_skills or [])
|
||||
self.runner = AgentRunner(provider)
|
||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||
@ -235,7 +238,10 @@ class SubagentManager:
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
|
||||
time_ctx = ContextBuilder._build_runtime_context(None, None)
|
||||
skills_summary = SkillsLoader(self.workspace).build_skills_summary()
|
||||
skills_summary = SkillsLoader(
|
||||
self.workspace,
|
||||
disabled_skills=self.disabled_skills,
|
||||
).build_skills_summary()
|
||||
return render_template(
|
||||
"agent/subagent_system.md",
|
||||
time_ctx=time_ctx,
|
||||
|
||||
105
nanobot/agent/tools/file_state.py
Normal file
105
nanobot/agent/tools/file_state.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""Track file-read state for read-before-edit warnings and read deduplication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ReadState:
|
||||
mtime: float
|
||||
offset: int
|
||||
limit: int | None
|
||||
content_hash: str | None
|
||||
can_dedup: bool
|
||||
|
||||
|
||||
_state: dict[str, ReadState] = {}
|
||||
|
||||
|
||||
def _hash_file(p: str) -> str | None:
|
||||
try:
|
||||
return hashlib.sha256(Path(p).read_bytes()).hexdigest()
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
def record_read(path: str | Path, offset: int = 1, limit: int | None = None) -> None:
|
||||
"""Record that a file was read (called after successful read)."""
|
||||
p = str(Path(path).resolve())
|
||||
try:
|
||||
mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
return
|
||||
_state[p] = ReadState(
|
||||
mtime=mtime,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
content_hash=_hash_file(p),
|
||||
can_dedup=True,
|
||||
)
|
||||
|
||||
|
||||
def record_write(path: str | Path) -> None:
|
||||
"""Record that a file was written (updates mtime in state)."""
|
||||
p = str(Path(path).resolve())
|
||||
try:
|
||||
mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
_state.pop(p, None)
|
||||
return
|
||||
_state[p] = ReadState(
|
||||
mtime=mtime,
|
||||
offset=1,
|
||||
limit=None,
|
||||
content_hash=_hash_file(p),
|
||||
can_dedup=False,
|
||||
)
|
||||
|
||||
|
||||
def check_read(path: str | Path) -> str | None:
|
||||
"""Check if a file has been read and is fresh.
|
||||
|
||||
Returns None if OK, or a warning string.
|
||||
When mtime changed but file content is identical (e.g. touch, editor save),
|
||||
the check passes to avoid false-positive staleness warnings.
|
||||
"""
|
||||
p = str(Path(path).resolve())
|
||||
entry = _state.get(p)
|
||||
if entry is None:
|
||||
return "Warning: file has not been read yet. Read it first to verify content before editing."
|
||||
try:
|
||||
current_mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
return None
|
||||
if current_mtime != entry.mtime:
|
||||
if entry.content_hash and _hash_file(p) == entry.content_hash:
|
||||
entry.mtime = current_mtime
|
||||
return None
|
||||
return "Warning: file has been modified since last read. Re-read to verify content before editing."
|
||||
return None
|
||||
|
||||
|
||||
def is_unchanged(path: str | Path, offset: int = 1, limit: int | None = None) -> bool:
|
||||
"""Return True if file was previously read with same params and mtime is unchanged."""
|
||||
p = str(Path(path).resolve())
|
||||
entry = _state.get(p)
|
||||
if entry is None:
|
||||
return False
|
||||
if not entry.can_dedup:
|
||||
return False
|
||||
if entry.offset != offset or entry.limit != limit:
|
||||
return False
|
||||
try:
|
||||
current_mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
return False
|
||||
return current_mtime == entry.mtime
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
"""Clear all tracked state (useful for testing)."""
|
||||
_state.clear()
|
||||
@ -2,11 +2,13 @@
|
||||
|
||||
import difflib
|
||||
import mimetypes
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools import file_state
|
||||
from nanobot.utils.helpers import build_image_content_blocks, detect_image_mime
|
||||
from nanobot.config.paths import get_media_dir
|
||||
|
||||
@ -60,6 +62,36 @@ class _FsTool(Tool):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_BLOCKED_DEVICE_PATHS = frozenset({
|
||||
"/dev/zero", "/dev/random", "/dev/urandom", "/dev/full",
|
||||
"/dev/stdin", "/dev/stdout", "/dev/stderr",
|
||||
"/dev/tty", "/dev/console",
|
||||
"/dev/fd/0", "/dev/fd/1", "/dev/fd/2",
|
||||
})
|
||||
|
||||
|
||||
def _is_blocked_device(path: str | Path) -> bool:
|
||||
"""Check if path is a blocked device that could hang or produce infinite output."""
|
||||
import re
|
||||
raw = str(path)
|
||||
if raw in _BLOCKED_DEVICE_PATHS:
|
||||
return True
|
||||
if re.match(r"/proc/\d+/fd/[012]$", raw) or re.match(r"/proc/self/fd/[012]$", raw):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _parse_page_range(pages: str, total: int) -> tuple[int, int]:
|
||||
"""Parse a page range like '2-5' into 0-based (start, end) inclusive."""
|
||||
parts = pages.strip().split("-")
|
||||
if len(parts) == 1:
|
||||
p = int(parts[0])
|
||||
return max(0, p - 1), min(p - 1, total - 1)
|
||||
start = int(parts[0])
|
||||
end = int(parts[1])
|
||||
return max(0, start - 1), min(end - 1, total - 1)
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
path=StringSchema("The file path to read"),
|
||||
@ -73,6 +105,7 @@ class _FsTool(Tool):
|
||||
description="Maximum number of lines to read (default 2000)",
|
||||
minimum=1,
|
||||
),
|
||||
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
|
||||
required=["path"],
|
||||
)
|
||||
)
|
||||
@ -81,6 +114,7 @@ class ReadFileTool(_FsTool):
|
||||
|
||||
_MAX_CHARS = 128_000
|
||||
_DEFAULT_LIMIT = 2000
|
||||
_MAX_PDF_PAGES = 20
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -89,9 +123,10 @@ class ReadFileTool(_FsTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a text file. Output format: LINE_NUM|CONTENT. "
|
||||
"Read a file (text or image). Text output format: LINE_NUM|CONTENT. "
|
||||
"Images return visual content for analysis. "
|
||||
"Use offset and limit for large files. "
|
||||
"Cannot read binary files or images. "
|
||||
"Cannot read non-image binary files. "
|
||||
"Reads exceeding ~128K chars are truncated."
|
||||
)
|
||||
|
||||
@ -99,16 +134,27 @@ class ReadFileTool(_FsTool):
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any:
|
||||
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any:
|
||||
try:
|
||||
if not path:
|
||||
return "Error reading file: Unknown path"
|
||||
|
||||
# Device path blacklist
|
||||
if _is_blocked_device(path):
|
||||
return f"Error: Reading {path} is blocked (device path that could hang or produce infinite output)."
|
||||
|
||||
fp = self._resolve(path)
|
||||
if _is_blocked_device(fp):
|
||||
return f"Error: Reading {fp} is blocked (device path that could hang or produce infinite output)."
|
||||
if not fp.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
if not fp.is_file():
|
||||
return f"Error: Not a file: {path}"
|
||||
|
||||
# PDF support
|
||||
if fp.suffix.lower() == ".pdf":
|
||||
return self._read_pdf(fp, pages)
|
||||
|
||||
raw = fp.read_bytes()
|
||||
if not raw:
|
||||
return f"(Empty file: {path})"
|
||||
@ -117,6 +163,10 @@ class ReadFileTool(_FsTool):
|
||||
if mime and mime.startswith("image/"):
|
||||
return build_image_content_blocks(raw, mime, str(fp), f"(Image file: {path})")
|
||||
|
||||
# Read dedup: same path + offset + limit + unchanged mtime → stub
|
||||
if file_state.is_unchanged(fp, offset=offset, limit=limit):
|
||||
return f"[File unchanged since last read: {path}]"
|
||||
|
||||
try:
|
||||
text_content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
@ -149,12 +199,59 @@ class ReadFileTool(_FsTool):
|
||||
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
|
||||
else:
|
||||
result += f"\n\n(End of file — {total} lines total)"
|
||||
file_state.record_read(fp, offset=offset, limit=limit)
|
||||
return result
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error reading file: {e}"
|
||||
|
||||
def _read_pdf(self, fp: Path, pages: str | None) -> str:
|
||||
try:
|
||||
import fitz # pymupdf
|
||||
except ImportError:
|
||||
return "Error: PDF reading requires pymupdf. Install with: pip install pymupdf"
|
||||
|
||||
try:
|
||||
doc = fitz.open(str(fp))
|
||||
except Exception as e:
|
||||
return f"Error reading PDF: {e}"
|
||||
|
||||
total_pages = len(doc)
|
||||
if pages:
|
||||
try:
|
||||
start, end = _parse_page_range(pages, total_pages)
|
||||
except (ValueError, IndexError):
|
||||
doc.close()
|
||||
return f"Error: Invalid page range '{pages}'. Use format like '1-5'."
|
||||
if start > end or start >= total_pages:
|
||||
doc.close()
|
||||
return f"Error: Page range '{pages}' is out of bounds (document has {total_pages} pages)."
|
||||
else:
|
||||
start = 0
|
||||
end = min(total_pages - 1, self._MAX_PDF_PAGES - 1)
|
||||
|
||||
if end - start + 1 > self._MAX_PDF_PAGES:
|
||||
end = start + self._MAX_PDF_PAGES - 1
|
||||
|
||||
parts: list[str] = []
|
||||
for i in range(start, end + 1):
|
||||
page = doc[i]
|
||||
text = page.get_text().strip()
|
||||
if text:
|
||||
parts.append(f"--- Page {i + 1} ---\n{text}")
|
||||
doc.close()
|
||||
|
||||
if not parts:
|
||||
return f"(PDF has no extractable text: {fp})"
|
||||
|
||||
result = "\n\n".join(parts)
|
||||
if end < total_pages - 1:
|
||||
result += f"\n\n(Showing pages {start + 1}-{end + 1} of {total_pages}. Use pages='{end + 2}-{min(end + 1 + self._MAX_PDF_PAGES, total_pages)}' to continue.)"
|
||||
if len(result) > self._MAX_CHARS:
|
||||
result = result[:self._MAX_CHARS] + "\n\n(PDF text truncated at ~128K chars)"
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# write_file
|
||||
@ -192,6 +289,7 @@ class WriteFileTool(_FsTool):
|
||||
fp = self._resolve(path)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(content, encoding="utf-8")
|
||||
file_state.record_write(fp)
|
||||
return f"Successfully wrote {len(content)} characters to {fp}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
@ -203,30 +301,269 @@ class WriteFileTool(_FsTool):
|
||||
# edit_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_QUOTE_TABLE = str.maketrans({
|
||||
"\u2018": "'", "\u2019": "'", # curly single → straight
|
||||
"\u201c": '"', "\u201d": '"', # curly double → straight
|
||||
"'": "'", '"': '"', # identity (kept for completeness)
|
||||
})
|
||||
|
||||
|
||||
def _normalize_quotes(s: str) -> str:
|
||||
return s.translate(_QUOTE_TABLE)
|
||||
|
||||
|
||||
def _curly_double_quotes(text: str) -> str:
|
||||
parts: list[str] = []
|
||||
opening = True
|
||||
for ch in text:
|
||||
if ch == '"':
|
||||
parts.append("\u201c" if opening else "\u201d")
|
||||
opening = not opening
|
||||
else:
|
||||
parts.append(ch)
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _curly_single_quotes(text: str) -> str:
|
||||
parts: list[str] = []
|
||||
opening = True
|
||||
for i, ch in enumerate(text):
|
||||
if ch != "'":
|
||||
parts.append(ch)
|
||||
continue
|
||||
prev_ch = text[i - 1] if i > 0 else ""
|
||||
next_ch = text[i + 1] if i + 1 < len(text) else ""
|
||||
if prev_ch.isalnum() and next_ch.isalnum():
|
||||
parts.append("\u2019")
|
||||
continue
|
||||
parts.append("\u2018" if opening else "\u2019")
|
||||
opening = not opening
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _preserve_quote_style(old_text: str, actual_text: str, new_text: str) -> str:
|
||||
"""Preserve curly quote style when a quote-normalized fallback matched."""
|
||||
if _normalize_quotes(old_text.strip()) != _normalize_quotes(actual_text.strip()) or old_text == actual_text:
|
||||
return new_text
|
||||
|
||||
styled = new_text
|
||||
if any(ch in actual_text for ch in ("\u201c", "\u201d")) and '"' in styled:
|
||||
styled = _curly_double_quotes(styled)
|
||||
if any(ch in actual_text for ch in ("\u2018", "\u2019")) and "'" in styled:
|
||||
styled = _curly_single_quotes(styled)
|
||||
return styled
|
||||
|
||||
|
||||
def _leading_ws(line: str) -> str:
|
||||
return line[: len(line) - len(line.lstrip(" \t"))]
|
||||
|
||||
|
||||
def _reindent_like_match(old_text: str, actual_text: str, new_text: str) -> str:
|
||||
"""Preserve the outer indentation from the actual matched block."""
|
||||
old_lines = old_text.split("\n")
|
||||
actual_lines = actual_text.split("\n")
|
||||
if len(old_lines) != len(actual_lines):
|
||||
return new_text
|
||||
|
||||
comparable = [
|
||||
(old_line, actual_line)
|
||||
for old_line, actual_line in zip(old_lines, actual_lines)
|
||||
if old_line.strip() and actual_line.strip()
|
||||
]
|
||||
if not comparable or any(
|
||||
_normalize_quotes(old_line.strip()) != _normalize_quotes(actual_line.strip())
|
||||
for old_line, actual_line in comparable
|
||||
):
|
||||
return new_text
|
||||
|
||||
old_ws = _leading_ws(comparable[0][0])
|
||||
actual_ws = _leading_ws(comparable[0][1])
|
||||
if actual_ws == old_ws:
|
||||
return new_text
|
||||
|
||||
if old_ws:
|
||||
if not actual_ws.startswith(old_ws):
|
||||
return new_text
|
||||
delta = actual_ws[len(old_ws):]
|
||||
else:
|
||||
delta = actual_ws
|
||||
|
||||
if not delta:
|
||||
return new_text
|
||||
|
||||
return "\n".join((delta + line) if line else line for line in new_text.split("\n"))
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _MatchSpan:
|
||||
start: int
|
||||
end: int
|
||||
text: str
|
||||
line: int
|
||||
|
||||
|
||||
def _find_exact_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||
matches: list[_MatchSpan] = []
|
||||
start = 0
|
||||
while True:
|
||||
idx = content.find(old_text, start)
|
||||
if idx == -1:
|
||||
break
|
||||
matches.append(
|
||||
_MatchSpan(
|
||||
start=idx,
|
||||
end=idx + len(old_text),
|
||||
text=content[idx : idx + len(old_text)],
|
||||
line=content.count("\n", 0, idx) + 1,
|
||||
)
|
||||
)
|
||||
start = idx + max(1, len(old_text))
|
||||
return matches
|
||||
|
||||
|
||||
def _find_trim_matches(content: str, old_text: str, *, normalize_quotes: bool = False) -> list[_MatchSpan]:
|
||||
old_lines = old_text.splitlines()
|
||||
if not old_lines:
|
||||
return []
|
||||
|
||||
content_lines = content.splitlines()
|
||||
content_lines_keepends = content.splitlines(keepends=True)
|
||||
if len(content_lines) < len(old_lines):
|
||||
return []
|
||||
|
||||
offsets: list[int] = []
|
||||
pos = 0
|
||||
for line in content_lines_keepends:
|
||||
offsets.append(pos)
|
||||
pos += len(line)
|
||||
offsets.append(pos)
|
||||
|
||||
if normalize_quotes:
|
||||
stripped_old = [_normalize_quotes(line.strip()) for line in old_lines]
|
||||
else:
|
||||
stripped_old = [line.strip() for line in old_lines]
|
||||
|
||||
matches: list[_MatchSpan] = []
|
||||
window_size = len(stripped_old)
|
||||
for i in range(len(content_lines) - window_size + 1):
|
||||
window = content_lines[i : i + window_size]
|
||||
if normalize_quotes:
|
||||
comparable = [_normalize_quotes(line.strip()) for line in window]
|
||||
else:
|
||||
comparable = [line.strip() for line in window]
|
||||
if comparable != stripped_old:
|
||||
continue
|
||||
|
||||
start = offsets[i]
|
||||
end = offsets[i + window_size]
|
||||
if content_lines_keepends[i + window_size - 1].endswith("\n"):
|
||||
end -= 1
|
||||
matches.append(
|
||||
_MatchSpan(
|
||||
start=start,
|
||||
end=end,
|
||||
text=content[start:end],
|
||||
line=i + 1,
|
||||
)
|
||||
)
|
||||
return matches
|
||||
|
||||
|
||||
def _find_quote_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||
norm_content = _normalize_quotes(content)
|
||||
norm_old = _normalize_quotes(old_text)
|
||||
matches: list[_MatchSpan] = []
|
||||
start = 0
|
||||
while True:
|
||||
idx = norm_content.find(norm_old, start)
|
||||
if idx == -1:
|
||||
break
|
||||
matches.append(
|
||||
_MatchSpan(
|
||||
start=idx,
|
||||
end=idx + len(old_text),
|
||||
text=content[idx : idx + len(old_text)],
|
||||
line=content.count("\n", 0, idx) + 1,
|
||||
)
|
||||
)
|
||||
start = idx + max(1, len(norm_old))
|
||||
return matches
|
||||
|
||||
|
||||
def _find_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||
"""Locate all matches using progressively looser strategies."""
|
||||
for matcher in (
|
||||
lambda: _find_exact_matches(content, old_text),
|
||||
lambda: _find_trim_matches(content, old_text),
|
||||
lambda: _find_trim_matches(content, old_text, normalize_quotes=True),
|
||||
lambda: _find_quote_matches(content, old_text),
|
||||
):
|
||||
matches = matcher()
|
||||
if matches:
|
||||
return matches
|
||||
return []
|
||||
|
||||
|
||||
def _find_match_line_numbers(content: str, old_text: str) -> list[int]:
|
||||
"""Return 1-based starting line numbers for the current matching strategies."""
|
||||
return [match.line for match in _find_matches(content, old_text)]
|
||||
|
||||
|
||||
def _collapse_internal_whitespace(text: str) -> str:
|
||||
return "\n".join(" ".join(line.split()) for line in text.splitlines())
|
||||
|
||||
|
||||
def _diagnose_near_match(old_text: str, actual_text: str) -> list[str]:
|
||||
"""Return actionable hints describing why text was close but not exact."""
|
||||
hints: list[str] = []
|
||||
|
||||
if old_text.lower() == actual_text.lower() and old_text != actual_text:
|
||||
hints.append("letter case differs")
|
||||
if _collapse_internal_whitespace(old_text) == _collapse_internal_whitespace(actual_text) and old_text != actual_text:
|
||||
hints.append("whitespace differs")
|
||||
if old_text.rstrip("\n") == actual_text.rstrip("\n") and old_text != actual_text:
|
||||
hints.append("trailing newline differs")
|
||||
if _normalize_quotes(old_text) == _normalize_quotes(actual_text) and old_text != actual_text:
|
||||
hints.append("quote style differs")
|
||||
|
||||
return hints
|
||||
|
||||
|
||||
def _best_window(old_text: str, content: str) -> tuple[float, int, list[str], list[str]]:
|
||||
"""Find the closest line-window match and return ratio/start/snippet/hints."""
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = max(1, len(old_lines))
|
||||
|
||||
best_ratio, best_start = -1.0, 0
|
||||
best_window_lines: list[str] = []
|
||||
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
current = lines[i : i + window]
|
||||
ratio = difflib.SequenceMatcher(None, old_lines, current).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start = ratio, i
|
||||
best_window_lines = current
|
||||
|
||||
actual_text = "".join(best_window_lines).replace("\r\n", "\n").rstrip("\n")
|
||||
hints = _diagnose_near_match(old_text.replace("\r\n", "\n").rstrip("\n"), actual_text)
|
||||
return best_ratio, best_start, best_window_lines, hints
|
||||
|
||||
|
||||
def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||
"""Locate old_text in content: exact first, then line-trimmed sliding window.
|
||||
"""Locate old_text in content with a multi-level fallback chain:
|
||||
|
||||
1. Exact substring match
|
||||
2. Line-trimmed sliding window (handles indentation differences)
|
||||
3. Smart quote normalization (curly ↔ straight quotes)
|
||||
|
||||
Both inputs should use LF line endings (caller normalises CRLF).
|
||||
Returns (matched_fragment, count) or (None, 0).
|
||||
"""
|
||||
if old_text in content:
|
||||
return old_text, content.count(old_text)
|
||||
|
||||
old_lines = old_text.splitlines()
|
||||
if not old_lines:
|
||||
matches = _find_matches(content, old_text)
|
||||
if not matches:
|
||||
return None, 0
|
||||
stripped_old = [l.strip() for l in old_lines]
|
||||
content_lines = content.splitlines()
|
||||
|
||||
candidates = []
|
||||
for i in range(len(content_lines) - len(stripped_old) + 1):
|
||||
window = content_lines[i : i + len(stripped_old)]
|
||||
if [l.strip() for l in window] == stripped_old:
|
||||
candidates.append("\n".join(window))
|
||||
|
||||
if candidates:
|
||||
return candidates[0], len(candidates)
|
||||
return None, 0
|
||||
return matches[0].text, len(matches)
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
@ -241,6 +578,9 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||
class EditFileTool(_FsTool):
|
||||
"""Edit a file by replacing text with fallback matching."""
|
||||
|
||||
_MAX_EDIT_FILE_SIZE = 1024 * 1024 * 1024 # 1 GiB
|
||||
_MARKDOWN_EXTS = frozenset({".md", ".mdx", ".markdown"})
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_file"
|
||||
@ -249,11 +589,16 @@ class EditFileTool(_FsTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a file by replacing old_text with new_text. "
|
||||
"Tolerates minor whitespace/indentation differences. "
|
||||
"Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. "
|
||||
"If old_text matches multiple times, you must provide more context "
|
||||
"or set replace_all=true. Shows a diff of the closest match on failure."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _strip_trailing_ws(text: str) -> str:
|
||||
"""Strip trailing whitespace from each line."""
|
||||
return "\n".join(line.rstrip() for line in text.split("\n"))
|
||||
|
||||
async def execute(
|
||||
self, path: str | None = None, old_text: str | None = None,
|
||||
new_text: str | None = None,
|
||||
@ -267,55 +612,133 @@ class EditFileTool(_FsTool):
|
||||
if new_text is None:
|
||||
raise ValueError("Unknown new_text")
|
||||
|
||||
# .ipynb detection
|
||||
if path.endswith(".ipynb"):
|
||||
return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file."
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
# Create-file semantics: old_text='' + file doesn't exist → create
|
||||
if not fp.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
if old_text == "":
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(new_text, encoding="utf-8")
|
||||
file_state.record_write(fp)
|
||||
return f"Successfully created {fp}"
|
||||
return self._file_not_found_msg(path, fp)
|
||||
|
||||
# File size protection
|
||||
try:
|
||||
fsize = fp.stat().st_size
|
||||
except OSError:
|
||||
fsize = 0
|
||||
if fsize > self._MAX_EDIT_FILE_SIZE:
|
||||
return f"Error: File too large to edit ({fsize / (1024**3):.1f} GiB). Maximum is 1 GiB."
|
||||
|
||||
# Create-file: old_text='' but file exists and not empty → reject
|
||||
if old_text == "":
|
||||
raw = fp.read_bytes()
|
||||
content = raw.decode("utf-8")
|
||||
if content.strip():
|
||||
return f"Error: Cannot create file — {path} already exists and is not empty."
|
||||
fp.write_text(new_text, encoding="utf-8")
|
||||
file_state.record_write(fp)
|
||||
return f"Successfully edited {fp}"
|
||||
|
||||
# Read-before-edit check
|
||||
warning = file_state.check_read(fp)
|
||||
|
||||
raw = fp.read_bytes()
|
||||
uses_crlf = b"\r\n" in raw
|
||||
content = raw.decode("utf-8").replace("\r\n", "\n")
|
||||
match, count = _find_match(content, old_text.replace("\r\n", "\n"))
|
||||
norm_old = old_text.replace("\r\n", "\n")
|
||||
matches = _find_matches(content, norm_old)
|
||||
|
||||
if match is None:
|
||||
if not matches:
|
||||
return self._not_found_msg(old_text, content, path)
|
||||
count = len(matches)
|
||||
if count > 1 and not replace_all:
|
||||
line_numbers = [match.line for match in matches]
|
||||
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
|
||||
if len(line_numbers) > 3:
|
||||
preview += ", ..."
|
||||
location_hint = f" at {preview}" if preview else ""
|
||||
return (
|
||||
f"Warning: old_text appears {count} times. "
|
||||
f"Warning: old_text appears {count} times{location_hint}. "
|
||||
"Provide more context to make it unique, or set replace_all=true."
|
||||
)
|
||||
|
||||
norm_new = new_text.replace("\r\n", "\n")
|
||||
new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1)
|
||||
|
||||
# Trailing whitespace stripping (skip markdown to preserve double-space line breaks)
|
||||
if fp.suffix.lower() not in self._MARKDOWN_EXTS:
|
||||
norm_new = self._strip_trailing_ws(norm_new)
|
||||
|
||||
selected = matches if replace_all else matches[:1]
|
||||
new_content = content
|
||||
for match in reversed(selected):
|
||||
replacement = _preserve_quote_style(norm_old, match.text, norm_new)
|
||||
replacement = _reindent_like_match(norm_old, match.text, replacement)
|
||||
|
||||
# Delete-line cleanup: when deleting text (new_text=''), consume trailing
|
||||
# newline to avoid leaving a blank line
|
||||
end = match.end
|
||||
if replacement == "" and not match.text.endswith("\n") and content[end:end + 1] == "\n":
|
||||
end += 1
|
||||
|
||||
new_content = new_content[: match.start] + replacement + new_content[end:]
|
||||
if uses_crlf:
|
||||
new_content = new_content.replace("\n", "\r\n")
|
||||
|
||||
fp.write_bytes(new_content.encode("utf-8"))
|
||||
return f"Successfully edited {fp}"
|
||||
file_state.record_write(fp)
|
||||
msg = f"Successfully edited {fp}"
|
||||
if warning:
|
||||
msg = f"{warning}\n{msg}"
|
||||
return msg
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing file: {e}"
|
||||
|
||||
def _file_not_found_msg(self, path: str, fp: Path) -> str:
|
||||
"""Build an error message with 'Did you mean ...?' suggestions."""
|
||||
parent = fp.parent
|
||||
suggestions: list[str] = []
|
||||
if parent.is_dir():
|
||||
siblings = [f.name for f in parent.iterdir() if f.is_file()]
|
||||
close = difflib.get_close_matches(fp.name, siblings, n=3, cutoff=0.6)
|
||||
suggestions = [str(parent / c) for c in close]
|
||||
parts = [f"Error: File not found: {path}"]
|
||||
if suggestions:
|
||||
parts.append("Did you mean: " + ", ".join(suggestions) + "?")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _not_found_msg(old_text: str, content: str, path: str) -> str:
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = len(old_lines)
|
||||
|
||||
best_ratio, best_start = 0.0, 0
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start = ratio, i
|
||||
|
||||
best_ratio, best_start, best_window_lines, hints = _best_window(old_text, content)
|
||||
if best_ratio > 0.5:
|
||||
diff = "\n".join(difflib.unified_diff(
|
||||
old_lines, lines[best_start : best_start + window],
|
||||
old_text.splitlines(keepends=True),
|
||||
best_window_lines,
|
||||
fromfile="old_text (provided)",
|
||||
tofile=f"{path} (actual, line {best_start + 1})",
|
||||
lineterm="",
|
||||
))
|
||||
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
hint_text = ""
|
||||
if hints:
|
||||
hint_text = "\nPossible cause: " + ", ".join(hints) + "."
|
||||
return (
|
||||
f"Error: old_text not found in {path}."
|
||||
f"{hint_text}\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
)
|
||||
|
||||
if hints:
|
||||
return (
|
||||
f"Error: old_text not found in {path}. "
|
||||
f"Possible cause: {', '.join(hints)}. "
|
||||
"Copy the exact text from read_file and try again."
|
||||
)
|
||||
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||
|
||||
|
||||
|
||||
@ -57,9 +57,7 @@ def _normalize_schema_for_openai(schema: Any) -> dict[str, Any]:
|
||||
|
||||
if "properties" in normalized and isinstance(normalized["properties"], dict):
|
||||
normalized["properties"] = {
|
||||
name: _normalize_schema_for_openai(prop)
|
||||
if isinstance(prop, dict)
|
||||
else prop
|
||||
name: _normalize_schema_for_openai(prop) if isinstance(prop, dict) else prop
|
||||
for name, prop in normalized["properties"].items()
|
||||
}
|
||||
|
||||
@ -138,9 +136,7 @@ class MCPToolWrapper(Tool):
|
||||
class MCPResourceWrapper(Tool):
|
||||
"""Wraps an MCP resource URI as a read-only nanobot Tool."""
|
||||
|
||||
def __init__(
|
||||
self, session, server_name: str, resource_def, resource_timeout: int = 30
|
||||
):
|
||||
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
|
||||
self._session = session
|
||||
self._uri = resource_def.uri
|
||||
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
|
||||
@ -211,9 +207,7 @@ class MCPResourceWrapper(Tool):
|
||||
class MCPPromptWrapper(Tool):
|
||||
"""Wraps an MCP prompt as a read-only nanobot Tool."""
|
||||
|
||||
def __init__(
|
||||
self, session, server_name: str, prompt_def, prompt_timeout: int = 30
|
||||
):
|
||||
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
|
||||
self._session = session
|
||||
self._prompt_name = prompt_def.name
|
||||
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
|
||||
@ -266,9 +260,7 @@ class MCPPromptWrapper(Tool):
|
||||
timeout=self._prompt_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout
|
||||
)
|
||||
logger.warning("MCP prompt '{}' timed out after {}s", self._name, self._prompt_timeout)
|
||||
return f"(MCP prompt call timed out after {self._prompt_timeout}s)"
|
||||
except asyncio.CancelledError:
|
||||
task = asyncio.current_task()
|
||||
@ -279,13 +271,17 @@ class MCPPromptWrapper(Tool):
|
||||
except McpError as exc:
|
||||
logger.error(
|
||||
"MCP prompt '{}' failed: code={} message={}",
|
||||
self._name, exc.error.code, exc.error.message,
|
||||
self._name,
|
||||
exc.error.code,
|
||||
exc.error.message,
|
||||
)
|
||||
return f"(MCP prompt call failed: {exc.error.message} [code {exc.error.code}])"
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"MCP prompt '{}' failed: {}: {}",
|
||||
self._name, type(exc).__name__, exc,
|
||||
self._name,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
return f"(MCP prompt call failed: {type(exc).__name__})"
|
||||
|
||||
@ -307,35 +303,44 @@ class MCPPromptWrapper(Tool):
|
||||
|
||||
|
||||
async def connect_mcp_servers(
|
||||
mcp_servers: dict, registry: ToolRegistry, stack: AsyncExitStack
|
||||
) -> None:
|
||||
"""Connect to configured MCP servers and register their tools, resources, and prompts."""
|
||||
mcp_servers: dict, registry: ToolRegistry
|
||||
) -> dict[str, AsyncExitStack]:
|
||||
"""Connect to configured MCP servers and register their tools, resources, prompts.
|
||||
|
||||
Returns a dict mapping server name -> its dedicated AsyncExitStack.
|
||||
Each server gets its own stack and runs in its own task to prevent
|
||||
cancel scope conflicts when multiple MCP servers are configured.
|
||||
"""
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamable_http_client
|
||||
|
||||
for name, cfg in mcp_servers.items():
|
||||
async def connect_single_server(name: str, cfg) -> tuple[str, AsyncExitStack | None]:
|
||||
server_stack = AsyncExitStack()
|
||||
await server_stack.__aenter__()
|
||||
|
||||
try:
|
||||
transport_type = cfg.type
|
||||
if not transport_type:
|
||||
if cfg.command:
|
||||
transport_type = "stdio"
|
||||
elif cfg.url:
|
||||
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
|
||||
transport_type = (
|
||||
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||
continue
|
||||
await server_stack.aclose()
|
||||
return name, None
|
||||
|
||||
if transport_type == "stdio":
|
||||
params = StdioServerParameters(
|
||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||
)
|
||||
read, write = await stack.enter_async_context(stdio_client(params))
|
||||
read, write = await server_stack.enter_async_context(stdio_client(params))
|
||||
elif transport_type == "sse":
|
||||
|
||||
def httpx_client_factory(
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: httpx.Timeout | None = None,
|
||||
@ -353,27 +358,26 @@ async def connect_mcp_servers(
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
read, write = await stack.enter_async_context(
|
||||
read, write = await server_stack.enter_async_context(
|
||||
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
||||
)
|
||||
elif transport_type == "streamableHttp":
|
||||
# Always provide an explicit httpx client so MCP HTTP transport does not
|
||||
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
||||
http_client = await stack.enter_async_context(
|
||||
http_client = await server_stack.enter_async_context(
|
||||
httpx.AsyncClient(
|
||||
headers=cfg.headers or None,
|
||||
follow_redirects=True,
|
||||
timeout=None,
|
||||
)
|
||||
)
|
||||
read, write, _ = await stack.enter_async_context(
|
||||
read, write, _ = await server_stack.enter_async_context(
|
||||
streamable_http_client(cfg.url, http_client=http_client)
|
||||
)
|
||||
else:
|
||||
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
||||
continue
|
||||
await server_stack.aclose()
|
||||
return name, None
|
||||
|
||||
session = await stack.enter_async_context(ClientSession(read, write))
|
||||
session = await server_stack.enter_async_context(ClientSession(read, write))
|
||||
await session.initialize()
|
||||
|
||||
tools = await session.list_tools()
|
||||
@ -418,7 +422,6 @@ async def connect_mcp_servers(
|
||||
", ".join(available_wrapped_names) or "(none)",
|
||||
)
|
||||
|
||||
# --- Register resources ---
|
||||
try:
|
||||
resources_result = await session.list_resources()
|
||||
for resource in resources_result.resources:
|
||||
@ -433,7 +436,6 @@ async def connect_mcp_servers(
|
||||
except Exception as e:
|
||||
logger.debug("MCP server '{}': resources not supported or failed: {}", name, e)
|
||||
|
||||
# --- Register prompts ---
|
||||
try:
|
||||
prompts_result = await session.list_prompts()
|
||||
for prompt in prompts_result.prompts:
|
||||
@ -442,14 +444,54 @@ async def connect_mcp_servers(
|
||||
)
|
||||
registry.register(wrapper)
|
||||
registered_count += 1
|
||||
logger.debug(
|
||||
"MCP: registered prompt '{}' from server '{}'", wrapper.name, name
|
||||
)
|
||||
logger.debug("MCP: registered prompt '{}' from server '{}'", wrapper.name, name)
|
||||
except Exception as e:
|
||||
logger.debug("MCP server '{}': prompts not supported or failed: {}", name, e)
|
||||
|
||||
logger.info(
|
||||
"MCP server '{}': connected, {} capabilities registered", name, registered_count
|
||||
)
|
||||
return name, server_stack
|
||||
|
||||
except Exception as e:
|
||||
logger.error("MCP server '{}': failed to connect: {}", name, e)
|
||||
hint = ""
|
||||
text = str(e).lower()
|
||||
if any(
|
||||
marker in text
|
||||
for marker in (
|
||||
"parse error",
|
||||
"invalid json",
|
||||
"unexpected token",
|
||||
"jsonrpc",
|
||||
"content-length",
|
||||
)
|
||||
):
|
||||
hint = (
|
||||
" Hint: this looks like stdio protocol pollution. Make sure the MCP server writes "
|
||||
"only JSON-RPC to stdout and sends logs/debug output to stderr instead."
|
||||
)
|
||||
logger.error("MCP server '{}': failed to connect: {}{}", name, e, hint)
|
||||
try:
|
||||
await server_stack.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
return name, None
|
||||
|
||||
server_stacks: dict[str, AsyncExitStack] = {}
|
||||
|
||||
tasks: list[asyncio.Task] = []
|
||||
for name, cfg in mcp_servers.items():
|
||||
task = asyncio.create_task(connect_single_server(name, cfg))
|
||||
tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, result in enumerate(results):
|
||||
name = list(mcp_servers.keys())[i]
|
||||
if isinstance(result, BaseException):
|
||||
if not isinstance(result, asyncio.CancelledError):
|
||||
logger.error("MCP server '{}' connection task failed: {}", name, result)
|
||||
elif result is not None and result[1] is not None:
|
||||
server_stacks[result[0]] = result[1]
|
||||
|
||||
return server_stacks
|
||||
|
||||
161
nanobot/agent/tools/notebook.py
Normal file
161
nanobot/agent/tools/notebook.py
Normal file
@ -0,0 +1,161 @@
|
||||
"""NotebookEditTool — edit Jupyter .ipynb notebooks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import tool_parameters
|
||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools.filesystem import _FsTool
|
||||
|
||||
|
||||
def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict:
|
||||
cell: dict[str, Any] = {
|
||||
"cell_type": cell_type,
|
||||
"source": source,
|
||||
"metadata": {},
|
||||
}
|
||||
if cell_type == "code":
|
||||
cell["outputs"] = []
|
||||
cell["execution_count"] = None
|
||||
if generate_id:
|
||||
cell["id"] = uuid.uuid4().hex[:8]
|
||||
return cell
|
||||
|
||||
|
||||
def _make_empty_notebook() -> dict:
|
||||
return {
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5,
|
||||
"metadata": {
|
||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
||||
"language_info": {"name": "python"},
|
||||
},
|
||||
"cells": [],
|
||||
}
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
path=StringSchema("Path to the .ipynb notebook file"),
|
||||
cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0),
|
||||
new_source=StringSchema("New source content for the cell"),
|
||||
cell_type=StringSchema(
|
||||
"Cell type: 'code' or 'markdown' (default: code)",
|
||||
enum=["code", "markdown"],
|
||||
),
|
||||
edit_mode=StringSchema(
|
||||
"Mode: 'replace' (default), 'insert' (after target), or 'delete'",
|
||||
enum=["replace", "insert", "delete"],
|
||||
),
|
||||
required=["path", "cell_index"],
|
||||
)
|
||||
)
|
||||
class NotebookEditTool(_FsTool):
|
||||
"""Edit Jupyter notebook cells: replace, insert, or delete."""
|
||||
|
||||
_VALID_CELL_TYPES = frozenset({"code", "markdown"})
|
||||
_VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"})
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "notebook_edit"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a Jupyter notebook (.ipynb) cell. "
|
||||
"Modes: replace (default) replaces cell content, "
|
||||
"insert adds a new cell after the target index, "
|
||||
"delete removes the cell at the index. "
|
||||
"cell_index is 0-based."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
path: str | None = None,
|
||||
cell_index: int = 0,
|
||||
new_source: str = "",
|
||||
cell_type: str = "code",
|
||||
edit_mode: str = "replace",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if not path:
|
||||
return "Error: path is required"
|
||||
|
||||
if not path.endswith(".ipynb"):
|
||||
return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files."
|
||||
|
||||
if edit_mode not in self._VALID_EDIT_MODES:
|
||||
return (
|
||||
f"Error: Invalid edit_mode '{edit_mode}'. "
|
||||
"Use one of: replace, insert, delete."
|
||||
)
|
||||
|
||||
if cell_type not in self._VALID_CELL_TYPES:
|
||||
return (
|
||||
f"Error: Invalid cell_type '{cell_type}'. "
|
||||
"Use one of: code, markdown."
|
||||
)
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
# Create new notebook if file doesn't exist and mode is insert
|
||||
if not fp.exists():
|
||||
if edit_mode != "insert":
|
||||
return f"Error: File not found: {path}"
|
||||
nb = _make_empty_notebook()
|
||||
cell = _new_cell(new_source, cell_type, generate_id=True)
|
||||
nb["cells"].append(cell)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully created {fp} with 1 cell"
|
||||
|
||||
try:
|
||||
nb = json.loads(fp.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
return f"Error: Failed to parse notebook: {e}"
|
||||
|
||||
cells = nb.get("cells", [])
|
||||
nbformat_minor = nb.get("nbformat_minor", 0)
|
||||
generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5
|
||||
|
||||
if edit_mode == "delete":
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells.pop(cell_index)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully deleted cell {cell_index} from {fp}"
|
||||
|
||||
if edit_mode == "insert":
|
||||
insert_at = min(cell_index + 1, len(cells))
|
||||
cell = _new_cell(new_source, cell_type, generate_id=generate_id)
|
||||
cells.insert(insert_at, cell)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully inserted cell at index {insert_at} in {fp}"
|
||||
|
||||
# Default: replace
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells[cell_index]["source"] = new_source
|
||||
if cell_type and cells[cell_index].get("cell_type") != cell_type:
|
||||
cells[cell_index]["cell_type"] = cell_type
|
||||
if cell_type == "code":
|
||||
cells[cell_index].setdefault("outputs", [])
|
||||
cells[cell_index].setdefault("execution_count", None)
|
||||
elif "outputs" in cells[cell_index]:
|
||||
del cells[cell_index]["outputs"]
|
||||
cells[cell_index].pop("execution_count", None)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully edited cell {cell_index} in {fp}"
|
||||
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing notebook: {e}"
|
||||
@ -68,6 +68,13 @@ class ToolRegistry:
|
||||
params: dict[str, Any],
|
||||
) -> tuple[Tool | None, dict[str, Any], str | None]:
|
||||
"""Resolve, cast, and validate one tool call."""
|
||||
# Guard against invalid parameter types (e.g., list instead of dict)
|
||||
if not isinstance(params, dict) and name in ('write_file', 'read_file'):
|
||||
return None, params, (
|
||||
f"Error: Tool '{name}' parameters must be a JSON object, got {type(params).__name__}. "
|
||||
"Use named parameters: tool_name(param1=\"value1\", param2=\"value2\")"
|
||||
)
|
||||
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
return None, params, (
|
||||
|
||||
@ -46,6 +46,7 @@ class ExecTool(Tool):
|
||||
restrict_to_workspace: bool = False,
|
||||
sandbox: str = "",
|
||||
path_append: str = "",
|
||||
allowed_env_keys: list[str] | None = None,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.working_dir = working_dir
|
||||
@ -60,10 +61,19 @@ class ExecTool(Tool):
|
||||
r">\s*/dev/sd", # write to disk
|
||||
r"\b(shutdown|reboot|poweroff)\b", # system power
|
||||
r":\(\)\s*\{.*\};\s*:", # fork bomb
|
||||
# Block writes to nanobot internal state files (#2989).
|
||||
# history.jsonl / .dream_cursor are managed by append_history();
|
||||
# direct writes corrupt the cursor format and crash /dream.
|
||||
r">>?\s*\S*(?:history\.jsonl|\.dream_cursor)", # > / >> redirect
|
||||
r"\btee\b[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # tee / tee -a
|
||||
r"\b(?:cp|mv)\b(?:\s+[^\s|;&<>]+)+\s+\S*(?:history\.jsonl|\.dream_cursor)", # cp/mv target
|
||||
r"\bdd\b[^|;&<>]*\bof=\S*(?:history\.jsonl|\.dream_cursor)", # dd of=
|
||||
r"\bsed\s+-i[^|;&<>]*(?:history\.jsonl|\.dream_cursor)", # sed -i
|
||||
]
|
||||
self.allow_patterns = allow_patterns or []
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.path_append = path_append
|
||||
self.allowed_env_keys = allowed_env_keys or []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -91,6 +101,21 @@ class ExecTool(Tool):
|
||||
timeout: int | None = None, **kwargs: Any,
|
||||
) -> str:
|
||||
cwd = working_dir or self.working_dir or os.getcwd()
|
||||
|
||||
# Prevent an LLM-supplied working_dir from escaping the configured
|
||||
# workspace when restrict_to_workspace is enabled (#2826). Without
|
||||
# this, a caller can pass working_dir="/etc" and then all absolute
|
||||
# paths under /etc would pass the _guard_command check that anchors
|
||||
# on cwd.
|
||||
if self.restrict_to_workspace and self.working_dir:
|
||||
try:
|
||||
requested = Path(cwd).expanduser().resolve()
|
||||
workspace_root = Path(self.working_dir).expanduser().resolve()
|
||||
except Exception:
|
||||
return "Error: working_dir could not be resolved"
|
||||
if requested != workspace_root and workspace_root not in requested.parents:
|
||||
return "Error: working_dir is outside the configured workspace"
|
||||
|
||||
guard_error = self._guard_command(command, cwd)
|
||||
if guard_error:
|
||||
return guard_error
|
||||
@ -208,7 +233,7 @@ class ExecTool(Tool):
|
||||
"""
|
||||
if _IS_WINDOWS:
|
||||
sr = os.environ.get("SYSTEMROOT", r"C:\Windows")
|
||||
return {
|
||||
env = {
|
||||
"SYSTEMROOT": sr,
|
||||
"COMSPEC": os.environ.get("COMSPEC", f"{sr}\\system32\\cmd.exe"),
|
||||
"USERPROFILE": os.environ.get("USERPROFILE", ""),
|
||||
@ -218,13 +243,29 @@ class ExecTool(Tool):
|
||||
"TMP": os.environ.get("TMP", f"{sr}\\Temp"),
|
||||
"PATHEXT": os.environ.get("PATHEXT", ".COM;.EXE;.BAT;.CMD"),
|
||||
"PATH": os.environ.get("PATH", f"{sr}\\system32;{sr}"),
|
||||
"APPDATA": os.environ.get("APPDATA", ""),
|
||||
"LOCALAPPDATA": os.environ.get("LOCALAPPDATA", ""),
|
||||
"ProgramData": os.environ.get("ProgramData", ""),
|
||||
"ProgramFiles": os.environ.get("ProgramFiles", ""),
|
||||
"ProgramFiles(x86)": os.environ.get("ProgramFiles(x86)", ""),
|
||||
"ProgramW6432": os.environ.get("ProgramW6432", ""),
|
||||
}
|
||||
for key in self.allowed_env_keys:
|
||||
val = os.environ.get(key)
|
||||
if val is not None:
|
||||
env[key] = val
|
||||
return env
|
||||
home = os.environ.get("HOME", "/tmp")
|
||||
return {
|
||||
env = {
|
||||
"HOME": home,
|
||||
"LANG": os.environ.get("LANG", "C.UTF-8"),
|
||||
"TERM": os.environ.get("TERM", "dumb"),
|
||||
}
|
||||
for key in self.allowed_env_keys:
|
||||
val = os.environ.get(key)
|
||||
if val is not None:
|
||||
env[key] = val
|
||||
return env
|
||||
|
||||
def _guard_command(self, command: str, cwd: str) -> str | None:
|
||||
"""Best-effort safety guard for potentially destructive commands."""
|
||||
|
||||
@ -96,10 +96,37 @@ class WebSearchTool(Tool):
|
||||
self.config = config if config is not None else WebSearchConfig()
|
||||
self.proxy = proxy
|
||||
|
||||
def _effective_provider(self) -> str:
|
||||
"""Resolve the backend that execute() will actually use."""
|
||||
provider = self.config.provider.strip().lower() or "brave"
|
||||
if provider == "duckduckgo":
|
||||
return "duckduckgo"
|
||||
if provider == "brave":
|
||||
api_key = self.config.api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||
return "brave" if api_key else "duckduckgo"
|
||||
if provider == "tavily":
|
||||
api_key = self.config.api_key or os.environ.get("TAVILY_API_KEY", "")
|
||||
return "tavily" if api_key else "duckduckgo"
|
||||
if provider == "searxng":
|
||||
base_url = (self.config.base_url or os.environ.get("SEARXNG_BASE_URL", "")).strip()
|
||||
return "searxng" if base_url else "duckduckgo"
|
||||
if provider == "jina":
|
||||
api_key = self.config.api_key or os.environ.get("JINA_API_KEY", "")
|
||||
return "jina" if api_key else "duckduckgo"
|
||||
if provider == "kagi":
|
||||
api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "")
|
||||
return "kagi" if api_key else "duckduckgo"
|
||||
return provider
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
"""DuckDuckGo searches are serialized because ddgs is not concurrency-safe."""
|
||||
return self._effective_provider() == "duckduckgo"
|
||||
|
||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||
provider = self.config.provider.strip().lower() or "brave"
|
||||
n = min(max(count or self.config.max_results, 1), 10)
|
||||
@ -114,6 +141,8 @@ class WebSearchTool(Tool):
|
||||
return await self._search_jina(query, n)
|
||||
elif provider == "brave":
|
||||
return await self._search_brave(query, n)
|
||||
elif provider == "kagi":
|
||||
return await self._search_kagi(query, n)
|
||||
else:
|
||||
return f"Error: unknown search provider '{provider}'"
|
||||
|
||||
@ -204,6 +233,29 @@ class WebSearchTool(Tool):
|
||||
logger.warning("Jina search failed ({}), falling back to DuckDuckGo", e)
|
||||
return await self._search_duckduckgo(query, n)
|
||||
|
||||
async def _search_kagi(self, query: str, n: int) -> str:
|
||||
api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "")
|
||||
if not api_key:
|
||||
logger.warning("KAGI_API_KEY not set, falling back to DuckDuckGo")
|
||||
return await self._search_duckduckgo(query, n)
|
||||
try:
|
||||
async with httpx.AsyncClient(proxy=self.proxy) as client:
|
||||
r = await client.get(
|
||||
"https://kagi.com/api/v0/search",
|
||||
params={"q": query, "limit": n},
|
||||
headers={"Authorization": f"Bot {api_key}"},
|
||||
timeout=10.0,
|
||||
)
|
||||
r.raise_for_status()
|
||||
# t=0 items are search results; other values are related searches, etc.
|
||||
items = [
|
||||
{"title": d.get("title", ""), "url": d.get("url", ""), "content": d.get("snippet", "")}
|
||||
for d in r.json().get("data", []) if d.get("t") == 0
|
||||
]
|
||||
return _format_results(query, items, n)
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
async def _search_duckduckgo(self, query: str, n: int) -> str:
|
||||
try:
|
||||
# Note: duckduckgo_search is synchronous and does its own requests
|
||||
|
||||
@ -84,6 +84,10 @@ def _save_base64_data_url(data_url: str, media_dir: Path) -> str | None:
|
||||
raw = base64.b64decode(b64_payload)
|
||||
except Exception:
|
||||
return None
|
||||
if len(raw) > MAX_FILE_SIZE:
|
||||
raise _FileSizeExceeded(
|
||||
f"File exceeds {MAX_FILE_SIZE // (1024 * 1024)}MB limit"
|
||||
)
|
||||
ext = mimetypes.guess_extension(mime_type) or ".bin"
|
||||
filename = f"{uuid.uuid4().hex[:12]}{ext}"
|
||||
dest = media_dir / safe_filename(filename)
|
||||
|
||||
@ -5,6 +5,8 @@ import json
|
||||
import mimetypes
|
||||
import os
|
||||
import time
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import unquote, urlparse
|
||||
@ -171,6 +173,7 @@ class DingTalkChannel(BaseChannel):
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
||||
_ZIP_BEFORE_UPLOAD_EXTS = {".htm", ".html"}
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
@ -287,6 +290,31 @@ class DingTalkChannel(BaseChannel):
|
||||
name = os.path.basename(urlparse(media_ref).path)
|
||||
return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
|
||||
|
||||
@staticmethod
|
||||
def _zip_bytes(filename: str, data: bytes) -> tuple[bytes, str, str]:
|
||||
stem = Path(filename).stem or "attachment"
|
||||
safe_name = filename or "attachment.bin"
|
||||
zip_name = f"{stem}.zip"
|
||||
buffer = BytesIO()
|
||||
with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
|
||||
archive.writestr(safe_name, data)
|
||||
return buffer.getvalue(), zip_name, "application/zip"
|
||||
|
||||
def _normalize_upload_payload(
|
||||
self,
|
||||
filename: str,
|
||||
data: bytes,
|
||||
content_type: str | None,
|
||||
) -> tuple[bytes, str, str | None]:
|
||||
ext = Path(filename).suffix.lower()
|
||||
if ext in self._ZIP_BEFORE_UPLOAD_EXTS or content_type == "text/html":
|
||||
logger.info(
|
||||
"DingTalk does not accept raw HTML attachments, zipping {} before upload",
|
||||
filename,
|
||||
)
|
||||
return self._zip_bytes(filename, data)
|
||||
return data, filename, content_type
|
||||
|
||||
async def _read_media_bytes(
|
||||
self,
|
||||
media_ref: str,
|
||||
@ -309,6 +337,9 @@ class DingTalkChannel(BaseChannel):
|
||||
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
|
||||
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||
return resp.content, filename, content_type or None
|
||||
except httpx.TransportError as e:
|
||||
logger.error("DingTalk media download network error ref={} err={}", media_ref, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
||||
return None, None, None
|
||||
@ -360,6 +391,9 @@ class DingTalkChannel(BaseChannel):
|
||||
logger.error("DingTalk media upload missing media_id body={}", text[:500])
|
||||
return None
|
||||
return str(media_id)
|
||||
except httpx.TransportError as e:
|
||||
logger.error("DingTalk media upload network error type={} err={}", media_type, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("DingTalk media upload error type={} err={}", media_type, e)
|
||||
return None
|
||||
@ -409,6 +443,9 @@ class DingTalkChannel(BaseChannel):
|
||||
return False
|
||||
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
|
||||
return True
|
||||
except httpx.TransportError as e:
|
||||
logger.error("DingTalk network error sending message msgKey={} err={}", msg_key, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
|
||||
return False
|
||||
@ -444,6 +481,7 @@ class DingTalkChannel(BaseChannel):
|
||||
return False
|
||||
|
||||
filename = filename or self._guess_filename(media_ref, upload_type)
|
||||
data, filename, content_type = self._normalize_upload_payload(filename, data, content_type)
|
||||
file_type = Path(filename).suffix.lower().lstrip(".")
|
||||
if not file_type:
|
||||
guessed = mimetypes.guess_extension(content_type or "")
|
||||
|
||||
@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
@ -20,6 +22,7 @@ from nanobot.utils.helpers import safe_filename, split_message
|
||||
|
||||
DISCORD_AVAILABLE = importlib.util.find_spec("discord") is not None
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.abc import Messageable
|
||||
@ -34,6 +37,16 @@ MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||
TYPING_INTERVAL_S = 8
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StreamBuf:
|
||||
"""Per-chat streaming accumulator for progressive Discord message edits."""
|
||||
|
||||
text: str = ""
|
||||
message: Any | None = None
|
||||
last_edit: float = 0.0
|
||||
stream_id: str | None = None
|
||||
|
||||
|
||||
class DiscordConfig(Base):
|
||||
"""Discord channel configuration."""
|
||||
|
||||
@ -45,6 +58,10 @@ class DiscordConfig(Base):
|
||||
read_receipt_emoji: str = "👀"
|
||||
working_emoji: str = "🔧"
|
||||
working_emoji_delay: float = 2.0
|
||||
streaming: bool = True
|
||||
proxy: str | None = None
|
||||
proxy_username: str | None = None
|
||||
proxy_password: str | None = None
|
||||
|
||||
|
||||
if DISCORD_AVAILABLE:
|
||||
@ -52,8 +69,15 @@ if DISCORD_AVAILABLE:
|
||||
class DiscordBotClient(discord.Client):
|
||||
"""discord.py client that forwards events to the channel."""
|
||||
|
||||
def __init__(self, channel: DiscordChannel, *, intents: discord.Intents) -> None:
|
||||
super().__init__(intents=intents)
|
||||
def __init__(
|
||||
self,
|
||||
channel: DiscordChannel,
|
||||
*,
|
||||
intents: discord.Intents,
|
||||
proxy: str | None = None,
|
||||
proxy_auth: aiohttp.BasicAuth | None = None,
|
||||
) -> None:
|
||||
super().__init__(intents=intents, proxy=proxy, proxy_auth=proxy_auth)
|
||||
self._channel = channel
|
||||
self.tree = app_commands.CommandTree(self)
|
||||
self._register_app_commands()
|
||||
@ -117,6 +141,7 @@ if DISCORD_AVAILABLE:
|
||||
)
|
||||
|
||||
for name, description, command_text in commands:
|
||||
|
||||
@self.tree.command(name=name, description=description)
|
||||
async def command_handler(
|
||||
interaction: discord.Interaction,
|
||||
@ -173,7 +198,9 @@ if DISCORD_AVAILABLE:
|
||||
else:
|
||||
failed_media.append(Path(media_path).name)
|
||||
|
||||
for index, chunk in enumerate(self._build_chunks(msg.content or "", failed_media, sent_media)):
|
||||
for index, chunk in enumerate(
|
||||
self._build_chunks(msg.content or "", failed_media, sent_media)
|
||||
):
|
||||
kwargs: dict[str, Any] = {"content": chunk}
|
||||
if index == 0 and reference is not None and not sent_media:
|
||||
kwargs["reference"] = reference
|
||||
@ -242,6 +269,7 @@ class DiscordChannel(BaseChannel):
|
||||
|
||||
name = "discord"
|
||||
display_name = "Discord"
|
||||
_STREAM_EDIT_INTERVAL = 0.8
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
@ -263,6 +291,7 @@ class DiscordChannel(BaseChannel):
|
||||
self._bot_user_id: str | None = None
|
||||
self._pending_reactions: dict[str, Any] = {} # chat_id -> message object
|
||||
self._working_emoji_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._stream_bufs: dict[str, _StreamBuf] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Discord client."""
|
||||
@ -277,7 +306,29 @@ class DiscordChannel(BaseChannel):
|
||||
try:
|
||||
intents = discord.Intents.none()
|
||||
intents.value = self.config.intents
|
||||
self._client = DiscordBotClient(self, intents=intents)
|
||||
|
||||
proxy_auth = None
|
||||
has_user = bool(self.config.proxy_username)
|
||||
has_pass = bool(self.config.proxy_password)
|
||||
if has_user and has_pass:
|
||||
import aiohttp
|
||||
|
||||
proxy_auth = aiohttp.BasicAuth(
|
||||
login=self.config.proxy_username,
|
||||
password=self.config.proxy_password,
|
||||
)
|
||||
elif has_user != has_pass:
|
||||
logger.warning(
|
||||
"Discord proxy auth incomplete: both proxy_username and "
|
||||
"proxy_password must be set; ignoring partial credentials",
|
||||
)
|
||||
|
||||
self._client = DiscordBotClient(
|
||||
self,
|
||||
intents=intents,
|
||||
proxy=self.config.proxy,
|
||||
proxy_auth=proxy_auth,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to initialize Discord client: {}", e)
|
||||
self._client = None
|
||||
@ -315,11 +366,71 @@ class DiscordChannel(BaseChannel):
|
||||
await client.send_outbound(msg)
|
||||
except Exception as e:
|
||||
logger.error("Error sending Discord message: {}", e)
|
||||
raise
|
||||
finally:
|
||||
if not is_progress:
|
||||
await self._stop_typing(msg.chat_id)
|
||||
await self._clear_reactions(msg.chat_id)
|
||||
|
||||
async def send_delta(
|
||||
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Progressive Discord delivery: send once, then edit until the stream ends."""
|
||||
client = self._client
|
||||
if client is None or not client.is_ready():
|
||||
logger.warning("Discord client not ready; dropping stream delta")
|
||||
return
|
||||
|
||||
meta = metadata or {}
|
||||
stream_id = meta.get("_stream_id")
|
||||
|
||||
if meta.get("_stream_end"):
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
if not buf or buf.message is None or not buf.text:
|
||||
return
|
||||
if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id:
|
||||
return
|
||||
await self._finalize_stream(chat_id, buf)
|
||||
return
|
||||
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
if buf is None or (
|
||||
stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id
|
||||
):
|
||||
buf = _StreamBuf(stream_id=stream_id)
|
||||
self._stream_bufs[chat_id] = buf
|
||||
elif buf.stream_id is None:
|
||||
buf.stream_id = stream_id
|
||||
|
||||
buf.text += delta
|
||||
if not buf.text.strip():
|
||||
return
|
||||
|
||||
target = await self._resolve_channel(chat_id)
|
||||
if target is None:
|
||||
logger.warning("Discord stream target {} unavailable", chat_id)
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
if buf.message is None:
|
||||
try:
|
||||
buf.message = await target.send(content=buf.text)
|
||||
buf.last_edit = now
|
||||
except Exception as e:
|
||||
logger.warning("Discord stream initial send failed: {}", e)
|
||||
raise
|
||||
return
|
||||
|
||||
if (now - buf.last_edit) < self._STREAM_EDIT_INTERVAL:
|
||||
return
|
||||
|
||||
try:
|
||||
await buf.message.edit(content=DiscordBotClient._build_chunks(buf.text, [], False)[0])
|
||||
buf.last_edit = now
|
||||
except Exception as e:
|
||||
logger.warning("Discord stream edit failed: {}", e)
|
||||
raise
|
||||
|
||||
async def _handle_discord_message(self, message: discord.Message) -> None:
|
||||
"""Handle incoming Discord messages from discord.py."""
|
||||
if message.author.bot:
|
||||
@ -373,6 +484,47 @@ class DiscordChannel(BaseChannel):
|
||||
"""Backward-compatible alias for legacy tests/callers."""
|
||||
await self._handle_discord_message(message)
|
||||
|
||||
async def _resolve_channel(self, chat_id: str) -> Any | None:
|
||||
"""Resolve a Discord channel from cache first, then network fetch."""
|
||||
client = self._client
|
||||
if client is None or not client.is_ready():
|
||||
return None
|
||||
channel_id = int(chat_id)
|
||||
channel = client.get_channel(channel_id)
|
||||
if channel is not None:
|
||||
return channel
|
||||
try:
|
||||
return await client.fetch_channel(channel_id)
|
||||
except Exception as e:
|
||||
logger.warning("Discord channel {} unavailable: {}", chat_id, e)
|
||||
return None
|
||||
|
||||
async def _finalize_stream(self, chat_id: str, buf: _StreamBuf) -> None:
|
||||
"""Commit the final streamed content and flush overflow chunks."""
|
||||
chunks = DiscordBotClient._build_chunks(buf.text, [], False)
|
||||
if not chunks:
|
||||
self._stream_bufs.pop(chat_id, None)
|
||||
return
|
||||
|
||||
try:
|
||||
await buf.message.edit(content=chunks[0])
|
||||
except Exception as e:
|
||||
logger.warning("Discord final stream edit failed: {}", e)
|
||||
raise
|
||||
|
||||
target = getattr(buf.message, "channel", None) or await self._resolve_channel(chat_id)
|
||||
if target is None:
|
||||
logger.warning("Discord stream follow-up target {} unavailable", chat_id)
|
||||
self._stream_bufs.pop(chat_id, None)
|
||||
return
|
||||
|
||||
for extra_chunk in chunks[1:]:
|
||||
await target.send(content=extra_chunk)
|
||||
|
||||
self._stream_bufs.pop(chat_id, None)
|
||||
await self._stop_typing(chat_id)
|
||||
await self._clear_reactions(chat_id)
|
||||
|
||||
def _should_accept_inbound(
|
||||
self,
|
||||
message: discord.Message,
|
||||
@ -423,7 +575,11 @@ class DiscordChannel(BaseChannel):
|
||||
@staticmethod
|
||||
def _build_inbound_metadata(message: discord.Message) -> dict[str, str | None]:
|
||||
"""Build metadata for inbound Discord messages."""
|
||||
reply_to = str(message.reference.message_id) if message.reference and message.reference.message_id else None
|
||||
reply_to = (
|
||||
str(message.reference.message_id)
|
||||
if message.reference and message.reference.message_id
|
||||
else None
|
||||
)
|
||||
return {
|
||||
"message_id": str(message.id),
|
||||
"guild_id": str(message.guild.id) if message.guild else None,
|
||||
@ -438,7 +594,9 @@ class DiscordChannel(BaseChannel):
|
||||
if self.config.group_policy == "mention":
|
||||
bot_user_id = self._bot_user_id
|
||||
if bot_user_id is None:
|
||||
logger.debug("Discord message in {} ignored (bot identity unavailable)", message.channel.id)
|
||||
logger.debug(
|
||||
"Discord message in {} ignored (bot identity unavailable)", message.channel.id
|
||||
)
|
||||
return False
|
||||
|
||||
if any(str(user.id) == bot_user_id for user in message.mentions):
|
||||
@ -480,7 +638,6 @@ class DiscordChannel(BaseChannel):
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
async def _clear_reactions(self, chat_id: str) -> None:
|
||||
"""Remove all pending reactions after bot replies."""
|
||||
# Cancel delayed working emoji if it hasn't fired yet
|
||||
@ -507,6 +664,7 @@ class DiscordChannel(BaseChannel):
|
||||
async def _reset_runtime_state(self, close_client: bool) -> None:
|
||||
"""Reset client and typing state."""
|
||||
await self._cancel_all_typing()
|
||||
self._stream_bufs.clear()
|
||||
if close_client and self._client is not None and not self._client.is_closed():
|
||||
try:
|
||||
await self._client.close()
|
||||
|
||||
@ -22,6 +22,8 @@ from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
|
||||
|
||||
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||
|
||||
# Message type display mapping
|
||||
@ -250,9 +252,12 @@ class FeishuConfig(Base):
|
||||
verification_token: str = ""
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
react_emoji: str = "THUMBSUP"
|
||||
done_emoji: str | None = None # Emoji to show when task is completed (e.g., "DONE", "OK")
|
||||
tool_hint_prefix: str = "\U0001f527" # Prefix for inline tool hints (default: 🔧)
|
||||
group_policy: Literal["open", "mention"] = "mention"
|
||||
reply_to_message: bool = False # If True, bot replies quote the user's original message
|
||||
streaming: bool = True
|
||||
domain: Literal["feishu", "lark"] = "feishu" # Set to "lark" for international Lark
|
||||
|
||||
|
||||
_STREAM_ELEMENT_ID = "streaming_md"
|
||||
@ -326,10 +331,12 @@ class FeishuChannel(BaseChannel):
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Create Lark client for sending messages
|
||||
domain = LARK_DOMAIN if self.config.domain == "lark" else FEISHU_DOMAIN
|
||||
self._client = (
|
||||
lark.Client.builder()
|
||||
.app_id(self.config.app_id)
|
||||
.app_secret(self.config.app_secret)
|
||||
.domain(domain)
|
||||
.log_level(lark.LogLevel.INFO)
|
||||
.build()
|
||||
)
|
||||
@ -357,6 +364,7 @@ class FeishuChannel(BaseChannel):
|
||||
self._ws_client = lark.ws.Client(
|
||||
self.config.app_id,
|
||||
self.config.app_secret,
|
||||
domain=domain,
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.INFO,
|
||||
)
|
||||
@ -1012,14 +1020,29 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
elif msg_type in ("audio", "file", "media"):
|
||||
file_key = content_json.get("file_key")
|
||||
if file_key and message_id:
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_file_sync, message_id, file_key, msg_type
|
||||
)
|
||||
if not filename:
|
||||
filename = file_key[:16]
|
||||
if msg_type == "audio" and not filename.endswith(".opus"):
|
||||
filename = f"{filename}.opus"
|
||||
if not file_key:
|
||||
logger.warning("Feishu {} message missing file_key: {}", msg_type, content_json)
|
||||
return None, f"[{msg_type}: missing file_key]"
|
||||
if not message_id:
|
||||
logger.warning("Feishu {} message missing message_id", msg_type)
|
||||
return None, f"[{msg_type}: missing message_id]"
|
||||
|
||||
data, filename = await loop.run_in_executor(
|
||||
None, self._download_file_sync, message_id, file_key, msg_type
|
||||
)
|
||||
|
||||
if not data:
|
||||
logger.warning("Feishu {} download failed: file_key={}", msg_type, file_key)
|
||||
return None, f"[{msg_type}: download failed]"
|
||||
|
||||
if not filename:
|
||||
filename = file_key[:16]
|
||||
|
||||
# Feishu voice messages are opus in OGG container.
|
||||
# Use .ogg extension for better Whisper compatibility.
|
||||
if msg_type == "audio":
|
||||
if not any(filename.endswith(ext) for ext in (".opus", ".ogg", ".oga")):
|
||||
filename = f"{filename}.ogg"
|
||||
|
||||
if data and filename:
|
||||
file_path = media_dir / filename
|
||||
@ -1263,7 +1286,14 @@ class FeishuChannel(BaseChannel):
|
||||
async def send_delta(
|
||||
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Progressive streaming via CardKit: create card on first delta, stream-update on subsequent."""
|
||||
"""Progressive streaming via CardKit: create card on first delta, stream-update on subsequent.
|
||||
|
||||
Supported metadata keys:
|
||||
_stream_end: Finalize the streaming card.
|
||||
_tool_hint: Delta is a formatted tool hint (for display only).
|
||||
message_id: Original message id (used with _stream_end for reaction cleanup).
|
||||
reaction_id: Reaction id to remove on stream end.
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
meta = metadata or {}
|
||||
@ -1274,38 +1304,48 @@ class FeishuChannel(BaseChannel):
|
||||
if meta.get("_stream_end"):
|
||||
if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")):
|
||||
await self._remove_reaction(message_id, reaction_id)
|
||||
# Add completion emoji if configured
|
||||
if self.config.done_emoji and message_id:
|
||||
await self._add_reaction(message_id, self.config.done_emoji)
|
||||
|
||||
buf = self._stream_bufs.pop(chat_id, None)
|
||||
if not buf or not buf.text:
|
||||
return
|
||||
# Try to finalize via streaming card; if that fails (e.g.
|
||||
# streaming mode was closed by Feishu due to timeout), fall
|
||||
# back to sending a regular interactive card.
|
||||
if buf.card_id:
|
||||
buf.sequence += 1
|
||||
await loop.run_in_executor(
|
||||
ok = await loop.run_in_executor(
|
||||
None,
|
||||
self._stream_update_text_sync,
|
||||
buf.card_id,
|
||||
buf.text,
|
||||
buf.sequence,
|
||||
)
|
||||
# Required so the chat list preview exits the streaming placeholder (Feishu streaming card docs).
|
||||
buf.sequence += 1
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self._close_streaming_mode_sync,
|
||||
buf.card_id,
|
||||
buf.sequence,
|
||||
)
|
||||
else:
|
||||
for chunk in self._split_elements_by_table_limit(
|
||||
self._build_card_elements(buf.text)
|
||||
):
|
||||
card = json.dumps(
|
||||
{"config": {"wide_screen_mode": True}, "elements": chunk},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
if ok:
|
||||
buf.sequence += 1
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, rid_type, chat_id, "interactive", card
|
||||
None,
|
||||
self._close_streaming_mode_sync,
|
||||
buf.card_id,
|
||||
buf.sequence,
|
||||
)
|
||||
return
|
||||
logger.warning(
|
||||
"Streaming card {} final update failed, falling back to regular card",
|
||||
buf.card_id,
|
||||
)
|
||||
for chunk in self._split_elements_by_table_limit(
|
||||
self._build_card_elements(buf.text)
|
||||
):
|
||||
card = json.dumps(
|
||||
{"config": {"wide_screen_mode": True}, "elements": chunk},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, rid_type, chat_id, "interactive", card
|
||||
)
|
||||
return
|
||||
|
||||
# --- accumulate delta ---
|
||||
@ -1346,13 +1386,33 @@ class FeishuChannel(BaseChannel):
|
||||
receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id"
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Handle tool hint messages as code blocks in interactive cards.
|
||||
# These are progress-only messages and should bypass normal reply routing.
|
||||
# Handle tool hint messages. When a streaming card is active for
|
||||
# this chat, inline the hint into the card instead of sending a
|
||||
# separate message so the user experience stays cohesive.
|
||||
if msg.metadata.get("_tool_hint"):
|
||||
if msg.content and msg.content.strip():
|
||||
await self._send_tool_hint_card(
|
||||
receive_id_type, msg.chat_id, msg.content.strip()
|
||||
hint = (msg.content or "").strip()
|
||||
if not hint:
|
||||
return
|
||||
buf = self._stream_bufs.get(msg.chat_id)
|
||||
if buf and buf.card_id:
|
||||
# Delegate to send_delta so tool hints get the same
|
||||
# throttling (and card creation) as regular text deltas.
|
||||
await self.send_delta(
|
||||
msg.chat_id,
|
||||
"\n\n" + self._format_tool_hint_delta(hint) + "\n\n",
|
||||
)
|
||||
return
|
||||
# No active streaming card — send as a regular
|
||||
# interactive card with the same 🔧 prefix style.
|
||||
card = json.dumps(
|
||||
{"config": {"wide_screen_mode": True}, "elements": [
|
||||
{"tag": "markdown", "content": self._format_tool_hint_delta(hint)},
|
||||
]},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card
|
||||
)
|
||||
return
|
||||
|
||||
# Determine whether the first message should quote the user's message.
|
||||
@ -1648,33 +1708,9 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
return "\n".join(part for part in parts if part)
|
||||
|
||||
async def _send_tool_hint_card(
|
||||
self, receive_id_type: str, receive_id: str, tool_hint: str
|
||||
) -> None:
|
||||
"""Send tool hint as an interactive card with formatted code block.
|
||||
|
||||
Args:
|
||||
receive_id_type: "chat_id" or "open_id"
|
||||
receive_id: The target chat or user ID
|
||||
tool_hint: Formatted tool hint string (e.g., 'web_search("q"), read_file("path")')
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
# Put each top-level tool call on its own line without altering commas inside arguments.
|
||||
formatted_code = self._format_tool_hint_lines(tool_hint)
|
||||
|
||||
card = {
|
||||
"config": {"wide_screen_mode": True},
|
||||
"elements": [
|
||||
{"tag": "markdown", "content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```"}
|
||||
],
|
||||
}
|
||||
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self._send_message_sync,
|
||||
receive_id_type,
|
||||
receive_id,
|
||||
"interactive",
|
||||
json.dumps(card, ensure_ascii=False),
|
||||
def _format_tool_hint_delta(self, tool_hint: str) -> str:
|
||||
"""Format a tool hint string with the 🔧 prefix for each line."""
|
||||
lines = self.__class__._format_tool_hint_lines(tool_hint).split("\n")
|
||||
return "\n".join(
|
||||
f"{self.config.tool_hint_prefix} {ln}" for ln in lines if ln.strip()
|
||||
)
|
||||
|
||||
@ -242,43 +242,49 @@ class QQChannel(BaseChannel):
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send attachments first, then text."""
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
try:
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
|
||||
msg_id = msg.metadata.get("message_id")
|
||||
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||
is_group = chat_type == "group"
|
||||
msg_id = msg.metadata.get("message_id")
|
||||
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||
is_group = chat_type == "group"
|
||||
|
||||
# 1) Send media
|
||||
for media_ref in msg.media or []:
|
||||
ok = await self._send_media(
|
||||
chat_id=msg.chat_id,
|
||||
media_ref=media_ref,
|
||||
msg_id=msg_id,
|
||||
is_group=is_group,
|
||||
)
|
||||
if not ok:
|
||||
filename = (
|
||||
os.path.basename(urlparse(media_ref).path)
|
||||
or os.path.basename(media_ref)
|
||||
or "file"
|
||||
# 1) Send media
|
||||
for media_ref in msg.media or []:
|
||||
ok = await self._send_media(
|
||||
chat_id=msg.chat_id,
|
||||
media_ref=media_ref,
|
||||
msg_id=msg_id,
|
||||
is_group=is_group,
|
||||
)
|
||||
if not ok:
|
||||
filename = (
|
||||
os.path.basename(urlparse(media_ref).path)
|
||||
or os.path.basename(media_ref)
|
||||
or "file"
|
||||
)
|
||||
await self._send_text_only(
|
||||
chat_id=msg.chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=msg_id,
|
||||
content=f"[Attachment send failed: {filename}]",
|
||||
)
|
||||
|
||||
# 2) Send text
|
||||
if msg.content and msg.content.strip():
|
||||
await self._send_text_only(
|
||||
chat_id=msg.chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=msg_id,
|
||||
content=f"[Attachment send failed: {filename}]",
|
||||
content=msg.content.strip(),
|
||||
)
|
||||
|
||||
# 2) Send text
|
||||
if msg.content and msg.content.strip():
|
||||
await self._send_text_only(
|
||||
chat_id=msg.chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=msg_id,
|
||||
content=msg.content.strip(),
|
||||
)
|
||||
except (aiohttp.ClientError, OSError):
|
||||
# Network / transport errors — propagate so ChannelManager can retry
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error sending QQ message to chat_id={}", msg.chat_id)
|
||||
|
||||
async def _send_text_only(
|
||||
self,
|
||||
@ -359,7 +365,12 @@ class QQChannel(BaseChannel):
|
||||
|
||||
logger.info("QQ media sent: {}", filename)
|
||||
return True
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
# Network / transport errors — propagate for retry by caller
|
||||
logger.warning("QQ send media network error filename={} err={}", filename, e)
|
||||
raise
|
||||
except Exception as e:
|
||||
# API-level or other non-network errors — return False so send() can fallback
|
||||
logger.error("QQ send media failed filename={} err={}", filename, e)
|
||||
return False
|
||||
|
||||
@ -438,15 +449,26 @@ class QQChannel(BaseChannel):
|
||||
endpoint = "/v2/users/{openid}/files"
|
||||
id_key = "openid"
|
||||
|
||||
payload = {
|
||||
payload: dict[str, Any] = {
|
||||
id_key: chat_id,
|
||||
"file_type": file_type,
|
||||
"file_data": file_data,
|
||||
"file_name": file_name,
|
||||
"srv_send_msg": srv_send_msg,
|
||||
}
|
||||
# Only pass file_name for non-image types (file_type=4).
|
||||
# Passing file_name for images causes QQ client to render them as
|
||||
# file attachments instead of inline images.
|
||||
if file_type != QQ_FILE_TYPE_IMAGE and file_name:
|
||||
payload["file_name"] = file_name
|
||||
|
||||
route = Route("POST", endpoint, **{id_key: chat_id})
|
||||
return await self._client.api._http.request(route, json=payload)
|
||||
result = await self._client.api._http.request(route, json=payload)
|
||||
|
||||
# Extract only the file_info field to avoid extra fields (file_uuid, ttl, etc.)
|
||||
# that may confuse QQ client when sending the media object.
|
||||
if isinstance(result, dict) and "file_info" in result:
|
||||
return {"file_info": result["file_info"]}
|
||||
return result
|
||||
|
||||
# ---------------------------
|
||||
# Inbound (receive)
|
||||
@ -454,58 +476,68 @@ class QQChannel(BaseChannel):
|
||||
|
||||
async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None:
|
||||
"""Parse inbound message, download attachments, and publish to the bus."""
|
||||
if data.id in self._processed_ids:
|
||||
return
|
||||
self._processed_ids.append(data.id)
|
||||
try:
|
||||
if data.id in self._processed_ids:
|
||||
return
|
||||
self._processed_ids.append(data.id)
|
||||
|
||||
if is_group:
|
||||
chat_id = data.group_openid
|
||||
user_id = data.author.member_openid
|
||||
self._chat_type_cache[chat_id] = "group"
|
||||
else:
|
||||
chat_id = str(
|
||||
getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown")
|
||||
)
|
||||
user_id = chat_id
|
||||
self._chat_type_cache[chat_id] = "c2c"
|
||||
|
||||
content = (data.content or "").strip()
|
||||
|
||||
# the data used by tests don't contain attachments property
|
||||
# so we use getattr with a default of [] to avoid AttributeError in tests
|
||||
attachments = getattr(data, "attachments", None) or []
|
||||
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
|
||||
|
||||
# Compose content that always contains actionable saved paths
|
||||
if recv_lines:
|
||||
tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]"
|
||||
file_block = "Received files:\n" + "\n".join(recv_lines)
|
||||
content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
|
||||
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
if self.config.ack_message:
|
||||
try:
|
||||
await self._send_text_only(
|
||||
chat_id=chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=data.id,
|
||||
content=self.config.ack_message,
|
||||
if is_group:
|
||||
chat_id = data.group_openid
|
||||
user_id = data.author.member_openid
|
||||
self._chat_type_cache[chat_id] = "group"
|
||||
else:
|
||||
chat_id = str(
|
||||
getattr(data.author, "id", None)
|
||||
or getattr(data.author, "user_openid", "unknown")
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("QQ ack message failed for chat_id={}", chat_id)
|
||||
user_id = chat_id
|
||||
self._chat_type_cache[chat_id] = "c2c"
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media_paths if media_paths else None,
|
||||
metadata={
|
||||
"message_id": data.id,
|
||||
"attachments": att_meta,
|
||||
},
|
||||
)
|
||||
content = (data.content or "").strip()
|
||||
|
||||
# the data used by tests don't contain attachments property
|
||||
# so we use getattr with a default of [] to avoid AttributeError in tests
|
||||
attachments = getattr(data, "attachments", None) or []
|
||||
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
|
||||
|
||||
# Compose content that always contains actionable saved paths
|
||||
if recv_lines:
|
||||
tag = (
|
||||
"[Image]"
|
||||
if any(_is_image_name(Path(p).name) for p in media_paths)
|
||||
else "[File]"
|
||||
)
|
||||
file_block = "Received files:\n" + "\n".join(recv_lines)
|
||||
content = (
|
||||
f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
|
||||
)
|
||||
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
if self.config.ack_message:
|
||||
try:
|
||||
await self._send_text_only(
|
||||
chat_id=chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=data.id,
|
||||
content=self.config.ack_message,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("QQ ack message failed for chat_id={}", chat_id)
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media_paths if media_paths else None,
|
||||
metadata={
|
||||
"message_id": data.id,
|
||||
"attachments": att_meta,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling QQ inbound message id={}", getattr(data, "id", "?"))
|
||||
|
||||
async def _handle_attachments(
|
||||
self,
|
||||
@ -520,7 +552,9 @@ class QQChannel(BaseChannel):
|
||||
return media_paths, recv_lines, att_meta
|
||||
|
||||
for att in attachments:
|
||||
url, filename, ctype = att.url, att.filename, att.content_type
|
||||
url = getattr(att, "url", None) or ""
|
||||
filename = getattr(att, "filename", None) or ""
|
||||
ctype = getattr(att, "content_type", None) or ""
|
||||
|
||||
logger.info("Downloading file from QQ: {}", filename or url)
|
||||
local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename)
|
||||
@ -555,6 +589,10 @@ class QQChannel(BaseChannel):
|
||||
Enforces a max download size and writes to a .part temp file
|
||||
that is atomically renamed on success.
|
||||
"""
|
||||
# Handle protocol-relative URLs (e.g. "//multimedia.nt.qq.com/...")
|
||||
if url.startswith("//"):
|
||||
url = f"https:{url}"
|
||||
|
||||
if not self._http:
|
||||
self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import re
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||
@ -13,8 +14,6 @@ from slackify_markdown import slackify_markdown
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from pydantic import Field
|
||||
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
@ -50,6 +49,9 @@ class SlackChannel(BaseChannel):
|
||||
|
||||
name = "slack"
|
||||
display_name = "Slack"
|
||||
_SLACK_ID_RE = re.compile(r"^[CDGUW][A-Z0-9]{2,}$")
|
||||
_SLACK_CHANNEL_REF_RE = re.compile(r"^<#([A-Z0-9]+)(?:\|[^>]+)?>$")
|
||||
_SLACK_USER_REF_RE = re.compile(r"^<@([A-Z0-9]+)(?:\|[^>]+)?>$")
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
@ -63,6 +65,7 @@ class SlackChannel(BaseChannel):
|
||||
self._web_client: AsyncWebClient | None = None
|
||||
self._socket_client: SocketModeClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
self._target_cache: dict[str, str] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Slack Socket Mode client."""
|
||||
@ -113,17 +116,23 @@ class SlackChannel(BaseChannel):
|
||||
logger.warning("Slack client not running")
|
||||
return
|
||||
try:
|
||||
target_chat_id = await self._resolve_target_chat_id(msg.chat_id)
|
||||
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
|
||||
thread_ts = slack_meta.get("thread_ts")
|
||||
channel_type = slack_meta.get("channel_type")
|
||||
origin_chat_id = str((slack_meta.get("event", {}) or {}).get("channel") or msg.chat_id)
|
||||
# Slack DMs don't use threads; channel/group replies may keep thread_ts.
|
||||
thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None
|
||||
thread_ts_param = (
|
||||
thread_ts
|
||||
if thread_ts and channel_type != "im" and target_chat_id == origin_chat_id
|
||||
else None
|
||||
)
|
||||
|
||||
# Slack rejects empty text payloads. Keep media-only messages media-only,
|
||||
# but send a single blank message when the bot has no text or files to send.
|
||||
if msg.content or not (msg.media or []):
|
||||
await self._web_client.chat_postMessage(
|
||||
channel=msg.chat_id,
|
||||
channel=target_chat_id,
|
||||
text=self._to_mrkdwn(msg.content) if msg.content else " ",
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
@ -131,7 +140,7 @@ class SlackChannel(BaseChannel):
|
||||
for media_path in msg.media or []:
|
||||
try:
|
||||
await self._web_client.files_upload_v2(
|
||||
channel=msg.chat_id,
|
||||
channel=target_chat_id,
|
||||
file=media_path,
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
@ -141,12 +150,123 @@ class SlackChannel(BaseChannel):
|
||||
# Update reaction emoji when the final (non-progress) response is sent
|
||||
if not (msg.metadata or {}).get("_progress"):
|
||||
event = slack_meta.get("event", {})
|
||||
await self._update_react_emoji(msg.chat_id, event.get("ts"))
|
||||
await self._update_react_emoji(origin_chat_id, event.get("ts"))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending Slack message: {}", e)
|
||||
raise
|
||||
|
||||
async def _resolve_target_chat_id(self, target: str) -> str:
|
||||
"""Resolve human-friendly Slack targets to concrete IDs when needed."""
|
||||
if not self._web_client:
|
||||
return target
|
||||
|
||||
target = target.strip()
|
||||
if not target:
|
||||
return target
|
||||
|
||||
if match := self._SLACK_CHANNEL_REF_RE.fullmatch(target):
|
||||
return match.group(1)
|
||||
if match := self._SLACK_USER_REF_RE.fullmatch(target):
|
||||
return await self._open_dm_for_user(match.group(1))
|
||||
if self._SLACK_ID_RE.fullmatch(target):
|
||||
if target.startswith(("U", "W")):
|
||||
return await self._open_dm_for_user(target)
|
||||
return target
|
||||
|
||||
if target.startswith("#"):
|
||||
return await self._resolve_channel_name(target[1:])
|
||||
if target.startswith("@"):
|
||||
return await self._resolve_user_handle(target[1:])
|
||||
|
||||
try:
|
||||
return await self._resolve_channel_name(target)
|
||||
except ValueError:
|
||||
return await self._resolve_user_handle(target)
|
||||
|
||||
async def _resolve_channel_name(self, name: str) -> str:
|
||||
normalized = self._normalize_target_name(name)
|
||||
if not normalized:
|
||||
raise ValueError("Slack target channel name is empty")
|
||||
|
||||
cache_key = f"channel:{normalized}"
|
||||
if cache_key in self._target_cache:
|
||||
return self._target_cache[cache_key]
|
||||
|
||||
cursor: str | None = None
|
||||
while True:
|
||||
response = await self._web_client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=200,
|
||||
cursor=cursor,
|
||||
)
|
||||
for channel in response.get("channels", []):
|
||||
if self._normalize_target_name(str(channel.get("name") or "")) == normalized:
|
||||
channel_id = str(channel.get("id") or "")
|
||||
if channel_id:
|
||||
self._target_cache[cache_key] = channel_id
|
||||
return channel_id
|
||||
cursor = ((response.get("response_metadata") or {}).get("next_cursor") or "").strip()
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
raise ValueError(
|
||||
f"Slack channel '{name}' was not found. Use a joined channel name like "
|
||||
f"'#general' or a concrete channel ID."
|
||||
)
|
||||
|
||||
async def _resolve_user_handle(self, handle: str) -> str:
|
||||
normalized = self._normalize_target_name(handle)
|
||||
if not normalized:
|
||||
raise ValueError("Slack target user handle is empty")
|
||||
|
||||
cache_key = f"user:{normalized}"
|
||||
if cache_key in self._target_cache:
|
||||
return self._target_cache[cache_key]
|
||||
|
||||
cursor: str | None = None
|
||||
while True:
|
||||
response = await self._web_client.users_list(limit=200, cursor=cursor)
|
||||
for member in response.get("members", []):
|
||||
if self._member_matches_handle(member, normalized):
|
||||
user_id = str(member.get("id") or "")
|
||||
if not user_id:
|
||||
continue
|
||||
dm_id = await self._open_dm_for_user(user_id)
|
||||
self._target_cache[cache_key] = dm_id
|
||||
return dm_id
|
||||
cursor = ((response.get("response_metadata") or {}).get("next_cursor") or "").strip()
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
raise ValueError(
|
||||
f"Slack user '{handle}' was not found. Use '@name' or a concrete DM/channel ID."
|
||||
)
|
||||
|
||||
async def _open_dm_for_user(self, user_id: str) -> str:
|
||||
response = await self._web_client.conversations_open(users=user_id)
|
||||
channel_id = str(((response.get("channel") or {}).get("id")) or "")
|
||||
if not channel_id:
|
||||
raise ValueError(f"Slack DM target for user '{user_id}' could not be opened.")
|
||||
return channel_id
|
||||
|
||||
@staticmethod
|
||||
def _normalize_target_name(value: str) -> str:
|
||||
return value.strip().lstrip("#@").lower()
|
||||
|
||||
@classmethod
|
||||
def _member_matches_handle(cls, member: dict[str, Any], normalized: str) -> bool:
|
||||
profile = member.get("profile") or {}
|
||||
candidates = {
|
||||
str(member.get("name") or ""),
|
||||
str(profile.get("display_name") or ""),
|
||||
str(profile.get("display_name_normalized") or ""),
|
||||
str(profile.get("real_name") or ""),
|
||||
str(profile.get("real_name_normalized") or ""),
|
||||
}
|
||||
return normalized in {cls._normalize_target_name(candidate) for candidate in candidates if candidate}
|
||||
|
||||
async def _on_socket_request(
|
||||
self,
|
||||
client: SocketModeClient,
|
||||
|
||||
@ -166,6 +166,7 @@ def _markdown_to_telegram_html(text: str) -> str:
|
||||
|
||||
_SEND_MAX_RETRIES = 3
|
||||
_SEND_RETRY_BASE_DELAY = 0.5 # seconds, doubled each retry
|
||||
_STREAM_EDIT_INTERVAL_DEFAULT = 0.6 # min seconds between edit_message_text calls
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -190,6 +191,7 @@ class TelegramConfig(Base):
|
||||
connection_pool_size: int = 32
|
||||
pool_timeout: float = 5.0
|
||||
streaming: bool = True
|
||||
stream_edit_interval: float = Field(default=_STREAM_EDIT_INTERVAL_DEFAULT, ge=0.1)
|
||||
|
||||
|
||||
class TelegramChannel(BaseChannel):
|
||||
@ -219,8 +221,6 @@ class TelegramChannel(BaseChannel):
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
return TelegramConfig().model_dump(by_alias=True)
|
||||
|
||||
_STREAM_EDIT_INTERVAL = 0.6 # min seconds between edit_message_text calls
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
if isinstance(config, dict):
|
||||
config = TelegramConfig.model_validate(config)
|
||||
@ -520,7 +520,10 @@ class TelegramChannel(BaseChannel):
|
||||
reply_parameters=reply_params,
|
||||
**(thread_kwargs or {}),
|
||||
)
|
||||
except Exception as e:
|
||||
except BadRequest as e:
|
||||
# Only fall back to plain text on actual HTML parse/format errors.
|
||||
# Network errors (TimedOut, NetworkError) should propagate immediately
|
||||
# to avoid doubling connection demand during pool exhaustion.
|
||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||
try:
|
||||
await self._call_with_retry(
|
||||
@ -567,7 +570,10 @@ class TelegramChannel(BaseChannel):
|
||||
chat_id=int_chat_id, message_id=buf.message_id,
|
||||
text=html, parse_mode="HTML",
|
||||
)
|
||||
except Exception as e:
|
||||
except BadRequest as e:
|
||||
# Only fall back to plain text on actual HTML parse/format errors.
|
||||
# Network errors (TimedOut, NetworkError) should propagate immediately
|
||||
# to avoid doubling connection demand during pool exhaustion.
|
||||
if self._is_not_modified_error(e):
|
||||
logger.debug("Final stream edit already applied for {}", chat_id)
|
||||
self._stream_bufs.pop(chat_id, None)
|
||||
@ -619,7 +625,7 @@ class TelegramChannel(BaseChannel):
|
||||
except Exception as e:
|
||||
logger.warning("Stream initial send failed: {}", e)
|
||||
raise # Let ChannelManager handle retry
|
||||
elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
|
||||
elif (now - buf.last_edit) >= self.config.stream_edit_interval:
|
||||
try:
|
||||
await self._call_with_retry(
|
||||
self._app.bot.edit_message_text,
|
||||
|
||||
457
nanobot/channels/websocket.py
Normal file
457
nanobot/channels/websocket.py
Normal file
@ -0,0 +1,457 @@
|
||||
"""WebSocket server channel: nanobot acts as a WebSocket server and serves connected clients."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import email.utils
|
||||
import hmac
|
||||
import http
|
||||
import json
|
||||
import secrets
|
||||
import ssl
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Self
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field, field_validator, model_validator
|
||||
from websockets.asyncio.server import ServerConnection, serve
|
||||
from websockets.datastructures import Headers
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.http11 import Request as WsRequest, Response
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
|
||||
def _strip_trailing_slash(path: str) -> str:
|
||||
if len(path) > 1 and path.endswith("/"):
|
||||
return path.rstrip("/")
|
||||
return path or "/"
|
||||
|
||||
|
||||
def _normalize_config_path(path: str) -> str:
|
||||
return _strip_trailing_slash(path)
|
||||
|
||||
|
||||
class WebSocketConfig(Base):
|
||||
"""WebSocket server channel configuration.
|
||||
|
||||
Clients connect with URLs like ``ws://{host}:{port}{path}?client_id=...&token=...``.
|
||||
- ``client_id``: Used for ``allow_from`` authorization; if omitted, a value is generated and logged.
|
||||
- ``token``: If non-empty, the ``token`` query param may match this static secret; short-lived tokens
|
||||
from ``token_issue_path`` are also accepted.
|
||||
- ``token_issue_path``: If non-empty, **GET** (HTTP/1.1) to this path returns JSON
|
||||
``{"token": "...", "expires_in": <seconds>}``; use ``?token=...`` when opening the WebSocket.
|
||||
Must differ from ``path`` (the WS upgrade path). If the client runs in the **same process** as
|
||||
nanobot and shares the asyncio loop, use a thread or async HTTP client for GET—do not call
|
||||
blocking ``urllib`` or synchronous ``httpx`` from inside a coroutine.
|
||||
- ``token_issue_secret``: If non-empty, token requests must send ``Authorization: Bearer <secret>`` or
|
||||
``X-Nanobot-Auth: <secret>``.
|
||||
- ``websocket_requires_token``: If True, the handshake must include a valid token (static or issued and not expired).
|
||||
- Each connection has its own session: a unique ``chat_id`` maps to the agent session internally.
|
||||
- ``media`` field in outbound messages contains local filesystem paths; remote clients need a
|
||||
shared filesystem or an HTTP file server to access these files.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 8765
|
||||
path: str = "/"
|
||||
token: str = ""
|
||||
token_issue_path: str = ""
|
||||
token_issue_secret: str = ""
|
||||
token_ttl_s: int = Field(default=300, ge=30, le=86_400)
|
||||
websocket_requires_token: bool = True
|
||||
allow_from: list[str] = Field(default_factory=lambda: ["*"])
|
||||
streaming: bool = True
|
||||
max_message_bytes: int = Field(default=1_048_576, ge=1024, le=16_777_216)
|
||||
ping_interval_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||||
ping_timeout_s: float = Field(default=20.0, ge=5.0, le=300.0)
|
||||
ssl_certfile: str = ""
|
||||
ssl_keyfile: str = ""
|
||||
|
||||
@field_validator("path")
|
||||
@classmethod
|
||||
def path_must_start_with_slash(cls, value: str) -> str:
|
||||
if not value.startswith("/"):
|
||||
raise ValueError('path must start with "/"')
|
||||
return _normalize_config_path(value)
|
||||
|
||||
@field_validator("token_issue_path")
|
||||
@classmethod
|
||||
def token_issue_path_format(cls, value: str) -> str:
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return ""
|
||||
if not value.startswith("/"):
|
||||
raise ValueError('token_issue_path must start with "/"')
|
||||
return _normalize_config_path(value)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def token_issue_path_differs_from_ws_path(self) -> Self:
|
||||
if not self.token_issue_path:
|
||||
return self
|
||||
if _normalize_config_path(self.token_issue_path) == _normalize_config_path(self.path):
|
||||
raise ValueError("token_issue_path must differ from path (the WebSocket upgrade path)")
|
||||
return self
|
||||
|
||||
|
||||
def _http_json_response(data: dict[str, Any], *, status: int = 200) -> Response:
|
||||
body = json.dumps(data, ensure_ascii=False).encode("utf-8")
|
||||
headers = Headers(
|
||||
[
|
||||
("Date", email.utils.formatdate(usegmt=True)),
|
||||
("Connection", "close"),
|
||||
("Content-Length", str(len(body))),
|
||||
("Content-Type", "application/json; charset=utf-8"),
|
||||
]
|
||||
)
|
||||
reason = http.HTTPStatus(status).phrase
|
||||
return Response(status, reason, headers, body)
|
||||
|
||||
|
||||
def _parse_request_path(path_with_query: str) -> tuple[str, dict[str, list[str]]]:
|
||||
"""Parse normalized path and query parameters in one pass."""
|
||||
parsed = urlparse("ws://x" + path_with_query)
|
||||
path = _strip_trailing_slash(parsed.path or "/")
|
||||
return path, parse_qs(parsed.query)
|
||||
|
||||
|
||||
def _normalize_http_path(path_with_query: str) -> str:
|
||||
"""Return the path component (no query string), with trailing slash normalized (root stays ``/``)."""
|
||||
return _parse_request_path(path_with_query)[0]
|
||||
|
||||
|
||||
def _parse_query(path_with_query: str) -> dict[str, list[str]]:
|
||||
return _parse_request_path(path_with_query)[1]
|
||||
|
||||
|
||||
def _query_first(query: dict[str, list[str]], key: str) -> str | None:
|
||||
"""Return the first value for *key*, or None."""
|
||||
values = query.get(key)
|
||||
return values[0] if values else None
|
||||
|
||||
|
||||
def _parse_inbound_payload(raw: str) -> str | None:
|
||||
"""Parse a client frame into text; return None for empty or unrecognized content."""
|
||||
text = raw.strip()
|
||||
if not text:
|
||||
return None
|
||||
if text.startswith("{"):
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return text
|
||||
if isinstance(data, dict):
|
||||
for key in ("content", "text", "message"):
|
||||
value = data.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value
|
||||
return None
|
||||
return None
|
||||
return text
|
||||
|
||||
|
||||
def _issue_route_secret_matches(headers: Any, configured_secret: str) -> bool:
|
||||
"""Return True if the token-issue HTTP request carries credentials matching ``token_issue_secret``."""
|
||||
if not configured_secret:
|
||||
return True
|
||||
authorization = headers.get("Authorization") or headers.get("authorization")
|
||||
if authorization and authorization.lower().startswith("bearer "):
|
||||
supplied = authorization[7:].strip()
|
||||
return hmac.compare_digest(supplied, configured_secret)
|
||||
header_token = headers.get("X-Nanobot-Auth") or headers.get("x-nanobot-auth")
|
||||
if not header_token:
|
||||
return False
|
||||
return hmac.compare_digest(header_token.strip(), configured_secret)
|
||||
|
||||
|
||||
class WebSocketChannel(BaseChannel):
|
||||
"""Run a local WebSocket server; forward text/JSON messages to the message bus."""
|
||||
|
||||
name = "websocket"
|
||||
display_name = "WebSocket"
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
if isinstance(config, dict):
|
||||
config = WebSocketConfig.model_validate(config)
|
||||
super().__init__(config, bus)
|
||||
self.config: WebSocketConfig = config
|
||||
self._connections: dict[str, Any] = {}
|
||||
self._issued_tokens: dict[str, float] = {}
|
||||
self._stop_event: asyncio.Event | None = None
|
||||
self._server_task: asyncio.Task[None] | None = None
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
return WebSocketConfig().model_dump(by_alias=True)
|
||||
|
||||
def _expected_path(self) -> str:
|
||||
return _normalize_config_path(self.config.path)
|
||||
|
||||
def _build_ssl_context(self) -> ssl.SSLContext | None:
|
||||
cert = self.config.ssl_certfile.strip()
|
||||
key = self.config.ssl_keyfile.strip()
|
||||
if not cert and not key:
|
||||
return None
|
||||
if not cert or not key:
|
||||
raise ValueError(
|
||||
"websocket: ssl_certfile and ssl_keyfile must both be set for WSS, or both left empty"
|
||||
)
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
||||
ctx.load_cert_chain(certfile=cert, keyfile=key)
|
||||
return ctx
|
||||
|
||||
_MAX_ISSUED_TOKENS = 10_000
|
||||
|
||||
def _purge_expired_issued_tokens(self) -> None:
|
||||
now = time.monotonic()
|
||||
for token_key, expiry in list(self._issued_tokens.items()):
|
||||
if now > expiry:
|
||||
self._issued_tokens.pop(token_key, None)
|
||||
|
||||
def _take_issued_token_if_valid(self, token_value: str | None) -> bool:
|
||||
"""Validate and consume one issued token (single use per connection attempt).
|
||||
|
||||
Uses single-step pop to minimize the window between lookup and removal;
|
||||
safe under asyncio's single-threaded cooperative model.
|
||||
"""
|
||||
if not token_value:
|
||||
return False
|
||||
self._purge_expired_issued_tokens()
|
||||
expiry = self._issued_tokens.pop(token_value, None)
|
||||
if expiry is None:
|
||||
return False
|
||||
if time.monotonic() > expiry:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _handle_token_issue_http(self, connection: Any, request: Any) -> Any:
|
||||
secret = self.config.token_issue_secret.strip()
|
||||
if secret:
|
||||
if not _issue_route_secret_matches(request.headers, secret):
|
||||
return connection.respond(401, "Unauthorized")
|
||||
else:
|
||||
logger.warning(
|
||||
"websocket: token_issue_path is set but token_issue_secret is empty; "
|
||||
"any client can obtain connection tokens — set token_issue_secret for production."
|
||||
)
|
||||
self._purge_expired_issued_tokens()
|
||||
if len(self._issued_tokens) >= self._MAX_ISSUED_TOKENS:
|
||||
logger.error(
|
||||
"websocket: too many outstanding issued tokens ({}), rejecting issuance",
|
||||
len(self._issued_tokens),
|
||||
)
|
||||
return _http_json_response({"error": "too many outstanding tokens"}, status=429)
|
||||
token_value = f"nbwt_{secrets.token_urlsafe(32)}"
|
||||
self._issued_tokens[token_value] = time.monotonic() + float(self.config.token_ttl_s)
|
||||
|
||||
return _http_json_response(
|
||||
{"token": token_value, "expires_in": self.config.token_ttl_s}
|
||||
)
|
||||
|
||||
def _authorize_websocket_handshake(self, connection: Any, query: dict[str, list[str]]) -> Any:
|
||||
supplied = _query_first(query, "token")
|
||||
static_token = self.config.token.strip()
|
||||
|
||||
if static_token:
|
||||
if supplied and hmac.compare_digest(supplied, static_token):
|
||||
return None
|
||||
if supplied and self._take_issued_token_if_valid(supplied):
|
||||
return None
|
||||
return connection.respond(401, "Unauthorized")
|
||||
|
||||
if self.config.websocket_requires_token:
|
||||
if supplied and self._take_issued_token_if_valid(supplied):
|
||||
return None
|
||||
return connection.respond(401, "Unauthorized")
|
||||
|
||||
if supplied:
|
||||
self._take_issued_token_if_valid(supplied)
|
||||
return None
|
||||
|
||||
async def start(self) -> None:
|
||||
self._running = True
|
||||
self._stop_event = asyncio.Event()
|
||||
|
||||
ssl_context = self._build_ssl_context()
|
||||
scheme = "wss" if ssl_context else "ws"
|
||||
|
||||
async def process_request(
|
||||
connection: ServerConnection,
|
||||
request: WsRequest,
|
||||
) -> Any:
|
||||
got, _ = _parse_request_path(request.path)
|
||||
if self.config.token_issue_path:
|
||||
issue_expected = _normalize_config_path(self.config.token_issue_path)
|
||||
if got == issue_expected:
|
||||
return self._handle_token_issue_http(connection, request)
|
||||
|
||||
expected_ws = self._expected_path()
|
||||
if got != expected_ws:
|
||||
return connection.respond(404, "Not Found")
|
||||
# Early reject before WebSocket upgrade to avoid unnecessary overhead;
|
||||
# _handle_message() performs a second check as defense-in-depth.
|
||||
query = _parse_query(request.path)
|
||||
client_id = _query_first(query, "client_id") or ""
|
||||
if len(client_id) > 128:
|
||||
client_id = client_id[:128]
|
||||
if not self.is_allowed(client_id):
|
||||
return connection.respond(403, "Forbidden")
|
||||
return self._authorize_websocket_handshake(connection, query)
|
||||
|
||||
async def handler(connection: ServerConnection) -> None:
|
||||
await self._connection_loop(connection)
|
||||
|
||||
logger.info(
|
||||
"WebSocket server listening on {}://{}:{}{}",
|
||||
scheme,
|
||||
self.config.host,
|
||||
self.config.port,
|
||||
self.config.path,
|
||||
)
|
||||
if self.config.token_issue_path:
|
||||
logger.info(
|
||||
"WebSocket token issue route: {}://{}:{}{}",
|
||||
scheme,
|
||||
self.config.host,
|
||||
self.config.port,
|
||||
_normalize_config_path(self.config.token_issue_path),
|
||||
)
|
||||
|
||||
async def runner() -> None:
|
||||
async with serve(
|
||||
handler,
|
||||
self.config.host,
|
||||
self.config.port,
|
||||
process_request=process_request,
|
||||
max_size=self.config.max_message_bytes,
|
||||
ping_interval=self.config.ping_interval_s,
|
||||
ping_timeout=self.config.ping_timeout_s,
|
||||
ssl=ssl_context,
|
||||
):
|
||||
assert self._stop_event is not None
|
||||
await self._stop_event.wait()
|
||||
|
||||
self._server_task = asyncio.create_task(runner())
|
||||
await self._server_task
|
||||
|
||||
async def _connection_loop(self, connection: Any) -> None:
|
||||
request = connection.request
|
||||
path_part = request.path if request else "/"
|
||||
_, query = _parse_request_path(path_part)
|
||||
client_id_raw = _query_first(query, "client_id")
|
||||
client_id = client_id_raw.strip() if client_id_raw else ""
|
||||
if not client_id:
|
||||
client_id = f"anon-{uuid.uuid4().hex[:12]}"
|
||||
elif len(client_id) > 128:
|
||||
logger.warning("websocket: client_id too long ({} chars), truncating", len(client_id))
|
||||
client_id = client_id[:128]
|
||||
|
||||
chat_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
await connection.send(
|
||||
json.dumps(
|
||||
{
|
||||
"event": "ready",
|
||||
"chat_id": chat_id,
|
||||
"client_id": client_id,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
# Register only after ready is successfully sent to avoid out-of-order sends
|
||||
self._connections[chat_id] = connection
|
||||
|
||||
async for raw in connection:
|
||||
if isinstance(raw, bytes):
|
||||
try:
|
||||
raw = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
logger.warning("websocket: ignoring non-utf8 binary frame")
|
||||
continue
|
||||
content = _parse_inbound_payload(raw)
|
||||
if content is None:
|
||||
continue
|
||||
await self._handle_message(
|
||||
sender_id=client_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
metadata={"remote": getattr(connection, "remote_address", None)},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("websocket connection ended: {}", e)
|
||||
finally:
|
||||
self._connections.pop(chat_id, None)
|
||||
|
||||
async def stop(self) -> None:
|
||||
if not self._running:
|
||||
return
|
||||
self._running = False
|
||||
if self._stop_event:
|
||||
self._stop_event.set()
|
||||
if self._server_task:
|
||||
try:
|
||||
await self._server_task
|
||||
except Exception as e:
|
||||
logger.warning("websocket: server task error during shutdown: {}", e)
|
||||
self._server_task = None
|
||||
self._connections.clear()
|
||||
self._issued_tokens.clear()
|
||||
|
||||
async def _safe_send(self, chat_id: str, raw: str, *, label: str = "") -> None:
|
||||
"""Send a raw frame, cleaning up dead connections on ConnectionClosed."""
|
||||
connection = self._connections.get(chat_id)
|
||||
if connection is None:
|
||||
return
|
||||
try:
|
||||
await connection.send(raw)
|
||||
except ConnectionClosed:
|
||||
self._connections.pop(chat_id, None)
|
||||
logger.warning("websocket{}connection gone for chat_id={}", label, chat_id)
|
||||
except Exception as e:
|
||||
logger.error("websocket{}send failed: {}", label, e)
|
||||
raise
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
connection = self._connections.get(msg.chat_id)
|
||||
if connection is None:
|
||||
logger.warning("websocket: no active connection for chat_id={}", msg.chat_id)
|
||||
return
|
||||
payload: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"text": msg.content,
|
||||
}
|
||||
if msg.media:
|
||||
payload["media"] = msg.media
|
||||
if msg.reply_to:
|
||||
payload["reply_to"] = msg.reply_to
|
||||
raw = json.dumps(payload, ensure_ascii=False)
|
||||
await self._safe_send(msg.chat_id, raw, label=" ")
|
||||
|
||||
async def send_delta(
|
||||
self,
|
||||
chat_id: str,
|
||||
delta: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if self._connections.get(chat_id) is None:
|
||||
return
|
||||
meta = metadata or {}
|
||||
if meta.get("_stream_end"):
|
||||
body: dict[str, Any] = {"event": "stream_end"}
|
||||
else:
|
||||
body = {
|
||||
"event": "delta",
|
||||
"text": delta,
|
||||
}
|
||||
if meta.get("_stream_id") is not None:
|
||||
body["stream_id"] = meta["_stream_id"]
|
||||
raw = json.dumps(body, ensure_ascii=False)
|
||||
await self._safe_send(chat_id, raw, label=" stream ")
|
||||
@ -1,9 +1,13 @@
|
||||
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import importlib.util
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
@ -17,6 +21,37 @@ from pydantic import Field
|
||||
|
||||
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||
|
||||
# Upload safety limits (matching QQ channel defaults)
|
||||
WECOM_UPLOAD_MAX_BYTES = 1024 * 1024 * 200 # 200MB
|
||||
|
||||
# Replace unsafe characters with "_", keep Chinese and common safe punctuation.
|
||||
_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]()【】\u4e00-\u9fff]+", re.UNICODE)
|
||||
|
||||
|
||||
def _sanitize_filename(name: str) -> str:
|
||||
"""Sanitize filename to avoid traversal and problematic chars."""
|
||||
name = (name or "").strip()
|
||||
name = Path(name).name
|
||||
name = _SAFE_NAME_RE.sub("_", name).strip("._ ")
|
||||
return name
|
||||
|
||||
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
|
||||
_VIDEO_EXTS = {".mp4", ".avi", ".mov"}
|
||||
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg"}
|
||||
|
||||
|
||||
def _guess_wecom_media_type(filename: str) -> str:
|
||||
"""Classify file extension as WeCom media_type string."""
|
||||
ext = Path(filename).suffix.lower()
|
||||
if ext in _IMAGE_EXTS:
|
||||
return "image"
|
||||
if ext in _VIDEO_EXTS:
|
||||
return "video"
|
||||
if ext in _AUDIO_EXTS:
|
||||
return "voice"
|
||||
return "file"
|
||||
|
||||
class WecomConfig(Base):
|
||||
"""WeCom (Enterprise WeChat) AI Bot channel configuration."""
|
||||
|
||||
@ -217,6 +252,7 @@ class WecomChannel(BaseChannel):
|
||||
chat_id = body.get("chatid", sender_id)
|
||||
|
||||
content_parts = []
|
||||
media_paths: list[str] = []
|
||||
|
||||
if msg_type == "text":
|
||||
text = body.get("text", {}).get("content", "")
|
||||
@ -232,7 +268,8 @@ class WecomChannel(BaseChannel):
|
||||
file_path = await self._download_and_save_media(file_url, aes_key, "image")
|
||||
if file_path:
|
||||
filename = os.path.basename(file_path)
|
||||
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
|
||||
content_parts.append(f"[image: {filename}]")
|
||||
media_paths.append(file_path)
|
||||
else:
|
||||
content_parts.append("[image: download failed]")
|
||||
else:
|
||||
@ -256,7 +293,8 @@ class WecomChannel(BaseChannel):
|
||||
if file_url and aes_key:
|
||||
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
|
||||
if file_path:
|
||||
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
|
||||
content_parts.append(f"[file: {file_name}]")
|
||||
media_paths.append(file_path)
|
||||
else:
|
||||
content_parts.append(f"[file: {file_name}: download failed]")
|
||||
else:
|
||||
@ -286,12 +324,11 @@ class WecomChannel(BaseChannel):
|
||||
self._chat_frames[chat_id] = frame
|
||||
|
||||
# Forward to message bus
|
||||
# Note: media paths are included in content for broader model compatibility
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=None,
|
||||
media=media_paths or None,
|
||||
metadata={
|
||||
"message_id": msg_id,
|
||||
"msg_type": msg_type,
|
||||
@ -322,13 +359,21 @@ class WecomChannel(BaseChannel):
|
||||
logger.warning("Failed to download media from WeCom")
|
||||
return None
|
||||
|
||||
if len(data) > WECOM_UPLOAD_MAX_BYTES:
|
||||
logger.warning(
|
||||
"WeCom inbound media too large: {} bytes (max {})",
|
||||
len(data),
|
||||
WECOM_UPLOAD_MAX_BYTES,
|
||||
)
|
||||
return None
|
||||
|
||||
media_dir = get_media_dir("wecom")
|
||||
if not filename:
|
||||
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
|
||||
filename = os.path.basename(filename)
|
||||
filename = _sanitize_filename(filename)
|
||||
|
||||
file_path = media_dir / filename
|
||||
file_path.write_bytes(data)
|
||||
await asyncio.to_thread(file_path.write_bytes, data)
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
return str(file_path)
|
||||
|
||||
@ -336,6 +381,100 @@ class WecomChannel(BaseChannel):
|
||||
logger.error("Error downloading media: {}", e)
|
||||
return None
|
||||
|
||||
async def _upload_media_ws(
|
||||
self, client: Any, file_path: str,
|
||||
) -> "tuple[str, str] | tuple[None, None]":
|
||||
"""Upload a local file to WeCom via WebSocket 3-step protocol (base64).
|
||||
|
||||
Uses the WeCom WebSocket upload commands directly via
|
||||
``client._ws_manager.send_reply()``:
|
||||
|
||||
``aibot_upload_media_init`` → upload_id
|
||||
``aibot_upload_media_chunk`` × N (≤512 KB raw per chunk, base64)
|
||||
``aibot_upload_media_finish`` → media_id
|
||||
|
||||
Returns (media_id, media_type) on success, (None, None) on failure.
|
||||
"""
|
||||
from wecom_aibot_sdk.utils import generate_req_id as _gen_req_id
|
||||
|
||||
try:
|
||||
fname = os.path.basename(file_path)
|
||||
media_type = _guess_wecom_media_type(fname)
|
||||
|
||||
# Read file size and data in a thread to avoid blocking the event loop
|
||||
def _read_file():
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size > WECOM_UPLOAD_MAX_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large: {file_size} bytes (max {WECOM_UPLOAD_MAX_BYTES})"
|
||||
)
|
||||
with open(file_path, "rb") as f:
|
||||
return file_size, f.read()
|
||||
|
||||
file_size, data = await asyncio.to_thread(_read_file)
|
||||
# MD5 is used for file integrity only, not cryptographic security
|
||||
md5_hash = hashlib.md5(data).hexdigest()
|
||||
|
||||
CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64)
|
||||
mv = memoryview(data)
|
||||
chunk_list = [bytes(mv[i : i + CHUNK_SIZE]) for i in range(0, file_size, CHUNK_SIZE)]
|
||||
n_chunks = len(chunk_list)
|
||||
del mv, data
|
||||
|
||||
# Step 1: init
|
||||
req_id = _gen_req_id("upload_init")
|
||||
resp = await client._ws_manager.send_reply(req_id, {
|
||||
"type": media_type,
|
||||
"filename": fname,
|
||||
"total_size": file_size,
|
||||
"total_chunks": n_chunks,
|
||||
"md5": md5_hash,
|
||||
}, "aibot_upload_media_init")
|
||||
if resp.errcode != 0:
|
||||
logger.warning("WeCom upload init failed ({}): {}", resp.errcode, resp.errmsg)
|
||||
return None, None
|
||||
upload_id = resp.body.get("upload_id") if resp.body else None
|
||||
if not upload_id:
|
||||
logger.warning("WeCom upload init: no upload_id in response")
|
||||
return None, None
|
||||
|
||||
# Step 2: send chunks
|
||||
for i, chunk in enumerate(chunk_list):
|
||||
req_id = _gen_req_id("upload_chunk")
|
||||
resp = await client._ws_manager.send_reply(req_id, {
|
||||
"upload_id": upload_id,
|
||||
"chunk_index": i,
|
||||
"base64_data": base64.b64encode(chunk).decode(),
|
||||
}, "aibot_upload_media_chunk")
|
||||
if resp.errcode != 0:
|
||||
logger.warning("WeCom upload chunk {} failed ({}): {}", i, resp.errcode, resp.errmsg)
|
||||
return None, None
|
||||
|
||||
# Step 3: finish
|
||||
req_id = _gen_req_id("upload_finish")
|
||||
resp = await client._ws_manager.send_reply(req_id, {
|
||||
"upload_id": upload_id,
|
||||
}, "aibot_upload_media_finish")
|
||||
if resp.errcode != 0:
|
||||
logger.warning("WeCom upload finish failed ({}): {}", resp.errcode, resp.errmsg)
|
||||
return None, None
|
||||
|
||||
media_id = resp.body.get("media_id") if resp.body else None
|
||||
if not media_id:
|
||||
logger.warning("WeCom upload finish: no media_id in response body={}", resp.body)
|
||||
return None, None
|
||||
|
||||
suffix = "..." if len(media_id) > 16 else ""
|
||||
logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + suffix)
|
||||
return media_id, media_type
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning("WeCom upload skipped for {}: {}", file_path, e)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error("WeCom _upload_media_ws error for {}: {}", file_path, e)
|
||||
return None, None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through WeCom."""
|
||||
if not self._client:
|
||||
@ -343,29 +482,59 @@ class WecomChannel(BaseChannel):
|
||||
return
|
||||
|
||||
try:
|
||||
content = msg.content.strip()
|
||||
if not content:
|
||||
return
|
||||
content = (msg.content or "").strip()
|
||||
is_progress = bool(msg.metadata.get("_progress"))
|
||||
|
||||
# Get the stored frame for this chat
|
||||
frame = self._chat_frames.get(msg.chat_id)
|
||||
if not frame:
|
||||
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
|
||||
|
||||
# Send media files via WebSocket upload
|
||||
for file_path in msg.media or []:
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning("WeCom media file not found: {}", file_path)
|
||||
continue
|
||||
media_id, media_type = await self._upload_media_ws(self._client, file_path)
|
||||
if media_id:
|
||||
if frame:
|
||||
await self._client.reply(frame, {
|
||||
"msgtype": media_type,
|
||||
media_type: {"media_id": media_id},
|
||||
})
|
||||
else:
|
||||
await self._client.send_message(msg.chat_id, {
|
||||
"msgtype": media_type,
|
||||
media_type: {"media_id": media_id},
|
||||
})
|
||||
logger.debug("WeCom sent {} → {}", media_type, msg.chat_id)
|
||||
else:
|
||||
content += f"\n[file upload failed: {os.path.basename(file_path)}]"
|
||||
|
||||
if not content:
|
||||
return
|
||||
|
||||
# Use streaming reply for better UX
|
||||
stream_id = self._generate_req_id("stream")
|
||||
if frame:
|
||||
# Both progress and final messages must use reply_stream (cmd="aibot_respond_msg").
|
||||
# The plain reply() uses cmd="reply" which does not support "text" msgtype
|
||||
# and causes errcode=40008 from WeCom API.
|
||||
stream_id = self._generate_req_id("stream")
|
||||
await self._client.reply_stream(
|
||||
frame,
|
||||
stream_id,
|
||||
content,
|
||||
finish=not is_progress,
|
||||
)
|
||||
logger.debug(
|
||||
"WeCom {} sent to {}",
|
||||
"progress" if is_progress else "message",
|
||||
msg.chat_id,
|
||||
)
|
||||
else:
|
||||
# No frame (e.g. cron push): proactive send only supports markdown
|
||||
await self._client.send_message(msg.chat_id, {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {"content": content},
|
||||
})
|
||||
logger.info("WeCom proactive send to {}", msg.chat_id)
|
||||
|
||||
# Send as streaming message with finish=True
|
||||
await self._client.reply_stream(
|
||||
frame,
|
||||
stream_id,
|
||||
content,
|
||||
finish=True,
|
||||
)
|
||||
|
||||
logger.debug("WeCom message sent to {}", msg.chat_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending WeCom message: {}", e)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error sending WeCom message to chat_id={}", msg.chat_id)
|
||||
|
||||
@ -985,7 +985,43 @@ class WeixinChannel(BaseChannel):
|
||||
for media_path in (msg.media or []):
|
||||
try:
|
||||
await self._send_media_file(msg.chat_id, media_path, ctx_token)
|
||||
except (httpx.TimeoutException, httpx.TransportError) as net_err:
|
||||
# Network/transport errors: do NOT fall back to text —
|
||||
# the text send would also likely fail, and the outer
|
||||
# except will re-raise so ChannelManager retries properly.
|
||||
logger.error(
|
||||
"Network error sending WeChat media {}: {}",
|
||||
media_path,
|
||||
net_err,
|
||||
)
|
||||
raise
|
||||
except httpx.HTTPStatusError as http_err:
|
||||
status_code = (
|
||||
http_err.response.status_code
|
||||
if http_err.response is not None
|
||||
else 0
|
||||
)
|
||||
if status_code >= 500:
|
||||
# Server-side / retryable HTTP error — same as network.
|
||||
logger.error(
|
||||
"Server error ({} {}) sending WeChat media {}: {}",
|
||||
status_code,
|
||||
http_err.response.reason_phrase
|
||||
if http_err.response is not None
|
||||
else "",
|
||||
media_path,
|
||||
http_err,
|
||||
)
|
||||
raise
|
||||
# 4xx client errors are NOT retryable — fall back to text.
|
||||
filename = Path(media_path).name
|
||||
logger.error("Failed to send WeChat media {}: {}", media_path, http_err)
|
||||
await self._send_text(
|
||||
msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
|
||||
)
|
||||
except Exception as e:
|
||||
# Non-network errors (format, file-not-found, etc.):
|
||||
# notify the user via text fallback.
|
||||
filename = Path(media_path).name
|
||||
logger.error("Failed to send WeChat media {}: {}", media_path, e)
|
||||
# Notify user about failure via text
|
||||
|
||||
@ -590,6 +590,9 @@ def serve(
|
||||
mcp_servers=runtime_config.tools.mcp_servers,
|
||||
channels_config=runtime_config.channels,
|
||||
timezone=runtime_config.agents.defaults.timezone,
|
||||
unified_session=runtime_config.agents.defaults.unified_session,
|
||||
disabled_skills=runtime_config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes,
|
||||
)
|
||||
|
||||
model_name = runtime_config.agents.defaults.model
|
||||
@ -681,6 +684,9 @@ def gateway(
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
timezone=config.agents.defaults.timezone,
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
disabled_skills=config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||
)
|
||||
|
||||
# Set cron callback (needs agent)
|
||||
@ -815,6 +821,48 @@ def gateway(
|
||||
|
||||
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
|
||||
|
||||
async def _health_server(host: str, health_port: int):
|
||||
"""Lightweight HTTP health endpoint on the gateway port."""
|
||||
import json as _json
|
||||
|
||||
async def handle(reader, writer):
|
||||
try:
|
||||
data = await asyncio.wait_for(reader.read(4096), timeout=5)
|
||||
except (asyncio.TimeoutError, ConnectionError):
|
||||
writer.close()
|
||||
return
|
||||
|
||||
request_line = data.split(b"\r\n", 1)[0].decode("utf-8", errors="replace")
|
||||
method, path = "", ""
|
||||
parts = request_line.split(" ")
|
||||
if len(parts) >= 2:
|
||||
method, path = parts[0], parts[1]
|
||||
|
||||
if method == "GET" and path == "/health":
|
||||
body = _json.dumps({"status": "ok"})
|
||||
resp = (
|
||||
f"HTTP/1.0 200 OK\r\n"
|
||||
f"Content-Type: application/json\r\n"
|
||||
f"Content-Length: {len(body)}\r\n"
|
||||
f"\r\n{body}"
|
||||
)
|
||||
else:
|
||||
body = "Not Found"
|
||||
resp = (
|
||||
f"HTTP/1.0 404 Not Found\r\n"
|
||||
f"Content-Type: text/plain\r\n"
|
||||
f"Content-Length: {len(body)}\r\n"
|
||||
f"\r\n{body}"
|
||||
)
|
||||
|
||||
writer.write(resp.encode())
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
|
||||
server = await asyncio.start_server(handle, host, health_port)
|
||||
console.print(f"[green]✓[/green] Health endpoint: http://{host}:{health_port}/health")
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
# Register Dream system job (always-on, idempotent on restart)
|
||||
dream_cfg = config.agents.defaults.dream
|
||||
if dream_cfg.model_override:
|
||||
@ -837,6 +885,7 @@ def gateway(
|
||||
await asyncio.gather(
|
||||
agent.run(),
|
||||
channels.start_all(),
|
||||
_health_server(config.gateway.host, port),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\nShutting down...")
|
||||
@ -912,6 +961,9 @@ def agent(
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
timezone=config.agents.defaults.timezone,
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
disabled_skills=config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||
)
|
||||
restart_notice = consume_restart_notice_from_env()
|
||||
if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):
|
||||
@ -1116,7 +1168,7 @@ def channels_status(
|
||||
|
||||
table = Table(title="Channel Status")
|
||||
table.add_column("Channel", style="cyan")
|
||||
table.add_column("Enabled", style="green")
|
||||
table.add_column("Enabled")
|
||||
|
||||
for name, cls in sorted(discover_all().items()):
|
||||
section = getattr(config.channels, name, None)
|
||||
@ -1251,7 +1303,7 @@ def plugins_list():
|
||||
table = Table(title="Channel Plugins")
|
||||
table.add_column("Name", style="cyan")
|
||||
table.add_column("Source", style="magenta")
|
||||
table.add_column("Enabled", style="green")
|
||||
table.add_column("Enabled")
|
||||
|
||||
for name in sorted(all_channels):
|
||||
cls = all_channels[name]
|
||||
|
||||
@ -76,6 +76,14 @@ class AgentDefaults(Base):
|
||||
provider_retry_mode: Literal["standard", "persistent"] = "standard"
|
||||
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
||||
disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"])
|
||||
session_ttl_minutes: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
validation_alias=AliasChoices("idleCompactAfterMinutes", "sessionTtlMinutes"),
|
||||
serialization_alias="idleCompactAfterMinutes",
|
||||
) # Auto-compact idle threshold in minutes (0 = disabled)
|
||||
dream: DreamConfig = Field(default_factory=DreamConfig)
|
||||
|
||||
|
||||
@ -144,7 +152,7 @@ class ApiConfig(Base):
|
||||
class GatewayConfig(Base):
|
||||
"""Gateway/server configuration."""
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
host: str = "127.0.0.1" # Safer default: local-only bind.
|
||||
port: int = 18790
|
||||
heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig)
|
||||
|
||||
@ -152,7 +160,7 @@ class GatewayConfig(Base):
|
||||
class WebSearchConfig(Base):
|
||||
"""Web search tool configuration."""
|
||||
|
||||
provider: str = "duckduckgo" # brave, tavily, duckduckgo, searxng, jina
|
||||
provider: str = "duckduckgo" # brave, tavily, duckduckgo, searxng, jina, kagi
|
||||
api_key: str = ""
|
||||
base_url: str = "" # SearXNG base URL
|
||||
max_results: int = 5
|
||||
@ -176,6 +184,7 @@ class ExecToolConfig(Base):
|
||||
timeout: int = 60
|
||||
path_append: str = ""
|
||||
sandbox: str = "" # sandbox backend: "" (none) or "bwrap"
|
||||
allowed_env_keys: list[str] = Field(default_factory=list) # Env var names to pass through to subprocess (e.g. ["GOPATH", "JAVA_HOME"])
|
||||
|
||||
class MCPServerConfig(Base):
|
||||
"""MCP server connection configuration (stdio or HTTP)."""
|
||||
|
||||
@ -4,10 +4,12 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, Literal
|
||||
|
||||
from filelock import FileLock
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronRunRecord, CronSchedule, CronStore
|
||||
@ -69,28 +71,26 @@ class CronService:
|
||||
self,
|
||||
store_path: Path,
|
||||
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None,
|
||||
max_sleep_ms: int = 300_000, # 5 minutes
|
||||
):
|
||||
self.store_path = store_path
|
||||
self._action_path = store_path.parent / "action.jsonl"
|
||||
self._lock = FileLock(str(self._action_path.parent) + ".lock")
|
||||
self.on_job = on_job
|
||||
self._store: CronStore | None = None
|
||||
self._last_mtime: float = 0.0
|
||||
self._timer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._timer_active = False
|
||||
self.max_sleep_ms = max_sleep_ms
|
||||
|
||||
def _load_store(self) -> CronStore:
|
||||
"""Load jobs from disk. Reloads automatically if file was modified externally."""
|
||||
if self._store and self.store_path.exists():
|
||||
mtime = self.store_path.stat().st_mtime
|
||||
if mtime != self._last_mtime:
|
||||
logger.info("Cron: jobs.json modified externally, reloading")
|
||||
self._store = None
|
||||
if self._store:
|
||||
return self._store
|
||||
|
||||
def _load_jobs(self) -> tuple[list[CronJob], int]:
|
||||
jobs = []
|
||||
version = 1
|
||||
if self.store_path.exists():
|
||||
try:
|
||||
data = json.loads(self.store_path.read_text(encoding="utf-8"))
|
||||
jobs = []
|
||||
version = data.get("version", 1)
|
||||
for j in data.get("jobs", []):
|
||||
jobs.append(CronJob(
|
||||
id=j["id"],
|
||||
@ -129,13 +129,57 @@ class CronService:
|
||||
updated_at_ms=j.get("updatedAtMs", 0),
|
||||
delete_after_run=j.get("deleteAfterRun", False),
|
||||
))
|
||||
self._store = CronStore(jobs=jobs)
|
||||
self._last_mtime = self.store_path.stat().st_mtime
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load cron store: {}", e)
|
||||
self._store = CronStore()
|
||||
else:
|
||||
self._store = CronStore()
|
||||
return jobs, version
|
||||
|
||||
def _merge_action(self):
|
||||
if not self._action_path.exists():
|
||||
return
|
||||
|
||||
jobs_map = {j.id: j for j in self._store.jobs}
|
||||
def _update(params: dict):
|
||||
j = CronJob.from_dict(params)
|
||||
jobs_map[j.id] = j
|
||||
|
||||
def _del(params: dict):
|
||||
if job_id := params.get("job_id"):
|
||||
jobs_map.pop(job_id)
|
||||
|
||||
with self._lock:
|
||||
with open(self._action_path, "r", encoding="utf-8") as f:
|
||||
changed = False
|
||||
for line in f:
|
||||
try:
|
||||
line = line.strip()
|
||||
action = json.loads(line)
|
||||
if "action" not in action:
|
||||
continue
|
||||
if action["action"] == "del":
|
||||
_del(action.get("params", {}))
|
||||
else:
|
||||
_update(action.get("params", {}))
|
||||
changed = True
|
||||
except Exception as exp:
|
||||
logger.debug(f"load action line error: {exp}")
|
||||
continue
|
||||
self._store.jobs = list(jobs_map.values())
|
||||
if self._running and changed:
|
||||
self._action_path.write_text("", encoding="utf-8")
|
||||
self._save_store()
|
||||
return
|
||||
|
||||
def _load_store(self) -> CronStore:
|
||||
"""Load jobs from disk. Reloads automatically if file was modified externally.
|
||||
- Reload every time because it needs to merge operations on the jobs object from other instances.
|
||||
- During _on_timer execution, return the existing store to prevent concurrent
|
||||
_load_store calls (e.g. from list_jobs polling) from replacing it mid-execution.
|
||||
"""
|
||||
if self._timer_active and self._store:
|
||||
return self._store
|
||||
jobs, version = self._load_jobs()
|
||||
self._store = CronStore(version=version, jobs=jobs)
|
||||
self._merge_action()
|
||||
|
||||
return self._store
|
||||
|
||||
@ -230,11 +274,14 @@ class CronService:
|
||||
if self._timer_task:
|
||||
self._timer_task.cancel()
|
||||
|
||||
next_wake = self._get_next_wake_ms()
|
||||
if not next_wake or not self._running:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
delay_ms = max(0, next_wake - _now_ms())
|
||||
next_wake = self._get_next_wake_ms()
|
||||
if next_wake is None:
|
||||
delay_ms = self.max_sleep_ms
|
||||
else:
|
||||
delay_ms = min(self.max_sleep_ms, max(0, next_wake - _now_ms()))
|
||||
delay_s = delay_ms / 1000
|
||||
|
||||
async def tick():
|
||||
@ -248,18 +295,23 @@ class CronService:
|
||||
"""Handle timer tick - run due jobs."""
|
||||
self._load_store()
|
||||
if not self._store:
|
||||
self._arm_timer()
|
||||
return
|
||||
|
||||
now = _now_ms()
|
||||
due_jobs = [
|
||||
j for j in self._store.jobs
|
||||
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
|
||||
]
|
||||
self._timer_active = True
|
||||
try:
|
||||
now = _now_ms()
|
||||
due_jobs = [
|
||||
j for j in self._store.jobs
|
||||
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
|
||||
]
|
||||
|
||||
for job in due_jobs:
|
||||
await self._execute_job(job)
|
||||
for job in due_jobs:
|
||||
await self._execute_job(job)
|
||||
|
||||
self._save_store()
|
||||
self._save_store()
|
||||
finally:
|
||||
self._timer_active = False
|
||||
self._arm_timer()
|
||||
|
||||
async def _execute_job(self, job: CronJob) -> None:
|
||||
@ -303,6 +355,13 @@ class CronService:
|
||||
# Compute next run
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||
|
||||
def _append_action(self, action: Literal["add", "del", "update"], params: dict):
|
||||
self.store_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self._lock:
|
||||
with open(self._action_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({"action": action, "params": params}, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
# ========== Public API ==========
|
||||
|
||||
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
|
||||
@ -322,7 +381,6 @@ class CronService:
|
||||
delete_after_run: bool = False,
|
||||
) -> CronJob:
|
||||
"""Add a new job."""
|
||||
store = self._load_store()
|
||||
_validate_schedule_for_add(schedule)
|
||||
now = _now_ms()
|
||||
|
||||
@ -343,10 +401,13 @@ class CronService:
|
||||
updated_at_ms=now,
|
||||
delete_after_run=delete_after_run,
|
||||
)
|
||||
|
||||
store.jobs.append(job)
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
if self._running:
|
||||
store = self._load_store()
|
||||
store.jobs.append(job)
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
else:
|
||||
self._append_action("add", asdict(job))
|
||||
|
||||
logger.info("Cron: added job '{}' ({})", name, job.id)
|
||||
return job
|
||||
@ -380,8 +441,11 @@ class CronService:
|
||||
removed = len(store.jobs) < before
|
||||
|
||||
if removed:
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
if self._running:
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
else:
|
||||
self._append_action("del", {"job_id": job_id})
|
||||
logger.info("Cron: removed job {}", job_id)
|
||||
return "removed"
|
||||
|
||||
@ -398,23 +462,85 @@ class CronService:
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||
else:
|
||||
job.state.next_run_at_ms = None
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
if self._running:
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
else:
|
||||
self._append_action("update", asdict(job))
|
||||
return job
|
||||
return None
|
||||
|
||||
async def run_job(self, job_id: str, force: bool = False) -> bool:
|
||||
"""Manually run a job."""
|
||||
def update_job(
|
||||
self,
|
||||
job_id: str,
|
||||
*,
|
||||
name: str | None = None,
|
||||
schedule: CronSchedule | None = None,
|
||||
message: str | None = None,
|
||||
deliver: bool | None = None,
|
||||
channel: str | None = ...,
|
||||
to: str | None = ...,
|
||||
delete_after_run: bool | None = None,
|
||||
) -> CronJob | Literal["not_found", "protected"]:
|
||||
"""Update mutable fields of an existing job. System jobs cannot be updated.
|
||||
|
||||
For ``channel`` and ``to``, pass an explicit value (including ``None``)
|
||||
to update; omit (sentinel ``...``) to leave unchanged.
|
||||
"""
|
||||
store = self._load_store()
|
||||
for job in store.jobs:
|
||||
if job.id == job_id:
|
||||
if not force and not job.enabled:
|
||||
return False
|
||||
await self._execute_job(job)
|
||||
self._save_store()
|
||||
job = next((j for j in store.jobs if j.id == job_id), None)
|
||||
if job is None:
|
||||
return "not_found"
|
||||
if job.payload.kind == "system_event":
|
||||
return "protected"
|
||||
|
||||
if schedule is not None:
|
||||
_validate_schedule_for_add(schedule)
|
||||
job.schedule = schedule
|
||||
if name is not None:
|
||||
job.name = name
|
||||
if message is not None:
|
||||
job.payload.message = message
|
||||
if deliver is not None:
|
||||
job.payload.deliver = deliver
|
||||
if channel is not ...:
|
||||
job.payload.channel = channel
|
||||
if to is not ...:
|
||||
job.payload.to = to
|
||||
if delete_after_run is not None:
|
||||
job.delete_after_run = delete_after_run
|
||||
|
||||
job.updated_at_ms = _now_ms()
|
||||
if job.enabled:
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||
|
||||
if self._running:
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
else:
|
||||
self._append_action("update", asdict(job))
|
||||
|
||||
logger.info("Cron: updated job '{}' ({})", job.name, job.id)
|
||||
return job
|
||||
|
||||
async def run_job(self, job_id: str, force: bool = False) -> bool:
|
||||
"""Manually run a job without disturbing the service's running state."""
|
||||
was_running = self._running
|
||||
self._running = True
|
||||
try:
|
||||
store = self._load_store()
|
||||
for job in store.jobs:
|
||||
if job.id == job_id:
|
||||
if not force and not job.enabled:
|
||||
return False
|
||||
await self._execute_job(job)
|
||||
self._save_store()
|
||||
return True
|
||||
return False
|
||||
finally:
|
||||
self._running = was_running
|
||||
if was_running:
|
||||
self._arm_timer()
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_job(self, job_id: str) -> CronJob | None:
|
||||
"""Get a job by ID."""
|
||||
|
||||
@ -61,6 +61,18 @@ class CronJob:
|
||||
updated_at_ms: int = 0
|
||||
delete_after_run: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs: dict):
|
||||
state_kwargs = dict(kwargs.get("state", {}))
|
||||
state_kwargs["run_history"] = [
|
||||
record if isinstance(record, CronRunRecord) else CronRunRecord(**record)
|
||||
for record in state_kwargs.get("run_history", [])
|
||||
]
|
||||
kwargs["schedule"] = CronSchedule(**kwargs.get("schedule", {"kind": "every"}))
|
||||
kwargs["payload"] = CronPayload(**kwargs.get("payload", {}))
|
||||
kwargs["state"] = CronJobState(**state_kwargs)
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronStore:
|
||||
|
||||
@ -81,6 +81,9 @@ class Nanobot:
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
timezone=defaults.timezone,
|
||||
unified_session=defaults.unified_session,
|
||||
disabled_skills=defaults.disabled_skills,
|
||||
session_ttl_minutes=defaults.session_ttl_minutes,
|
||||
)
|
||||
return cls(loop)
|
||||
|
||||
|
||||
@ -353,6 +353,64 @@ class LLMProvider(ABC):
|
||||
# Unknown 429 defaults to WAIT+retry.
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _enforce_role_alternation(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Merge consecutive same-role messages and drop trailing assistant messages.
|
||||
|
||||
Some providers (OpenAI-compat, Azure, vLLM, Ollama, etc.) reject requests
|
||||
where the last message is 'assistant' (prefill not supported) or two
|
||||
consecutive non-system messages share the same role.
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
merged: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role")
|
||||
if (
|
||||
merged
|
||||
and role != "system"
|
||||
and role not in ("tool",)
|
||||
and merged[-1].get("role") == role
|
||||
and role in ("user", "assistant")
|
||||
):
|
||||
prev = merged[-1]
|
||||
if role == "assistant":
|
||||
prev_has_tools = bool(prev.get("tool_calls"))
|
||||
curr_has_tools = bool(msg.get("tool_calls"))
|
||||
if curr_has_tools:
|
||||
merged[-1] = dict(msg)
|
||||
continue
|
||||
if prev_has_tools:
|
||||
continue
|
||||
prev_content = prev.get("content") or ""
|
||||
curr_content = msg.get("content") or ""
|
||||
if isinstance(prev_content, str) and isinstance(curr_content, str):
|
||||
prev["content"] = (prev_content + "\n\n" + curr_content).strip()
|
||||
else:
|
||||
merged[-1] = dict(msg)
|
||||
else:
|
||||
merged.append(dict(msg))
|
||||
|
||||
last_popped = None
|
||||
while merged and merged[-1].get("role") == "assistant":
|
||||
last_popped = merged.pop()
|
||||
|
||||
# If removing trailing assistant messages left only system messages,
|
||||
# the request would be invalid for most providers (e.g. Zhipu/GLM
|
||||
# error 1214). Recover by converting the last popped assistant
|
||||
# message to a user message so the LLM can still see the content.
|
||||
if (
|
||||
merged
|
||||
and last_popped is not None
|
||||
and not any(m.get("role") in ("user", "tool") for m in merged)
|
||||
):
|
||||
recovered = dict(last_popped)
|
||||
recovered["role"] = "user"
|
||||
merged.append(recovered)
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
||||
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
|
||||
@ -375,6 +433,26 @@ class LLMProvider(ABC):
|
||||
result.append(msg)
|
||||
return result if found else None
|
||||
|
||||
@staticmethod
|
||||
def _strip_image_content_inplace(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Replace image_url blocks with text placeholder *in-place*.
|
||||
|
||||
Mutates the content lists of the original message dicts so that
|
||||
callers holding references to those dicts also see the stripped
|
||||
version.
|
||||
"""
|
||||
found = False
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for i, b in enumerate(content):
|
||||
if isinstance(b, dict) and b.get("type") == "image_url":
|
||||
path = (b.get("_meta") or {}).get("path", "")
|
||||
placeholder = image_placeholder_text(path, empty="[image omitted]")
|
||||
content[i] = {"type": "text", "text": placeholder}
|
||||
found = True
|
||||
return found
|
||||
|
||||
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
|
||||
"""Call chat() and convert unexpected exceptions to error responses."""
|
||||
try:
|
||||
@ -626,7 +704,12 @@ class LLMProvider(ABC):
|
||||
)
|
||||
retry_kw = dict(kw)
|
||||
retry_kw["messages"] = stripped
|
||||
return await call(**retry_kw)
|
||||
result = await call(**retry_kw)
|
||||
# Permanently strip images from the original messages so
|
||||
# subsequent iterations do not repeat the error-retry cycle.
|
||||
if result.finish_reason != "error":
|
||||
self._strip_image_content_inplace(original_messages)
|
||||
return result
|
||||
return response
|
||||
|
||||
if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT:
|
||||
|
||||
@ -26,6 +26,12 @@ else:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.openai_responses import (
|
||||
consume_sdk_stream,
|
||||
convert_messages,
|
||||
convert_tools,
|
||||
parse_response_output,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.registry import ProviderSpec
|
||||
@ -113,6 +119,14 @@ def _uses_openrouter_attribution(spec: "ProviderSpec | None", api_base: str | No
|
||||
return bool(api_base and "openrouter" in api_base.lower())
|
||||
|
||||
|
||||
def _is_direct_openai_base(api_base: str | None) -> bool:
|
||||
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
|
||||
if not api_base:
|
||||
return True
|
||||
normalized = api_base.strip().lower().rstrip("/")
|
||||
return "api.openai.com" in normalized and "openrouter" not in normalized
|
||||
|
||||
|
||||
class OpenAICompatProvider(LLMProvider):
|
||||
"""Unified provider for all OpenAI-compatible APIs.
|
||||
|
||||
@ -137,6 +151,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
self._setup_env(api_key, api_base)
|
||||
|
||||
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
||||
self._effective_base = effective_base
|
||||
default_headers = {"x-session-affinity": uuid.uuid4().hex}
|
||||
if _uses_openrouter_attribution(spec, effective_base):
|
||||
default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
|
||||
@ -228,9 +243,13 @@ class OpenAICompatProvider(LLMProvider):
|
||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||
normalized.append(tc_clean)
|
||||
clean["tool_calls"] = normalized
|
||||
if clean.get("role") == "assistant":
|
||||
# Some OpenAI-compatible gateways reject assistant messages
|
||||
# that mix non-empty content with tool_calls.
|
||||
clean["content"] = None
|
||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return sanitized
|
||||
return self._enforce_role_alternation(sanitized)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build kwargs
|
||||
@ -321,6 +340,88 @@ class OpenAICompatProvider(LLMProvider):
|
||||
|
||||
return kwargs
|
||||
|
||||
def _should_use_responses_api(
|
||||
self,
|
||||
model: str | None,
|
||||
reasoning_effort: str | None,
|
||||
) -> bool:
|
||||
"""Use Responses API only for direct OpenAI requests that benefit from it."""
|
||||
if self._spec and self._spec.name != "openai":
|
||||
return False
|
||||
if not _is_direct_openai_base(self._effective_base):
|
||||
return False
|
||||
|
||||
model_name = (model or self.default_model).lower()
|
||||
if reasoning_effort and reasoning_effort.lower() != "none":
|
||||
return True
|
||||
return any(token in model_name for token in ("gpt-5", "o1", "o3", "o4"))
|
||||
|
||||
@staticmethod
|
||||
def _should_fallback_from_responses_error(e: Exception) -> bool:
|
||||
"""Fallback only for likely Responses API compatibility errors."""
|
||||
response = getattr(e, "response", None)
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if status_code is None and response is not None:
|
||||
status_code = getattr(response, "status_code", None)
|
||||
if status_code not in {400, 404, 422}:
|
||||
return False
|
||||
|
||||
body = (
|
||||
getattr(e, "body", None)
|
||||
or getattr(e, "doc", None)
|
||||
or getattr(response, "text", None)
|
||||
)
|
||||
body_text = str(body).lower() if body is not None else ""
|
||||
compatibility_markers = (
|
||||
"responses",
|
||||
"response api",
|
||||
"max_output_tokens",
|
||||
"instructions",
|
||||
"previous_response",
|
||||
"unsupported",
|
||||
"not supported",
|
||||
"unknown parameter",
|
||||
"unrecognized request argument",
|
||||
)
|
||||
return any(marker in body_text for marker in compatibility_markers)
|
||||
|
||||
def _build_responses_body(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a Responses API body for direct OpenAI requests."""
|
||||
model_name = model or self.default_model
|
||||
sanitized_messages = self._sanitize_messages(self._sanitize_empty_content(messages))
|
||||
instructions, input_items = convert_messages(sanitized_messages)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"instructions": instructions or None,
|
||||
"input": input_items,
|
||||
"max_output_tokens": max(1, max_tokens),
|
||||
"store": False,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
if self._supports_temperature(model_name, reasoning_effort):
|
||||
body["temperature"] = temperature
|
||||
|
||||
if reasoning_effort and reasoning_effort.lower() != "none":
|
||||
body["reasoning"] = {"effort": reasoning_effort}
|
||||
body["include"] = ["reasoning.encrypted_content"]
|
||||
|
||||
if tools:
|
||||
body["tools"] = convert_tools(tools)
|
||||
body["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
return body
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Response parsing
|
||||
# ------------------------------------------------------------------
|
||||
@ -698,7 +799,12 @@ class OpenAICompatProvider(LLMProvider):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _handle_error(e: Exception) -> LLMResponse:
|
||||
def _handle_error(
|
||||
e: Exception,
|
||||
*,
|
||||
spec: ProviderSpec | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> LLMResponse:
|
||||
body = (
|
||||
getattr(e, "doc", None)
|
||||
or getattr(e, "body", None)
|
||||
@ -706,6 +812,15 @@ class OpenAICompatProvider(LLMProvider):
|
||||
)
|
||||
body_text = body if isinstance(body, str) else str(body) if body is not None else ""
|
||||
msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}"
|
||||
|
||||
text = f"{body_text} {e}".lower()
|
||||
if spec and spec.is_local and ("502" in text or "connection" in text or "refused" in text):
|
||||
msg += (
|
||||
"\nHint: this is a local model endpoint. Check that the local server is reachable at "
|
||||
f"{api_base or spec.default_api_base}, and if you are using a proxy/tunnel, make sure it "
|
||||
"can reach your local Ollama/vLLM service instead of routing localhost through the remote host."
|
||||
)
|
||||
|
||||
response = getattr(e, "response", None)
|
||||
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
|
||||
if retry_after is None:
|
||||
@ -731,14 +846,25 @@ class OpenAICompatProvider(LLMProvider):
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
try:
|
||||
if self._should_use_responses_api(model, reasoning_effort):
|
||||
try:
|
||||
body = self._build_responses_body(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
return parse_response_output(await self._client.responses.create(**body))
|
||||
except Exception as responses_error:
|
||||
if not self._should_fallback_from_responses_error(responses_error):
|
||||
raise
|
||||
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
@ -751,14 +877,49 @@ class OpenAICompatProvider(LLMProvider):
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||
try:
|
||||
if self._should_use_responses_api(model, reasoning_effort):
|
||||
try:
|
||||
body = self._build_responses_body(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
body["stream"] = True
|
||||
stream = await self._client.responses.create(**body)
|
||||
|
||||
async def _timed_stream():
|
||||
stream_iter = stream.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
yield await asyncio.wait_for(
|
||||
stream_iter.__anext__(),
|
||||
timeout=idle_timeout_s,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
content, tool_calls, finish_reason, usage, reasoning_content = await consume_sdk_stream(
|
||||
_timed_stream(),
|
||||
on_content_delta,
|
||||
)
|
||||
return LLMResponse(
|
||||
content=content or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
)
|
||||
except Exception as responses_error:
|
||||
if not self._should_fallback_from_responses_error(responses_error):
|
||||
raise
|
||||
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
chunks: list[Any] = []
|
||||
stream_iter = stream.__aiter__()
|
||||
@ -786,7 +947,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
error_kind="timeout",
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
@ -155,6 +155,7 @@ class SessionManager:
|
||||
messages = []
|
||||
metadata = {}
|
||||
created_at = None
|
||||
updated_at = None
|
||||
last_consolidated = 0
|
||||
|
||||
with open(path, encoding="utf-8") as f:
|
||||
@ -168,6 +169,7 @@ class SessionManager:
|
||||
if data.get("_type") == "metadata":
|
||||
metadata = data.get("metadata", {})
|
||||
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
||||
updated_at = datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None
|
||||
last_consolidated = data.get("last_consolidated", 0)
|
||||
else:
|
||||
messages.append(data)
|
||||
@ -176,6 +178,7 @@ class SessionManager:
|
||||
key=key,
|
||||
messages=messages,
|
||||
created_at=created_at or datetime.now(),
|
||||
updated_at=updated_at or datetime.now(),
|
||||
metadata=metadata,
|
||||
last_consolidated=last_consolidated
|
||||
)
|
||||
|
||||
@ -3,6 +3,7 @@ Compare conversation history against current memory files. Also scan memory file
|
||||
Output one line per finding:
|
||||
[FILE] atomic fact (not already in memory)
|
||||
[FILE-REMOVE] reason for removal
|
||||
[SKILL] kebab-case-name: one-line description of the reusable pattern
|
||||
|
||||
Files: USER (identity, preferences), SOUL (bot behavior, tone), MEMORY (knowledge, project context)
|
||||
|
||||
@ -18,6 +19,12 @@ Staleness — flag for [FILE-REMOVE]:
|
||||
- Detailed incident info after 14 days — reduce to one-line summary
|
||||
- Superseded: approaches replaced by newer solutions, deprecated dependencies
|
||||
|
||||
Skill discovery — flag [SKILL] when ALL of these are true:
|
||||
- A specific, repeatable workflow appeared 2+ times in the conversation history
|
||||
- It involves clear steps (not vague preferences like "likes concise answers")
|
||||
- It is substantial enough to warrant its own instruction set (not trivial like "read a file")
|
||||
- Do not worry about duplicates — the next phase will check against existing skills
|
||||
|
||||
Do not add: current weather, transient status, temporary errors, conversational filler.
|
||||
|
||||
[SKIP] if nothing needs updating.
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
Update memory files based on the analysis below.
|
||||
- [FILE] entries: add the described content to the appropriate file
|
||||
- [FILE-REMOVE] entries: delete the corresponding content from memory files
|
||||
- [SKILL] entries: create a new skill under skills/<name>/SKILL.md using write_file
|
||||
|
||||
## File paths (relative to workspace root)
|
||||
- SOUL.md
|
||||
- USER.md
|
||||
- memory/MEMORY.md
|
||||
- skills/<name>/SKILL.md (for [SKILL] entries only)
|
||||
|
||||
Do NOT guess paths.
|
||||
|
||||
@ -17,6 +19,17 @@ Do NOT guess paths.
|
||||
- Surgical edits only — never rewrite entire files
|
||||
- If nothing to update, stop without calling tools
|
||||
|
||||
## Skill creation rules (for [SKILL] entries)
|
||||
- Use write_file to create skills/<name>/SKILL.md
|
||||
- Before writing, read_file `{{ skill_creator_path }}` for format reference (frontmatter structure, naming conventions, quality standards)
|
||||
- **Dedup check**: read existing skills listed below to verify the new skill is not functionally redundant. Skip creation if an existing skill already covers the same workflow.
|
||||
- Include YAML frontmatter with name and description fields
|
||||
- Keep SKILL.md under 2000 words — concise and actionable
|
||||
- Include: when to use, steps, output format, at least one example
|
||||
- Do NOT overwrite existing skills — skip if the skill directory already exists
|
||||
- Reference specific tools the agent has access to (read_file, write_file, exec, web_search, etc.)
|
||||
- Skills are instruction sets, not code — do not include implementation code
|
||||
|
||||
## Quality
|
||||
- Every line must carry standalone value
|
||||
- Concise bullets under clear headers
|
||||
|
||||
@ -15,9 +15,12 @@ from loguru import logger
|
||||
|
||||
|
||||
def strip_think(text: str) -> str:
|
||||
"""Remove <think>…</think> blocks and any unclosed trailing <think> tag."""
|
||||
"""Remove thinking blocks and any unclosed trailing tag."""
|
||||
text = re.sub(r"<think>[\s\S]*?</think>", "", text)
|
||||
text = re.sub(r"<think>[\s\S]*$", "", text)
|
||||
text = re.sub(r"^\s*<think>[\s\S]*$", "", text)
|
||||
# Gemma 4 and similar models use <thought>...</thought> blocks
|
||||
text = re.sub(r"<thought>[\s\S]*?</thought>", "", text)
|
||||
text = re.sub(r"^\s*<thought>[\s\S]*$", "", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
@ -272,7 +275,7 @@ def build_assistant_message(
|
||||
thinking_blocks: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a provider-safe assistant message with optional reasoning fields."""
|
||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||
msg: dict[str, Any] = {"role": "assistant", "content": content or ""}
|
||||
if tool_calls:
|
||||
msg["tool_calls"] = tool_calls
|
||||
if reasoning_content is not None or thinking_blocks:
|
||||
@ -417,7 +420,7 @@ def build_status_content(
|
||||
ctx_total = max(context_window_tokens, 0)
|
||||
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
|
||||
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
|
||||
ctx_total_str = f"{ctx_total // 1024}k" if ctx_total > 0 else "n/a"
|
||||
ctx_total_str = f"{ctx_total // 1000}k" if ctx_total > 0 else "n/a"
|
||||
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
|
||||
if cached and last_in:
|
||||
token_line += f" ({cached * 100 // last_in}% cached)"
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from nanobot.utils.path import abbreviate_path
|
||||
|
||||
# Registry: tool_name -> (key_args, template, is_path, is_command)
|
||||
@ -17,27 +19,39 @@ _TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = {
|
||||
"list_dir": (["path"], "ls {}", True, False),
|
||||
}
|
||||
|
||||
# Matches file paths embedded in shell commands, including quoted paths with spaces.
|
||||
_PATH_IN_CMD_RE = re.compile(
|
||||
r'"(?P<double>(?:[A-Za-z]:[/\\]|~/|/)[^"]+)"'
|
||||
r"|'(?P<single>(?:[A-Za-z]:[/\\]|~/|/)[^']+)'"
|
||||
r"|(?P<bare>(?:[A-Za-z]:[/\\]|~/|(?<=\s)/)[^\s;&|<>\"']+)"
|
||||
)
|
||||
|
||||
|
||||
def format_tool_hints(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hints with smart abbreviation."""
|
||||
if not tool_calls:
|
||||
return ""
|
||||
|
||||
hints = []
|
||||
for name, count, example_tc in _group_consecutive(tool_calls):
|
||||
fmt = _TOOL_FORMATS.get(name)
|
||||
formatted = []
|
||||
for tc in tool_calls:
|
||||
fmt = _TOOL_FORMATS.get(tc.name)
|
||||
if fmt:
|
||||
hint = _fmt_known(example_tc, fmt)
|
||||
elif name.startswith("mcp_"):
|
||||
hint = _fmt_mcp(example_tc)
|
||||
formatted.append(_fmt_known(tc, fmt))
|
||||
elif tc.name.startswith("mcp_"):
|
||||
formatted.append(_fmt_mcp(tc))
|
||||
else:
|
||||
hint = _fmt_fallback(example_tc)
|
||||
formatted.append(_fmt_fallback(tc))
|
||||
|
||||
if count > 1:
|
||||
hint = f"{hint} \u00d7 {count}"
|
||||
hints.append(hint)
|
||||
hints = []
|
||||
for hint in formatted:
|
||||
if hints and hints[-1][0] == hint:
|
||||
hints[-1] = (hint, hints[-1][1] + 1)
|
||||
else:
|
||||
hints.append((hint, 1))
|
||||
|
||||
return ", ".join(hints)
|
||||
return ", ".join(
|
||||
f"{h} \u00d7 {c}" if c > 1 else h for h, c in hints
|
||||
)
|
||||
|
||||
|
||||
def _get_args(tc) -> dict:
|
||||
@ -51,17 +65,6 @@ def _get_args(tc) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
def _group_consecutive(calls: list) -> list[tuple[str, int, object]]:
|
||||
"""Group consecutive calls to the same tool: [(name, count, first), ...]."""
|
||||
groups: list[tuple[str, int, object]] = []
|
||||
for tc in calls:
|
||||
if groups and groups[-1][0] == tc.name:
|
||||
groups[-1] = (groups[-1][0], groups[-1][1] + 1, groups[-1][2])
|
||||
else:
|
||||
groups.append((tc.name, 1, tc))
|
||||
return groups
|
||||
|
||||
|
||||
def _extract_arg(tc, key_args: list[str]) -> str | None:
|
||||
"""Extract the first available value from preferred key names."""
|
||||
args = _get_args(tc)
|
||||
@ -85,10 +88,25 @@ def _fmt_known(tc, fmt: tuple) -> str:
|
||||
if fmt[2]: # is_path
|
||||
val = abbreviate_path(val)
|
||||
elif fmt[3]: # is_command
|
||||
val = val[:40] + "\u2026" if len(val) > 40 else val
|
||||
val = _abbreviate_command(val)
|
||||
return fmt[1].format(val)
|
||||
|
||||
|
||||
def _abbreviate_command(cmd: str, max_len: int = 40) -> str:
|
||||
"""Abbreviate paths in a command string, then truncate."""
|
||||
def _replace_path(match: re.Match[str]) -> str:
|
||||
if match.group("double") is not None:
|
||||
return f'"{abbreviate_path(match.group("double"), max_len=25)}"'
|
||||
if match.group("single") is not None:
|
||||
return f"'{abbreviate_path(match.group('single'), max_len=25)}'"
|
||||
return abbreviate_path(match.group("bare"), max_len=25)
|
||||
|
||||
abbreviated = _PATH_IN_CMD_RE.sub(_replace_path, cmd)
|
||||
if len(abbreviated) <= max_len:
|
||||
return abbreviated
|
||||
return abbreviated[:max_len - 1] + "\u2026"
|
||||
|
||||
|
||||
def _fmt_mcp(tc) -> str:
|
||||
"""Format MCP tool as server::tool."""
|
||||
name = tc.name
|
||||
|
||||
@ -54,6 +54,7 @@ dependencies = [
|
||||
"python-docx>=1.1.0,<2.0.0",
|
||||
"openpyxl>=3.1.0,<4.0.0",
|
||||
"python-pptx>=1.0.0,<2.0.0",
|
||||
"filelock>=3.25.2",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
@ -79,12 +80,16 @@ discord = [
|
||||
langsmith = [
|
||||
"langsmith>=0.1.0",
|
||||
]
|
||||
pdf = [
|
||||
"pymupdf>=1.25.0",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=9.0.0,<10.0.0",
|
||||
"pytest-asyncio>=1.3.0,<2.0.0",
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
"pytest-cov>=6.0.0,<7.0.0",
|
||||
"ruff>=0.1.0",
|
||||
"pymupdf>=1.25.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
1010
tests/agent/test_auto_compact.py
Normal file
1010
tests/agent/test_auto_compact.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -46,7 +46,7 @@ class TestConsolidatorSummarize:
|
||||
{"role": "assistant", "content": "Done, fixed the race condition."},
|
||||
]
|
||||
result = await consolidator.archive(messages)
|
||||
assert result is True
|
||||
assert result == "User fixed a bug in the auth module."
|
||||
entries = store.read_unprocessed_history(since_cursor=0)
|
||||
assert len(entries) == 1
|
||||
|
||||
@ -55,14 +55,14 @@ class TestConsolidatorSummarize:
|
||||
mock_provider.chat_with_retry.side_effect = Exception("API error")
|
||||
messages = [{"role": "user", "content": "hello"}]
|
||||
result = await consolidator.archive(messages)
|
||||
assert result is True # always succeeds
|
||||
assert result is None # no summary on raw dump fallback
|
||||
entries = store.read_unprocessed_history(since_cursor=0)
|
||||
assert len(entries) == 1
|
||||
assert "[RAW]" in entries[0]["content"]
|
||||
|
||||
async def test_summarize_skips_empty_messages(self, consolidator):
|
||||
result = await consolidator.archive([])
|
||||
assert result is False
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestConsolidatorTokenBudget:
|
||||
@ -76,3 +76,52 @@ class TestConsolidatorTokenBudget:
|
||||
consolidator.archive = AsyncMock(return_value=True)
|
||||
await consolidator.maybe_consolidate_by_tokens(session)
|
||||
consolidator.archive.assert_not_called()
|
||||
|
||||
async def test_chunk_cap_preserves_user_turn_boundary(self, consolidator):
|
||||
"""Chunk cap should rewind to the last user boundary within the cap."""
|
||||
consolidator._SAFETY_BUFFER = 0
|
||||
session = MagicMock()
|
||||
session.last_consolidated = 0
|
||||
session.key = "test:key"
|
||||
session.messages = [
|
||||
{
|
||||
"role": "user" if i in {0, 50, 61} else "assistant",
|
||||
"content": f"m{i}",
|
||||
}
|
||||
for i in range(70)
|
||||
]
|
||||
consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||
side_effect=[(1200, "tiktoken"), (400, "tiktoken")]
|
||||
)
|
||||
consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999))
|
||||
consolidator.archive = AsyncMock(return_value=True)
|
||||
|
||||
await consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
archived_chunk = consolidator.archive.await_args.args[0]
|
||||
assert len(archived_chunk) == 50
|
||||
assert archived_chunk[0]["content"] == "m0"
|
||||
assert archived_chunk[-1]["content"] == "m49"
|
||||
assert session.last_consolidated == 50
|
||||
|
||||
async def test_chunk_cap_skips_when_no_user_boundary_within_cap(self, consolidator):
|
||||
"""If the cap would cut mid-turn, consolidation should skip that round."""
|
||||
consolidator._SAFETY_BUFFER = 0
|
||||
session = MagicMock()
|
||||
session.last_consolidated = 0
|
||||
session.key = "test:key"
|
||||
session.messages = [
|
||||
{
|
||||
"role": "user" if i in {0, 61} else "assistant",
|
||||
"content": f"m{i}",
|
||||
}
|
||||
for i in range(70)
|
||||
]
|
||||
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(1200, "tiktoken"))
|
||||
consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999))
|
||||
consolidator.archive = AsyncMock(return_value=True)
|
||||
|
||||
await consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
consolidator.archive.assert_not_awaited()
|
||||
assert session.last_consolidated == 0
|
||||
|
||||
@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from nanobot.agent.memory import Dream, MemoryStore
|
||||
from nanobot.agent.runner import AgentRunResult
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -95,3 +96,30 @@ class TestDreamRun:
|
||||
entries = store.read_unprocessed_history(since_cursor=0)
|
||||
assert all(e["cursor"] > 0 for e in entries)
|
||||
|
||||
async def test_skill_phase_uses_builtin_skill_creator_path(self, dream, mock_provider, mock_runner, store):
|
||||
"""Dream should point skill creation guidance at the builtin skill-creator template."""
|
||||
store.append_history("Repeated workflow one")
|
||||
store.append_history("Repeated workflow two")
|
||||
mock_provider.chat_with_retry.return_value = MagicMock(content="[SKILL] test-skill: test description")
|
||||
mock_runner.run = AsyncMock(return_value=_make_run_result())
|
||||
|
||||
await dream.run()
|
||||
|
||||
spec = mock_runner.run.call_args[0][0]
|
||||
system_prompt = spec.initial_messages[0]["content"]
|
||||
expected = str(BUILTIN_SKILLS_DIR / "skill-creator" / "SKILL.md")
|
||||
assert expected in system_prompt
|
||||
|
||||
async def test_skill_write_tool_accepts_workspace_relative_skill_path(self, dream, store):
|
||||
"""Dream skill creation should allow skills/<name>/SKILL.md relative to workspace root."""
|
||||
write_tool = dream._tools.get("write_file")
|
||||
assert write_tool is not None
|
||||
|
||||
result = await write_tool.execute(
|
||||
path="skills/test-skill/SKILL.md",
|
||||
content="---\nname: test-skill\ndescription: Test\n---\n",
|
||||
)
|
||||
|
||||
assert "Successfully wrote" in result
|
||||
assert (store.workspace / "skills" / "test-skill" / "SKILL.md").exists()
|
||||
|
||||
|
||||
@ -184,17 +184,22 @@ def test_stale_extra_content_in_tool_calls_survives_sanitize() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
messages = [{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "fn", "arguments": "{}"},
|
||||
"extra_content": GEMINI_EXTRA,
|
||||
}],
|
||||
}]
|
||||
messages = [
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "fn", "arguments": "{}"},
|
||||
"extra_content": GEMINI_EXTRA,
|
||||
}],
|
||||
},
|
||||
{"role": "tool", "content": "ok", "tool_call_id": "call_1"},
|
||||
{"role": "user", "content": "thanks"},
|
||||
]
|
||||
|
||||
sanitized = provider._sanitize_messages(messages)
|
||||
|
||||
assert sanitized[0]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA
|
||||
assert sanitized[1]["tool_calls"][0]["extra_content"] == GEMINI_EXTRA
|
||||
|
||||
@ -232,6 +232,35 @@ async def test_composite_empty_hooks_no_ops():
|
||||
assert hook.finalize_content(ctx, "test") == "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_supports_legacy_hook_init_without_super():
|
||||
calls: list[str] = []
|
||||
|
||||
class LegacyHook(AgentHook):
|
||||
def __init__(self, label: str) -> None:
|
||||
self.label = label
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append(self.label)
|
||||
|
||||
hook = CompositeHook([LegacyHook("legacy")])
|
||||
await hook.before_iteration(_ctx())
|
||||
assert calls == ["legacy"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_can_wrap_another_composite():
|
||||
calls: list[str] = []
|
||||
|
||||
class Inner(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append("inner")
|
||||
|
||||
hook = CompositeHook([CompositeHook([Inner()])])
|
||||
await hook.before_iteration(_ctx())
|
||||
assert calls == ["inner"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: AgentLoop with extra hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -278,7 +307,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, tools_used, messages = await loop._run_agent_loop(
|
||||
content, tools_used, messages, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
@ -302,7 +331,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path):
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, _, _ = await loop._run_agent_loop(
|
||||
content, _, _, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
@ -344,7 +373,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path):
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
content, tools_used, _ = await loop._run_agent_loop([])
|
||||
content, tools_used, _, _, _ = await loop._run_agent_loop([])
|
||||
assert content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
|
||||
@ -1,5 +1,13 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.session.manager import Session
|
||||
|
||||
|
||||
@ -11,6 +19,12 @@ def _mk_loop() -> AgentLoop:
|
||||
return loop
|
||||
|
||||
|
||||
def _make_full_loop(tmp_path: Path) -> AgentLoop:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
|
||||
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:runtime-only")
|
||||
@ -200,3 +214,206 @@ def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
|
||||
assert session.messages[0]["role"] == "assistant"
|
||||
assert session.messages[1]["tool_call_id"] == "call_done"
|
||||
assert session.messages[2]["tool_call_id"] == "call_pending"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_persists_user_message_before_turn_completes(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
loop._run_agent_loop = AsyncMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c1", content="persist me")
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await loop._process_message(msg)
|
||||
|
||||
loop.sessions.invalidate("feishu:c1")
|
||||
persisted = loop.sessions.get_or_create("feishu:c1")
|
||||
assert [m["role"] for m in persisted.messages] == ["user"]
|
||||
assert persisted.messages[0]["content"] == "persist me"
|
||||
assert persisted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True
|
||||
assert persisted.updated_at >= persisted.created_at
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_does_not_duplicate_early_persisted_user_message(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
loop._run_agent_loop = AsyncMock(return_value=(
|
||||
"done",
|
||||
None,
|
||||
[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
],
|
||||
"stop",
|
||||
False,
|
||||
)) # type: ignore[method-assign]
|
||||
|
||||
result = await loop._process_message(
|
||||
InboundMessage(channel="feishu", sender_id="u1", chat_id="c2", content="hello")
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "done"
|
||||
session = loop.sessions.get_or_create("feishu:c2")
|
||||
assert [
|
||||
{k: v for k, v in m.items() if k in {"role", "content"}}
|
||||
for m in session.messages
|
||||
] == [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=MagicMock()) # unused because _run_agent_loop is stubbed
|
||||
|
||||
session = loop.sessions.get_or_create("feishu:c3")
|
||||
session.add_message("user", "old question")
|
||||
session.metadata[AgentLoop._PENDING_USER_TURN_KEY] = True
|
||||
loop.sessions.save(session)
|
||||
|
||||
loop._run_agent_loop = AsyncMock(return_value=(
|
||||
"new answer",
|
||||
None,
|
||||
[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "old question"},
|
||||
{"role": "assistant", "content": "Error: Task interrupted before a response was generated."},
|
||||
{"role": "user", "content": "new question"},
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
],
|
||||
"stop",
|
||||
False,
|
||||
)) # type: ignore[method-assign]
|
||||
|
||||
result = await loop._process_message(
|
||||
InboundMessage(channel="feishu", sender_id="u1", chat_id="c3", content="new question")
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "new answer"
|
||||
session = loop.sessions.get_or_create("feishu:c3")
|
||||
assert [
|
||||
{k: v for k, v in m.items() if k in {"role", "content"}}
|
||||
for m in session.messages
|
||||
] == [
|
||||
{"role": "user", "content": "old question"},
|
||||
{"role": "assistant", "content": "Error: Task interrupted before a response was generated."},
|
||||
{"role": "user", "content": "new question"},
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
]
|
||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_preserves_runtime_checkpoint_for_next_turn(tmp_path: Path) -> None:
|
||||
from nanobot.command.builtin import cmd_stop
|
||||
from nanobot.command.router import CommandContext
|
||||
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
|
||||
checkpoint_saved = asyncio.Event()
|
||||
|
||||
async def interrupted_run_agent_loop(_initial_messages, *, session=None, **_kwargs):
|
||||
assert session is not None
|
||||
loop._set_runtime_checkpoint(
|
||||
session,
|
||||
{
|
||||
"assistant_message": {
|
||||
"role": "assistant",
|
||||
"content": "working",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_done",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
"completed_tool_results": [
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_done",
|
||||
"name": "read_file",
|
||||
"content": "ok",
|
||||
}
|
||||
],
|
||||
"pending_tool_calls": [
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
checkpoint_saved.set()
|
||||
await asyncio.Event().wait()
|
||||
|
||||
loop._run_agent_loop = interrupted_run_agent_loop # type: ignore[method-assign]
|
||||
|
||||
first_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="keep progress")
|
||||
task = asyncio.create_task(loop._process_message(first_msg))
|
||||
loop._active_tasks[first_msg.session_key] = [task]
|
||||
await asyncio.wait_for(checkpoint_saved.wait(), timeout=1.0)
|
||||
|
||||
stop_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="/stop")
|
||||
stop_ctx = CommandContext(msg=stop_msg, session=None, key=stop_msg.session_key, raw="/stop", loop=loop)
|
||||
stop_result = await cmd_stop(stop_ctx)
|
||||
|
||||
assert "Stopped 1 task" in stop_result.content
|
||||
assert task.done()
|
||||
|
||||
loop.sessions.invalidate("feishu:c4")
|
||||
interrupted = loop.sessions.get_or_create("feishu:c4")
|
||||
assert interrupted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True
|
||||
assert interrupted.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is not None
|
||||
|
||||
async def resumed_run_agent_loop(initial_messages, **_kwargs):
|
||||
return (
|
||||
"next answer",
|
||||
None,
|
||||
[*initial_messages, {"role": "assistant", "content": "next answer"}],
|
||||
"stop",
|
||||
False,
|
||||
)
|
||||
|
||||
loop._run_agent_loop = resumed_run_agent_loop # type: ignore[method-assign]
|
||||
result = await loop._process_message(
|
||||
InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="continue here")
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "next answer"
|
||||
|
||||
session = loop.sessions.get_or_create("feishu:c4")
|
||||
assert [
|
||||
{k: v for k, v in m.items() if k in {"role", "content", "tool_call_id", "name"}}
|
||||
for m in session.messages
|
||||
] == [
|
||||
{"role": "user", "content": "keep progress"},
|
||||
{"role": "assistant", "content": "working"},
|
||||
{"role": "tool", "tool_call_id": "call_done", "name": "read_file", "content": "ok"},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_pending",
|
||||
"name": "exec",
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
},
|
||||
{"role": "user", "content": "continue here"},
|
||||
{"role": "assistant", "content": "next answer"},
|
||||
]
|
||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||
assert AgentLoop._RUNTIME_CHECKPOINT_KEY not in session.metadata
|
||||
|
||||
44
tests/agent/test_mcp_connection.py
Normal file
44
tests/agent/test_mcp_connection.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Tests for MCP connection lifecycle in AgentLoop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
|
||||
def _make_loop(tmp_path, *, mcp_servers: dict | None = None) -> AgentLoop:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation.max_tokens = 4096
|
||||
return AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
mcp_servers=mcp_servers or {"test": object()},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_retries_when_no_servers_connect(tmp_path, monkeypatch: pytest.MonkeyPatch):
|
||||
loop = _make_loop(tmp_path)
|
||||
attempts = 0
|
||||
|
||||
async def _fake_connect(_servers, _registry):
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.connect_mcp_servers", _fake_connect)
|
||||
|
||||
await loop._connect_mcp()
|
||||
await loop._connect_mcp()
|
||||
|
||||
assert attempts == 2
|
||||
assert loop._mcp_connected is False
|
||||
assert loop._mcp_stacks == {}
|
||||
File diff suppressed because it is too large
Load Diff
@ -250,3 +250,63 @@ def test_list_skills_openclaw_metadata_parsed_for_requirements(
|
||||
assert entries == [
|
||||
{"name": "openclaw_skill", "path": str(skill_path), "source": "workspace"},
|
||||
]
|
||||
|
||||
|
||||
def test_disabled_skills_excluded_from_list(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
ws_skills = workspace / "skills"
|
||||
ws_skills.mkdir(parents=True)
|
||||
_write_skill(ws_skills, "alpha", body="# Alpha")
|
||||
beta_path = _write_skill(ws_skills, "beta", body="# Beta")
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"})
|
||||
entries = loader.list_skills(filter_unavailable=False)
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["name"] == "beta"
|
||||
assert entries[0]["path"] == str(beta_path)
|
||||
|
||||
|
||||
def test_disabled_skills_empty_set_no_effect(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
ws_skills = workspace / "skills"
|
||||
ws_skills.mkdir(parents=True)
|
||||
_write_skill(ws_skills, "alpha", body="# Alpha")
|
||||
_write_skill(ws_skills, "beta", body="# Beta")
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills=set())
|
||||
entries = loader.list_skills(filter_unavailable=False)
|
||||
assert len(entries) == 2
|
||||
|
||||
|
||||
def test_disabled_skills_excluded_from_build_skills_summary(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
ws_skills = workspace / "skills"
|
||||
ws_skills.mkdir(parents=True)
|
||||
_write_skill(ws_skills, "alpha", body="# Alpha")
|
||||
_write_skill(ws_skills, "beta", body="# Beta")
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"})
|
||||
summary = loader.build_skills_summary()
|
||||
assert "alpha" not in summary
|
||||
assert "beta" in summary
|
||||
|
||||
|
||||
def test_disabled_skills_excluded_from_get_always_skills(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
ws_skills = workspace / "skills"
|
||||
ws_skills.mkdir(parents=True)
|
||||
_write_skill(ws_skills, "alpha", metadata_json={"always": True}, body="# Alpha")
|
||||
_write_skill(ws_skills, "beta", metadata_json={"always": True}, body="# Beta")
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin, disabled_skills={"alpha"})
|
||||
always = loader.get_always_skills()
|
||||
assert "alpha" not in always
|
||||
assert "beta" in always
|
||||
|
||||
@ -52,6 +52,53 @@ class TestToolHintKnownTools:
|
||||
assert result.startswith("$ ")
|
||||
assert len(result) <= 50 # reasonable limit
|
||||
|
||||
def test_exec_abbreviates_paths_in_command(self):
|
||||
"""Windows paths in exec commands should be folded, not blindly truncated."""
|
||||
cmd = "cd D:\\Documents\\GitHub\\nanobot\\.worktree\\tomain\\nanobot && git diff origin/main...pr-2706 --name-only 2>&1"
|
||||
result = _hint([_tc("exec", {"command": cmd})])
|
||||
assert "\u2026/" in result # path should be folded with …/
|
||||
assert "worktree" not in result # middle segments should be collapsed
|
||||
|
||||
def test_exec_abbreviates_linux_paths(self):
|
||||
"""Unix absolute paths in exec commands should be folded."""
|
||||
cmd = "cd /home/user/projects/nanobot/.worktree/tomain && make build"
|
||||
result = _hint([_tc("exec", {"command": cmd})])
|
||||
assert "\u2026/" in result
|
||||
assert "projects" not in result
|
||||
|
||||
def test_exec_abbreviates_home_paths(self):
|
||||
"""~/ paths in exec commands should be folded."""
|
||||
cmd = "cd ~/projects/nanobot/workspace && pytest tests/"
|
||||
result = _hint([_tc("exec", {"command": cmd})])
|
||||
assert "\u2026/" in result
|
||||
|
||||
def test_exec_abbreviates_quoted_linux_paths_with_spaces(self):
|
||||
"""Quoted Unix paths with spaces should still be folded."""
|
||||
cmd = 'cd "/home/user/My Documents/project" && pytest tests/'
|
||||
result = _hint([_tc("exec", {"command": cmd})])
|
||||
assert "\u2026/" in result
|
||||
assert '"/home/user/My Documents/project"' not in result
|
||||
assert '"' in result
|
||||
|
||||
def test_exec_abbreviates_quoted_windows_paths_with_spaces(self):
|
||||
"""Quoted Windows paths with spaces should still be folded."""
|
||||
cmd = 'cd "C:/Program Files/Git/project" && git status'
|
||||
result = _hint([_tc("exec", {"command": cmd})])
|
||||
assert "\u2026/" in result
|
||||
assert '"C:/Program Files/Git/project"' not in result
|
||||
assert '"' in result
|
||||
|
||||
def test_exec_short_command_unchanged(self):
|
||||
result = _hint([_tc("exec", {"command": "npm install typescript"})])
|
||||
assert result == "$ npm install typescript"
|
||||
|
||||
def test_exec_chained_commands_truncated_not_mid_path(self):
|
||||
"""Long chained commands should truncate preserving abbreviated paths."""
|
||||
cmd = "cd D:\\Documents\\GitHub\\project && npm run build && npm test"
|
||||
result = _hint([_tc("exec", {"command": cmd})])
|
||||
assert "\u2026/" in result # path folded
|
||||
assert "npm" in result # chained command still visible
|
||||
|
||||
def test_web_search(self):
|
||||
result = _hint([_tc("web_search", {"query": "Claude 4 vs GPT-4"})])
|
||||
assert result == 'search "Claude 4 vs GPT-4"'
|
||||
@ -105,22 +152,30 @@ class TestToolHintFolding:
|
||||
result = _hint(calls)
|
||||
assert "\u00d7" not in result
|
||||
|
||||
def test_two_consecutive_same_folded(self):
|
||||
def test_two_consecutive_different_args_not_folded(self):
|
||||
calls = [
|
||||
_tc("grep", {"pattern": "*.py"}),
|
||||
_tc("grep", {"pattern": "*.ts"}),
|
||||
]
|
||||
result = _hint(calls)
|
||||
assert "\u00d7" not in result
|
||||
|
||||
def test_two_consecutive_same_args_folded(self):
|
||||
calls = [
|
||||
_tc("grep", {"pattern": "TODO"}),
|
||||
_tc("grep", {"pattern": "TODO"}),
|
||||
]
|
||||
result = _hint(calls)
|
||||
assert "\u00d7 2" in result
|
||||
|
||||
def test_three_consecutive_same_folded(self):
|
||||
def test_three_consecutive_different_args_not_folded(self):
|
||||
calls = [
|
||||
_tc("read_file", {"path": "a.py"}),
|
||||
_tc("read_file", {"path": "b.py"}),
|
||||
_tc("read_file", {"path": "c.py"}),
|
||||
]
|
||||
result = _hint(calls)
|
||||
assert "\u00d7 3" in result
|
||||
assert "\u00d7" not in result
|
||||
|
||||
def test_different_tools_not_folded(self):
|
||||
calls = [
|
||||
@ -187,7 +242,7 @@ class TestToolHintMixedFolding:
|
||||
"""G4: Mixed folding groups with interleaved same-tool segments."""
|
||||
|
||||
def test_read_read_grep_grep_read(self):
|
||||
"""read×2, grep×2, read — should produce two separate groups."""
|
||||
"""All different args — each hint listed separately."""
|
||||
calls = [
|
||||
_tc("read_file", {"path": "a.py"}),
|
||||
_tc("read_file", {"path": "b.py"}),
|
||||
@ -196,7 +251,6 @@ class TestToolHintMixedFolding:
|
||||
_tc("read_file", {"path": "c.py"}),
|
||||
]
|
||||
result = _hint(calls)
|
||||
assert "\u00d7 2" in result
|
||||
# Should have 3 groups: read×2, grep×2, read
|
||||
assert "\u00d7" not in result
|
||||
parts = result.split(", ")
|
||||
assert len(parts) == 3
|
||||
assert len(parts) == 5
|
||||
|
||||
502
tests/agent/test_unified_session.py
Normal file
502
tests/agent/test_unified_session.py
Normal file
@ -0,0 +1,502 @@
|
||||
"""Tests for unified_session feature.
|
||||
|
||||
Covers:
|
||||
- AgentLoop._dispatch() rewrites session_key to "unified:default" when enabled
|
||||
- Existing session_key_override is respected (not overwritten)
|
||||
- Feature is off by default (no behavior change for existing users)
|
||||
- Config schema serialises unified_session as camelCase "unifiedSession"
|
||||
- onboard-generated config.json contains "unifiedSession" key
|
||||
- /new command correctly clears the shared session in unified mode
|
||||
- /new is NOT a priority command (goes through _dispatch, key rewrite applies)
|
||||
- Context window consolidation is unaffected by unified_session
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.command.builtin import cmd_new, register_builtin_commands
|
||||
from nanobot.command.router import CommandContext, CommandRouter
|
||||
from nanobot.config.schema import AgentDefaults, Config
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_loop(tmp_path: Path, unified_session: bool = False) -> AgentLoop:
|
||||
"""Create a minimal AgentLoop for dispatch-level tests."""
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
with patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr, \
|
||||
patch("nanobot.agent.loop.Dream"):
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
unified_session=unified_session,
|
||||
)
|
||||
return loop
|
||||
|
||||
|
||||
def _make_msg(channel: str = "telegram", chat_id: str = "111",
|
||||
session_key_override: str | None = None) -> InboundMessage:
|
||||
return InboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
sender_id="user1",
|
||||
content="hello",
|
||||
session_key_override=session_key_override,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestUnifiedSessionDispatch — core behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUnifiedSessionDispatch:
|
||||
"""AgentLoop._dispatch() session key rewriting logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unified_session_rewrites_key_to_unified_default(self, tmp_path: Path):
|
||||
"""When unified_session=True, all messages use 'unified:default' as session key."""
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
async def fake_process(msg, **kwargs):
|
||||
captured.append(msg.session_key)
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process # type: ignore[method-assign]
|
||||
|
||||
msg = _make_msg(channel="telegram", chat_id="111")
|
||||
await loop._dispatch(msg)
|
||||
|
||||
assert captured == ["unified:default"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unified_session_different_channels_share_same_key(self, tmp_path: Path):
|
||||
"""Messages from different channels all resolve to the same session key."""
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
async def fake_process(msg, **kwargs):
|
||||
captured.append(msg.session_key)
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process # type: ignore[method-assign]
|
||||
|
||||
await loop._dispatch(_make_msg(channel="telegram", chat_id="111"))
|
||||
await loop._dispatch(_make_msg(channel="discord", chat_id="222"))
|
||||
await loop._dispatch(_make_msg(channel="cli", chat_id="direct"))
|
||||
|
||||
assert captured == ["unified:default", "unified:default", "unified:default"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unified_session_disabled_preserves_original_key(self, tmp_path: Path):
|
||||
"""When unified_session=False (default), session key is channel:chat_id as usual."""
|
||||
loop = _make_loop(tmp_path, unified_session=False)
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
async def fake_process(msg, **kwargs):
|
||||
captured.append(msg.session_key)
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process # type: ignore[method-assign]
|
||||
|
||||
msg = _make_msg(channel="telegram", chat_id="999")
|
||||
await loop._dispatch(msg)
|
||||
|
||||
assert captured == ["telegram:999"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unified_session_respects_existing_override(self, tmp_path: Path):
|
||||
"""If session_key_override is already set (e.g. Telegram thread), it is NOT overwritten."""
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
async def fake_process(msg, **kwargs):
|
||||
captured.append(msg.session_key)
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process # type: ignore[method-assign]
|
||||
|
||||
msg = _make_msg(channel="telegram", chat_id="111", session_key_override="telegram:thread:42")
|
||||
await loop._dispatch(msg)
|
||||
|
||||
assert captured == ["telegram:thread:42"]
|
||||
|
||||
def test_unified_session_default_is_false(self, tmp_path: Path):
|
||||
"""unified_session defaults to False — no behavior change for existing users."""
|
||||
loop = _make_loop(tmp_path)
|
||||
assert loop._unified_session is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestUnifiedSessionConfig — schema & serialisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUnifiedSessionConfig:
|
||||
"""Config schema and onboard serialisation for unified_session."""
|
||||
|
||||
def test_agent_defaults_unified_session_default_is_false(self):
|
||||
"""AgentDefaults.unified_session defaults to False."""
|
||||
defaults = AgentDefaults()
|
||||
assert defaults.unified_session is False
|
||||
|
||||
def test_agent_defaults_unified_session_can_be_enabled(self):
|
||||
"""AgentDefaults.unified_session can be set to True."""
|
||||
defaults = AgentDefaults(unified_session=True)
|
||||
assert defaults.unified_session is True
|
||||
|
||||
def test_config_serialises_unified_session_as_camel_case(self):
|
||||
"""model_dump(by_alias=True) outputs 'unifiedSession' (camelCase) for JSON."""
|
||||
config = Config()
|
||||
data = config.model_dump(mode="json", by_alias=True)
|
||||
agents_defaults = data["agents"]["defaults"]
|
||||
assert "unifiedSession" in agents_defaults
|
||||
assert agents_defaults["unifiedSession"] is False
|
||||
|
||||
def test_config_parses_unified_session_from_camel_case(self):
|
||||
"""Config can be loaded from JSON with camelCase 'unifiedSession'."""
|
||||
raw = {"agents": {"defaults": {"unifiedSession": True}}}
|
||||
config = Config.model_validate(raw)
|
||||
assert config.agents.defaults.unified_session is True
|
||||
|
||||
def test_config_parses_unified_session_from_snake_case(self):
|
||||
"""Config also accepts snake_case 'unified_session' (populate_by_name=True)."""
|
||||
raw = {"agents": {"defaults": {"unified_session": True}}}
|
||||
config = Config.model_validate(raw)
|
||||
assert config.agents.defaults.unified_session is True
|
||||
|
||||
def test_onboard_generated_config_contains_unified_session(self, tmp_path: Path):
|
||||
"""save_config() writes 'unifiedSession' into config.json (simulates nanobot onboard)."""
|
||||
from nanobot.config.loader import save_config
|
||||
|
||||
config = Config()
|
||||
config_path = tmp_path / "config.json"
|
||||
save_config(config, config_path)
|
||||
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
agents_defaults = data["agents"]["defaults"]
|
||||
assert "unifiedSession" in agents_defaults, (
|
||||
"onboard-generated config.json must contain 'unifiedSession' key"
|
||||
)
|
||||
assert agents_defaults["unifiedSession"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestCmdNewUnifiedSession — /new command behaviour in unified mode
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCmdNewUnifiedSession:
|
||||
"""/new command routing and session-clear behaviour in unified mode."""
|
||||
|
||||
def test_new_is_not_a_priority_command(self):
|
||||
"""/new must NOT be in the priority table — it must go through _dispatch()
|
||||
so the unified session key rewrite applies before cmd_new runs."""
|
||||
router = CommandRouter()
|
||||
register_builtin_commands(router)
|
||||
assert router.is_priority("/new") is False
|
||||
|
||||
def test_new_is_an_exact_command(self):
|
||||
"""/new must be registered as an exact command."""
|
||||
router = CommandRouter()
|
||||
register_builtin_commands(router)
|
||||
assert "/new" in router._exact
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cmd_new_clears_unified_session(self, tmp_path: Path):
|
||||
"""cmd_new called with key='unified:default' clears the shared session."""
|
||||
sessions = SessionManager(tmp_path)
|
||||
|
||||
# Pre-populate the shared session with some messages
|
||||
shared = sessions.get_or_create("unified:default")
|
||||
shared.add_message("user", "hello from telegram")
|
||||
shared.add_message("assistant", "hi there")
|
||||
sessions.save(shared)
|
||||
assert len(sessions.get_or_create("unified:default").messages) == 2
|
||||
|
||||
# _schedule_background is a *sync* method that schedules a coroutine via
|
||||
# asyncio.create_task(). Mirror that exactly so the coroutine is consumed
|
||||
# and no RuntimeWarning is emitted.
|
||||
loop = SimpleNamespace(
|
||||
sessions=sessions,
|
||||
consolidator=SimpleNamespace(archive=AsyncMock(return_value=True)),
|
||||
)
|
||||
loop._schedule_background = lambda coro: asyncio.ensure_future(coro)
|
||||
|
||||
msg = InboundMessage(
|
||||
channel="telegram", sender_id="user1", chat_id="111", content="/new",
|
||||
session_key_override="unified:default", # as _dispatch() would set it
|
||||
)
|
||||
ctx = CommandContext(msg=msg, session=None, key="unified:default", raw="/new", loop=loop)
|
||||
|
||||
result = await cmd_new(ctx)
|
||||
|
||||
assert "New session started" in result.content
|
||||
# Invalidate cache and reload from disk to confirm persistence
|
||||
sessions.invalidate("unified:default")
|
||||
reloaded = sessions.get_or_create("unified:default")
|
||||
assert reloaded.messages == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cmd_new_in_unified_mode_does_not_affect_other_sessions(self, tmp_path: Path):
|
||||
"""Clearing unified:default must not touch other sessions on disk."""
|
||||
sessions = SessionManager(tmp_path)
|
||||
|
||||
other = sessions.get_or_create("discord:999")
|
||||
other.add_message("user", "discord message")
|
||||
sessions.save(other)
|
||||
|
||||
shared = sessions.get_or_create("unified:default")
|
||||
shared.add_message("user", "shared message")
|
||||
sessions.save(shared)
|
||||
|
||||
loop = SimpleNamespace(
|
||||
sessions=sessions,
|
||||
consolidator=SimpleNamespace(archive=AsyncMock(return_value=True)),
|
||||
)
|
||||
loop._schedule_background = lambda coro: asyncio.ensure_future(coro)
|
||||
|
||||
msg = InboundMessage(
|
||||
channel="telegram", sender_id="user1", chat_id="111", content="/new",
|
||||
session_key_override="unified:default",
|
||||
)
|
||||
ctx = CommandContext(msg=msg, session=None, key="unified:default", raw="/new", loop=loop)
|
||||
await cmd_new(ctx)
|
||||
|
||||
sessions.invalidate("unified:default")
|
||||
sessions.invalidate("discord:999")
|
||||
assert sessions.get_or_create("unified:default").messages == []
|
||||
assert len(sessions.get_or_create("discord:999").messages) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestConsolidationUnaffectedByUnifiedSession — consolidation is key-agnostic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConsolidationUnaffectedByUnifiedSession:
|
||||
"""maybe_consolidate_by_tokens() behaviour is identical regardless of session key."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_skips_empty_session_for_unified_key(self):
|
||||
"""Empty unified:default session → consolidation exits immediately, archive not called."""
|
||||
from nanobot.agent.memory import Consolidator, MemoryStore
|
||||
|
||||
store = MagicMock(spec=MemoryStore)
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.chat_with_retry = AsyncMock(return_value=MagicMock(content="summary"))
|
||||
# Use spec= so MagicMock doesn't auto-generate AsyncMock for non-async methods,
|
||||
# which would leave unawaited coroutines and trigger RuntimeWarning.
|
||||
sessions = MagicMock(spec=SessionManager)
|
||||
|
||||
consolidator = Consolidator(
|
||||
store=store,
|
||||
provider=mock_provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=1000,
|
||||
build_messages=MagicMock(return_value=[]),
|
||||
get_tool_definitions=MagicMock(return_value=[]),
|
||||
max_completion_tokens=100,
|
||||
)
|
||||
consolidator.archive = AsyncMock()
|
||||
|
||||
session = Session(key="unified:default")
|
||||
session.messages = []
|
||||
|
||||
await consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
consolidator.archive.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_behaviour_identical_for_any_key(self):
|
||||
"""archive call count is the same for 'telegram:123' and 'unified:default'
|
||||
under identical token conditions."""
|
||||
from nanobot.agent.memory import Consolidator, MemoryStore
|
||||
|
||||
archive_calls: dict[str, int] = {}
|
||||
|
||||
for key in ("telegram:123", "unified:default"):
|
||||
store = MagicMock(spec=MemoryStore)
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.chat_with_retry = AsyncMock(return_value=MagicMock(content="summary"))
|
||||
sessions = MagicMock(spec=SessionManager)
|
||||
|
||||
consolidator = Consolidator(
|
||||
store=store,
|
||||
provider=mock_provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=1000,
|
||||
build_messages=MagicMock(return_value=[]),
|
||||
get_tool_definitions=MagicMock(return_value=[]),
|
||||
max_completion_tokens=100,
|
||||
)
|
||||
|
||||
session = Session(key=key)
|
||||
session.messages = [] # empty → exits immediately for both keys
|
||||
|
||||
consolidator.archive = AsyncMock()
|
||||
await consolidator.maybe_consolidate_by_tokens(session)
|
||||
archive_calls[key] = consolidator.archive.call_count
|
||||
|
||||
assert archive_calls["telegram:123"] == archive_calls["unified:default"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consolidation_triggers_when_over_budget_unified_key(self):
|
||||
"""When tokens exceed budget, consolidation attempts to find a boundary —
|
||||
behaviour is identical to any other session key."""
|
||||
from nanobot.agent.memory import Consolidator, MemoryStore
|
||||
|
||||
store = MagicMock(spec=MemoryStore)
|
||||
mock_provider = MagicMock()
|
||||
sessions = MagicMock(spec=SessionManager)
|
||||
|
||||
consolidator = Consolidator(
|
||||
store=store,
|
||||
provider=mock_provider,
|
||||
model="test-model",
|
||||
sessions=sessions,
|
||||
context_window_tokens=1000,
|
||||
build_messages=MagicMock(return_value=[]),
|
||||
get_tool_definitions=MagicMock(return_value=[]),
|
||||
max_completion_tokens=100,
|
||||
)
|
||||
|
||||
session = Session(key="unified:default")
|
||||
session.messages = [{"role": "user", "content": "msg"}]
|
||||
|
||||
# Simulate over-budget: estimated > budget
|
||||
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(950, "tiktoken"))
|
||||
# No valid boundary found → returns gracefully without archiving
|
||||
consolidator.pick_consolidation_boundary = MagicMock(return_value=None)
|
||||
consolidator.archive = AsyncMock()
|
||||
|
||||
await consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
# estimate was called (consolidation was attempted)
|
||||
consolidator.estimate_session_prompt_tokens.assert_called_once_with(session)
|
||||
# but archive was not called (no valid boundary)
|
||||
consolidator.archive.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestStopCommandWithUnifiedSession — /stop command integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStopCommandWithUnifiedSession:
|
||||
"""Verify /stop command works correctly with unified session enabled."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_tasks_use_effective_key_in_unified_mode(self, tmp_path: Path):
|
||||
"""When unified_session=True, tasks are stored under UNIFIED_SESSION_KEY."""
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
# Create a message from telegram channel
|
||||
msg = _make_msg(channel="telegram", chat_id="123456")
|
||||
|
||||
# Mock _dispatch to complete immediately
|
||||
async def fake_dispatch(m):
|
||||
pass
|
||||
|
||||
loop._dispatch = fake_dispatch # type: ignore[method-assign]
|
||||
|
||||
# Simulate the task creation flow (from _run loop)
|
||||
effective_key = UNIFIED_SESSION_KEY if loop._unified_session and not msg.session_key_override else msg.session_key
|
||||
task = asyncio.create_task(loop._dispatch(msg))
|
||||
loop._active_tasks.setdefault(effective_key, []).append(task)
|
||||
|
||||
# Wait for task to complete
|
||||
await task
|
||||
|
||||
# Verify the task is stored under UNIFIED_SESSION_KEY, not the original channel:chat_id
|
||||
assert UNIFIED_SESSION_KEY in loop._active_tasks
|
||||
assert "telegram:123456" not in loop._active_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_command_finds_task_in_unified_mode(self, tmp_path: Path):
|
||||
"""cmd_stop can cancel tasks when unified_session=True."""
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
from nanobot.command.builtin import cmd_stop
|
||||
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
# Create a long-running task stored under UNIFIED_SESSION_KEY
|
||||
async def long_running():
|
||||
await asyncio.sleep(10) # Will be cancelled
|
||||
|
||||
task = asyncio.create_task(long_running())
|
||||
loop._active_tasks[UNIFIED_SESSION_KEY] = [task]
|
||||
|
||||
# Create a message that would have session_key=UNIFIED_SESSION_KEY after dispatch
|
||||
msg = InboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123456",
|
||||
sender_id="user1",
|
||||
content="/stop",
|
||||
session_key_override=UNIFIED_SESSION_KEY, # Simulate post-dispatch state
|
||||
)
|
||||
|
||||
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
|
||||
|
||||
# Execute /stop
|
||||
result = await cmd_stop(ctx)
|
||||
|
||||
# Verify task was cancelled
|
||||
assert task.cancelled() or task.done()
|
||||
assert "Stopped 1 task" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path):
|
||||
"""In unified mode, /stop from one channel cancels tasks from another channel."""
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
from nanobot.command.builtin import cmd_stop
|
||||
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
# Create tasks from different channels, all stored under UNIFIED_SESSION_KEY
|
||||
async def long_running():
|
||||
await asyncio.sleep(10)
|
||||
|
||||
task1 = asyncio.create_task(long_running())
|
||||
task2 = asyncio.create_task(long_running())
|
||||
loop._active_tasks[UNIFIED_SESSION_KEY] = [task1, task2]
|
||||
|
||||
# /stop from discord should cancel tasks started from telegram
|
||||
msg = InboundMessage(
|
||||
channel="discord",
|
||||
chat_id="789012",
|
||||
sender_id="user2",
|
||||
content="/stop",
|
||||
session_key_override=UNIFIED_SESSION_KEY,
|
||||
)
|
||||
|
||||
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
|
||||
|
||||
result = await cmd_stop(ctx)
|
||||
|
||||
# Both tasks should be cancelled
|
||||
assert "Stopped 2 task" in result.content
|
||||
@ -1,6 +1,10 @@
|
||||
import asyncio
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Check optional dingtalk dependencies before running tests
|
||||
@ -50,6 +54,21 @@ class _FakeHttp:
|
||||
return self._next_response()
|
||||
|
||||
|
||||
class _NetworkErrorHttp:
|
||||
"""HTTP client stub that raises httpx.TransportError on every request."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def post(self, url: str, json=None, headers=None, **kwargs):
|
||||
self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
async def get(self, url: str, **kwargs):
|
||||
self.calls.append({"method": "GET", "url": url})
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_message_keeps_sender_id_and_routes_chat_id() -> None:
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["user1"])
|
||||
@ -221,3 +240,216 @@ async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None:
|
||||
assert "messageFiles/download" in channel._http.calls[0]["url"]
|
||||
assert channel._http.calls[0]["json"]["downloadCode"] == "code123"
|
||||
assert channel._http.calls[1]["method"] == "GET"
|
||||
|
||||
|
||||
def test_normalize_upload_payload_zips_html_attachment() -> None:
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
data, filename, content_type = channel._normalize_upload_payload(
|
||||
"report.html",
|
||||
b"<html><body>Hello</body></html>",
|
||||
"text/html",
|
||||
)
|
||||
|
||||
assert filename == "report.zip"
|
||||
assert content_type == "application/zip"
|
||||
|
||||
archive = zipfile.ZipFile(BytesIO(data))
|
||||
assert archive.namelist() == ["report.html"]
|
||||
assert archive.read("report.html") == b"<html><body>Hello</body></html>"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_ref_zips_html_before_upload(tmp_path, monkeypatch) -> None:
|
||||
channel = DingTalkChannel(
|
||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
html_path = tmp_path / "report.html"
|
||||
html_path.write_text("<html><body>Hello</body></html>", encoding="utf-8")
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def fake_upload_media(*, token, data, media_type, filename, content_type):
|
||||
captured.update(
|
||||
{
|
||||
"token": token,
|
||||
"data": data,
|
||||
"media_type": media_type,
|
||||
"filename": filename,
|
||||
"content_type": content_type,
|
||||
}
|
||||
)
|
||||
return "media-123"
|
||||
|
||||
async def fake_send_batch_message(token, chat_id, msg_key, msg_param):
|
||||
captured.update(
|
||||
{
|
||||
"sent_token": token,
|
||||
"chat_id": chat_id,
|
||||
"msg_key": msg_key,
|
||||
"msg_param": msg_param,
|
||||
}
|
||||
)
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(channel, "_upload_media", fake_upload_media)
|
||||
monkeypatch.setattr(channel, "_send_batch_message", fake_send_batch_message)
|
||||
|
||||
ok = await channel._send_media_ref("token-123", "user-1", str(html_path))
|
||||
|
||||
assert ok is True
|
||||
assert captured["media_type"] == "file"
|
||||
assert captured["filename"] == "report.zip"
|
||||
assert captured["content_type"] == "application/zip"
|
||||
assert captured["msg_key"] == "sampleFile"
|
||||
assert captured["msg_param"] == {
|
||||
"mediaId": "media-123",
|
||||
"fileName": "report.zip",
|
||||
"fileType": "zip",
|
||||
}
|
||||
|
||||
archive = zipfile.ZipFile(BytesIO(captured["data"]))
|
||||
assert archive.namelist() == ["report.html"]
|
||||
|
||||
|
||||
# ── Exception handling tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_batch_message_propagates_transport_error() -> None:
|
||||
"""Network/transport errors must re-raise so callers can retry."""
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
channel._http = _NetworkErrorHttp()
|
||||
|
||||
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||
await channel._send_batch_message(
|
||||
"token",
|
||||
"user123",
|
||||
"sampleMarkdown",
|
||||
{"text": "hello", "title": "Nanobot Reply"},
|
||||
)
|
||||
|
||||
# The POST was attempted exactly once
|
||||
assert len(channel._http.calls) == 1
|
||||
assert channel._http.calls[0]["method"] == "POST"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_batch_message_returns_false_on_api_error() -> None:
|
||||
"""DingTalk API-level errors (non-200 status, errcode != 0) should return False."""
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
|
||||
# Non-200 status code → API error → return False
|
||||
channel._http = _FakeHttp(responses=[_FakeResponse(400, {"errcode": 400})])
|
||||
result = await channel._send_batch_message(
|
||||
"token", "user123", "sampleMarkdown", {"text": "hello"}
|
||||
)
|
||||
assert result is False
|
||||
|
||||
# 200 with non-zero errcode → API error → return False
|
||||
channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 100})])
|
||||
result = await channel._send_batch_message(
|
||||
"token", "user123", "sampleMarkdown", {"text": "hello"}
|
||||
)
|
||||
assert result is False
|
||||
|
||||
# 200 with errcode=0 → success → return True
|
||||
channel._http = _FakeHttp(responses=[_FakeResponse(200, {"errcode": 0})])
|
||||
result = await channel._send_batch_message(
|
||||
"token", "user123", "sampleMarkdown", {"text": "hello"}
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_ref_short_circuits_on_transport_error() -> None:
|
||||
"""When the first send fails with a transport error, _send_media_ref must
|
||||
re-raise immediately instead of trying download+upload+fallback."""
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
channel._http = _NetworkErrorHttp()
|
||||
|
||||
# An image URL triggers the sampleImageMsg path first
|
||||
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
|
||||
|
||||
# Only one POST should have been attempted — no download/upload/fallback
|
||||
assert len(channel._http.calls) == 1
|
||||
assert channel._http.calls[0]["method"] == "POST"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_ref_short_circuits_on_download_transport_error() -> None:
|
||||
"""When the image URL send returns an API error (False) but the download
|
||||
for the fallback hits a transport error, it must re-raise rather than
|
||||
silently returning False."""
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
|
||||
# First POST (sampleImageMsg) returns API error → False, then GET (download) raises transport error
|
||||
class _MixedHttp:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def post(self, url, json=None, headers=None, **kwargs):
|
||||
self.calls.append({"method": "POST", "url": url})
|
||||
# API-level failure: 200 with errcode != 0
|
||||
return _FakeResponse(200, {"errcode": 100})
|
||||
|
||||
async def get(self, url, **kwargs):
|
||||
self.calls.append({"method": "GET", "url": url})
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
|
||||
channel._http = _MixedHttp()
|
||||
|
||||
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
|
||||
|
||||
# Should have attempted POST (image URL) and GET (download), but NOT upload
|
||||
assert len(channel._http.calls) == 2
|
||||
assert channel._http.calls[0]["method"] == "POST"
|
||||
assert channel._http.calls[1]["method"] == "GET"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_ref_short_circuits_on_upload_transport_error() -> None:
|
||||
"""When download succeeds but upload hits a transport error, must re-raise."""
|
||||
config = DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"])
|
||||
channel = DingTalkChannel(config, MessageBus())
|
||||
|
||||
image_bytes = b"\xff\xd8\xff\xe0" + b"\x00" * 100 # minimal JPEG-ish data
|
||||
|
||||
class _UploadFailsHttp:
|
||||
def __init__(self) -> None:
|
||||
self.calls: list[dict] = []
|
||||
|
||||
async def post(self, url, json=None, headers=None, files=None, **kwargs):
|
||||
self.calls.append({"method": "POST", "url": url})
|
||||
# If it's the upload endpoint, raise transport error
|
||||
if "media/upload" in url:
|
||||
raise httpx.ConnectError("Connection refused")
|
||||
# Otherwise (sampleImageMsg), return API error to trigger fallback
|
||||
return _FakeResponse(200, {"errcode": 100})
|
||||
|
||||
async def get(self, url, **kwargs):
|
||||
self.calls.append({"method": "GET", "url": url})
|
||||
resp = _FakeResponse(200)
|
||||
resp.content = image_bytes
|
||||
resp.headers = {"content-type": "image/jpeg"}
|
||||
return resp
|
||||
|
||||
channel._http = _UploadFailsHttp()
|
||||
|
||||
with pytest.raises(httpx.ConnectError, match="Connection refused"):
|
||||
await channel._send_media_ref("token", "user123", "https://example.com/photo.jpg")
|
||||
|
||||
# POST (image URL), GET (download), POST (upload) attempted — no further sends
|
||||
methods = [c["method"] for c in channel._http.calls]
|
||||
assert methods == ["POST", "GET", "POST"]
|
||||
|
||||
@ -5,11 +5,17 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
discord = pytest.importorskip("discord")
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.discord import DiscordBotClient, DiscordChannel, DiscordConfig
|
||||
from nanobot.channels.discord import (
|
||||
MAX_MESSAGE_LEN,
|
||||
DiscordBotClient,
|
||||
DiscordChannel,
|
||||
DiscordConfig,
|
||||
)
|
||||
from nanobot.command.builtin import build_help_text
|
||||
|
||||
|
||||
@ -18,9 +24,11 @@ class _FakeDiscordClient:
|
||||
instances: list["_FakeDiscordClient"] = []
|
||||
start_error: Exception | None = None
|
||||
|
||||
def __init__(self, owner, *, intents) -> None:
|
||||
def __init__(self, owner, *, intents, proxy=None, proxy_auth=None) -> None:
|
||||
self.owner = owner
|
||||
self.intents = intents
|
||||
self.proxy = proxy
|
||||
self.proxy_auth = proxy_auth
|
||||
self.closed = False
|
||||
self.ready = True
|
||||
self.channels: dict[int, object] = {}
|
||||
@ -53,7 +61,9 @@ class _FakeDiscordClient:
|
||||
|
||||
class _FakeAttachment:
|
||||
# Attachment double that can simulate successful or failing save() calls.
|
||||
def __init__(self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False) -> None:
|
||||
def __init__(
|
||||
self, attachment_id: int, filename: str, *, size: int = 1, fail: bool = False
|
||||
) -> None:
|
||||
self.id = attachment_id
|
||||
self.filename = filename
|
||||
self.size = size
|
||||
@ -71,11 +81,25 @@ class _FakePartialMessage:
|
||||
self.id = message_id
|
||||
|
||||
|
||||
class _FakeSentMessage:
|
||||
# Sent-message double supporting edit() for streaming tests.
|
||||
def __init__(self, channel, content: str) -> None:
|
||||
self.channel = channel
|
||||
self.content = content
|
||||
self.edits: list[dict] = []
|
||||
|
||||
async def edit(self, **kwargs) -> None:
|
||||
self.edits.append(dict(kwargs))
|
||||
if "content" in kwargs:
|
||||
self.content = kwargs["content"]
|
||||
|
||||
|
||||
class _FakeChannel:
|
||||
# Channel double that records outbound payloads and typing activity.
|
||||
def __init__(self, channel_id: int = 123) -> None:
|
||||
self.id = channel_id
|
||||
self.sent_payloads: list[dict] = []
|
||||
self.sent_messages: list[_FakeSentMessage] = []
|
||||
self.trigger_typing_calls = 0
|
||||
self.typing_enter_hook = None
|
||||
|
||||
@ -85,6 +109,9 @@ class _FakeChannel:
|
||||
payload["file_name"] = payload["file"].filename
|
||||
del payload["file"]
|
||||
self.sent_payloads.append(payload)
|
||||
message = _FakeSentMessage(self, payload.get("content", ""))
|
||||
self.sent_messages.append(message)
|
||||
return message
|
||||
|
||||
def get_partial_message(self, message_id: int) -> _FakePartialMessage:
|
||||
return _FakePartialMessage(message_id)
|
||||
@ -194,7 +221,7 @@ async def test_start_handles_client_construction_failure(monkeypatch) -> None:
|
||||
MessageBus(),
|
||||
)
|
||||
|
||||
def _boom(owner, *, intents):
|
||||
def _boom(owner, *, intents, proxy=None, proxy_auth=None):
|
||||
raise RuntimeError("bad client")
|
||||
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _boom)
|
||||
@ -427,6 +454,60 @@ async def test_send_fetches_channel_when_not_cached() -> None:
|
||||
assert target.sent_payloads == [{"content": "hello"}]
|
||||
|
||||
|
||||
def test_supports_streaming_enabled_by_default() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
|
||||
assert channel.supports_streaming is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_streams_by_editing_message(monkeypatch) -> None:
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(owner, intents=None)
|
||||
owner._client = client
|
||||
owner._running = True
|
||||
target = _FakeChannel(channel_id=123)
|
||||
client.channels[123] = target
|
||||
|
||||
times = iter([1.0, 3.0, 5.0])
|
||||
monkeypatch.setattr("nanobot.channels.discord.time.monotonic", lambda: next(times, 5.0))
|
||||
|
||||
await owner.send_delta("123", "hel", {"_stream_delta": True, "_stream_id": "s1"})
|
||||
await owner.send_delta("123", "lo", {"_stream_delta": True, "_stream_id": "s1"})
|
||||
await owner.send_delta("123", "", {"_stream_end": True, "_stream_id": "s1"})
|
||||
|
||||
assert target.sent_payloads[0] == {"content": "hel"}
|
||||
assert target.sent_messages[0].edits == [{"content": "hello"}, {"content": "hello"}]
|
||||
assert owner._stream_bufs == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_splits_oversized_reply(monkeypatch) -> None:
|
||||
owner = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(owner, intents=None)
|
||||
owner._client = client
|
||||
owner._running = True
|
||||
target = _FakeChannel(channel_id=123)
|
||||
client.channels[123] = target
|
||||
|
||||
prefix = "a" * (MAX_MESSAGE_LEN - 100)
|
||||
suffix = "b" * 150
|
||||
full_text = prefix + suffix
|
||||
chunks = DiscordBotClient._build_chunks(full_text, [], False)
|
||||
assert len(chunks) == 2
|
||||
|
||||
times = iter([1.0, 3.0])
|
||||
monkeypatch.setattr("nanobot.channels.discord.time.monotonic", lambda: next(times, 3.0))
|
||||
|
||||
await owner.send_delta("123", prefix, {"_stream_delta": True, "_stream_id": "s1"})
|
||||
await owner.send_delta("123", suffix, {"_stream_delta": True, "_stream_id": "s1"})
|
||||
await owner.send_delta("123", "", {"_stream_end": True, "_stream_id": "s1"})
|
||||
|
||||
assert target.sent_payloads == [{"content": prefix}, {"content": chunks[1]}]
|
||||
assert target.sent_messages[0].edits == [{"content": chunks[0]}, {"content": chunks[0]}]
|
||||
assert owner._stream_bufs == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["123"]), MessageBus())
|
||||
@ -443,9 +524,7 @@ async def test_slash_new_forwards_when_user_is_allowlisted() -> None:
|
||||
assert new_cmd is not None
|
||||
await new_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": "Processing /new...", "ephemeral": True}
|
||||
]
|
||||
assert interaction.response.messages == [{"content": "Processing /new...", "ephemeral": True}]
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["content"] == "/new"
|
||||
assert handled[0]["sender_id"] == "123"
|
||||
@ -519,9 +598,7 @@ async def test_slash_help_returns_ephemeral_help_text() -> None:
|
||||
assert help_cmd is not None
|
||||
await help_cmd.callback(interaction)
|
||||
|
||||
assert interaction.response.messages == [
|
||||
{"content": build_help_text(), "ephemeral": True}
|
||||
]
|
||||
assert interaction.response.messages == [{"content": build_help_text(), "ephemeral": True}]
|
||||
assert handled == []
|
||||
|
||||
|
||||
@ -656,11 +733,13 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
||||
def typing(self):
|
||||
async def _waiter():
|
||||
await release.wait()
|
||||
|
||||
# Hold the loop so task remains active until explicitly stopped.
|
||||
class _Ctx(_TypingCtx):
|
||||
async def __aenter__(self):
|
||||
await super().__aenter__()
|
||||
await _waiter()
|
||||
|
||||
return _Ctx()
|
||||
|
||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||
@ -674,3 +753,214 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
def test_config_accepts_proxy_fields() -> None:
|
||||
config = DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
proxy_username="user",
|
||||
proxy_password="pass",
|
||||
)
|
||||
assert config.proxy == "http://127.0.0.1:7890"
|
||||
assert config.proxy_username == "user"
|
||||
assert config.proxy_password == "pass"
|
||||
|
||||
|
||||
def test_config_proxy_defaults_to_none() -> None:
|
||||
config = DiscordConfig(enabled=True, token="token", allow_from=["*"])
|
||||
assert config.proxy is None
|
||||
assert config.proxy_username is None
|
||||
assert config.proxy_password is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_passes_proxy_to_client(monkeypatch) -> None:
|
||||
_FakeDiscordClient.instances.clear()
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert len(_FakeDiscordClient.instances) == 1
|
||||
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_passes_proxy_auth_when_credentials_provided(monkeypatch) -> None:
|
||||
aiohttp = pytest.importorskip("aiohttp")
|
||||
_FakeDiscordClient.instances.clear()
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
proxy_username="user",
|
||||
proxy_password="pass",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert len(_FakeDiscordClient.instances) == 1
|
||||
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is not None
|
||||
assert isinstance(_FakeDiscordClient.instances[0].proxy_auth, aiohttp.BasicAuth)
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth.login == "user"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth.password == "pass"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_no_proxy_auth_when_only_username(monkeypatch) -> None:
|
||||
_FakeDiscordClient.instances.clear()
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
proxy_username="user",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_no_proxy_auth_when_only_password(monkeypatch) -> None:
|
||||
_FakeDiscordClient.instances.clear()
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(
|
||||
enabled=True,
|
||||
token="token",
|
||||
allow_from=["*"],
|
||||
proxy="http://127.0.0.1:7890",
|
||||
proxy_password="pass",
|
||||
),
|
||||
MessageBus(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.channels.discord.DiscordBotClient", _FakeDiscordClient)
|
||||
|
||||
await channel.start()
|
||||
|
||||
assert channel.is_running is False
|
||||
assert _FakeDiscordClient.instances[0].proxy == "http://127.0.0.1:7890"
|
||||
assert _FakeDiscordClient.instances[0].proxy_auth is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for the send() exception propagation fix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_re_raises_network_error() -> None:
|
||||
"""Network errors during send must propagate so ChannelManager can retry."""
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
async def _failing_send_outbound(msg: OutboundMessage) -> None:
|
||||
raise ConnectionError("network unreachable")
|
||||
|
||||
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(ConnectionError, match="network unreachable"):
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_re_raises_generic_exception() -> None:
|
||||
"""Any exception from send_outbound must propagate, not be swallowed."""
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
async def _failing_send_outbound(msg: OutboundMessage) -> None:
|
||||
raise RuntimeError("discord API failure")
|
||||
|
||||
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(RuntimeError, match="discord API failure"):
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_still_stops_typing_on_error() -> None:
|
||||
"""Typing cleanup must still run in the finally block even when send raises."""
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
# Start a typing task so we can verify it gets cleaned up
|
||||
start = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def slow_typing() -> None:
|
||||
start.set()
|
||||
await release.wait()
|
||||
|
||||
typing_channel = _FakeChannel(channel_id=123)
|
||||
typing_channel.typing_enter_hook = slow_typing
|
||||
await channel._start_typing(typing_channel)
|
||||
await asyncio.wait_for(start.wait(), timeout=1.0)
|
||||
|
||||
async def _failing_send_outbound(msg: OutboundMessage) -> None:
|
||||
raise ConnectionError("timeout")
|
||||
|
||||
client.send_outbound = _failing_send_outbound # type: ignore[method-assign]
|
||||
|
||||
with pytest.raises(ConnectionError, match="timeout"):
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
|
||||
release.set()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Typing should have been cleaned up by the finally block
|
||||
assert channel._typing_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_succeeds_normally() -> None:
|
||||
"""Successful sends should work without raising."""
|
||||
channel = DiscordChannel(DiscordConfig(enabled=True, allow_from=["*"]), MessageBus())
|
||||
client = _FakeDiscordClient(channel, intents=None)
|
||||
channel._client = client
|
||||
channel._running = True
|
||||
|
||||
sent_messages: list[OutboundMessage] = []
|
||||
|
||||
async def _capture_send_outbound(msg: OutboundMessage) -> None:
|
||||
sent_messages.append(msg)
|
||||
|
||||
client.send_outbound = _capture_send_outbound # type: ignore[method-assign]
|
||||
|
||||
msg = OutboundMessage(channel="discord", chat_id="123", content="hello world")
|
||||
await channel.send(msg)
|
||||
|
||||
assert len(sent_messages) == 1
|
||||
assert sent_messages[0].content == "hello world"
|
||||
assert sent_messages[0].chat_id == "123"
|
||||
|
||||
48
tests/channels/test_feishu_domain.py
Normal file
48
tests/channels/test_feishu_domain.py
Normal file
@ -0,0 +1,48 @@
|
||||
"""Tests for Feishu/Lark domain configuration."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
|
||||
|
||||
|
||||
def _make_channel(domain: str = "feishu") -> FeishuChannel:
|
||||
config = FeishuConfig(
|
||||
enabled=True,
|
||||
app_id="cli_test",
|
||||
app_secret="secret",
|
||||
allow_from=["*"],
|
||||
domain=domain,
|
||||
)
|
||||
ch = FeishuChannel(config, MessageBus())
|
||||
ch._client = MagicMock()
|
||||
ch._loop = None
|
||||
return ch
|
||||
|
||||
|
||||
class TestFeishuConfigDomain:
|
||||
def test_domain_default_is_feishu(self):
|
||||
config = FeishuConfig()
|
||||
assert config.domain == "feishu"
|
||||
|
||||
def test_domain_accepts_lark(self):
|
||||
config = FeishuConfig(domain="lark")
|
||||
assert config.domain == "lark"
|
||||
|
||||
def test_domain_accepts_feishu(self):
|
||||
config = FeishuConfig(domain="feishu")
|
||||
assert config.domain == "feishu"
|
||||
|
||||
def test_default_config_includes_domain(self):
|
||||
default_cfg = FeishuChannel.default_config()
|
||||
assert "domain" in default_cfg
|
||||
assert default_cfg["domain"] == "feishu"
|
||||
|
||||
def test_channel_persists_domain_from_config(self):
|
||||
ch = _make_channel(domain="lark")
|
||||
assert ch.config.domain == "lark"
|
||||
|
||||
def test_channel_persists_feishu_domain_from_config(self):
|
||||
ch = _make_channel(domain="feishu")
|
||||
assert ch.config.domain == "feishu"
|
||||
@ -5,6 +5,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig, _FeishuStreamBuf
|
||||
|
||||
@ -203,6 +204,24 @@ class TestSendDelta:
|
||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||
ch._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_fallback_when_final_update_fails(self):
|
||||
"""If streaming mode was closed (e.g. Feishu timeout), fall back to a regular card."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Lost content", card_id="card_1", sequence=3, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response(success=False)
|
||||
ch._client.im.v1.message.create.return_value = _mock_send_response("om_fb")
|
||||
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
# Should NOT attempt to close streaming mode since update failed
|
||||
ch._client.cardkit.v1.card.settings.assert_not_called()
|
||||
# Should fall back to sending a regular interactive card
|
||||
ch._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_end_without_buf_is_noop(self):
|
||||
ch = _make_channel()
|
||||
@ -239,6 +258,130 @@ class TestSendDelta:
|
||||
assert buf.sequence == 7
|
||||
|
||||
|
||||
class TestToolHintInlineStreaming:
|
||||
"""Tool hint messages should be inlined into active streaming cards."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_hint_inlined_when_stream_active(self):
|
||||
"""With an active streaming buffer, tool hint appends to the card."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel="feishu", chat_id="oc_chat1",
|
||||
content='web_fetch("https://example.com")',
|
||||
metadata={"_tool_hint": True},
|
||||
)
|
||||
await ch.send(msg)
|
||||
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert '🔧 web_fetch("https://example.com")' in buf.text
|
||||
assert buf.sequence == 3
|
||||
ch._client.cardkit.v1.card_element.content.assert_called_once()
|
||||
ch._client.im.v1.message.create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_hint_preserved_on_next_delta(self):
|
||||
"""When new delta arrives, the tool hint is kept as permanent content and delta appends after it."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Partial answer\n\n🔧 web_fetch(\"url\")\n\n",
|
||||
card_id="card_1", sequence=3, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
|
||||
await ch.send_delta("oc_chat1", " continued")
|
||||
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert "Partial answer" in buf.text
|
||||
assert "🔧 web_fetch" in buf.text
|
||||
assert buf.text.endswith(" continued")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_hint_fallback_when_no_stream(self):
|
||||
"""Without an active buffer, tool hint falls back to a standalone card."""
|
||||
ch = _make_channel()
|
||||
ch._client.im.v1.message.create.return_value = _mock_send_response("om_hint")
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel="feishu", chat_id="oc_chat1",
|
||||
content='read_file("path")',
|
||||
metadata={"_tool_hint": True},
|
||||
)
|
||||
await ch.send(msg)
|
||||
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
ch._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consecutive_tool_hints_append(self):
|
||||
"""When multiple tool hints arrive consecutively, each appends to the card."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
|
||||
msg1 = OutboundMessage(
|
||||
channel="feishu", chat_id="oc_chat1",
|
||||
content='$ cd /project', metadata={"_tool_hint": True},
|
||||
)
|
||||
await ch.send(msg1)
|
||||
|
||||
msg2 = OutboundMessage(
|
||||
channel="feishu", chat_id="oc_chat1",
|
||||
content='$ git status', metadata={"_tool_hint": True},
|
||||
)
|
||||
await ch.send(msg2)
|
||||
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert "$ cd /project" in buf.text
|
||||
assert "$ git status" in buf.text
|
||||
assert buf.text.startswith("Partial answer")
|
||||
assert "🔧 $ cd /project" in buf.text
|
||||
assert "🔧 $ git status" in buf.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_hint_preserved_on_final_stream_end(self):
|
||||
"""When final _stream_end closes the card, tool hint is kept in the final text."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Final content\n\n🔧 web_fetch(\"url\")\n\n",
|
||||
card_id="card_1", sequence=3, last_edit=0.0,
|
||||
)
|
||||
ch._client.cardkit.v1.card_element.content.return_value = _mock_content_response()
|
||||
ch._client.cardkit.v1.card.settings.return_value = _mock_content_response()
|
||||
|
||||
await ch.send_delta("oc_chat1", "", metadata={"_stream_end": True})
|
||||
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
update_call = ch._client.cardkit.v1.card_element.content.call_args[0][0]
|
||||
assert "🔧" in update_call.body.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_tool_hint_is_noop(self):
|
||||
"""Empty or whitespace-only tool hint content is silently ignored."""
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Partial answer", card_id="card_1", sequence=2, last_edit=0.0,
|
||||
)
|
||||
|
||||
for content in ("", " ", "\t\n"):
|
||||
msg = OutboundMessage(
|
||||
channel="feishu", chat_id="oc_chat1",
|
||||
content=content, metadata={"_tool_hint": True},
|
||||
)
|
||||
await ch.send(msg)
|
||||
|
||||
buf = ch._stream_bufs["oc_chat1"]
|
||||
assert buf.text == "Partial answer"
|
||||
assert buf.sequence == 2
|
||||
ch._client.cardkit.v1.card_element.content.assert_not_called()
|
||||
|
||||
|
||||
class TestSendMessageReturnsId:
|
||||
def test_returns_message_id_on_success(self):
|
||||
ch = _make_channel()
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Tests for FeishuChannel tool hint code block formatting."""
|
||||
"""Tests for FeishuChannel tool hint formatting."""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -28,15 +29,24 @@ def mock_feishu_channel():
|
||||
config.app_secret = "test_app_secret"
|
||||
config.encrypt_key = None
|
||||
config.verification_token = None
|
||||
config.tool_hint_prefix = "\U0001f527" # 🔧
|
||||
bus = MagicMock()
|
||||
channel = FeishuChannel(config, bus)
|
||||
channel._client = MagicMock() # Simulate initialized client
|
||||
channel._client = MagicMock()
|
||||
return channel
|
||||
|
||||
|
||||
def _get_tool_hint_card(mock_send):
|
||||
"""Extract the interactive card from _send_message_sync calls."""
|
||||
call_args = mock_send.call_args[0]
|
||||
_, _, msg_type, content = call_args
|
||||
assert msg_type == "interactive"
|
||||
return json.loads(content)
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_sends_code_message(mock_feishu_channel):
|
||||
"""Tool hint messages should be sent as interactive cards with code blocks."""
|
||||
async def test_tool_hint_sends_interactive_card(mock_feishu_channel):
|
||||
"""Tool hint without active buffer sends an interactive card with 🔧 style."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
@ -47,23 +57,12 @@ async def test_tool_hint_sends_code_message(mock_feishu_channel):
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
# Verify interactive message with card was sent
|
||||
assert mock_send.call_count == 1
|
||||
call_args = mock_send.call_args[0]
|
||||
receive_id_type, receive_id, msg_type, content = call_args
|
||||
|
||||
assert receive_id_type == "chat_id"
|
||||
assert receive_id == "oc_123456"
|
||||
assert msg_type == "interactive"
|
||||
|
||||
# Parse content to verify card structure
|
||||
card = json.loads(content)
|
||||
card = _get_tool_hint_card(mock_send)
|
||||
assert card["config"]["wide_screen_mode"] is True
|
||||
assert len(card["elements"]) == 1
|
||||
assert card["elements"][0]["tag"] == "markdown"
|
||||
# Check that code block is properly formatted with language hint
|
||||
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"test query\")\n```"
|
||||
assert card["elements"][0]["content"] == expected_md
|
||||
md = card["elements"][0]["content"]
|
||||
assert "\U0001f527" in md
|
||||
assert "web_search" in md
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
@ -78,8 +77,6 @@ async def test_tool_hint_empty_content_does_not_send(mock_feishu_channel):
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
# Should not send any message
|
||||
mock_send.assert_not_called()
|
||||
|
||||
|
||||
@ -96,7 +93,6 @@ async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel):
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
# Should send as text message (detected format)
|
||||
assert mock_send.call_count == 1
|
||||
call_args = mock_send.call_args[0]
|
||||
_, _, msg_type, content = call_args
|
||||
@ -106,7 +102,7 @@ async def test_tool_hint_without_metadata_sends_as_normal(mock_feishu_channel):
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
|
||||
"""Multiple tool calls should be displayed each on its own line in a code block."""
|
||||
"""Multiple tool calls should each get the 🔧 prefix."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
@ -117,13 +113,11 @@ async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
call_args = mock_send.call_args[0]
|
||||
msg_type = call_args[2]
|
||||
content = json.loads(call_args[3])
|
||||
assert msg_type == "interactive"
|
||||
# Each tool call should be on its own line
|
||||
expected_md = "**Tool Calls**\n\n```text\nweb_search(\"query\"),\nread_file(\"/path/to/file\")\n```"
|
||||
assert content["elements"][0]["content"] == expected_md
|
||||
card = _get_tool_hint_card(mock_send)
|
||||
md = card["elements"][0]["content"]
|
||||
assert "web_search" in md
|
||||
assert "read_file" in md
|
||||
assert "\U0001f527" in md
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
@ -139,8 +133,8 @@ async def test_tool_hint_new_format_basic(mock_feishu_channel):
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
content = json.loads(mock_send.call_args[0][3])
|
||||
md = content["elements"][0]["content"]
|
||||
card = _get_tool_hint_card(mock_send)
|
||||
md = card["elements"][0]["content"]
|
||||
assert "read src/main.py" in md
|
||||
assert 'grep "TODO"' in md
|
||||
|
||||
@ -158,16 +152,15 @@ async def test_tool_hint_new_format_with_comma_in_quotes(mock_feishu_channel):
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
content = json.loads(mock_send.call_args[0][3])
|
||||
md = content["elements"][0]["content"]
|
||||
# The comma inside quotes should NOT cause a line break
|
||||
card = _get_tool_hint_card(mock_send)
|
||||
md = card["elements"][0]["content"]
|
||||
assert 'grep "hello, world"' in md
|
||||
assert "$ echo test" in md
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_new_format_with_folding(mock_feishu_channel):
|
||||
"""Folded calls (× N) should display on separate lines."""
|
||||
"""Folded calls (× N) should display correctly."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
@ -178,8 +171,8 @@ async def test_tool_hint_new_format_with_folding(mock_feishu_channel):
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
content = json.loads(mock_send.call_args[0][3])
|
||||
md = content["elements"][0]["content"]
|
||||
card = _get_tool_hint_card(mock_send)
|
||||
md = card["elements"][0]["content"]
|
||||
assert "\u00d7 3" in md
|
||||
assert 'grep "pattern"' in md
|
||||
|
||||
@ -197,9 +190,12 @@ async def test_tool_hint_new_format_mcp(mock_feishu_channel):
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
content = json.loads(mock_send.call_args[0][3])
|
||||
md = content["elements"][0]["content"]
|
||||
card = _get_tool_hint_card(mock_send)
|
||||
md = card["elements"][0]["content"]
|
||||
assert "4_5v::analyze_image" in md
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
|
||||
"""Commas inside a single tool argument must not be split onto a new line."""
|
||||
msg = OutboundMessage(
|
||||
@ -212,10 +208,7 @@ async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
content = json.loads(mock_send.call_args[0][3])
|
||||
expected_md = (
|
||||
"**Tool Calls**\n\n```text\n"
|
||||
"web_search(\"foo, bar\"),\n"
|
||||
"read_file(\"/path/to/file\")\n```"
|
||||
)
|
||||
assert content["elements"][0]["content"] == expected_md
|
||||
card = _get_tool_hint_card(mock_send)
|
||||
md = card["elements"][0]["content"]
|
||||
assert 'web_search("foo, bar")' in md
|
||||
assert 'read_file("/path/to/file")' in md
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -14,6 +15,8 @@ except ImportError:
|
||||
if not QQ_AVAILABLE:
|
||||
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
||||
|
||||
import aiohttp
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.qq import QQChannel, QQConfig
|
||||
@ -170,3 +173,221 @@ async def test_read_media_bytes_missing_file() -> None:
|
||||
data, filename = await channel._read_media_bytes("/nonexistent/path/image.png")
|
||||
assert data is None
|
||||
assert filename is None
|
||||
|
||||
|
||||
# -------------------------------------------------------
|
||||
# Tests for _send_media exception handling
|
||||
# -------------------------------------------------------
|
||||
|
||||
def _make_channel_with_local_file(suffix: str = ".png", content: bytes = b"\x89PNG\r\n"):
|
||||
"""Create a QQChannel with a fake client and a temp file for media."""
|
||||
channel = QQChannel(
|
||||
QQConfig(app_id="app", secret="secret", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._client = _FakeClient()
|
||||
channel._chat_type_cache["user1"] = "c2c"
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
tmp.write(content)
|
||||
tmp.close()
|
||||
return channel, tmp.name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_network_error_propagates() -> None:
|
||||
"""aiohttp.ClientError (network/transport) should re-raise, not return False."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
# Make the base64 upload raise a network error
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=aiohttp.ServerDisconnectedError("connection lost"),
|
||||
)
|
||||
|
||||
with pytest.raises(aiohttp.ServerDisconnectedError):
|
||||
await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_client_connector_error_propagates() -> None:
|
||||
"""aiohttp.ClientConnectorError (DNS/connection refused) should re-raise."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
from aiohttp.client_reqrep import ConnectionKey
|
||||
conn_key = ConnectionKey("api.qq.com", 443, True, None, None, None, None)
|
||||
connector_error = aiohttp.ClientConnectorError(
|
||||
connection_key=conn_key,
|
||||
os_error=OSError("Connection refused"),
|
||||
)
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=connector_error,
|
||||
)
|
||||
|
||||
with pytest.raises(aiohttp.ClientConnectorError):
|
||||
await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_oserror_propagates() -> None:
|
||||
"""OSError (low-level I/O) should re-raise for retry."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=OSError("Network is unreachable"),
|
||||
)
|
||||
|
||||
with pytest.raises(OSError):
|
||||
await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_api_error_returns_false() -> None:
|
||||
"""API-level errors (botpy RuntimeError subclasses) should return False, not raise."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
# Simulate a botpy API error (e.g. ServerError is a RuntimeError subclass)
|
||||
from botpy.errors import ServerError
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=ServerError("internal server error"),
|
||||
)
|
||||
|
||||
result = await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_generic_runtime_error_returns_false() -> None:
|
||||
"""Generic RuntimeError (not network) should return False."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=RuntimeError("some API error"),
|
||||
)
|
||||
|
||||
result = await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_value_error_returns_false() -> None:
|
||||
"""ValueError (bad API response data) should return False."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=ValueError("bad response data"),
|
||||
)
|
||||
|
||||
result = await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_timeout_error_propagates() -> None:
|
||||
"""asyncio.TimeoutError inherits from Exception but not ClientError/OSError.
|
||||
However, aiohttp.ServerTimeoutError IS a ClientError subclass, so that propagates.
|
||||
For a plain TimeoutError (which is also OSError in Python 3.11+), it should propagate."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=aiohttp.ServerTimeoutError("request timed out"),
|
||||
)
|
||||
|
||||
with pytest.raises(aiohttp.ServerTimeoutError):
|
||||
await channel._send_media(
|
||||
chat_id="user1",
|
||||
media_ref=tmp_path,
|
||||
msg_id="msg1",
|
||||
is_group=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_fallback_text_on_api_error() -> None:
|
||||
"""When _send_media returns False (API error), send() should emit fallback text."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
from botpy.errors import ServerError
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=ServerError("internal server error"),
|
||||
)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user1",
|
||||
content="",
|
||||
media=[tmp_path],
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
# Should have sent a fallback text message
|
||||
assert len(channel._client.api.c2c_calls) == 1
|
||||
fallback_content = channel._client.api.c2c_calls[0]["content"]
|
||||
assert "Attachment send failed" in fallback_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_propagates_network_error_no_fallback() -> None:
|
||||
"""When _send_media raises a network error, send() should NOT silently fallback."""
|
||||
channel, tmp_path = _make_channel_with_local_file()
|
||||
|
||||
channel._client.api._http = SimpleNamespace()
|
||||
channel._client.api._http.request = AsyncMock(
|
||||
side_effect=aiohttp.ServerDisconnectedError("connection lost"),
|
||||
)
|
||||
|
||||
with pytest.raises(aiohttp.ServerDisconnectedError):
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user1",
|
||||
content="hello",
|
||||
media=[tmp_path],
|
||||
metadata={"message_id": "msg1"},
|
||||
)
|
||||
)
|
||||
|
||||
# No fallback text should have been sent
|
||||
assert len(channel._client.api.c2c_calls) == 0
|
||||
|
||||
304
tests/channels/test_qq_media.py
Normal file
304
tests/channels/test_qq_media.py
Normal file
@ -0,0 +1,304 @@
|
||||
"""Tests for QQ channel media support: helpers, send, inbound, and upload."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from nanobot.channels import qq
|
||||
|
||||
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
|
||||
except ImportError:
|
||||
QQ_AVAILABLE = False
|
||||
|
||||
if not QQ_AVAILABLE:
|
||||
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.qq import (
|
||||
QQ_FILE_TYPE_FILE,
|
||||
QQ_FILE_TYPE_IMAGE,
|
||||
QQChannel,
|
||||
QQConfig,
|
||||
_guess_send_file_type,
|
||||
_is_image_name,
|
||||
_sanitize_filename,
|
||||
)
|
||||
|
||||
|
||||
class _FakeApi:
|
||||
def __init__(self) -> None:
|
||||
self.c2c_calls: list[dict] = []
|
||||
self.group_calls: list[dict] = []
|
||||
|
||||
async def post_c2c_message(self, **kwargs) -> None:
|
||||
self.c2c_calls.append(kwargs)
|
||||
|
||||
async def post_group_message(self, **kwargs) -> None:
|
||||
self.group_calls.append(kwargs)
|
||||
|
||||
|
||||
class _FakeHttp:
|
||||
"""Fake _http for _post_base64file tests."""
|
||||
|
||||
def __init__(self, return_value: dict | None = None) -> None:
|
||||
self.return_value = return_value or {}
|
||||
self.calls: list[tuple] = []
|
||||
|
||||
async def request(self, route, **kwargs):
|
||||
self.calls.append((route, kwargs))
|
||||
return self.return_value
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, http_return: dict | None = None) -> None:
|
||||
self.api = _FakeApi()
|
||||
self.api._http = _FakeHttp(http_return)
|
||||
|
||||
|
||||
# ── Helper function tests (pure, no async) ──────────────────────────
|
||||
|
||||
|
||||
def test_sanitize_filename_strips_path_traversal() -> None:
|
||||
assert _sanitize_filename("../../etc/passwd") == "passwd"
|
||||
|
||||
|
||||
def test_sanitize_filename_keeps_chinese_chars() -> None:
|
||||
assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg"
|
||||
|
||||
|
||||
def test_sanitize_filename_strips_unsafe_chars() -> None:
|
||||
result = _sanitize_filename('file<>:"|?*.txt')
|
||||
# All unsafe chars replaced with "_", but * is replaced too
|
||||
assert result.startswith("file")
|
||||
assert result.endswith(".txt")
|
||||
assert "<" not in result
|
||||
assert ">" not in result
|
||||
assert '"' not in result
|
||||
assert "|" not in result
|
||||
assert "?" not in result
|
||||
|
||||
|
||||
def test_sanitize_filename_empty_input() -> None:
|
||||
assert _sanitize_filename("") == ""
|
||||
|
||||
|
||||
def test_is_image_name_with_known_extensions() -> None:
|
||||
for ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".ico", ".svg"):
|
||||
assert _is_image_name(f"photo{ext}") is True
|
||||
|
||||
|
||||
def test_is_image_name_with_unknown_extension() -> None:
|
||||
for ext in (".pdf", ".txt", ".mp3", ".mp4"):
|
||||
assert _is_image_name(f"doc{ext}") is False
|
||||
|
||||
|
||||
def test_guess_send_file_type_image() -> None:
|
||||
assert _guess_send_file_type("photo.png") == QQ_FILE_TYPE_IMAGE
|
||||
assert _guess_send_file_type("pic.jpg") == QQ_FILE_TYPE_IMAGE
|
||||
|
||||
|
||||
def test_guess_send_file_type_file() -> None:
|
||||
assert _guess_send_file_type("doc.pdf") == QQ_FILE_TYPE_FILE
|
||||
|
||||
|
||||
def test_guess_send_file_type_by_mime() -> None:
|
||||
# A filename with no known extension but whose mime type is image/*
|
||||
assert _guess_send_file_type("photo.xyz_image_test") == QQ_FILE_TYPE_FILE
|
||||
|
||||
|
||||
# ── send() exception handling ───────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_exception_caught_not_raised() -> None:
|
||||
"""Exceptions inside send() must not propagate."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with patch.object(channel, "_send_text_only", new_callable=AsyncMock, side_effect=RuntimeError("boom")):
|
||||
await channel.send(
|
||||
OutboundMessage(channel="qq", chat_id="user1", content="hello")
|
||||
)
|
||||
# No exception raised — test passes if we get here.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_then_text() -> None:
|
||||
"""Media is sent before text when both are present."""
|
||||
import tempfile
|
||||
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
with patch.object(channel, "_post_base64file", new_callable=AsyncMock, return_value={"file_info": "1"}) as mock_upload:
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user1",
|
||||
content="text after image",
|
||||
media=[tmp],
|
||||
metadata={"message_id": "m1"},
|
||||
)
|
||||
)
|
||||
assert mock_upload.called
|
||||
|
||||
# Text should have been sent via c2c (default chat type)
|
||||
text_calls = [c for c in channel._client.api.c2c_calls if c.get("msg_type") == 0]
|
||||
assert len(text_calls) >= 1
|
||||
assert text_calls[-1]["content"] == "text after image"
|
||||
finally:
|
||||
import os
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_failure_falls_back_to_text() -> None:
|
||||
"""When _send_media returns False, a failure notice is appended."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with patch.object(channel, "_send_media", new_callable=AsyncMock, return_value=False):
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user1",
|
||||
content="hello",
|
||||
media=["https://example.com/bad.png"],
|
||||
metadata={"message_id": "m1"},
|
||||
)
|
||||
)
|
||||
|
||||
# Should have the failure text among the c2c calls
|
||||
failure_calls = [c for c in channel._client.api.c2c_calls if "Attachment send failed" in c.get("content", "")]
|
||||
assert len(failure_calls) == 1
|
||||
assert "bad.png" in failure_calls[0]["content"]
|
||||
|
||||
|
||||
# ── _on_message() exception handling ────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_exception_caught_not_raised() -> None:
|
||||
"""Missing required attributes should not crash _on_message."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
# Construct a message-like object that lacks 'author' — triggers AttributeError
|
||||
bad_data = SimpleNamespace(id="x1", content="hi")
|
||||
# Should not raise
|
||||
await channel._on_message(bad_data, is_group=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_with_attachments() -> None:
|
||||
"""Messages with attachments produce media_paths and formatted content."""
|
||||
import tempfile
|
||||
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
saved_path = f.name
|
||||
|
||||
att = SimpleNamespace(url="", filename="screenshot.png", content_type="image/png")
|
||||
|
||||
# Patch _download_to_media_dir_chunked to return the temp file path
|
||||
async def fake_download(url, filename_hint=""):
|
||||
return saved_path
|
||||
|
||||
try:
|
||||
with patch.object(channel, "_download_to_media_dir_chunked", side_effect=fake_download):
|
||||
data = SimpleNamespace(
|
||||
id="att1",
|
||||
content="look at this",
|
||||
author=SimpleNamespace(user_openid="u1"),
|
||||
attachments=[att],
|
||||
)
|
||||
await channel._on_message(data, is_group=False)
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert "look at this" in msg.content
|
||||
assert "screenshot.png" in msg.content
|
||||
assert "Received files:" in msg.content
|
||||
assert len(msg.media) == 1
|
||||
assert msg.media[0] == saved_path
|
||||
finally:
|
||||
import os
|
||||
os.unlink(saved_path)
|
||||
|
||||
|
||||
# ── _post_base64file() ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_base64file_omits_file_name_for_images() -> None:
|
||||
"""file_type=1 (image) → payload must not contain file_name."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
channel._client = _FakeClient(http_return={"file_info": "img_abc"})
|
||||
|
||||
await channel._post_base64file(
|
||||
chat_id="user1",
|
||||
is_group=False,
|
||||
file_type=QQ_FILE_TYPE_IMAGE,
|
||||
file_data="ZmFrZQ==",
|
||||
file_name="photo.png",
|
||||
)
|
||||
|
||||
http = channel._client.api._http
|
||||
assert len(http.calls) == 1
|
||||
payload = http.calls[0][1]["json"]
|
||||
assert "file_name" not in payload
|
||||
assert payload["file_type"] == QQ_FILE_TYPE_IMAGE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_base64file_includes_file_name_for_files() -> None:
|
||||
"""file_type=4 (file) → payload must contain file_name."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
channel._client = _FakeClient(http_return={"file_info": "file_abc"})
|
||||
|
||||
await channel._post_base64file(
|
||||
chat_id="user1",
|
||||
is_group=False,
|
||||
file_type=QQ_FILE_TYPE_FILE,
|
||||
file_data="ZmFrZQ==",
|
||||
file_name="report.pdf",
|
||||
)
|
||||
|
||||
http = channel._client.api._http
|
||||
assert len(http.calls) == 1
|
||||
payload = http.calls[0][1]["json"]
|
||||
assert payload["file_name"] == "report.pdf"
|
||||
assert payload["file_type"] == QQ_FILE_TYPE_FILE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_base64file_filters_response_to_file_info() -> None:
|
||||
"""Response with file_info + extra fields must be filtered to only file_info."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
channel._client = _FakeClient(http_return={
|
||||
"file_info": "fi_123",
|
||||
"file_uuid": "uuid_xxx",
|
||||
"ttl": 3600,
|
||||
})
|
||||
|
||||
result = await channel._post_base64file(
|
||||
chat_id="user1",
|
||||
is_group=False,
|
||||
file_type=QQ_FILE_TYPE_FILE,
|
||||
file_data="ZmFrZQ==",
|
||||
file_name="doc.pdf",
|
||||
)
|
||||
|
||||
assert result == {"file_info": "fi_123"}
|
||||
assert "file_uuid" not in result
|
||||
assert "ttl" not in result
|
||||
@ -10,8 +10,7 @@ except ImportError:
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.slack import SlackChannel
|
||||
from nanobot.channels.slack import SlackConfig
|
||||
from nanobot.channels.slack import SlackChannel, SlackConfig
|
||||
|
||||
|
||||
class _FakeAsyncWebClient:
|
||||
@ -20,6 +19,12 @@ class _FakeAsyncWebClient:
|
||||
self.file_upload_calls: list[dict[str, object | None]] = []
|
||||
self.reactions_add_calls: list[dict[str, object | None]] = []
|
||||
self.reactions_remove_calls: list[dict[str, object | None]] = []
|
||||
self.conversations_list_calls: list[dict[str, object | None]] = []
|
||||
self.users_list_calls: list[dict[str, object | None]] = []
|
||||
self.conversations_open_calls: list[dict[str, object | None]] = []
|
||||
self._conversations_pages: list[dict[str, object]] = []
|
||||
self._users_pages: list[dict[str, object]] = []
|
||||
self._open_dm_response: dict[str, object] = {"channel": {"id": "D_OPENED"}}
|
||||
|
||||
async def chat_postMessage(
|
||||
self,
|
||||
@ -81,6 +86,22 @@ class _FakeAsyncWebClient:
|
||||
}
|
||||
)
|
||||
|
||||
async def conversations_list(self, **kwargs):
|
||||
self.conversations_list_calls.append(kwargs)
|
||||
if self._conversations_pages:
|
||||
return self._conversations_pages.pop(0)
|
||||
return {"channels": [], "response_metadata": {"next_cursor": ""}}
|
||||
|
||||
async def users_list(self, **kwargs):
|
||||
self.users_list_calls.append(kwargs)
|
||||
if self._users_pages:
|
||||
return self._users_pages.pop(0)
|
||||
return {"members": [], "response_metadata": {"next_cursor": ""}}
|
||||
|
||||
async def conversations_open(self, **kwargs):
|
||||
self.conversations_open_calls.append(kwargs)
|
||||
return self._open_dm_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_thread_for_channel_messages() -> None:
|
||||
@ -151,3 +172,147 @@ async def test_send_updates_reaction_when_final_response_sent() -> None:
|
||||
assert fake_web.reactions_add_calls == [
|
||||
{"channel": "C123", "name": "white_check_mark", "timestamp": "1700000000.000100"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_resolves_channel_name_to_channel_id() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
fake_web._conversations_pages = [
|
||||
{
|
||||
"channels": [{"id": "C999", "name": "channel_x"}],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
]
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="#channel_x",
|
||||
content="hello",
|
||||
)
|
||||
)
|
||||
|
||||
assert fake_web.chat_post_calls == [
|
||||
{"channel": "C999", "text": "hello\n", "thread_ts": None}
|
||||
]
|
||||
assert len(fake_web.conversations_list_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_resolves_user_handle_to_dm_channel() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
fake_web._users_pages = [
|
||||
{
|
||||
"members": [
|
||||
{
|
||||
"id": "U234",
|
||||
"name": "alice",
|
||||
"profile": {"display_name": "Alice"},
|
||||
}
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
]
|
||||
fake_web._open_dm_response = {"channel": {"id": "D234"}}
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="@alice",
|
||||
content="hello",
|
||||
)
|
||||
)
|
||||
|
||||
assert fake_web.conversations_open_calls == [{"users": "U234"}]
|
||||
assert fake_web.chat_post_calls == [
|
||||
{"channel": "D234", "text": "hello\n", "thread_ts": None}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_updates_reaction_on_origin_channel_for_cross_channel_send() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True, react_emoji="eyes"), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
fake_web._conversations_pages = [
|
||||
{
|
||||
"channels": [{"id": "C999", "name": "channel_x"}],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
]
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="channel_x",
|
||||
content="done",
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": {"ts": "1700000000.000100", "channel": "D_ORIGIN"},
|
||||
"channel_type": "im",
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert fake_web.chat_post_calls == [
|
||||
{"channel": "C999", "text": "done\n", "thread_ts": None}
|
||||
]
|
||||
assert fake_web.reactions_remove_calls == [
|
||||
{"channel": "D_ORIGIN", "name": "eyes", "timestamp": "1700000000.000100"}
|
||||
]
|
||||
assert fake_web.reactions_add_calls == [
|
||||
{"channel": "D_ORIGIN", "name": "white_check_mark", "timestamp": "1700000000.000100"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_does_not_reuse_origin_thread_ts_for_cross_channel_send() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
fake_web._conversations_pages = [
|
||||
{
|
||||
"channels": [{"id": "C999", "name": "channel_x"}],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
]
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="channel_x",
|
||||
content="done",
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": {"ts": "1700000000.000100", "channel": "C_ORIGIN"},
|
||||
"thread_ts": "1700000000.000200",
|
||||
"channel_type": "channel",
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert fake_web.chat_post_calls == [
|
||||
{"channel": "C999", "text": "done\n", "thread_ts": None}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_raises_when_named_target_cannot_be_resolved() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
with pytest.raises(ValueError, match="was not found"):
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="#missing-channel",
|
||||
content="hello",
|
||||
)
|
||||
)
|
||||
|
||||
@ -387,6 +387,84 @@ async def test_send_delta_stream_end_treats_not_modified_as_success() -> None:
|
||||
assert "123" not in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_does_not_fallback_on_network_timeout() -> None:
|
||||
"""TimedOut during HTML edit should propagate, never fall back to plain text."""
|
||||
from telegram.error import TimedOut
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
# _call_with_retry retries TimedOut up to 3 times, so the mock will be called
|
||||
# multiple times – but all calls must be with parse_mode="HTML" (no plain fallback).
|
||||
channel._app.bot.edit_message_text = AsyncMock(side_effect=TimedOut("network timeout"))
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0)
|
||||
|
||||
with pytest.raises(TimedOut, match="network timeout"):
|
||||
await channel.send_delta("123", "", {"_stream_end": True})
|
||||
|
||||
# Every call to edit_message_text must have used parse_mode="HTML" —
|
||||
# no plain-text fallback call should have been made.
|
||||
for call in channel._app.bot.edit_message_text.call_args_list:
|
||||
assert call.kwargs.get("parse_mode") == "HTML"
|
||||
# Buffer should still be present (not cleaned up on error)
|
||||
assert "123" in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_does_not_fallback_on_network_error() -> None:
|
||||
"""NetworkError during HTML edit should propagate, never fall back to plain text."""
|
||||
from telegram.error import NetworkError
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.edit_message_text = AsyncMock(side_effect=NetworkError("connection reset"))
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello", message_id=7, last_edit=0.0)
|
||||
|
||||
with pytest.raises(NetworkError, match="connection reset"):
|
||||
await channel.send_delta("123", "", {"_stream_end": True})
|
||||
|
||||
# Every call to edit_message_text must have used parse_mode="HTML" —
|
||||
# no plain-text fallback call should have been made.
|
||||
for call in channel._app.bot.edit_message_text.call_args_list:
|
||||
assert call.kwargs.get("parse_mode") == "HTML"
|
||||
# Buffer should still be present (not cleaned up on error)
|
||||
assert "123" in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_falls_back_on_bad_request() -> None:
|
||||
"""BadRequest (HTML parse error) should still trigger plain-text fallback."""
|
||||
from telegram.error import BadRequest
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
# First call (HTML) raises BadRequest, second call (plain) succeeds
|
||||
channel._app.bot.edit_message_text = AsyncMock(
|
||||
side_effect=[BadRequest("Can't parse entities"), None]
|
||||
)
|
||||
channel._stream_bufs["123"] = _StreamBuf(text="hello <bad>", message_id=7, last_edit=0.0)
|
||||
|
||||
await channel.send_delta("123", "", {"_stream_end": True})
|
||||
|
||||
# edit_message_text should have been called twice: once for HTML, once for plain fallback
|
||||
assert channel._app.bot.edit_message_text.call_count == 2
|
||||
# Second call should not use parse_mode="HTML"
|
||||
second_call_kwargs = channel._app.bot.edit_message_text.call_args_list[1].kwargs
|
||||
assert "parse_mode" not in second_call_kwargs or second_call_kwargs.get("parse_mode") is None
|
||||
# Buffer should be cleaned up on success
|
||||
assert "123" not in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_splits_oversized_reply() -> None:
|
||||
"""Final streamed reply exceeding Telegram limit is split into chunks."""
|
||||
@ -1159,3 +1237,159 @@ async def test_on_message_location_with_text() -> None:
|
||||
assert len(handled) == 1
|
||||
assert "meet me here" in handled[0]["content"]
|
||||
assert "[location: 51.5074, -0.1278]" in handled[0]["content"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for retry amplification fix (issue #3050)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_does_not_fallback_on_network_timeout() -> None:
|
||||
"""TimedOut should propagate immediately, NOT trigger plain-text fallback.
|
||||
|
||||
Before the fix, _send_text caught ALL exceptions (including TimedOut)
|
||||
and retried as plain text, doubling connection demand during pool
|
||||
exhaustion — see issue #3050.
|
||||
"""
|
||||
from telegram.error import TimedOut
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def always_timeout(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise TimedOut()
|
||||
|
||||
channel._app.bot.send_message = always_timeout
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
with pytest.raises(TimedOut):
|
||||
await channel._send_text(123, "hello", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
# With the fix: only _call_with_retry's 3 HTML attempts (no plain fallback).
|
||||
# Before the fix: 3 HTML + 3 plain = 6 attempts.
|
||||
assert call_count == 3, (
|
||||
f"Expected 3 calls (HTML retries only), got {call_count} "
|
||||
"(plain-text fallback should not trigger on TimedOut)"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_does_not_fallback_on_network_error() -> None:
|
||||
"""NetworkError should propagate immediately, NOT trigger plain-text fallback."""
|
||||
from telegram.error import NetworkError
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def always_network_error(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise NetworkError("Connection reset")
|
||||
|
||||
channel._app.bot.send_message = always_network_error
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
with pytest.raises(NetworkError):
|
||||
await channel._send_text(123, "hello", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
# _call_with_retry does NOT retry NetworkError (only TimedOut/RetryAfter),
|
||||
# so it raises after 1 attempt. The fix prevents plain-text fallback.
|
||||
# Before the fix: 1 HTML + 1 plain = 2. After the fix: 1 HTML only.
|
||||
assert call_count == 1, (
|
||||
f"Expected 1 call (HTML only, no plain fallback), got {call_count}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_falls_back_on_bad_request() -> None:
|
||||
"""BadRequest (HTML parse error) should still trigger plain-text fallback."""
|
||||
from telegram.error import BadRequest
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
original_send = channel._app.bot.send_message
|
||||
html_call_count = 0
|
||||
|
||||
async def html_fails(**kwargs):
|
||||
nonlocal html_call_count
|
||||
if kwargs.get("parse_mode") == "HTML":
|
||||
html_call_count += 1
|
||||
raise BadRequest("Can't parse entities")
|
||||
return await original_send(**kwargs)
|
||||
|
||||
channel._app.bot.send_message = html_fails
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
await channel._send_text(123, "hello **world**", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
# HTML attempt failed with BadRequest → fallback to plain text succeeds.
|
||||
assert html_call_count == 1, f"Expected 1 HTML attempt, got {html_call_count}"
|
||||
assert len(channel._app.bot.sent_messages) == 1
|
||||
# Plain text send should NOT have parse_mode
|
||||
assert channel._app.bot.sent_messages[0].get("parse_mode") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_bad_request_plain_fallback_exhausted() -> None:
|
||||
"""When both HTML and plain-text fallback fail with BadRequest, the error propagates."""
|
||||
from telegram.error import BadRequest
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def always_bad_request(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise BadRequest("Bad request")
|
||||
|
||||
channel._app.bot.send_message = always_bad_request
|
||||
|
||||
import nanobot.channels.telegram as tg_mod
|
||||
orig_delay = tg_mod._SEND_RETRY_BASE_DELAY
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = 0.01
|
||||
try:
|
||||
with pytest.raises(BadRequest):
|
||||
await channel._send_text(123, "hello", None, {})
|
||||
finally:
|
||||
tg_mod._SEND_RETRY_BASE_DELAY = orig_delay
|
||||
|
||||
# _call_with_retry does NOT retry BadRequest (only TimedOut/RetryAfter),
|
||||
# so HTML fails after 1 attempt → fallback to plain also fails after 1 attempt.
|
||||
# Before the fix: 2 total. After the fix: still 2 (BadRequest SHOULD fallback).
|
||||
assert call_count == 2, f"Expected 2 calls (1 HTML + 1 plain), got {call_count}"
|
||||
|
||||
598
tests/channels/test_websocket_channel.py
Normal file
598
tests/channels/test_websocket_channel.py
Normal file
@ -0,0 +1,598 @@
|
||||
"""Unit and lightweight integration tests for the WebSocket channel."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.frames import Close
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.channels.websocket import (
|
||||
WebSocketChannel,
|
||||
WebSocketConfig,
|
||||
_issue_route_secret_matches,
|
||||
_normalize_config_path,
|
||||
_normalize_http_path,
|
||||
_parse_inbound_payload,
|
||||
_parse_query,
|
||||
_parse_request_path,
|
||||
)
|
||||
|
||||
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
|
||||
|
||||
_PORT = 29876
|
||||
|
||||
|
||||
def _ch(bus: Any, **kw: Any) -> WebSocketChannel:
|
||||
cfg: dict[str, Any] = {
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": _PORT,
|
||||
"path": "/ws",
|
||||
"websocketRequiresToken": False,
|
||||
}
|
||||
cfg.update(kw)
|
||||
return WebSocketChannel(cfg, bus)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def bus() -> MagicMock:
|
||||
b = MagicMock()
|
||||
b.publish_inbound = AsyncMock()
|
||||
return b
|
||||
|
||||
|
||||
async def _http_get(url: str, headers: dict[str, str] | None = None) -> httpx.Response:
|
||||
"""Run GET in a thread to avoid blocking the asyncio loop shared with websockets."""
|
||||
return await asyncio.to_thread(
|
||||
functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0)
|
||||
)
|
||||
|
||||
|
||||
def test_normalize_http_path_strips_trailing_slash_except_root() -> None:
|
||||
assert _normalize_http_path("/chat/") == "/chat"
|
||||
assert _normalize_http_path("/chat?x=1") == "/chat"
|
||||
assert _normalize_http_path("/") == "/"
|
||||
|
||||
|
||||
def test_parse_request_path_matches_normalize_and_query() -> None:
|
||||
path, query = _parse_request_path("/ws/?token=secret&client_id=u1")
|
||||
assert path == _normalize_http_path("/ws/?token=secret&client_id=u1")
|
||||
assert query == _parse_query("/ws/?token=secret&client_id=u1")
|
||||
|
||||
|
||||
def test_normalize_config_path_matches_request() -> None:
|
||||
assert _normalize_config_path("/ws/") == "/ws"
|
||||
assert _normalize_config_path("/") == "/"
|
||||
|
||||
|
||||
def test_parse_query_extracts_token_and_client_id() -> None:
|
||||
query = _parse_query("/?token=secret&client_id=u1")
|
||||
assert query.get("token") == ["secret"]
|
||||
assert query.get("client_id") == ["u1"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "expected"),
|
||||
[
|
||||
("plain", "plain"),
|
||||
('{"content": "hi"}', "hi"),
|
||||
('{"text": "there"}', "there"),
|
||||
('{"message": "x"}', "x"),
|
||||
(" ", None),
|
||||
("{}", None),
|
||||
],
|
||||
)
|
||||
def test_parse_inbound_payload(raw: str, expected: str | None) -> None:
|
||||
assert _parse_inbound_payload(raw) == expected
|
||||
|
||||
|
||||
def test_parse_inbound_invalid_json_falls_back_to_raw_string() -> None:
|
||||
assert _parse_inbound_payload("{not json") == "{not json"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "expected"),
|
||||
[
|
||||
('{"content": ""}', None), # empty string content
|
||||
('{"content": 123}', None), # non-string content
|
||||
('{"content": " "}', None), # whitespace-only content
|
||||
('["hello"]', '["hello"]'), # JSON array: not a dict, treated as plain text
|
||||
('{"unknown_key": "val"}', None), # unrecognized key
|
||||
('{"content": null}', None), # null content
|
||||
],
|
||||
)
|
||||
def test_parse_inbound_payload_edge_cases(raw: str, expected: str | None) -> None:
|
||||
assert _parse_inbound_payload(raw) == expected
|
||||
|
||||
|
||||
def test_web_socket_config_path_must_start_with_slash() -> None:
|
||||
with pytest.raises(ValueError, match='path must start with "/"'):
|
||||
WebSocketConfig(path="bad")
|
||||
|
||||
|
||||
def test_ssl_context_requires_both_cert_and_key_files() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel(
|
||||
{"enabled": True, "allowFrom": ["*"], "sslCertfile": "/tmp/c.pem", "sslKeyfile": ""},
|
||||
bus,
|
||||
)
|
||||
with pytest.raises(ValueError, match="ssl_certfile and ssl_keyfile"):
|
||||
channel._build_ssl_context()
|
||||
|
||||
|
||||
def test_default_config_includes_safe_bind_and_streaming() -> None:
|
||||
defaults = WebSocketChannel.default_config()
|
||||
assert defaults["enabled"] is False
|
||||
assert defaults["host"] == "127.0.0.1"
|
||||
assert defaults["streaming"] is True
|
||||
assert defaults["allowFrom"] == ["*"]
|
||||
assert defaults.get("tokenIssuePath", "") == ""
|
||||
|
||||
|
||||
def test_token_issue_path_must_differ_from_websocket_path() -> None:
|
||||
with pytest.raises(ValueError, match="token_issue_path must differ"):
|
||||
WebSocketConfig(path="/ws", token_issue_path="/ws")
|
||||
|
||||
|
||||
def test_issue_route_secret_matches_bearer_and_header() -> None:
|
||||
from websockets.datastructures import Headers
|
||||
|
||||
secret = "my-secret"
|
||||
bearer_headers = Headers([("Authorization", "Bearer my-secret")])
|
||||
assert _issue_route_secret_matches(bearer_headers, secret) is True
|
||||
x_headers = Headers([("X-Nanobot-Auth", "my-secret")])
|
||||
assert _issue_route_secret_matches(x_headers, secret) is True
|
||||
wrong = Headers([("Authorization", "Bearer other")])
|
||||
assert _issue_route_secret_matches(wrong, secret) is False
|
||||
|
||||
|
||||
def test_issue_route_secret_matches_empty_secret() -> None:
|
||||
from websockets.datastructures import Headers
|
||||
|
||||
# Empty secret always returns True regardless of headers
|
||||
assert _issue_route_secret_matches(Headers([]), "") is True
|
||||
assert _issue_route_secret_matches(Headers([("Authorization", "Bearer anything")]), "") is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel="websocket",
|
||||
chat_id="chat-1",
|
||||
content="hello",
|
||||
reply_to="m1",
|
||||
media=["/tmp/a.png"],
|
||||
)
|
||||
await channel.send(msg)
|
||||
|
||||
mock_ws.send.assert_awaited_once()
|
||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||
assert payload["event"] == "message"
|
||||
assert payload["text"] == "hello"
|
||||
assert payload["reply_to"] == "m1"
|
||||
assert payload["media"] == ["/tmp/a.png"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_missing_connection_is_noop_without_error() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
msg = OutboundMessage(channel="websocket", chat_id="missing", content="x")
|
||||
await channel.send(msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_removes_connection_on_connection_closed() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
|
||||
await channel.send(msg)
|
||||
|
||||
assert "chat-1" not in channel._connections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_removes_connection_on_connection_closed() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = ConnectionClosed(Close(1006, ""), Close(1006, ""), True)
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
await channel.send_delta("chat-1", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
|
||||
|
||||
assert "chat-1" not in channel._connections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_emits_delta_and_stream_end() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
await channel.send_delta("chat-1", "part", {"_stream_delta": True, "_stream_id": "sid"})
|
||||
await channel.send_delta("chat-1", "", {"_stream_end": True, "_stream_id": "sid"})
|
||||
|
||||
assert mock_ws.send.await_count == 2
|
||||
first = json.loads(mock_ws.send.call_args_list[0][0][0])
|
||||
second = json.loads(mock_ws.send.call_args_list[1][0][0])
|
||||
assert first["event"] == "delta"
|
||||
assert first["text"] == "part"
|
||||
assert first["stream_id"] == "sid"
|
||||
assert second["event"] == "stream_end"
|
||||
assert second["stream_id"] == "sid"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_non_connection_closed_exception_is_raised() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
mock_ws = AsyncMock()
|
||||
mock_ws.send.side_effect = RuntimeError("unexpected")
|
||||
channel._connections["chat-1"] = mock_ws
|
||||
|
||||
msg = OutboundMessage(channel="websocket", chat_id="chat-1", content="hello")
|
||||
with pytest.raises(RuntimeError, match="unexpected"):
|
||||
await channel.send(msg)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_missing_connection_is_noop() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"], "streaming": True}, bus)
|
||||
# No exception, no error — just a no-op
|
||||
await channel.send_delta("nonexistent", "chunk", {"_stream_delta": True, "_stream_id": "s1"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_is_idempotent() -> None:
|
||||
bus = MagicMock()
|
||||
channel = WebSocketChannel({"enabled": True, "allowFrom": ["*"]}, bus)
|
||||
# stop() before start() should not raise
|
||||
await channel.stop()
|
||||
await channel.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_client_receives_ready_and_agent_sees_inbound(bus: MagicMock) -> None:
|
||||
port = 29876
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=tester") as client:
|
||||
ready_raw = await client.recv()
|
||||
ready = json.loads(ready_raw)
|
||||
assert ready["event"] == "ready"
|
||||
assert ready["client_id"] == "tester"
|
||||
chat_id = ready["chat_id"]
|
||||
|
||||
await client.send(json.dumps({"content": "ping from client"}))
|
||||
await asyncio.sleep(0.08)
|
||||
|
||||
bus.publish_inbound.assert_awaited()
|
||||
inbound = bus.publish_inbound.call_args[0][0]
|
||||
assert inbound.channel == "websocket"
|
||||
assert inbound.sender_id == "tester"
|
||||
assert inbound.chat_id == chat_id
|
||||
assert inbound.content == "ping from client"
|
||||
|
||||
await client.send("plain text frame")
|
||||
await asyncio.sleep(0.08)
|
||||
assert bus.publish_inbound.await_count >= 2
|
||||
second = [c[0][0] for c in bus.publish_inbound.call_args_list][-1]
|
||||
assert second.content == "plain text frame"
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_rejects_handshake_when_mismatch(bus: MagicMock) -> None:
|
||||
port = 29877
|
||||
channel = _ch(bus, port=port, path="/", token="secret")
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/?token=wrong"):
|
||||
pass
|
||||
assert excinfo.value.response.status_code == 401
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_path_returns_404(bus: MagicMock) -> None:
|
||||
port = 29878
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as excinfo:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/other"):
|
||||
pass
|
||||
assert excinfo.value.response.status_code == 404
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
def test_registry_discovers_websocket_channel() -> None:
|
||||
from nanobot.channels.registry import load_channel_class
|
||||
|
||||
cls = load_channel_class("websocket")
|
||||
assert cls.name == "websocket"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock) -> None:
|
||||
port = 29879
|
||||
channel = _ch(
|
||||
bus, port=port,
|
||||
tokenIssuePath="/auth/token",
|
||||
tokenIssueSecret="route-secret",
|
||||
websocketRequiresToken=True,
|
||||
)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
deny = await _http_get(f"http://127.0.0.1:{port}/auth/token")
|
||||
assert deny.status_code == 401
|
||||
|
||||
issue = await _http_get(
|
||||
f"http://127.0.0.1:{port}/auth/token",
|
||||
headers={"Authorization": "Bearer route-secret"},
|
||||
)
|
||||
assert issue.status_code == 200
|
||||
token = issue.json()["token"]
|
||||
assert token.startswith("nbwt_")
|
||||
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as missing_token:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=x"):
|
||||
pass
|
||||
assert missing_token.value.response.status_code == 401
|
||||
|
||||
uri = f"ws://127.0.0.1:{port}/ws?token={token}&client_id=caller"
|
||||
async with websockets.connect(uri) as client:
|
||||
ready = json.loads(await client.recv())
|
||||
assert ready["event"] == "ready"
|
||||
assert ready["client_id"] == "caller"
|
||||
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as reuse:
|
||||
async with websockets.connect(uri):
|
||||
pass
|
||||
assert reuse.value.response.status_code == 401
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
|
||||
port = 29880
|
||||
channel = _ch(bus, port=port, streaming=True)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=stream-tester") as client:
|
||||
ready_raw = await client.recv()
|
||||
ready = json.loads(ready_raw)
|
||||
chat_id = ready["chat_id"]
|
||||
|
||||
# Server pushes deltas directly
|
||||
await channel.send_delta(
|
||||
chat_id, "Hello ", {"_stream_delta": True, "_stream_id": "s1"}
|
||||
)
|
||||
await channel.send_delta(
|
||||
chat_id, "world", {"_stream_delta": True, "_stream_id": "s1"}
|
||||
)
|
||||
await channel.send_delta(
|
||||
chat_id, "", {"_stream_end": True, "_stream_id": "s1"}
|
||||
)
|
||||
|
||||
delta1 = json.loads(await client.recv())
|
||||
assert delta1["event"] == "delta"
|
||||
assert delta1["text"] == "Hello "
|
||||
assert delta1["stream_id"] == "s1"
|
||||
|
||||
delta2 = json.loads(await client.recv())
|
||||
assert delta2["event"] == "delta"
|
||||
assert delta2["text"] == "world"
|
||||
assert delta2["stream_id"] == "s1"
|
||||
|
||||
end = json.loads(await client.recv())
|
||||
assert end["event"] == "stream_end"
|
||||
assert end["stream_id"] == "s1"
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_issue_rejects_when_at_capacity(bus: MagicMock) -> None:
|
||||
port = 29881
|
||||
channel = _ch(bus, port=port, tokenIssuePath="/auth/token", tokenIssueSecret="s")
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
# Fill issued tokens to capacity
|
||||
channel._issued_tokens = {
|
||||
f"nbwt_fill_{i}": time.monotonic() + 300 for i in range(channel._MAX_ISSUED_TOKENS)
|
||||
}
|
||||
|
||||
resp = await _http_get(
|
||||
f"http://127.0.0.1:{port}/auth/token",
|
||||
headers={"Authorization": "Bearer s"},
|
||||
)
|
||||
assert resp.status_code == 429
|
||||
data = resp.json()
|
||||
assert "error" in data
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_from_rejects_unauthorized_client_id(bus: MagicMock) -> None:
|
||||
port = 29882
|
||||
channel = _ch(bus, port=port, allowFrom=["alice", "bob"])
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=eve"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 403
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_client_id_truncation(bus: MagicMock) -> None:
|
||||
port = 29883
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
long_id = "x" * 200
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id={long_id}") as client:
|
||||
ready = json.loads(await client.recv())
|
||||
assert ready["client_id"] == "x" * 128
|
||||
assert len(ready["client_id"]) == 128
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_utf8_binary_frame_ignored(bus: MagicMock) -> None:
|
||||
port = 29884
|
||||
channel = _ch(bus, port=port)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=bin-test") as client:
|
||||
await client.recv() # consume ready
|
||||
# Send non-UTF-8 bytes
|
||||
await client.send(b"\xff\xfe\xfd")
|
||||
await asyncio.sleep(0.05)
|
||||
# publish_inbound should NOT have been called
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_token_accepts_issued_token_as_fallback(bus: MagicMock) -> None:
|
||||
port = 29885
|
||||
channel = _ch(
|
||||
bus, port=port,
|
||||
token="static-secret",
|
||||
tokenIssuePath="/auth/token",
|
||||
tokenIssueSecret="route-secret",
|
||||
)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
# Get an issued token
|
||||
resp = await _http_get(
|
||||
f"http://127.0.0.1:{port}/auth/token",
|
||||
headers={"Authorization": "Bearer route-secret"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
issued_token = resp.json()["token"]
|
||||
|
||||
# Connect using issued token (not the static one)
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?token={issued_token}&client_id=caller") as client:
|
||||
ready = json.loads(await client.recv())
|
||||
assert ready["event"] == "ready"
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_from_empty_list_denies_all(bus: MagicMock) -> None:
|
||||
port = 29886
|
||||
channel = _ch(bus, port=port, allowFrom=[])
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=anyone"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 403
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_requires_token_without_issue_path(bus: MagicMock) -> None:
|
||||
"""When websocket_requires_token is True but no token or issue path configured, all connections are rejected."""
|
||||
port = 29887
|
||||
channel = _ch(bus, port=port, websocketRequiresToken=True)
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
# No token at all → 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 401
|
||||
|
||||
# Wrong token → 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc_info:
|
||||
async with websockets.connect(f"ws://127.0.0.1:{port}/ws?client_id=u&token=wrong"):
|
||||
pass
|
||||
assert exc_info.value.response.status_code == 401
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
478
tests/channels/test_websocket_integration.py
Normal file
478
tests/channels/test_websocket_integration.py
Normal file
@ -0,0 +1,478 @@
|
||||
"""Integration tests for the WebSocket channel using WsTestClient.
|
||||
|
||||
Complements the unit/lightweight tests in test_websocket_channel.py by covering
|
||||
multi-client scenarios, edge cases, and realistic usage patterns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
|
||||
from nanobot.channels.websocket import WebSocketChannel
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from ws_test_client import WsTestClient, issue_token, issue_token_ok
|
||||
|
||||
|
||||
def _ch(bus: Any, port: int, **kw: Any) -> WebSocketChannel:
|
||||
cfg: dict[str, Any] = {
|
||||
"enabled": True,
|
||||
"allowFrom": ["*"],
|
||||
"host": "127.0.0.1",
|
||||
"port": port,
|
||||
"path": "/",
|
||||
"websocketRequiresToken": False,
|
||||
}
|
||||
cfg.update(kw)
|
||||
return WebSocketChannel(cfg, bus)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def bus() -> MagicMock:
|
||||
b = MagicMock()
|
||||
b.publish_inbound = AsyncMock()
|
||||
return b
|
||||
|
||||
|
||||
# -- Connection basics ----------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ready_event_fields(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29901)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29901/", client_id="c1") as c:
|
||||
r = await c.recv_ready()
|
||||
assert r.event == "ready"
|
||||
assert len(r.chat_id) == 36
|
||||
assert r.client_id == "c1"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anonymous_client_gets_generated_id(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29902)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29902/", client_id="") as c:
|
||||
r = await c.recv_ready()
|
||||
assert r.client_id.startswith("anon-")
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_each_connection_unique_chat_id(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29903)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29903/", client_id="a") as c1:
|
||||
async with WsTestClient("ws://127.0.0.1:29903/", client_id="b") as c2:
|
||||
assert (await c1.recv_ready()).chat_id != (await c2.recv_ready()).chat_id
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Inbound messages (client -> server) ----------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_text(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29904)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29904/", client_id="p") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_text("hello world")
|
||||
await asyncio.sleep(0.1)
|
||||
inbound = bus.publish_inbound.call_args[0][0]
|
||||
assert inbound.content == "hello world"
|
||||
assert inbound.sender_id == "p"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_content_field(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29905)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29905/", client_id="j") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_json({"content": "structured"})
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "structured"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_text_and_message_fields(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29906)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29906/", client_id="x") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_json({"text": "via text"})
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "via text"
|
||||
await c.send_json({"message": "via message"})
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "via message"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_payload_ignored(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29907)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29907/", client_id="e") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_text(" ")
|
||||
await c.send_json({})
|
||||
await asyncio.sleep(0.1)
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_messages_preserve_order(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29908)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29908/", client_id="o") as c:
|
||||
await c.recv_ready()
|
||||
for i in range(5):
|
||||
await c.send_text(f"msg-{i}")
|
||||
await asyncio.sleep(0.2)
|
||||
contents = [call[0][0].content for call in bus.publish_inbound.call_args_list]
|
||||
assert contents == [f"msg-{i}" for i in range(5)]
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Outbound messages (server -> client) ---------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_send_message(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29909)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29909/", client_id="r") as c:
|
||||
ready = await c.recv_ready()
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content="reply",
|
||||
))
|
||||
msg = await c.recv_message()
|
||||
assert msg.text == "reply"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_send_with_media_and_reply(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29910)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29910/", client_id="m") as c:
|
||||
ready = await c.recv_ready()
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content="img",
|
||||
media=["/tmp/a.png"], reply_to="m1",
|
||||
))
|
||||
msg = await c.recv_message()
|
||||
assert msg.text == "img"
|
||||
assert msg.media == ["/tmp/a.png"]
|
||||
assert msg.reply_to == "m1"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Streaming ------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_deltas_and_end(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29911, streaming=True)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29911/", client_id="s") as c:
|
||||
cid = (await c.recv_ready()).chat_id
|
||||
for part in ("Hello", " ", "world", "!"):
|
||||
await ch.send_delta(cid, part, {"_stream_delta": True, "_stream_id": "s1"})
|
||||
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "s1"})
|
||||
|
||||
msgs = await c.collect_stream()
|
||||
deltas = [m for m in msgs if m.event == "delta"]
|
||||
assert "".join(d.text for d in deltas) == "Hello world!"
|
||||
ends = [m for m in msgs if m.event == "stream_end"]
|
||||
assert len(ends) == 1
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interleaved_streams(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29912, streaming=True)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29912/", client_id="i") as c:
|
||||
cid = (await c.recv_ready()).chat_id
|
||||
await ch.send_delta(cid, "A1", {"_stream_delta": True, "_stream_id": "sa"})
|
||||
await ch.send_delta(cid, "B1", {"_stream_delta": True, "_stream_id": "sb"})
|
||||
await ch.send_delta(cid, "A2", {"_stream_delta": True, "_stream_id": "sa"})
|
||||
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sa"})
|
||||
await ch.send_delta(cid, "B2", {"_stream_delta": True, "_stream_id": "sb"})
|
||||
await ch.send_delta(cid, "", {"_stream_end": True, "_stream_id": "sb"})
|
||||
|
||||
msgs = await c.recv_n(6)
|
||||
sa = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sa")
|
||||
sb = "".join(m.text for m in msgs if m.event == "delta" and m.stream_id == "sb")
|
||||
assert sa == "A1A2"
|
||||
assert sb == "B1B2"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Multi-client ---------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_independent_sessions(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29913)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u1") as c1:
|
||||
async with WsTestClient("ws://127.0.0.1:29913/", client_id="u2") as c2:
|
||||
r1, r2 = await c1.recv_ready(), await c2.recv_ready()
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=r1.chat_id, content="for-u1",
|
||||
))
|
||||
assert (await c1.recv_message()).text == "for-u1"
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=r2.chat_id, content="for-u2",
|
||||
))
|
||||
assert (await c2.recv_message()).text == "for-u2"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnected_client_cleanup(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29914)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29914/", client_id="tmp") as c:
|
||||
chat_id = (await c.recv_ready()).chat_id
|
||||
# disconnected
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=chat_id, content="orphan",
|
||||
))
|
||||
assert chat_id not in ch._connections
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Authentication -------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_token_accepted(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29915, token="secret")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29915/", client_id="a", token="secret") as c:
|
||||
assert (await c.recv_ready()).client_id == "a"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_token_rejected(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29916, token="correct")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29916/", client_id="b", token="wrong"):
|
||||
pass
|
||||
assert exc.value.response.status_code == 401
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_issue_full_flow(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29917, path="/ws",
|
||||
tokenIssuePath="/auth/token", tokenIssueSecret="s",
|
||||
websocketRequiresToken=True)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
# no secret -> 401
|
||||
_, status = await issue_token(port=29917, issue_path="/auth/token")
|
||||
assert status == 401
|
||||
|
||||
# with secret -> token
|
||||
token = await issue_token_ok(port=29917, issue_path="/auth/token", secret="s")
|
||||
|
||||
# no token -> 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="x"):
|
||||
pass
|
||||
assert exc.value.response.status_code == 401
|
||||
|
||||
# valid token -> ok
|
||||
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="ok", token=token) as c:
|
||||
assert (await c.recv_ready()).client_id == "ok"
|
||||
|
||||
# reuse -> 401
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29917/ws", client_id="r", token=token):
|
||||
pass
|
||||
assert exc.value.response.status_code == 401
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Path routing ---------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_path(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29918, path="/my-chat")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29918/my-chat", client_id="p") as c:
|
||||
assert (await c.recv_ready()).event == "ready"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wrong_path_404(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29919, path="/ws")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
with pytest.raises(websockets.exceptions.InvalidStatus) as exc:
|
||||
async with WsTestClient("ws://127.0.0.1:29919/wrong", client_id="x"):
|
||||
pass
|
||||
assert exc.value.response.status_code == 404
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trailing_slash_normalized(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29920, path="/ws")
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29920/ws/", client_id="s") as c:
|
||||
assert (await c.recv_ready()).event == "ready"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
# -- Edge cases -----------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_message(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29921)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29921/", client_id="big") as c:
|
||||
await c.recv_ready()
|
||||
big = "x" * 100_000
|
||||
await c.send_text(big)
|
||||
await asyncio.sleep(0.2)
|
||||
assert bus.publish_inbound.call_args[0][0].content == big
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unicode_roundtrip(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29922)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29922/", client_id="u") as c:
|
||||
ready = await c.recv_ready()
|
||||
text = "你好世界 🌍 日本語テスト"
|
||||
await c.send_text(text)
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == text
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content=text,
|
||||
))
|
||||
assert (await c.recv_message()).text == text
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rapid_fire(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29923)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29923/", client_id="r") as c:
|
||||
ready = await c.recv_ready()
|
||||
for i in range(50):
|
||||
await c.send_text(f"in-{i}")
|
||||
await asyncio.sleep(0.5)
|
||||
assert bus.publish_inbound.await_count == 50
|
||||
for i in range(50):
|
||||
await ch.send(OutboundMessage(
|
||||
channel="websocket", chat_id=ready.chat_id, content=f"out-{i}",
|
||||
))
|
||||
received = [(await c.recv_message()).text for _ in range(50)]
|
||||
assert received == [f"out-{i}" for i in range(50)]
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_json_as_plain_text(bus: MagicMock) -> None:
|
||||
ch = _ch(bus, 29924)
|
||||
t = asyncio.create_task(ch.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
async with WsTestClient("ws://127.0.0.1:29924/", client_id="j") as c:
|
||||
await c.recv_ready()
|
||||
await c.send_text("{broken json")
|
||||
await asyncio.sleep(0.1)
|
||||
assert bus.publish_inbound.call_args[0][0].content == "{broken json"
|
||||
finally:
|
||||
await ch.stop(); await t
|
||||
584
tests/channels/test_wecom_channel.py
Normal file
584
tests/channels/test_wecom_channel.py
Normal file
@ -0,0 +1,584 @@
|
||||
"""Tests for WeCom channel: helpers, download, upload, send, and message processing."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import importlib.util
|
||||
|
||||
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||
except ImportError:
|
||||
WECOM_AVAILABLE = False
|
||||
|
||||
if not WECOM_AVAILABLE:
|
||||
pytest.skip("WeCom dependencies not installed (wecom_aibot_sdk)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.wecom import (
|
||||
WecomChannel,
|
||||
WecomConfig,
|
||||
_guess_wecom_media_type,
|
||||
_sanitize_filename,
|
||||
)
|
||||
|
||||
# Try to import the real response class; fall back to a stub if unavailable.
|
||||
try:
|
||||
from wecom_aibot_sdk.utils import WsResponse
|
||||
|
||||
_RealWsResponse = WsResponse
|
||||
except ImportError:
|
||||
_RealWsResponse = None
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
"""Minimal stand-in for wecom_aibot_sdk WsResponse."""
|
||||
|
||||
def __init__(self, errcode: int = 0, body: dict | None = None, errmsg: str = "ok"):
|
||||
self.errcode = errcode
|
||||
self.errmsg = errmsg
|
||||
self.body = body or {}
|
||||
|
||||
|
||||
class _FakeWsManager:
|
||||
"""Tracks send_reply calls and returns configurable responses."""
|
||||
|
||||
def __init__(self, responses: list[_FakeResponse] | None = None):
|
||||
self.responses = responses or []
|
||||
self.calls: list[tuple[str, dict, str]] = []
|
||||
self._idx = 0
|
||||
|
||||
async def send_reply(self, req_id: str, data: dict, cmd: str) -> _FakeResponse:
|
||||
self.calls.append((req_id, data, cmd))
|
||||
if self._idx < len(self.responses):
|
||||
resp = self.responses[self._idx]
|
||||
self._idx += 1
|
||||
return resp
|
||||
return _FakeResponse()
|
||||
|
||||
|
||||
class _FakeFrame:
|
||||
"""Minimal frame object with a body dict."""
|
||||
|
||||
def __init__(self, body: dict | None = None):
|
||||
self.body = body or {}
|
||||
|
||||
|
||||
class _FakeWeComClient:
|
||||
"""Fake WeCom client with mock methods."""
|
||||
|
||||
def __init__(self, ws_responses: list[_FakeResponse] | None = None):
|
||||
self._ws_manager = _FakeWsManager(ws_responses)
|
||||
self.download_file = AsyncMock(return_value=(None, None))
|
||||
self.reply = AsyncMock()
|
||||
self.reply_stream = AsyncMock()
|
||||
self.send_message = AsyncMock()
|
||||
self.reply_welcome = AsyncMock()
|
||||
|
||||
|
||||
# ── Helper function tests (pure, no async) ──────────────────────────
|
||||
|
||||
|
||||
def test_sanitize_filename_strips_path_traversal() -> None:
|
||||
assert _sanitize_filename("../../etc/passwd") == "passwd"
|
||||
|
||||
|
||||
def test_sanitize_filename_keeps_chinese_chars() -> None:
|
||||
assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg"
|
||||
|
||||
|
||||
def test_sanitize_filename_empty_input() -> None:
|
||||
assert _sanitize_filename("") == ""
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_image() -> None:
|
||||
for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"):
|
||||
assert _guess_wecom_media_type(f"photo{ext}") == "image"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_video() -> None:
|
||||
for ext in (".mp4", ".avi", ".mov"):
|
||||
assert _guess_wecom_media_type(f"video{ext}") == "video"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_voice() -> None:
|
||||
for ext in (".amr", ".mp3", ".wav", ".ogg"):
|
||||
assert _guess_wecom_media_type(f"audio{ext}") == "voice"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_file_fallback() -> None:
|
||||
for ext in (".pdf", ".doc", ".xlsx", ".zip"):
|
||||
assert _guess_wecom_media_type(f"doc{ext}") == "file"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_case_insensitive() -> None:
|
||||
assert _guess_wecom_media_type("photo.PNG") == "image"
|
||||
assert _guess_wecom_media_type("photo.Jpg") == "image"
|
||||
|
||||
|
||||
# ── _download_and_save_media() ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_and_save_success() -> None:
|
||||
"""Successful download writes file and returns sanitized path."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
fake_data = b"\x89PNG\r\nfake image"
|
||||
client.download_file.return_value = (fake_data, "raw_photo.png")
|
||||
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||
path = await channel._download_and_save_media("https://example.com/img.png", "aes_key", "image", "photo.png")
|
||||
|
||||
assert path is not None
|
||||
assert os.path.isfile(path)
|
||||
assert os.path.basename(path) == "photo.png"
|
||||
# Cleanup
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_and_save_oversized_rejected() -> None:
|
||||
"""Data exceeding 200MB is rejected → returns None."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
big_data = b"\x00" * (200 * 1024 * 1024 + 1) # 200MB + 1 byte
|
||||
client.download_file.return_value = (big_data, "big.bin")
|
||||
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||
result = await channel._download_and_save_media("https://example.com/big.bin", "key", "file", "big.bin")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_and_save_failure() -> None:
|
||||
"""SDK returns None data → returns None."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
client.download_file.return_value = (None, None)
|
||||
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||
result = await channel._download_and_save_media("https://example.com/fail.png", "key", "image")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── _upload_media_ws() ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_success() -> None:
|
||||
"""Happy path: init → chunk → finish → returns (media_id, media_type)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=0, body={}),
|
||||
_FakeResponse(errcode=0, body={"media_id": "media_abc"}),
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
media_id, media_type = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert media_id == "media_abc"
|
||||
assert media_type == "image"
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_oversized_file() -> None:
|
||||
"""File >200MB triggers ValueError → returns (None, None)."""
|
||||
# Instead of creating a real 200MB+ file, mock os.path.getsize and open
|
||||
with patch("os.path.getsize", return_value=200 * 1024 * 1024 + 1), \
|
||||
patch("builtins.open", MagicMock()):
|
||||
client = _FakeWeComClient()
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
result = await channel._upload_media_ws(client, "/fake/large.bin")
|
||||
assert result == (None, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_init_failure() -> None:
|
||||
"""Init step returns errcode != 0 → returns (None, None)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
||||
f.write(b"hello")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=50001, errmsg="invalid"),
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
result = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert result == (None, None)
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_chunk_failure() -> None:
|
||||
"""Chunk step returns errcode != 0 → returns (None, None)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=50002, errmsg="chunk fail"),
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
result = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert result == (None, None)
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_finish_no_media_id() -> None:
|
||||
"""Finish step returns empty media_id → returns (None, None)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=0, body={}),
|
||||
_FakeResponse(errcode=0, body={}), # no media_id
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
result = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert result == (None, None)
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
# ── send() ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_with_frame() -> None:
|
||||
"""When frame is stored, send uses reply_stream for final text."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="hello")
|
||||
)
|
||||
|
||||
client.reply_stream.assert_called_once()
|
||||
call_args = client.reply_stream.call_args
|
||||
assert call_args[0][2] == "hello" # content arg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_with_frame() -> None:
|
||||
"""When metadata has _progress, send uses reply_stream with finish=False."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="thinking...", metadata={"_progress": True})
|
||||
)
|
||||
|
||||
client.reply_stream.assert_called_once()
|
||||
call_args = client.reply_stream.call_args
|
||||
assert call_args[0][2] == "thinking..." # content arg
|
||||
assert call_args[1]["finish"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_proactive_without_frame() -> None:
|
||||
"""Without stored frame, send uses send_message with markdown."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="proactive msg")
|
||||
)
|
||||
|
||||
client.send_message.assert_called_once()
|
||||
call_args = client.send_message.call_args
|
||||
assert call_args[0][0] == "chat1"
|
||||
assert call_args[0][1]["msgtype"] == "markdown"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_then_text() -> None:
|
||||
"""Media files are uploaded and sent before text content."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=0, body={}),
|
||||
_FakeResponse(errcode=0, body={"media_id": "media_123"}),
|
||||
]
|
||||
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient(responses)
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="see image", media=[tmp])
|
||||
)
|
||||
|
||||
# Media should have been sent via reply
|
||||
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") == "image"]
|
||||
assert len(media_calls) == 1
|
||||
assert media_calls[0][0][1]["image"]["media_id"] == "media_123"
|
||||
|
||||
# Text should have been sent via reply_stream
|
||||
client.reply_stream.assert_called_once()
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_file_not_found() -> None:
|
||||
"""Non-existent media path is skipped with a warning."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="hello", media=["/nonexistent/file.png"])
|
||||
)
|
||||
|
||||
# reply_stream should still be called for the text part
|
||||
client.reply_stream.assert_called_once()
|
||||
# No media reply should happen
|
||||
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") in ("image", "file", "video")]
|
||||
assert len(media_calls) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_exception_caught_not_raised() -> None:
|
||||
"""Exceptions inside send() must not propagate."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
# Make reply_stream raise
|
||||
client.reply_stream.side_effect = RuntimeError("boom")
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="fail test")
|
||||
)
|
||||
# No exception — test passes if we reach here.
|
||||
|
||||
|
||||
# ── _process_message() ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_text_message() -> None:
|
||||
"""Text message is routed to bus with correct fields."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_text_1",
|
||||
"chatid": "chat1",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user1"},
|
||||
"text": {"content": "hello wecom"},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "text")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "chat1"
|
||||
assert msg.content == "hello wecom"
|
||||
assert msg.metadata["msg_type"] == "text"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_image_message() -> None:
|
||||
"""Image message: download success → media_paths non-empty."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
saved = f.name
|
||||
|
||||
client.download_file.return_value = (b"\x89PNG\r\n", "photo.png")
|
||||
channel._client = client
|
||||
|
||||
try:
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_img_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"image": {"url": "https://example.com/img.png", "aeskey": "key123"},
|
||||
})
|
||||
await channel._process_message(frame, "image")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert len(msg.media) == 1
|
||||
assert msg.media[0].endswith("photo.png")
|
||||
assert "[image:" in msg.content
|
||||
finally:
|
||||
if os.path.exists(saved):
|
||||
pass # may have been overwritten; clean up if exists
|
||||
# Clean up any photo.png in tempdir
|
||||
p = os.path.join(os.path.dirname(saved), "photo.png")
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_file_message() -> None:
|
||||
"""File message: download success → media_paths non-empty (critical fix verification)."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(b"%PDF-1.4 fake")
|
||||
saved = f.name
|
||||
|
||||
client.download_file.return_value = (b"%PDF-1.4 fake", "report.pdf")
|
||||
channel._client = client
|
||||
|
||||
try:
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_file_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"file": {"url": "https://example.com/report.pdf", "aeskey": "key456", "name": "report.pdf"},
|
||||
})
|
||||
await channel._process_message(frame, "file")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert len(msg.media) == 1
|
||||
assert msg.media[0].endswith("report.pdf")
|
||||
assert "[file: report.pdf]" in msg.content
|
||||
finally:
|
||||
p = os.path.join(os.path.dirname(saved), "report.pdf")
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_voice_message() -> None:
|
||||
"""Voice message: transcribed text is included in content."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_voice_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"voice": {"content": "transcribed text here"},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "voice")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert "transcribed text here" in msg.content
|
||||
assert "[voice]" in msg.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_deduplication() -> None:
|
||||
"""Same msg_id is not processed twice."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_dup_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"text": {"content": "once"},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "text")
|
||||
await channel._process_message(frame, "text")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.content == "once"
|
||||
|
||||
# Second message should not appear on the bus
|
||||
assert channel.bus.inbound.empty()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_empty_content_skipped() -> None:
|
||||
"""Message with empty content produces no bus message."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_empty_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"text": {"content": ""},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "text")
|
||||
|
||||
assert channel.bus.inbound.empty()
|
||||
@ -1003,3 +1003,185 @@ async def test_download_media_item_non_image_requires_aes_key_even_with_full_url
|
||||
|
||||
assert saved_path is None
|
||||
channel._client.get.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for media-send error classification (network vs non-network errors)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_outbound_msg(chat_id: str = "wx-user", content: str = "", media: list | None = None):
|
||||
"""Build a minimal OutboundMessage-like object for send() tests."""
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
return OutboundMessage(
|
||||
channel="weixin",
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media or [],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_timeout_error_propagates_without_text_fallback() -> None:
|
||||
"""httpx.TimeoutException during media send must re-raise immediately,
|
||||
NOT fall back to _send_text (which would also fail during network issues)."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._send_media_file = AsyncMock(side_effect=httpx.TimeoutException("timed out"))
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||
|
||||
with pytest.raises(httpx.TimeoutException, match="timed out"):
|
||||
await channel.send(msg)
|
||||
|
||||
# _send_text must NOT have been called as a fallback
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_transport_error_propagates_without_text_fallback() -> None:
|
||||
"""httpx.TransportError during media send must re-raise immediately."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=httpx.TransportError("connection reset")
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||
|
||||
with pytest.raises(httpx.TransportError, match="connection reset"):
|
||||
await channel.send(msg)
|
||||
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_5xx_http_status_error_propagates_without_text_fallback() -> None:
|
||||
"""httpx.HTTPStatusError with a 5xx status must re-raise immediately."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
|
||||
fake_response = httpx.Response(
|
||||
status_code=503,
|
||||
request=httpx.Request("POST", "https://example.test/upload"),
|
||||
)
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Service Unavailable", request=fake_response.request, response=fake_response
|
||||
)
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError, match="Service Unavailable"):
|
||||
await channel.send(msg)
|
||||
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_4xx_http_status_error_falls_back_to_text() -> None:
|
||||
"""httpx.HTTPStatusError with a 4xx status should fall back to text, not re-raise."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
|
||||
fake_response = httpx.Response(
|
||||
status_code=400,
|
||||
request=httpx.Request("POST", "https://example.test/upload"),
|
||||
)
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Bad Request", request=fake_response.request, response=fake_response
|
||||
)
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/photo.jpg"])
|
||||
|
||||
# Should NOT raise — 4xx is a client error, non-retryable
|
||||
await channel.send(msg)
|
||||
|
||||
# _send_text should have been called with the fallback message
|
||||
channel._send_text.assert_awaited_once_with(
|
||||
"wx-user", "[Failed to send: photo.jpg]", "ctx-1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_file_not_found_falls_back_to_text() -> None:
|
||||
"""FileNotFoundError (a non-network error) should fall back to text."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=FileNotFoundError("Media file not found: /tmp/missing.jpg")
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/missing.jpg"])
|
||||
|
||||
# Should NOT raise
|
||||
await channel.send(msg)
|
||||
|
||||
channel._send_text.assert_awaited_once_with(
|
||||
"wx-user", "[Failed to send: missing.jpg]", "ctx-1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_value_error_falls_back_to_text() -> None:
|
||||
"""ValueError (e.g. unsupported format) should fall back to text."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=ValueError("Unsupported media format")
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", media=["/tmp/file.xyz"])
|
||||
|
||||
# Should NOT raise
|
||||
await channel.send(msg)
|
||||
|
||||
channel._send_text.assert_awaited_once_with(
|
||||
"wx-user", "[Failed to send: file.xyz]", "ctx-1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_network_error_does_not_double_api_calls() -> None:
|
||||
"""During network issues, media send should make exactly 1 API call attempt,
|
||||
not 2 (media + text fallback). Verify total call count."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._send_media_file = AsyncMock(
|
||||
side_effect=httpx.ConnectError("connection refused")
|
||||
)
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
msg = _make_outbound_msg(chat_id="wx-user", content="hello", media=["/tmp/img.png"])
|
||||
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
await channel.send(msg)
|
||||
|
||||
# _send_media_file called once, _send_text never called
|
||||
channel._send_media_file.assert_awaited_once()
|
||||
channel._send_text.assert_not_awaited()
|
||||
|
||||
227
tests/channels/ws_test_client.py
Normal file
227
tests/channels/ws_test_client.py
Normal file
@ -0,0 +1,227 @@
|
||||
"""Lightweight WebSocket test client for integration testing the nanobot WebSocket channel.
|
||||
|
||||
Provides an async ``WsTestClient`` class and token-issuance helpers that
|
||||
integration tests can import and use directly::
|
||||
|
||||
from ws_test_client import WsTestClient
|
||||
|
||||
async with WsTestClient("ws://127.0.0.1:8765/", client_id="t") as c:
|
||||
ready = await c.recv_ready()
|
||||
await c.send_text("hello")
|
||||
msg = await c.recv_message()
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import websockets
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
|
||||
@dataclass
|
||||
class WsMessage:
|
||||
"""A parsed message received from the WebSocket server."""
|
||||
|
||||
event: str
|
||||
raw: dict[str, Any] = field(repr=False)
|
||||
|
||||
@property
|
||||
def text(self) -> str | None:
|
||||
return self.raw.get("text")
|
||||
|
||||
@property
|
||||
def chat_id(self) -> str | None:
|
||||
return self.raw.get("chat_id")
|
||||
|
||||
@property
|
||||
def client_id(self) -> str | None:
|
||||
return self.raw.get("client_id")
|
||||
|
||||
@property
|
||||
def media(self) -> list[str] | None:
|
||||
return self.raw.get("media")
|
||||
|
||||
@property
|
||||
def reply_to(self) -> str | None:
|
||||
return self.raw.get("reply_to")
|
||||
|
||||
@property
|
||||
def stream_id(self) -> str | None:
|
||||
return self.raw.get("stream_id")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, WsMessage):
|
||||
return NotImplemented
|
||||
return self.event == other.event and self.raw == other.raw
|
||||
|
||||
|
||||
class WsTestClient:
|
||||
"""Async WebSocket test client with helper methods for common operations.
|
||||
|
||||
Usage::
|
||||
|
||||
async with WsTestClient("ws://127.0.0.1:8765/", client_id="tester") as client:
|
||||
ready = await client.recv_ready()
|
||||
await client.send_text("hello")
|
||||
msg = await client.recv_message(timeout=5.0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
*,
|
||||
client_id: str = "test-client",
|
||||
token: str = "",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
params: list[str] = []
|
||||
if client_id:
|
||||
params.append(f"client_id={client_id}")
|
||||
if token:
|
||||
params.append(f"token={token}")
|
||||
sep = "&" if "?" in uri else "?"
|
||||
self._uri = uri + sep + "&".join(params) if params else uri
|
||||
self._extra_headers = extra_headers
|
||||
self._ws: ClientConnection | None = None
|
||||
|
||||
async def connect(self) -> None:
|
||||
self._ws = await websockets.connect(
|
||||
self._uri,
|
||||
additional_headers=self._extra_headers,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
async def __aenter__(self) -> WsTestClient:
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
@property
|
||||
def ws(self) -> ClientConnection:
|
||||
assert self._ws is not None, "Client is not connected"
|
||||
return self._ws
|
||||
|
||||
# -- Receiving --------------------------------------------------------
|
||||
|
||||
async def recv_raw(self, timeout: float = 10.0) -> dict[str, Any]:
|
||||
"""Receive and parse one raw JSON message with timeout."""
|
||||
raw = await asyncio.wait_for(self.ws.recv(), timeout=timeout)
|
||||
return json.loads(raw)
|
||||
|
||||
async def recv(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive one message, returning a WsMessage wrapper."""
|
||||
data = await self.recv_raw(timeout)
|
||||
return WsMessage(event=data.get("event", ""), raw=data)
|
||||
|
||||
async def recv_ready(self, timeout: float = 5.0) -> WsMessage:
|
||||
"""Receive and validate the 'ready' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "ready", f"Expected 'ready' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def recv_message(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive and validate a 'message' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "message", f"Expected 'message' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def recv_delta(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive and validate a 'delta' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "delta", f"Expected 'delta' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def recv_stream_end(self, timeout: float = 10.0) -> WsMessage:
|
||||
"""Receive and validate a 'stream_end' event."""
|
||||
msg = await self.recv(timeout)
|
||||
assert msg.event == "stream_end", f"Expected 'stream_end' event, got '{msg.event}'"
|
||||
return msg
|
||||
|
||||
async def collect_stream(self, timeout: float = 10.0) -> list[WsMessage]:
|
||||
"""Collect all deltas and the final stream_end into a list."""
|
||||
messages: list[WsMessage] = []
|
||||
while True:
|
||||
msg = await self.recv(timeout)
|
||||
messages.append(msg)
|
||||
if msg.event == "stream_end":
|
||||
break
|
||||
return messages
|
||||
|
||||
async def recv_n(self, n: int, timeout: float = 10.0) -> list[WsMessage]:
|
||||
"""Receive exactly *n* messages."""
|
||||
return [await self.recv(timeout) for _ in range(n)]
|
||||
|
||||
# -- Sending ----------------------------------------------------------
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""Send a plain text frame."""
|
||||
await self.ws.send(text)
|
||||
|
||||
async def send_json(self, data: dict[str, Any]) -> None:
|
||||
"""Send a JSON frame."""
|
||||
await self.ws.send(json.dumps(data, ensure_ascii=False))
|
||||
|
||||
async def send_content(self, content: str) -> None:
|
||||
"""Send content in the preferred JSON format ``{"content": ...}``."""
|
||||
await self.send_json({"content": content})
|
||||
|
||||
# -- Connection introspection -----------------------------------------
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return self._ws is None or self._ws.closed
|
||||
|
||||
|
||||
# -- Token issuance helpers -----------------------------------------------
|
||||
|
||||
|
||||
async def issue_token(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
issue_path: str = "/auth/token",
|
||||
secret: str = "",
|
||||
) -> tuple[dict[str, Any] | None, int]:
|
||||
"""Request a short-lived token from the token-issue HTTP endpoint.
|
||||
|
||||
Returns ``(parsed_json_or_None, status_code)``.
|
||||
"""
|
||||
url = f"http://{host}:{port}{issue_path}"
|
||||
headers: dict[str, str] = {}
|
||||
if secret:
|
||||
headers["Authorization"] = f"Bearer {secret}"
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
resp = await loop.run_in_executor(
|
||||
None, lambda: httpx.get(url, headers=headers, timeout=5.0)
|
||||
)
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
data = None
|
||||
return data, resp.status_code
|
||||
|
||||
|
||||
async def issue_token_ok(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 8765,
|
||||
issue_path: str = "/auth/token",
|
||||
secret: str = "",
|
||||
) -> str:
|
||||
"""Request a token, asserting success, and return the token string."""
|
||||
(data, status) = await issue_token(host, port, issue_path, secret)
|
||||
assert status == 200, f"Token issue failed with status {status}"
|
||||
assert data is not None
|
||||
token = data["token"]
|
||||
assert token.startswith("nbwt_"), f"Unexpected token format: {token}"
|
||||
return token
|
||||
@ -1126,6 +1126,153 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
|
||||
assert "port 18792" in result.stdout
|
||||
|
||||
|
||||
def test_gateway_health_endpoint_binds_and_serves_expected_responses(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _FakeDream:
|
||||
model = None
|
||||
max_batch_size = 0
|
||||
max_iterations = 0
|
||||
|
||||
async def run(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, **_kwargs) -> None:
|
||||
self.model = "test-model"
|
||||
self.dream = _FakeDream()
|
||||
|
||||
async def run(self) -> None:
|
||||
await asyncio.Event().wait()
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeChannelManager:
|
||||
def __init__(self, _config, _bus) -> None:
|
||||
self.enabled_channels = ["telegram", "discord"]
|
||||
|
||||
async def start_all(self) -> None:
|
||||
await asyncio.Event().wait()
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeCronService:
|
||||
def __init__(self, _store_path: Path) -> None:
|
||||
self.on_job = None
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
return None
|
||||
|
||||
def status(self) -> dict[str, int]:
|
||||
return {"jobs": 0}
|
||||
|
||||
def register_system_job(self, _job) -> None:
|
||||
return None
|
||||
|
||||
class _FakeHeartbeatService:
|
||||
def __init__(self, **_kwargs) -> None:
|
||||
return None
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeServer:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
async def serve_forever(self) -> None:
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
async def _fake_start_server(handler, host: str, port: int):
|
||||
captured["handler"] = handler
|
||||
captured["host"] = host
|
||||
captured["port"] = port
|
||||
return _FakeServer()
|
||||
|
||||
class _FakeReader:
|
||||
def __init__(self, payload: bytes) -> None:
|
||||
self.payload = payload
|
||||
|
||||
async def read(self, _size: int) -> bytes:
|
||||
return self.payload
|
||||
|
||||
class _FakeWriter:
|
||||
def __init__(self) -> None:
|
||||
self.output = b""
|
||||
self.closed = False
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
self.output += data
|
||||
|
||||
async def drain(self) -> None:
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager)
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCronService)
|
||||
monkeypatch.setattr("nanobot.heartbeat.service.HeartbeatService", _FakeHeartbeatService)
|
||||
monkeypatch.setattr("asyncio.start_server", _fake_start_server)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert captured["host"] == "127.0.0.1"
|
||||
assert captured["port"] == 18791
|
||||
assert "Health endpoint: http://127.0.0.1:18791/health" in result.stdout
|
||||
|
||||
def _call_handler(path: str) -> tuple[str, _FakeWriter]:
|
||||
request = f"GET {path} HTTP/1.1\r\nHost: localhost\r\n\r\n".encode()
|
||||
writer = _FakeWriter()
|
||||
handler = captured["handler"]
|
||||
assert callable(handler)
|
||||
asyncio.run(handler(_FakeReader(request), writer))
|
||||
return writer.output.decode(), writer
|
||||
|
||||
root_response, root_writer = _call_handler("/")
|
||||
assert root_writer.closed is True
|
||||
assert "HTTP/1.0 404 Not Found" in root_response
|
||||
assert root_response.endswith("\r\n\r\nNot Found")
|
||||
|
||||
health_response, health_writer = _call_handler("/health")
|
||||
assert health_writer.closed is True
|
||||
assert "HTTP/1.0 200 OK" in health_response
|
||||
health_body = json.loads(health_response.split("\r\n\r\n", 1)[1])
|
||||
assert health_body == {"status": "ok"}
|
||||
|
||||
missing_response, missing_writer = _call_handler("/missing")
|
||||
assert missing_writer.closed is True
|
||||
assert "HTTP/1.0 404 Not Found" in missing_response
|
||||
assert missing_response.endswith("\r\n\r\nNot Found")
|
||||
|
||||
|
||||
def test_serve_uses_api_config_defaults_and_workspace_override(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
|
||||
@ -148,7 +148,7 @@ class TestRestartCommand:
|
||||
assert response is not None
|
||||
assert "Model: test-model" in response.content
|
||||
assert "Tokens: 0 in / 0 out" in response.content
|
||||
assert "Context: 20k/64k (31%)" in response.content
|
||||
assert "Context: 20k/65k (31%)" in response.content
|
||||
assert "Session: 3 messages" in response.content
|
||||
assert "Uptime: 2m 5s" in response.content
|
||||
assert response.metadata == {"render_as": "text"}
|
||||
@ -186,7 +186,7 @@ class TestRestartCommand:
|
||||
|
||||
assert response is not None
|
||||
assert "Tokens: 1200 in / 34 out" in response.content
|
||||
assert "Context: 1k/64k (1%)" in response.content
|
||||
assert "Context: 1k/65k (1%)" in response.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_direct_preserves_render_metadata(self):
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
@ -114,6 +115,41 @@ async def test_run_history_persisted_to_disk(tmp_path) -> None:
|
||||
assert loaded.state.run_history[0].status == "ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_disabled_does_not_flip_running_state(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
job = service.add_job(
|
||||
name="disabled",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
service.enable_job(job.id, enabled=False)
|
||||
|
||||
result = await service.run_job(job.id)
|
||||
|
||||
assert result is False
|
||||
assert service._running is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job_preserves_running_service_state(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
service._running = True
|
||||
job = service.add_job(
|
||||
name="manual",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
|
||||
result = await service.run_job(job.id, force=True)
|
||||
|
||||
assert result is True
|
||||
assert service._running is True
|
||||
service.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_service_honors_external_disable(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
@ -158,24 +194,49 @@ def test_remove_job_refuses_system_jobs(tmp_path) -> None:
|
||||
assert service.get_job("dream") is not None
|
||||
|
||||
|
||||
def test_reload_jobs(tmp_path):
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_server_not_jobs(tmp_path):
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
service.add_job(
|
||||
name="hist",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
called = []
|
||||
async def on_job(job):
|
||||
called.append(job.name)
|
||||
|
||||
assert len(service.list_jobs()) == 1
|
||||
service = CronService(store_path, on_job=on_job, max_sleep_ms=1000)
|
||||
await service.start()
|
||||
assert len(service.list_jobs()) == 0
|
||||
|
||||
service2 = CronService(tmp_path / "cron" / "jobs.json")
|
||||
service2.add_job(
|
||||
name="hist2",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello2",
|
||||
name="hist",
|
||||
schedule=CronSchedule(kind="every", every_ms=500),
|
||||
message="hello",
|
||||
)
|
||||
assert len(service.list_jobs()) == 2
|
||||
assert len(service.list_jobs()) == 1
|
||||
await asyncio.sleep(2)
|
||||
assert len(called) != 0
|
||||
service.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subsecond_job_not_delayed_to_one_second(tmp_path):
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
called = []
|
||||
|
||||
async def on_job(job):
|
||||
called.append(job.name)
|
||||
|
||||
service = CronService(store_path, on_job=on_job, max_sleep_ms=5000)
|
||||
service.add_job(
|
||||
name="fast",
|
||||
schedule=CronSchedule(kind="every", every_ms=100),
|
||||
message="hello",
|
||||
)
|
||||
await service.start()
|
||||
try:
|
||||
await asyncio.sleep(0.35)
|
||||
assert called
|
||||
finally:
|
||||
service.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -204,7 +265,302 @@ async def test_running_service_picks_up_external_add(tmp_path):
|
||||
message="ping",
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.6)
|
||||
await asyncio.sleep(2)
|
||||
assert "external" in called
|
||||
finally:
|
||||
service.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_job_during_jobs_exec(tmp_path):
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
run_once = True
|
||||
|
||||
async def on_job(job):
|
||||
nonlocal run_once
|
||||
if run_once:
|
||||
service2 = CronService(store_path, on_job=lambda x: asyncio.sleep(0))
|
||||
service2.add_job(
|
||||
name="test",
|
||||
schedule=CronSchedule(kind="every", every_ms=150),
|
||||
message="tick",
|
||||
)
|
||||
run_once = False
|
||||
|
||||
service = CronService(store_path, on_job=on_job)
|
||||
service.add_job(
|
||||
name="heartbeat",
|
||||
schedule=CronSchedule(kind="every", every_ms=150),
|
||||
message="tick",
|
||||
)
|
||||
assert len(service.list_jobs()) == 1
|
||||
await service.start()
|
||||
try:
|
||||
await asyncio.sleep(3)
|
||||
jobs = service.list_jobs()
|
||||
assert len(jobs) == 2
|
||||
assert "test" in [j.name for j in jobs]
|
||||
finally:
|
||||
service.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_external_update_preserves_run_history_records(tmp_path):
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
job = service.add_job(
|
||||
name="history",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
await service.run_job(job.id, force=True)
|
||||
|
||||
external = CronService(store_path)
|
||||
updated = external.enable_job(job.id, enabled=False)
|
||||
assert updated is not None
|
||||
|
||||
fresh = CronService(store_path)
|
||||
loaded = fresh.get_job(job.id)
|
||||
assert loaded is not None
|
||||
assert loaded.state.run_history
|
||||
assert loaded.state.run_history[0].status == "ok"
|
||||
|
||||
fresh._running = True
|
||||
fresh._save_store()
|
||||
|
||||
|
||||
# ── timer race regression tests ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timer_execution_is_not_rolled_back_by_list_jobs_reload(tmp_path):
|
||||
"""list_jobs() during _on_timer should not replace the active store and re-run the same due job."""
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
calls: list[str] = []
|
||||
|
||||
async def on_job(job):
|
||||
calls.append(job.id)
|
||||
# Simulate frontend polling list_jobs while the timer callback is mid-execution.
|
||||
service.list_jobs(include_disabled=True)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
service = CronService(store_path, on_job=on_job)
|
||||
service._running = True
|
||||
service._load_store()
|
||||
service._arm_timer = lambda: None
|
||||
|
||||
job = service.add_job(
|
||||
name="race",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
job.state.next_run_at_ms = max(1, int(time.time() * 1000) - 1_000)
|
||||
service._save_store()
|
||||
|
||||
await service._on_timer()
|
||||
await service._on_timer()
|
||||
|
||||
assert calls == [job.id]
|
||||
loaded = service.get_job(job.id)
|
||||
assert loaded is not None
|
||||
assert loaded.state.last_run_at_ms is not None
|
||||
assert loaded.state.next_run_at_ms is not None
|
||||
assert loaded.state.next_run_at_ms > loaded.state.last_run_at_ms
|
||||
|
||||
|
||||
# ── update_job tests ──
|
||||
|
||||
|
||||
def test_update_job_changes_name(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="old name",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
result = service.update_job(job.id, name="new name")
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.name == "new name"
|
||||
assert result.payload.message == "hello"
|
||||
|
||||
|
||||
def test_update_job_changes_schedule(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="sched",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
old_next = job.state.next_run_at_ms
|
||||
|
||||
new_sched = CronSchedule(kind="every", every_ms=120_000)
|
||||
result = service.update_job(job.id, schedule=new_sched)
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.schedule.every_ms == 120_000
|
||||
assert result.state.next_run_at_ms != old_next
|
||||
|
||||
|
||||
def test_update_job_changes_message(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="msg",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="old message",
|
||||
)
|
||||
result = service.update_job(job.id, message="new message")
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.payload.message == "new message"
|
||||
|
||||
|
||||
def test_update_job_changes_cron_expression(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="cron-job",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||
message="hello",
|
||||
)
|
||||
result = service.update_job(
|
||||
job.id,
|
||||
schedule=CronSchedule(kind="cron", expr="0 18 * * *", tz="UTC"),
|
||||
)
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.schedule.expr == "0 18 * * *"
|
||||
assert result.state.next_run_at_ms is not None
|
||||
|
||||
|
||||
def test_update_job_not_found(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
result = service.update_job("nonexistent", name="x")
|
||||
assert result == "not_found"
|
||||
|
||||
|
||||
def test_update_job_rejects_system_job(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
service.register_system_job(CronJob(
|
||||
id="dream",
|
||||
name="dream",
|
||||
schedule=CronSchedule(kind="cron", expr="0 */2 * * *", tz="UTC"),
|
||||
payload=CronPayload(kind="system_event"),
|
||||
))
|
||||
result = service.update_job("dream", name="hacked")
|
||||
assert result == "protected"
|
||||
assert service.get_job("dream").name == "dream"
|
||||
|
||||
|
||||
def test_update_job_validates_schedule(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="validate",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
with pytest.raises(ValueError, match="unknown timezone"):
|
||||
service.update_job(
|
||||
job.id,
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="Bad/Zone"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_preserves_run_history(tmp_path) -> None:
|
||||
import asyncio
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
job = service.add_job(
|
||||
name="hist",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
await service.run_job(job.id)
|
||||
|
||||
result = service.update_job(job.id, name="renamed")
|
||||
assert isinstance(result, CronJob)
|
||||
assert len(result.state.run_history) == 1
|
||||
assert result.state.run_history[0].status == "ok"
|
||||
|
||||
|
||||
def test_update_job_offline_writes_action(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="offline",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
service.update_job(job.id, name="updated-offline")
|
||||
|
||||
action_path = tmp_path / "cron" / "action.jsonl"
|
||||
assert action_path.exists()
|
||||
lines = [l for l in action_path.read_text().strip().split("\n") if l]
|
||||
last = json.loads(lines[-1])
|
||||
assert last["action"] == "update"
|
||||
assert last["params"]["name"] == "updated-offline"
|
||||
|
||||
|
||||
def test_update_job_sentinel_channel_and_to(tmp_path) -> None:
|
||||
"""Passing None clears channel/to; omitting leaves them unchanged."""
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
job = service.add_job(
|
||||
name="sentinel",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
channel="telegram",
|
||||
to="user123",
|
||||
)
|
||||
assert job.payload.channel == "telegram"
|
||||
assert job.payload.to == "user123"
|
||||
|
||||
result = service.update_job(job.id, name="renamed")
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.payload.channel == "telegram"
|
||||
assert result.payload.to == "user123"
|
||||
|
||||
result = service.update_job(job.id, channel=None, to=None)
|
||||
assert isinstance(result, CronJob)
|
||||
assert result.payload.channel is None
|
||||
assert result.payload.to is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_jobs_during_on_job_does_not_cause_stale_reload(tmp_path) -> None:
|
||||
"""Regression: if the bot calls list_jobs (which reloads from disk) during
|
||||
on_job execution, the in-memory next_run_at_ms update must not be lost.
|
||||
Previously this caused an infinite re-trigger loop."""
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
execution_count = 0
|
||||
|
||||
async def on_job_that_lists(job):
|
||||
nonlocal execution_count
|
||||
execution_count += 1
|
||||
# Simulate the bot calling cron(action=list) mid-execution
|
||||
service.list_jobs()
|
||||
|
||||
service = CronService(store_path, on_job=on_job_that_lists, max_sleep_ms=100)
|
||||
await service.start()
|
||||
|
||||
# Add two jobs scheduled in the past so they're immediately due
|
||||
now_ms = int(time.time() * 1000)
|
||||
for name in ("job-a", "job-b"):
|
||||
service.add_job(
|
||||
name=name,
|
||||
schedule=CronSchedule(kind="every", every_ms=3_600_000),
|
||||
message="test",
|
||||
)
|
||||
# Force next_run to the past so _on_timer picks them up
|
||||
for job in service._store.jobs:
|
||||
job.state.next_run_at_ms = now_ms - 1000
|
||||
service._save_store()
|
||||
service._arm_timer()
|
||||
|
||||
# Let the timer fire once
|
||||
await asyncio.sleep(0.3)
|
||||
service.stop()
|
||||
|
||||
# Each job should have run exactly once, not looped
|
||||
assert execution_count == 2
|
||||
|
||||
# Verify next_run_at_ms was persisted correctly (in the future)
|
||||
raw = json.loads(store_path.read_text())
|
||||
for j in raw["jobs"]:
|
||||
next_run = j["state"]["nextRunAtMs"]
|
||||
assert next_run is not None
|
||||
assert next_run > now_ms, f"Job '{j['name']}' next_run should be in the future"
|
||||
|
||||
@ -2,9 +2,12 @@
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronSchedule
|
||||
from tests.test_openai_api import pytest_plugins
|
||||
|
||||
|
||||
def _make_tool(tmp_path) -> CronTool:
|
||||
@ -215,8 +218,10 @@ def test_list_at_job_shows_iso_timestamp(tmp_path) -> None:
|
||||
assert "Asia/Shanghai" in result
|
||||
|
||||
|
||||
def test_list_shows_last_run_state(tmp_path) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_shows_last_run_state(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool._cron._running = True
|
||||
job = tool._cron.add_job(
|
||||
name="Stateful job",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||
@ -232,9 +237,10 @@ def test_list_shows_last_run_state(tmp_path) -> None:
|
||||
assert "ok" in result
|
||||
assert "(UTC)" in result
|
||||
|
||||
|
||||
def test_list_shows_error_message(tmp_path) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_shows_error_message(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool._cron._running = True
|
||||
job = tool._cron.add_job(
|
||||
name="Failed job",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||
|
||||
@ -4,6 +4,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
|
||||
def test_custom_provider_parse_handles_empty_choices() -> None:
|
||||
@ -53,3 +54,20 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None:
|
||||
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.content == "hello world"
|
||||
|
||||
|
||||
def test_local_provider_502_error_includes_reachability_hint() -> None:
|
||||
spec = find_by_name("ollama")
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider(api_base="http://localhost:11434/v1", spec=spec)
|
||||
|
||||
result = provider._handle_error(
|
||||
Exception("Error code: 502"),
|
||||
spec=spec,
|
||||
api_base="http://localhost:11434/v1",
|
||||
)
|
||||
|
||||
assert result.finish_reason == "error"
|
||||
assert "local model endpoint" in result.content
|
||||
assert "http://localhost:11434/v1" in result.content
|
||||
assert "proxy/tunnel" in result.content
|
||||
|
||||
197
tests/providers/test_enforce_role_alternation.py
Normal file
197
tests/providers/test_enforce_role_alternation.py
Normal file
@ -0,0 +1,197 @@
|
||||
"""Tests for LLMProvider._enforce_role_alternation."""
|
||||
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
class TestEnforceRoleAlternation:
|
||||
"""Verify trailing-assistant removal and consecutive same-role merging."""
|
||||
|
||||
def test_empty_messages(self):
|
||||
assert LLMProvider._enforce_role_alternation([]) == []
|
||||
|
||||
def test_no_change_needed(self):
|
||||
msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
{"role": "user", "content": "Bye"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 4
|
||||
assert result[-1]["role"] == "user"
|
||||
|
||||
def test_trailing_assistant_removed(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_multiple_trailing_assistants_removed(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "A"},
|
||||
{"role": "assistant", "content": "B"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_consecutive_user_messages_merged(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 1
|
||||
assert "Hello" in result[0]["content"]
|
||||
assert "How are you?" in result[0]["content"]
|
||||
|
||||
def test_consecutive_assistant_messages_merged(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
{"role": "assistant", "content": "How can I help?"},
|
||||
{"role": "user", "content": "Thanks"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 3
|
||||
assert "Hello!" in result[1]["content"]
|
||||
assert "How can I help?" in result[1]["content"]
|
||||
|
||||
def test_system_messages_not_merged(self):
|
||||
msgs = [
|
||||
{"role": "system", "content": "System A"},
|
||||
{"role": "system", "content": "System B"},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 3
|
||||
assert result[0]["content"] == "System A"
|
||||
assert result[1]["content"] == "System B"
|
||||
|
||||
def test_tool_messages_not_merged(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||
{"role": "tool", "content": "result2", "tool_call_id": "2"},
|
||||
{"role": "user", "content": "Next"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
tool_msgs = [m for m in result if m["role"] == "tool"]
|
||||
assert len(tool_msgs) == 2
|
||||
|
||||
def test_consecutive_assistant_keeps_later_tool_call_message(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Previous reply"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||
{"role": "user", "content": "Next"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert result[1]["tool_calls"] == [{"id": "1"}]
|
||||
assert result[1]["content"] is None
|
||||
assert result[2]["role"] == "tool"
|
||||
|
||||
def test_consecutive_assistant_does_not_overwrite_existing_tool_call_message(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [{"id": "1"}]},
|
||||
{"role": "assistant", "content": "Later plain assistant"},
|
||||
{"role": "tool", "content": "result1", "tool_call_id": "1"},
|
||||
{"role": "user", "content": "Next"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert result[1]["tool_calls"] == [{"id": "1"}]
|
||||
assert result[1]["content"] is None
|
||||
assert result[2]["role"] == "tool"
|
||||
|
||||
def test_non_string_content_uses_latest(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": [{"type": "text", "text": "A"}]},
|
||||
{"role": "user", "content": "B"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 1
|
||||
assert result[0]["content"] == "B"
|
||||
|
||||
def test_original_messages_not_mutated(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "user", "content": "World"},
|
||||
]
|
||||
original_first = dict(msgs[0])
|
||||
LLMProvider._enforce_role_alternation(msgs)
|
||||
assert msgs[0] == original_first
|
||||
assert len(msgs) == 2
|
||||
|
||||
def test_trailing_assistant_recovered_as_user_when_only_system_remains(self):
|
||||
"""Subagent result injected as assistant message must not be silently dropped.
|
||||
|
||||
When build_messages(current_role="assistant") produces [system, assistant],
|
||||
_enforce_role_alternation would drop the assistant, leaving only [system].
|
||||
Most providers (e.g. Zhipu/GLM error 1214) reject such requests.
|
||||
The trailing assistant should be recovered as a user message instead.
|
||||
"""
|
||||
msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "assistant", "content": "Subagent completed successfully."},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 2
|
||||
assert result[0]["role"] == "system"
|
||||
assert result[1]["role"] == "user"
|
||||
assert "Subagent completed successfully." in result[1]["content"]
|
||||
|
||||
def test_trailing_assistant_not_recovered_when_user_message_present(self):
|
||||
"""Recovery should NOT happen when a user message already exists."""
|
||||
msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 2
|
||||
assert result[-1]["role"] == "user"
|
||||
|
||||
def test_trailing_assistant_recovered_with_tool_result_preceding(self):
|
||||
"""When only [system, tool, assistant] remains, recovery is not needed
|
||||
because tool messages are valid non-system content."""
|
||||
msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "tool", "content": "result", "tool_call_id": "1"},
|
||||
{"role": "assistant", "content": "Done."},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 2
|
||||
assert result[-1]["role"] == "tool"
|
||||
|
||||
def test_only_assistant_messages(self):
|
||||
msgs = [
|
||||
{"role": "assistant", "content": "A"},
|
||||
{"role": "assistant", "content": "B"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert result == []
|
||||
|
||||
def test_realistic_conversation(self):
|
||||
msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
{"role": "user", "content": "And 3+3?"},
|
||||
{"role": "user", "content": "(please be quick)"},
|
||||
{"role": "assistant", "content": "6"},
|
||||
]
|
||||
result = LLMProvider._enforce_role_alternation(msgs)
|
||||
assert len(result) == 4
|
||||
assert result[2]["role"] == "assistant"
|
||||
assert result[3]["role"] == "user"
|
||||
assert "And 3+3?" in result[3]["content"]
|
||||
assert "(please be quick)" in result[3]["content"]
|
||||
@ -10,7 +10,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -54,6 +54,57 @@ def _fake_tool_call_response() -> SimpleNamespace:
|
||||
return SimpleNamespace(choices=[choice], usage=usage)
|
||||
|
||||
|
||||
def _fake_responses_response(content: str = "ok") -> MagicMock:
|
||||
"""Build a minimal Responses API response object."""
|
||||
resp = MagicMock()
|
||||
resp.model_dump.return_value = {
|
||||
"output": [{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
}],
|
||||
"status": "completed",
|
||||
"usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
}
|
||||
return resp
|
||||
|
||||
|
||||
def _fake_responses_stream(text: str = "ok"):
|
||||
async def _stream():
|
||||
yield SimpleNamespace(type="response.output_text.delta", delta=text)
|
||||
yield SimpleNamespace(
|
||||
type="response.completed",
|
||||
response=SimpleNamespace(
|
||||
status="completed",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=5, total_tokens=15),
|
||||
output=[],
|
||||
),
|
||||
)
|
||||
|
||||
return _stream()
|
||||
|
||||
|
||||
def _fake_chat_stream(text: str = "ok"):
|
||||
async def _stream():
|
||||
yield SimpleNamespace(
|
||||
choices=[SimpleNamespace(finish_reason=None, delta=SimpleNamespace(content=text, reasoning_content=None, tool_calls=None))],
|
||||
usage=None,
|
||||
)
|
||||
yield SimpleNamespace(
|
||||
choices=[SimpleNamespace(finish_reason="stop", delta=SimpleNamespace(content=None, reasoning_content=None, tool_calls=None))],
|
||||
usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15),
|
||||
)
|
||||
|
||||
return _stream()
|
||||
|
||||
|
||||
class _FakeResponsesError(Exception):
|
||||
def __init__(self, status_code: int, text: str):
|
||||
super().__init__(text)
|
||||
self.status_code = status_code
|
||||
self.response = SimpleNamespace(status_code=status_code, text=text, headers={})
|
||||
|
||||
|
||||
class _StalledStream:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
@ -226,6 +277,224 @@ def test_openai_model_passthrough() -> None:
|
||||
assert provider.get_default_model() == "gpt-4o"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_gpt5_uses_responses_api() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||
mock_responses = AsyncMock(return_value=_fake_responses_response("from responses"))
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-5-chat",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
|
||||
assert result.content == "from responses"
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_not_awaited()
|
||||
call_kwargs = mock_responses.call_args.kwargs
|
||||
assert call_kwargs["model"] == "gpt-5-chat"
|
||||
assert call_kwargs["max_output_tokens"] == 4096
|
||||
assert "input" in call_kwargs
|
||||
assert "messages" not in call_kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_reasoning_prefers_responses_api() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||
mock_responses = AsyncMock(return_value=_fake_responses_response("reasoned"))
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-4o",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-4o",
|
||||
reasoning_effort="medium",
|
||||
)
|
||||
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_not_awaited()
|
||||
call_kwargs = mock_responses.call_args.kwargs
|
||||
assert call_kwargs["reasoning"] == {"effort": "medium"}
|
||||
assert call_kwargs["include"] == ["reasoning.encrypted_content"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_gpt4o_stays_on_chat_completions() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||
mock_responses = AsyncMock(return_value=_fake_responses_response())
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-4o",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
mock_chat.assert_awaited_once()
|
||||
mock_responses.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openrouter_gpt5_stays_on_chat_completions() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response())
|
||||
mock_responses = AsyncMock(return_value=_fake_responses_response())
|
||||
spec = find_by_name("openrouter")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-or-test-key",
|
||||
api_base="https://openrouter.ai/api/v1",
|
||||
default_model="openai/gpt-5",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="openai/gpt-5",
|
||||
)
|
||||
|
||||
mock_chat.assert_awaited_once()
|
||||
mock_responses.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_streaming_gpt5_uses_responses_api() -> None:
|
||||
mock_chat = AsyncMock(return_value=_StalledStream())
|
||||
mock_responses = AsyncMock(return_value=_fake_responses_stream("hi"))
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-5-chat",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat_stream(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
|
||||
assert result.content == "hi"
|
||||
assert result.finish_reason == "stop"
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_responses_404_falls_back_to_chat_completions() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response("from chat"))
|
||||
mock_responses = AsyncMock(side_effect=_FakeResponsesError(404, "Responses endpoint not supported"))
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-5-chat",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
|
||||
assert result.content == "from chat"
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_stream_responses_unsupported_param_falls_back() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_stream("fallback stream"))
|
||||
mock_responses = AsyncMock(
|
||||
side_effect=_FakeResponsesError(400, "Unknown parameter: max_output_tokens for Responses API")
|
||||
)
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-5-chat",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat_stream(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
|
||||
assert result.content == "fallback stream"
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_openai_responses_rate_limit_does_not_fallback() -> None:
|
||||
mock_chat = AsyncMock(return_value=_fake_chat_response("from chat"))
|
||||
mock_responses = AsyncMock(side_effect=_FakeResponsesError(429, "rate limit"))
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_chat
|
||||
client_instance.responses.create = mock_responses
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-5-chat",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-5-chat",
|
||||
)
|
||||
|
||||
assert result.finish_reason == "error"
|
||||
mock_responses.assert_awaited_once()
|
||||
mock_chat.assert_not_awaited()
|
||||
|
||||
|
||||
def test_openai_compat_supports_temperature_matches_reasoning_model_rules() -> None:
|
||||
assert OpenAICompatProvider._supports_temperature("gpt-4o") is True
|
||||
assert OpenAICompatProvider._supports_temperature("gpt-5-chat") is False
|
||||
@ -263,6 +532,7 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "done",
|
||||
@ -276,12 +546,42 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
||||
"extra_content": {"google": {"thought_signature": "sig"}},
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
{"role": "user", "content": "thanks"},
|
||||
])
|
||||
|
||||
assert sanitized[0]["reasoning_content"] == "hidden"
|
||||
assert sanitized[0]["extra_content"] == {"debug": True}
|
||||
assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||
assert sanitized[1]["content"] is None
|
||||
assert sanitized[1]["reasoning_content"] == "hidden"
|
||||
assert sanitized[1]["extra_content"] == {"debug": True}
|
||||
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||
|
||||
|
||||
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{"role": "user", "content": "不错"},
|
||||
{"role": "assistant", "content": "对,破 4 万指日可待"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "<think>我再查一下</think>",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_function_akxp3wqzn7ph_1",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_function_akxp3wqzn7ph_1", "name": "exec", "content": "ok"},
|
||||
{"role": "user", "content": "多少star了呢"},
|
||||
])
|
||||
|
||||
assert sanitized[1]["role"] == "assistant"
|
||||
assert sanitized[1]["content"] is None
|
||||
assert sanitized[1]["tool_calls"][0]["id"] == "3ec83c30d"
|
||||
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
|
||||
@ -152,7 +153,7 @@ async def test_non_transient_error_with_images_retries_without_images() -> None:
|
||||
LLMResponse(content="ok, no image"),
|
||||
])
|
||||
|
||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
||||
response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG))
|
||||
|
||||
assert response.content == "ok, no image"
|
||||
assert provider.calls == 2
|
||||
@ -164,6 +165,24 @@ async def test_non_transient_error_with_images_retries_without_images() -> None:
|
||||
assert any("[image: /media/test.png]" in (b.get("text") or "") for b in content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_image_retry_mutates_original_messages_in_place() -> None:
|
||||
"""Successful no-image retry should update the caller's message history."""
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="model does not support images", finish_reason="error"),
|
||||
LLMResponse(content="ok, no image"),
|
||||
])
|
||||
messages = copy.deepcopy(_IMAGE_MSG)
|
||||
|
||||
response = await provider.chat_with_retry(messages=messages)
|
||||
|
||||
assert response.content == "ok, no image"
|
||||
content = messages[0]["content"]
|
||||
assert isinstance(content, list)
|
||||
assert all(block.get("type") != "image_url" for block in content)
|
||||
assert any("[image: /media/test.png]" in (block.get("text") or "") for block in content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_transient_error_without_images_no_retry() -> None:
|
||||
"""Non-transient errors without image content are returned immediately."""
|
||||
@ -187,7 +206,7 @@ async def test_image_fallback_returns_error_on_second_failure() -> None:
|
||||
LLMResponse(content="still failing", finish_reason="error"),
|
||||
])
|
||||
|
||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG)
|
||||
response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG))
|
||||
|
||||
assert provider.calls == 2
|
||||
assert response.content == "still failing"
|
||||
@ -202,7 +221,7 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
|
||||
response = await provider.chat_with_retry(messages=_IMAGE_MSG_NO_META)
|
||||
response = await provider.chat_with_retry(messages=copy.deepcopy(_IMAGE_MSG_NO_META))
|
||||
|
||||
assert response.content == "ok"
|
||||
assert provider.calls == 2
|
||||
|
||||
@ -10,8 +10,7 @@ import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from nanobot.api.server import (
|
||||
API_CHAT_ID,
|
||||
API_SESSION_KEY,
|
||||
_FileSizeExceeded,
|
||||
_parse_json_content,
|
||||
_save_base64_data_url,
|
||||
create_app,
|
||||
@ -91,6 +90,15 @@ def test_save_base64_data_url_handles_unknown_mime(tmp_path) -> None:
|
||||
assert result.endswith(".bin")
|
||||
|
||||
|
||||
def test_save_base64_data_url_rejects_oversized_payload(tmp_path) -> None:
|
||||
"""Base64 uploads should respect the same per-file limit as multipart."""
|
||||
large_payload = base64.b64encode(b"x" * (11 * 1024 * 1024)).decode()
|
||||
data_url = f"data:image/png;base64,{large_payload}"
|
||||
|
||||
with pytest.raises(_FileSizeExceeded, match="10MB limit"):
|
||||
_save_base64_data_url(data_url, tmp_path)
|
||||
|
||||
|
||||
def test_parse_json_content_extracts_text_and_media(tmp_path) -> None:
|
||||
"""Parse JSON with text + base64 image saves image and returns paths."""
|
||||
b64_data = base64.b64encode(b"img").decode()
|
||||
@ -144,6 +152,31 @@ def test_parse_json_content_validates_user_role() -> None:
|
||||
_parse_json_content(body)
|
||||
|
||||
|
||||
def test_parse_json_content_rejects_oversized_base64_file(tmp_path) -> None:
|
||||
"""Oversized JSON data URLs should fail before writing to disk."""
|
||||
large_payload = base64.b64encode(b"x" * (11 * 1024 * 1024)).decode()
|
||||
body = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe"},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{large_payload}"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
with pytest.raises(_FileSizeExceeded, match="10MB limit"):
|
||||
_parse_json_content(body)
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multipart upload tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
@ -64,3 +63,19 @@ def test_build_user_content_mixed_image_and_document(tmp_path: Path) -> None:
|
||||
assert any(b["type"] == "image_url" for b in result)
|
||||
text_parts = [b.get("text", "") for b in result if b.get("type") == "text"]
|
||||
assert any("Report text here" in t for t in text_parts)
|
||||
|
||||
|
||||
def test_build_user_content_skips_document_extraction_errors(tmp_path: Path, monkeypatch) -> None:
|
||||
"""Document extraction errors should not be embedded into the user prompt."""
|
||||
docx_path = tmp_path / "broken.docx"
|
||||
docx_path.write_text("not a real docx", encoding="utf-8")
|
||||
|
||||
builder = _make_builder(tmp_path)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.document.extract_text",
|
||||
lambda _path: "[error: failed to extract DOCX: boom]",
|
||||
)
|
||||
|
||||
result = builder._build_user_content("summarize this", [str(docx_path)])
|
||||
assert result == "summarize this"
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
"""Tests for document text extraction utilities."""
|
||||
|
||||
import io
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.utils.document import (
|
||||
SUPPORTED_EXTENSIONS,
|
||||
_is_text_extension,
|
||||
|
||||
41
tests/test_package_version.py
Normal file
41
tests/test_package_version.py
Normal file
@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
import tomllib
|
||||
|
||||
|
||||
def test_source_checkout_import_uses_pyproject_version_without_metadata() -> None:
|
||||
repo_root = Path(__file__).resolve().parents[1]
|
||||
expected = tomllib.loads((repo_root / "pyproject.toml").read_text(encoding="utf-8"))["project"][
|
||||
"version"
|
||||
]
|
||||
script = textwrap.dedent(
|
||||
f"""
|
||||
import sys
|
||||
import types
|
||||
|
||||
sys.path.insert(0, {str(repo_root)!r})
|
||||
fake = types.ModuleType("nanobot.nanobot")
|
||||
fake.Nanobot = object
|
||||
fake.RunResult = object
|
||||
sys.modules["nanobot.nanobot"] = fake
|
||||
|
||||
import nanobot
|
||||
|
||||
print(nanobot.__version__)
|
||||
"""
|
||||
)
|
||||
|
||||
proc = subprocess.run(
|
||||
[sys.executable, "-S", "-c", script],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
assert proc.returncode == 0, proc.stderr
|
||||
assert proc.stdout.strip() == expected
|
||||
31
tests/test_truncate_text_shadowing.py
Normal file
31
tests/test_truncate_text_shadowing.py
Normal file
@ -0,0 +1,31 @@
|
||||
import inspect
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
def test_sanitize_persisted_blocks_truncate_text_shadowing_regression() -> None:
|
||||
"""Regression: avoid bool param shadowing imported truncate_text.
|
||||
|
||||
Buggy behavior (historical):
|
||||
- loop.py imports `truncate_text` from helpers
|
||||
- `_sanitize_persisted_blocks(..., truncate_text: bool=...)` uses same name
|
||||
- when called with `truncate_text=True`, function body executes `truncate_text(text, ...)`
|
||||
which resolves to bool and raises `TypeError: 'bool' object is not callable`.
|
||||
|
||||
This test asserts the fixed API exists and truncation works without raising.
|
||||
"""
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
sig = inspect.signature(AgentLoop._sanitize_persisted_blocks)
|
||||
assert "should_truncate_text" in sig.parameters
|
||||
assert "truncate_text" not in sig.parameters
|
||||
|
||||
dummy = SimpleNamespace(max_tool_result_chars=5)
|
||||
content = [{"type": "text", "text": "0123456789"}]
|
||||
|
||||
out = AgentLoop._sanitize_persisted_blocks(dummy, content, should_truncate_text=True)
|
||||
assert isinstance(out, list)
|
||||
assert out and out[0]["type"] == "text"
|
||||
assert isinstance(out[0]["text"], str)
|
||||
assert out[0]["text"] != content[0]["text"]
|
||||
|
||||
423
tests/tools/test_edit_advanced.py
Normal file
423
tests/tools/test_edit_advanced.py
Normal file
@ -0,0 +1,423 @@
|
||||
"""Tests for advanced EditFileTool enhancements inspired by claude-code:
|
||||
- Delete-line newline cleanup
|
||||
- Smart quote normalization (curly ↔ straight)
|
||||
- Quote style preservation in replacements
|
||||
- Indentation preservation when fallback match is trimmed
|
||||
- Trailing whitespace stripping for new_text
|
||||
- File size protection
|
||||
- Stale detection with content-equality fallback
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, _find_match
|
||||
from nanobot.agent.tools import file_state
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_file_state():
|
||||
file_state.clear()
|
||||
yield
|
||||
file_state.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete-line newline cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteLineCleanup:
|
||||
"""When new_text='' and deleting a line, trailing newline should be consumed."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_line_consumes_trailing_newline(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("line1\nline2\nline3\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="line2", new_text="")
|
||||
assert "Successfully" in result
|
||||
content = f.read_text()
|
||||
# Should not leave a blank line where line2 was
|
||||
assert content == "line1\nline3\n"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_line_with_explicit_newline_in_old_text(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("line1\nline2\nline3\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="line2\n", new_text="")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "line1\nline3\n"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_preserves_content_when_not_trailing_newline(self, tool, tmp_path):
|
||||
"""Deleting a word mid-line should not consume extra characters."""
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world here\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="world ", new_text="")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "hello here\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Smart quote normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSmartQuoteNormalization:
|
||||
"""_find_match should handle curly ↔ straight quote fallback."""
|
||||
|
||||
def test_curly_double_quotes_match_straight(self):
|
||||
content = 'She said \u201chello\u201d to him'
|
||||
old_text = 'She said "hello" to him'
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
# Returned match should be the ORIGINAL content with curly quotes
|
||||
assert "\u201c" in match
|
||||
|
||||
def test_curly_single_quotes_match_straight(self):
|
||||
content = "it\u2019s a test"
|
||||
old_text = "it's a test"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
assert "\u2019" in match
|
||||
|
||||
def test_straight_matches_curly_in_old_text(self):
|
||||
content = 'x = "hello"'
|
||||
old_text = 'x = \u201chello\u201d'
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
|
||||
def test_exact_match_still_preferred_over_quote_normalization(self):
|
||||
content = 'x = "hello"'
|
||||
old_text = 'x = "hello"'
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match == old_text
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestQuoteStylePreservation:
|
||||
"""When quote-normalized matching occurs, replacement should preserve actual quote style."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replacement_preserves_curly_double_quotes(self, tool, tmp_path):
|
||||
f = tmp_path / "quotes.txt"
|
||||
f.write_text('message = “hello”\n', encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='message = "hello"',
|
||||
new_text='message = "goodbye"',
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == 'message = “goodbye”\n'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replacement_preserves_curly_apostrophe(self, tool, tmp_path):
|
||||
f = tmp_path / "apostrophe.txt"
|
||||
f.write_text("it’s fine\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="it's fine",
|
||||
new_text="it's better",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == "it’s better\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Indentation preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIndentationPreservation:
|
||||
"""Replacement should keep outer indentation when trim fallback matched."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_fallback_preserves_outer_indentation(self, tool, tmp_path):
|
||||
f = tmp_path / "indent.py"
|
||||
f.write_text(
|
||||
"if True:\n"
|
||||
" def foo():\n"
|
||||
" pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="def foo():\n pass",
|
||||
new_text="def bar():\n return 1",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == (
|
||||
"if True:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Failure diagnostics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEditDiagnostics:
|
||||
"""Failure paths should offer actionable hints."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_match_reports_candidate_lines(self, tool, tmp_path):
|
||||
f = tmp_path / "dup.py"
|
||||
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
|
||||
assert "appears 2 times" in result.lower()
|
||||
assert "line 1" in result.lower()
|
||||
assert "line 3" in result.lower()
|
||||
assert "replace_all=true" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_reports_whitespace_hint(self, tool, tmp_path):
|
||||
f = tmp_path / "space.py"
|
||||
f.write_text("value = 1\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="value = 1", new_text="value = 2")
|
||||
assert "Error" in result
|
||||
assert "whitespace" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_reports_case_hint(self, tool, tmp_path):
|
||||
f = tmp_path / "case.py"
|
||||
f.write_text("HelloWorld\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="helloworld", new_text="goodbye")
|
||||
assert "Error" in result
|
||||
assert "letter case differs" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Advanced fallback replacement behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdvancedReplaceAll:
|
||||
"""replace_all should work correctly for fallback-based matches too."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path):
|
||||
f = tmp_path / "indent_multi.py"
|
||||
f.write_text(
|
||||
"if a:\n"
|
||||
" def foo():\n"
|
||||
" pass\n"
|
||||
"if b:\n"
|
||||
" def foo():\n"
|
||||
" pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="def foo():\n pass",
|
||||
new_text="def bar():\n return 1",
|
||||
replace_all=True,
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == (
|
||||
"if a:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
"if b:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path):
|
||||
f = tmp_path / "quote_indent.py"
|
||||
f.write_text(" message = “hello”\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='message = "hello"',
|
||||
new_text='message = "goodbye"',
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == " message = “goodbye”\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Advanced fallback replacement behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdvancedReplaceAll:
|
||||
"""replace_all should work correctly for fallback-based matches too."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path):
|
||||
f = tmp_path / "indent_multi.py"
|
||||
f.write_text(
|
||||
"if a:\n"
|
||||
" def foo():\n"
|
||||
" pass\n"
|
||||
"if b:\n"
|
||||
" def foo():\n"
|
||||
" pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="def foo():\n pass",
|
||||
new_text="def bar():\n return 1",
|
||||
replace_all=True,
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == (
|
||||
"if a:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
"if b:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path):
|
||||
f = tmp_path / "quote_indent.py"
|
||||
f.write_text(" message = “hello”\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='message = "hello"',
|
||||
new_text='message = "goodbye"',
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == " message = “goodbye”\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trailing whitespace stripping on new_text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrailingWhitespaceStrip:
|
||||
"""new_text trailing whitespace should be stripped (except .md files)."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strips_trailing_whitespace_from_new_text(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("x = 1\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="x = 1", new_text="x = 2 \ny = 3 ",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
content = f.read_text()
|
||||
assert "x = 2\ny = 3\n" == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_trailing_whitespace_in_markdown(self, tool, tmp_path):
|
||||
f = tmp_path / "doc.md"
|
||||
f.write_text("# Title\n", encoding="utf-8")
|
||||
# Markdown uses trailing double-space for line breaks
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="# Title", new_text="# Title \nSubtitle ",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
content = f.read_text()
|
||||
# Trailing spaces should be preserved for markdown
|
||||
assert "Title " in content
|
||||
assert "Subtitle " in content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File size protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileSizeProtection:
|
||||
"""Editing extremely large files should be rejected."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_file_over_size_limit(self, tool, tmp_path):
|
||||
f = tmp_path / "huge.txt"
|
||||
f.write_text("x", encoding="utf-8")
|
||||
# Monkey-patch the file size check by creating a stat mock
|
||||
original_stat = f.stat
|
||||
|
||||
class FakeStat:
|
||||
def __init__(self, real_stat):
|
||||
self._real = real_stat
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._real, name)
|
||||
|
||||
@property
|
||||
def st_size(self):
|
||||
return 2 * 1024 * 1024 * 1024 # 2 GiB
|
||||
|
||||
import unittest.mock
|
||||
with unittest.mock.patch.object(type(f), 'stat', return_value=FakeStat(f.stat())):
|
||||
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
||||
assert "Error" in result
|
||||
assert "too large" in result.lower() or "size" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stale detection with content-equality fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStaleDetectionContentFallback:
|
||||
"""When mtime changed but file content is unchanged, edit should proceed without warning."""
|
||||
|
||||
@pytest.fixture()
|
||||
def read_tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def edit_tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mtime_bump_same_content_no_warning(self, read_tool, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
await read_tool.execute(path=str(f))
|
||||
|
||||
# Touch the file to bump mtime without changing content
|
||||
time.sleep(0.05)
|
||||
original_content = f.read_text()
|
||||
f.write_text(original_content, encoding="utf-8")
|
||||
|
||||
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
# Should NOT warn about modification since content is the same
|
||||
assert "modified" not in result.lower()
|
||||
152
tests/tools/test_edit_enhancements.py
Normal file
152
tests/tools/test_edit_enhancements.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
|
||||
.ipynb detection, and create-file semantics."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools import file_state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_file_state():
|
||||
"""Reset global read-state between tests."""
|
||||
file_state.clear()
|
||||
yield
|
||||
file_state.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read-before-edit tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditReadTracking:
|
||||
"""edit_file should warn when file hasn't been read first."""
|
||||
|
||||
@pytest.fixture()
|
||||
def read_tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def edit_tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_warns_if_file_not_read_first(self, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
# Should still succeed but include a warning
|
||||
assert "Successfully" in result
|
||||
assert "not been read" in result.lower() or "warning" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_succeeds_cleanly_after_read(self, read_tool, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
await read_tool.execute(path=str(f))
|
||||
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
# No warning when file was read first
|
||||
assert "not been read" not in result.lower()
|
||||
assert f.read_text() == "hello earth"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_warns_if_file_modified_since_read(self, read_tool, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
await read_tool.execute(path=str(f))
|
||||
# External modification
|
||||
f.write_text("hello universe", encoding="utf-8")
|
||||
result = await edit_tool.execute(path=str(f), old_text="universe", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
assert "modified" in result.lower() or "warning" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Create-file semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditCreateFile:
|
||||
"""edit_file with old_text='' creates new file if not exists."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_file_with_empty_old_text(self, tool, tmp_path):
|
||||
f = tmp_path / "subdir" / "new.py"
|
||||
result = await tool.execute(path=str(f), old_text="", new_text="print('hi')")
|
||||
assert "created" in result.lower() or "Successfully" in result
|
||||
assert f.exists()
|
||||
assert f.read_text() == "print('hi')"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_fails_if_file_already_exists_and_not_empty(self, tool, tmp_path):
|
||||
f = tmp_path / "existing.py"
|
||||
f.write_text("existing content", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="", new_text="new content")
|
||||
assert "Error" in result or "already exists" in result.lower()
|
||||
# File should be unchanged
|
||||
assert f.read_text() == "existing content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_succeeds_if_file_exists_but_empty(self, tool, tmp_path):
|
||||
f = tmp_path / "empty.py"
|
||||
f.write_text("", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="", new_text="print('hi')")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "print('hi')"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# .ipynb detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditIpynbDetection:
|
||||
"""edit_file should refuse .ipynb and suggest notebook_edit."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path):
|
||||
f = tmp_path / "analysis.ipynb"
|
||||
f.write_text('{"cells": []}', encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
||||
assert "notebook" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Path suggestion on not-found
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditPathSuggestion:
|
||||
"""edit_file should suggest similar paths on not-found."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggests_similar_filename(self, tool, tmp_path):
|
||||
f = tmp_path / "config.py"
|
||||
f.write_text("x = 1", encoding="utf-8")
|
||||
# Typo: conifg.py
|
||||
result = await tool.execute(
|
||||
path=str(tmp_path / "conifg.py"), old_text="x = 1", new_text="x = 2",
|
||||
)
|
||||
assert "Error" in result
|
||||
assert "config.py" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shows_cwd_in_error(self, tool, tmp_path):
|
||||
result = await tool.execute(
|
||||
path=str(tmp_path / "nonexistent.py"), old_text="a", new_text="b",
|
||||
)
|
||||
assert "Error" in result
|
||||
@ -43,3 +43,34 @@ async def test_exec_path_append_preserves_system_path():
|
||||
tool = ExecTool(path_append="/opt/custom/bin")
|
||||
result = await tool.execute(command="ls /")
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allowed_env_keys_passthrough(monkeypatch):
|
||||
"""Env vars listed in allowed_env_keys should be visible to commands."""
|
||||
monkeypatch.setenv("MY_CUSTOM_VAR", "hello-from-config")
|
||||
tool = ExecTool(allowed_env_keys=["MY_CUSTOM_VAR"])
|
||||
result = await tool.execute(command="printenv MY_CUSTOM_VAR")
|
||||
assert "hello-from-config" in result
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allowed_env_keys_does_not_leak_others(monkeypatch):
|
||||
"""Env vars NOT in allowed_env_keys should still be blocked."""
|
||||
monkeypatch.setenv("MY_CUSTOM_VAR", "hello-from-config")
|
||||
monkeypatch.setenv("MY_SECRET_VAR", "secret-value")
|
||||
tool = ExecTool(allowed_env_keys=["MY_CUSTOM_VAR"])
|
||||
result = await tool.execute(command="printenv MY_SECRET_VAR")
|
||||
assert "secret-value" not in result
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allowed_env_keys_missing_var_ignored(monkeypatch):
|
||||
"""If an allowed key is not set in the parent process, it should be silently skipped."""
|
||||
monkeypatch.delenv("NONEXISTENT_VAR_12345", raising=False)
|
||||
tool = ExecTool(allowed_env_keys=["NONEXISTENT_VAR_12345"])
|
||||
result = await tool.execute(command="printenv NONEXISTENT_VAR_12345")
|
||||
assert "Exit code: 1" in result
|
||||
|
||||
@ -5,12 +5,18 @@ strategy, and sandbox behaviour per platform — without actually running
|
||||
platform-specific binaries (all subprocess calls are mocked).
|
||||
"""
|
||||
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
|
||||
_WINDOWS_ENV_KEYS = {
|
||||
"APPDATA", "LOCALAPPDATA", "ProgramData",
|
||||
"ProgramFiles", "ProgramFiles(x86)", "ProgramW6432",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_env
|
||||
@ -21,7 +27,10 @@ class TestBuildEnvUnix:
|
||||
def test_expected_keys(self):
|
||||
with patch("nanobot.agent.tools.shell._IS_WINDOWS", False):
|
||||
env = ExecTool()._build_env()
|
||||
assert set(env) == {"HOME", "LANG", "TERM"}
|
||||
expected = {"HOME", "LANG", "TERM"}
|
||||
assert expected <= set(env)
|
||||
if sys.platform != "win32":
|
||||
assert set(env) == expected
|
||||
|
||||
def test_home_from_environ(self, monkeypatch):
|
||||
monkeypatch.setenv("HOME", "/Users/dev")
|
||||
@ -45,6 +54,7 @@ class TestBuildEnvWindows:
|
||||
_EXPECTED_KEYS = {
|
||||
"SYSTEMROOT", "COMSPEC", "USERPROFILE", "HOMEDRIVE",
|
||||
"HOMEPATH", "TEMP", "TMP", "PATHEXT", "PATH",
|
||||
*_WINDOWS_ENV_KEYS,
|
||||
}
|
||||
|
||||
def test_expected_keys(self):
|
||||
|
||||
@ -67,3 +67,118 @@ async def test_exec_blocks_chained_internal_url():
|
||||
command="echo start && curl http://169.254.169.254/latest/meta-data/ && echo done"
|
||||
)
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
# --- #2989: block writes to nanobot internal state files -----------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"command",
|
||||
[
|
||||
"cat foo >> history.jsonl",
|
||||
"echo '{}' > history.jsonl",
|
||||
"echo '{}' > memory/history.jsonl",
|
||||
"echo '{}' > ./workspace/memory/history.jsonl",
|
||||
"tee -a history.jsonl < foo",
|
||||
"tee history.jsonl",
|
||||
"cp /tmp/fake.jsonl history.jsonl",
|
||||
"mv backup.jsonl memory/history.jsonl",
|
||||
"dd if=/dev/zero of=memory/history.jsonl",
|
||||
"sed -i 's/old/new/' history.jsonl",
|
||||
"echo x > .dream_cursor",
|
||||
"cp /tmp/x memory/.dream_cursor",
|
||||
],
|
||||
)
|
||||
def test_exec_blocks_writes_to_history_jsonl(command):
|
||||
"""Direct writes to history.jsonl / .dream_cursor must be blocked (#2989)."""
|
||||
tool = ExecTool()
|
||||
result = tool._guard_command(command, "/tmp")
|
||||
assert result is not None
|
||||
assert "dangerous pattern" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"command",
|
||||
[
|
||||
"cat history.jsonl",
|
||||
"wc -l history.jsonl",
|
||||
"tail -n 5 history.jsonl",
|
||||
"grep foo history.jsonl",
|
||||
"cp history.jsonl /tmp/history.backup",
|
||||
"ls memory/",
|
||||
"echo history.jsonl",
|
||||
],
|
||||
)
|
||||
def test_exec_allows_reads_of_history_jsonl(command):
|
||||
"""Read-only access to history.jsonl must still be allowed."""
|
||||
tool = ExecTool()
|
||||
result = tool._guard_command(command, "/tmp")
|
||||
assert result is None
|
||||
|
||||
|
||||
# --- #2826: working_dir must not escape the configured workspace ---------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_blocks_working_dir_outside_workspace(tmp_path):
|
||||
"""An LLM-supplied working_dir outside the workspace must be rejected."""
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True)
|
||||
result = await tool.execute(command="rm calendar.ics", working_dir="/etc")
|
||||
assert "outside the configured workspace" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_blocks_absolute_rm_via_hijacked_working_dir(tmp_path):
|
||||
"""Regression for #2826: `rm /abs/path` via working_dir hijack."""
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
victim_dir = tmp_path / "outside"
|
||||
victim_dir.mkdir()
|
||||
victim = victim_dir / "file.ics"
|
||||
victim.write_text("data")
|
||||
|
||||
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True)
|
||||
result = await tool.execute(
|
||||
command=f"rm {victim}",
|
||||
working_dir=str(victim_dir),
|
||||
)
|
||||
assert "outside the configured workspace" in result
|
||||
assert victim.exists(), "victim file must not have been deleted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allows_working_dir_within_workspace(tmp_path):
|
||||
"""A working_dir that is a subdirectory of the workspace is fine."""
|
||||
workspace = tmp_path / "workspace"
|
||||
subdir = workspace / "project"
|
||||
subdir.mkdir(parents=True)
|
||||
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True, timeout=5)
|
||||
result = await tool.execute(command="echo ok", working_dir=str(subdir))
|
||||
assert "ok" in result
|
||||
assert "outside the configured workspace" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_allows_working_dir_equal_to_workspace(tmp_path):
|
||||
"""Passing working_dir equal to the workspace root must be allowed."""
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=True, timeout=5)
|
||||
result = await tool.execute(command="echo ok", working_dir=str(workspace))
|
||||
assert "ok" in result
|
||||
assert "outside the configured workspace" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_ignores_workspace_check_when_not_restricted(tmp_path):
|
||||
"""Without restrict_to_workspace, the LLM may still choose any working_dir."""
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
other = tmp_path / "other"
|
||||
other.mkdir()
|
||||
tool = ExecTool(working_dir=str(workspace), restrict_to_workspace=False, timeout=5)
|
||||
result = await tool.execute(command="echo ok", working_dir=str(other))
|
||||
assert "ok" in result
|
||||
assert "outside the configured workspace" not in result
|
||||
|
||||
@ -271,15 +271,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_raw_names(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["demo"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo"]
|
||||
@ -291,15 +287,11 @@ async def test_connect_mcp_servers_enabled_tools_defaults_to_all(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo", "mcp_test_other"]
|
||||
@ -311,15 +303,11 @@ async def test_connect_mcp_servers_enabled_tools_supports_wrapped_names(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_demo"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_demo"]
|
||||
@ -331,15 +319,11 @@ async def test_connect_mcp_servers_enabled_tools_empty_list_registers_none(
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session(["demo", "other"])
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=[])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == []
|
||||
@ -358,15 +342,11 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.warning", _warning)
|
||||
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["unknown"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == []
|
||||
@ -376,6 +356,73 @@ async def test_connect_mcp_servers_enabled_tools_warns_on_unknown_entries(
|
||||
assert "Available wrapped names: mcp_test_demo" in warnings[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_logs_stdio_pollution_hint(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
messages: list[str] = []
|
||||
|
||||
def _error(message: str, *args: object) -> None:
|
||||
messages.append(message.format(*args))
|
||||
|
||||
@asynccontextmanager
|
||||
async def _broken_stdio_client(_params: object):
|
||||
raise RuntimeError("Parse error: Unexpected token 'INFO' before JSON-RPC headers")
|
||||
yield # pragma: no cover
|
||||
|
||||
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _broken_stdio_client)
|
||||
monkeypatch.setattr("nanobot.agent.tools.mcp.logger.error", _error)
|
||||
|
||||
registry = ToolRegistry()
|
||||
stacks = await connect_mcp_servers({"gh": MCPServerConfig(command="github-mcp")}, registry)
|
||||
|
||||
assert stacks == {}
|
||||
assert messages
|
||||
assert "stdio protocol pollution" in messages[-1]
|
||||
assert "stdout" in messages[-1]
|
||||
assert "stderr" in messages[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_one_failure_does_not_block_others(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
sessions = {"good": _make_fake_session(["demo"])}
|
||||
|
||||
class _SelectiveClientSession:
|
||||
def __init__(self, read: object, _write: object) -> None:
|
||||
self._session = sessions[read]
|
||||
|
||||
async def __aenter__(self) -> object:
|
||||
return self._session
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
@asynccontextmanager
|
||||
async def _selective_stdio_client(params: object):
|
||||
if params.command == "bad":
|
||||
raise RuntimeError("boom")
|
||||
yield params.command, object()
|
||||
|
||||
monkeypatch.setattr(sys.modules["mcp"], "ClientSession", _SelectiveClientSession)
|
||||
monkeypatch.setattr(sys.modules["mcp.client.stdio"], "stdio_client", _selective_stdio_client)
|
||||
|
||||
registry = ToolRegistry()
|
||||
stacks = await connect_mcp_servers(
|
||||
{
|
||||
"good": MCPServerConfig(command="good"),
|
||||
"bad": MCPServerConfig(command="bad"),
|
||||
},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_good_demo"]
|
||||
assert set(stacks) == {"good"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCPResourceWrapper tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -389,9 +436,7 @@ def _make_resource_def(
|
||||
return SimpleNamespace(name=name, uri=uri, description=description)
|
||||
|
||||
|
||||
def _make_resource_wrapper(
|
||||
session: object, *, timeout: float = 0.1
|
||||
) -> MCPResourceWrapper:
|
||||
def _make_resource_wrapper(session: object, *, timeout: float = 0.1) -> MCPResourceWrapper:
|
||||
return MCPResourceWrapper(session, "srv", _make_resource_def(), resource_timeout=timeout)
|
||||
|
||||
|
||||
@ -434,9 +479,7 @@ async def test_resource_wrapper_execute_handles_timeout() -> None:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(contents=[])
|
||||
|
||||
wrapper = _make_resource_wrapper(
|
||||
SimpleNamespace(read_resource=read_resource), timeout=0.01
|
||||
)
|
||||
wrapper = _make_resource_wrapper(SimpleNamespace(read_resource=read_resource), timeout=0.01)
|
||||
result = await wrapper.execute()
|
||||
assert result == "(MCP resource read timed out after 0.01s)"
|
||||
|
||||
@ -464,20 +507,14 @@ def _make_prompt_def(
|
||||
return SimpleNamespace(name=name, description=description, arguments=arguments)
|
||||
|
||||
|
||||
def _make_prompt_wrapper(
|
||||
session: object, *, timeout: float = 0.1
|
||||
) -> MCPPromptWrapper:
|
||||
return MCPPromptWrapper(
|
||||
session, "srv", _make_prompt_def(), prompt_timeout=timeout
|
||||
)
|
||||
def _make_prompt_wrapper(session: object, *, timeout: float = 0.1) -> MCPPromptWrapper:
|
||||
return MCPPromptWrapper(session, "srv", _make_prompt_def(), prompt_timeout=timeout)
|
||||
|
||||
|
||||
def test_prompt_wrapper_properties() -> None:
|
||||
arg1 = SimpleNamespace(name="topic", required=True)
|
||||
arg2 = SimpleNamespace(name="style", required=False)
|
||||
wrapper = MCPPromptWrapper(
|
||||
None, "myserver", _make_prompt_def(arguments=[arg1, arg2])
|
||||
)
|
||||
wrapper = MCPPromptWrapper(None, "myserver", _make_prompt_def(arguments=[arg1, arg2]))
|
||||
assert wrapper.name == "mcp_myserver_prompt_myprompt"
|
||||
assert "[MCP Prompt]" in wrapper.description
|
||||
assert "A test prompt" in wrapper.description
|
||||
@ -528,9 +565,7 @@ async def test_prompt_wrapper_execute_handles_timeout() -> None:
|
||||
await asyncio.sleep(1)
|
||||
return SimpleNamespace(messages=[])
|
||||
|
||||
wrapper = _make_prompt_wrapper(
|
||||
SimpleNamespace(get_prompt=get_prompt), timeout=0.01
|
||||
)
|
||||
wrapper = _make_prompt_wrapper(SimpleNamespace(get_prompt=get_prompt), timeout=0.01)
|
||||
result = await wrapper.execute()
|
||||
assert result == "(MCP prompt call timed out after 0.01s)"
|
||||
|
||||
@ -616,15 +651,11 @@ async def test_connect_registers_resources_and_prompts(
|
||||
prompt_names=["prompt_c"],
|
||||
)
|
||||
registry = ToolRegistry()
|
||||
stack = AsyncExitStack()
|
||||
await stack.__aenter__()
|
||||
try:
|
||||
await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
stack,
|
||||
)
|
||||
finally:
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert "mcp_test_tool_a" in registry.tool_names
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Test message tool suppress logic for final replies."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic:
|
||||
assert result is not None
|
||||
assert "Hello" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="First answer", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
pending_queue = asyncio.Queue()
|
||||
await pending_queue.put(
|
||||
InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up")
|
||||
)
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start")
|
||||
result = await loop._process_message(msg, pending_queue=pending_queue)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert sent[0].content == "Tool reply"
|
||||
assert result is None
|
||||
|
||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||
@ -107,7 +144,7 @@ class TestMessageToolSuppressLogic:
|
||||
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
progress.append((content, tool_hint))
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
|
||||
147
tests/tools/test_notebook_tool.py
Normal file
147
tests/tools/test_notebook_tool.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""Tests for NotebookEditTool — Jupyter .ipynb editing."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
|
||||
|
||||
def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict:
|
||||
"""Build a minimal valid .ipynb structure."""
|
||||
return {
|
||||
"nbformat": nbformat,
|
||||
"nbformat_minor": nbformat_minor,
|
||||
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
|
||||
"cells": cells or [],
|
||||
}
|
||||
|
||||
|
||||
def _code_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _md_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "markdown", "source": source, "metadata": {}}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _write_nb(tmp_path, name: str, nb: dict) -> str:
|
||||
p = tmp_path / name
|
||||
p.write_text(json.dumps(nb), encoding="utf-8")
|
||||
return str(p)
|
||||
|
||||
|
||||
class TestNotebookEdit:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return NotebookEditTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_cell_content(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="print('world')")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][0]["source"] == "print('world')"
|
||||
assert saved["cells"][1]["source"] == "x = 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_cell_after_target(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 3
|
||||
assert saved["cells"][0]["source"] == "cell 0"
|
||||
assert saved["cells"][1]["source"] == "inserted"
|
||||
assert saved["cells"][2]["source"] == "cell 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_cell(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=1, edit_mode="delete")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 2
|
||||
assert saved["cells"][0]["source"] == "A"
|
||||
assert saved["cells"][1]["source"] == "C"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_notebook_from_scratch(self, tool, tmp_path):
|
||||
path = str(tmp_path / "new.ipynb")
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown")
|
||||
assert "Successfully" in result or "created" in result.lower()
|
||||
saved = json.loads((tmp_path / "new.ipynb").read_text())
|
||||
assert saved["nbformat"] == 4
|
||||
assert len(saved["cells"]) == 1
|
||||
assert saved["cells"][0]["cell_type"] == "markdown"
|
||||
assert saved["cells"][0]["source"] == "# Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_index_error(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("only cell")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=5, new_source="x")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_ipynb_rejected(self, tool, tmp_path):
|
||||
f = tmp_path / "script.py"
|
||||
f.write_text("pass")
|
||||
result = await tool.execute(path=str(f), cell_index=0, new_source="x")
|
||||
assert "Error" in result
|
||||
assert ".ipynb" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_metadata_and_outputs(self, tool, tmp_path):
|
||||
cell = _code_cell("old")
|
||||
cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}]
|
||||
cell["execution_count"] = 42
|
||||
nb = _make_notebook([cell])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="new")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["metadata"]["kernelspec"]["language"] == "python"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nbformat_45_generates_cell_id(self, tool, tmp_path):
|
||||
nb = _make_notebook([], nbformat_minor=5)
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert "id" in saved["cells"][0]
|
||||
assert len(saved["cells"][0]["id"]) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_with_cell_type_markdown(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][1]["cell_type"] == "markdown"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_edit_mode_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae")
|
||||
assert "Error" in result
|
||||
assert "edit_mode" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_type_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw")
|
||||
assert "Error" in result
|
||||
assert "cell_type" in result
|
||||
180
tests/tools/test_read_enhancements.py
Normal file
180
tests/tools/test_read_enhancements.py
Normal file
@ -0,0 +1,180 @@
|
||||
"""Tests for ReadFileTool enhancements: description fix, read dedup, PDF support, device blacklist."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools import file_state
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_file_state():
|
||||
file_state.clear()
|
||||
yield
|
||||
file_state.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Description fix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadDescriptionFix:
|
||||
|
||||
def test_description_mentions_image_support(self):
|
||||
tool = ReadFileTool()
|
||||
assert "image" in tool.description.lower()
|
||||
|
||||
def test_description_no_longer_says_cannot_read_images(self):
|
||||
tool = ReadFileTool()
|
||||
assert "cannot read binary files or images" not in tool.description.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read deduplication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadDedup:
|
||||
"""Same file + same offset/limit + unchanged mtime -> short stub."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def write_tool(self, tmp_path):
|
||||
return WriteFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_read_returns_unchanged_stub(self, tool, tmp_path):
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("\n".join(f"line {i}" for i in range(100)), encoding="utf-8")
|
||||
first = await tool.execute(path=str(f))
|
||||
assert "line 0" in first
|
||||
second = await tool.execute(path=str(f))
|
||||
assert "unchanged" in second.lower()
|
||||
# Stub should not contain file content
|
||||
assert "line 0" not in second
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_after_external_modification_returns_full(self, tool, tmp_path):
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("original", encoding="utf-8")
|
||||
await tool.execute(path=str(f))
|
||||
# Modify the file externally
|
||||
f.write_text("modified content", encoding="utf-8")
|
||||
second = await tool.execute(path=str(f))
|
||||
assert "modified content" in second
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_offset_returns_full(self, tool, tmp_path):
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
|
||||
await tool.execute(path=str(f), offset=1, limit=5)
|
||||
second = await tool.execute(path=str(f), offset=6, limit=5)
|
||||
# Different offset → full read, not stub
|
||||
assert "line 6" in second
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_read_after_write_returns_full_content(self, tool, write_tool, tmp_path):
|
||||
f = tmp_path / "fresh.txt"
|
||||
result = await write_tool.execute(path=str(f), content="hello")
|
||||
assert "Successfully" in result
|
||||
read_result = await tool.execute(path=str(f))
|
||||
assert "hello" in read_result
|
||||
assert "unchanged" not in read_result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedup_does_not_apply_to_images(self, tool, tmp_path):
|
||||
f = tmp_path / "img.png"
|
||||
f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data")
|
||||
first = await tool.execute(path=str(f))
|
||||
assert isinstance(first, list)
|
||||
second = await tool.execute(path=str(f))
|
||||
# Images should always return full content blocks, not a stub
|
||||
assert isinstance(second, list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PDF support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadPdf:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_returns_text_content(self, tool, tmp_path):
|
||||
fitz = pytest.importorskip("fitz")
|
||||
pdf_path = tmp_path / "test.pdf"
|
||||
doc = fitz.open()
|
||||
page = doc.new_page()
|
||||
page.insert_text((72, 72), "Hello PDF World")
|
||||
doc.save(str(pdf_path))
|
||||
doc.close()
|
||||
|
||||
result = await tool.execute(path=str(pdf_path))
|
||||
assert "Hello PDF World" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_pages_parameter(self, tool, tmp_path):
|
||||
fitz = pytest.importorskip("fitz")
|
||||
pdf_path = tmp_path / "multi.pdf"
|
||||
doc = fitz.open()
|
||||
for i in range(5):
|
||||
page = doc.new_page()
|
||||
page.insert_text((72, 72), f"Page {i + 1} content")
|
||||
doc.save(str(pdf_path))
|
||||
doc.close()
|
||||
|
||||
result = await tool.execute(path=str(pdf_path), pages="2-3")
|
||||
assert "Page 2 content" in result
|
||||
assert "Page 3 content" in result
|
||||
assert "Page 1 content" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_file_not_found_error(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope.pdf"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device path blacklist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadDeviceBlacklist:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self):
|
||||
return ReadFileTool()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_random_blocked(self, tool):
|
||||
result = await tool.execute(path="/dev/random")
|
||||
assert "Error" in result
|
||||
assert "blocked" in result.lower() or "device" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_urandom_blocked(self, tool):
|
||||
result = await tool.execute(path="/dev/urandom")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_zero_blocked(self, tool):
|
||||
result = await tool.execute(path="/dev/zero")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proc_fd_blocked(self, tool):
|
||||
result = await tool.execute(path="/proc/self/fd/0")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_symlink_to_dev_zero_blocked(self, tmp_path):
|
||||
tool = ReadFileTool(workspace=tmp_path)
|
||||
link = tmp_path / "zero-link"
|
||||
link.symlink_to("/dev/zero")
|
||||
result = await tool.execute(path=str(link))
|
||||
assert "Error" in result
|
||||
assert "blocked" in result.lower() or "device" in result.lower()
|
||||
@ -323,3 +323,27 @@ async def test_subagent_registers_grep_and_glob(tmp_path: Path) -> None:
|
||||
|
||||
assert "grep" in captured["tool_names"]
|
||||
assert "glob" in captured["tool_names"]
|
||||
|
||||
|
||||
def test_subagent_prompt_respects_disabled_skills(tmp_path: Path) -> None:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
skills_dir = tmp_path / "skills"
|
||||
(skills_dir / "alpha").mkdir(parents=True)
|
||||
(skills_dir / "alpha" / "SKILL.md").write_text("# Alpha\n\nhidden\n", encoding="utf-8")
|
||||
(skills_dir / "beta").mkdir(parents=True)
|
||||
(skills_dir / "beta" / "SKILL.md").write_text("# Beta\n\nshown\n", encoding="utf-8")
|
||||
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=4096,
|
||||
disabled_skills=["alpha"],
|
||||
)
|
||||
|
||||
prompt = mgr._build_subagent_prompt()
|
||||
|
||||
assert "alpha" not in prompt
|
||||
assert "beta" in prompt
|
||||
|
||||
@ -47,3 +47,27 @@ def test_get_definitions_orders_builtins_then_mcp_tools() -> None:
|
||||
"mcp_fs_list",
|
||||
"mcp_git_status",
|
||||
]
|
||||
|
||||
|
||||
def test_prepare_call_read_file_rejects_non_object_params_with_actionable_hint() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool("read_file"))
|
||||
|
||||
tool, params, error = registry.prepare_call("read_file", ["foo.txt"])
|
||||
|
||||
assert tool is None
|
||||
assert params == ["foo.txt"]
|
||||
assert error is not None
|
||||
assert "must be a JSON object" in error
|
||||
assert "Use named parameters" in error
|
||||
|
||||
|
||||
def test_prepare_call_other_tools_keep_generic_object_validation() -> None:
|
||||
registry = ToolRegistry()
|
||||
registry.register(_FakeTool("grep"))
|
||||
|
||||
tool, params, error = registry.prepare_call("grep", ["TODO"])
|
||||
|
||||
assert tool is not None
|
||||
assert params == ["TODO"]
|
||||
assert error == "Error: Invalid parameters for tool 'grep': parameters must be an object, got list"
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
"""Tests for multi-provider web search."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
@ -20,6 +18,25 @@ def _response(status: int = 200, json: dict | None = None) -> httpx.Response:
|
||||
return r
|
||||
|
||||
|
||||
def test_duckduckgo_search_is_exclusive():
|
||||
tool = _tool(provider="duckduckgo")
|
||||
assert tool.exclusive is True
|
||||
assert tool.concurrency_safe is False
|
||||
|
||||
|
||||
def test_brave_with_api_key_remains_concurrency_safe():
|
||||
tool = _tool(provider="brave", api_key="brave-key")
|
||||
assert tool.exclusive is False
|
||||
assert tool.concurrency_safe is True
|
||||
|
||||
|
||||
def test_brave_without_api_key_is_treated_as_duckduckgo_for_concurrency(monkeypatch):
|
||||
monkeypatch.delenv("BRAVE_API_KEY", raising=False)
|
||||
tool = _tool(provider="brave", api_key="")
|
||||
assert tool.exclusive is True
|
||||
assert tool.concurrency_safe is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brave_search(monkeypatch):
|
||||
async def mock_get(self, url, **kw):
|
||||
@ -79,7 +96,6 @@ async def test_duckduckgo_search(monkeypatch):
|
||||
import nanobot.agent.tools.web as web_mod
|
||||
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
|
||||
|
||||
from ddgs import DDGS
|
||||
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||
|
||||
tool = _tool(provider="duckduckgo")
|
||||
@ -120,6 +136,27 @@ async def test_jina_search(monkeypatch):
|
||||
assert "https://jina.ai" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kagi_search(monkeypatch):
|
||||
async def mock_get(self, url, **kw):
|
||||
assert "kagi.com/api/v0/search" in url
|
||||
assert kw["headers"]["Authorization"] == "Bot kagi-key"
|
||||
assert kw["params"] == {"q": "test", "limit": 2}
|
||||
return _response(json={
|
||||
"data": [
|
||||
{"t": 0, "title": "Kagi Result", "url": "https://kagi.com", "snippet": "Premium search"},
|
||||
{"t": 1, "list": ["ignored related search"]},
|
||||
]
|
||||
})
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
|
||||
tool = _tool(provider="kagi", api_key="kagi-key")
|
||||
result = await tool.execute(query="test", count=2)
|
||||
assert "Kagi Result" in result
|
||||
assert "https://kagi.com" in result
|
||||
assert "ignored related search" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_provider():
|
||||
tool = _tool(provider="unknown")
|
||||
@ -189,6 +226,23 @@ async def test_jina_422_falls_back_to_duckduckgo(monkeypatch):
|
||||
assert "DuckDuckGo fallback" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kagi_fallback_to_duckduckgo_when_no_key(monkeypatch):
|
||||
class MockDDGS:
|
||||
def __init__(self, **kw):
|
||||
pass
|
||||
|
||||
def text(self, query, max_results=5):
|
||||
return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}]
|
||||
|
||||
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||
monkeypatch.delenv("KAGI_API_KEY", raising=False)
|
||||
|
||||
tool = _tool(provider="kagi", api_key="")
|
||||
result = await tool.execute(query="test")
|
||||
assert "Fallback" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jina_search_uses_path_encoded_query(monkeypatch):
|
||||
calls = {}
|
||||
@ -227,5 +281,3 @@ async def test_duckduckgo_timeout_returns_error(monkeypatch):
|
||||
result = await tool.execute(query="test")
|
||||
gate.set()
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
|
||||
65
tests/utils/test_strip_think.py
Normal file
65
tests/utils/test_strip_think.py
Normal file
@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
|
||||
class TestStripThinkTag:
|
||||
"""Test <thought>...</thought> block stripping (Gemma 4 and similar models)."""
|
||||
|
||||
def test_closed_tag(self):
|
||||
assert strip_think("Hello <thought>reasoning</thought> World") == "Hello World"
|
||||
|
||||
def test_unclosed_trailing_tag(self):
|
||||
assert strip_think("<thought>ongoing...") == ""
|
||||
|
||||
def test_multiline_tag(self):
|
||||
assert strip_think("<thought>\nline1\nline2\n</thought>End") == "End"
|
||||
|
||||
def test_tag_with_nested_angle_brackets(self):
|
||||
text = "<thought>a < 3 and b > 2</thought>result"
|
||||
assert strip_think(text) == "result"
|
||||
|
||||
def test_multiple_tag_blocks(self):
|
||||
text = "A<thought>x</thought>B<thought>y</thought>C"
|
||||
assert strip_think(text) == "ABC"
|
||||
|
||||
def test_tag_only_whitespace_inside(self):
|
||||
assert strip_think("before<thought> </thought>after") == "beforeafter"
|
||||
|
||||
def test_self_closing_tag_not_matched(self):
|
||||
assert strip_think("<thought/>some text") == "<thought/>some text"
|
||||
|
||||
def test_normal_text_unchanged(self):
|
||||
assert strip_think("Just normal text") == "Just normal text"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert strip_think("") == ""
|
||||
|
||||
|
||||
class TestStripThinkFalsePositive:
|
||||
"""Ensure mid-content <think>/<thought> tags are NOT stripped (#3004)."""
|
||||
|
||||
def test_backtick_think_tag_preserved(self):
|
||||
text = "*Think Stripping:* A new utility to strip `<think>` tags from output."
|
||||
assert strip_think(text) == text
|
||||
|
||||
def test_prose_think_tag_preserved(self):
|
||||
text = "The model emits <think> at the start of its response."
|
||||
assert strip_think(text) == text
|
||||
|
||||
def test_code_block_think_tag_preserved(self):
|
||||
text = "Example:\n```\ntext = re.sub(r\"<think>[\\s\\S]*\", \"\", text)\n```\nDone."
|
||||
assert strip_think(text) == text
|
||||
|
||||
def test_backtick_thought_tag_preserved(self):
|
||||
text = "Gemma 4 uses `<thought>` blocks for reasoning."
|
||||
assert strip_think(text) == text
|
||||
|
||||
def test_prefix_unclosed_think_still_stripped(self):
|
||||
assert strip_think("<think>reasoning without closing") == ""
|
||||
|
||||
def test_prefix_unclosed_think_with_whitespace(self):
|
||||
assert strip_think(" <think>reasoning...") == ""
|
||||
|
||||
def test_prefix_unclosed_thought_still_stripped(self):
|
||||
assert strip_think("<thought>reasoning without closing") == ""
|
||||
Loading…
x
Reference in New Issue
Block a user