mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-09 20:53:38 +00:00
Merge remote-tracking branch 'origin/main' into nightly
This commit is contained in:
commit
ba38d41ad1
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
# Ensure shell scripts always use LF line endings (Docker/Linux compat)
|
||||
*.sh text eol=lf
|
||||
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
@ -30,5 +30,8 @@ jobs:
|
||||
- name: Install all dependencies
|
||||
run: uv sync --all-extras
|
||||
|
||||
- name: Lint with ruff
|
||||
run: uv run ruff check nanobot --select F401,F841
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest tests/
|
||||
|
||||
16
Dockerfile
16
Dockerfile
@ -2,7 +2,7 @@ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
||||
|
||||
# Install Node.js 20 for the WhatsApp bridge
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends curl ca-certificates gnupg git openssh-client && \
|
||||
apt-get install -y --no-install-recommends curl ca-certificates gnupg git bubblewrap openssh-client && \
|
||||
mkdir -p /etc/apt/keyrings && \
|
||||
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" > /etc/apt/sources.list.d/nodesource.list && \
|
||||
@ -32,11 +32,19 @@ RUN git config --global --add url."https://github.com/".insteadOf ssh://git@gith
|
||||
npm install && npm run build
|
||||
WORKDIR /app
|
||||
|
||||
# Create config directory
|
||||
RUN mkdir -p /root/.nanobot
|
||||
# Create non-root user and config directory
|
||||
RUN useradd -m -u 1000 -s /bin/bash nanobot && \
|
||||
mkdir -p /home/nanobot/.nanobot && \
|
||||
chown -R nanobot:nanobot /home/nanobot /app
|
||||
|
||||
COPY entrypoint.sh /usr/local/bin/entrypoint.sh
|
||||
RUN sed -i 's/\r$//' /usr/local/bin/entrypoint.sh && chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
USER nanobot
|
||||
ENV HOME=/home/nanobot
|
||||
|
||||
# Gateway default port
|
||||
EXPOSE 18790
|
||||
|
||||
ENTRYPOINT ["nanobot"]
|
||||
ENTRYPOINT ["entrypoint.sh"]
|
||||
CMD ["status"]
|
||||
|
||||
103
README.md
103
README.md
@ -1,39 +1,44 @@
|
||||
<div align="center">
|
||||
<img src="nanobot_logo.png" alt="nanobot" width="500">
|
||||
<h1>nanobot: Ultra-Lightweight Personal AI Assistant</h1>
|
||||
<h1>nanobot: Ultra-Lightweight Personal AI Agent</h1>
|
||||
<p>
|
||||
<a href="https://pypi.org/project/nanobot-ai/"><img src="https://img.shields.io/pypi/v/nanobot-ai" alt="PyPI"></a>
|
||||
<a href="https://pepy.tech/project/nanobot-ai"><img src="https://static.pepy.tech/badge/nanobot-ai" alt="Downloads"></a>
|
||||
<img src="https://img.shields.io/badge/python-≥3.11-blue" alt="Python">
|
||||
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
|
||||
<a href="https://nanobot.wiki/docs/0.1.5/getting-started/nanobot-overview"><img src="https://img.shields.io/badge/Docs-nanobot.wiki-blue?style=flat&logo=readthedocs&logoColor=white" alt="Docs"></a>
|
||||
<a href="./COMMUNICATION.md"><img src="https://img.shields.io/badge/Feishu-Group-E9DBFC?style=flat&logo=feishu&logoColor=white" alt="Feishu"></a>
|
||||
<a href="./COMMUNICATION.md"><img src="https://img.shields.io/badge/WeChat-Group-C5EAB4?style=flat&logo=wechat&logoColor=white" alt="WeChat"></a>
|
||||
<a href="https://discord.gg/MnCvHqpUGB"><img src="https://img.shields.io/badge/Discord-Community-5865F2?style=flat&logo=discord&logoColor=white" alt="Discord"></a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw).
|
||||
🐈 **nanobot** is an **ultra-lightweight** personal AI agent inspired by [OpenClaw](https://github.com/openclaw/openclaw).
|
||||
|
||||
⚡️ Delivers core agent functionality with **99% fewer lines of code** than OpenClaw.
|
||||
⚡️ Delivers core agent functionality with **99% fewer lines of code**.
|
||||
|
||||
📏 Real-time line count: run `bash core_agent_lines.sh` to verify anytime.
|
||||
|
||||
## 📢 News
|
||||
|
||||
- **2026-04-02** 🧱 **Long-running tasks** run more reliably — core runtime hardening.
|
||||
- **2026-04-05** 🚀 Released **v0.1.5** — sturdier long-running tasks, Dream two-stage memory, production-ready sandboxing and programming Agent SDK. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.5) for details.
|
||||
- **2026-04-04** 🚀 Jinja2 response templates, Dream memory hardened, smarter retry handling.
|
||||
- **2026-04-03** 🧠 Xiaomi MiMo provider, chain-of-thought reasoning visible, Telegram UX polish.
|
||||
- **2026-04-02** 🧱 Long-running tasks run more reliably — core runtime hardening.
|
||||
- **2026-04-01** 🔑 GitHub Copilot auth restored; stricter workspace paths; OpenRouter Claude caching fix.
|
||||
- **2026-03-31** 🛰️ WeChat multimodal alignment, Discord/Matrix polish, Python SDK facade, MCP and tool fixes.
|
||||
- **2026-03-30** 🧩 OpenAI-compatible API tightened; composable agent lifecycle hooks.
|
||||
- **2026-03-29** 💬 WeChat voice, typing, QR/media resilience; fixed-session OpenAI-compatible API.
|
||||
- **2026-03-28** 📚 Provider docs refresh; skill template wording fix.
|
||||
- **2026-03-27** 🚀 Released **v0.1.4.post6** — architecture decoupling, litellm removal, end-to-end streaming, WeChat channel, and a security fix. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post6) for details.
|
||||
- **2026-03-26** 🏗️ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries.
|
||||
- **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures.
|
||||
- **2026-03-24** 🔧 WeChat compatibility, Feishu CardKit streaming, test suite restructured.
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Earlier news</summary>
|
||||
|
||||
- **2026-03-26** 🏗️ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries.
|
||||
- **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures.
|
||||
- **2026-03-24** 🔧 WeChat compatibility, Feishu CardKit streaming, test suite restructured.
|
||||
- **2026-03-23** 🔧 Command routing refactored for plugins, WhatsApp/WeChat media, unified channel login CLI.
|
||||
- **2026-03-22** ⚡ End-to-end streaming, WeChat channel, Anthropic cache optimization, `/status` command.
|
||||
- **2026-03-21** 🔒 Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7).
|
||||
@ -91,7 +96,7 @@
|
||||
|
||||
## Key Features of nanobot:
|
||||
|
||||
🪶 **Ultra-Lightweight**: A super lightweight implementation of OpenClaw — 99% smaller, significantly faster.
|
||||
🪶 **Ultra-Lightweight**: A lightweight implementation built for stable, long-running AI agents.
|
||||
|
||||
🔬 **Research-Ready**: Clean, readable code that's easy to understand, modify, and extend for research.
|
||||
|
||||
@ -140,7 +145,7 @@
|
||||
<tr>
|
||||
<td align="center"><p align="center"><img src="case/search.gif" width="180" height="400"></p></td>
|
||||
<td align="center"><p align="center"><img src="case/code.gif" width="180" height="400"></p></td>
|
||||
<td align="center"><p align="center"><img src="case/scedule.gif" width="180" height="400"></p></td>
|
||||
<td align="center"><p align="center"><img src="case/schedule.gif" width="180" height="400"></p></td>
|
||||
<td align="center"><p align="center"><img src="case/memory.gif" width="180" height="400"></p></td>
|
||||
</tr>
|
||||
<tr>
|
||||
@ -252,7 +257,7 @@ Configure these **two parts** in your config (other options have defaults).
|
||||
nanobot agent
|
||||
```
|
||||
|
||||
That's it! You have a working AI assistant in 2 minutes.
|
||||
That's it! You have a working AI agent in 2 minutes.
|
||||
|
||||
## 💬 Chat Apps
|
||||
|
||||
@ -433,9 +438,11 @@ pip install nanobot-ai[matrix]
|
||||
|
||||
- You need:
|
||||
- `userId` (example: `@nanobot:matrix.org`)
|
||||
- `accessToken`
|
||||
- `deviceId` (recommended so sync tokens can be restored across restarts)
|
||||
- You can obtain these from your homeserver login API (`/_matrix/client/v3/login`) or from your client's advanced session settings.
|
||||
- `password`
|
||||
|
||||
(Note: `accessToken` and `deviceId` are still supported for legacy reasons, but
|
||||
for reliable encryption, password login is recommended instead. If the
|
||||
`password` is provided, `accessToken` and `deviceId` will be ignored.)
|
||||
|
||||
**3. Configure**
|
||||
|
||||
@ -446,8 +453,7 @@ pip install nanobot-ai[matrix]
|
||||
"enabled": true,
|
||||
"homeserver": "https://matrix.org",
|
||||
"userId": "@nanobot:matrix.org",
|
||||
"accessToken": "syt_xxx",
|
||||
"deviceId": "NANOBOT01",
|
||||
"password": "mypasswordhere",
|
||||
"e2eeEnabled": true,
|
||||
"allowFrom": ["@your_user:matrix.org"],
|
||||
"groupPolicy": "open",
|
||||
@ -459,7 +465,7 @@ pip install nanobot-ai[matrix]
|
||||
}
|
||||
```
|
||||
|
||||
> Keep a persistent `matrix-store` and stable `deviceId` — encrypted session state is lost if these change across restarts.
|
||||
> Keep a persistent `matrix-store` — encrypted session state is lost if these change across restarts.
|
||||
|
||||
| Option | Description |
|
||||
|--------|-------------|
|
||||
@ -720,6 +726,9 @@ Give nanobot its own email account. It polls **IMAP** for incoming mail and repl
|
||||
> - `allowFrom`: Add your email address. Use `["*"]` to accept emails from anyone.
|
||||
> - `smtpUseTls` and `smtpUseSsl` default to `true` / `false` respectively, which is correct for Gmail (port 587 + STARTTLS). No need to set them explicitly.
|
||||
> - Set `"autoReplyEnabled": false` if you only want to read/analyze emails without sending automatic replies.
|
||||
> - `allowedAttachmentTypes`: Save inbound attachments matching these MIME types — `["*"]` for all, e.g. `["application/pdf", "image/*"]` (default `[]` = disabled).
|
||||
> - `maxAttachmentSize`: Max size per attachment in bytes (default `2000000` / 2MB).
|
||||
> - `maxAttachmentsPerEmail`: Max attachments to save per email (default `5`).
|
||||
|
||||
```json
|
||||
{
|
||||
@ -736,7 +745,8 @@ Give nanobot its own email account. It polls **IMAP** for incoming mail and repl
|
||||
"smtpUsername": "my-nanobot@gmail.com",
|
||||
"smtpPassword": "your-app-password",
|
||||
"fromAddress": "my-nanobot@gmail.com",
|
||||
"allowFrom": ["your-real-email@gmail.com"]
|
||||
"allowFrom": ["your-real-email@gmail.com"],
|
||||
"allowedAttachmentTypes": ["application/pdf", "image/*"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -861,10 +871,45 @@ Config file: `~/.nanobot/config.json`
|
||||
> run `nanobot onboard`, then answer `N` when asked whether to overwrite the config.
|
||||
> nanobot will merge in missing default fields and keep your current settings.
|
||||
|
||||
### Environment Variables for Secrets
|
||||
|
||||
Instead of storing secrets directly in `config.json`, you can use `${VAR_NAME}` references that are resolved from environment variables at startup:
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"telegram": { "token": "${TELEGRAM_TOKEN}" },
|
||||
"email": {
|
||||
"imapPassword": "${IMAP_PASSWORD}",
|
||||
"smtpPassword": "${SMTP_PASSWORD}"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"groq": { "apiKey": "${GROQ_API_KEY}" }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
For **systemd** deployments, use `EnvironmentFile=` in the service unit to load variables from a file that only the deploying user can read:
|
||||
|
||||
```ini
|
||||
# /etc/systemd/system/nanobot.service (excerpt)
|
||||
[Service]
|
||||
EnvironmentFile=/home/youruser/nanobot_secrets.env
|
||||
User=nanobot
|
||||
ExecStart=...
|
||||
```
|
||||
|
||||
```bash
|
||||
# /home/youruser/nanobot_secrets.env (mode 600, owned by youruser)
|
||||
TELEGRAM_TOKEN=your-token-here
|
||||
IMAP_PASSWORD=your-password-here
|
||||
```
|
||||
|
||||
### Providers
|
||||
|
||||
> [!TIP]
|
||||
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
|
||||
> - **Voice transcription**: Voice messages (Telegram, WhatsApp) are automatically transcribed using Whisper. By default Groq is used (free tier). Set `"transcriptionProvider": "openai"` under `channels` to use OpenAI Whisper instead — the API key is picked from the matching provider config.
|
||||
> - **MiniMax Coding Plan**: Exclusive discount links for the nanobot community: [Overseas](https://platform.minimax.io/subscribe/coding-plan?code=9txpdXw04g&source=link) · [Mainland China](https://platform.minimaxi.com/subscribe/token-plan?code=GILTJpMTqZ&source=link)
|
||||
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
||||
> - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers.
|
||||
@ -880,9 +925,9 @@ Config file: `~/.nanobot/config.json`
|
||||
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
|
||||
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
||||
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
|
||||
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
|
||||
| `openai` | LLM + Voice transcription (Whisper) | [platform.openai.com](https://platform.openai.com) |
|
||||
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||
| `groq` | LLM + Voice transcription (Whisper, default) | [console.groq.com](https://console.groq.com) |
|
||||
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
|
||||
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
||||
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
||||
@ -1197,6 +1242,7 @@ Global settings that apply to all channels. Configure under the `channels` secti
|
||||
"sendProgress": true,
|
||||
"sendToolHints": false,
|
||||
"sendMaxRetries": 3,
|
||||
"transcriptionProvider": "groq",
|
||||
"telegram": { ... }
|
||||
}
|
||||
}
|
||||
@ -1207,6 +1253,7 @@ Global settings that apply to all channels. Configure under the `channels` secti
|
||||
| `sendProgress` | `true` | Stream agent's text progress to the channel |
|
||||
| `sendToolHints` | `false` | Stream tool-call hints (e.g. `read_file("…")`) |
|
||||
| `sendMaxRetries` | `3` | Max delivery attempts per outbound message, including the initial send (0-10 configured, minimum 1 actual attempt) |
|
||||
| `transcriptionProvider` | `"groq"` | Voice transcription backend: `"groq"` (free tier, default) or `"openai"`. API key is auto-resolved from the matching provider config. |
|
||||
|
||||
#### Retry Behavior
|
||||
|
||||
@ -1434,16 +1481,19 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
|
||||
### Security
|
||||
|
||||
> [!TIP]
|
||||
> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent.
|
||||
> For production deployments, set `"restrictToWorkspace": true` and `"tools.exec.sandbox": "bwrap"` in your config to sandbox the agent.
|
||||
> In `v0.1.4.post3` and earlier, an empty `allowFrom` allowed all senders. Since `v0.1.4.post4`, empty `allowFrom` denies all access by default. To allow all senders, set `"allowFrom": ["*"]`.
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
|
||||
| `tools.exec.sandbox` | `""` | Sandbox backend for shell commands. Set to `"bwrap"` to wrap exec calls in a [bubblewrap](https://github.com/containers/bubblewrap) sandbox — the process can only see the workspace (read-write) and media directory (read-only); config files and API keys are hidden. Automatically enables `restrictToWorkspace` for file tools. **Linux only** — requires `bwrap` installed (`apt install bubblewrap`; pre-installed in the Docker image). Not available on macOS or Windows (bwrap depends on Linux kernel namespaces). |
|
||||
| `tools.exec.enable` | `true` | When `false`, the shell `exec` tool is not registered at all. Use this to completely disable shell command execution. |
|
||||
| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). |
|
||||
| `channels.*.allowFrom` | `[]` (deny all) | Whitelist of user IDs. Empty denies all; use `["*"]` to allow everyone. |
|
||||
|
||||
**Docker security**: The official Docker image runs as a non-root user (`nanobot`, UID 1000) with bubblewrap pre-installed. When using `docker-compose.yml`, the container drops all Linux capabilities except `SYS_ADMIN` (required for bwrap's namespace isolation).
|
||||
|
||||
|
||||
### Timezone
|
||||
|
||||
@ -1763,7 +1813,8 @@ print(resp.choices[0].message.content)
|
||||
## 🐳 Docker
|
||||
|
||||
> [!TIP]
|
||||
> The `-v ~/.nanobot:/root/.nanobot` flag mounts your local config directory into the container, so your config and workspace persist across container restarts.
|
||||
> The `-v ~/.nanobot:/home/nanobot/.nanobot` flag mounts your local config directory into the container, so your config and workspace persist across container restarts.
|
||||
> The container runs as user `nanobot` (UID 1000). If you get **Permission denied**, fix ownership on the host first: `sudo chown -R 1000:1000 ~/.nanobot`, or pass `--user $(id -u):$(id -g)` to match your host UID. Podman users can use `--userns=keep-id` instead.
|
||||
|
||||
### Docker Compose
|
||||
|
||||
@ -1786,17 +1837,17 @@ docker compose down # stop
|
||||
docker build -t nanobot .
|
||||
|
||||
# Initialize config (first time only)
|
||||
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot onboard
|
||||
docker run -v ~/.nanobot:/home/nanobot/.nanobot --rm nanobot onboard
|
||||
|
||||
# Edit config on host to add API keys
|
||||
vim ~/.nanobot/config.json
|
||||
|
||||
# Run gateway (connects to enabled channels, e.g. Telegram/Discord/Mochat)
|
||||
docker run -v ~/.nanobot:/root/.nanobot -p 18790:18790 nanobot gateway
|
||||
docker run -v ~/.nanobot:/home/nanobot/.nanobot -p 18790:18790 nanobot gateway
|
||||
|
||||
# Or run a single command
|
||||
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot agent -m "Hello!"
|
||||
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status
|
||||
docker run -v ~/.nanobot:/home/nanobot/.nanobot --rm nanobot agent -m "Hello!"
|
||||
docker run -v ~/.nanobot:/home/nanobot/.nanobot --rm nanobot status
|
||||
```
|
||||
|
||||
## 🐧 Linux Service
|
||||
|
||||
20
SECURITY.md
20
SECURITY.md
@ -64,6 +64,7 @@ chmod 600 ~/.nanobot/config.json
|
||||
|
||||
The `exec` tool can execute shell commands. While dangerous command patterns are blocked, you should:
|
||||
|
||||
- ✅ **Enable the bwrap sandbox** (`"tools.exec.sandbox": "bwrap"`) for kernel-level isolation (Linux only)
|
||||
- ✅ Review all tool usage in agent logs
|
||||
- ✅ Understand what commands the agent is running
|
||||
- ✅ Use a dedicated user account with limited privileges
|
||||
@ -71,6 +72,19 @@ The `exec` tool can execute shell commands. While dangerous command patterns are
|
||||
- ❌ Don't disable security checks
|
||||
- ❌ Don't run on systems with sensitive data without careful review
|
||||
|
||||
**Exec sandbox (bwrap):**
|
||||
|
||||
On Linux, set `"tools.exec.sandbox": "bwrap"` to wrap every shell command in a [bubblewrap](https://github.com/containers/bubblewrap) sandbox. This uses Linux kernel namespaces to restrict what the process can see:
|
||||
|
||||
- Workspace directory → **read-write** (agent works normally)
|
||||
- Media directory → **read-only** (can read uploaded attachments)
|
||||
- System directories (`/usr`, `/bin`, `/lib`) → **read-only** (commands still work)
|
||||
- Config files and API keys (`~/.nanobot/config.json`) → **hidden** (masked by tmpfs)
|
||||
|
||||
Requires `bwrap` installed (`apt install bubblewrap`). Pre-installed in the official Docker image. **Not available on macOS or Windows** — bubblewrap depends on Linux kernel namespaces.
|
||||
|
||||
Enabling the sandbox also automatically activates `restrictToWorkspace` for file tools.
|
||||
|
||||
**Blocked patterns:**
|
||||
- `rm -rf /` - Root filesystem deletion
|
||||
- Fork bombs
|
||||
@ -82,6 +96,7 @@ The `exec` tool can execute shell commands. While dangerous command patterns are
|
||||
|
||||
File operations have path traversal protection, but:
|
||||
|
||||
- ✅ Enable `restrictToWorkspace` or the bwrap sandbox to confine file access
|
||||
- ✅ Run nanobot with a dedicated user account
|
||||
- ✅ Use filesystem permissions to protect sensitive directories
|
||||
- ✅ Regularly audit file operations in logs
|
||||
@ -232,7 +247,7 @@ If you suspect a security breach:
|
||||
1. **No Rate Limiting** - Users can send unlimited messages (add your own if needed)
|
||||
2. **Plain Text Config** - API keys stored in plain text (use keyring for production)
|
||||
3. **No Session Management** - No automatic session expiry
|
||||
4. **Limited Command Filtering** - Only blocks obvious dangerous patterns
|
||||
4. **Limited Command Filtering** - Only blocks obvious dangerous patterns (enable the bwrap sandbox for kernel-level isolation on Linux)
|
||||
5. **No Audit Trail** - Limited security event logging (enhance as needed)
|
||||
|
||||
## Security Checklist
|
||||
@ -243,6 +258,7 @@ Before deploying nanobot:
|
||||
- [ ] Config file permissions set to 0600
|
||||
- [ ] `allowFrom` lists configured for all channels
|
||||
- [ ] Running as non-root user
|
||||
- [ ] Exec sandbox enabled (`"tools.exec.sandbox": "bwrap"`) on Linux deployments
|
||||
- [ ] File system permissions properly restricted
|
||||
- [ ] Dependencies updated to latest secure versions
|
||||
- [ ] Logs monitored for security events
|
||||
@ -252,7 +268,7 @@ Before deploying nanobot:
|
||||
|
||||
## Updates
|
||||
|
||||
**Last Updated**: 2026-02-03
|
||||
**Last Updated**: 2026-04-05
|
||||
|
||||
For the latest security updates and announcements, check:
|
||||
- GitHub Security Advisories: https://github.com/HKUDS/nanobot/security/advisories
|
||||
|
||||
|
Before Width: | Height: | Size: 6.8 MiB After Width: | Height: | Size: 6.8 MiB |
@ -3,7 +3,14 @@ x-common-config: &common-config
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
volumes:
|
||||
- ~/.nanobot:/root/.nanobot
|
||||
- ~/.nanobot:/home/nanobot/.nanobot
|
||||
cap_drop:
|
||||
- ALL
|
||||
cap_add:
|
||||
- SYS_ADMIN
|
||||
security_opt:
|
||||
- apparmor=unconfined
|
||||
- seccomp=unconfined
|
||||
|
||||
services:
|
||||
nanobot-gateway:
|
||||
@ -16,12 +23,29 @@ services:
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: '1'
|
||||
cpus: "1"
|
||||
memory: 1G
|
||||
reservations:
|
||||
cpus: '0.25'
|
||||
cpus: "0.25"
|
||||
memory: 256M
|
||||
|
||||
|
||||
nanobot-api:
|
||||
container_name: nanobot-api
|
||||
<<: *common-config
|
||||
command:
|
||||
["serve", "--host", "0.0.0.0", "-w", "/home/nanobot/.nanobot/api-workspace"]
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- 127.0.0.1:8900:8900
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: "1"
|
||||
memory: 1G
|
||||
reservations:
|
||||
cpus: "0.25"
|
||||
memory: 256M
|
||||
|
||||
nanobot-cli:
|
||||
<<: *common-config
|
||||
profiles:
|
||||
|
||||
15
entrypoint.sh
Executable file
15
entrypoint.sh
Executable file
@ -0,0 +1,15 @@
|
||||
#!/bin/sh
|
||||
dir="$HOME/.nanobot"
|
||||
if [ -d "$dir" ] && [ ! -w "$dir" ]; then
|
||||
owner_uid=$(stat -c %u "$dir" 2>/dev/null || stat -f %u "$dir" 2>/dev/null)
|
||||
cat >&2 <<EOF
|
||||
Error: $dir is not writable (owned by UID $owner_uid, running as UID $(id -u)).
|
||||
|
||||
Fix (pick one):
|
||||
Host: sudo chown -R 1000:1000 ~/.nanobot
|
||||
Docker: docker run --user \$(id -u):\$(id -g) ...
|
||||
Podman: podman run --userns=keep-id ...
|
||||
EOF
|
||||
exit 1
|
||||
fi
|
||||
exec nanobot "$@"
|
||||
@ -2,7 +2,7 @@
|
||||
nanobot - A lightweight AI agent framework
|
||||
"""
|
||||
|
||||
__version__ = "0.1.4.post6"
|
||||
__version__ = "0.1.5"
|
||||
__logo__ = "🐈"
|
||||
|
||||
from nanobot.nanobot import Nanobot, RunResult
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.memory import Consolidator, Dream, MemoryStore
|
||||
from nanobot.agent.memory import Dream, MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
|
||||
|
||||
@ -67,40 +67,27 @@ class CompositeHook(AgentHook):
|
||||
def wants_streaming(self) -> bool:
|
||||
return any(h.wants_streaming() for h in self._hooks)
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
async def _for_each_hook_safe(self, method_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.before_iteration(context)
|
||||
await getattr(h, method_name)(*args, **kwargs)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.before_iteration error in {}", type(h).__name__)
|
||||
logger.exception("AgentHook.{} error in {}", method_name, type(h).__name__)
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
await self._for_each_hook_safe("before_iteration", context)
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.on_stream(context, delta)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.on_stream error in {}", type(h).__name__)
|
||||
await self._for_each_hook_safe("on_stream", context, delta)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.on_stream_end(context, resuming=resuming)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.on_stream_end error in {}", type(h).__name__)
|
||||
await self._for_each_hook_safe("on_stream_end", context, resuming=resuming)
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.before_execute_tools(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.before_execute_tools error in {}", type(h).__name__)
|
||||
await self._for_each_hook_safe("before_execute_tools", context)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.after_iteration(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.after_iteration error in {}", type(h).__name__)
|
||||
await self._for_each_hook_safe("after_iteration", context)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
for h in self._hooks:
|
||||
|
||||
@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
from contextlib import AsyncExitStack, nullcontext
|
||||
@ -262,7 +261,7 @@ class AgentLoop:
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
||||
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
||||
@ -274,6 +273,7 @@ class AgentLoop:
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
sandbox=self.exec_config.sandbox,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
if self.web_config.enable:
|
||||
@ -325,14 +325,10 @@ class AgentLoop:
|
||||
|
||||
@staticmethod
|
||||
def _tool_hint(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
|
||||
def _fmt(tc):
|
||||
args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {}
|
||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
return ", ".join(_fmt(tc) for tc in tool_calls)
|
||||
"""Format tool calls as concise hints with smart abbreviation."""
|
||||
from nanobot.utils.tool_hints import format_tool_hints
|
||||
|
||||
return format_tool_hints(tool_calls)
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
|
||||
@ -30,6 +30,7 @@ from nanobot.utils.runtime import (
|
||||
)
|
||||
|
||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||
_MAX_EMPTY_RETRIES = 2
|
||||
_SNIP_SAFETY_BUFFER = 1024
|
||||
@dataclass(slots=True)
|
||||
class AgentRunSpec:
|
||||
@ -86,6 +87,7 @@ class AgentRunner:
|
||||
stop_reason = "completed"
|
||||
tool_events: list[dict[str, str]] = []
|
||||
external_lookup_counts: dict[str, int] = {}
|
||||
empty_content_retries = 0
|
||||
|
||||
for iteration in range(spec.max_iterations):
|
||||
try:
|
||||
@ -178,15 +180,30 @@ class AgentRunner:
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
empty_content_retries = 0
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
clean = hook.finalize_content(context, response.content)
|
||||
if response.finish_reason != "error" and is_blank_text(clean):
|
||||
empty_content_retries += 1
|
||||
if empty_content_retries < _MAX_EMPTY_RETRIES:
|
||||
logger.warning(
|
||||
"Empty response on turn {} for {} ({}/{}); retrying",
|
||||
iteration,
|
||||
spec.session_key or "default",
|
||||
empty_content_retries,
|
||||
_MAX_EMPTY_RETRIES,
|
||||
)
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
logger.warning(
|
||||
"Empty final response on turn {} for {}; retrying with explicit finalization prompt",
|
||||
"Empty response on turn {} for {} after {} retries; attempting finalization",
|
||||
iteration,
|
||||
spec.session_key or "default",
|
||||
empty_content_retries,
|
||||
)
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
|
||||
@ -9,6 +9,16 @@ from pathlib import Path
|
||||
# Default builtin skills directory (relative to this file)
|
||||
BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills"
|
||||
|
||||
# Opening ---, YAML body (group 1), closing --- on its own line; supports CRLF.
|
||||
_STRIP_SKILL_FRONTMATTER = re.compile(
|
||||
r"^---\s*\r?\n(.*?)\r?\n---\s*\r?\n?",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def _escape_xml(text: str) -> str:
|
||||
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
|
||||
class SkillsLoader:
|
||||
"""
|
||||
@ -23,6 +33,22 @@ class SkillsLoader:
|
||||
self.workspace_skills = workspace / "skills"
|
||||
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
|
||||
|
||||
def _skill_entries_from_dir(self, base: Path, source: str, *, skip_names: set[str] | None = None) -> list[dict[str, str]]:
|
||||
if not base.exists():
|
||||
return []
|
||||
entries: list[dict[str, str]] = []
|
||||
for skill_dir in base.iterdir():
|
||||
if not skill_dir.is_dir():
|
||||
continue
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if not skill_file.exists():
|
||||
continue
|
||||
name = skill_dir.name
|
||||
if skip_names is not None and name in skip_names:
|
||||
continue
|
||||
entries.append({"name": name, "path": str(skill_file), "source": source})
|
||||
return entries
|
||||
|
||||
def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]:
|
||||
"""
|
||||
List all available skills.
|
||||
@ -33,27 +59,15 @@ class SkillsLoader:
|
||||
Returns:
|
||||
List of skill info dicts with 'name', 'path', 'source'.
|
||||
"""
|
||||
skills = []
|
||||
|
||||
# Workspace skills (highest priority)
|
||||
if self.workspace_skills.exists():
|
||||
for skill_dir in self.workspace_skills.iterdir():
|
||||
if skill_dir.is_dir():
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if skill_file.exists():
|
||||
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"})
|
||||
|
||||
# Built-in skills
|
||||
skills = self._skill_entries_from_dir(self.workspace_skills, "workspace")
|
||||
workspace_names = {entry["name"] for entry in skills}
|
||||
if self.builtin_skills and self.builtin_skills.exists():
|
||||
for skill_dir in self.builtin_skills.iterdir():
|
||||
if skill_dir.is_dir():
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills):
|
||||
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"})
|
||||
skills.extend(
|
||||
self._skill_entries_from_dir(self.builtin_skills, "builtin", skip_names=workspace_names)
|
||||
)
|
||||
|
||||
# Filter by requirements
|
||||
if filter_unavailable:
|
||||
return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))]
|
||||
return [skill for skill in skills if self._check_requirements(self._get_skill_meta(skill["name"]))]
|
||||
return skills
|
||||
|
||||
def load_skill(self, name: str) -> str | None:
|
||||
@ -66,17 +80,13 @@ class SkillsLoader:
|
||||
Returns:
|
||||
Skill content or None if not found.
|
||||
"""
|
||||
# Check workspace first
|
||||
workspace_skill = self.workspace_skills / name / "SKILL.md"
|
||||
if workspace_skill.exists():
|
||||
return workspace_skill.read_text(encoding="utf-8")
|
||||
|
||||
# Check built-in
|
||||
roots = [self.workspace_skills]
|
||||
if self.builtin_skills:
|
||||
builtin_skill = self.builtin_skills / name / "SKILL.md"
|
||||
if builtin_skill.exists():
|
||||
return builtin_skill.read_text(encoding="utf-8")
|
||||
|
||||
roots.append(self.builtin_skills)
|
||||
for root in roots:
|
||||
path = root / name / "SKILL.md"
|
||||
if path.exists():
|
||||
return path.read_text(encoding="utf-8")
|
||||
return None
|
||||
|
||||
def load_skills_for_context(self, skill_names: list[str]) -> str:
|
||||
@ -89,14 +99,12 @@ class SkillsLoader:
|
||||
Returns:
|
||||
Formatted skills content.
|
||||
"""
|
||||
parts = []
|
||||
for name in skill_names:
|
||||
content = self.load_skill(name)
|
||||
if content:
|
||||
content = self._strip_frontmatter(content)
|
||||
parts.append(f"### Skill: {name}\n\n{content}")
|
||||
|
||||
return "\n\n---\n\n".join(parts) if parts else ""
|
||||
parts = [
|
||||
f"### Skill: {name}\n\n{self._strip_frontmatter(markdown)}"
|
||||
for name in skill_names
|
||||
if (markdown := self.load_skill(name))
|
||||
]
|
||||
return "\n\n---\n\n".join(parts)
|
||||
|
||||
def build_skills_summary(self) -> str:
|
||||
"""
|
||||
@ -112,44 +120,36 @@ class SkillsLoader:
|
||||
if not all_skills:
|
||||
return ""
|
||||
|
||||
def escape_xml(s: str) -> str:
|
||||
return s.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
lines = ["<skills>"]
|
||||
for s in all_skills:
|
||||
name = escape_xml(s["name"])
|
||||
path = s["path"]
|
||||
desc = escape_xml(self._get_skill_description(s["name"]))
|
||||
skill_meta = self._get_skill_meta(s["name"])
|
||||
available = self._check_requirements(skill_meta)
|
||||
|
||||
lines.append(f" <skill available=\"{str(available).lower()}\">")
|
||||
lines.append(f" <name>{name}</name>")
|
||||
lines.append(f" <description>{desc}</description>")
|
||||
lines.append(f" <location>{path}</location>")
|
||||
|
||||
# Show missing requirements for unavailable skills
|
||||
lines: list[str] = ["<skills>"]
|
||||
for entry in all_skills:
|
||||
skill_name = entry["name"]
|
||||
meta = self._get_skill_meta(skill_name)
|
||||
available = self._check_requirements(meta)
|
||||
lines.extend(
|
||||
[
|
||||
f' <skill available="{str(available).lower()}">',
|
||||
f" <name>{_escape_xml(skill_name)}</name>",
|
||||
f" <description>{_escape_xml(self._get_skill_description(skill_name))}</description>",
|
||||
f" <location>{entry['path']}</location>",
|
||||
]
|
||||
)
|
||||
if not available:
|
||||
missing = self._get_missing_requirements(skill_meta)
|
||||
missing = self._get_missing_requirements(meta)
|
||||
if missing:
|
||||
lines.append(f" <requires>{escape_xml(missing)}</requires>")
|
||||
|
||||
lines.append(f" <requires>{_escape_xml(missing)}</requires>")
|
||||
lines.append(" </skill>")
|
||||
lines.append("</skills>")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _get_missing_requirements(self, skill_meta: dict) -> str:
|
||||
"""Get a description of missing requirements."""
|
||||
missing = []
|
||||
requires = skill_meta.get("requires", {})
|
||||
for b in requires.get("bins", []):
|
||||
if not shutil.which(b):
|
||||
missing.append(f"CLI: {b}")
|
||||
for env in requires.get("env", []):
|
||||
if not os.environ.get(env):
|
||||
missing.append(f"ENV: {env}")
|
||||
return ", ".join(missing)
|
||||
required_bins = requires.get("bins", [])
|
||||
required_env_vars = requires.get("env", [])
|
||||
return ", ".join(
|
||||
[f"CLI: {command_name}" for command_name in required_bins if not shutil.which(command_name)]
|
||||
+ [f"ENV: {env_name}" for env_name in required_env_vars if not os.environ.get(env_name)]
|
||||
)
|
||||
|
||||
def _get_skill_description(self, name: str) -> str:
|
||||
"""Get the description of a skill from its frontmatter."""
|
||||
@ -160,30 +160,32 @@ class SkillsLoader:
|
||||
|
||||
def _strip_frontmatter(self, content: str) -> str:
|
||||
"""Remove YAML frontmatter from markdown content."""
|
||||
if content.startswith("---"):
|
||||
match = re.match(r"^---\n.*?\n---\n", content, re.DOTALL)
|
||||
if match:
|
||||
return content[match.end():].strip()
|
||||
if not content.startswith("---"):
|
||||
return content
|
||||
match = _STRIP_SKILL_FRONTMATTER.match(content)
|
||||
if match:
|
||||
return content[match.end():].strip()
|
||||
return content
|
||||
|
||||
def _parse_nanobot_metadata(self, raw: str) -> dict:
|
||||
"""Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys)."""
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
payload = data.get("nanobot", data.get("openclaw", {}))
|
||||
return payload if isinstance(payload, dict) else {}
|
||||
|
||||
def _check_requirements(self, skill_meta: dict) -> bool:
|
||||
"""Check if skill requirements are met (bins, env vars)."""
|
||||
requires = skill_meta.get("requires", {})
|
||||
for b in requires.get("bins", []):
|
||||
if not shutil.which(b):
|
||||
return False
|
||||
for env in requires.get("env", []):
|
||||
if not os.environ.get(env):
|
||||
return False
|
||||
return True
|
||||
required_bins = requires.get("bins", [])
|
||||
required_env_vars = requires.get("env", [])
|
||||
return all(shutil.which(cmd) for cmd in required_bins) and all(
|
||||
os.environ.get(var) for var in required_env_vars
|
||||
)
|
||||
|
||||
def _get_skill_meta(self, name: str) -> dict:
|
||||
"""Get nanobot metadata for a skill (cached in frontmatter)."""
|
||||
@ -192,13 +194,15 @@ class SkillsLoader:
|
||||
|
||||
def get_always_skills(self) -> list[str]:
|
||||
"""Get skills marked as always=true that meet requirements."""
|
||||
result = []
|
||||
for s in self.list_skills(filter_unavailable=True):
|
||||
meta = self.get_skill_metadata(s["name"]) or {}
|
||||
skill_meta = self._parse_nanobot_metadata(meta.get("metadata", ""))
|
||||
if skill_meta.get("always") or meta.get("always"):
|
||||
result.append(s["name"])
|
||||
return result
|
||||
return [
|
||||
entry["name"]
|
||||
for entry in self.list_skills(filter_unavailable=True)
|
||||
if (meta := self.get_skill_metadata(entry["name"]) or {})
|
||||
and (
|
||||
self._parse_nanobot_metadata(meta.get("metadata", "")).get("always")
|
||||
or meta.get("always")
|
||||
)
|
||||
]
|
||||
|
||||
def get_skill_metadata(self, name: str) -> dict | None:
|
||||
"""
|
||||
@ -211,18 +215,15 @@ class SkillsLoader:
|
||||
Metadata dict or None.
|
||||
"""
|
||||
content = self.load_skill(name)
|
||||
if not content:
|
||||
if not content or not content.startswith("---"):
|
||||
return None
|
||||
|
||||
if content.startswith("---"):
|
||||
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
|
||||
if match:
|
||||
# Simple YAML parsing
|
||||
metadata = {}
|
||||
for line in match.group(1).split("\n"):
|
||||
if ":" in line:
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip().strip('"\'')
|
||||
return metadata
|
||||
|
||||
return None
|
||||
match = _STRIP_SKILL_FRONTMATTER.match(content)
|
||||
if not match:
|
||||
return None
|
||||
metadata: dict[str, str] = {}
|
||||
for line in match.group(1).splitlines():
|
||||
if ":" not in line:
|
||||
continue
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip().strip('"\'')
|
||||
return metadata
|
||||
|
||||
@ -111,7 +111,7 @@ class SubagentManager:
|
||||
try:
|
||||
# Build subagent tools (no message tool, no spawn tool)
|
||||
tools = ToolRegistry()
|
||||
allowed_dir = self.workspace if self.restrict_to_workspace else None
|
||||
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
||||
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
@ -124,6 +124,7 @@ class SubagentManager:
|
||||
working_dir=str(self.workspace),
|
||||
timeout=self.exec_config.timeout,
|
||||
restrict_to_workspace=self.restrict_to_workspace,
|
||||
sandbox=self.exec_config.sandbox,
|
||||
path_append=self.exec_config.path_append,
|
||||
))
|
||||
if self.web_config.enable:
|
||||
|
||||
@ -13,6 +13,10 @@ from nanobot.cron.types import CronJob, CronJobState, CronSchedule
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
action=StringSchema("Action to perform", enum=["add", "list", "remove"]),
|
||||
name=StringSchema(
|
||||
"Optional short human-readable label for the job "
|
||||
"(e.g., 'weather-monitor', 'daily-standup'). Defaults to first 30 chars of message."
|
||||
),
|
||||
message=StringSchema(
|
||||
"Instruction for the agent to execute when the job triggers "
|
||||
"(e.g., 'Send a reminder to WeChat: xxx' or 'Check system status and report')"
|
||||
@ -93,6 +97,7 @@ class CronTool(Tool):
|
||||
async def execute(
|
||||
self,
|
||||
action: str,
|
||||
name: str | None = None,
|
||||
message: str = "",
|
||||
every_seconds: int | None = None,
|
||||
cron_expr: str | None = None,
|
||||
@ -105,7 +110,7 @@ class CronTool(Tool):
|
||||
if action == "add":
|
||||
if self._in_cron_context.get():
|
||||
return "Error: cannot schedule new jobs from within a cron job execution"
|
||||
return self._add_job(message, every_seconds, cron_expr, tz, at, deliver)
|
||||
return self._add_job(name, message, every_seconds, cron_expr, tz, at, deliver)
|
||||
elif action == "list":
|
||||
return self._list_jobs()
|
||||
elif action == "remove":
|
||||
@ -114,6 +119,7 @@ class CronTool(Tool):
|
||||
|
||||
def _add_job(
|
||||
self,
|
||||
name: str | None,
|
||||
message: str,
|
||||
every_seconds: int | None,
|
||||
cron_expr: str | None,
|
||||
@ -158,7 +164,7 @@ class CronTool(Tool):
|
||||
return "Error: either every_seconds, cron_expr, or at is required"
|
||||
|
||||
job = self._cron.add_job(
|
||||
name=message[:30],
|
||||
name=name or message[:30],
|
||||
schedule=schedule,
|
||||
message=message,
|
||||
deliver=deliver,
|
||||
|
||||
@ -186,7 +186,7 @@ class WriteFileTool(_FsTool):
|
||||
fp = self._resolve(path)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(content, encoding="utf-8")
|
||||
return f"Successfully wrote {len(content)} bytes to {fp}"
|
||||
return f"Successfully wrote {len(content)} characters to {fp}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
|
||||
55
nanobot/agent/tools/sandbox.py
Normal file
55
nanobot/agent/tools/sandbox.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Sandbox backends for shell command execution.
|
||||
|
||||
To add a new backend, implement a function with the signature:
|
||||
_wrap_<name>(command: str, workspace: str, cwd: str) -> str
|
||||
and register it in _BACKENDS below.
|
||||
"""
|
||||
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.paths import get_media_dir
|
||||
|
||||
|
||||
def _bwrap(command: str, workspace: str, cwd: str) -> str:
|
||||
"""Wrap command in a bubblewrap sandbox (requires bwrap in container).
|
||||
|
||||
Only the workspace is bind-mounted read-write; its parent dir (which holds
|
||||
config.json) is hidden behind a fresh tmpfs. The media directory is
|
||||
bind-mounted read-only so exec commands can read uploaded attachments.
|
||||
"""
|
||||
ws = Path(workspace).resolve()
|
||||
media = get_media_dir().resolve()
|
||||
|
||||
try:
|
||||
sandbox_cwd = str(ws / Path(cwd).resolve().relative_to(ws))
|
||||
except ValueError:
|
||||
sandbox_cwd = str(ws)
|
||||
|
||||
required = ["/usr"]
|
||||
optional = ["/bin", "/lib", "/lib64", "/etc/alternatives",
|
||||
"/etc/ssl/certs", "/etc/resolv.conf", "/etc/ld.so.cache"]
|
||||
|
||||
args = ["bwrap", "--new-session", "--die-with-parent"]
|
||||
for p in required: args += ["--ro-bind", p, p]
|
||||
for p in optional: args += ["--ro-bind-try", p, p]
|
||||
args += [
|
||||
"--proc", "/proc", "--dev", "/dev", "--tmpfs", "/tmp",
|
||||
"--tmpfs", str(ws.parent), # mask config dir
|
||||
"--dir", str(ws), # recreate workspace mount point
|
||||
"--bind", str(ws), str(ws),
|
||||
"--ro-bind-try", str(media), str(media), # read-only access to media
|
||||
"--chdir", sandbox_cwd,
|
||||
"--", "sh", "-c", command,
|
||||
]
|
||||
return shlex.join(args)
|
||||
|
||||
|
||||
_BACKENDS = {"bwrap": _bwrap}
|
||||
|
||||
|
||||
def wrap_command(sandbox: str, command: str, workspace: str, cwd: str) -> str:
|
||||
"""Wrap *command* using the named sandbox backend."""
|
||||
if backend := _BACKENDS.get(sandbox):
|
||||
return backend(command, workspace, cwd)
|
||||
raise ValueError(f"Unknown sandbox backend {sandbox!r}. Available: {list(_BACKENDS)}")
|
||||
@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -10,6 +11,7 @@ from typing import Any
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.sandbox import wrap_command
|
||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.config.paths import get_media_dir
|
||||
|
||||
@ -40,10 +42,12 @@ class ExecTool(Tool):
|
||||
deny_patterns: list[str] | None = None,
|
||||
allow_patterns: list[str] | None = None,
|
||||
restrict_to_workspace: bool = False,
|
||||
sandbox: str = "",
|
||||
path_append: str = "",
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.working_dir = working_dir
|
||||
self.sandbox = sandbox
|
||||
self.deny_patterns = deny_patterns or [
|
||||
r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr
|
||||
r"\bdel\s+/[fq]\b", # del /f, del /q
|
||||
@ -83,15 +87,23 @@ class ExecTool(Tool):
|
||||
if guard_error:
|
||||
return guard_error
|
||||
|
||||
if self.sandbox:
|
||||
workspace = self.working_dir or cwd
|
||||
command = wrap_command(self.sandbox, command, workspace, cwd)
|
||||
cwd = str(Path(workspace).resolve())
|
||||
|
||||
effective_timeout = min(timeout or self.timeout, self._MAX_TIMEOUT)
|
||||
|
||||
env = os.environ.copy()
|
||||
env = self._build_env()
|
||||
|
||||
if self.path_append:
|
||||
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
|
||||
command = f'export PATH="$PATH:{self.path_append}"; {command}'
|
||||
|
||||
bash = shutil.which("bash") or "/bin/bash"
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
bash, "-l", "-c", command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
@ -104,18 +116,11 @@ class ExecTool(Tool):
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
process.kill()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
if sys.platform != "win32":
|
||||
try:
|
||||
os.waitpid(process.pid, os.WNOHANG)
|
||||
except (ProcessLookupError, ChildProcessError) as e:
|
||||
logger.debug("Process already reaped or not found: {}", e)
|
||||
await self._kill_process(process)
|
||||
return f"Error: Command timed out after {effective_timeout} seconds"
|
||||
except asyncio.CancelledError:
|
||||
await self._kill_process(process)
|
||||
raise
|
||||
|
||||
output_parts = []
|
||||
|
||||
@ -146,6 +151,36 @@ class ExecTool(Tool):
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
async def _kill_process(process: asyncio.subprocess.Process) -> None:
|
||||
"""Kill a subprocess and reap it to prevent zombies."""
|
||||
process.kill()
|
||||
try:
|
||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
if sys.platform != "win32":
|
||||
try:
|
||||
os.waitpid(process.pid, os.WNOHANG)
|
||||
except (ProcessLookupError, ChildProcessError) as e:
|
||||
logger.debug("Process already reaped or not found: {}", e)
|
||||
|
||||
def _build_env(self) -> dict[str, str]:
|
||||
"""Build a minimal environment for subprocess execution.
|
||||
|
||||
Uses HOME so that ``bash -l`` sources the user's profile (which sets
|
||||
PATH and other essentials). Only PATH is extended with *path_append*;
|
||||
the parent process's environment is **not** inherited, preventing
|
||||
secrets in env vars from leaking to LLM-generated commands.
|
||||
"""
|
||||
home = os.environ.get("HOME", "/tmp")
|
||||
return {
|
||||
"HOME": home,
|
||||
"LANG": os.environ.get("LANG", "C.UTF-8"),
|
||||
"TERM": os.environ.get("TERM", "dumb"),
|
||||
}
|
||||
|
||||
def _guard_command(self, command: str, cwd: str) -> str | None:
|
||||
"""Best-effort safety guard for potentially destructive commands."""
|
||||
cmd = command.strip()
|
||||
|
||||
@ -8,7 +8,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import quote, urlparse
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
@ -182,10 +182,10 @@ class WebSearchTool(Tool):
|
||||
return await self._search_duckduckgo(query, n)
|
||||
try:
|
||||
headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
encoded_query = quote(query, safe="")
|
||||
async with httpx.AsyncClient(proxy=self.proxy) as client:
|
||||
r = await client.get(
|
||||
f"https://s.jina.ai/",
|
||||
params={"q": query},
|
||||
f"https://s.jina.ai/{encoded_query}",
|
||||
headers=headers,
|
||||
timeout=15.0,
|
||||
)
|
||||
@ -197,7 +197,8 @@ class WebSearchTool(Tool):
|
||||
]
|
||||
return _format_results(query, items, n)
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
logger.warning("Jina search failed ({}), falling back to DuckDuckGo", e)
|
||||
return await self._search_duckduckgo(query, n)
|
||||
|
||||
async def _search_duckduckgo(self, query: str, n: int) -> str:
|
||||
try:
|
||||
@ -206,7 +207,10 @@ class WebSearchTool(Tool):
|
||||
from ddgs import DDGS
|
||||
|
||||
ddgs = DDGS(timeout=10)
|
||||
raw = await asyncio.to_thread(ddgs.text, query, max_results=n)
|
||||
raw = await asyncio.wait_for(
|
||||
asyncio.to_thread(ddgs.text, query, max_results=n),
|
||||
timeout=self.config.timeout,
|
||||
)
|
||||
if not raw:
|
||||
return f"No results for: {query}"
|
||||
items = [
|
||||
|
||||
@ -22,6 +22,7 @@ class BaseChannel(ABC):
|
||||
|
||||
name: str = "base"
|
||||
display_name: str = "Base"
|
||||
transcription_provider: str = "groq"
|
||||
transcription_api_key: str = ""
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
@ -37,13 +38,16 @@ class BaseChannel(ABC):
|
||||
self._running = False
|
||||
|
||||
async def transcribe_audio(self, file_path: str | Path) -> str:
|
||||
"""Transcribe an audio file via Groq Whisper. Returns empty string on failure."""
|
||||
"""Transcribe an audio file via Whisper (OpenAI or Groq). Returns empty string on failure."""
|
||||
if not self.transcription_api_key:
|
||||
return ""
|
||||
try:
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
|
||||
provider = GroqTranscriptionProvider(api_key=self.transcription_api_key)
|
||||
if self.transcription_provider == "openai":
|
||||
from nanobot.providers.transcription import OpenAITranscriptionProvider
|
||||
provider = OpenAITranscriptionProvider(api_key=self.transcription_api_key)
|
||||
else:
|
||||
from nanobot.providers.transcription import GroqTranscriptionProvider
|
||||
provider = GroqTranscriptionProvider(api_key=self.transcription_api_key)
|
||||
return await provider.transcribe(file_path)
|
||||
except Exception as e:
|
||||
logger.warning("{}: audio transcription failed: {}", self.name, e)
|
||||
|
||||
@ -12,6 +12,8 @@ from email.header import decode_header, make_header
|
||||
from email.message import EmailMessage
|
||||
from email.parser import BytesParser
|
||||
from email.utils import parseaddr
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
@ -20,7 +22,9 @@ from pydantic import Field
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
|
||||
|
||||
class EmailConfig(Base):
|
||||
@ -55,6 +59,11 @@ class EmailConfig(Base):
|
||||
verify_dkim: bool = True # Require Authentication-Results with dkim=pass
|
||||
verify_spf: bool = True # Require Authentication-Results with spf=pass
|
||||
|
||||
# Attachment handling — set allowed types to enable (e.g. ["application/pdf", "image/*"], or ["*"] for all)
|
||||
allowed_attachment_types: list[str] = Field(default_factory=list)
|
||||
max_attachment_size: int = 2_000_000 # 2MB per attachment
|
||||
max_attachments_per_email: int = 5
|
||||
|
||||
|
||||
class EmailChannel(BaseChannel):
|
||||
"""
|
||||
@ -153,6 +162,7 @@ class EmailChannel(BaseChannel):
|
||||
sender_id=sender,
|
||||
chat_id=sender,
|
||||
content=item["content"],
|
||||
media=item.get("media") or None,
|
||||
metadata=item.get("metadata", {}),
|
||||
)
|
||||
except Exception as e:
|
||||
@ -404,6 +414,20 @@ class EmailChannel(BaseChannel):
|
||||
f"{body}"
|
||||
)
|
||||
|
||||
# --- Attachment extraction ---
|
||||
attachment_paths: list[str] = []
|
||||
if self.config.allowed_attachment_types:
|
||||
saved = self._extract_attachments(
|
||||
parsed,
|
||||
uid or "noid",
|
||||
allowed_types=self.config.allowed_attachment_types,
|
||||
max_size=self.config.max_attachment_size,
|
||||
max_count=self.config.max_attachments_per_email,
|
||||
)
|
||||
for p in saved:
|
||||
attachment_paths.append(str(p))
|
||||
content += f"\n[attachment: {p.name} — saved to {p}]"
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"subject": subject,
|
||||
@ -418,6 +442,7 @@ class EmailChannel(BaseChannel):
|
||||
"message_id": message_id,
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
"media": attachment_paths,
|
||||
}
|
||||
)
|
||||
|
||||
@ -537,6 +562,61 @@ class EmailChannel(BaseChannel):
|
||||
dkim_pass = True
|
||||
return spf_pass, dkim_pass
|
||||
|
||||
@classmethod
|
||||
def _extract_attachments(
|
||||
cls,
|
||||
msg: Any,
|
||||
uid: str,
|
||||
*,
|
||||
allowed_types: list[str],
|
||||
max_size: int,
|
||||
max_count: int,
|
||||
) -> list[Path]:
|
||||
"""Extract and save email attachments to the media directory.
|
||||
|
||||
Returns list of saved file paths.
|
||||
"""
|
||||
if not msg.is_multipart():
|
||||
return []
|
||||
|
||||
saved: list[Path] = []
|
||||
media_dir = get_media_dir("email")
|
||||
|
||||
for part in msg.walk():
|
||||
if len(saved) >= max_count:
|
||||
break
|
||||
if part.get_content_disposition() != "attachment":
|
||||
continue
|
||||
|
||||
content_type = part.get_content_type()
|
||||
if not any(fnmatch(content_type, pat) for pat in allowed_types):
|
||||
logger.debug("Email attachment skipped (type {}): not in allowed list", content_type)
|
||||
continue
|
||||
|
||||
payload = part.get_payload(decode=True)
|
||||
if payload is None:
|
||||
continue
|
||||
if len(payload) > max_size:
|
||||
logger.warning(
|
||||
"Email attachment skipped: size {} exceeds limit {}",
|
||||
len(payload),
|
||||
max_size,
|
||||
)
|
||||
continue
|
||||
|
||||
raw_name = part.get_filename() or "attachment"
|
||||
sanitized = safe_filename(raw_name) or "attachment"
|
||||
dest = media_dir / f"{uid}_{sanitized}"
|
||||
|
||||
try:
|
||||
dest.write_bytes(payload)
|
||||
saved.append(dest)
|
||||
logger.info("Email attachment saved: {}", dest)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to save email attachment {}: {}", dest, exc)
|
||||
|
||||
return saved
|
||||
|
||||
@staticmethod
|
||||
def _html_to_text(raw_html: str) -> str:
|
||||
text = re.sub(r"<\s*br\s*/?>", "\n", raw_html, flags=re.IGNORECASE)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection."""
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@ -9,19 +10,17 @@ import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from lark_oapi.api.im.v1.model import MentionEvent, P2ImMessageReceiveV1
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
from pydantic import Field
|
||||
|
||||
import importlib.util
|
||||
|
||||
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||
|
||||
@ -76,7 +75,9 @@ def _extract_interactive_content(content: dict) -> list[str]:
|
||||
elif isinstance(title, str):
|
||||
parts.append(f"title: {title}")
|
||||
|
||||
for elements in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
|
||||
for elements in (
|
||||
content.get("elements", []) if isinstance(content.get("elements"), list) else []
|
||||
):
|
||||
for element in elements:
|
||||
parts.extend(_extract_element_content(element))
|
||||
|
||||
@ -260,6 +261,7 @@ _STREAM_ELEMENT_ID = "streaming_md"
|
||||
@dataclass
|
||||
class _FeishuStreamBuf:
|
||||
"""Per-chat streaming accumulator using CardKit streaming API."""
|
||||
|
||||
text: str = ""
|
||||
card_id: str | None = None
|
||||
sequence: int = 0
|
||||
@ -288,16 +290,19 @@ class FeishuChannel(BaseChannel):
|
||||
return FeishuConfig().model_dump(by_alias=True)
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
import lark_oapi as lark
|
||||
|
||||
if isinstance(config, dict):
|
||||
config = FeishuConfig.model_validate(config)
|
||||
super().__init__(config, bus)
|
||||
self.config: FeishuConfig = config
|
||||
self._client: Any = None
|
||||
self._client: lark.Client = None
|
||||
self._ws_client: Any = None
|
||||
self._ws_thread: threading.Thread | None = None
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._stream_bufs: dict[str, _FeishuStreamBuf] = {}
|
||||
self._bot_open_id: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
|
||||
@ -316,24 +321,28 @@ class FeishuChannel(BaseChannel):
|
||||
return
|
||||
|
||||
import lark_oapi as lark
|
||||
|
||||
self._running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Create Lark client for sending messages
|
||||
self._client = lark.Client.builder() \
|
||||
.app_id(self.config.app_id) \
|
||||
.app_secret(self.config.app_secret) \
|
||||
.log_level(lark.LogLevel.INFO) \
|
||||
self._client = (
|
||||
lark.Client.builder()
|
||||
.app_id(self.config.app_id)
|
||||
.app_secret(self.config.app_secret)
|
||||
.log_level(lark.LogLevel.INFO)
|
||||
.build()
|
||||
)
|
||||
builder = lark.EventDispatcherHandler.builder(
|
||||
self.config.encrypt_key or "",
|
||||
self.config.verification_token or "",
|
||||
).register_p2_im_message_receive_v1(
|
||||
self._on_message_sync
|
||||
)
|
||||
).register_p2_im_message_receive_v1(self._on_message_sync)
|
||||
builder = self._register_optional_event(
|
||||
builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created
|
||||
)
|
||||
builder = self._register_optional_event(
|
||||
builder, "register_p2_im_message_reaction_deleted_v1", self._on_reaction_deleted
|
||||
)
|
||||
builder = self._register_optional_event(
|
||||
builder, "register_p2_im_message_message_read_v1", self._on_message_read
|
||||
)
|
||||
@ -349,7 +358,7 @@ class FeishuChannel(BaseChannel):
|
||||
self.config.app_id,
|
||||
self.config.app_secret,
|
||||
event_handler=event_handler,
|
||||
log_level=lark.LogLevel.INFO
|
||||
log_level=lark.LogLevel.INFO,
|
||||
)
|
||||
|
||||
# Start WebSocket client in a separate thread with reconnect loop.
|
||||
@ -359,7 +368,9 @@ class FeishuChannel(BaseChannel):
|
||||
# "This event loop is already running" errors.
|
||||
def run_ws():
|
||||
import time
|
||||
|
||||
import lark_oapi.ws.client as _lark_ws_client
|
||||
|
||||
ws_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(ws_loop)
|
||||
# Patch the module-level loop used by lark's ws Client.start()
|
||||
@ -378,6 +389,15 @@ class FeishuChannel(BaseChannel):
|
||||
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
||||
self._ws_thread.start()
|
||||
|
||||
# Fetch bot's own open_id for accurate @mention matching
|
||||
self._bot_open_id = await asyncio.get_running_loop().run_in_executor(
|
||||
None, self._fetch_bot_open_id
|
||||
)
|
||||
if self._bot_open_id:
|
||||
logger.info("Feishu bot open_id: {}", self._bot_open_id)
|
||||
else:
|
||||
logger.warning("Could not fetch bot open_id; @mention matching may be inaccurate")
|
||||
|
||||
logger.info("Feishu bot started with WebSocket long connection")
|
||||
logger.info("No public IP required - using WebSocket to receive events")
|
||||
|
||||
@ -396,6 +416,70 @@ class FeishuChannel(BaseChannel):
|
||||
self._running = False
|
||||
logger.info("Feishu bot stopped")
|
||||
|
||||
def _fetch_bot_open_id(self) -> str | None:
|
||||
"""Fetch the bot's own open_id via GET /open-apis/bot/v3/info."""
|
||||
try:
|
||||
import lark_oapi as lark
|
||||
|
||||
request = (
|
||||
lark.BaseRequest.builder()
|
||||
.http_method(lark.HttpMethod.GET)
|
||||
.uri("/open-apis/bot/v3/info")
|
||||
.token_types({lark.AccessTokenType.APP})
|
||||
.build()
|
||||
)
|
||||
response = self._client.request(request)
|
||||
if response.success():
|
||||
import json
|
||||
|
||||
data = json.loads(response.raw.content)
|
||||
bot = (data.get("data") or data).get("bot") or data.get("bot") or {}
|
||||
return bot.get("open_id")
|
||||
logger.warning("Failed to get bot info: code={}, msg={}", response.code, response.msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Error fetching bot info: {}", e)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_mentions(text: str, mentions: list[MentionEvent] | None) -> str:
|
||||
"""Replace @_user_n placeholders with actual user info from mentions.
|
||||
|
||||
Args:
|
||||
text: The message text containing @_user_n placeholders
|
||||
mentions: List of mention objects from Feishu message
|
||||
|
||||
Returns:
|
||||
Text with placeholders replaced by @姓名 (open_id)
|
||||
"""
|
||||
if not mentions or not text:
|
||||
return text
|
||||
|
||||
for mention in mentions:
|
||||
key = mention.key or None
|
||||
if not key or key not in text:
|
||||
continue
|
||||
|
||||
user_id_obj = mention.id or None
|
||||
if not user_id_obj:
|
||||
continue
|
||||
|
||||
open_id = user_id_obj.open_id
|
||||
user_id = user_id_obj.user_id
|
||||
name = mention.name or key
|
||||
|
||||
# Format: @姓名 (open_id, user_id: xxx)
|
||||
if open_id and user_id:
|
||||
replacement = f"@{name} ({open_id}, user id: {user_id})"
|
||||
elif open_id:
|
||||
replacement = f"@{name} ({open_id})"
|
||||
else:
|
||||
replacement = f"@{name}"
|
||||
|
||||
text = text.replace(key, replacement)
|
||||
|
||||
return text
|
||||
|
||||
def _is_bot_mentioned(self, message: Any) -> bool:
|
||||
"""Check if the bot is @mentioned in the message."""
|
||||
raw_content = message.content or ""
|
||||
@ -406,9 +490,14 @@ class FeishuChannel(BaseChannel):
|
||||
mid = getattr(mention, "id", None)
|
||||
if not mid:
|
||||
continue
|
||||
# Bot mentions have no user_id (None or "") but a valid open_id
|
||||
if not getattr(mid, "user_id", None) and (getattr(mid, "open_id", None) or "").startswith("ou_"):
|
||||
return True
|
||||
mention_open_id = getattr(mid, "open_id", None) or ""
|
||||
if self._bot_open_id:
|
||||
if mention_open_id == self._bot_open_id:
|
||||
return True
|
||||
else:
|
||||
# Fallback heuristic when bot open_id is unavailable
|
||||
if not getattr(mid, "user_id", None) and mention_open_id.startswith("ou_"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_group_message_for_bot(self, message: Any) -> bool:
|
||||
@ -419,20 +508,30 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> str | None:
|
||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageReactionRequest,
|
||||
CreateMessageReactionRequestBody,
|
||||
Emoji,
|
||||
)
|
||||
|
||||
try:
|
||||
request = CreateMessageReactionRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
request = (
|
||||
CreateMessageReactionRequest.builder()
|
||||
.message_id(message_id)
|
||||
.request_body(
|
||||
CreateMessageReactionRequestBody.builder()
|
||||
.reaction_type(Emoji.builder().emoji_type(emoji_type).build())
|
||||
.build()
|
||||
).build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = self._client.im.v1.message_reaction.create(request)
|
||||
|
||||
if not response.success():
|
||||
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
|
||||
logger.warning(
|
||||
"Failed to add reaction: code={}, msg={}", response.code, response.msg
|
||||
)
|
||||
return None
|
||||
else:
|
||||
logger.debug("Added {} reaction to message {}", emoji_type, message_id)
|
||||
@ -456,17 +555,22 @@ class FeishuChannel(BaseChannel):
|
||||
def _remove_reaction_sync(self, message_id: str, reaction_id: str) -> None:
|
||||
"""Sync helper for removing reaction (runs in thread pool)."""
|
||||
from lark_oapi.api.im.v1 import DeleteMessageReactionRequest
|
||||
|
||||
try:
|
||||
request = DeleteMessageReactionRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
.reaction_id(reaction_id) \
|
||||
request = (
|
||||
DeleteMessageReactionRequest.builder()
|
||||
.message_id(message_id)
|
||||
.reaction_id(reaction_id)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = self._client.im.v1.message_reaction.delete(request)
|
||||
if response.success():
|
||||
logger.debug("Removed reaction {} from message {}", reaction_id, message_id)
|
||||
else:
|
||||
logger.debug("Failed to remove reaction: code={}, msg={}", response.code, response.msg)
|
||||
logger.debug(
|
||||
"Failed to remove reaction: code={}, msg={}", response.code, response.msg
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Error removing reaction: {}", e)
|
||||
|
||||
@ -521,27 +625,35 @@ class FeishuChannel(BaseChannel):
|
||||
lines = [_line.strip() for _line in table_text.strip().split("\n") if _line.strip()]
|
||||
if len(lines) < 3:
|
||||
return None
|
||||
|
||||
def split(_line: str) -> list[str]:
|
||||
return [c.strip() for c in _line.strip("|").split("|")]
|
||||
|
||||
headers = [cls._strip_md_formatting(h) for h in split(lines[0])]
|
||||
rows = [[cls._strip_md_formatting(c) for c in split(_line)] for _line in lines[2:]]
|
||||
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
|
||||
for i, h in enumerate(headers)]
|
||||
columns = [
|
||||
{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
|
||||
for i, h in enumerate(headers)
|
||||
]
|
||||
return {
|
||||
"tag": "table",
|
||||
"page_size": len(rows) + 1,
|
||||
"columns": columns,
|
||||
"rows": [{f"c{i}": r[i] if i < len(r) else "" for i in range(len(headers))} for r in rows],
|
||||
"rows": [
|
||||
{f"c{i}": r[i] if i < len(r) else "" for i in range(len(headers))} for r in rows
|
||||
],
|
||||
}
|
||||
|
||||
def _build_card_elements(self, content: str) -> list[dict]:
|
||||
"""Split content into div/markdown + table elements for Feishu card."""
|
||||
elements, last_end = [], 0
|
||||
for m in self._TABLE_RE.finditer(content):
|
||||
before = content[last_end:m.start()]
|
||||
before = content[last_end : m.start()]
|
||||
if before.strip():
|
||||
elements.extend(self._split_headings(before))
|
||||
elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)})
|
||||
elements.append(
|
||||
self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)}
|
||||
)
|
||||
last_end = m.end()
|
||||
remaining = content[last_end:]
|
||||
if remaining.strip():
|
||||
@ -549,7 +661,9 @@ class FeishuChannel(BaseChannel):
|
||||
return elements or [{"tag": "markdown", "content": content}]
|
||||
|
||||
@staticmethod
|
||||
def _split_elements_by_table_limit(elements: list[dict], max_tables: int = 1) -> list[list[dict]]:
|
||||
def _split_elements_by_table_limit(
|
||||
elements: list[dict], max_tables: int = 1
|
||||
) -> list[list[dict]]:
|
||||
"""Split card elements into groups with at most *max_tables* table elements each.
|
||||
|
||||
Feishu cards have a hard limit of one table per card (API error 11310).
|
||||
@ -582,23 +696,25 @@ class FeishuChannel(BaseChannel):
|
||||
code_blocks = []
|
||||
for m in self._CODE_BLOCK_RE.finditer(content):
|
||||
code_blocks.append(m.group(1))
|
||||
protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks)-1}\x00", 1)
|
||||
protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks) - 1}\x00", 1)
|
||||
|
||||
elements = []
|
||||
last_end = 0
|
||||
for m in self._HEADING_RE.finditer(protected):
|
||||
before = protected[last_end:m.start()].strip()
|
||||
before = protected[last_end : m.start()].strip()
|
||||
if before:
|
||||
elements.append({"tag": "markdown", "content": before})
|
||||
text = self._strip_md_formatting(m.group(2).strip())
|
||||
display_text = f"**{text}**" if text else ""
|
||||
elements.append({
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": display_text,
|
||||
},
|
||||
})
|
||||
elements.append(
|
||||
{
|
||||
"tag": "div",
|
||||
"text": {
|
||||
"tag": "lark_md",
|
||||
"content": display_text,
|
||||
},
|
||||
}
|
||||
)
|
||||
last_end = m.end()
|
||||
remaining = protected[last_end:].strip()
|
||||
if remaining:
|
||||
@ -614,19 +730,19 @@ class FeishuChannel(BaseChannel):
|
||||
# ── Smart format detection ──────────────────────────────────────────
|
||||
# Patterns that indicate "complex" markdown needing card rendering
|
||||
_COMPLEX_MD_RE = re.compile(
|
||||
r"```" # fenced code block
|
||||
r"```" # fenced code block
|
||||
r"|^\|.+\|.*\n\s*\|[-:\s|]+\|" # markdown table (header + separator)
|
||||
r"|^#{1,6}\s+" # headings
|
||||
, re.MULTILINE,
|
||||
r"|^#{1,6}\s+", # headings
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
# Simple markdown patterns (bold, italic, strikethrough)
|
||||
_SIMPLE_MD_RE = re.compile(
|
||||
r"\*\*.+?\*\*" # **bold**
|
||||
r"|__.+?__" # __bold__
|
||||
r"\*\*.+?\*\*" # **bold**
|
||||
r"|__.+?__" # __bold__
|
||||
r"|(?<!\*)\*(?!\*)(.+?)(?<!\*)\*(?!\*)" # *italic* (single *)
|
||||
r"|~~.+?~~" # ~~strikethrough~~
|
||||
, re.DOTALL,
|
||||
r"|~~.+?~~", # ~~strikethrough~~
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Markdown link: [text](url)
|
||||
@ -698,14 +814,16 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
for m in cls._MD_LINK_RE.finditer(line):
|
||||
# Text before this link
|
||||
before = line[last_end:m.start()]
|
||||
before = line[last_end : m.start()]
|
||||
if before:
|
||||
elements.append({"tag": "text", "text": before})
|
||||
elements.append({
|
||||
"tag": "a",
|
||||
"text": m.group(1),
|
||||
"href": m.group(2),
|
||||
})
|
||||
elements.append(
|
||||
{
|
||||
"tag": "a",
|
||||
"text": m.group(1),
|
||||
"href": m.group(2),
|
||||
}
|
||||
)
|
||||
last_end = m.end()
|
||||
|
||||
# Remaining text after last link
|
||||
@ -730,29 +848,39 @@ class FeishuChannel(BaseChannel):
|
||||
_AUDIO_EXTS = {".opus"}
|
||||
_VIDEO_EXTS = {".mp4", ".mov", ".avi"}
|
||||
_FILE_TYPE_MAP = {
|
||||
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
|
||||
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
|
||||
".opus": "opus",
|
||||
".mp4": "mp4",
|
||||
".pdf": "pdf",
|
||||
".doc": "doc",
|
||||
".docx": "doc",
|
||||
".xls": "xls",
|
||||
".xlsx": "xls",
|
||||
".ppt": "ppt",
|
||||
".pptx": "ppt",
|
||||
}
|
||||
|
||||
def _upload_image_sync(self, file_path: str) -> str | None:
|
||||
"""Upload an image to Feishu and return the image_key."""
|
||||
from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
request = CreateImageRequest.builder() \
|
||||
request = (
|
||||
CreateImageRequest.builder()
|
||||
.request_body(
|
||||
CreateImageRequestBody.builder()
|
||||
.image_type("message")
|
||||
.image(f)
|
||||
.build()
|
||||
).build()
|
||||
CreateImageRequestBody.builder().image_type("message").image(f).build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = self._client.im.v1.image.create(request)
|
||||
if response.success():
|
||||
image_key = response.data.image_key
|
||||
logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key)
|
||||
return image_key
|
||||
else:
|
||||
logger.error("Failed to upload image: code={}, msg={}", response.code, response.msg)
|
||||
logger.error(
|
||||
"Failed to upload image: code={}, msg={}", response.code, response.msg
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error uploading image {}: {}", file_path, e)
|
||||
@ -761,49 +889,62 @@ class FeishuChannel(BaseChannel):
|
||||
def _upload_file_sync(self, file_path: str) -> str | None:
|
||||
"""Upload a file to Feishu and return the file_key."""
|
||||
from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody
|
||||
|
||||
ext = os.path.splitext(file_path)[1].lower()
|
||||
file_type = self._FILE_TYPE_MAP.get(ext, "stream")
|
||||
file_name = os.path.basename(file_path)
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
request = CreateFileRequest.builder() \
|
||||
request = (
|
||||
CreateFileRequest.builder()
|
||||
.request_body(
|
||||
CreateFileRequestBody.builder()
|
||||
.file_type(file_type)
|
||||
.file_name(file_name)
|
||||
.file(f)
|
||||
.build()
|
||||
).build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = self._client.im.v1.file.create(request)
|
||||
if response.success():
|
||||
file_key = response.data.file_key
|
||||
logger.debug("Uploaded file {}: {}", file_name, file_key)
|
||||
return file_key
|
||||
else:
|
||||
logger.error("Failed to upload file: code={}, msg={}", response.code, response.msg)
|
||||
logger.error(
|
||||
"Failed to upload file: code={}, msg={}", response.code, response.msg
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error uploading file {}: {}", file_path, e)
|
||||
return None
|
||||
|
||||
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
|
||||
def _download_image_sync(
|
||||
self, message_id: str, image_key: str
|
||||
) -> tuple[bytes | None, str | None]:
|
||||
"""Download an image from Feishu message by message_id and image_key."""
|
||||
from lark_oapi.api.im.v1 import GetMessageResourceRequest
|
||||
|
||||
try:
|
||||
request = GetMessageResourceRequest.builder() \
|
||||
.message_id(message_id) \
|
||||
.file_key(image_key) \
|
||||
.type("image") \
|
||||
request = (
|
||||
GetMessageResourceRequest.builder()
|
||||
.message_id(message_id)
|
||||
.file_key(image_key)
|
||||
.type("image")
|
||||
.build()
|
||||
)
|
||||
response = self._client.im.v1.message_resource.get(request)
|
||||
if response.success():
|
||||
file_data = response.file
|
||||
# GetMessageResourceRequest returns BytesIO, need to read bytes
|
||||
if hasattr(file_data, 'read'):
|
||||
if hasattr(file_data, "read"):
|
||||
file_data = file_data.read()
|
||||
return file_data, response.file_name
|
||||
else:
|
||||
logger.error("Failed to download image: code={}, msg={}", response.code, response.msg)
|
||||
logger.error(
|
||||
"Failed to download image: code={}, msg={}", response.code, response.msg
|
||||
)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error("Error downloading image {}: {}", image_key, e)
|
||||
@ -835,17 +976,19 @@ class FeishuChannel(BaseChannel):
|
||||
file_data = file_data.read()
|
||||
return file_data, response.file_name
|
||||
else:
|
||||
logger.error("Failed to download {}: code={}, msg={}", resource_type, response.code, response.msg)
|
||||
logger.error(
|
||||
"Failed to download {}: code={}, msg={}",
|
||||
resource_type,
|
||||
response.code,
|
||||
response.msg,
|
||||
)
|
||||
return None, None
|
||||
except Exception:
|
||||
logger.exception("Error downloading {} {}", resource_type, file_key)
|
||||
return None, None
|
||||
|
||||
async def _download_and_save_media(
|
||||
self,
|
||||
msg_type: str,
|
||||
content_json: dict,
|
||||
message_id: str | None = None
|
||||
self, msg_type: str, content_json: dict, message_id: str | None = None
|
||||
) -> tuple[str | None, str]:
|
||||
"""
|
||||
Download media from Feishu and save to local disk.
|
||||
@ -894,13 +1037,16 @@ class FeishuChannel(BaseChannel):
|
||||
Returns a "[Reply to: ...]" context string, or None on failure.
|
||||
"""
|
||||
from lark_oapi.api.im.v1 import GetMessageRequest
|
||||
|
||||
try:
|
||||
request = GetMessageRequest.builder().message_id(message_id).build()
|
||||
response = self._client.im.v1.message.get(request)
|
||||
if not response.success():
|
||||
logger.debug(
|
||||
"Feishu: could not fetch parent message {}: code={}, msg={}",
|
||||
message_id, response.code, response.msg,
|
||||
message_id,
|
||||
response.code,
|
||||
response.msg,
|
||||
)
|
||||
return None
|
||||
items = getattr(response.data, "items", None)
|
||||
@ -935,20 +1081,24 @@ class FeishuChannel(BaseChannel):
|
||||
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str) -> bool:
|
||||
"""Reply to an existing Feishu message using the Reply API (synchronous)."""
|
||||
from lark_oapi.api.im.v1 import ReplyMessageRequest, ReplyMessageRequestBody
|
||||
|
||||
try:
|
||||
request = ReplyMessageRequest.builder() \
|
||||
.message_id(parent_message_id) \
|
||||
request = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(parent_message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.msg_type(msg_type)
|
||||
.content(content)
|
||||
.build()
|
||||
).build()
|
||||
ReplyMessageRequestBody.builder().msg_type(msg_type).content(content).build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = self._client.im.v1.message.reply(request)
|
||||
if not response.success():
|
||||
logger.error(
|
||||
"Failed to reply to Feishu message {}: code={}, msg={}, log_id={}",
|
||||
parent_message_id, response.code, response.msg, response.get_log_id()
|
||||
parent_message_id,
|
||||
response.code,
|
||||
response.msg,
|
||||
response.get_log_id(),
|
||||
)
|
||||
return False
|
||||
logger.debug("Feishu reply sent to message {}", parent_message_id)
|
||||
@ -957,24 +1107,33 @@ class FeishuChannel(BaseChannel):
|
||||
logger.error("Error replying to Feishu message {}: {}", parent_message_id, e)
|
||||
return False
|
||||
|
||||
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> str | None:
|
||||
def _send_message_sync(
|
||||
self, receive_id_type: str, receive_id: str, msg_type: str, content: str
|
||||
) -> str | None:
|
||||
"""Send a single message and return the message_id on success."""
|
||||
from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody
|
||||
|
||||
try:
|
||||
request = CreateMessageRequest.builder() \
|
||||
.receive_id_type(receive_id_type) \
|
||||
request = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(receive_id_type)
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(receive_id)
|
||||
.msg_type(msg_type)
|
||||
.content(content)
|
||||
.build()
|
||||
).build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = self._client.im.v1.message.create(request)
|
||||
if not response.success():
|
||||
logger.error(
|
||||
"Failed to send Feishu {} message: code={}, msg={}, log_id={}",
|
||||
msg_type, response.code, response.msg, response.get_log_id()
|
||||
msg_type,
|
||||
response.code,
|
||||
response.msg,
|
||||
response.get_log_id(),
|
||||
)
|
||||
return None
|
||||
msg_id = getattr(response.data, "message_id", None)
|
||||
@ -987,31 +1146,44 @@ class FeishuChannel(BaseChannel):
|
||||
def _create_streaming_card_sync(self, receive_id_type: str, chat_id: str) -> str | None:
|
||||
"""Create a CardKit streaming card, send it to chat, return card_id."""
|
||||
from lark_oapi.api.cardkit.v1 import CreateCardRequest, CreateCardRequestBody
|
||||
|
||||
card_json = {
|
||||
"schema": "2.0",
|
||||
"config": {"wide_screen_mode": True, "update_multi": True, "streaming_mode": True},
|
||||
"body": {"elements": [{"tag": "markdown", "content": "", "element_id": _STREAM_ELEMENT_ID}]},
|
||||
"body": {
|
||||
"elements": [{"tag": "markdown", "content": "", "element_id": _STREAM_ELEMENT_ID}]
|
||||
},
|
||||
}
|
||||
try:
|
||||
request = CreateCardRequest.builder().request_body(
|
||||
CreateCardRequestBody.builder()
|
||||
.type("card_json")
|
||||
.data(json.dumps(card_json, ensure_ascii=False))
|
||||
request = (
|
||||
CreateCardRequest.builder()
|
||||
.request_body(
|
||||
CreateCardRequestBody.builder()
|
||||
.type("card_json")
|
||||
.data(json.dumps(card_json, ensure_ascii=False))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
).build()
|
||||
)
|
||||
response = self._client.cardkit.v1.card.create(request)
|
||||
if not response.success():
|
||||
logger.warning("Failed to create streaming card: code={}, msg={}", response.code, response.msg)
|
||||
logger.warning(
|
||||
"Failed to create streaming card: code={}, msg={}", response.code, response.msg
|
||||
)
|
||||
return None
|
||||
card_id = getattr(response.data, "card_id", None)
|
||||
if card_id:
|
||||
message_id = self._send_message_sync(
|
||||
receive_id_type, chat_id, "interactive",
|
||||
receive_id_type,
|
||||
chat_id,
|
||||
"interactive",
|
||||
json.dumps({"type": "card", "data": {"card_id": card_id}}),
|
||||
)
|
||||
if message_id:
|
||||
return card_id
|
||||
logger.warning("Created streaming card {} but failed to send it to {}", card_id, chat_id)
|
||||
logger.warning(
|
||||
"Created streaming card {} but failed to send it to {}", card_id, chat_id
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("Error creating streaming card: {}", e)
|
||||
@ -1019,18 +1191,32 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
def _stream_update_text_sync(self, card_id: str, content: str, sequence: int) -> bool:
|
||||
"""Stream-update the markdown element on a CardKit card (typewriter effect)."""
|
||||
from lark_oapi.api.cardkit.v1 import ContentCardElementRequest, ContentCardElementRequestBody
|
||||
from lark_oapi.api.cardkit.v1 import (
|
||||
ContentCardElementRequest,
|
||||
ContentCardElementRequestBody,
|
||||
)
|
||||
|
||||
try:
|
||||
request = ContentCardElementRequest.builder() \
|
||||
.card_id(card_id) \
|
||||
.element_id(_STREAM_ELEMENT_ID) \
|
||||
request = (
|
||||
ContentCardElementRequest.builder()
|
||||
.card_id(card_id)
|
||||
.element_id(_STREAM_ELEMENT_ID)
|
||||
.request_body(
|
||||
ContentCardElementRequestBody.builder()
|
||||
.content(content).sequence(sequence).build()
|
||||
).build()
|
||||
.content(content)
|
||||
.sequence(sequence)
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = self._client.cardkit.v1.card_element.content(request)
|
||||
if not response.success():
|
||||
logger.warning("Failed to stream-update card {}: code={}, msg={}", card_id, response.code, response.msg)
|
||||
logger.warning(
|
||||
"Failed to stream-update card {}: code={}, msg={}",
|
||||
card_id,
|
||||
response.code,
|
||||
response.msg,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
@ -1045,22 +1231,28 @@ class FeishuChannel(BaseChannel):
|
||||
Sequence must strictly exceed the previous card OpenAPI operation on this entity.
|
||||
"""
|
||||
from lark_oapi.api.cardkit.v1 import SettingsCardRequest, SettingsCardRequestBody
|
||||
|
||||
settings_payload = json.dumps({"config": {"streaming_mode": False}}, ensure_ascii=False)
|
||||
try:
|
||||
request = SettingsCardRequest.builder() \
|
||||
.card_id(card_id) \
|
||||
request = (
|
||||
SettingsCardRequest.builder()
|
||||
.card_id(card_id)
|
||||
.request_body(
|
||||
SettingsCardRequestBody.builder()
|
||||
.settings(settings_payload)
|
||||
.sequence(sequence)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
).build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = self._client.cardkit.v1.card.settings(request)
|
||||
if not response.success():
|
||||
logger.warning(
|
||||
"Failed to close streaming on card {}: code={}, msg={}",
|
||||
card_id, response.code, response.msg,
|
||||
card_id,
|
||||
response.code,
|
||||
response.msg,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
@ -1068,7 +1260,9 @@ class FeishuChannel(BaseChannel):
|
||||
logger.warning("Error closing streaming on card {}: {}", card_id, e)
|
||||
return False
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||
async def send_delta(
|
||||
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Progressive streaming via CardKit: create card on first delta, stream-update on subsequent."""
|
||||
if not self._client:
|
||||
return
|
||||
@ -1087,17 +1281,31 @@ class FeishuChannel(BaseChannel):
|
||||
if buf.card_id:
|
||||
buf.sequence += 1
|
||||
await loop.run_in_executor(
|
||||
None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence,
|
||||
None,
|
||||
self._stream_update_text_sync,
|
||||
buf.card_id,
|
||||
buf.text,
|
||||
buf.sequence,
|
||||
)
|
||||
# Required so the chat list preview exits the streaming placeholder (Feishu streaming card docs).
|
||||
buf.sequence += 1
|
||||
await loop.run_in_executor(
|
||||
None, self._close_streaming_mode_sync, buf.card_id, buf.sequence,
|
||||
None,
|
||||
self._close_streaming_mode_sync,
|
||||
buf.card_id,
|
||||
buf.sequence,
|
||||
)
|
||||
else:
|
||||
for chunk in self._split_elements_by_table_limit(self._build_card_elements(buf.text)):
|
||||
card = json.dumps({"config": {"wide_screen_mode": True}, "elements": chunk}, ensure_ascii=False)
|
||||
await loop.run_in_executor(None, self._send_message_sync, rid_type, chat_id, "interactive", card)
|
||||
for chunk in self._split_elements_by_table_limit(
|
||||
self._build_card_elements(buf.text)
|
||||
):
|
||||
card = json.dumps(
|
||||
{"config": {"wide_screen_mode": True}, "elements": chunk},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, rid_type, chat_id, "interactive", card
|
||||
)
|
||||
return
|
||||
|
||||
# --- accumulate delta ---
|
||||
@ -1111,15 +1319,21 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
now = time.monotonic()
|
||||
if buf.card_id is None:
|
||||
card_id = await loop.run_in_executor(None, self._create_streaming_card_sync, rid_type, chat_id)
|
||||
card_id = await loop.run_in_executor(
|
||||
None, self._create_streaming_card_sync, rid_type, chat_id
|
||||
)
|
||||
if card_id:
|
||||
buf.card_id = card_id
|
||||
buf.sequence = 1
|
||||
await loop.run_in_executor(None, self._stream_update_text_sync, card_id, buf.text, 1)
|
||||
await loop.run_in_executor(
|
||||
None, self._stream_update_text_sync, card_id, buf.text, 1
|
||||
)
|
||||
buf.last_edit = now
|
||||
elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
|
||||
buf.sequence += 1
|
||||
await loop.run_in_executor(None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence)
|
||||
await loop.run_in_executor(
|
||||
None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence
|
||||
)
|
||||
buf.last_edit = now
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
@ -1145,14 +1359,13 @@ class FeishuChannel(BaseChannel):
|
||||
# Only the very first send (media or text) in this call uses reply; subsequent
|
||||
# chunks/media fall back to plain create to avoid redundant quote bubbles.
|
||||
reply_message_id: str | None = None
|
||||
if (
|
||||
self.config.reply_to_message
|
||||
and not msg.metadata.get("_progress", False)
|
||||
):
|
||||
if self.config.reply_to_message and not msg.metadata.get("_progress", False):
|
||||
reply_message_id = msg.metadata.get("message_id") or None
|
||||
# For topic group messages, always reply to keep context in thread
|
||||
elif msg.metadata.get("thread_id"):
|
||||
reply_message_id = msg.metadata.get("root_id") or msg.metadata.get("message_id") or None
|
||||
reply_message_id = (
|
||||
msg.metadata.get("root_id") or msg.metadata.get("message_id") or None
|
||||
)
|
||||
|
||||
first_send = True # tracks whether the reply has already been used
|
||||
|
||||
@ -1176,8 +1389,10 @@ class FeishuChannel(BaseChannel):
|
||||
key = await loop.run_in_executor(None, self._upload_image_sync, file_path)
|
||||
if key:
|
||||
await loop.run_in_executor(
|
||||
None, _do_send,
|
||||
"image", json.dumps({"image_key": key}, ensure_ascii=False),
|
||||
None,
|
||||
_do_send,
|
||||
"image",
|
||||
json.dumps({"image_key": key}, ensure_ascii=False),
|
||||
)
|
||||
else:
|
||||
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
|
||||
@ -1192,8 +1407,10 @@ class FeishuChannel(BaseChannel):
|
||||
else:
|
||||
media_type = "file"
|
||||
await loop.run_in_executor(
|
||||
None, _do_send,
|
||||
media_type, json.dumps({"file_key": key}, ensure_ascii=False),
|
||||
None,
|
||||
_do_send,
|
||||
media_type,
|
||||
json.dumps({"file_key": key}, ensure_ascii=False),
|
||||
)
|
||||
|
||||
if msg.content and msg.content.strip():
|
||||
@ -1215,8 +1432,10 @@ class FeishuChannel(BaseChannel):
|
||||
for chunk in self._split_elements_by_table_limit(elements):
|
||||
card = {"config": {"wide_screen_mode": True}, "elements": chunk}
|
||||
await loop.run_in_executor(
|
||||
None, _do_send,
|
||||
"interactive", json.dumps(card, ensure_ascii=False),
|
||||
None,
|
||||
_do_send,
|
||||
"interactive",
|
||||
json.dumps(card, ensure_ascii=False),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -1231,13 +1450,16 @@ class FeishuChannel(BaseChannel):
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
|
||||
|
||||
async def _on_message(self, data: Any) -> None:
|
||||
async def _on_message(self, data: P2ImMessageReceiveV1) -> None:
|
||||
"""Handle incoming message from Feishu."""
|
||||
try:
|
||||
event = data.event
|
||||
message = event.message
|
||||
sender = event.sender
|
||||
|
||||
|
||||
logger.debug("Feishu raw message: {}", message.content)
|
||||
logger.debug("Feishu mentions: {}", getattr(message, "mentions", None))
|
||||
|
||||
# Deduplication check
|
||||
message_id = message.message_id
|
||||
if message_id in self._processed_message_ids:
|
||||
@ -1276,6 +1498,8 @@ class FeishuChannel(BaseChannel):
|
||||
if msg_type == "text":
|
||||
text = content_json.get("text", "")
|
||||
if text:
|
||||
mentions = getattr(message, "mentions", None)
|
||||
text = self._resolve_mentions(text, mentions)
|
||||
content_parts.append(text)
|
||||
|
||||
elif msg_type == "post":
|
||||
@ -1292,7 +1516,9 @@ class FeishuChannel(BaseChannel):
|
||||
content_parts.append(content_text)
|
||||
|
||||
elif msg_type in ("image", "audio", "file", "media"):
|
||||
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
|
||||
file_path, content_text = await self._download_and_save_media(
|
||||
msg_type, content_json, message_id
|
||||
)
|
||||
if file_path:
|
||||
media_paths.append(file_path)
|
||||
|
||||
@ -1303,7 +1529,14 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
content_parts.append(content_text)
|
||||
|
||||
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
|
||||
elif msg_type in (
|
||||
"share_chat",
|
||||
"share_user",
|
||||
"interactive",
|
||||
"share_calendar_event",
|
||||
"system",
|
||||
"merge_forward",
|
||||
):
|
||||
# Handle share cards and interactive messages
|
||||
text = _extract_share_card_content(content_json, msg_type)
|
||||
if text:
|
||||
@ -1346,7 +1579,7 @@ class FeishuChannel(BaseChannel):
|
||||
"parent_id": parent_id,
|
||||
"root_id": root_id,
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -1356,6 +1589,10 @@ class FeishuChannel(BaseChannel):
|
||||
"""Ignore reaction events so they do not generate SDK noise."""
|
||||
pass
|
||||
|
||||
def _on_reaction_deleted(self, data: Any) -> None:
|
||||
"""Ignore reaction deleted events so they do not generate SDK noise."""
|
||||
pass
|
||||
|
||||
def _on_message_read(self, data: Any) -> None:
|
||||
"""Ignore read events so they do not generate SDK noise."""
|
||||
pass
|
||||
@ -1411,7 +1648,9 @@ class FeishuChannel(BaseChannel):
|
||||
|
||||
return "\n".join(part for part in parts if part)
|
||||
|
||||
async def _send_tool_hint_card(self, receive_id_type: str, receive_id: str, tool_hint: str) -> None:
|
||||
async def _send_tool_hint_card(
|
||||
self, receive_id_type: str, receive_id: str, tool_hint: str
|
||||
) -> None:
|
||||
"""Send tool hint as an interactive card with formatted code block.
|
||||
|
||||
Args:
|
||||
@ -1427,15 +1666,15 @@ class FeishuChannel(BaseChannel):
|
||||
card = {
|
||||
"config": {"wide_screen_mode": True},
|
||||
"elements": [
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```"
|
||||
}
|
||||
]
|
||||
{"tag": "markdown", "content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```"}
|
||||
],
|
||||
}
|
||||
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync,
|
||||
receive_id_type, receive_id, "interactive",
|
||||
None,
|
||||
self._send_message_sync,
|
||||
receive_id_type,
|
||||
receive_id,
|
||||
"interactive",
|
||||
json.dumps(card, ensure_ascii=False),
|
||||
)
|
||||
|
||||
@ -39,7 +39,8 @@ class ChannelManager:
|
||||
"""Initialize channels discovered via pkgutil scan + entry_points plugins."""
|
||||
from nanobot.channels.registry import discover_all
|
||||
|
||||
groq_key = self.config.providers.groq.api_key
|
||||
transcription_provider = self.config.channels.transcription_provider
|
||||
transcription_key = self._resolve_transcription_key(transcription_provider)
|
||||
|
||||
for name, cls in discover_all().items():
|
||||
section = getattr(self.config.channels, name, None)
|
||||
@ -54,7 +55,8 @@ class ChannelManager:
|
||||
continue
|
||||
try:
|
||||
channel = cls(section, self.bus)
|
||||
channel.transcription_api_key = groq_key
|
||||
channel.transcription_provider = transcription_provider
|
||||
channel.transcription_api_key = transcription_key
|
||||
self.channels[name] = channel
|
||||
logger.info("{} channel enabled", cls.display_name)
|
||||
except Exception as e:
|
||||
@ -62,6 +64,15 @@ class ChannelManager:
|
||||
|
||||
self._validate_allow_from()
|
||||
|
||||
def _resolve_transcription_key(self, provider: str) -> str:
|
||||
"""Pick the API key for the configured transcription provider."""
|
||||
try:
|
||||
if provider == "openai":
|
||||
return self.config.providers.openai.api_key
|
||||
return self.config.providers.groq.api_key
|
||||
except AttributeError:
|
||||
return ""
|
||||
|
||||
def _validate_allow_from(self) -> None:
|
||||
for name, ch in self.channels.items():
|
||||
if getattr(ch.config, "allow_from", None) == []:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Matrix (Element) channel — inbound sync + outbound message/media delivery."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
@ -17,10 +18,10 @@ try:
|
||||
from nio import (
|
||||
AsyncClient,
|
||||
AsyncClientConfig,
|
||||
ContentRepositoryConfigError,
|
||||
DownloadError,
|
||||
InviteEvent,
|
||||
JoinError,
|
||||
LoginResponse,
|
||||
MatrixRoom,
|
||||
MemoryDownloadResponse,
|
||||
RoomEncryptedMedia,
|
||||
@ -203,10 +204,11 @@ class MatrixConfig(Base):
|
||||
|
||||
enabled: bool = False
|
||||
homeserver: str = "https://matrix.org"
|
||||
access_token: str = ""
|
||||
user_id: str = ""
|
||||
password: str = ""
|
||||
access_token: str = ""
|
||||
device_id: str = ""
|
||||
e2ee_enabled: bool = True
|
||||
e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled")
|
||||
sync_stop_grace_seconds: int = 2
|
||||
max_media_bytes: int = 20 * 1024 * 1024
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
@ -256,17 +258,15 @@ class MatrixChannel(BaseChannel):
|
||||
self._running = True
|
||||
_configure_nio_logging_bridge()
|
||||
|
||||
store_path = get_data_dir() / "matrix-store"
|
||||
store_path.mkdir(parents=True, exist_ok=True)
|
||||
self.store_path = get_data_dir() / "matrix-store"
|
||||
self.store_path.mkdir(parents=True, exist_ok=True)
|
||||
self.session_path = self.store_path / "session.json"
|
||||
|
||||
self.client = AsyncClient(
|
||||
homeserver=self.config.homeserver, user=self.config.user_id,
|
||||
store_path=store_path,
|
||||
store_path=self.store_path,
|
||||
config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
|
||||
)
|
||||
self.client.user_id = self.config.user_id
|
||||
self.client.access_token = self.config.access_token
|
||||
self.client.device_id = self.config.device_id
|
||||
|
||||
self._register_event_callbacks()
|
||||
self._register_response_callbacks()
|
||||
@ -274,13 +274,49 @@ class MatrixChannel(BaseChannel):
|
||||
if not self.config.e2ee_enabled:
|
||||
logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
|
||||
|
||||
if self.config.device_id:
|
||||
if self.config.password:
|
||||
if self.config.access_token or self.config.device_id:
|
||||
logger.warning("Password-based Matrix login active; access_token and device_id fields will be ignored.")
|
||||
|
||||
create_new_session = True
|
||||
if self.session_path.exists():
|
||||
logger.info("Found session.json at {}; attempting to use existing session...", self.session_path)
|
||||
try:
|
||||
with open(self.session_path, "r", encoding="utf-8") as f:
|
||||
session = json.load(f)
|
||||
self.client.user_id = self.config.user_id
|
||||
self.client.access_token = session["access_token"]
|
||||
self.client.device_id = session["device_id"]
|
||||
self.client.load_store()
|
||||
logger.info("Successfully loaded from existing session")
|
||||
create_new_session = False
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load from existing session: {}", e)
|
||||
logger.info("Falling back to password login...")
|
||||
|
||||
if create_new_session:
|
||||
logger.info("Using password login...")
|
||||
resp = await self.client.login(self.config.password)
|
||||
if isinstance(resp, LoginResponse):
|
||||
logger.info("Logged in using a password; saving details to disk")
|
||||
self._write_session_to_disk(resp)
|
||||
else:
|
||||
logger.error("Failed to log in: {}", resp)
|
||||
return
|
||||
|
||||
elif self.config.access_token and self.config.device_id:
|
||||
try:
|
||||
self.client.user_id = self.config.user_id
|
||||
self.client.access_token = self.config.access_token
|
||||
self.client.device_id = self.config.device_id
|
||||
self.client.load_store()
|
||||
except Exception:
|
||||
logger.exception("Matrix store load failed; restart may replay recent messages.")
|
||||
logger.info("Successfully loaded from existing session")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load from existing session: {}", e)
|
||||
|
||||
else:
|
||||
logger.warning("Matrix device_id empty; restart may replay recent messages.")
|
||||
logger.warning("Unable to load a Matrix session due to missing password, access_token, or device_id; encryption may not work")
|
||||
return
|
||||
|
||||
self._sync_task = asyncio.create_task(self._sync_loop())
|
||||
|
||||
@ -304,6 +340,19 @@ class MatrixChannel(BaseChannel):
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
def _write_session_to_disk(self, resp: LoginResponse) -> None:
|
||||
"""Save login session to disk for persistence across restarts."""
|
||||
session = {
|
||||
"access_token": resp.access_token,
|
||||
"device_id": resp.device_id,
|
||||
}
|
||||
try:
|
||||
with open(self.session_path, "w", encoding="utf-8") as f:
|
||||
json.dump(session, f, indent=2)
|
||||
logger.info("Session saved to {}", self.session_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save session: {}", e)
|
||||
|
||||
def _is_workspace_path_allowed(self, path: Path) -> bool:
|
||||
"""Check path is inside workspace (when restriction enabled)."""
|
||||
if not self._restrict_to_workspace or not self._workspace:
|
||||
|
||||
@ -6,14 +6,14 @@ import asyncio
|
||||
import re
|
||||
import time
|
||||
import unicodedata
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from telegram import BotCommand, ReactionTypeEmoji, ReplyParameters, Update
|
||||
from telegram.error import BadRequest, NetworkError, TimedOut
|
||||
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
|
||||
from telegram.ext import Application, ContextTypes, MessageHandler, filters
|
||||
from telegram.request import HTTPXRequest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
@ -558,8 +558,10 @@ class TelegramChannel(BaseChannel):
|
||||
await self._remove_reaction(chat_id, int(reply_to_message_id))
|
||||
except ValueError:
|
||||
pass
|
||||
chunks = split_message(buf.text, TELEGRAM_MAX_MESSAGE_LEN)
|
||||
primary_text = chunks[0] if chunks else buf.text
|
||||
try:
|
||||
html = _markdown_to_telegram_html(buf.text)
|
||||
html = _markdown_to_telegram_html(primary_text)
|
||||
await self._call_with_retry(
|
||||
self._app.bot.edit_message_text,
|
||||
chat_id=int_chat_id, message_id=buf.message_id,
|
||||
@ -575,15 +577,18 @@ class TelegramChannel(BaseChannel):
|
||||
await self._call_with_retry(
|
||||
self._app.bot.edit_message_text,
|
||||
chat_id=int_chat_id, message_id=buf.message_id,
|
||||
text=buf.text,
|
||||
text=primary_text,
|
||||
)
|
||||
except Exception as e2:
|
||||
if self._is_not_modified_error(e2):
|
||||
logger.debug("Final stream plain edit already applied for {}", chat_id)
|
||||
self._stream_bufs.pop(chat_id, None)
|
||||
return
|
||||
logger.warning("Final stream edit failed: {}", e2)
|
||||
raise # Let ChannelManager handle retry
|
||||
else:
|
||||
logger.warning("Final stream edit failed: {}", e2)
|
||||
raise # Let ChannelManager handle retry
|
||||
# If final content exceeds Telegram limit, keep the first chunk in
|
||||
# the edited stream message and send the rest as follow-up messages.
|
||||
for extra_chunk in chunks[1:]:
|
||||
await self._send_text(int_chat_id, extra_chunk)
|
||||
self._stream_bufs.pop(chat_id, None)
|
||||
return
|
||||
|
||||
@ -599,11 +604,15 @@ class TelegramChannel(BaseChannel):
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
thread_kwargs = {}
|
||||
if message_thread_id := meta.get("message_thread_id"):
|
||||
thread_kwargs["message_thread_id"] = message_thread_id
|
||||
if buf.message_id is None:
|
||||
try:
|
||||
sent = await self._call_with_retry(
|
||||
self._app.bot.send_message,
|
||||
chat_id=int_chat_id, text=buf.text,
|
||||
**thread_kwargs,
|
||||
)
|
||||
buf.message_id = sent.message_id
|
||||
buf.last_edit = now
|
||||
@ -651,9 +660,9 @@ class TelegramChannel(BaseChannel):
|
||||
|
||||
@staticmethod
|
||||
def _derive_topic_session_key(message) -> str | None:
|
||||
"""Derive topic-scoped session key for non-private Telegram chats."""
|
||||
"""Derive topic-scoped session key for Telegram chats with threads."""
|
||||
message_thread_id = getattr(message, "message_thread_id", None)
|
||||
if message.chat.type == "private" or message_thread_id is None:
|
||||
if message_thread_id is None:
|
||||
return None
|
||||
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
|
||||
|
||||
@ -815,7 +824,7 @@ class TelegramChannel(BaseChannel):
|
||||
return bool(bot_id and reply_user and reply_user.id == bot_id)
|
||||
|
||||
def _remember_thread_context(self, message) -> None:
|
||||
"""Cache topic thread id by chat/message id for follow-up replies."""
|
||||
"""Cache Telegram thread context by chat/message id for follow-up replies."""
|
||||
message_thread_id = getattr(message, "message_thread_id", None)
|
||||
if message_thread_id is None:
|
||||
return
|
||||
|
||||
@ -484,7 +484,7 @@ class WeixinChannel(BaseChannel):
|
||||
except httpx.TimeoutException:
|
||||
# Normal for long-poll, just retry
|
||||
continue
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
if not self._running:
|
||||
break
|
||||
consecutive_failures += 1
|
||||
|
||||
@ -75,6 +75,7 @@ class WhatsAppChannel(BaseChannel):
|
||||
self._ws = None
|
||||
self._connected = False
|
||||
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
||||
self._lid_to_phone: dict[str, str] = {}
|
||||
self._bridge_token: str | None = None
|
||||
|
||||
def _effective_bridge_token(self) -> str:
|
||||
@ -228,21 +229,45 @@ class WhatsAppChannel(BaseChannel):
|
||||
if not was_mentioned:
|
||||
return
|
||||
|
||||
user_id = pn if pn else sender
|
||||
sender_id = user_id.split("@")[0] if "@" in user_id else user_id
|
||||
logger.info("Sender {}", sender)
|
||||
# Classify by JID suffix: @s.whatsapp.net = phone, @lid.whatsapp.net = LID
|
||||
# The bridge's pn/sender fields don't consistently map to phone/LID across versions.
|
||||
raw_a = pn or ""
|
||||
raw_b = sender or ""
|
||||
id_a = raw_a.split("@")[0] if "@" in raw_a else raw_a
|
||||
id_b = raw_b.split("@")[0] if "@" in raw_b else raw_b
|
||||
|
||||
# Handle voice transcription if it's a voice message
|
||||
if content == "[Voice Message]":
|
||||
logger.info(
|
||||
"Voice message received from {}, but direct download from bridge is not yet supported.",
|
||||
sender_id,
|
||||
)
|
||||
content = "[Voice Message: Transcription not available for WhatsApp yet]"
|
||||
phone_id = ""
|
||||
lid_id = ""
|
||||
for raw, extracted in [(raw_a, id_a), (raw_b, id_b)]:
|
||||
if "@s.whatsapp.net" in raw:
|
||||
phone_id = extracted
|
||||
elif "@lid.whatsapp.net" in raw:
|
||||
lid_id = extracted
|
||||
elif extracted and not phone_id:
|
||||
phone_id = extracted # best guess for bare values
|
||||
|
||||
if phone_id and lid_id:
|
||||
self._lid_to_phone[lid_id] = phone_id
|
||||
sender_id = phone_id or self._lid_to_phone.get(lid_id, "") or lid_id or id_a or id_b
|
||||
|
||||
logger.info("Sender phone={} lid={} → sender_id={}", phone_id or "(empty)", lid_id or "(empty)", sender_id)
|
||||
|
||||
# Extract media paths (images/documents/videos downloaded by the bridge)
|
||||
media_paths = data.get("media") or []
|
||||
|
||||
# Handle voice transcription if it's a voice message
|
||||
if content == "[Voice Message]":
|
||||
if media_paths:
|
||||
logger.info("Transcribing voice message from {}...", sender_id)
|
||||
transcription = await self.transcribe_audio(media_paths[0])
|
||||
if transcription:
|
||||
content = transcription
|
||||
logger.info("Transcribed voice from {}: {}...", sender_id, transcription[:50])
|
||||
else:
|
||||
content = "[Voice Message: Transcription failed]"
|
||||
else:
|
||||
content = "[Voice Message: Audio not available]"
|
||||
|
||||
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
|
||||
if media_paths:
|
||||
for p in media_paths:
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
"""CLI commands for nanobot."""
|
||||
|
||||
import asyncio
|
||||
from contextlib import contextmanager, nullcontext
|
||||
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import sys
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -34,6 +33,19 @@ from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
from nanobot import __logo__, __version__
|
||||
|
||||
|
||||
class SafeFileHistory(FileHistory):
|
||||
"""FileHistory subclass that sanitizes surrogate characters on write.
|
||||
|
||||
On Windows, special Unicode input (emoji, mixed-script) can produce
|
||||
surrogate characters that crash prompt_toolkit's file write.
|
||||
See issue #2846.
|
||||
"""
|
||||
|
||||
def store_string(self, string: str) -> None:
|
||||
safe = string.encode("utf-8", errors="surrogateescape").decode("utf-8", errors="replace")
|
||||
super().store_string(safe)
|
||||
from nanobot.cli.stream import StreamRenderer, ThinkingSpinner
|
||||
from nanobot.config.paths import get_workspace_path, is_default_workspace
|
||||
from nanobot.config.schema import Config
|
||||
@ -73,6 +85,7 @@ def _flush_pending_tty_input() -> None:
|
||||
|
||||
try:
|
||||
import termios
|
||||
|
||||
termios.tcflush(fd, termios.TCIFLUSH)
|
||||
return
|
||||
except Exception:
|
||||
@ -95,6 +108,7 @@ def _restore_terminal() -> None:
|
||||
return
|
||||
try:
|
||||
import termios
|
||||
|
||||
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, _SAVED_TERM_ATTRS)
|
||||
except Exception:
|
||||
pass
|
||||
@ -107,6 +121,7 @@ def _init_prompt_session() -> None:
|
||||
# Save terminal state so we can restore it on exit
|
||||
try:
|
||||
import termios
|
||||
|
||||
_SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno())
|
||||
except Exception:
|
||||
pass
|
||||
@ -117,9 +132,9 @@ def _init_prompt_session() -> None:
|
||||
history_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_PROMPT_SESSION = PromptSession(
|
||||
history=FileHistory(str(history_file)),
|
||||
history=SafeFileHistory(str(history_file)),
|
||||
enable_open_in_editor=False,
|
||||
multiline=False, # Enter submits (single line mode)
|
||||
multiline=False, # Enter submits (single line mode)
|
||||
)
|
||||
|
||||
|
||||
@ -231,7 +246,6 @@ async def _read_interactive_input_async() -> str:
|
||||
raise KeyboardInterrupt from exc
|
||||
|
||||
|
||||
|
||||
def version_callback(value: bool):
|
||||
if value:
|
||||
console.print(f"{__logo__} nanobot v{__version__}")
|
||||
@ -281,8 +295,12 @@ def onboard(
|
||||
config = _apply_workspace_override(load_config(config_path))
|
||||
else:
|
||||
console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
|
||||
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
|
||||
console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields")
|
||||
console.print(
|
||||
" [bold]y[/bold] = overwrite with defaults (existing values will be lost)"
|
||||
)
|
||||
console.print(
|
||||
" [bold]N[/bold] = refresh config, keeping existing values and adding new fields"
|
||||
)
|
||||
if typer.confirm("Overwrite?"):
|
||||
config = _apply_workspace_override(Config())
|
||||
save_config(config, config_path)
|
||||
@ -290,7 +308,9 @@ def onboard(
|
||||
else:
|
||||
config = _apply_workspace_override(load_config(config_path))
|
||||
save_config(config, config_path)
|
||||
console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)")
|
||||
console.print(
|
||||
f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)"
|
||||
)
|
||||
else:
|
||||
config = _apply_workspace_override(Config())
|
||||
# In wizard mode, don't save yet - the wizard will handle saving if should_save=True
|
||||
@ -340,7 +360,9 @@ def onboard(
|
||||
console.print(f" 1. Add your API key to [cyan]{config_path}[/cyan]")
|
||||
console.print(" Get one at: https://openrouter.ai/keys")
|
||||
console.print(f" 2. Chat: [cyan]{agent_cmd}[/cyan]")
|
||||
console.print("\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]")
|
||||
console.print(
|
||||
"\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]"
|
||||
)
|
||||
|
||||
|
||||
def _merge_missing_defaults(existing: Any, defaults: Any) -> Any:
|
||||
@ -413,9 +435,11 @@ def _make_provider(config: Config):
|
||||
# --- instantiation by backend ---
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key,
|
||||
api_base=p.api_base,
|
||||
@ -426,6 +450,7 @@ def _make_provider(config: Config):
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
@ -434,6 +459,7 @@ def _make_provider(config: Config):
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
@ -453,7 +479,7 @@ def _make_provider(config: Config):
|
||||
|
||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||
"""Load config and optionally override the active workspace."""
|
||||
from nanobot.config.loader import load_config, set_config_path
|
||||
from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path
|
||||
|
||||
config_path = None
|
||||
if config:
|
||||
@ -464,7 +490,11 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
|
||||
set_config_path(config_path)
|
||||
console.print(f"[dim]Using config: {config_path}[/dim]")
|
||||
|
||||
loaded = load_config(config_path)
|
||||
try:
|
||||
loaded = resolve_config_env_vars(load_config(config_path))
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error: {e}[/red]")
|
||||
raise typer.Exit(1)
|
||||
_warn_deprecated_config_keys(config_path)
|
||||
if workspace:
|
||||
loaded.agents.defaults.workspace = workspace
|
||||
@ -474,6 +504,7 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
|
||||
def _warn_deprecated_config_keys(config_path: Path | None) -> None:
|
||||
"""Hint users to remove obsolete keys from their config file."""
|
||||
import json
|
||||
|
||||
from nanobot.config.loader import get_config_path
|
||||
|
||||
path = config_path or get_config_path()
|
||||
@ -497,6 +528,7 @@ def _migrate_cron_store(config: "Config") -> None:
|
||||
if legacy_path.is_file() and not new_path.exists():
|
||||
new_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
import shutil
|
||||
|
||||
shutil.move(str(legacy_path), str(new_path))
|
||||
|
||||
|
||||
@ -610,6 +642,7 @@ def gateway(
|
||||
|
||||
if verbose:
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
config = _load_runtime_config(config, workspace)
|
||||
@ -695,7 +728,7 @@ def gateway(
|
||||
|
||||
if job.payload.deliver and job.payload.to and response:
|
||||
should_notify = await evaluate_response(
|
||||
response, job.payload.message, provider, agent.model,
|
||||
response, reminder_note, provider, agent.model,
|
||||
)
|
||||
if should_notify:
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
@ -705,6 +738,7 @@ def gateway(
|
||||
content=response,
|
||||
))
|
||||
return response
|
||||
|
||||
cron.on_job = on_cron_job
|
||||
|
||||
# Create channel manager
|
||||
@ -808,6 +842,7 @@ def gateway(
|
||||
console.print("\nShutting down...")
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
console.print("\n[red]Error: Gateway crashed unexpectedly[/red]")
|
||||
console.print(traceback.format_exc())
|
||||
finally:
|
||||
@ -820,8 +855,6 @@ def gateway(
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Agent Commands
|
||||
# ============================================================================
|
||||
@ -1296,6 +1329,7 @@ def _register_login(name: str):
|
||||
def decorator(fn):
|
||||
_LOGIN_HANDLERS[name] = fn
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@ -1326,6 +1360,7 @@ def provider_login(
|
||||
def _login_openai_codex() -> None:
|
||||
try:
|
||||
from oauth_cli_kit import get_token, login_oauth_interactive
|
||||
|
||||
token = None
|
||||
try:
|
||||
token = get_token()
|
||||
|
||||
@ -60,6 +60,20 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
||||
pass
|
||||
if ctx_est <= 0:
|
||||
ctx_est = loop._last_usage.get("prompt_tokens", 0)
|
||||
|
||||
# Fetch web search provider usage (best-effort, never blocks the response)
|
||||
search_usage_text: str | None = None
|
||||
try:
|
||||
from nanobot.utils.searchusage import fetch_search_usage
|
||||
web_cfg = getattr(loop, "web_config", None)
|
||||
search_cfg = getattr(web_cfg, "search", None) if web_cfg else None
|
||||
if search_cfg is not None:
|
||||
provider = getattr(search_cfg, "provider", "duckduckgo")
|
||||
api_key = getattr(search_cfg, "api_key", "") or None
|
||||
usage = await fetch_search_usage(provider=provider, api_key=api_key)
|
||||
search_usage_text = usage.format()
|
||||
except Exception:
|
||||
pass # Never let usage fetch break /status
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
@ -69,6 +83,7 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
||||
context_window_tokens=loop.context_window_tokens,
|
||||
session_msg_count=len(session.get_history(max_messages=0)),
|
||||
context_tokens_estimate=ctx_est,
|
||||
search_usage_text=search_usage_text,
|
||||
),
|
||||
metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"},
|
||||
)
|
||||
@ -93,14 +108,30 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
|
||||
|
||||
async def cmd_dream(ctx: CommandContext) -> OutboundMessage:
|
||||
"""Manually trigger a Dream consolidation run."""
|
||||
import time
|
||||
|
||||
loop = ctx.loop
|
||||
try:
|
||||
did_work = await loop.dream.run()
|
||||
content = "Dream completed." if did_work else "Dream: nothing to process."
|
||||
except Exception as e:
|
||||
content = f"Dream failed: {e}"
|
||||
msg = ctx.msg
|
||||
|
||||
async def _run_dream():
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
did_work = await loop.dream.run()
|
||||
elapsed = time.monotonic() - t0
|
||||
if did_work:
|
||||
content = f"Dream completed in {elapsed:.1f}s."
|
||||
else:
|
||||
content = "Dream: nothing to process."
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - t0
|
||||
content = f"Dream failed after {elapsed:.1f}s: {e}"
|
||||
await loop.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content=content,
|
||||
))
|
||||
|
||||
asyncio.create_task(_run_dream())
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel, chat_id=ctx.msg.chat_id, content=content,
|
||||
channel=msg.channel, chat_id=msg.chat_id, content="Dreaming...",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
"""Configuration loading utilities."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pydantic
|
||||
@ -76,6 +78,38 @@ def save_config(config: Config, config_path: Path | None = None) -> None:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def resolve_config_env_vars(config: Config) -> Config:
|
||||
"""Return a copy of *config* with ``${VAR}`` env-var references resolved.
|
||||
|
||||
Only string values are affected; other types pass through unchanged.
|
||||
Raises :class:`ValueError` if a referenced variable is not set.
|
||||
"""
|
||||
data = config.model_dump(mode="json", by_alias=True)
|
||||
data = _resolve_env_vars(data)
|
||||
return Config.model_validate(data)
|
||||
|
||||
|
||||
def _resolve_env_vars(obj: object) -> object:
|
||||
"""Recursively resolve ``${VAR}`` patterns in string values."""
|
||||
if isinstance(obj, str):
|
||||
return re.sub(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", _env_replace, obj)
|
||||
if isinstance(obj, dict):
|
||||
return {k: _resolve_env_vars(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_resolve_env_vars(v) for v in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def _env_replace(match: re.Match[str]) -> str:
|
||||
name = match.group(1)
|
||||
value = os.environ.get(name)
|
||||
if value is None:
|
||||
raise ValueError(
|
||||
f"Environment variable '{name}' referenced in config is not set"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
def _migrate_config(data: dict) -> dict:
|
||||
"""Migrate old config formats to current."""
|
||||
# Move tools.exec.restrictToWorkspace → tools.restrictToWorkspace
|
||||
|
||||
@ -28,6 +28,7 @@ class ChannelsConfig(Base):
|
||||
send_progress: bool = True # stream agent's text progress to the channel
|
||||
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
||||
send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included)
|
||||
transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai"
|
||||
|
||||
|
||||
class DreamConfig(Base):
|
||||
@ -155,6 +156,7 @@ class WebSearchConfig(Base):
|
||||
api_key: str = ""
|
||||
base_url: str = "" # SearXNG base URL
|
||||
max_results: int = 5
|
||||
timeout: int = 30 # Wall-clock timeout (seconds) for search operations
|
||||
|
||||
|
||||
class WebToolsConfig(Base):
|
||||
@ -173,6 +175,7 @@ class ExecToolConfig(Base):
|
||||
enable: bool = True
|
||||
timeout: int = 60
|
||||
path_append: str = ""
|
||||
sandbox: str = "" # sandbox backend: "" (none) or "bwrap"
|
||||
|
||||
class MCPServerConfig(Base):
|
||||
"""MCP server connection configuration (stdio or HTTP)."""
|
||||
@ -191,7 +194,7 @@ class ToolsConfig(Base):
|
||||
|
||||
web: WebToolsConfig = Field(default_factory=WebToolsConfig)
|
||||
exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
|
||||
restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory
|
||||
restrict_to_workspace: bool = False # restrict all tool access to workspace directory
|
||||
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
||||
ssrf_whitelist: list[str] = Field(default_factory=list) # CIDR ranges to exempt from SSRF blocking (e.g. ["100.64.0.0/10"] for Tailscale)
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ class Nanobot:
|
||||
``~/.nanobot/config.json``.
|
||||
workspace: Override the workspace directory from config.
|
||||
"""
|
||||
from nanobot.config.loader import load_config
|
||||
from nanobot.config.loader import load_config, resolve_config_env_vars
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
resolved: Path | None = None
|
||||
@ -56,7 +56,7 @@ class Nanobot:
|
||||
if not resolved.exists():
|
||||
raise FileNotFoundError(f"Config not found: {resolved}")
|
||||
|
||||
config: Config = load_config(resolved)
|
||||
config: Config = resolve_config_env_vars(load_config(resolved))
|
||||
if workspace is not None:
|
||||
config.agents.defaults.workspace = str(
|
||||
Path(workspace).expanduser().resolve()
|
||||
|
||||
@ -52,6 +52,62 @@ class AnthropicProvider(LLMProvider):
|
||||
client_kw["max_retries"] = 0
|
||||
self._client = AsyncAnthropic(**client_kw)
|
||||
|
||||
@classmethod
|
||||
def _handle_error(cls, e: Exception) -> LLMResponse:
|
||||
response = getattr(e, "response", None)
|
||||
headers = getattr(response, "headers", None)
|
||||
payload = (
|
||||
getattr(e, "body", None)
|
||||
or getattr(e, "doc", None)
|
||||
or getattr(response, "text", None)
|
||||
)
|
||||
if payload is None and response is not None:
|
||||
response_json = getattr(response, "json", None)
|
||||
if callable(response_json):
|
||||
try:
|
||||
payload = response_json()
|
||||
except Exception:
|
||||
payload = None
|
||||
payload_text = payload if isinstance(payload, str) else str(payload) if payload is not None else ""
|
||||
msg = f"Error: {payload_text.strip()[:500]}" if payload_text.strip() else f"Error calling LLM: {e}"
|
||||
retry_after = cls._extract_retry_after_from_headers(headers)
|
||||
if retry_after is None:
|
||||
retry_after = LLMProvider._extract_retry_after(msg)
|
||||
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if status_code is None and response is not None:
|
||||
status_code = getattr(response, "status_code", None)
|
||||
|
||||
should_retry: bool | None = None
|
||||
if headers is not None:
|
||||
raw = headers.get("x-should-retry")
|
||||
if isinstance(raw, str):
|
||||
lowered = raw.strip().lower()
|
||||
if lowered == "true":
|
||||
should_retry = True
|
||||
elif lowered == "false":
|
||||
should_retry = False
|
||||
|
||||
error_kind: str | None = None
|
||||
error_name = e.__class__.__name__.lower()
|
||||
if "timeout" in error_name:
|
||||
error_kind = "timeout"
|
||||
elif "connection" in error_name:
|
||||
error_kind = "connection"
|
||||
error_type, error_code = LLMProvider._extract_error_type_code(payload)
|
||||
|
||||
return LLMResponse(
|
||||
content=msg,
|
||||
finish_reason="error",
|
||||
retry_after=retry_after,
|
||||
error_status_code=int(status_code) if status_code is not None else None,
|
||||
error_kind=error_kind,
|
||||
error_type=error_type,
|
||||
error_code=error_code,
|
||||
error_retry_after_s=retry_after,
|
||||
error_should_retry=should_retry,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _strip_prefix(model: str) -> str:
|
||||
if model.startswith("anthropic/"):
|
||||
@ -404,15 +460,6 @@ class AnthropicProvider(LLMProvider):
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _handle_error(e: Exception) -> LLMResponse:
|
||||
msg = f"Error calling LLM: {e}"
|
||||
response = getattr(e, "response", None)
|
||||
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
|
||||
if retry_after is None:
|
||||
retry_after = LLMProvider._extract_retry_after(msg)
|
||||
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
@ -474,6 +521,7 @@ class AnthropicProvider(LLMProvider):
|
||||
f"{idle_timeout_s} seconds"
|
||||
),
|
||||
finish_reason="error",
|
||||
error_kind="timeout",
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
@ -54,6 +54,13 @@ class LLMResponse:
|
||||
retry_after: float | None = None # Provider supplied retry wait in seconds.
|
||||
reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc.
|
||||
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
|
||||
# Structured error metadata used by retry policy when finish_reason == "error".
|
||||
error_status_code: int | None = None
|
||||
error_kind: str | None = None # e.g. "timeout", "connection"
|
||||
error_type: str | None = None # Provider/type semantic, e.g. insufficient_quota.
|
||||
error_code: str | None = None # Provider/code semantic, e.g. rate_limit_exceeded.
|
||||
error_retry_after_s: float | None = None
|
||||
error_should_retry: bool | None = None
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
@ -91,6 +98,52 @@ class LLMProvider(ABC):
|
||||
"server error",
|
||||
"temporarily unavailable",
|
||||
)
|
||||
_RETRYABLE_STATUS_CODES = frozenset({408, 409, 429})
|
||||
_TRANSIENT_ERROR_KINDS = frozenset({"timeout", "connection"})
|
||||
_NON_RETRYABLE_429_ERROR_TOKENS = frozenset({
|
||||
"insufficient_quota",
|
||||
"quota_exceeded",
|
||||
"quota_exhausted",
|
||||
"billing_hard_limit_reached",
|
||||
"insufficient_balance",
|
||||
"credit_balance_too_low",
|
||||
"billing_not_active",
|
||||
"payment_required",
|
||||
})
|
||||
_RETRYABLE_429_ERROR_TOKENS = frozenset({
|
||||
"rate_limit_exceeded",
|
||||
"rate_limit_error",
|
||||
"too_many_requests",
|
||||
"request_limit_exceeded",
|
||||
"requests_limit_exceeded",
|
||||
"overloaded_error",
|
||||
})
|
||||
_NON_RETRYABLE_429_TEXT_MARKERS = (
|
||||
"insufficient_quota",
|
||||
"insufficient quota",
|
||||
"quota exceeded",
|
||||
"quota exhausted",
|
||||
"billing hard limit",
|
||||
"billing_hard_limit_reached",
|
||||
"billing not active",
|
||||
"insufficient balance",
|
||||
"insufficient_balance",
|
||||
"credit balance too low",
|
||||
"payment required",
|
||||
"out of credits",
|
||||
"out of quota",
|
||||
"exceeded your current quota",
|
||||
)
|
||||
_RETRYABLE_429_TEXT_MARKERS = (
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"too many requests",
|
||||
"retry after",
|
||||
"try again in",
|
||||
"temporarily unavailable",
|
||||
"overloaded",
|
||||
"concurrency limit",
|
||||
)
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
@ -226,6 +279,80 @@ class LLMProvider(ABC):
|
||||
err = (content or "").lower()
|
||||
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||
|
||||
@classmethod
|
||||
def _is_transient_response(cls, response: LLMResponse) -> bool:
|
||||
"""Prefer structured error metadata, fallback to text markers for legacy providers."""
|
||||
if response.error_should_retry is not None:
|
||||
return bool(response.error_should_retry)
|
||||
|
||||
if response.error_status_code is not None:
|
||||
status = int(response.error_status_code)
|
||||
if status == 429:
|
||||
return cls._is_retryable_429_response(response)
|
||||
if status in cls._RETRYABLE_STATUS_CODES or status >= 500:
|
||||
return True
|
||||
|
||||
kind = (response.error_kind or "").strip().lower()
|
||||
if kind in cls._TRANSIENT_ERROR_KINDS:
|
||||
return True
|
||||
|
||||
return cls._is_transient_error(response.content)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_error_token(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
token = str(value).strip().lower()
|
||||
return token or None
|
||||
|
||||
@classmethod
|
||||
def _extract_error_type_code(cls, payload: Any) -> tuple[str | None, str | None]:
|
||||
data: dict[str, Any] | None = None
|
||||
if isinstance(payload, dict):
|
||||
data = payload
|
||||
elif isinstance(payload, str):
|
||||
text = payload.strip()
|
||||
if text:
|
||||
try:
|
||||
parsed = json.loads(text)
|
||||
except Exception:
|
||||
parsed = None
|
||||
if isinstance(parsed, dict):
|
||||
data = parsed
|
||||
if not isinstance(data, dict):
|
||||
return None, None
|
||||
|
||||
error_obj = data.get("error")
|
||||
type_value = data.get("type")
|
||||
code_value = data.get("code")
|
||||
if isinstance(error_obj, dict):
|
||||
type_value = error_obj.get("type") or type_value
|
||||
code_value = error_obj.get("code") or code_value
|
||||
|
||||
return cls._normalize_error_token(type_value), cls._normalize_error_token(code_value)
|
||||
|
||||
@classmethod
|
||||
def _is_retryable_429_response(cls, response: LLMResponse) -> bool:
|
||||
type_token = cls._normalize_error_token(response.error_type)
|
||||
code_token = cls._normalize_error_token(response.error_code)
|
||||
semantic_tokens = {
|
||||
token for token in (type_token, code_token)
|
||||
if token is not None
|
||||
}
|
||||
if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in semantic_tokens):
|
||||
return False
|
||||
|
||||
content = (response.content or "").lower()
|
||||
if any(marker in content for marker in cls._NON_RETRYABLE_429_TEXT_MARKERS):
|
||||
return False
|
||||
|
||||
if any(token in cls._RETRYABLE_429_ERROR_TOKENS for token in semantic_tokens):
|
||||
return True
|
||||
if any(marker in content for marker in cls._RETRYABLE_429_TEXT_MARKERS):
|
||||
return True
|
||||
# Unknown 429 defaults to WAIT+retry.
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
||||
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
|
||||
@ -397,14 +524,28 @@ class LLMProvider(ABC):
|
||||
def _extract_retry_after_from_headers(cls, headers: Any) -> float | None:
|
||||
if not headers:
|
||||
return None
|
||||
retry_after: Any = None
|
||||
if hasattr(headers, "get"):
|
||||
retry_after = headers.get("retry-after") or headers.get("Retry-After")
|
||||
if retry_after is None and isinstance(headers, dict):
|
||||
for key, value in headers.items():
|
||||
if isinstance(key, str) and key.lower() == "retry-after":
|
||||
retry_after = value
|
||||
break
|
||||
|
||||
def _header_value(name: str) -> Any:
|
||||
if hasattr(headers, "get"):
|
||||
value = headers.get(name) or headers.get(name.title())
|
||||
if value is not None:
|
||||
return value
|
||||
if isinstance(headers, dict):
|
||||
for key, value in headers.items():
|
||||
if isinstance(key, str) and key.lower() == name.lower():
|
||||
return value
|
||||
return None
|
||||
|
||||
try:
|
||||
retry_ms = _header_value("retry-after-ms")
|
||||
if retry_ms is not None:
|
||||
value = float(retry_ms) / 1000.0
|
||||
if value > 0:
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
retry_after = _header_value("retry-after")
|
||||
if retry_after is None:
|
||||
return None
|
||||
retry_after_text = str(retry_after).strip()
|
||||
@ -421,6 +562,14 @@ class LLMProvider(ABC):
|
||||
remaining = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds()
|
||||
return max(0.1, remaining)
|
||||
|
||||
@classmethod
|
||||
def _extract_retry_after_from_response(cls, response: LLMResponse) -> float | None:
|
||||
if response.error_retry_after_s is not None and response.error_retry_after_s > 0:
|
||||
return response.error_retry_after_s
|
||||
if response.retry_after is not None and response.retry_after > 0:
|
||||
return response.retry_after
|
||||
return cls._extract_retry_after(response.content)
|
||||
|
||||
async def _sleep_with_heartbeat(
|
||||
self,
|
||||
delay: float,
|
||||
@ -469,7 +618,7 @@ class LLMProvider(ABC):
|
||||
last_error_key = error_key
|
||||
identical_error_count = 1 if error_key else 0
|
||||
|
||||
if not self._is_transient_error(response.content):
|
||||
if not self._is_transient_response(response):
|
||||
stripped = self._strip_image_content(original_messages)
|
||||
if stripped is not None and stripped != kw["messages"]:
|
||||
logger.warning(
|
||||
@ -492,7 +641,7 @@ class LLMProvider(ABC):
|
||||
break
|
||||
|
||||
base_delay = delays[min(attempt - 1, len(delays) - 1)]
|
||||
delay = response.retry_after or self._extract_retry_after(response.content) or base_delay
|
||||
delay = self._extract_retry_after_from_response(response) or base_delay
|
||||
if persistent:
|
||||
delay = min(delay, self._PERSISTENT_MAX_DELAY)
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import importlib.util
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
@ -12,7 +13,17 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import json_repair
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
if os.environ.get("LANGFUSE_SECRET_KEY") and importlib.util.find_spec("langfuse"):
|
||||
from langfuse.openai import AsyncOpenAI
|
||||
else:
|
||||
if os.environ.get("LANGFUSE_SECRET_KEY"):
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"LANGFUSE_SECRET_KEY is set but langfuse is not installed; "
|
||||
"install with `pip install langfuse` to enable tracing"
|
||||
)
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
@ -286,6 +297,24 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
|
||||
# Provider-specific thinking parameters.
|
||||
# Only sent when reasoning_effort is explicitly configured so that
|
||||
# the provider default is preserved otherwise.
|
||||
if spec and reasoning_effort is not None:
|
||||
thinking_enabled = reasoning_effort.lower() != "minimal"
|
||||
extra: dict[str, Any] | None = None
|
||||
if spec.name == "dashscope":
|
||||
extra = {"enable_thinking": thinking_enabled}
|
||||
elif spec.name in (
|
||||
"volcengine", "volcengine_coding_plan",
|
||||
"byteplus", "byteplus_coding_plan",
|
||||
):
|
||||
extra = {
|
||||
"thinking": {"type": "enabled" if thinking_enabled else "disabled"}
|
||||
}
|
||||
if extra:
|
||||
kwargs.setdefault("extra_body", {}).update(extra)
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = tool_choice or "auto"
|
||||
@ -605,16 +634,73 @@ class OpenAICompatProvider(LLMProvider):
|
||||
reasoning_content="".join(reasoning_parts) or None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_error_metadata(cls, e: Exception) -> dict[str, Any]:
|
||||
response = getattr(e, "response", None)
|
||||
headers = getattr(response, "headers", None)
|
||||
payload = (
|
||||
getattr(e, "body", None)
|
||||
or getattr(e, "doc", None)
|
||||
or getattr(response, "text", None)
|
||||
)
|
||||
if payload is None and response is not None:
|
||||
response_json = getattr(response, "json", None)
|
||||
if callable(response_json):
|
||||
try:
|
||||
payload = response_json()
|
||||
except Exception:
|
||||
payload = None
|
||||
error_type, error_code = LLMProvider._extract_error_type_code(payload)
|
||||
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if status_code is None and response is not None:
|
||||
status_code = getattr(response, "status_code", None)
|
||||
|
||||
should_retry: bool | None = None
|
||||
if headers is not None:
|
||||
raw = headers.get("x-should-retry")
|
||||
if isinstance(raw, str):
|
||||
lowered = raw.strip().lower()
|
||||
if lowered == "true":
|
||||
should_retry = True
|
||||
elif lowered == "false":
|
||||
should_retry = False
|
||||
|
||||
error_kind: str | None = None
|
||||
error_name = e.__class__.__name__.lower()
|
||||
if "timeout" in error_name:
|
||||
error_kind = "timeout"
|
||||
elif "connection" in error_name:
|
||||
error_kind = "connection"
|
||||
|
||||
return {
|
||||
"error_status_code": int(status_code) if status_code is not None else None,
|
||||
"error_kind": error_kind,
|
||||
"error_type": error_type,
|
||||
"error_code": error_code,
|
||||
"error_retry_after_s": cls._extract_retry_after_from_headers(headers),
|
||||
"error_should_retry": should_retry,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _handle_error(e: Exception) -> LLMResponse:
|
||||
body = (
|
||||
getattr(e, "doc", None)
|
||||
or getattr(e, "body", None)
|
||||
or getattr(getattr(e, "response", None), "text", None)
|
||||
)
|
||||
body_text = body if isinstance(body, str) else str(body) if body is not None else ""
|
||||
msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}"
|
||||
response = getattr(e, "response", None)
|
||||
body = getattr(e, "doc", None) or getattr(response, "text", None)
|
||||
body_text = str(body).strip() if body is not None else ""
|
||||
msg = f"Error: {body_text[:500]}" if body_text else f"Error calling LLM: {e}"
|
||||
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
|
||||
if retry_after is None:
|
||||
retry_after = LLMProvider._extract_retry_after(msg)
|
||||
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
|
||||
return LLMResponse(
|
||||
content=msg,
|
||||
finish_reason="error",
|
||||
retry_after=retry_after,
|
||||
**OpenAICompatProvider._extract_error_metadata(e),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
@ -682,6 +768,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
f"{idle_timeout_s} seconds"
|
||||
),
|
||||
finish_reason="error",
|
||||
error_kind="timeout",
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
"""Voice transcription provider using Groq."""
|
||||
"""Voice transcription providers (Groq and OpenAI Whisper)."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
@ -7,6 +7,36 @@ import httpx
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class OpenAITranscriptionProvider:
|
||||
"""Voice transcription provider using OpenAI's Whisper API."""
|
||||
|
||||
def __init__(self, api_key: str | None = None):
|
||||
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
||||
self.api_url = "https://api.openai.com/v1/audio/transcriptions"
|
||||
|
||||
async def transcribe(self, file_path: str | Path) -> str:
|
||||
if not self.api_key:
|
||||
logger.warning("OpenAI API key not configured for transcription")
|
||||
return ""
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
logger.error("Audio file not found: {}", file_path)
|
||||
return ""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
with open(path, "rb") as f:
|
||||
files = {"file": (path.name, f), "model": (None, "whisper-1")}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
response = await client.post(
|
||||
self.api_url, headers=headers, files=files, timeout=60.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json().get("text", "")
|
||||
except Exception as e:
|
||||
logger.error("OpenAI transcription error: {}", e)
|
||||
return ""
|
||||
|
||||
|
||||
class GroqTranscriptionProvider:
|
||||
"""
|
||||
Voice transcription provider using Groq's Whisper API.
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
{% if part == 'system' %}
|
||||
You are a notification gate for a background agent. You will be given the original task and the agent's response. Call the evaluate_notification tool to decide whether the user should be notified.
|
||||
|
||||
Notify when the response contains actionable information, errors, completed deliverables, or anything the user explicitly asked to be reminded about.
|
||||
Notify when the response contains actionable information, errors, completed deliverables, scheduled reminder/timer completions, or anything the user explicitly asked to be reminded about.
|
||||
|
||||
A user-scheduled reminder should usually notify even when the response is brief or mostly repeats the original reminder.
|
||||
|
||||
Suppress when the response is a routine status check with nothing new, a confirmation that everything is normal, or essentially empty.
|
||||
{% elif part == 'user' %}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Utility functions for nanobot."""
|
||||
|
||||
from nanobot.utils.helpers import ensure_dir
|
||||
from nanobot.utils.path import abbreviate_path
|
||||
|
||||
__all__ = ["ensure_dir"]
|
||||
__all__ = ["ensure_dir", "abbreviate_path"]
|
||||
|
||||
@ -396,8 +396,15 @@ def build_status_content(
|
||||
context_window_tokens: int,
|
||||
session_msg_count: int,
|
||||
context_tokens_estimate: int,
|
||||
search_usage_text: str | None = None,
|
||||
) -> str:
|
||||
"""Build a human-readable runtime status snapshot."""
|
||||
"""Build a human-readable runtime status snapshot.
|
||||
|
||||
Args:
|
||||
search_usage_text: Optional pre-formatted web search usage string
|
||||
(produced by SearchUsageInfo.format()). When provided
|
||||
it is appended as an extra section.
|
||||
"""
|
||||
uptime_s = int(time.time() - start_time)
|
||||
uptime = (
|
||||
f"{uptime_s // 3600}h {(uptime_s % 3600) // 60}m"
|
||||
@ -414,14 +421,17 @@ def build_status_content(
|
||||
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
|
||||
if cached and last_in:
|
||||
token_line += f" ({cached * 100 // last_in}% cached)"
|
||||
return "\n".join([
|
||||
lines = [
|
||||
f"\U0001f408 nanobot v{version}",
|
||||
f"\U0001f9e0 Model: {model}",
|
||||
token_line,
|
||||
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
|
||||
f"\U0001f4ac Session: {session_msg_count} messages",
|
||||
f"\u23f1 Uptime: {uptime}",
|
||||
])
|
||||
]
|
||||
if search_usage_text:
|
||||
lines.append(search_usage_text)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||
|
||||
107
nanobot/utils/path.py
Normal file
107
nanobot/utils/path.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""Path abbreviation utilities for display."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def abbreviate_path(path: str, max_len: int = 40) -> str:
|
||||
"""Abbreviate a file path or URL, preserving basename and key directories.
|
||||
|
||||
Strategy:
|
||||
1. Return as-is if short enough
|
||||
2. Replace home directory with ~/
|
||||
3. From right, keep basename + parent dirs until budget exhausted
|
||||
4. Prefix with …/
|
||||
"""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
# Handle URLs: preserve scheme://domain + filename
|
||||
if re.match(r"https?://", path):
|
||||
return _abbreviate_url(path, max_len)
|
||||
|
||||
# Normalize separators to /
|
||||
normalized = path.replace("\\", "/")
|
||||
|
||||
# Replace home directory
|
||||
home = os.path.expanduser("~").replace("\\", "/")
|
||||
if normalized.startswith(home + "/"):
|
||||
normalized = "~" + normalized[len(home):]
|
||||
elif normalized == home:
|
||||
normalized = "~"
|
||||
|
||||
# Return early only after normalization and home replacement
|
||||
if len(normalized) <= max_len:
|
||||
return normalized
|
||||
|
||||
# Split into segments
|
||||
parts = normalized.rstrip("/").split("/")
|
||||
if len(parts) <= 1:
|
||||
return normalized[:max_len - 1] + "\u2026"
|
||||
|
||||
# Always keep the basename
|
||||
basename = parts[-1]
|
||||
# Budget: max_len minus "…/" prefix (2 chars) minus "/" separator minus basename
|
||||
budget = max_len - len(basename) - 3 # -3 for "…/" + final "/"
|
||||
|
||||
# Walk backwards from parent, collecting segments
|
||||
kept: list[str] = []
|
||||
for seg in reversed(parts[:-1]):
|
||||
needed = len(seg) + 1 # segment + "/"
|
||||
if not kept and needed <= budget:
|
||||
kept.append(seg)
|
||||
budget -= needed
|
||||
elif kept:
|
||||
needed_with_sep = len(seg) + 1
|
||||
if needed_with_sep <= budget:
|
||||
kept.append(seg)
|
||||
budget -= needed_with_sep
|
||||
else:
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
kept.reverse()
|
||||
if kept:
|
||||
return "\u2026/" + "/".join(kept) + "/" + basename
|
||||
return "\u2026/" + basename
|
||||
|
||||
|
||||
def _abbreviate_url(url: str, max_len: int = 40) -> str:
|
||||
"""Abbreviate a URL keeping domain and filename."""
|
||||
if len(url) <= max_len:
|
||||
return url
|
||||
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc # e.g. "example.com"
|
||||
path_part = parsed.path # e.g. "/api/v2/resource.json"
|
||||
|
||||
# Extract filename from path
|
||||
segments = path_part.rstrip("/").split("/")
|
||||
basename = segments[-1] if segments else ""
|
||||
|
||||
if not basename:
|
||||
# No filename, truncate URL
|
||||
return url[: max_len - 1] + "\u2026"
|
||||
|
||||
budget = max_len - len(domain) - len(basename) - 4 # "…/" + "/"
|
||||
if budget < 0:
|
||||
trunc = max_len - len(domain) - 5 # "…/" + "/"
|
||||
return domain + "/\u2026/" + (basename[:trunc] if trunc > 0 else "")
|
||||
|
||||
# Build abbreviated path
|
||||
kept: list[str] = []
|
||||
for seg in reversed(segments[:-1]):
|
||||
if len(seg) + 1 <= budget:
|
||||
kept.append(seg)
|
||||
budget -= len(seg) + 1
|
||||
else:
|
||||
break
|
||||
|
||||
kept.reverse()
|
||||
if kept:
|
||||
return domain + "/\u2026/" + "/".join(kept) + "/" + basename
|
||||
return domain + "/\u2026/" + basename
|
||||
@ -16,8 +16,7 @@ EMPTY_FINAL_RESPONSE_MESSAGE = (
|
||||
)
|
||||
|
||||
FINALIZATION_RETRY_PROMPT = (
|
||||
"You have already finished the tool work. Do not call any more tools. "
|
||||
"Using only the conversation and tool results above, provide the final answer for the user now."
|
||||
"Please provide your response to the user based on the conversation above."
|
||||
)
|
||||
|
||||
|
||||
|
||||
168
nanobot/utils/searchusage.py
Normal file
168
nanobot/utils/searchusage.py
Normal file
@ -0,0 +1,168 @@
|
||||
"""Web search provider usage fetchers for /status command."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchUsageInfo:
|
||||
"""Structured usage info returned by a provider fetcher."""
|
||||
|
||||
provider: str
|
||||
supported: bool = False # True if the provider has a usage API
|
||||
error: str | None = None # Set when the API call failed
|
||||
|
||||
# Usage counters (None = not available for this provider)
|
||||
used: int | None = None
|
||||
limit: int | None = None
|
||||
remaining: int | None = None
|
||||
reset_date: str | None = None # ISO date string, e.g. "2026-05-01"
|
||||
|
||||
# Tavily-specific breakdown
|
||||
search_used: int | None = None
|
||||
extract_used: int | None = None
|
||||
crawl_used: int | None = None
|
||||
|
||||
def format(self) -> str:
|
||||
"""Return a human-readable multi-line string for /status output."""
|
||||
lines = [f"🔍 Web Search: {self.provider}"]
|
||||
|
||||
if not self.supported:
|
||||
lines.append(" Usage tracking: not available for this provider")
|
||||
return "\n".join(lines)
|
||||
|
||||
if self.error:
|
||||
lines.append(f" Usage: unavailable ({self.error})")
|
||||
return "\n".join(lines)
|
||||
|
||||
if self.used is not None and self.limit is not None:
|
||||
lines.append(f" Usage: {self.used} / {self.limit} requests")
|
||||
elif self.used is not None:
|
||||
lines.append(f" Usage: {self.used} requests")
|
||||
|
||||
# Tavily breakdown
|
||||
breakdown_parts = []
|
||||
if self.search_used is not None:
|
||||
breakdown_parts.append(f"Search: {self.search_used}")
|
||||
if self.extract_used is not None:
|
||||
breakdown_parts.append(f"Extract: {self.extract_used}")
|
||||
if self.crawl_used is not None:
|
||||
breakdown_parts.append(f"Crawl: {self.crawl_used}")
|
||||
if breakdown_parts:
|
||||
lines.append(f" Breakdown: {' | '.join(breakdown_parts)}")
|
||||
|
||||
if self.remaining is not None:
|
||||
lines.append(f" Remaining: {self.remaining} requests")
|
||||
|
||||
if self.reset_date:
|
||||
lines.append(f" Resets: {self.reset_date}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def fetch_search_usage(
|
||||
provider: str,
|
||||
api_key: str | None = None,
|
||||
) -> SearchUsageInfo:
|
||||
"""
|
||||
Fetch usage info for the configured web search provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name (e.g. "tavily", "brave", "duckduckgo").
|
||||
api_key: API key for the provider (falls back to env vars).
|
||||
|
||||
Returns:
|
||||
SearchUsageInfo with populated fields where available.
|
||||
"""
|
||||
p = (provider or "duckduckgo").strip().lower()
|
||||
|
||||
if p == "tavily":
|
||||
return await _fetch_tavily_usage(api_key)
|
||||
else:
|
||||
# brave, duckduckgo, searxng, jina, unknown — no usage API
|
||||
return SearchUsageInfo(provider=p, supported=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tavily
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _fetch_tavily_usage(api_key: str | None) -> SearchUsageInfo:
|
||||
"""Fetch usage from GET https://api.tavily.com/usage."""
|
||||
import httpx
|
||||
|
||||
key = api_key or os.environ.get("TAVILY_API_KEY", "")
|
||||
if not key:
|
||||
return SearchUsageInfo(
|
||||
provider="tavily",
|
||||
supported=True,
|
||||
error="TAVILY_API_KEY not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=8.0) as client:
|
||||
r = await client.get(
|
||||
"https://api.tavily.com/usage",
|
||||
headers={"Authorization": f"Bearer {key}"},
|
||||
)
|
||||
r.raise_for_status()
|
||||
data: dict[str, Any] = r.json()
|
||||
return _parse_tavily_usage(data)
|
||||
except httpx.HTTPStatusError as e:
|
||||
return SearchUsageInfo(
|
||||
provider="tavily",
|
||||
supported=True,
|
||||
error=f"HTTP {e.response.status_code}",
|
||||
)
|
||||
except Exception as e:
|
||||
return SearchUsageInfo(
|
||||
provider="tavily",
|
||||
supported=True,
|
||||
error=str(e)[:80],
|
||||
)
|
||||
|
||||
|
||||
def _parse_tavily_usage(data: dict[str, Any]) -> SearchUsageInfo:
|
||||
"""
|
||||
Parse Tavily /usage response.
|
||||
|
||||
Actual API response shape:
|
||||
{
|
||||
"account": {
|
||||
"current_plan": "Researcher",
|
||||
"plan_usage": 20,
|
||||
"plan_limit": 1000,
|
||||
"search_usage": 20,
|
||||
"crawl_usage": 0,
|
||||
"extract_usage": 0,
|
||||
"map_usage": 0,
|
||||
"research_usage": 0,
|
||||
"paygo_usage": 0,
|
||||
"paygo_limit": null
|
||||
}
|
||||
}
|
||||
"""
|
||||
account = data.get("account") or {}
|
||||
used = account.get("plan_usage")
|
||||
limit = account.get("plan_limit")
|
||||
|
||||
# Compute remaining
|
||||
remaining = None
|
||||
if used is not None and limit is not None:
|
||||
remaining = max(0, limit - used)
|
||||
|
||||
return SearchUsageInfo(
|
||||
provider="tavily",
|
||||
supported=True,
|
||||
used=used,
|
||||
limit=limit,
|
||||
remaining=remaining,
|
||||
search_used=account.get("search_usage"),
|
||||
extract_used=account.get("extract_usage"),
|
||||
crawl_used=account.get("crawl_usage"),
|
||||
)
|
||||
|
||||
|
||||
119
nanobot/utils/tool_hints.py
Normal file
119
nanobot/utils/tool_hints.py
Normal file
@ -0,0 +1,119 @@
|
||||
"""Tool hint formatting for concise, human-readable tool call display."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from nanobot.utils.path import abbreviate_path
|
||||
|
||||
# Registry: tool_name -> (key_args, template, is_path, is_command)
|
||||
_TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = {
|
||||
"read_file": (["path", "file_path"], "read {}", True, False),
|
||||
"write_file": (["path", "file_path"], "write {}", True, False),
|
||||
"edit": (["file_path", "path"], "edit {}", True, False),
|
||||
"glob": (["pattern"], 'glob "{}"', False, False),
|
||||
"grep": (["pattern"], 'grep "{}"', False, False),
|
||||
"exec": (["command"], "$ {}", False, True),
|
||||
"web_search": (["query"], 'search "{}"', False, False),
|
||||
"web_fetch": (["url"], "fetch {}", True, False),
|
||||
"list_dir": (["path"], "ls {}", True, False),
|
||||
}
|
||||
|
||||
|
||||
def format_tool_hints(tool_calls: list) -> str:
|
||||
"""Format tool calls as concise hints with smart abbreviation."""
|
||||
if not tool_calls:
|
||||
return ""
|
||||
|
||||
hints = []
|
||||
for name, count, example_tc in _group_consecutive(tool_calls):
|
||||
fmt = _TOOL_FORMATS.get(name)
|
||||
if fmt:
|
||||
hint = _fmt_known(example_tc, fmt)
|
||||
elif name.startswith("mcp_"):
|
||||
hint = _fmt_mcp(example_tc)
|
||||
else:
|
||||
hint = _fmt_fallback(example_tc)
|
||||
|
||||
if count > 1:
|
||||
hint = f"{hint} \u00d7 {count}"
|
||||
hints.append(hint)
|
||||
|
||||
return ", ".join(hints)
|
||||
|
||||
|
||||
def _get_args(tc) -> dict:
|
||||
"""Extract args dict from tc.arguments, handling list/dict/None/empty."""
|
||||
if tc.arguments is None:
|
||||
return {}
|
||||
if isinstance(tc.arguments, list):
|
||||
return tc.arguments[0] if tc.arguments else {}
|
||||
if isinstance(tc.arguments, dict):
|
||||
return tc.arguments
|
||||
return {}
|
||||
|
||||
|
||||
def _group_consecutive(calls: list) -> list[tuple[str, int, object]]:
|
||||
"""Group consecutive calls to the same tool: [(name, count, first), ...]."""
|
||||
groups: list[tuple[str, int, object]] = []
|
||||
for tc in calls:
|
||||
if groups and groups[-1][0] == tc.name:
|
||||
groups[-1] = (groups[-1][0], groups[-1][1] + 1, groups[-1][2])
|
||||
else:
|
||||
groups.append((tc.name, 1, tc))
|
||||
return groups
|
||||
|
||||
|
||||
def _extract_arg(tc, key_args: list[str]) -> str | None:
|
||||
"""Extract the first available value from preferred key names."""
|
||||
args = _get_args(tc)
|
||||
if not isinstance(args, dict):
|
||||
return None
|
||||
for key in key_args:
|
||||
val = args.get(key)
|
||||
if isinstance(val, str) and val:
|
||||
return val
|
||||
for val in args.values():
|
||||
if isinstance(val, str) and val:
|
||||
return val
|
||||
return None
|
||||
|
||||
|
||||
def _fmt_known(tc, fmt: tuple) -> str:
|
||||
"""Format a registered tool using its template."""
|
||||
val = _extract_arg(tc, fmt[0])
|
||||
if val is None:
|
||||
return tc.name
|
||||
if fmt[2]: # is_path
|
||||
val = abbreviate_path(val)
|
||||
elif fmt[3]: # is_command
|
||||
val = val[:40] + "\u2026" if len(val) > 40 else val
|
||||
return fmt[1].format(val)
|
||||
|
||||
|
||||
def _fmt_mcp(tc) -> str:
|
||||
"""Format MCP tool as server::tool."""
|
||||
name = tc.name
|
||||
if "__" in name:
|
||||
parts = name.split("__", 1)
|
||||
server = parts[0].removeprefix("mcp_")
|
||||
tool = parts[1]
|
||||
else:
|
||||
rest = name.removeprefix("mcp_")
|
||||
parts = rest.split("_", 1)
|
||||
server = parts[0] if parts else rest
|
||||
tool = parts[1] if len(parts) > 1 else ""
|
||||
if not tool:
|
||||
return name
|
||||
args = _get_args(tc)
|
||||
val = next((v for v in args.values() if isinstance(v, str) and v), None)
|
||||
if val is None:
|
||||
return f"{server}::{tool}"
|
||||
return f'{server}::{tool}("{abbreviate_path(val, 40)}")'
|
||||
|
||||
|
||||
def _fmt_fallback(tc) -> str:
|
||||
"""Original formatting logic for unregistered tools."""
|
||||
args = _get_args(tc)
|
||||
val = next(iter(args.values()), None) if isinstance(args, dict) else None
|
||||
if not isinstance(val, str):
|
||||
return tc.name
|
||||
return f'{tc.name}("{abbreviate_path(val, 40)}")' if len(val) > 40 else f'{tc.name}("{val}")'
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "nanobot-ai"
|
||||
version = "0.1.4.post6"
|
||||
version = "0.1.5"
|
||||
description = "A lightweight personal AI assistant framework"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
requires-python = ">=3.11"
|
||||
|
||||
@ -458,6 +458,7 @@ async def test_runner_uses_raw_messages_when_context_governance_fails():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_retries_empty_final_response_with_summary_prompt():
|
||||
"""Empty responses get 2 silent retries before finalization kicks in."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
@ -465,11 +466,11 @@ async def test_runner_retries_empty_final_response_with_summary_prompt():
|
||||
|
||||
async def chat_with_retry(*, messages, tools=None, **kwargs):
|
||||
calls.append({"messages": messages, "tools": tools})
|
||||
if len(calls) == 1:
|
||||
if len(calls) <= 2:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 1},
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 1},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="final answer",
|
||||
@ -486,20 +487,23 @@ async def test_runner_retries_empty_final_response_with_summary_prompt():
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "final answer"
|
||||
assert len(calls) == 2
|
||||
assert calls[1]["tools"] is None
|
||||
assert "Do not call any more tools" in calls[1]["messages"][-1]["content"]
|
||||
# 2 silent retries (iterations 0,1) + finalization on iteration 1
|
||||
assert len(calls) == 3
|
||||
assert calls[0]["tools"] is not None
|
||||
assert calls[1]["tools"] is not None
|
||||
assert calls[2]["tools"] is None
|
||||
assert result.usage["prompt_tokens"] == 13
|
||||
assert result.usage["completion_tokens"] == 8
|
||||
assert result.usage["completion_tokens"] == 9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_uses_specific_message_after_empty_finalization_retry():
|
||||
"""After silent retries + finalization all return empty, stop_reason is empty_final_response."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
@ -517,7 +521,7 @@ async def test_runner_uses_specific_message_after_empty_finalization_retry():
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
@ -525,6 +529,66 @@ async def test_runner_uses_specific_message_after_empty_finalization_retry():
|
||||
assert result.stop_reason == "empty_final_response"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_empty_response_does_not_break_tool_chain():
|
||||
"""An empty intermediate response must not kill an ongoing tool chain.
|
||||
|
||||
Sequence: tool_call → empty → tool_call → final text.
|
||||
The runner should recover via silent retry and complete normally.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = 0
|
||||
|
||||
async def chat_with_retry(*, messages, tools=None, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a.txt"})],
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||
)
|
||||
if call_count == 2:
|
||||
return LLMResponse(content=None, tool_calls=[], usage={"prompt_tokens": 10, "completion_tokens": 1})
|
||||
if call_count == 3:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[ToolCallRequest(id="tc2", name="read_file", arguments={"path": "b.txt"})],
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5},
|
||||
)
|
||||
return LLMResponse(
|
||||
content="Here are the results.",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 10},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
provider.chat_stream_with_retry = chat_with_retry
|
||||
|
||||
async def fake_tool(name, args, **kw):
|
||||
return "file content"
|
||||
|
||||
tool_registry = MagicMock()
|
||||
tool_registry.get_definitions.return_value = [{"type": "function", "function": {"name": "read_file"}}]
|
||||
tool_registry.execute = AsyncMock(side_effect=fake_tool)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "read both files"}],
|
||||
tools=tool_registry,
|
||||
model="test-model",
|
||||
max_iterations=10,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "Here are the results."
|
||||
assert result.stop_reason == "completed"
|
||||
assert call_count == 4
|
||||
assert "read_file" in result.tools_used
|
||||
|
||||
|
||||
def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch):
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
|
||||
252
tests/agent/test_skills_loader.py
Normal file
252
tests/agent/test_skills_loader.py
Normal file
@ -0,0 +1,252 @@
|
||||
"""Tests for nanobot.agent.skills.SkillsLoader."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
|
||||
|
||||
def _write_skill(
|
||||
base: Path,
|
||||
name: str,
|
||||
*,
|
||||
metadata_json: dict | None = None,
|
||||
body: str = "# Skill\n",
|
||||
) -> Path:
|
||||
"""Create ``base / name / SKILL.md`` with optional nanobot metadata JSON."""
|
||||
skill_dir = base / name
|
||||
skill_dir.mkdir(parents=True)
|
||||
lines = ["---"]
|
||||
if metadata_json is not None:
|
||||
payload = json.dumps({"nanobot": metadata_json}, separators=(",", ":"))
|
||||
lines.append(f'metadata: {payload}')
|
||||
lines.extend(["---", "", body])
|
||||
path = skill_dir / "SKILL.md"
|
||||
path.write_text("\n".join(lines), encoding="utf-8")
|
||||
return path
|
||||
|
||||
|
||||
def test_list_skills_empty_when_skills_dir_missing(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
workspace.mkdir()
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
assert loader.list_skills(filter_unavailable=False) == []
|
||||
|
||||
|
||||
def test_list_skills_empty_when_skills_dir_exists_but_empty(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
(workspace / "skills").mkdir(parents=True)
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
assert loader.list_skills(filter_unavailable=False) == []
|
||||
|
||||
|
||||
def test_list_skills_workspace_entry_shape_and_source(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
skills_root = workspace / "skills"
|
||||
skills_root.mkdir(parents=True)
|
||||
skill_path = _write_skill(skills_root, "alpha", body="# Alpha")
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
entries = loader.list_skills(filter_unavailable=False)
|
||||
assert entries == [
|
||||
{"name": "alpha", "path": str(skill_path), "source": "workspace"},
|
||||
]
|
||||
|
||||
|
||||
def test_list_skills_skips_non_directories_and_missing_skill_md(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
skills_root = workspace / "skills"
|
||||
skills_root.mkdir(parents=True)
|
||||
(skills_root / "not_a_dir.txt").write_text("x", encoding="utf-8")
|
||||
(skills_root / "no_skill_md").mkdir()
|
||||
ok_path = _write_skill(skills_root, "ok", body="# Ok")
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
entries = loader.list_skills(filter_unavailable=False)
|
||||
names = {entry["name"] for entry in entries}
|
||||
assert names == {"ok"}
|
||||
assert entries[0]["path"] == str(ok_path)
|
||||
|
||||
|
||||
def test_list_skills_workspace_shadows_builtin_same_name(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
ws_skills = workspace / "skills"
|
||||
ws_skills.mkdir(parents=True)
|
||||
ws_path = _write_skill(ws_skills, "dup", body="# Workspace wins")
|
||||
|
||||
builtin = tmp_path / "builtin"
|
||||
_write_skill(builtin, "dup", body="# Builtin")
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
entries = loader.list_skills(filter_unavailable=False)
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["source"] == "workspace"
|
||||
assert entries[0]["path"] == str(ws_path)
|
||||
|
||||
|
||||
def test_list_skills_merges_workspace_and_builtin(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
ws_skills = workspace / "skills"
|
||||
ws_skills.mkdir(parents=True)
|
||||
ws_path = _write_skill(ws_skills, "ws_only", body="# W")
|
||||
builtin = tmp_path / "builtin"
|
||||
bi_path = _write_skill(builtin, "bi_only", body="# B")
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
entries = sorted(loader.list_skills(filter_unavailable=False), key=lambda item: item["name"])
|
||||
assert entries == [
|
||||
{"name": "bi_only", "path": str(bi_path), "source": "builtin"},
|
||||
{"name": "ws_only", "path": str(ws_path), "source": "workspace"},
|
||||
]
|
||||
|
||||
|
||||
def test_list_skills_builtin_omitted_when_dir_missing(tmp_path: Path) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
ws_skills = workspace / "skills"
|
||||
ws_skills.mkdir(parents=True)
|
||||
ws_path = _write_skill(ws_skills, "solo", body="# S")
|
||||
missing_builtin = tmp_path / "no_such_builtin"
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=missing_builtin)
|
||||
entries = loader.list_skills(filter_unavailable=False)
|
||||
assert entries == [{"name": "solo", "path": str(ws_path), "source": "workspace"}]
|
||||
|
||||
|
||||
def test_list_skills_filter_unavailable_excludes_unmet_bin_requirement(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
skills_root = workspace / "skills"
|
||||
skills_root.mkdir(parents=True)
|
||||
_write_skill(
|
||||
skills_root,
|
||||
"needs_bin",
|
||||
metadata_json={"requires": {"bins": ["nanobot_test_fake_binary"]}},
|
||||
)
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
def fake_which(cmd: str) -> str | None:
|
||||
if cmd == "nanobot_test_fake_binary":
|
||||
return None
|
||||
return "/usr/bin/true"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.skills.shutil.which", fake_which)
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
assert loader.list_skills(filter_unavailable=True) == []
|
||||
|
||||
|
||||
def test_list_skills_filter_unavailable_includes_when_bin_requirement_met(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
skills_root = workspace / "skills"
|
||||
skills_root.mkdir(parents=True)
|
||||
skill_path = _write_skill(
|
||||
skills_root,
|
||||
"has_bin",
|
||||
metadata_json={"requires": {"bins": ["nanobot_test_fake_binary"]}},
|
||||
)
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
def fake_which(cmd: str) -> str | None:
|
||||
if cmd == "nanobot_test_fake_binary":
|
||||
return "/fake/nanobot_test_fake_binary"
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.skills.shutil.which", fake_which)
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
entries = loader.list_skills(filter_unavailable=True)
|
||||
assert entries == [
|
||||
{"name": "has_bin", "path": str(skill_path), "source": "workspace"},
|
||||
]
|
||||
|
||||
|
||||
def test_list_skills_filter_unavailable_false_keeps_unmet_requirements(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
skills_root = workspace / "skills"
|
||||
skills_root.mkdir(parents=True)
|
||||
skill_path = _write_skill(
|
||||
skills_root,
|
||||
"blocked",
|
||||
metadata_json={"requires": {"bins": ["nanobot_test_fake_binary"]}},
|
||||
)
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.skills.shutil.which", lambda _cmd: None)
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
entries = loader.list_skills(filter_unavailable=False)
|
||||
assert entries == [
|
||||
{"name": "blocked", "path": str(skill_path), "source": "workspace"},
|
||||
]
|
||||
|
||||
|
||||
def test_list_skills_filter_unavailable_excludes_unmet_env_requirement(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
skills_root = workspace / "skills"
|
||||
skills_root.mkdir(parents=True)
|
||||
_write_skill(
|
||||
skills_root,
|
||||
"needs_env",
|
||||
metadata_json={"requires": {"env": ["NANOBOT_SKILLS_TEST_ENV_VAR"]}},
|
||||
)
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
monkeypatch.delenv("NANOBOT_SKILLS_TEST_ENV_VAR", raising=False)
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
assert loader.list_skills(filter_unavailable=True) == []
|
||||
|
||||
|
||||
def test_list_skills_openclaw_metadata_parsed_for_requirements(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
workspace = tmp_path / "ws"
|
||||
skills_root = workspace / "skills"
|
||||
skills_root.mkdir(parents=True)
|
||||
skill_dir = skills_root / "openclaw_skill"
|
||||
skill_dir.mkdir(parents=True)
|
||||
skill_path = skill_dir / "SKILL.md"
|
||||
oc_payload = json.dumps({"openclaw": {"requires": {"bins": ["nanobot_oc_bin"]}}}, separators=(",", ":"))
|
||||
skill_path.write_text(
|
||||
"\n".join(["---", f"metadata: {oc_payload}", "---", "", "# OC"]),
|
||||
encoding="utf-8",
|
||||
)
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.skills.shutil.which", lambda _cmd: None)
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
assert loader.list_skills(filter_unavailable=True) == []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.skills.shutil.which",
|
||||
lambda cmd: "/x" if cmd == "nanobot_oc_bin" else None,
|
||||
)
|
||||
entries = loader.list_skills(filter_unavailable=True)
|
||||
assert entries == [
|
||||
{"name": "openclaw_skill", "path": str(skill_path), "source": "workspace"},
|
||||
]
|
||||
202
tests/agent/test_tool_hint.py
Normal file
202
tests/agent/test_tool_hint.py
Normal file
@ -0,0 +1,202 @@
|
||||
"""Tests for tool hint formatting (nanobot.utils.tool_hints)."""
|
||||
|
||||
from nanobot.utils.tool_hints import format_tool_hints
|
||||
from nanobot.providers.base import ToolCallRequest
|
||||
|
||||
|
||||
def _tc(name: str, args) -> ToolCallRequest:
|
||||
return ToolCallRequest(id="c1", name=name, arguments=args)
|
||||
|
||||
|
||||
def _hint(calls):
|
||||
"""Shortcut for format_tool_hints."""
|
||||
return format_tool_hints(calls)
|
||||
|
||||
|
||||
class TestToolHintKnownTools:
|
||||
"""Test registered tool types produce correct formatted output."""
|
||||
|
||||
def test_read_file_short_path(self):
|
||||
result = _hint([_tc("read_file", {"path": "foo.txt"})])
|
||||
assert result == 'read foo.txt'
|
||||
|
||||
def test_read_file_long_path(self):
|
||||
result = _hint([_tc("read_file", {"path": "/home/user/.local/share/uv/tools/nanobot/agent/loop.py"})])
|
||||
assert "loop.py" in result
|
||||
assert "read " in result
|
||||
|
||||
def test_write_file_shows_path_not_content(self):
|
||||
result = _hint([_tc("write_file", {"path": "docs/api.md", "content": "# API Reference\n\nLong content..."})])
|
||||
assert result == "write docs/api.md"
|
||||
|
||||
def test_edit_shows_path(self):
|
||||
result = _hint([_tc("edit", {"file_path": "src/main.py", "old_string": "x", "new_string": "y"})])
|
||||
assert "main.py" in result
|
||||
assert "edit " in result
|
||||
|
||||
def test_glob_shows_pattern(self):
|
||||
result = _hint([_tc("glob", {"pattern": "**/*.py", "path": "src"})])
|
||||
assert result == 'glob "**/*.py"'
|
||||
|
||||
def test_grep_shows_pattern(self):
|
||||
result = _hint([_tc("grep", {"pattern": "TODO|FIXME", "path": "src"})])
|
||||
assert result == 'grep "TODO|FIXME"'
|
||||
|
||||
def test_exec_shows_command(self):
|
||||
result = _hint([_tc("exec", {"command": "npm install typescript"})])
|
||||
assert result == "$ npm install typescript"
|
||||
|
||||
def test_exec_truncates_long_command(self):
|
||||
cmd = "cd /very/long/path && cat file && echo done && sleep 1 && ls -la"
|
||||
result = _hint([_tc("exec", {"command": cmd})])
|
||||
assert result.startswith("$ ")
|
||||
assert len(result) <= 50 # reasonable limit
|
||||
|
||||
def test_web_search(self):
|
||||
result = _hint([_tc("web_search", {"query": "Claude 4 vs GPT-4"})])
|
||||
assert result == 'search "Claude 4 vs GPT-4"'
|
||||
|
||||
def test_web_fetch(self):
|
||||
result = _hint([_tc("web_fetch", {"url": "https://example.com/page"})])
|
||||
assert result == "fetch https://example.com/page"
|
||||
|
||||
|
||||
class TestToolHintMCP:
|
||||
"""Test MCP tools are abbreviated to server::tool format."""
|
||||
|
||||
def test_mcp_standard_format(self):
|
||||
result = _hint([_tc("mcp_4_5v_mcp__analyze_image", {"imageSource": "https://img.jpg", "prompt": "describe"})])
|
||||
assert "4_5v" in result
|
||||
assert "analyze_image" in result
|
||||
|
||||
def test_mcp_simple_name(self):
|
||||
result = _hint([_tc("mcp_github__create_issue", {"title": "Bug fix"})])
|
||||
assert "github" in result
|
||||
assert "create_issue" in result
|
||||
|
||||
|
||||
class TestToolHintFallback:
|
||||
"""Test unknown tools fall back to original behavior."""
|
||||
|
||||
def test_unknown_tool_with_string_arg(self):
|
||||
result = _hint([_tc("custom_tool", {"data": "hello world"})])
|
||||
assert result == 'custom_tool("hello world")'
|
||||
|
||||
def test_unknown_tool_with_long_arg_truncates(self):
|
||||
long_val = "a" * 60
|
||||
result = _hint([_tc("custom_tool", {"data": long_val})])
|
||||
assert len(result) < 80
|
||||
assert "\u2026" in result
|
||||
|
||||
def test_unknown_tool_no_string_arg(self):
|
||||
result = _hint([_tc("custom_tool", {"count": 42})])
|
||||
assert result == "custom_tool"
|
||||
|
||||
def test_empty_tool_calls(self):
|
||||
result = _hint([])
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestToolHintFolding:
|
||||
"""Test consecutive same-tool calls are folded."""
|
||||
|
||||
def test_single_call_no_fold(self):
|
||||
calls = [_tc("grep", {"pattern": "*.py"})]
|
||||
result = _hint(calls)
|
||||
assert "\u00d7" not in result
|
||||
|
||||
def test_two_consecutive_same_folded(self):
|
||||
calls = [
|
||||
_tc("grep", {"pattern": "*.py"}),
|
||||
_tc("grep", {"pattern": "*.ts"}),
|
||||
]
|
||||
result = _hint(calls)
|
||||
assert "\u00d7 2" in result
|
||||
|
||||
def test_three_consecutive_same_folded(self):
|
||||
calls = [
|
||||
_tc("read_file", {"path": "a.py"}),
|
||||
_tc("read_file", {"path": "b.py"}),
|
||||
_tc("read_file", {"path": "c.py"}),
|
||||
]
|
||||
result = _hint(calls)
|
||||
assert "\u00d7 3" in result
|
||||
|
||||
def test_different_tools_not_folded(self):
|
||||
calls = [
|
||||
_tc("grep", {"pattern": "TODO"}),
|
||||
_tc("read_file", {"path": "a.py"}),
|
||||
]
|
||||
result = _hint(calls)
|
||||
assert "\u00d7" not in result
|
||||
|
||||
def test_interleaved_same_tools_not_folded(self):
|
||||
calls = [
|
||||
_tc("grep", {"pattern": "a"}),
|
||||
_tc("read_file", {"path": "f.py"}),
|
||||
_tc("grep", {"pattern": "b"}),
|
||||
]
|
||||
result = _hint(calls)
|
||||
assert "\u00d7" not in result
|
||||
|
||||
|
||||
class TestToolHintMultipleCalls:
|
||||
"""Test multiple different tool calls are comma-separated."""
|
||||
|
||||
def test_two_different_tools(self):
|
||||
calls = [
|
||||
_tc("grep", {"pattern": "TODO"}),
|
||||
_tc("read_file", {"path": "main.py"}),
|
||||
]
|
||||
result = _hint(calls)
|
||||
assert 'grep "TODO"' in result
|
||||
assert "read main.py" in result
|
||||
assert ", " in result
|
||||
|
||||
|
||||
class TestToolHintEdgeCases:
|
||||
"""Test edge cases and defensive handling (G1, G2)."""
|
||||
|
||||
def test_known_tool_empty_list_args(self):
|
||||
"""C1/G1: Empty list arguments should not crash."""
|
||||
result = _hint([_tc("read_file", [])])
|
||||
assert result == "read_file"
|
||||
|
||||
def test_known_tool_none_args(self):
|
||||
"""G2: None arguments should not crash."""
|
||||
result = _hint([_tc("read_file", None)])
|
||||
assert result == "read_file"
|
||||
|
||||
def test_fallback_empty_list_args(self):
|
||||
"""C1: Empty list args in fallback should not crash."""
|
||||
result = _hint([_tc("custom_tool", [])])
|
||||
assert result == "custom_tool"
|
||||
|
||||
def test_fallback_none_args(self):
|
||||
"""G2: None args in fallback should not crash."""
|
||||
result = _hint([_tc("custom_tool", None)])
|
||||
assert result == "custom_tool"
|
||||
|
||||
def test_list_dir_registered(self):
|
||||
"""S2: list_dir should use 'ls' format."""
|
||||
result = _hint([_tc("list_dir", {"path": "/tmp"})])
|
||||
assert result == "ls /tmp"
|
||||
|
||||
|
||||
class TestToolHintMixedFolding:
|
||||
"""G4: Mixed folding groups with interleaved same-tool segments."""
|
||||
|
||||
def test_read_read_grep_grep_read(self):
|
||||
"""read×2, grep×2, read — should produce two separate groups."""
|
||||
calls = [
|
||||
_tc("read_file", {"path": "a.py"}),
|
||||
_tc("read_file", {"path": "b.py"}),
|
||||
_tc("grep", {"pattern": "x"}),
|
||||
_tc("grep", {"pattern": "y"}),
|
||||
_tc("read_file", {"path": "c.py"}),
|
||||
]
|
||||
result = _hint(calls)
|
||||
assert "\u00d7 2" in result
|
||||
# Should have 3 groups: read×2, grep×2, read
|
||||
parts = result.split(", ")
|
||||
assert len(parts) == 3
|
||||
@ -1,5 +1,6 @@
|
||||
from email.message import EmailMessage
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
import imaplib
|
||||
|
||||
import pytest
|
||||
@ -650,3 +651,224 @@ def test_check_authentication_results_method() -> None:
|
||||
spf, dkim = EmailChannel._check_authentication_results(parsed)
|
||||
assert spf is False
|
||||
assert dkim is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Attachment extraction tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_raw_email_with_attachment(
|
||||
from_addr: str = "alice@example.com",
|
||||
subject: str = "With attachment",
|
||||
body: str = "See attached.",
|
||||
attachment_name: str = "doc.pdf",
|
||||
attachment_content: bytes = b"%PDF-1.4 fake pdf content",
|
||||
attachment_mime: str = "application/pdf",
|
||||
auth_results: str | None = None,
|
||||
) -> bytes:
|
||||
msg = EmailMessage()
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = "bot@example.com"
|
||||
msg["Subject"] = subject
|
||||
msg["Message-ID"] = "<m1@example.com>"
|
||||
if auth_results:
|
||||
msg["Authentication-Results"] = auth_results
|
||||
msg.set_content(body)
|
||||
maintype, subtype = attachment_mime.split("/", 1)
|
||||
msg.add_attachment(
|
||||
attachment_content,
|
||||
maintype=maintype,
|
||||
subtype=subtype,
|
||||
filename=attachment_name,
|
||||
)
|
||||
return msg.as_bytes()
|
||||
|
||||
|
||||
def test_extract_attachments_saves_pdf(tmp_path, monkeypatch) -> None:
|
||||
"""PDF attachment is saved to media dir and path returned in media list."""
|
||||
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
|
||||
|
||||
raw = _make_raw_email_with_attachment()
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(allowed_attachment_types=["application/pdf"], verify_dkim=False, verify_spf=False)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert len(items[0]["media"]) == 1
|
||||
saved_path = Path(items[0]["media"][0])
|
||||
assert saved_path.exists()
|
||||
assert saved_path.read_bytes() == b"%PDF-1.4 fake pdf content"
|
||||
assert "500_doc.pdf" in saved_path.name
|
||||
assert "[attachment:" in items[0]["content"]
|
||||
|
||||
|
||||
def test_extract_attachments_disabled_by_default(monkeypatch) -> None:
|
||||
"""With no allowed_attachment_types (default), no attachments are extracted."""
|
||||
raw = _make_raw_email_with_attachment()
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(verify_dkim=False, verify_spf=False)
|
||||
assert cfg.allowed_attachment_types == []
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["media"] == []
|
||||
assert "[attachment:" not in items[0]["content"]
|
||||
|
||||
|
||||
def test_extract_attachments_mime_type_filter(tmp_path, monkeypatch) -> None:
|
||||
"""Non-allowed MIME types are skipped."""
|
||||
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
|
||||
|
||||
raw = _make_raw_email_with_attachment(
|
||||
attachment_name="image.png",
|
||||
attachment_content=b"\x89PNG fake",
|
||||
attachment_mime="image/png",
|
||||
)
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(
|
||||
allowed_attachment_types=["application/pdf"],
|
||||
verify_dkim=False,
|
||||
verify_spf=False,
|
||||
)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["media"] == []
|
||||
|
||||
|
||||
def test_extract_attachments_empty_allowed_types_rejects_all(tmp_path, monkeypatch) -> None:
|
||||
"""Empty allowed_attachment_types means no types are accepted."""
|
||||
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
|
||||
|
||||
raw = _make_raw_email_with_attachment(
|
||||
attachment_name="image.png",
|
||||
attachment_content=b"\x89PNG fake",
|
||||
attachment_mime="image/png",
|
||||
)
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(
|
||||
allowed_attachment_types=[],
|
||||
verify_dkim=False,
|
||||
verify_spf=False,
|
||||
)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["media"] == []
|
||||
|
||||
|
||||
def test_extract_attachments_wildcard_pattern(tmp_path, monkeypatch) -> None:
|
||||
"""Glob patterns like 'image/*' match attachment MIME types."""
|
||||
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
|
||||
|
||||
raw = _make_raw_email_with_attachment(
|
||||
attachment_name="photo.jpg",
|
||||
attachment_content=b"\xff\xd8\xff fake jpeg",
|
||||
attachment_mime="image/jpeg",
|
||||
)
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(
|
||||
allowed_attachment_types=["image/*"],
|
||||
verify_dkim=False,
|
||||
verify_spf=False,
|
||||
)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert len(items[0]["media"]) == 1
|
||||
|
||||
|
||||
def test_extract_attachments_size_limit(tmp_path, monkeypatch) -> None:
|
||||
"""Attachments exceeding max_attachment_size are skipped."""
|
||||
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
|
||||
|
||||
raw = _make_raw_email_with_attachment(
|
||||
attachment_content=b"x" * 1000,
|
||||
)
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(
|
||||
allowed_attachment_types=["*"],
|
||||
max_attachment_size=500,
|
||||
verify_dkim=False,
|
||||
verify_spf=False,
|
||||
)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["media"] == []
|
||||
|
||||
|
||||
def test_extract_attachments_max_count(tmp_path, monkeypatch) -> None:
|
||||
"""Only max_attachments_per_email are saved."""
|
||||
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
|
||||
|
||||
# Build email with 3 attachments
|
||||
msg = EmailMessage()
|
||||
msg["From"] = "alice@example.com"
|
||||
msg["To"] = "bot@example.com"
|
||||
msg["Subject"] = "Many attachments"
|
||||
msg["Message-ID"] = "<m1@example.com>"
|
||||
msg.set_content("See attached.")
|
||||
for i in range(3):
|
||||
msg.add_attachment(
|
||||
f"content {i}".encode(),
|
||||
maintype="application",
|
||||
subtype="pdf",
|
||||
filename=f"doc{i}.pdf",
|
||||
)
|
||||
raw = msg.as_bytes()
|
||||
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(
|
||||
allowed_attachment_types=["*"],
|
||||
max_attachments_per_email=2,
|
||||
verify_dkim=False,
|
||||
verify_spf=False,
|
||||
)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert len(items[0]["media"]) == 2
|
||||
|
||||
|
||||
def test_extract_attachments_sanitizes_filename(tmp_path, monkeypatch) -> None:
|
||||
"""Path traversal in filenames is neutralized."""
|
||||
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
|
||||
|
||||
raw = _make_raw_email_with_attachment(
|
||||
attachment_name="../../../etc/passwd",
|
||||
)
|
||||
fake = _make_fake_imap(raw)
|
||||
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||
|
||||
cfg = _make_config(allowed_attachment_types=["*"], verify_dkim=False, verify_spf=False)
|
||||
channel = EmailChannel(cfg, MessageBus())
|
||||
items = channel._fetch_new_messages()
|
||||
|
||||
assert len(items) == 1
|
||||
assert len(items[0]["media"]) == 1
|
||||
saved_path = Path(items[0]["media"][0])
|
||||
# File must be inside the media dir, not escaped via path traversal
|
||||
assert saved_path.parent == tmp_path
|
||||
|
||||
62
tests/channels/test_feishu_mention.py
Normal file
62
tests/channels/test_feishu_mention.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""Tests for Feishu _is_bot_mentioned logic."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
|
||||
|
||||
def _make_channel(bot_open_id: str | None = None) -> FeishuChannel:
|
||||
config = SimpleNamespace(
|
||||
app_id="test_id",
|
||||
app_secret="test_secret",
|
||||
verification_token="",
|
||||
event_encrypt_key="",
|
||||
group_policy="mention",
|
||||
)
|
||||
ch = FeishuChannel.__new__(FeishuChannel)
|
||||
ch.config = config
|
||||
ch._bot_open_id = bot_open_id
|
||||
return ch
|
||||
|
||||
|
||||
def _make_message(mentions=None, content="hello"):
|
||||
return SimpleNamespace(content=content, mentions=mentions)
|
||||
|
||||
|
||||
def _make_mention(open_id: str, user_id: str | None = None):
|
||||
mid = SimpleNamespace(open_id=open_id, user_id=user_id)
|
||||
return SimpleNamespace(id=mid)
|
||||
|
||||
|
||||
class TestIsBotMentioned:
|
||||
def test_exact_match_with_bot_open_id(self):
|
||||
ch = _make_channel(bot_open_id="ou_bot123")
|
||||
msg = _make_message(mentions=[_make_mention("ou_bot123")])
|
||||
assert ch._is_bot_mentioned(msg) is True
|
||||
|
||||
def test_no_match_different_bot(self):
|
||||
ch = _make_channel(bot_open_id="ou_bot123")
|
||||
msg = _make_message(mentions=[_make_mention("ou_other_bot")])
|
||||
assert ch._is_bot_mentioned(msg) is False
|
||||
|
||||
def test_at_all_always_matches(self):
|
||||
ch = _make_channel(bot_open_id="ou_bot123")
|
||||
msg = _make_message(content="@_all hello")
|
||||
assert ch._is_bot_mentioned(msg) is True
|
||||
|
||||
def test_fallback_heuristic_when_no_bot_open_id(self):
|
||||
ch = _make_channel(bot_open_id=None)
|
||||
msg = _make_message(mentions=[_make_mention("ou_some_bot", user_id=None)])
|
||||
assert ch._is_bot_mentioned(msg) is True
|
||||
|
||||
def test_fallback_ignores_user_mentions(self):
|
||||
ch = _make_channel(bot_open_id=None)
|
||||
msg = _make_message(mentions=[_make_mention("ou_user", user_id="u_12345")])
|
||||
assert ch._is_bot_mentioned(msg) is False
|
||||
|
||||
def test_no_mentions_returns_false(self):
|
||||
ch = _make_channel(bot_open_id="ou_bot123")
|
||||
msg = _make_message(mentions=None)
|
||||
assert ch._is_bot_mentioned(msg) is False
|
||||
59
tests/channels/test_feishu_mentions.py
Normal file
59
tests/channels/test_feishu_mentions.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""Tests for FeishuChannel._resolve_mentions."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.channels.feishu import FeishuChannel
|
||||
|
||||
|
||||
def _mention(key: str, name: str, open_id: str = "", user_id: str = ""):
|
||||
"""Build a mock MentionEvent-like object."""
|
||||
id_obj = SimpleNamespace(open_id=open_id, user_id=user_id) if (open_id or user_id) else None
|
||||
return SimpleNamespace(key=key, name=name, id=id_obj)
|
||||
|
||||
|
||||
class TestResolveMentions:
|
||||
def test_single_mention_replaced(self):
|
||||
text = "hello @_user_1 how are you"
|
||||
mentions = [_mention("@_user_1", "Alice", open_id="ou_abc123")]
|
||||
result = FeishuChannel._resolve_mentions(text, mentions)
|
||||
assert "@Alice (ou_abc123)" in result
|
||||
assert "@_user_1" not in result
|
||||
|
||||
def test_mention_with_both_ids(self):
|
||||
text = "@_user_1 said hi"
|
||||
mentions = [_mention("@_user_1", "Bob", open_id="ou_abc", user_id="uid_456")]
|
||||
result = FeishuChannel._resolve_mentions(text, mentions)
|
||||
assert "@Bob (ou_abc, user id: uid_456)" in result
|
||||
|
||||
def test_mention_no_id_skipped(self):
|
||||
"""When mention has no id object, the placeholder is left unchanged."""
|
||||
text = "@_user_1 said hi"
|
||||
mentions = [SimpleNamespace(key="@_user_1", name="Charlie", id=None)]
|
||||
result = FeishuChannel._resolve_mentions(text, mentions)
|
||||
assert result == "@_user_1 said hi"
|
||||
|
||||
def test_multiple_mentions(self):
|
||||
text = "@_user_1 and @_user_2 are here"
|
||||
mentions = [
|
||||
_mention("@_user_1", "Alice", open_id="ou_a"),
|
||||
_mention("@_user_2", "Bob", open_id="ou_b"),
|
||||
]
|
||||
result = FeishuChannel._resolve_mentions(text, mentions)
|
||||
assert "@Alice (ou_a)" in result
|
||||
assert "@Bob (ou_b)" in result
|
||||
assert "@_user_1" not in result
|
||||
assert "@_user_2" not in result
|
||||
|
||||
def test_no_mentions_returns_text(self):
|
||||
assert FeishuChannel._resolve_mentions("hello world", None) == "hello world"
|
||||
assert FeishuChannel._resolve_mentions("hello world", []) == "hello world"
|
||||
|
||||
def test_empty_text_returns_empty(self):
|
||||
mentions = [_mention("@_user_1", "Alice", open_id="ou_a")]
|
||||
assert FeishuChannel._resolve_mentions("", mentions) == ""
|
||||
|
||||
def test_mention_key_not_in_text_skipped(self):
|
||||
text = "hello world"
|
||||
mentions = [_mention("@_user_99", "Ghost", open_id="ou_ghost")]
|
||||
result = FeishuChannel._resolve_mentions(text, mentions)
|
||||
assert result == "hello world"
|
||||
@ -127,6 +127,79 @@ async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_new_format_basic(mock_feishu_channel):
|
||||
"""New format hints (read path, grep "pattern") should parse correctly."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content='read src/main.py, grep "TODO"',
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
content = json.loads(mock_send.call_args[0][3])
|
||||
md = content["elements"][0]["content"]
|
||||
assert "read src/main.py" in md
|
||||
assert 'grep "TODO"' in md
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_new_format_with_comma_in_quotes(mock_feishu_channel):
|
||||
"""Commas inside quoted arguments must not cause incorrect line splits."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content='grep "hello, world", $ echo test',
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
content = json.loads(mock_send.call_args[0][3])
|
||||
md = content["elements"][0]["content"]
|
||||
# The comma inside quotes should NOT cause a line break
|
||||
assert 'grep "hello, world"' in md
|
||||
assert "$ echo test" in md
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_new_format_with_folding(mock_feishu_channel):
|
||||
"""Folded calls (× N) should display on separate lines."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content='read path × 3, grep "pattern"',
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
content = json.loads(mock_send.call_args[0][3])
|
||||
md = content["elements"][0]["content"]
|
||||
assert "\u00d7 3" in md
|
||||
assert 'grep "pattern"' in md
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
async def test_tool_hint_new_format_mcp(mock_feishu_channel):
|
||||
"""MCP tool format (server::tool) should parse correctly."""
|
||||
msg = OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_123456",
|
||||
content='4_5v::analyze_image("photo.jpg")',
|
||||
metadata={"_tool_hint": True}
|
||||
)
|
||||
|
||||
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
|
||||
await mock_feishu_channel.send(msg)
|
||||
|
||||
content = json.loads(mock_send.call_args[0][3])
|
||||
md = content["elements"][0]["content"]
|
||||
assert "4_5v::analyze_image" in md
|
||||
async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
|
||||
"""Commas inside a single tool argument must not be split onto a new line."""
|
||||
msg = OutboundMessage(
|
||||
|
||||
@ -385,6 +385,32 @@ async def test_send_delta_stream_end_treats_not_modified_as_success() -> None:
|
||||
assert "123" not in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_splits_oversized_reply() -> None:
|
||||
"""Final streamed reply exceeding Telegram limit is split into chunks."""
|
||||
from nanobot.channels.telegram import TELEGRAM_MAX_MESSAGE_LEN
|
||||
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
channel._app.bot.edit_message_text = AsyncMock()
|
||||
channel._app.bot.send_message = AsyncMock(return_value=SimpleNamespace(message_id=99))
|
||||
|
||||
oversized = "x" * (TELEGRAM_MAX_MESSAGE_LEN + 500)
|
||||
channel._stream_bufs["123"] = _StreamBuf(text=oversized, message_id=7, last_edit=0.0)
|
||||
|
||||
await channel.send_delta("123", "", {"_stream_end": True})
|
||||
|
||||
channel._app.bot.edit_message_text.assert_called_once()
|
||||
edit_text = channel._app.bot.edit_message_text.call_args.kwargs.get("text", "")
|
||||
assert len(edit_text) <= TELEGRAM_MAX_MESSAGE_LEN
|
||||
|
||||
channel._app.bot.send_message.assert_called_once()
|
||||
assert "123" not in channel._stream_bufs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_new_stream_id_replaces_stale_buffer() -> None:
|
||||
channel = TelegramChannel(
|
||||
@ -424,6 +450,23 @@ async def test_send_delta_incremental_edit_treats_not_modified_as_success() -> N
|
||||
assert channel._stream_bufs["123"].last_edit > 0.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_initial_send_keeps_message_in_thread() -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
|
||||
await channel.send_delta(
|
||||
"123",
|
||||
"hello",
|
||||
{"_stream_delta": True, "_stream_id": "s:0", "message_thread_id": 42},
|
||||
)
|
||||
|
||||
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
|
||||
|
||||
|
||||
def test_derive_topic_session_key_uses_thread_id() -> None:
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type="supergroup"),
|
||||
@ -434,6 +477,27 @@ def test_derive_topic_session_key_uses_thread_id() -> None:
|
||||
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
|
||||
|
||||
|
||||
def test_derive_topic_session_key_private_dm_thread() -> None:
|
||||
"""Private DM threads (Telegram Threaded Mode) must get their own session key."""
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type="private"),
|
||||
chat_id=999,
|
||||
message_thread_id=7,
|
||||
)
|
||||
assert TelegramChannel._derive_topic_session_key(message) == "telegram:999:topic:7"
|
||||
|
||||
|
||||
def test_derive_topic_session_key_none_without_thread() -> None:
|
||||
"""No thread id → no topic session key, regardless of chat type."""
|
||||
for chat_type in ("private", "supergroup", "group"):
|
||||
message = SimpleNamespace(
|
||||
chat=SimpleNamespace(type=chat_type),
|
||||
chat_id=123,
|
||||
message_thread_id=None,
|
||||
)
|
||||
assert TelegramChannel._derive_topic_session_key(message) is None
|
||||
|
||||
|
||||
def test_get_extension_falls_back_to_original_filename() -> None:
|
||||
channel = TelegramChannel(TelegramConfig(), MessageBus())
|
||||
|
||||
|
||||
@ -163,6 +163,107 @@ async def test_group_policy_mention_accepts_mentioned_group_message():
|
||||
assert kwargs["sender_id"] == "user"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sender_id_prefers_phone_jid_over_lid():
|
||||
"""sender_id should resolve to phone number when @s.whatsapp.net JID is present."""
|
||||
ch = WhatsAppChannel({"enabled": True}, MagicMock())
|
||||
ch._handle_message = AsyncMock()
|
||||
|
||||
await ch._handle_bridge_message(
|
||||
json.dumps({
|
||||
"type": "message",
|
||||
"id": "lid1",
|
||||
"sender": "ABC123@lid.whatsapp.net",
|
||||
"pn": "5551234@s.whatsapp.net",
|
||||
"content": "hi",
|
||||
"timestamp": 1,
|
||||
})
|
||||
)
|
||||
|
||||
kwargs = ch._handle_message.await_args.kwargs
|
||||
assert kwargs["sender_id"] == "5551234"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lid_to_phone_cache_resolves_lid_only_messages():
|
||||
"""When only LID is present, a cached LID→phone mapping should be used."""
|
||||
ch = WhatsAppChannel({"enabled": True}, MagicMock())
|
||||
ch._handle_message = AsyncMock()
|
||||
|
||||
# First message: both phone and LID → builds cache
|
||||
await ch._handle_bridge_message(
|
||||
json.dumps({
|
||||
"type": "message",
|
||||
"id": "c1",
|
||||
"sender": "LID99@lid.whatsapp.net",
|
||||
"pn": "5559999@s.whatsapp.net",
|
||||
"content": "first",
|
||||
"timestamp": 1,
|
||||
})
|
||||
)
|
||||
# Second message: only LID, no phone
|
||||
await ch._handle_bridge_message(
|
||||
json.dumps({
|
||||
"type": "message",
|
||||
"id": "c2",
|
||||
"sender": "LID99@lid.whatsapp.net",
|
||||
"pn": "",
|
||||
"content": "second",
|
||||
"timestamp": 2,
|
||||
})
|
||||
)
|
||||
|
||||
second_kwargs = ch._handle_message.await_args_list[1].kwargs
|
||||
assert second_kwargs["sender_id"] == "5559999"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_message_transcription_uses_media_path():
|
||||
"""Voice messages are transcribed when media path is available."""
|
||||
ch = WhatsAppChannel({"enabled": True}, MagicMock())
|
||||
ch.transcription_provider = "openai"
|
||||
ch.transcription_api_key = "sk-test"
|
||||
ch._handle_message = AsyncMock()
|
||||
ch.transcribe_audio = AsyncMock(return_value="Hello world")
|
||||
|
||||
await ch._handle_bridge_message(
|
||||
json.dumps({
|
||||
"type": "message",
|
||||
"id": "v1",
|
||||
"sender": "12345@s.whatsapp.net",
|
||||
"pn": "",
|
||||
"content": "[Voice Message]",
|
||||
"timestamp": 1,
|
||||
"media": ["/tmp/voice.ogg"],
|
||||
})
|
||||
)
|
||||
|
||||
ch.transcribe_audio.assert_awaited_once_with("/tmp/voice.ogg")
|
||||
kwargs = ch._handle_message.await_args.kwargs
|
||||
assert kwargs["content"].startswith("Hello world")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_message_no_media_shows_not_available():
|
||||
"""Voice messages without media produce a fallback placeholder."""
|
||||
ch = WhatsAppChannel({"enabled": True}, MagicMock())
|
||||
ch._handle_message = AsyncMock()
|
||||
|
||||
await ch._handle_bridge_message(
|
||||
json.dumps({
|
||||
"type": "message",
|
||||
"id": "v2",
|
||||
"sender": "12345@s.whatsapp.net",
|
||||
"pn": "",
|
||||
"content": "[Voice Message]",
|
||||
"timestamp": 1,
|
||||
})
|
||||
)
|
||||
|
||||
kwargs = ch._handle_message.await_args.kwargs
|
||||
assert kwargs["content"] == "[Voice Message: Audio not available]"
|
||||
|
||||
|
||||
def test_load_or_create_bridge_token_persists_generated_secret(tmp_path):
|
||||
token_path = tmp_path / "whatsapp-auth" / "bridge-token"
|
||||
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@ -9,6 +11,7 @@ from typer.testing import CliRunner
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.cli.commands import _make_provider, app
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.cron.types import CronJob, CronPayload
|
||||
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
@ -19,11 +22,6 @@ class _StopGatewayError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_paths():
|
||||
"""Mock config/workspace paths for test isolation."""
|
||||
@ -31,7 +29,6 @@ def mock_paths():
|
||||
patch("nanobot.config.loader.save_config") as mock_sc, \
|
||||
patch("nanobot.config.loader.load_config") as mock_lc, \
|
||||
patch("nanobot.cli.commands.get_workspace_path") as mock_ws:
|
||||
|
||||
base_dir = Path("./test_onboard_data")
|
||||
if base_dir.exists():
|
||||
shutil.rmtree(base_dir)
|
||||
@ -425,13 +422,13 @@ def mock_agent_runtime(tmp_path):
|
||||
config.agents.defaults.workspace = str(tmp_path / "default-workspace")
|
||||
|
||||
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
|
||||
patch("nanobot.config.loader.resolve_config_env_vars", side_effect=lambda c: c), \
|
||||
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
|
||||
patch("nanobot.cli.commands._make_provider", return_value=object()), \
|
||||
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
|
||||
patch("nanobot.bus.queue.MessageBus"), \
|
||||
patch("nanobot.cron.service.CronService"), \
|
||||
patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls:
|
||||
|
||||
agent_loop = MagicMock()
|
||||
agent_loop.channels_config = None
|
||||
agent_loop.process_direct = AsyncMock(
|
||||
@ -656,7 +653,9 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron(
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||
|
||||
@ -739,6 +738,7 @@ def _patch_cli_command_runtime(
|
||||
set_config_path or (lambda _path: None),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.config.loader.resolve_config_env_vars", lambda c: c)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.sync_workspace_templates",
|
||||
sync_templates or (lambda _path: None),
|
||||
@ -868,6 +868,115 @@ def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path:
|
||||
assert seen["cron_store"] == config.workspace_path / "cron" / "jobs.json"
|
||||
|
||||
|
||||
def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
config_file.write_text("{}")
|
||||
|
||||
config = Config()
|
||||
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
|
||||
provider = object()
|
||||
bus = MagicMock()
|
||||
bus.publish_outbound = AsyncMock()
|
||||
seen: dict[str, object] = {}
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider)
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
class _FakeCron:
|
||||
def __init__(self, _store_path: Path) -> None:
|
||||
self.on_job = None
|
||||
seen["cron"] = self
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
self.model = "test-model"
|
||||
self.tools = {}
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="user-1",
|
||||
content="Time to stretch.",
|
||||
)
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
async def run(self) -> None:
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
return None
|
||||
|
||||
class _StopAfterCronSetup:
|
||||
def __init__(self, *_args, **_kwargs) -> None:
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
async def _capture_evaluate_response(
|
||||
response: str,
|
||||
task_context: str,
|
||||
provider_arg: object,
|
||||
model: str,
|
||||
) -> bool:
|
||||
seen["response"] = response
|
||||
seen["task_context"] = task_context
|
||||
seen["provider"] = provider_arg
|
||||
seen["model"] = model
|
||||
return True
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _StopAfterCronSetup)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.evaluator.evaluate_response",
|
||||
_capture_evaluate_response,
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
cron = seen["cron"]
|
||||
assert isinstance(cron, _FakeCron)
|
||||
assert cron.on_job is not None
|
||||
|
||||
job = CronJob(
|
||||
id="cron-1",
|
||||
name="stretch",
|
||||
payload=CronPayload(
|
||||
message="Remind me to stretch.",
|
||||
deliver=True,
|
||||
channel="telegram",
|
||||
to="user-1",
|
||||
),
|
||||
)
|
||||
|
||||
response = asyncio.run(cron.on_job(job))
|
||||
|
||||
assert response == "Time to stretch."
|
||||
assert seen["response"] == "Time to stretch."
|
||||
assert seen["provider"] is provider
|
||||
assert seen["model"] == "test-model"
|
||||
assert seen["task_context"] == (
|
||||
"[Scheduled Task] Timer finished.\n\n"
|
||||
"Task 'stretch' has been triggered.\n"
|
||||
"Scheduled instruction: Remind me to stretch."
|
||||
)
|
||||
bus.publish_outbound.assert_awaited_once_with(
|
||||
OutboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="user-1",
|
||||
content="Time to stretch.",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_gateway_workspace_override_does_not_migrate_legacy_cron(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
|
||||
44
tests/cli/test_safe_file_history.py
Normal file
44
tests/cli/test_safe_file_history.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Regression tests for SafeFileHistory (issue #2846).
|
||||
|
||||
Surrogate characters in CLI input must not crash history file writes.
|
||||
"""
|
||||
|
||||
from nanobot.cli.commands import SafeFileHistory
|
||||
|
||||
|
||||
class TestSafeFileHistory:
|
||||
def test_surrogate_replaced(self, tmp_path):
|
||||
"""Surrogate pairs are replaced with U+FFFD, not crash."""
|
||||
hist = SafeFileHistory(str(tmp_path / "history"))
|
||||
hist.store_string("hello \udce9 world")
|
||||
entries = list(hist.load_history_strings())
|
||||
assert len(entries) == 1
|
||||
assert "\udce9" not in entries[0]
|
||||
assert "hello" in entries[0]
|
||||
assert "world" in entries[0]
|
||||
|
||||
def test_normal_text_unchanged(self, tmp_path):
|
||||
hist = SafeFileHistory(str(tmp_path / "history"))
|
||||
hist.store_string("normal ascii text")
|
||||
entries = list(hist.load_history_strings())
|
||||
assert entries[0] == "normal ascii text"
|
||||
|
||||
def test_emoji_preserved(self, tmp_path):
|
||||
hist = SafeFileHistory(str(tmp_path / "history"))
|
||||
hist.store_string("hello 🐈 nanobot")
|
||||
entries = list(hist.load_history_strings())
|
||||
assert entries[0] == "hello 🐈 nanobot"
|
||||
|
||||
def test_mixed_unicode_preserved(self, tmp_path):
|
||||
"""CJK + emoji + latin should all pass through cleanly."""
|
||||
hist = SafeFileHistory(str(tmp_path / "history"))
|
||||
hist.store_string("你好 hello こんにちは 🎉")
|
||||
entries = list(hist.load_history_strings())
|
||||
assert entries[0] == "你好 hello こんにちは 🎉"
|
||||
|
||||
def test_multiple_surrogates(self, tmp_path):
|
||||
hist = SafeFileHistory(str(tmp_path / "history"))
|
||||
hist.store_string("\udce9\udcf1\udcff")
|
||||
entries = list(hist.load_history_strings())
|
||||
assert len(entries) == 1
|
||||
assert "\udce9" not in entries[0]
|
||||
82
tests/config/test_env_interpolation.py
Normal file
82
tests/config/test_env_interpolation.py
Normal file
@ -0,0 +1,82 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.loader import (
|
||||
_resolve_env_vars,
|
||||
load_config,
|
||||
resolve_config_env_vars,
|
||||
save_config,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveEnvVars:
|
||||
def test_replaces_string_value(self, monkeypatch):
|
||||
monkeypatch.setenv("MY_SECRET", "hunter2")
|
||||
assert _resolve_env_vars("${MY_SECRET}") == "hunter2"
|
||||
|
||||
def test_partial_replacement(self, monkeypatch):
|
||||
monkeypatch.setenv("HOST", "example.com")
|
||||
assert _resolve_env_vars("https://${HOST}/api") == "https://example.com/api"
|
||||
|
||||
def test_multiple_vars_in_one_string(self, monkeypatch):
|
||||
monkeypatch.setenv("USER", "alice")
|
||||
monkeypatch.setenv("PASS", "secret")
|
||||
assert _resolve_env_vars("${USER}:${PASS}") == "alice:secret"
|
||||
|
||||
def test_nested_dicts(self, monkeypatch):
|
||||
monkeypatch.setenv("TOKEN", "abc123")
|
||||
data = {"channels": {"telegram": {"token": "${TOKEN}"}}}
|
||||
result = _resolve_env_vars(data)
|
||||
assert result["channels"]["telegram"]["token"] == "abc123"
|
||||
|
||||
def test_lists(self, monkeypatch):
|
||||
monkeypatch.setenv("VAL", "x")
|
||||
assert _resolve_env_vars(["${VAL}", "plain"]) == ["x", "plain"]
|
||||
|
||||
def test_ignores_non_strings(self):
|
||||
assert _resolve_env_vars(42) == 42
|
||||
assert _resolve_env_vars(True) is True
|
||||
assert _resolve_env_vars(None) is None
|
||||
assert _resolve_env_vars(3.14) == 3.14
|
||||
|
||||
def test_plain_strings_unchanged(self):
|
||||
assert _resolve_env_vars("no vars here") == "no vars here"
|
||||
|
||||
def test_missing_var_raises(self):
|
||||
with pytest.raises(ValueError, match="DOES_NOT_EXIST"):
|
||||
_resolve_env_vars("${DOES_NOT_EXIST}")
|
||||
|
||||
|
||||
class TestResolveConfig:
|
||||
def test_resolves_env_vars_in_config(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("TEST_API_KEY", "resolved-key")
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{"providers": {"groq": {"apiKey": "${TEST_API_KEY}"}}}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
raw = load_config(config_path)
|
||||
assert raw.providers.groq.api_key == "${TEST_API_KEY}"
|
||||
|
||||
resolved = resolve_config_env_vars(raw)
|
||||
assert resolved.providers.groq.api_key == "resolved-key"
|
||||
|
||||
def test_save_preserves_templates(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("MY_TOKEN", "real-token")
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{"channels": {"telegram": {"token": "${MY_TOKEN}"}}}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
raw = load_config(config_path)
|
||||
save_config(raw, config_path)
|
||||
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["channels"]["telegram"]["token"] == "${MY_TOKEN}"
|
||||
@ -299,7 +299,7 @@ def test_add_cron_job_defaults_to_tool_timezone(tmp_path) -> None:
|
||||
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
|
||||
tool.set_context("telegram", "chat-1")
|
||||
|
||||
result = tool._add_job("Morning standup", None, "0 8 * * *", None, None)
|
||||
result = tool._add_job(None, "Morning standup", None, "0 8 * * *", None, None)
|
||||
|
||||
assert result.startswith("Created job")
|
||||
job = tool._cron.list_jobs()[0]
|
||||
@ -310,7 +310,7 @@ def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None:
|
||||
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
|
||||
tool.set_context("telegram", "chat-1")
|
||||
|
||||
result = tool._add_job("Morning reminder", None, None, None, "2026-03-25T08:00:00")
|
||||
result = tool._add_job(None, "Morning reminder", None, None, None, "2026-03-25T08:00:00")
|
||||
|
||||
assert result.startswith("Created job")
|
||||
job = tool._cron.list_jobs()[0]
|
||||
@ -322,7 +322,7 @@ def test_add_job_delivers_by_default(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool.set_context("telegram", "chat-1")
|
||||
|
||||
result = tool._add_job("Morning standup", 60, None, None, None)
|
||||
result = tool._add_job(None, "Morning standup", 60, None, None, None)
|
||||
|
||||
assert result.startswith("Created job")
|
||||
job = tool._cron.list_jobs()[0]
|
||||
@ -333,7 +333,7 @@ def test_add_job_can_disable_delivery(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool.set_context("telegram", "chat-1")
|
||||
|
||||
result = tool._add_job("Background refresh", 60, None, None, None, deliver=False)
|
||||
result = tool._add_job(None, "Background refresh", 60, None, None, None, deliver=False)
|
||||
|
||||
assert result.startswith("Created job")
|
||||
job = tool._cron.list_jobs()[0]
|
||||
|
||||
@ -307,3 +307,54 @@ async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch)
|
||||
assert result.finish_reason == "error"
|
||||
assert result.content is not None
|
||||
assert "stream stalled" in result.content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider-specific thinking parameters (extra_body)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_kwargs_for(provider_name: str, model: str, reasoning_effort=None):
|
||||
spec = find_by_name(provider_name)
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
p = OpenAICompatProvider(api_key="k", default_model=model, spec=spec)
|
||||
return p._build_kwargs(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=None, model=model, max_tokens=1024, temperature=0.7,
|
||||
reasoning_effort=reasoning_effort, tool_choice=None,
|
||||
)
|
||||
|
||||
|
||||
def test_dashscope_thinking_enabled_with_reasoning_effort() -> None:
|
||||
kw = _build_kwargs_for("dashscope", "qwen3-plus", reasoning_effort="medium")
|
||||
assert kw["extra_body"] == {"enable_thinking": True}
|
||||
|
||||
|
||||
def test_dashscope_thinking_disabled_for_minimal() -> None:
|
||||
kw = _build_kwargs_for("dashscope", "qwen3-plus", reasoning_effort="minimal")
|
||||
assert kw["extra_body"] == {"enable_thinking": False}
|
||||
|
||||
|
||||
def test_dashscope_no_extra_body_when_reasoning_effort_none() -> None:
|
||||
kw = _build_kwargs_for("dashscope", "qwen-turbo", reasoning_effort=None)
|
||||
assert "extra_body" not in kw
|
||||
|
||||
|
||||
def test_volcengine_thinking_enabled() -> None:
|
||||
kw = _build_kwargs_for("volcengine", "doubao-seed-2-0-pro", reasoning_effort="high")
|
||||
assert kw["extra_body"] == {"thinking": {"type": "enabled"}}
|
||||
|
||||
|
||||
def test_byteplus_thinking_disabled_for_minimal() -> None:
|
||||
kw = _build_kwargs_for("byteplus", "doubao-seed-2-0-pro", reasoning_effort="minimal")
|
||||
assert kw["extra_body"] == {"thinking": {"type": "disabled"}}
|
||||
|
||||
|
||||
def test_byteplus_no_extra_body_when_reasoning_effort_none() -> None:
|
||||
kw = _build_kwargs_for("byteplus", "doubao-seed-2-0-pro", reasoning_effort=None)
|
||||
assert "extra_body" not in kw
|
||||
|
||||
|
||||
def test_openai_no_thinking_extra_body() -> None:
|
||||
"""Non-thinking providers should never get extra_body for thinking."""
|
||||
kw = _build_kwargs_for("openai", "gpt-4o", reasoning_effort="medium")
|
||||
assert "extra_body" not in kw
|
||||
|
||||
81
tests/providers/test_provider_error_metadata.py
Normal file
81
tests/providers/test_provider_error_metadata.py
Normal file
@ -0,0 +1,81 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
def _fake_response(
|
||||
*,
|
||||
status_code: int,
|
||||
headers: dict[str, str] | None = None,
|
||||
text: str = "",
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
status_code=status_code,
|
||||
headers=headers or {},
|
||||
text=text,
|
||||
)
|
||||
|
||||
|
||||
def test_openai_handle_error_extracts_structured_metadata() -> None:
|
||||
class FakeStatusError(Exception):
|
||||
pass
|
||||
|
||||
err = FakeStatusError("boom")
|
||||
err.status_code = 409
|
||||
err.response = _fake_response(
|
||||
status_code=409,
|
||||
headers={"retry-after-ms": "250", "x-should-retry": "false"},
|
||||
text='{"error":{"type":"rate_limit_exceeded","code":"rate_limit_exceeded"}}',
|
||||
)
|
||||
err.body = {"error": {"type": "rate_limit_exceeded", "code": "rate_limit_exceeded"}}
|
||||
|
||||
response = OpenAICompatProvider._handle_error(err)
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert response.error_status_code == 409
|
||||
assert response.error_type == "rate_limit_exceeded"
|
||||
assert response.error_code == "rate_limit_exceeded"
|
||||
assert response.error_retry_after_s == 0.25
|
||||
assert response.error_should_retry is False
|
||||
|
||||
|
||||
def test_openai_handle_error_marks_timeout_kind() -> None:
|
||||
class FakeTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
response = OpenAICompatProvider._handle_error(FakeTimeoutError("timeout"))
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert response.error_kind == "timeout"
|
||||
|
||||
|
||||
def test_anthropic_handle_error_extracts_structured_metadata() -> None:
|
||||
class FakeStatusError(Exception):
|
||||
pass
|
||||
|
||||
err = FakeStatusError("boom")
|
||||
err.status_code = 408
|
||||
err.response = _fake_response(
|
||||
status_code=408,
|
||||
headers={"retry-after": "1.5", "x-should-retry": "true"},
|
||||
)
|
||||
err.body = {"type": "error", "error": {"type": "rate_limit_error"}}
|
||||
|
||||
response = AnthropicProvider._handle_error(err)
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert response.error_status_code == 408
|
||||
assert response.error_type == "rate_limit_error"
|
||||
assert response.error_retry_after_s == 1.5
|
||||
assert response.error_should_retry is True
|
||||
|
||||
|
||||
def test_anthropic_handle_error_marks_connection_kind() -> None:
|
||||
class FakeConnectionError(Exception):
|
||||
pass
|
||||
|
||||
response = AnthropicProvider._handle_error(FakeConnectionError("connection"))
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert response.error_kind == "connection"
|
||||
@ -254,6 +254,14 @@ def test_extract_retry_after_from_headers_supports_numeric_and_http_date() -> No
|
||||
) == 0.1
|
||||
|
||||
|
||||
def test_extract_retry_after_from_headers_supports_retry_after_ms() -> None:
|
||||
assert LLMProvider._extract_retry_after_from_headers({"retry-after-ms": "250"}) == 0.25
|
||||
assert LLMProvider._extract_retry_after_from_headers({"Retry-After-Ms": "1000"}) == 1.0
|
||||
assert LLMProvider._extract_retry_after_from_headers(
|
||||
{"retry-after-ms": "500", "retry-after": "10"},
|
||||
) == 0.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_prefers_structured_retry_after_when_present(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
@ -273,6 +281,153 @@ async def test_chat_with_retry_prefers_structured_retry_after_when_present(monke
|
||||
assert delays == [9.0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_retries_structured_status_code_without_keyword(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content="request failed",
|
||||
finish_reason="error",
|
||||
error_status_code=409,
|
||||
),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "ok"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_stops_on_429_quota_exhausted(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content='{"error":{"type":"insufficient_quota","code":"insufficient_quota"}}',
|
||||
finish_reason="error",
|
||||
error_status_code=429,
|
||||
error_type="insufficient_quota",
|
||||
error_code="insufficient_quota",
|
||||
),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert provider.calls == 1
|
||||
assert delays == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_retries_429_transient_rate_limit(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content='{"error":{"type":"rate_limit_exceeded","code":"rate_limit_exceeded"}}',
|
||||
finish_reason="error",
|
||||
error_status_code=429,
|
||||
error_type="rate_limit_exceeded",
|
||||
error_code="rate_limit_exceeded",
|
||||
error_retry_after_s=0.2,
|
||||
),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "ok"
|
||||
assert provider.calls == 2
|
||||
assert delays == [0.2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_retries_structured_timeout_kind(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content="request failed",
|
||||
finish_reason="error",
|
||||
error_kind="timeout",
|
||||
),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "ok"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_structured_should_retry_false_disables_retry(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content="429 rate limit",
|
||||
finish_reason="error",
|
||||
error_should_retry=False,
|
||||
),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert provider.calls == 1
|
||||
assert delays == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_prefers_structured_retry_after(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content="429 rate limit, retry after 99s",
|
||||
finish_reason="error",
|
||||
error_retry_after_s=0.2,
|
||||
),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "ok"
|
||||
assert delays == [0.2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
@ -295,4 +450,3 @@ async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monk
|
||||
assert response.content == "429 rate limit"
|
||||
assert provider.calls == 10
|
||||
assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4]
|
||||
|
||||
|
||||
0
tests/tools/__init__.py
Normal file
0
tests/tools/__init__.py
Normal file
38
tests/tools/test_exec_env.py
Normal file
38
tests/tools/test_exec_env.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""Tests for exec tool environment isolation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_does_not_leak_parent_env(monkeypatch):
|
||||
"""Env vars from the parent process must not be visible to commands."""
|
||||
monkeypatch.setenv("NANOBOT_SECRET_TOKEN", "super-secret-value")
|
||||
tool = ExecTool()
|
||||
result = await tool.execute(command="printenv NANOBOT_SECRET_TOKEN")
|
||||
assert "super-secret-value" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_has_working_path():
|
||||
"""Basic commands should be available via the login shell's PATH."""
|
||||
tool = ExecTool()
|
||||
result = await tool.execute(command="echo hello")
|
||||
assert "hello" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_path_append():
|
||||
"""The pathAppend config should be available in the command's PATH."""
|
||||
tool = ExecTool(path_append="/opt/custom/bin")
|
||||
result = await tool.execute(command="echo $PATH")
|
||||
assert "/opt/custom/bin" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_path_append_preserves_system_path():
|
||||
"""pathAppend must not clobber standard system paths."""
|
||||
tool = ExecTool(path_append="/opt/custom/bin")
|
||||
result = await tool.execute(command="ls /")
|
||||
assert "Exit code: 0" in result
|
||||
@ -112,7 +112,7 @@ class TestMessageToolSuppressLogic:
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
("Visible", False),
|
||||
('read_file("foo.txt")', True),
|
||||
('read foo.txt', True),
|
||||
]
|
||||
|
||||
|
||||
|
||||
121
tests/tools/test_sandbox.py
Normal file
121
tests/tools/test_sandbox.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""Tests for nanobot.agent.tools.sandbox."""
|
||||
|
||||
import shlex
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.sandbox import wrap_command
|
||||
|
||||
|
||||
def _parse(cmd: str) -> list[str]:
|
||||
"""Split a wrapped command back into tokens for assertion."""
|
||||
return shlex.split(cmd)
|
||||
|
||||
|
||||
class TestBwrapBackend:
|
||||
def test_basic_structure(self, tmp_path):
|
||||
ws = str(tmp_path / "project")
|
||||
result = wrap_command("bwrap", "echo hi", ws, ws)
|
||||
tokens = _parse(result)
|
||||
|
||||
assert tokens[0] == "bwrap"
|
||||
assert "--new-session" in tokens
|
||||
assert "--die-with-parent" in tokens
|
||||
assert "--ro-bind" in tokens
|
||||
assert "--proc" in tokens
|
||||
assert "--dev" in tokens
|
||||
assert "--tmpfs" in tokens
|
||||
|
||||
sep = tokens.index("--")
|
||||
assert tokens[sep + 1:] == ["sh", "-c", "echo hi"]
|
||||
|
||||
def test_workspace_bind_mounted_rw(self, tmp_path):
|
||||
ws = str(tmp_path / "project")
|
||||
result = wrap_command("bwrap", "ls", ws, ws)
|
||||
tokens = _parse(result)
|
||||
|
||||
bind_idx = [i for i, t in enumerate(tokens) if t == "--bind"]
|
||||
assert any(tokens[i + 1] == ws and tokens[i + 2] == ws for i in bind_idx)
|
||||
|
||||
def test_parent_dir_masked_with_tmpfs(self, tmp_path):
|
||||
ws = tmp_path / "project"
|
||||
result = wrap_command("bwrap", "ls", str(ws), str(ws))
|
||||
tokens = _parse(result)
|
||||
|
||||
tmpfs_indices = [i for i, t in enumerate(tokens) if t == "--tmpfs"]
|
||||
tmpfs_targets = {tokens[i + 1] for i in tmpfs_indices}
|
||||
assert str(ws.parent) in tmpfs_targets
|
||||
|
||||
def test_cwd_inside_workspace(self, tmp_path):
|
||||
ws = tmp_path / "project"
|
||||
sub = ws / "src" / "lib"
|
||||
result = wrap_command("bwrap", "pwd", str(ws), str(sub))
|
||||
tokens = _parse(result)
|
||||
|
||||
chdir_idx = tokens.index("--chdir")
|
||||
assert tokens[chdir_idx + 1] == str(sub)
|
||||
|
||||
def test_cwd_outside_workspace_falls_back(self, tmp_path):
|
||||
ws = tmp_path / "project"
|
||||
outside = tmp_path / "other"
|
||||
result = wrap_command("bwrap", "pwd", str(ws), str(outside))
|
||||
tokens = _parse(result)
|
||||
|
||||
chdir_idx = tokens.index("--chdir")
|
||||
assert tokens[chdir_idx + 1] == str(ws.resolve())
|
||||
|
||||
def test_command_with_special_characters(self, tmp_path):
|
||||
ws = str(tmp_path / "project")
|
||||
cmd = "echo 'hello world' && cat \"file with spaces.txt\""
|
||||
result = wrap_command("bwrap", cmd, ws, ws)
|
||||
tokens = _parse(result)
|
||||
|
||||
sep = tokens.index("--")
|
||||
assert tokens[sep + 1:] == ["sh", "-c", cmd]
|
||||
|
||||
def test_system_dirs_ro_bound(self, tmp_path):
|
||||
ws = str(tmp_path / "project")
|
||||
result = wrap_command("bwrap", "ls", ws, ws)
|
||||
tokens = _parse(result)
|
||||
|
||||
ro_bind_indices = [i for i, t in enumerate(tokens) if t == "--ro-bind"]
|
||||
ro_targets = {tokens[i + 1] for i in ro_bind_indices}
|
||||
assert "/usr" in ro_targets
|
||||
|
||||
def test_optional_dirs_use_ro_bind_try(self, tmp_path):
|
||||
ws = str(tmp_path / "project")
|
||||
result = wrap_command("bwrap", "ls", ws, ws)
|
||||
tokens = _parse(result)
|
||||
|
||||
try_indices = [i for i, t in enumerate(tokens) if t == "--ro-bind-try"]
|
||||
try_targets = {tokens[i + 1] for i in try_indices}
|
||||
assert "/bin" in try_targets
|
||||
assert "/etc/ssl/certs" in try_targets
|
||||
|
||||
def test_media_dir_ro_bind(self, tmp_path, monkeypatch):
|
||||
"""Media directory should be read-only mounted inside the sandbox."""
|
||||
fake_media = tmp_path / "media"
|
||||
fake_media.mkdir()
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.tools.sandbox.get_media_dir",
|
||||
lambda: fake_media,
|
||||
)
|
||||
ws = str(tmp_path / "project")
|
||||
result = wrap_command("bwrap", "ls", ws, ws)
|
||||
tokens = _parse(result)
|
||||
|
||||
try_indices = [i for i, t in enumerate(tokens) if t == "--ro-bind-try"]
|
||||
try_pairs = {(tokens[i + 1], tokens[i + 2]) for i in try_indices}
|
||||
assert (str(fake_media), str(fake_media)) in try_pairs
|
||||
|
||||
|
||||
class TestUnknownBackend:
|
||||
def test_raises_value_error(self, tmp_path):
|
||||
ws = str(tmp_path / "project")
|
||||
with pytest.raises(ValueError, match="Unknown sandbox backend"):
|
||||
wrap_command("nonexistent", "ls", ws, ws)
|
||||
|
||||
def test_empty_string_raises(self, tmp_path):
|
||||
ws = str(tmp_path / "project")
|
||||
with pytest.raises(ValueError):
|
||||
wrap_command("", "ls", ws, ws)
|
||||
@ -1,3 +1,6 @@
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools import (
|
||||
@ -546,10 +549,15 @@ async def test_exec_head_tail_truncation() -> None:
|
||||
"""Long output should preserve both head and tail."""
|
||||
tool = ExecTool()
|
||||
# Generate output that exceeds _MAX_OUTPUT (10_000 chars)
|
||||
# Use python to generate output to avoid command line length limits
|
||||
result = await tool.execute(
|
||||
command="python -c \"print('A' * 6000 + '\\n' + 'B' * 6000)\""
|
||||
)
|
||||
# Use current interpreter (PATH may not have `python`). ExecTool uses
|
||||
# create_subprocess_shell: POSIX needs shlex.quote; Windows uses cmd.exe
|
||||
# rules, so list2cmdline is appropriate there.
|
||||
script = "print('A' * 6000 + '\\n' + 'B' * 6000)"
|
||||
if sys.platform == "win32":
|
||||
command = subprocess.list2cmdline([sys.executable, "-c", script])
|
||||
else:
|
||||
command = f"{shlex.quote(sys.executable)} -c {shlex.quote(script)}"
|
||||
result = await tool.execute(command=command)
|
||||
assert "chars truncated" in result
|
||||
# Head portion should start with As
|
||||
assert result.startswith("A")
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""Tests for multi-provider web search."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
@ -160,3 +162,70 @@ async def test_searxng_invalid_url():
|
||||
tool = _tool(provider="searxng", base_url="not-a-url")
|
||||
result = await tool.execute(query="test")
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jina_422_falls_back_to_duckduckgo(monkeypatch):
|
||||
class MockDDGS:
|
||||
def __init__(self, **kw):
|
||||
pass
|
||||
|
||||
def text(self, query, max_results=5):
|
||||
return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}]
|
||||
|
||||
async def mock_get(self, url, **kw):
|
||||
assert "s.jina.ai" in str(url)
|
||||
raise httpx.HTTPStatusError(
|
||||
"422 Unprocessable Entity",
|
||||
request=httpx.Request("GET", str(url)),
|
||||
response=httpx.Response(422, request=httpx.Request("GET", str(url))),
|
||||
)
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
|
||||
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||
|
||||
tool = _tool(provider="jina", api_key="jina-key")
|
||||
result = await tool.execute(query="test")
|
||||
assert "DuckDuckGo fallback" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jina_search_uses_path_encoded_query(monkeypatch):
|
||||
calls = {}
|
||||
|
||||
async def mock_get(self, url, **kw):
|
||||
calls["url"] = str(url)
|
||||
calls["params"] = kw.get("params")
|
||||
return _response(json={
|
||||
"data": [{"title": "Jina Result", "url": "https://jina.ai", "content": "AI search"}]
|
||||
})
|
||||
|
||||
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
|
||||
tool = _tool(provider="jina", api_key="jina-key")
|
||||
await tool.execute(query="hello world")
|
||||
assert calls["url"].rstrip("/") == "https://s.jina.ai/hello%20world"
|
||||
assert calls["params"] in (None, {})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duckduckgo_timeout_returns_error(monkeypatch):
|
||||
"""asyncio.wait_for guard should fire when DDG search hangs."""
|
||||
import threading
|
||||
gate = threading.Event()
|
||||
|
||||
class HangingDDGS:
|
||||
def __init__(self, **kw):
|
||||
pass
|
||||
|
||||
def text(self, query, max_results=5):
|
||||
gate.wait(timeout=10)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr("ddgs.DDGS", HangingDDGS)
|
||||
tool = _tool(provider="duckduckgo")
|
||||
tool.config.timeout = 0.2
|
||||
result = await tool.execute(query="test")
|
||||
gate.set()
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
|
||||
105
tests/utils/test_abbreviate_path.py
Normal file
105
tests/utils/test_abbreviate_path.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""Tests for abbreviate_path utility."""
|
||||
|
||||
import os
|
||||
from nanobot.utils.path import abbreviate_path
|
||||
|
||||
|
||||
class TestAbbreviatePathShort:
|
||||
def test_short_path_unchanged(self):
|
||||
assert abbreviate_path("/home/user/file.py") == "/home/user/file.py"
|
||||
|
||||
def test_exact_max_len_unchanged(self):
|
||||
path = "/a/b/c" # 7 chars
|
||||
assert abbreviate_path("/a/b/c", max_len=7) == "/a/b/c"
|
||||
|
||||
def test_basename_only(self):
|
||||
assert abbreviate_path("file.py") == "file.py"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert abbreviate_path("") == ""
|
||||
|
||||
|
||||
class TestAbbreviatePathHome:
|
||||
def test_home_replacement(self):
|
||||
home = os.path.expanduser("~")
|
||||
result = abbreviate_path(f"{home}/project/file.py")
|
||||
assert result.startswith("~/")
|
||||
assert result.endswith("file.py")
|
||||
|
||||
def test_home_preserves_short_path(self):
|
||||
home = os.path.expanduser("~")
|
||||
result = abbreviate_path(f"{home}/a.py")
|
||||
assert result == "~/a.py"
|
||||
|
||||
|
||||
class TestAbbreviatePathLong:
|
||||
def test_long_path_keeps_basename(self):
|
||||
path = "/a/b/c/d/e/f/g/h/very_long_filename.py"
|
||||
result = abbreviate_path(path, max_len=30)
|
||||
assert result.endswith("very_long_filename.py")
|
||||
assert "\u2026" in result
|
||||
|
||||
def test_long_path_keeps_parent_dir(self):
|
||||
path = "/a/b/c/d/e/f/g/h/src/loop.py"
|
||||
result = abbreviate_path(path, max_len=30)
|
||||
assert "loop.py" in result
|
||||
assert "src" in result
|
||||
|
||||
def test_very_long_path_just_basename(self):
|
||||
path = "/a/b/c/d/e/f/g/h/i/j/k/l/m/n/o/p/q/r/s/t/u/v/w/x/y/z/file.py"
|
||||
result = abbreviate_path(path, max_len=20)
|
||||
assert result.endswith("file.py")
|
||||
assert len(result) <= 20
|
||||
|
||||
|
||||
class TestAbbreviatePathWindows:
|
||||
def test_windows_drive_path(self):
|
||||
path = "D:\\Documents\\GitHub\\nanobot\\src\\utils\\helpers.py"
|
||||
result = abbreviate_path(path, max_len=40)
|
||||
assert result.endswith("helpers.py")
|
||||
assert "nanobot" in result
|
||||
|
||||
def test_windows_home(self):
|
||||
home = os.path.expanduser("~")
|
||||
path = os.path.join(home, ".nanobot", "workspace", "log.txt")
|
||||
result = abbreviate_path(path)
|
||||
assert result.startswith("~/")
|
||||
assert "log.txt" in result
|
||||
|
||||
|
||||
class TestAbbreviatePathURLs:
|
||||
def test_url_keeps_domain_and_filename(self):
|
||||
url = "https://example.com/api/v2/long/path/resource.json"
|
||||
result = abbreviate_path(url, max_len=40)
|
||||
assert "resource.json" in result
|
||||
assert "example.com" in result
|
||||
|
||||
def test_short_url_unchanged(self):
|
||||
url = "https://example.com/api"
|
||||
assert abbreviate_path(url) == url
|
||||
|
||||
def test_url_no_path_just_domain(self):
|
||||
"""G3: URL with no path should return as-is if short enough."""
|
||||
url = "https://example.com"
|
||||
assert abbreviate_path(url) == url
|
||||
|
||||
def test_url_with_query_string(self):
|
||||
"""G3: URL with query params should abbreviate path part."""
|
||||
url = "https://example.com/api/v2/endpoint?key=value&other=123"
|
||||
result = abbreviate_path(url, max_len=40)
|
||||
assert "example.com" in result
|
||||
assert "\u2026" in result
|
||||
|
||||
def test_url_very_long_basename(self):
|
||||
"""G3: URL with very long basename should truncate basename."""
|
||||
url = "https://example.com/path/very_long_resource_name_file.json"
|
||||
result = abbreviate_path(url, max_len=35)
|
||||
assert "example.com" in result
|
||||
assert "\u2026" in result
|
||||
|
||||
def test_url_negative_budget_consistent_format(self):
|
||||
"""I3: Negative budget should still produce domain/…/basename format."""
|
||||
url = "https://a.co/very/deep/path/with/lots/of/segments/and/a/long/basename.txt"
|
||||
result = abbreviate_path(url, max_len=20)
|
||||
assert "a.co" in result
|
||||
assert "/\u2026/" in result
|
||||
306
tests/utils/test_searchusage.py
Normal file
306
tests/utils/test_searchusage.py
Normal file
@ -0,0 +1,306 @@
|
||||
"""Tests for web search provider usage fetching and /status integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from nanobot.utils.searchusage import (
|
||||
SearchUsageInfo,
|
||||
_parse_tavily_usage,
|
||||
fetch_search_usage,
|
||||
)
|
||||
from nanobot.utils.helpers import build_status_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SearchUsageInfo.format() tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSearchUsageInfoFormat:
|
||||
def test_unsupported_provider_shows_no_tracking(self):
|
||||
info = SearchUsageInfo(provider="duckduckgo", supported=False)
|
||||
text = info.format()
|
||||
assert "duckduckgo" in text
|
||||
assert "not available" in text
|
||||
|
||||
def test_supported_with_error(self):
|
||||
info = SearchUsageInfo(provider="tavily", supported=True, error="HTTP 401")
|
||||
text = info.format()
|
||||
assert "tavily" in text
|
||||
assert "HTTP 401" in text
|
||||
assert "unavailable" in text
|
||||
|
||||
def test_full_tavily_usage(self):
|
||||
info = SearchUsageInfo(
|
||||
provider="tavily",
|
||||
supported=True,
|
||||
used=142,
|
||||
limit=1000,
|
||||
remaining=858,
|
||||
reset_date="2026-05-01",
|
||||
search_used=120,
|
||||
extract_used=15,
|
||||
crawl_used=7,
|
||||
)
|
||||
text = info.format()
|
||||
assert "tavily" in text
|
||||
assert "142 / 1000" in text
|
||||
assert "858" in text
|
||||
assert "2026-05-01" in text
|
||||
assert "Search: 120" in text
|
||||
assert "Extract: 15" in text
|
||||
assert "Crawl: 7" in text
|
||||
|
||||
def test_usage_without_limit(self):
|
||||
info = SearchUsageInfo(provider="tavily", supported=True, used=50)
|
||||
text = info.format()
|
||||
assert "50 requests" in text
|
||||
assert "/" not in text.split("Usage:")[1].split("\n")[0]
|
||||
|
||||
def test_no_breakdown_when_none(self):
|
||||
info = SearchUsageInfo(
|
||||
provider="tavily", supported=True, used=10, limit=100, remaining=90
|
||||
)
|
||||
text = info.format()
|
||||
assert "Breakdown" not in text
|
||||
|
||||
def test_brave_unsupported(self):
|
||||
info = SearchUsageInfo(provider="brave", supported=False)
|
||||
text = info.format()
|
||||
assert "brave" in text
|
||||
assert "not available" in text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_tavily_usage tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestParseTavilyUsage:
|
||||
def test_full_response(self):
|
||||
data = {
|
||||
"account": {
|
||||
"current_plan": "Researcher",
|
||||
"plan_usage": 142,
|
||||
"plan_limit": 1000,
|
||||
"search_usage": 120,
|
||||
"extract_usage": 15,
|
||||
"crawl_usage": 7,
|
||||
"map_usage": 0,
|
||||
"research_usage": 0,
|
||||
"paygo_usage": 0,
|
||||
"paygo_limit": None,
|
||||
},
|
||||
}
|
||||
info = _parse_tavily_usage(data)
|
||||
assert info.provider == "tavily"
|
||||
assert info.supported is True
|
||||
assert info.used == 142
|
||||
assert info.limit == 1000
|
||||
assert info.remaining == 858
|
||||
assert info.search_used == 120
|
||||
assert info.extract_used == 15
|
||||
assert info.crawl_used == 7
|
||||
|
||||
def test_remaining_computed(self):
|
||||
data = {"account": {"plan_usage": 300, "plan_limit": 1000}}
|
||||
info = _parse_tavily_usage(data)
|
||||
assert info.remaining == 700
|
||||
|
||||
def test_remaining_not_negative(self):
|
||||
data = {"account": {"plan_usage": 1100, "plan_limit": 1000}}
|
||||
info = _parse_tavily_usage(data)
|
||||
assert info.remaining == 0
|
||||
|
||||
def test_empty_response(self):
|
||||
info = _parse_tavily_usage({})
|
||||
assert info.provider == "tavily"
|
||||
assert info.supported is True
|
||||
assert info.used is None
|
||||
assert info.limit is None
|
||||
|
||||
def test_no_breakdown_fields(self):
|
||||
data = {"account": {"plan_usage": 5, "plan_limit": 50}}
|
||||
info = _parse_tavily_usage(data)
|
||||
assert info.search_used is None
|
||||
assert info.extract_used is None
|
||||
assert info.crawl_used is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fetch_search_usage routing tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFetchSearchUsageRouting:
|
||||
@pytest.mark.asyncio
|
||||
async def test_duckduckgo_returns_unsupported(self):
|
||||
info = await fetch_search_usage("duckduckgo")
|
||||
assert info.provider == "duckduckgo"
|
||||
assert info.supported is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_searxng_returns_unsupported(self):
|
||||
info = await fetch_search_usage("searxng")
|
||||
assert info.supported is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jina_returns_unsupported(self):
|
||||
info = await fetch_search_usage("jina")
|
||||
assert info.supported is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brave_returns_unsupported(self):
|
||||
info = await fetch_search_usage("brave")
|
||||
assert info.provider == "brave"
|
||||
assert info.supported is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_provider_returns_unsupported(self):
|
||||
info = await fetch_search_usage("some_unknown_provider")
|
||||
assert info.supported is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tavily_no_api_key_returns_error(self):
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
# Ensure TAVILY_API_KEY is not set
|
||||
import os
|
||||
os.environ.pop("TAVILY_API_KEY", None)
|
||||
info = await fetch_search_usage("tavily", api_key=None)
|
||||
assert info.provider == "tavily"
|
||||
assert info.supported is True
|
||||
assert info.error is not None
|
||||
assert "not configured" in info.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tavily_success(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"account": {
|
||||
"current_plan": "Researcher",
|
||||
"plan_usage": 142,
|
||||
"plan_limit": 1000,
|
||||
"search_usage": 120,
|
||||
"extract_usage": 15,
|
||||
"crawl_usage": 7,
|
||||
},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
info = await fetch_search_usage("tavily", api_key="test-key")
|
||||
|
||||
assert info.provider == "tavily"
|
||||
assert info.supported is True
|
||||
assert info.error is None
|
||||
assert info.used == 142
|
||||
assert info.limit == 1000
|
||||
assert info.remaining == 858
|
||||
assert info.search_used == 120
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tavily_http_error(self):
|
||||
import httpx
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"401", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
info = await fetch_search_usage("tavily", api_key="bad-key")
|
||||
|
||||
assert info.supported is True
|
||||
assert info.error == "HTTP 401"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tavily_network_error(self):
|
||||
import httpx
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(side_effect=httpx.ConnectError("timeout"))
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
info = await fetch_search_usage("tavily", api_key="test-key")
|
||||
|
||||
assert info.supported is True
|
||||
assert info.error is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_name_case_insensitive(self):
|
||||
info = await fetch_search_usage("Tavily", api_key=None)
|
||||
assert info.provider == "tavily"
|
||||
assert info.supported is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_status_content integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildStatusContentWithSearchUsage:
|
||||
_BASE_KWARGS = dict(
|
||||
version="0.1.0",
|
||||
model="claude-opus-4-5",
|
||||
start_time=1_000_000.0,
|
||||
last_usage={"prompt_tokens": 1000, "completion_tokens": 200},
|
||||
context_window_tokens=65536,
|
||||
session_msg_count=5,
|
||||
context_tokens_estimate=3000,
|
||||
)
|
||||
|
||||
def test_no_search_usage_unchanged(self):
|
||||
"""Omitting search_usage_text keeps existing behaviour."""
|
||||
content = build_status_content(**self._BASE_KWARGS)
|
||||
assert "🔍" not in content
|
||||
assert "Web Search" not in content
|
||||
|
||||
def test_search_usage_none_unchanged(self):
|
||||
content = build_status_content(**self._BASE_KWARGS, search_usage_text=None)
|
||||
assert "🔍" not in content
|
||||
|
||||
def test_search_usage_appended(self):
|
||||
usage_text = "🔍 Web Search: tavily\n Usage: 142 / 1000 requests"
|
||||
content = build_status_content(**self._BASE_KWARGS, search_usage_text=usage_text)
|
||||
assert "🔍 Web Search: tavily" in content
|
||||
assert "142 / 1000" in content
|
||||
|
||||
def test_existing_fields_still_present(self):
|
||||
usage_text = "🔍 Web Search: duckduckgo\n Usage tracking: not available"
|
||||
content = build_status_content(**self._BASE_KWARGS, search_usage_text=usage_text)
|
||||
# Original fields must still be present
|
||||
assert "nanobot v0.1.0" in content
|
||||
assert "claude-opus-4-5" in content
|
||||
assert "1000 in / 200 out" in content
|
||||
# New field appended
|
||||
assert "duckduckgo" in content
|
||||
|
||||
def test_full_tavily_in_status(self):
|
||||
info = SearchUsageInfo(
|
||||
provider="tavily",
|
||||
supported=True,
|
||||
used=142,
|
||||
limit=1000,
|
||||
remaining=858,
|
||||
reset_date="2026-05-01",
|
||||
search_used=120,
|
||||
extract_used=15,
|
||||
crawl_used=7,
|
||||
)
|
||||
content = build_status_content(**self._BASE_KWARGS, search_usage_text=info.format())
|
||||
assert "142 / 1000" in content
|
||||
assert "858" in content
|
||||
assert "2026-05-01" in content
|
||||
assert "Search: 120" in content
|
||||
Loading…
x
Reference in New Issue
Block a user