Merge remote-tracking branch 'origin/main' into nightly

This commit is contained in:
chengyongru 2026-04-07 20:49:53 +08:00
commit ba38d41ad1
69 changed files with 4190 additions and 462 deletions

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
# Ensure shell scripts always use LF line endings (Docker/Linux compat)
*.sh text eol=lf

View File

@ -30,5 +30,8 @@ jobs:
- name: Install all dependencies
run: uv sync --all-extras
- name: Lint with ruff
run: uv run ruff check nanobot --select F401,F841
- name: Run tests
run: uv run pytest tests/

View File

@ -2,7 +2,7 @@ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
# Install Node.js 20 for the WhatsApp bridge
RUN apt-get update && \
apt-get install -y --no-install-recommends curl ca-certificates gnupg git openssh-client && \
apt-get install -y --no-install-recommends curl ca-certificates gnupg git bubblewrap openssh-client && \
mkdir -p /etc/apt/keyrings && \
curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \
echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" > /etc/apt/sources.list.d/nodesource.list && \
@ -32,11 +32,19 @@ RUN git config --global --add url."https://github.com/".insteadOf ssh://git@gith
npm install && npm run build
WORKDIR /app
# Create config directory
RUN mkdir -p /root/.nanobot
# Create non-root user and config directory
RUN useradd -m -u 1000 -s /bin/bash nanobot && \
mkdir -p /home/nanobot/.nanobot && \
chown -R nanobot:nanobot /home/nanobot /app
COPY entrypoint.sh /usr/local/bin/entrypoint.sh
RUN sed -i 's/\r$//' /usr/local/bin/entrypoint.sh && chmod +x /usr/local/bin/entrypoint.sh
USER nanobot
ENV HOME=/home/nanobot
# Gateway default port
EXPOSE 18790
ENTRYPOINT ["nanobot"]
ENTRYPOINT ["entrypoint.sh"]
CMD ["status"]

103
README.md
View File

@ -1,39 +1,44 @@
<div align="center">
<img src="nanobot_logo.png" alt="nanobot" width="500">
<h1>nanobot: Ultra-Lightweight Personal AI Assistant</h1>
<h1>nanobot: Ultra-Lightweight Personal AI Agent</h1>
<p>
<a href="https://pypi.org/project/nanobot-ai/"><img src="https://img.shields.io/pypi/v/nanobot-ai" alt="PyPI"></a>
<a href="https://pepy.tech/project/nanobot-ai"><img src="https://static.pepy.tech/badge/nanobot-ai" alt="Downloads"></a>
<img src="https://img.shields.io/badge/python-≥3.11-blue" alt="Python">
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
<a href="https://nanobot.wiki/docs/0.1.5/getting-started/nanobot-overview"><img src="https://img.shields.io/badge/Docs-nanobot.wiki-blue?style=flat&logo=readthedocs&logoColor=white" alt="Docs"></a>
<a href="./COMMUNICATION.md"><img src="https://img.shields.io/badge/Feishu-Group-E9DBFC?style=flat&logo=feishu&logoColor=white" alt="Feishu"></a>
<a href="./COMMUNICATION.md"><img src="https://img.shields.io/badge/WeChat-Group-C5EAB4?style=flat&logo=wechat&logoColor=white" alt="WeChat"></a>
<a href="https://discord.gg/MnCvHqpUGB"><img src="https://img.shields.io/badge/Discord-Community-5865F2?style=flat&logo=discord&logoColor=white" alt="Discord"></a>
</p>
</div>
🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw).
🐈 **nanobot** is an **ultra-lightweight** personal AI agent inspired by [OpenClaw](https://github.com/openclaw/openclaw).
⚡️ Delivers core agent functionality with **99% fewer lines of code** than OpenClaw.
⚡️ Delivers core agent functionality with **99% fewer lines of code**.
📏 Real-time line count: run `bash core_agent_lines.sh` to verify anytime.
## 📢 News
- **2026-04-02** 🧱 **Long-running tasks** run more reliably — core runtime hardening.
- **2026-04-05** 🚀 Released **v0.1.5** — sturdier long-running tasks, Dream two-stage memory, production-ready sandboxing and programming Agent SDK. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.5) for details.
- **2026-04-04** 🚀 Jinja2 response templates, Dream memory hardened, smarter retry handling.
- **2026-04-03** 🧠 Xiaomi MiMo provider, chain-of-thought reasoning visible, Telegram UX polish.
- **2026-04-02** 🧱 Long-running tasks run more reliably — core runtime hardening.
- **2026-04-01** 🔑 GitHub Copilot auth restored; stricter workspace paths; OpenRouter Claude caching fix.
- **2026-03-31** 🛰️ WeChat multimodal alignment, Discord/Matrix polish, Python SDK facade, MCP and tool fixes.
- **2026-03-30** 🧩 OpenAI-compatible API tightened; composable agent lifecycle hooks.
- **2026-03-29** 💬 WeChat voice, typing, QR/media resilience; fixed-session OpenAI-compatible API.
- **2026-03-28** 📚 Provider docs refresh; skill template wording fix.
- **2026-03-27** 🚀 Released **v0.1.4.post6** — architecture decoupling, litellm removal, end-to-end streaming, WeChat channel, and a security fix. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post6) for details.
- **2026-03-26** 🏗️ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries.
- **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures.
- **2026-03-24** 🔧 WeChat compatibility, Feishu CardKit streaming, test suite restructured.
<details>
<summary>Earlier news</summary>
- **2026-03-26** 🏗️ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries.
- **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures.
- **2026-03-24** 🔧 WeChat compatibility, Feishu CardKit streaming, test suite restructured.
- **2026-03-23** 🔧 Command routing refactored for plugins, WhatsApp/WeChat media, unified channel login CLI.
- **2026-03-22** ⚡ End-to-end streaming, WeChat channel, Anthropic cache optimization, `/status` command.
- **2026-03-21** 🔒 Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7).
@ -91,7 +96,7 @@
## Key Features of nanobot:
🪶 **Ultra-Lightweight**: A super lightweight implementation of OpenClaw — 99% smaller, significantly faster.
🪶 **Ultra-Lightweight**: A lightweight implementation built for stable, long-running AI agents.
🔬 **Research-Ready**: Clean, readable code that's easy to understand, modify, and extend for research.
@ -140,7 +145,7 @@
<tr>
<td align="center"><p align="center"><img src="case/search.gif" width="180" height="400"></p></td>
<td align="center"><p align="center"><img src="case/code.gif" width="180" height="400"></p></td>
<td align="center"><p align="center"><img src="case/scedule.gif" width="180" height="400"></p></td>
<td align="center"><p align="center"><img src="case/schedule.gif" width="180" height="400"></p></td>
<td align="center"><p align="center"><img src="case/memory.gif" width="180" height="400"></p></td>
</tr>
<tr>
@ -252,7 +257,7 @@ Configure these **two parts** in your config (other options have defaults).
nanobot agent
```
That's it! You have a working AI assistant in 2 minutes.
That's it! You have a working AI agent in 2 minutes.
## 💬 Chat Apps
@ -433,9 +438,11 @@ pip install nanobot-ai[matrix]
- You need:
- `userId` (example: `@nanobot:matrix.org`)
- `accessToken`
- `deviceId` (recommended so sync tokens can be restored across restarts)
- You can obtain these from your homeserver login API (`/_matrix/client/v3/login`) or from your client's advanced session settings.
- `password`
(Note: `accessToken` and `deviceId` are still supported for legacy reasons, but
for reliable encryption, password login is recommended instead. If the
`password` is provided, `accessToken` and `deviceId` will be ignored.)
**3. Configure**
@ -446,8 +453,7 @@ pip install nanobot-ai[matrix]
"enabled": true,
"homeserver": "https://matrix.org",
"userId": "@nanobot:matrix.org",
"accessToken": "syt_xxx",
"deviceId": "NANOBOT01",
"password": "mypasswordhere",
"e2eeEnabled": true,
"allowFrom": ["@your_user:matrix.org"],
"groupPolicy": "open",
@ -459,7 +465,7 @@ pip install nanobot-ai[matrix]
}
```
> Keep a persistent `matrix-store` and stable `deviceId` — encrypted session state is lost if these change across restarts.
> Keep a persistent `matrix-store` — encrypted session state is lost if these change across restarts.
| Option | Description |
|--------|-------------|
@ -720,6 +726,9 @@ Give nanobot its own email account. It polls **IMAP** for incoming mail and repl
> - `allowFrom`: Add your email address. Use `["*"]` to accept emails from anyone.
> - `smtpUseTls` and `smtpUseSsl` default to `true` / `false` respectively, which is correct for Gmail (port 587 + STARTTLS). No need to set them explicitly.
> - Set `"autoReplyEnabled": false` if you only want to read/analyze emails without sending automatic replies.
> - `allowedAttachmentTypes`: Save inbound attachments matching these MIME types — `["*"]` for all, e.g. `["application/pdf", "image/*"]` (default `[]` = disabled).
> - `maxAttachmentSize`: Max size per attachment in bytes (default `2000000` / 2MB).
> - `maxAttachmentsPerEmail`: Max attachments to save per email (default `5`).
```json
{
@ -736,7 +745,8 @@ Give nanobot its own email account. It polls **IMAP** for incoming mail and repl
"smtpUsername": "my-nanobot@gmail.com",
"smtpPassword": "your-app-password",
"fromAddress": "my-nanobot@gmail.com",
"allowFrom": ["your-real-email@gmail.com"]
"allowFrom": ["your-real-email@gmail.com"],
"allowedAttachmentTypes": ["application/pdf", "image/*"]
}
}
}
@ -861,10 +871,45 @@ Config file: `~/.nanobot/config.json`
> run `nanobot onboard`, then answer `N` when asked whether to overwrite the config.
> nanobot will merge in missing default fields and keep your current settings.
### Environment Variables for Secrets
Instead of storing secrets directly in `config.json`, you can use `${VAR_NAME}` references that are resolved from environment variables at startup:
```json
{
"channels": {
"telegram": { "token": "${TELEGRAM_TOKEN}" },
"email": {
"imapPassword": "${IMAP_PASSWORD}",
"smtpPassword": "${SMTP_PASSWORD}"
}
},
"providers": {
"groq": { "apiKey": "${GROQ_API_KEY}" }
}
}
```
For **systemd** deployments, use `EnvironmentFile=` in the service unit to load variables from a file that only the deploying user can read:
```ini
# /etc/systemd/system/nanobot.service (excerpt)
[Service]
EnvironmentFile=/home/youruser/nanobot_secrets.env
User=nanobot
ExecStart=...
```
```bash
# /home/youruser/nanobot_secrets.env (mode 600, owned by youruser)
TELEGRAM_TOKEN=your-token-here
IMAP_PASSWORD=your-password-here
```
### Providers
> [!TIP]
> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
> - **Voice transcription**: Voice messages (Telegram, WhatsApp) are automatically transcribed using Whisper. By default Groq is used (free tier). Set `"transcriptionProvider": "openai"` under `channels` to use OpenAI Whisper instead — the API key is picked from the matching provider config.
> - **MiniMax Coding Plan**: Exclusive discount links for the nanobot community: [Overseas](https://platform.minimax.io/subscribe/coding-plan?code=9txpdXw04g&source=link) · [Mainland China](https://platform.minimaxi.com/subscribe/token-plan?code=GILTJpMTqZ&source=link)
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
> - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers.
@ -880,9 +925,9 @@ Config file: `~/.nanobot/config.json`
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
| `openai` | LLM + Voice transcription (Whisper) | [platform.openai.com](https://platform.openai.com) |
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
| `groq` | LLM + Voice transcription (Whisper, default) | [console.groq.com](https://console.groq.com) |
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
@ -1197,6 +1242,7 @@ Global settings that apply to all channels. Configure under the `channels` secti
"sendProgress": true,
"sendToolHints": false,
"sendMaxRetries": 3,
"transcriptionProvider": "groq",
"telegram": { ... }
}
}
@ -1207,6 +1253,7 @@ Global settings that apply to all channels. Configure under the `channels` secti
| `sendProgress` | `true` | Stream agent's text progress to the channel |
| `sendToolHints` | `false` | Stream tool-call hints (e.g. `read_file("…")`) |
| `sendMaxRetries` | `3` | Max delivery attempts per outbound message, including the initial send (0-10 configured, minimum 1 actual attempt) |
| `transcriptionProvider` | `"groq"` | Voice transcription backend: `"groq"` (free tier, default) or `"openai"`. API key is auto-resolved from the matching provider config. |
#### Retry Behavior
@ -1434,16 +1481,19 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
### Security
> [!TIP]
> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent.
> For production deployments, set `"restrictToWorkspace": true` and `"tools.exec.sandbox": "bwrap"` in your config to sandbox the agent.
> In `v0.1.4.post3` and earlier, an empty `allowFrom` allowed all senders. Since `v0.1.4.post4`, empty `allowFrom` denies all access by default. To allow all senders, set `"allowFrom": ["*"]`.
| Option | Default | Description |
|--------|---------|-------------|
| `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. |
| `tools.exec.sandbox` | `""` | Sandbox backend for shell commands. Set to `"bwrap"` to wrap exec calls in a [bubblewrap](https://github.com/containers/bubblewrap) sandbox — the process can only see the workspace (read-write) and media directory (read-only); config files and API keys are hidden. Automatically enables `restrictToWorkspace` for file tools. **Linux only** — requires `bwrap` installed (`apt install bubblewrap`; pre-installed in the Docker image). Not available on macOS or Windows (bwrap depends on Linux kernel namespaces). |
| `tools.exec.enable` | `true` | When `false`, the shell `exec` tool is not registered at all. Use this to completely disable shell command execution. |
| `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). |
| `channels.*.allowFrom` | `[]` (deny all) | Whitelist of user IDs. Empty denies all; use `["*"]` to allow everyone. |
**Docker security**: The official Docker image runs as a non-root user (`nanobot`, UID 1000) with bubblewrap pre-installed. When using `docker-compose.yml`, the container drops all Linux capabilities except `SYS_ADMIN` (required for bwrap's namespace isolation).
### Timezone
@ -1763,7 +1813,8 @@ print(resp.choices[0].message.content)
## 🐳 Docker
> [!TIP]
> The `-v ~/.nanobot:/root/.nanobot` flag mounts your local config directory into the container, so your config and workspace persist across container restarts.
> The `-v ~/.nanobot:/home/nanobot/.nanobot` flag mounts your local config directory into the container, so your config and workspace persist across container restarts.
> The container runs as user `nanobot` (UID 1000). If you get **Permission denied**, fix ownership on the host first: `sudo chown -R 1000:1000 ~/.nanobot`, or pass `--user $(id -u):$(id -g)` to match your host UID. Podman users can use `--userns=keep-id` instead.
### Docker Compose
@ -1786,17 +1837,17 @@ docker compose down # stop
docker build -t nanobot .
# Initialize config (first time only)
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot onboard
docker run -v ~/.nanobot:/home/nanobot/.nanobot --rm nanobot onboard
# Edit config on host to add API keys
vim ~/.nanobot/config.json
# Run gateway (connects to enabled channels, e.g. Telegram/Discord/Mochat)
docker run -v ~/.nanobot:/root/.nanobot -p 18790:18790 nanobot gateway
docker run -v ~/.nanobot:/home/nanobot/.nanobot -p 18790:18790 nanobot gateway
# Or run a single command
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot agent -m "Hello!"
docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status
docker run -v ~/.nanobot:/home/nanobot/.nanobot --rm nanobot agent -m "Hello!"
docker run -v ~/.nanobot:/home/nanobot/.nanobot --rm nanobot status
```
## 🐧 Linux Service

View File

@ -64,6 +64,7 @@ chmod 600 ~/.nanobot/config.json
The `exec` tool can execute shell commands. While dangerous command patterns are blocked, you should:
- ✅ **Enable the bwrap sandbox** (`"tools.exec.sandbox": "bwrap"`) for kernel-level isolation (Linux only)
- ✅ Review all tool usage in agent logs
- ✅ Understand what commands the agent is running
- ✅ Use a dedicated user account with limited privileges
@ -71,6 +72,19 @@ The `exec` tool can execute shell commands. While dangerous command patterns are
- ❌ Don't disable security checks
- ❌ Don't run on systems with sensitive data without careful review
**Exec sandbox (bwrap):**
On Linux, set `"tools.exec.sandbox": "bwrap"` to wrap every shell command in a [bubblewrap](https://github.com/containers/bubblewrap) sandbox. This uses Linux kernel namespaces to restrict what the process can see:
- Workspace directory → **read-write** (agent works normally)
- Media directory → **read-only** (can read uploaded attachments)
- System directories (`/usr`, `/bin`, `/lib`) → **read-only** (commands still work)
- Config files and API keys (`~/.nanobot/config.json`) → **hidden** (masked by tmpfs)
Requires `bwrap` installed (`apt install bubblewrap`). Pre-installed in the official Docker image. **Not available on macOS or Windows** — bubblewrap depends on Linux kernel namespaces.
Enabling the sandbox also automatically activates `restrictToWorkspace` for file tools.
**Blocked patterns:**
- `rm -rf /` - Root filesystem deletion
- Fork bombs
@ -82,6 +96,7 @@ The `exec` tool can execute shell commands. While dangerous command patterns are
File operations have path traversal protection, but:
- ✅ Enable `restrictToWorkspace` or the bwrap sandbox to confine file access
- ✅ Run nanobot with a dedicated user account
- ✅ Use filesystem permissions to protect sensitive directories
- ✅ Regularly audit file operations in logs
@ -232,7 +247,7 @@ If you suspect a security breach:
1. **No Rate Limiting** - Users can send unlimited messages (add your own if needed)
2. **Plain Text Config** - API keys stored in plain text (use keyring for production)
3. **No Session Management** - No automatic session expiry
4. **Limited Command Filtering** - Only blocks obvious dangerous patterns
4. **Limited Command Filtering** - Only blocks obvious dangerous patterns (enable the bwrap sandbox for kernel-level isolation on Linux)
5. **No Audit Trail** - Limited security event logging (enhance as needed)
## Security Checklist
@ -243,6 +258,7 @@ Before deploying nanobot:
- [ ] Config file permissions set to 0600
- [ ] `allowFrom` lists configured for all channels
- [ ] Running as non-root user
- [ ] Exec sandbox enabled (`"tools.exec.sandbox": "bwrap"`) on Linux deployments
- [ ] File system permissions properly restricted
- [ ] Dependencies updated to latest secure versions
- [ ] Logs monitored for security events
@ -252,7 +268,7 @@ Before deploying nanobot:
## Updates
**Last Updated**: 2026-02-03
**Last Updated**: 2026-04-05
For the latest security updates and announcements, check:
- GitHub Security Advisories: https://github.com/HKUDS/nanobot/security/advisories

View File

Before

Width:  |  Height:  |  Size: 6.8 MiB

After

Width:  |  Height:  |  Size: 6.8 MiB

View File

@ -3,7 +3,14 @@ x-common-config: &common-config
context: .
dockerfile: Dockerfile
volumes:
- ~/.nanobot:/root/.nanobot
- ~/.nanobot:/home/nanobot/.nanobot
cap_drop:
- ALL
cap_add:
- SYS_ADMIN
security_opt:
- apparmor=unconfined
- seccomp=unconfined
services:
nanobot-gateway:
@ -16,12 +23,29 @@ services:
deploy:
resources:
limits:
cpus: '1'
cpus: "1"
memory: 1G
reservations:
cpus: '0.25'
cpus: "0.25"
memory: 256M
nanobot-api:
container_name: nanobot-api
<<: *common-config
command:
["serve", "--host", "0.0.0.0", "-w", "/home/nanobot/.nanobot/api-workspace"]
restart: unless-stopped
ports:
- 127.0.0.1:8900:8900
deploy:
resources:
limits:
cpus: "1"
memory: 1G
reservations:
cpus: "0.25"
memory: 256M
nanobot-cli:
<<: *common-config
profiles:

15
entrypoint.sh Executable file
View File

@ -0,0 +1,15 @@
#!/bin/sh
dir="$HOME/.nanobot"
if [ -d "$dir" ] && [ ! -w "$dir" ]; then
owner_uid=$(stat -c %u "$dir" 2>/dev/null || stat -f %u "$dir" 2>/dev/null)
cat >&2 <<EOF
Error: $dir is not writable (owned by UID $owner_uid, running as UID $(id -u)).
Fix (pick one):
Host: sudo chown -R 1000:1000 ~/.nanobot
Docker: docker run --user \$(id -u):\$(id -g) ...
Podman: podman run --userns=keep-id ...
EOF
exit 1
fi
exec nanobot "$@"

View File

@ -2,7 +2,7 @@
nanobot - A lightweight AI agent framework
"""
__version__ = "0.1.4.post6"
__version__ = "0.1.5"
__logo__ = "🐈"
from nanobot.nanobot import Nanobot, RunResult

View File

@ -3,7 +3,7 @@
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.loop import AgentLoop
from nanobot.agent.memory import Consolidator, Dream, MemoryStore
from nanobot.agent.memory import Dream, MemoryStore
from nanobot.agent.skills import SkillsLoader
from nanobot.agent.subagent import SubagentManager

View File

@ -67,40 +67,27 @@ class CompositeHook(AgentHook):
def wants_streaming(self) -> bool:
return any(h.wants_streaming() for h in self._hooks)
async def before_iteration(self, context: AgentHookContext) -> None:
async def _for_each_hook_safe(self, method_name: str, *args: Any, **kwargs: Any) -> None:
for h in self._hooks:
try:
await h.before_iteration(context)
await getattr(h, method_name)(*args, **kwargs)
except Exception:
logger.exception("AgentHook.before_iteration error in {}", type(h).__name__)
logger.exception("AgentHook.{} error in {}", method_name, type(h).__name__)
async def before_iteration(self, context: AgentHookContext) -> None:
await self._for_each_hook_safe("before_iteration", context)
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
for h in self._hooks:
try:
await h.on_stream(context, delta)
except Exception:
logger.exception("AgentHook.on_stream error in {}", type(h).__name__)
await self._for_each_hook_safe("on_stream", context, delta)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
for h in self._hooks:
try:
await h.on_stream_end(context, resuming=resuming)
except Exception:
logger.exception("AgentHook.on_stream_end error in {}", type(h).__name__)
await self._for_each_hook_safe("on_stream_end", context, resuming=resuming)
async def before_execute_tools(self, context: AgentHookContext) -> None:
for h in self._hooks:
try:
await h.before_execute_tools(context)
except Exception:
logger.exception("AgentHook.before_execute_tools error in {}", type(h).__name__)
await self._for_each_hook_safe("before_execute_tools", context)
async def after_iteration(self, context: AgentHookContext) -> None:
for h in self._hooks:
try:
await h.after_iteration(context)
except Exception:
logger.exception("AgentHook.after_iteration error in {}", type(h).__name__)
await self._for_each_hook_safe("after_iteration", context)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
for h in self._hooks:

View File

@ -4,7 +4,6 @@ from __future__ import annotations
import asyncio
import json
import re
import os
import time
from contextlib import AsyncExitStack, nullcontext
@ -262,7 +261,7 @@ class AgentLoop:
def _register_default_tools(self) -> None:
"""Register the default set of tools."""
allowed_dir = self.workspace if self.restrict_to_workspace else None
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
for cls in (WriteFileTool, EditFileTool, ListDirTool):
@ -274,6 +273,7 @@ class AgentLoop:
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
sandbox=self.exec_config.sandbox,
path_append=self.exec_config.path_append,
))
if self.web_config.enable:
@ -325,14 +325,10 @@ class AgentLoop:
@staticmethod
def _tool_hint(tool_calls: list) -> str:
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
def _fmt(tc):
args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {}
val = next(iter(args.values()), None) if isinstance(args, dict) else None
if not isinstance(val, str):
return tc.name
return f'{tc.name}("{val[:40]}")' if len(val) > 40 else f'{tc.name}("{val}")'
return ", ".join(_fmt(tc) for tc in tool_calls)
"""Format tool calls as concise hints with smart abbreviation."""
from nanobot.utils.tool_hints import format_tool_hints
return format_tool_hints(tool_calls)
async def _run_agent_loop(
self,

View File

@ -30,6 +30,7 @@ from nanobot.utils.runtime import (
)
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
_MAX_EMPTY_RETRIES = 2
_SNIP_SAFETY_BUFFER = 1024
@dataclass(slots=True)
class AgentRunSpec:
@ -86,6 +87,7 @@ class AgentRunner:
stop_reason = "completed"
tool_events: list[dict[str, str]] = []
external_lookup_counts: dict[str, int] = {}
empty_content_retries = 0
for iteration in range(spec.max_iterations):
try:
@ -178,15 +180,30 @@ class AgentRunner:
"pending_tool_calls": [],
},
)
empty_content_retries = 0
await hook.after_iteration(context)
continue
clean = hook.finalize_content(context, response.content)
if response.finish_reason != "error" and is_blank_text(clean):
empty_content_retries += 1
if empty_content_retries < _MAX_EMPTY_RETRIES:
logger.warning(
"Empty response on turn {} for {} ({}/{}); retrying",
iteration,
spec.session_key or "default",
empty_content_retries,
_MAX_EMPTY_RETRIES,
)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
await hook.after_iteration(context)
continue
logger.warning(
"Empty final response on turn {} for {}; retrying with explicit finalization prompt",
"Empty response on turn {} for {} after {} retries; attempting finalization",
iteration,
spec.session_key or "default",
empty_content_retries,
)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)

View File

@ -9,6 +9,16 @@ from pathlib import Path
# Default builtin skills directory (relative to this file)
BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills"
# Opening ---, YAML body (group 1), closing --- on its own line; supports CRLF.
_STRIP_SKILL_FRONTMATTER = re.compile(
r"^---\s*\r?\n(.*?)\r?\n---\s*\r?\n?",
re.DOTALL,
)
def _escape_xml(text: str) -> str:
return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
class SkillsLoader:
"""
@ -23,6 +33,22 @@ class SkillsLoader:
self.workspace_skills = workspace / "skills"
self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR
def _skill_entries_from_dir(self, base: Path, source: str, *, skip_names: set[str] | None = None) -> list[dict[str, str]]:
if not base.exists():
return []
entries: list[dict[str, str]] = []
for skill_dir in base.iterdir():
if not skill_dir.is_dir():
continue
skill_file = skill_dir / "SKILL.md"
if not skill_file.exists():
continue
name = skill_dir.name
if skip_names is not None and name in skip_names:
continue
entries.append({"name": name, "path": str(skill_file), "source": source})
return entries
def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]:
"""
List all available skills.
@ -33,27 +59,15 @@ class SkillsLoader:
Returns:
List of skill info dicts with 'name', 'path', 'source'.
"""
skills = []
# Workspace skills (highest priority)
if self.workspace_skills.exists():
for skill_dir in self.workspace_skills.iterdir():
if skill_dir.is_dir():
skill_file = skill_dir / "SKILL.md"
if skill_file.exists():
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"})
# Built-in skills
skills = self._skill_entries_from_dir(self.workspace_skills, "workspace")
workspace_names = {entry["name"] for entry in skills}
if self.builtin_skills and self.builtin_skills.exists():
for skill_dir in self.builtin_skills.iterdir():
if skill_dir.is_dir():
skill_file = skill_dir / "SKILL.md"
if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills):
skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"})
skills.extend(
self._skill_entries_from_dir(self.builtin_skills, "builtin", skip_names=workspace_names)
)
# Filter by requirements
if filter_unavailable:
return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))]
return [skill for skill in skills if self._check_requirements(self._get_skill_meta(skill["name"]))]
return skills
def load_skill(self, name: str) -> str | None:
@ -66,17 +80,13 @@ class SkillsLoader:
Returns:
Skill content or None if not found.
"""
# Check workspace first
workspace_skill = self.workspace_skills / name / "SKILL.md"
if workspace_skill.exists():
return workspace_skill.read_text(encoding="utf-8")
# Check built-in
roots = [self.workspace_skills]
if self.builtin_skills:
builtin_skill = self.builtin_skills / name / "SKILL.md"
if builtin_skill.exists():
return builtin_skill.read_text(encoding="utf-8")
roots.append(self.builtin_skills)
for root in roots:
path = root / name / "SKILL.md"
if path.exists():
return path.read_text(encoding="utf-8")
return None
def load_skills_for_context(self, skill_names: list[str]) -> str:
@ -89,14 +99,12 @@ class SkillsLoader:
Returns:
Formatted skills content.
"""
parts = []
for name in skill_names:
content = self.load_skill(name)
if content:
content = self._strip_frontmatter(content)
parts.append(f"### Skill: {name}\n\n{content}")
return "\n\n---\n\n".join(parts) if parts else ""
parts = [
f"### Skill: {name}\n\n{self._strip_frontmatter(markdown)}"
for name in skill_names
if (markdown := self.load_skill(name))
]
return "\n\n---\n\n".join(parts)
def build_skills_summary(self) -> str:
"""
@ -112,44 +120,36 @@ class SkillsLoader:
if not all_skills:
return ""
def escape_xml(s: str) -> str:
return s.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
lines = ["<skills>"]
for s in all_skills:
name = escape_xml(s["name"])
path = s["path"]
desc = escape_xml(self._get_skill_description(s["name"]))
skill_meta = self._get_skill_meta(s["name"])
available = self._check_requirements(skill_meta)
lines.append(f" <skill available=\"{str(available).lower()}\">")
lines.append(f" <name>{name}</name>")
lines.append(f" <description>{desc}</description>")
lines.append(f" <location>{path}</location>")
# Show missing requirements for unavailable skills
lines: list[str] = ["<skills>"]
for entry in all_skills:
skill_name = entry["name"]
meta = self._get_skill_meta(skill_name)
available = self._check_requirements(meta)
lines.extend(
[
f' <skill available="{str(available).lower()}">',
f" <name>{_escape_xml(skill_name)}</name>",
f" <description>{_escape_xml(self._get_skill_description(skill_name))}</description>",
f" <location>{entry['path']}</location>",
]
)
if not available:
missing = self._get_missing_requirements(skill_meta)
missing = self._get_missing_requirements(meta)
if missing:
lines.append(f" <requires>{escape_xml(missing)}</requires>")
lines.append(f" <requires>{_escape_xml(missing)}</requires>")
lines.append(" </skill>")
lines.append("</skills>")
return "\n".join(lines)
def _get_missing_requirements(self, skill_meta: dict) -> str:
"""Get a description of missing requirements."""
missing = []
requires = skill_meta.get("requires", {})
for b in requires.get("bins", []):
if not shutil.which(b):
missing.append(f"CLI: {b}")
for env in requires.get("env", []):
if not os.environ.get(env):
missing.append(f"ENV: {env}")
return ", ".join(missing)
required_bins = requires.get("bins", [])
required_env_vars = requires.get("env", [])
return ", ".join(
[f"CLI: {command_name}" for command_name in required_bins if not shutil.which(command_name)]
+ [f"ENV: {env_name}" for env_name in required_env_vars if not os.environ.get(env_name)]
)
def _get_skill_description(self, name: str) -> str:
"""Get the description of a skill from its frontmatter."""
@ -160,30 +160,32 @@ class SkillsLoader:
def _strip_frontmatter(self, content: str) -> str:
"""Remove YAML frontmatter from markdown content."""
if content.startswith("---"):
match = re.match(r"^---\n.*?\n---\n", content, re.DOTALL)
if match:
return content[match.end():].strip()
if not content.startswith("---"):
return content
match = _STRIP_SKILL_FRONTMATTER.match(content)
if match:
return content[match.end():].strip()
return content
def _parse_nanobot_metadata(self, raw: str) -> dict:
"""Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys)."""
try:
data = json.loads(raw)
return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {}
except (json.JSONDecodeError, TypeError):
return {}
if not isinstance(data, dict):
return {}
payload = data.get("nanobot", data.get("openclaw", {}))
return payload if isinstance(payload, dict) else {}
def _check_requirements(self, skill_meta: dict) -> bool:
"""Check if skill requirements are met (bins, env vars)."""
requires = skill_meta.get("requires", {})
for b in requires.get("bins", []):
if not shutil.which(b):
return False
for env in requires.get("env", []):
if not os.environ.get(env):
return False
return True
required_bins = requires.get("bins", [])
required_env_vars = requires.get("env", [])
return all(shutil.which(cmd) for cmd in required_bins) and all(
os.environ.get(var) for var in required_env_vars
)
def _get_skill_meta(self, name: str) -> dict:
"""Get nanobot metadata for a skill (cached in frontmatter)."""
@ -192,13 +194,15 @@ class SkillsLoader:
def get_always_skills(self) -> list[str]:
"""Get skills marked as always=true that meet requirements."""
result = []
for s in self.list_skills(filter_unavailable=True):
meta = self.get_skill_metadata(s["name"]) or {}
skill_meta = self._parse_nanobot_metadata(meta.get("metadata", ""))
if skill_meta.get("always") or meta.get("always"):
result.append(s["name"])
return result
return [
entry["name"]
for entry in self.list_skills(filter_unavailable=True)
if (meta := self.get_skill_metadata(entry["name"]) or {})
and (
self._parse_nanobot_metadata(meta.get("metadata", "")).get("always")
or meta.get("always")
)
]
def get_skill_metadata(self, name: str) -> dict | None:
"""
@ -211,18 +215,15 @@ class SkillsLoader:
Metadata dict or None.
"""
content = self.load_skill(name)
if not content:
if not content or not content.startswith("---"):
return None
if content.startswith("---"):
match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL)
if match:
# Simple YAML parsing
metadata = {}
for line in match.group(1).split("\n"):
if ":" in line:
key, value = line.split(":", 1)
metadata[key.strip()] = value.strip().strip('"\'')
return metadata
return None
match = _STRIP_SKILL_FRONTMATTER.match(content)
if not match:
return None
metadata: dict[str, str] = {}
for line in match.group(1).splitlines():
if ":" not in line:
continue
key, value = line.split(":", 1)
metadata[key.strip()] = value.strip().strip('"\'')
return metadata

View File

@ -111,7 +111,7 @@ class SubagentManager:
try:
# Build subagent tools (no message tool, no spawn tool)
tools = ToolRegistry()
allowed_dir = self.workspace if self.restrict_to_workspace else None
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
@ -124,6 +124,7 @@ class SubagentManager:
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
sandbox=self.exec_config.sandbox,
path_append=self.exec_config.path_append,
))
if self.web_config.enable:

View File

@ -13,6 +13,10 @@ from nanobot.cron.types import CronJob, CronJobState, CronSchedule
@tool_parameters(
tool_parameters_schema(
action=StringSchema("Action to perform", enum=["add", "list", "remove"]),
name=StringSchema(
"Optional short human-readable label for the job "
"(e.g., 'weather-monitor', 'daily-standup'). Defaults to first 30 chars of message."
),
message=StringSchema(
"Instruction for the agent to execute when the job triggers "
"(e.g., 'Send a reminder to WeChat: xxx' or 'Check system status and report')"
@ -93,6 +97,7 @@ class CronTool(Tool):
async def execute(
self,
action: str,
name: str | None = None,
message: str = "",
every_seconds: int | None = None,
cron_expr: str | None = None,
@ -105,7 +110,7 @@ class CronTool(Tool):
if action == "add":
if self._in_cron_context.get():
return "Error: cannot schedule new jobs from within a cron job execution"
return self._add_job(message, every_seconds, cron_expr, tz, at, deliver)
return self._add_job(name, message, every_seconds, cron_expr, tz, at, deliver)
elif action == "list":
return self._list_jobs()
elif action == "remove":
@ -114,6 +119,7 @@ class CronTool(Tool):
def _add_job(
self,
name: str | None,
message: str,
every_seconds: int | None,
cron_expr: str | None,
@ -158,7 +164,7 @@ class CronTool(Tool):
return "Error: either every_seconds, cron_expr, or at is required"
job = self._cron.add_job(
name=message[:30],
name=name or message[:30],
schedule=schedule,
message=message,
deliver=deliver,

View File

@ -186,7 +186,7 @@ class WriteFileTool(_FsTool):
fp = self._resolve(path)
fp.parent.mkdir(parents=True, exist_ok=True)
fp.write_text(content, encoding="utf-8")
return f"Successfully wrote {len(content)} bytes to {fp}"
return f"Successfully wrote {len(content)} characters to {fp}"
except PermissionError as e:
return f"Error: {e}"
except Exception as e:

View File

@ -0,0 +1,55 @@
"""Sandbox backends for shell command execution.
To add a new backend, implement a function with the signature:
_wrap_<name>(command: str, workspace: str, cwd: str) -> str
and register it in _BACKENDS below.
"""
import shlex
from pathlib import Path
from nanobot.config.paths import get_media_dir
def _bwrap(command: str, workspace: str, cwd: str) -> str:
"""Wrap command in a bubblewrap sandbox (requires bwrap in container).
Only the workspace is bind-mounted read-write; its parent dir (which holds
config.json) is hidden behind a fresh tmpfs. The media directory is
bind-mounted read-only so exec commands can read uploaded attachments.
"""
ws = Path(workspace).resolve()
media = get_media_dir().resolve()
try:
sandbox_cwd = str(ws / Path(cwd).resolve().relative_to(ws))
except ValueError:
sandbox_cwd = str(ws)
required = ["/usr"]
optional = ["/bin", "/lib", "/lib64", "/etc/alternatives",
"/etc/ssl/certs", "/etc/resolv.conf", "/etc/ld.so.cache"]
args = ["bwrap", "--new-session", "--die-with-parent"]
for p in required: args += ["--ro-bind", p, p]
for p in optional: args += ["--ro-bind-try", p, p]
args += [
"--proc", "/proc", "--dev", "/dev", "--tmpfs", "/tmp",
"--tmpfs", str(ws.parent), # mask config dir
"--dir", str(ws), # recreate workspace mount point
"--bind", str(ws), str(ws),
"--ro-bind-try", str(media), str(media), # read-only access to media
"--chdir", sandbox_cwd,
"--", "sh", "-c", command,
]
return shlex.join(args)
_BACKENDS = {"bwrap": _bwrap}
def wrap_command(sandbox: str, command: str, workspace: str, cwd: str) -> str:
"""Wrap *command* using the named sandbox backend."""
if backend := _BACKENDS.get(sandbox):
return backend(command, workspace, cwd)
raise ValueError(f"Unknown sandbox backend {sandbox!r}. Available: {list(_BACKENDS)}")

View File

@ -3,6 +3,7 @@
import asyncio
import os
import re
import shutil
import sys
from pathlib import Path
from typing import Any
@ -10,6 +11,7 @@ from typing import Any
from loguru import logger
from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.sandbox import wrap_command
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
from nanobot.config.paths import get_media_dir
@ -40,10 +42,12 @@ class ExecTool(Tool):
deny_patterns: list[str] | None = None,
allow_patterns: list[str] | None = None,
restrict_to_workspace: bool = False,
sandbox: str = "",
path_append: str = "",
):
self.timeout = timeout
self.working_dir = working_dir
self.sandbox = sandbox
self.deny_patterns = deny_patterns or [
r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr
r"\bdel\s+/[fq]\b", # del /f, del /q
@ -83,15 +87,23 @@ class ExecTool(Tool):
if guard_error:
return guard_error
if self.sandbox:
workspace = self.working_dir or cwd
command = wrap_command(self.sandbox, command, workspace, cwd)
cwd = str(Path(workspace).resolve())
effective_timeout = min(timeout or self.timeout, self._MAX_TIMEOUT)
env = os.environ.copy()
env = self._build_env()
if self.path_append:
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
command = f'export PATH="$PATH:{self.path_append}"; {command}'
bash = shutil.which("bash") or "/bin/bash"
try:
process = await asyncio.create_subprocess_shell(
command,
process = await asyncio.create_subprocess_exec(
bash, "-l", "-c", command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
@ -104,18 +116,11 @@ class ExecTool(Tool):
timeout=effective_timeout,
)
except asyncio.TimeoutError:
process.kill()
try:
await asyncio.wait_for(process.wait(), timeout=5.0)
except asyncio.TimeoutError:
pass
finally:
if sys.platform != "win32":
try:
os.waitpid(process.pid, os.WNOHANG)
except (ProcessLookupError, ChildProcessError) as e:
logger.debug("Process already reaped or not found: {}", e)
await self._kill_process(process)
return f"Error: Command timed out after {effective_timeout} seconds"
except asyncio.CancelledError:
await self._kill_process(process)
raise
output_parts = []
@ -146,6 +151,36 @@ class ExecTool(Tool):
except Exception as e:
return f"Error executing command: {str(e)}"
@staticmethod
async def _kill_process(process: asyncio.subprocess.Process) -> None:
"""Kill a subprocess and reap it to prevent zombies."""
process.kill()
try:
await asyncio.wait_for(process.wait(), timeout=5.0)
except asyncio.TimeoutError:
pass
finally:
if sys.platform != "win32":
try:
os.waitpid(process.pid, os.WNOHANG)
except (ProcessLookupError, ChildProcessError) as e:
logger.debug("Process already reaped or not found: {}", e)
def _build_env(self) -> dict[str, str]:
"""Build a minimal environment for subprocess execution.
Uses HOME so that ``bash -l`` sources the user's profile (which sets
PATH and other essentials). Only PATH is extended with *path_append*;
the parent process's environment is **not** inherited, preventing
secrets in env vars from leaking to LLM-generated commands.
"""
home = os.environ.get("HOME", "/tmp")
return {
"HOME": home,
"LANG": os.environ.get("LANG", "C.UTF-8"),
"TERM": os.environ.get("TERM", "dumb"),
}
def _guard_command(self, command: str, cwd: str) -> str | None:
"""Best-effort safety guard for potentially destructive commands."""
cmd = command.strip()

View File

@ -8,7 +8,7 @@ import json
import os
import re
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
from urllib.parse import quote, urlparse
import httpx
from loguru import logger
@ -182,10 +182,10 @@ class WebSearchTool(Tool):
return await self._search_duckduckgo(query, n)
try:
headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"}
encoded_query = quote(query, safe="")
async with httpx.AsyncClient(proxy=self.proxy) as client:
r = await client.get(
f"https://s.jina.ai/",
params={"q": query},
f"https://s.jina.ai/{encoded_query}",
headers=headers,
timeout=15.0,
)
@ -197,7 +197,8 @@ class WebSearchTool(Tool):
]
return _format_results(query, items, n)
except Exception as e:
return f"Error: {e}"
logger.warning("Jina search failed ({}), falling back to DuckDuckGo", e)
return await self._search_duckduckgo(query, n)
async def _search_duckduckgo(self, query: str, n: int) -> str:
try:
@ -206,7 +207,10 @@ class WebSearchTool(Tool):
from ddgs import DDGS
ddgs = DDGS(timeout=10)
raw = await asyncio.to_thread(ddgs.text, query, max_results=n)
raw = await asyncio.wait_for(
asyncio.to_thread(ddgs.text, query, max_results=n),
timeout=self.config.timeout,
)
if not raw:
return f"No results for: {query}"
items = [

View File

@ -22,6 +22,7 @@ class BaseChannel(ABC):
name: str = "base"
display_name: str = "Base"
transcription_provider: str = "groq"
transcription_api_key: str = ""
def __init__(self, config: Any, bus: MessageBus):
@ -37,13 +38,16 @@ class BaseChannel(ABC):
self._running = False
async def transcribe_audio(self, file_path: str | Path) -> str:
"""Transcribe an audio file via Groq Whisper. Returns empty string on failure."""
"""Transcribe an audio file via Whisper (OpenAI or Groq). Returns empty string on failure."""
if not self.transcription_api_key:
return ""
try:
from nanobot.providers.transcription import GroqTranscriptionProvider
provider = GroqTranscriptionProvider(api_key=self.transcription_api_key)
if self.transcription_provider == "openai":
from nanobot.providers.transcription import OpenAITranscriptionProvider
provider = OpenAITranscriptionProvider(api_key=self.transcription_api_key)
else:
from nanobot.providers.transcription import GroqTranscriptionProvider
provider = GroqTranscriptionProvider(api_key=self.transcription_api_key)
return await provider.transcribe(file_path)
except Exception as e:
logger.warning("{}: audio transcription failed: {}", self.name, e)

View File

@ -12,6 +12,8 @@ from email.header import decode_header, make_header
from email.message import EmailMessage
from email.parser import BytesParser
from email.utils import parseaddr
from fnmatch import fnmatch
from pathlib import Path
from typing import Any
from loguru import logger
@ -20,7 +22,9 @@ from pydantic import Field
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base
from nanobot.utils.helpers import safe_filename
class EmailConfig(Base):
@ -55,6 +59,11 @@ class EmailConfig(Base):
verify_dkim: bool = True # Require Authentication-Results with dkim=pass
verify_spf: bool = True # Require Authentication-Results with spf=pass
# Attachment handling — set allowed types to enable (e.g. ["application/pdf", "image/*"], or ["*"] for all)
allowed_attachment_types: list[str] = Field(default_factory=list)
max_attachment_size: int = 2_000_000 # 2MB per attachment
max_attachments_per_email: int = 5
class EmailChannel(BaseChannel):
"""
@ -153,6 +162,7 @@ class EmailChannel(BaseChannel):
sender_id=sender,
chat_id=sender,
content=item["content"],
media=item.get("media") or None,
metadata=item.get("metadata", {}),
)
except Exception as e:
@ -404,6 +414,20 @@ class EmailChannel(BaseChannel):
f"{body}"
)
# --- Attachment extraction ---
attachment_paths: list[str] = []
if self.config.allowed_attachment_types:
saved = self._extract_attachments(
parsed,
uid or "noid",
allowed_types=self.config.allowed_attachment_types,
max_size=self.config.max_attachment_size,
max_count=self.config.max_attachments_per_email,
)
for p in saved:
attachment_paths.append(str(p))
content += f"\n[attachment: {p.name} — saved to {p}]"
metadata = {
"message_id": message_id,
"subject": subject,
@ -418,6 +442,7 @@ class EmailChannel(BaseChannel):
"message_id": message_id,
"content": content,
"metadata": metadata,
"media": attachment_paths,
}
)
@ -537,6 +562,61 @@ class EmailChannel(BaseChannel):
dkim_pass = True
return spf_pass, dkim_pass
@classmethod
def _extract_attachments(
cls,
msg: Any,
uid: str,
*,
allowed_types: list[str],
max_size: int,
max_count: int,
) -> list[Path]:
"""Extract and save email attachments to the media directory.
Returns list of saved file paths.
"""
if not msg.is_multipart():
return []
saved: list[Path] = []
media_dir = get_media_dir("email")
for part in msg.walk():
if len(saved) >= max_count:
break
if part.get_content_disposition() != "attachment":
continue
content_type = part.get_content_type()
if not any(fnmatch(content_type, pat) for pat in allowed_types):
logger.debug("Email attachment skipped (type {}): not in allowed list", content_type)
continue
payload = part.get_payload(decode=True)
if payload is None:
continue
if len(payload) > max_size:
logger.warning(
"Email attachment skipped: size {} exceeds limit {}",
len(payload),
max_size,
)
continue
raw_name = part.get_filename() or "attachment"
sanitized = safe_filename(raw_name) or "attachment"
dest = media_dir / f"{uid}_{sanitized}"
try:
dest.write_bytes(payload)
saved.append(dest)
logger.info("Email attachment saved: {}", dest)
except Exception as exc:
logger.warning("Failed to save email attachment {}: {}", dest, exc)
return saved
@staticmethod
def _html_to_text(raw_html: str) -> str:
text = re.sub(r"<\s*br\s*/?>", "\n", raw_html, flags=re.IGNORECASE)

View File

@ -1,6 +1,7 @@
"""Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection."""
import asyncio
import importlib.util
import json
import os
import re
@ -9,19 +10,17 @@ import time
import uuid
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
from lark_oapi.api.im.v1.model import MentionEvent, P2ImMessageReceiveV1
from loguru import logger
from pydantic import Field
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base
from pydantic import Field
import importlib.util
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
@ -76,7 +75,9 @@ def _extract_interactive_content(content: dict) -> list[str]:
elif isinstance(title, str):
parts.append(f"title: {title}")
for elements in content.get("elements", []) if isinstance(content.get("elements"), list) else []:
for elements in (
content.get("elements", []) if isinstance(content.get("elements"), list) else []
):
for element in elements:
parts.extend(_extract_element_content(element))
@ -260,6 +261,7 @@ _STREAM_ELEMENT_ID = "streaming_md"
@dataclass
class _FeishuStreamBuf:
"""Per-chat streaming accumulator using CardKit streaming API."""
text: str = ""
card_id: str | None = None
sequence: int = 0
@ -288,16 +290,19 @@ class FeishuChannel(BaseChannel):
return FeishuConfig().model_dump(by_alias=True)
def __init__(self, config: Any, bus: MessageBus):
import lark_oapi as lark
if isinstance(config, dict):
config = FeishuConfig.model_validate(config)
super().__init__(config, bus)
self.config: FeishuConfig = config
self._client: Any = None
self._client: lark.Client = None
self._ws_client: Any = None
self._ws_thread: threading.Thread | None = None
self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache
self._loop: asyncio.AbstractEventLoop | None = None
self._stream_bufs: dict[str, _FeishuStreamBuf] = {}
self._bot_open_id: str | None = None
@staticmethod
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
@ -316,24 +321,28 @@ class FeishuChannel(BaseChannel):
return
import lark_oapi as lark
self._running = True
self._loop = asyncio.get_running_loop()
# Create Lark client for sending messages
self._client = lark.Client.builder() \
.app_id(self.config.app_id) \
.app_secret(self.config.app_secret) \
.log_level(lark.LogLevel.INFO) \
self._client = (
lark.Client.builder()
.app_id(self.config.app_id)
.app_secret(self.config.app_secret)
.log_level(lark.LogLevel.INFO)
.build()
)
builder = lark.EventDispatcherHandler.builder(
self.config.encrypt_key or "",
self.config.verification_token or "",
).register_p2_im_message_receive_v1(
self._on_message_sync
)
).register_p2_im_message_receive_v1(self._on_message_sync)
builder = self._register_optional_event(
builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created
)
builder = self._register_optional_event(
builder, "register_p2_im_message_reaction_deleted_v1", self._on_reaction_deleted
)
builder = self._register_optional_event(
builder, "register_p2_im_message_message_read_v1", self._on_message_read
)
@ -349,7 +358,7 @@ class FeishuChannel(BaseChannel):
self.config.app_id,
self.config.app_secret,
event_handler=event_handler,
log_level=lark.LogLevel.INFO
log_level=lark.LogLevel.INFO,
)
# Start WebSocket client in a separate thread with reconnect loop.
@ -359,7 +368,9 @@ class FeishuChannel(BaseChannel):
# "This event loop is already running" errors.
def run_ws():
import time
import lark_oapi.ws.client as _lark_ws_client
ws_loop = asyncio.new_event_loop()
asyncio.set_event_loop(ws_loop)
# Patch the module-level loop used by lark's ws Client.start()
@ -378,6 +389,15 @@ class FeishuChannel(BaseChannel):
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
self._ws_thread.start()
# Fetch bot's own open_id for accurate @mention matching
self._bot_open_id = await asyncio.get_running_loop().run_in_executor(
None, self._fetch_bot_open_id
)
if self._bot_open_id:
logger.info("Feishu bot open_id: {}", self._bot_open_id)
else:
logger.warning("Could not fetch bot open_id; @mention matching may be inaccurate")
logger.info("Feishu bot started with WebSocket long connection")
logger.info("No public IP required - using WebSocket to receive events")
@ -396,6 +416,70 @@ class FeishuChannel(BaseChannel):
self._running = False
logger.info("Feishu bot stopped")
def _fetch_bot_open_id(self) -> str | None:
"""Fetch the bot's own open_id via GET /open-apis/bot/v3/info."""
try:
import lark_oapi as lark
request = (
lark.BaseRequest.builder()
.http_method(lark.HttpMethod.GET)
.uri("/open-apis/bot/v3/info")
.token_types({lark.AccessTokenType.APP})
.build()
)
response = self._client.request(request)
if response.success():
import json
data = json.loads(response.raw.content)
bot = (data.get("data") or data).get("bot") or data.get("bot") or {}
return bot.get("open_id")
logger.warning("Failed to get bot info: code={}, msg={}", response.code, response.msg)
return None
except Exception as e:
logger.warning("Error fetching bot info: {}", e)
return None
@staticmethod
def _resolve_mentions(text: str, mentions: list[MentionEvent] | None) -> str:
"""Replace @_user_n placeholders with actual user info from mentions.
Args:
text: The message text containing @_user_n placeholders
mentions: List of mention objects from Feishu message
Returns:
Text with placeholders replaced by @姓名 (open_id)
"""
if not mentions or not text:
return text
for mention in mentions:
key = mention.key or None
if not key or key not in text:
continue
user_id_obj = mention.id or None
if not user_id_obj:
continue
open_id = user_id_obj.open_id
user_id = user_id_obj.user_id
name = mention.name or key
# Format: @姓名 (open_id, user_id: xxx)
if open_id and user_id:
replacement = f"@{name} ({open_id}, user id: {user_id})"
elif open_id:
replacement = f"@{name} ({open_id})"
else:
replacement = f"@{name}"
text = text.replace(key, replacement)
return text
def _is_bot_mentioned(self, message: Any) -> bool:
"""Check if the bot is @mentioned in the message."""
raw_content = message.content or ""
@ -406,9 +490,14 @@ class FeishuChannel(BaseChannel):
mid = getattr(mention, "id", None)
if not mid:
continue
# Bot mentions have no user_id (None or "") but a valid open_id
if not getattr(mid, "user_id", None) and (getattr(mid, "open_id", None) or "").startswith("ou_"):
return True
mention_open_id = getattr(mid, "open_id", None) or ""
if self._bot_open_id:
if mention_open_id == self._bot_open_id:
return True
else:
# Fallback heuristic when bot open_id is unavailable
if not getattr(mid, "user_id", None) and mention_open_id.startswith("ou_"):
return True
return False
def _is_group_message_for_bot(self, message: Any) -> bool:
@ -419,20 +508,30 @@ class FeishuChannel(BaseChannel):
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> str | None:
"""Sync helper for adding reaction (runs in thread pool)."""
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
from lark_oapi.api.im.v1 import (
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
Emoji,
)
try:
request = CreateMessageReactionRequest.builder() \
.message_id(message_id) \
request = (
CreateMessageReactionRequest.builder()
.message_id(message_id)
.request_body(
CreateMessageReactionRequestBody.builder()
.reaction_type(Emoji.builder().emoji_type(emoji_type).build())
.build()
).build()
)
.build()
)
response = self._client.im.v1.message_reaction.create(request)
if not response.success():
logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg)
logger.warning(
"Failed to add reaction: code={}, msg={}", response.code, response.msg
)
return None
else:
logger.debug("Added {} reaction to message {}", emoji_type, message_id)
@ -456,17 +555,22 @@ class FeishuChannel(BaseChannel):
def _remove_reaction_sync(self, message_id: str, reaction_id: str) -> None:
"""Sync helper for removing reaction (runs in thread pool)."""
from lark_oapi.api.im.v1 import DeleteMessageReactionRequest
try:
request = DeleteMessageReactionRequest.builder() \
.message_id(message_id) \
.reaction_id(reaction_id) \
request = (
DeleteMessageReactionRequest.builder()
.message_id(message_id)
.reaction_id(reaction_id)
.build()
)
response = self._client.im.v1.message_reaction.delete(request)
if response.success():
logger.debug("Removed reaction {} from message {}", reaction_id, message_id)
else:
logger.debug("Failed to remove reaction: code={}, msg={}", response.code, response.msg)
logger.debug(
"Failed to remove reaction: code={}, msg={}", response.code, response.msg
)
except Exception as e:
logger.debug("Error removing reaction: {}", e)
@ -521,27 +625,35 @@ class FeishuChannel(BaseChannel):
lines = [_line.strip() for _line in table_text.strip().split("\n") if _line.strip()]
if len(lines) < 3:
return None
def split(_line: str) -> list[str]:
return [c.strip() for c in _line.strip("|").split("|")]
headers = [cls._strip_md_formatting(h) for h in split(lines[0])]
rows = [[cls._strip_md_formatting(c) for c in split(_line)] for _line in lines[2:]]
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
for i, h in enumerate(headers)]
columns = [
{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
for i, h in enumerate(headers)
]
return {
"tag": "table",
"page_size": len(rows) + 1,
"columns": columns,
"rows": [{f"c{i}": r[i] if i < len(r) else "" for i in range(len(headers))} for r in rows],
"rows": [
{f"c{i}": r[i] if i < len(r) else "" for i in range(len(headers))} for r in rows
],
}
def _build_card_elements(self, content: str) -> list[dict]:
"""Split content into div/markdown + table elements for Feishu card."""
elements, last_end = [], 0
for m in self._TABLE_RE.finditer(content):
before = content[last_end:m.start()]
before = content[last_end : m.start()]
if before.strip():
elements.extend(self._split_headings(before))
elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)})
elements.append(
self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)}
)
last_end = m.end()
remaining = content[last_end:]
if remaining.strip():
@ -549,7 +661,9 @@ class FeishuChannel(BaseChannel):
return elements or [{"tag": "markdown", "content": content}]
@staticmethod
def _split_elements_by_table_limit(elements: list[dict], max_tables: int = 1) -> list[list[dict]]:
def _split_elements_by_table_limit(
elements: list[dict], max_tables: int = 1
) -> list[list[dict]]:
"""Split card elements into groups with at most *max_tables* table elements each.
Feishu cards have a hard limit of one table per card (API error 11310).
@ -582,23 +696,25 @@ class FeishuChannel(BaseChannel):
code_blocks = []
for m in self._CODE_BLOCK_RE.finditer(content):
code_blocks.append(m.group(1))
protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks)-1}\x00", 1)
protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks) - 1}\x00", 1)
elements = []
last_end = 0
for m in self._HEADING_RE.finditer(protected):
before = protected[last_end:m.start()].strip()
before = protected[last_end : m.start()].strip()
if before:
elements.append({"tag": "markdown", "content": before})
text = self._strip_md_formatting(m.group(2).strip())
display_text = f"**{text}**" if text else ""
elements.append({
"tag": "div",
"text": {
"tag": "lark_md",
"content": display_text,
},
})
elements.append(
{
"tag": "div",
"text": {
"tag": "lark_md",
"content": display_text,
},
}
)
last_end = m.end()
remaining = protected[last_end:].strip()
if remaining:
@ -614,19 +730,19 @@ class FeishuChannel(BaseChannel):
# ── Smart format detection ──────────────────────────────────────────
# Patterns that indicate "complex" markdown needing card rendering
_COMPLEX_MD_RE = re.compile(
r"```" # fenced code block
r"```" # fenced code block
r"|^\|.+\|.*\n\s*\|[-:\s|]+\|" # markdown table (header + separator)
r"|^#{1,6}\s+" # headings
, re.MULTILINE,
r"|^#{1,6}\s+", # headings
re.MULTILINE,
)
# Simple markdown patterns (bold, italic, strikethrough)
_SIMPLE_MD_RE = re.compile(
r"\*\*.+?\*\*" # **bold**
r"|__.+?__" # __bold__
r"\*\*.+?\*\*" # **bold**
r"|__.+?__" # __bold__
r"|(?<!\*)\*(?!\*)(.+?)(?<!\*)\*(?!\*)" # *italic* (single *)
r"|~~.+?~~" # ~~strikethrough~~
, re.DOTALL,
r"|~~.+?~~", # ~~strikethrough~~
re.DOTALL,
)
# Markdown link: [text](url)
@ -698,14 +814,16 @@ class FeishuChannel(BaseChannel):
for m in cls._MD_LINK_RE.finditer(line):
# Text before this link
before = line[last_end:m.start()]
before = line[last_end : m.start()]
if before:
elements.append({"tag": "text", "text": before})
elements.append({
"tag": "a",
"text": m.group(1),
"href": m.group(2),
})
elements.append(
{
"tag": "a",
"text": m.group(1),
"href": m.group(2),
}
)
last_end = m.end()
# Remaining text after last link
@ -730,29 +848,39 @@ class FeishuChannel(BaseChannel):
_AUDIO_EXTS = {".opus"}
_VIDEO_EXTS = {".mp4", ".mov", ".avi"}
_FILE_TYPE_MAP = {
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
".opus": "opus",
".mp4": "mp4",
".pdf": "pdf",
".doc": "doc",
".docx": "doc",
".xls": "xls",
".xlsx": "xls",
".ppt": "ppt",
".pptx": "ppt",
}
def _upload_image_sync(self, file_path: str) -> str | None:
"""Upload an image to Feishu and return the image_key."""
from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody
try:
with open(file_path, "rb") as f:
request = CreateImageRequest.builder() \
request = (
CreateImageRequest.builder()
.request_body(
CreateImageRequestBody.builder()
.image_type("message")
.image(f)
.build()
).build()
CreateImageRequestBody.builder().image_type("message").image(f).build()
)
.build()
)
response = self._client.im.v1.image.create(request)
if response.success():
image_key = response.data.image_key
logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key)
return image_key
else:
logger.error("Failed to upload image: code={}, msg={}", response.code, response.msg)
logger.error(
"Failed to upload image: code={}, msg={}", response.code, response.msg
)
return None
except Exception as e:
logger.error("Error uploading image {}: {}", file_path, e)
@ -761,49 +889,62 @@ class FeishuChannel(BaseChannel):
def _upload_file_sync(self, file_path: str) -> str | None:
"""Upload a file to Feishu and return the file_key."""
from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody
ext = os.path.splitext(file_path)[1].lower()
file_type = self._FILE_TYPE_MAP.get(ext, "stream")
file_name = os.path.basename(file_path)
try:
with open(file_path, "rb") as f:
request = CreateFileRequest.builder() \
request = (
CreateFileRequest.builder()
.request_body(
CreateFileRequestBody.builder()
.file_type(file_type)
.file_name(file_name)
.file(f)
.build()
).build()
)
.build()
)
response = self._client.im.v1.file.create(request)
if response.success():
file_key = response.data.file_key
logger.debug("Uploaded file {}: {}", file_name, file_key)
return file_key
else:
logger.error("Failed to upload file: code={}, msg={}", response.code, response.msg)
logger.error(
"Failed to upload file: code={}, msg={}", response.code, response.msg
)
return None
except Exception as e:
logger.error("Error uploading file {}: {}", file_path, e)
return None
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
def _download_image_sync(
self, message_id: str, image_key: str
) -> tuple[bytes | None, str | None]:
"""Download an image from Feishu message by message_id and image_key."""
from lark_oapi.api.im.v1 import GetMessageResourceRequest
try:
request = GetMessageResourceRequest.builder() \
.message_id(message_id) \
.file_key(image_key) \
.type("image") \
request = (
GetMessageResourceRequest.builder()
.message_id(message_id)
.file_key(image_key)
.type("image")
.build()
)
response = self._client.im.v1.message_resource.get(request)
if response.success():
file_data = response.file
# GetMessageResourceRequest returns BytesIO, need to read bytes
if hasattr(file_data, 'read'):
if hasattr(file_data, "read"):
file_data = file_data.read()
return file_data, response.file_name
else:
logger.error("Failed to download image: code={}, msg={}", response.code, response.msg)
logger.error(
"Failed to download image: code={}, msg={}", response.code, response.msg
)
return None, None
except Exception as e:
logger.error("Error downloading image {}: {}", image_key, e)
@ -835,17 +976,19 @@ class FeishuChannel(BaseChannel):
file_data = file_data.read()
return file_data, response.file_name
else:
logger.error("Failed to download {}: code={}, msg={}", resource_type, response.code, response.msg)
logger.error(
"Failed to download {}: code={}, msg={}",
resource_type,
response.code,
response.msg,
)
return None, None
except Exception:
logger.exception("Error downloading {} {}", resource_type, file_key)
return None, None
async def _download_and_save_media(
self,
msg_type: str,
content_json: dict,
message_id: str | None = None
self, msg_type: str, content_json: dict, message_id: str | None = None
) -> tuple[str | None, str]:
"""
Download media from Feishu and save to local disk.
@ -894,13 +1037,16 @@ class FeishuChannel(BaseChannel):
Returns a "[Reply to: ...]" context string, or None on failure.
"""
from lark_oapi.api.im.v1 import GetMessageRequest
try:
request = GetMessageRequest.builder().message_id(message_id).build()
response = self._client.im.v1.message.get(request)
if not response.success():
logger.debug(
"Feishu: could not fetch parent message {}: code={}, msg={}",
message_id, response.code, response.msg,
message_id,
response.code,
response.msg,
)
return None
items = getattr(response.data, "items", None)
@ -935,20 +1081,24 @@ class FeishuChannel(BaseChannel):
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str) -> bool:
"""Reply to an existing Feishu message using the Reply API (synchronous)."""
from lark_oapi.api.im.v1 import ReplyMessageRequest, ReplyMessageRequestBody
try:
request = ReplyMessageRequest.builder() \
.message_id(parent_message_id) \
request = (
ReplyMessageRequest.builder()
.message_id(parent_message_id)
.request_body(
ReplyMessageRequestBody.builder()
.msg_type(msg_type)
.content(content)
.build()
).build()
ReplyMessageRequestBody.builder().msg_type(msg_type).content(content).build()
)
.build()
)
response = self._client.im.v1.message.reply(request)
if not response.success():
logger.error(
"Failed to reply to Feishu message {}: code={}, msg={}, log_id={}",
parent_message_id, response.code, response.msg, response.get_log_id()
parent_message_id,
response.code,
response.msg,
response.get_log_id(),
)
return False
logger.debug("Feishu reply sent to message {}", parent_message_id)
@ -957,24 +1107,33 @@ class FeishuChannel(BaseChannel):
logger.error("Error replying to Feishu message {}: {}", parent_message_id, e)
return False
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> str | None:
def _send_message_sync(
self, receive_id_type: str, receive_id: str, msg_type: str, content: str
) -> str | None:
"""Send a single message and return the message_id on success."""
from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody
try:
request = CreateMessageRequest.builder() \
.receive_id_type(receive_id_type) \
request = (
CreateMessageRequest.builder()
.receive_id_type(receive_id_type)
.request_body(
CreateMessageRequestBody.builder()
.receive_id(receive_id)
.msg_type(msg_type)
.content(content)
.build()
).build()
)
.build()
)
response = self._client.im.v1.message.create(request)
if not response.success():
logger.error(
"Failed to send Feishu {} message: code={}, msg={}, log_id={}",
msg_type, response.code, response.msg, response.get_log_id()
msg_type,
response.code,
response.msg,
response.get_log_id(),
)
return None
msg_id = getattr(response.data, "message_id", None)
@ -987,31 +1146,44 @@ class FeishuChannel(BaseChannel):
def _create_streaming_card_sync(self, receive_id_type: str, chat_id: str) -> str | None:
"""Create a CardKit streaming card, send it to chat, return card_id."""
from lark_oapi.api.cardkit.v1 import CreateCardRequest, CreateCardRequestBody
card_json = {
"schema": "2.0",
"config": {"wide_screen_mode": True, "update_multi": True, "streaming_mode": True},
"body": {"elements": [{"tag": "markdown", "content": "", "element_id": _STREAM_ELEMENT_ID}]},
"body": {
"elements": [{"tag": "markdown", "content": "", "element_id": _STREAM_ELEMENT_ID}]
},
}
try:
request = CreateCardRequest.builder().request_body(
CreateCardRequestBody.builder()
.type("card_json")
.data(json.dumps(card_json, ensure_ascii=False))
request = (
CreateCardRequest.builder()
.request_body(
CreateCardRequestBody.builder()
.type("card_json")
.data(json.dumps(card_json, ensure_ascii=False))
.build()
)
.build()
).build()
)
response = self._client.cardkit.v1.card.create(request)
if not response.success():
logger.warning("Failed to create streaming card: code={}, msg={}", response.code, response.msg)
logger.warning(
"Failed to create streaming card: code={}, msg={}", response.code, response.msg
)
return None
card_id = getattr(response.data, "card_id", None)
if card_id:
message_id = self._send_message_sync(
receive_id_type, chat_id, "interactive",
receive_id_type,
chat_id,
"interactive",
json.dumps({"type": "card", "data": {"card_id": card_id}}),
)
if message_id:
return card_id
logger.warning("Created streaming card {} but failed to send it to {}", card_id, chat_id)
logger.warning(
"Created streaming card {} but failed to send it to {}", card_id, chat_id
)
return None
except Exception as e:
logger.warning("Error creating streaming card: {}", e)
@ -1019,18 +1191,32 @@ class FeishuChannel(BaseChannel):
def _stream_update_text_sync(self, card_id: str, content: str, sequence: int) -> bool:
"""Stream-update the markdown element on a CardKit card (typewriter effect)."""
from lark_oapi.api.cardkit.v1 import ContentCardElementRequest, ContentCardElementRequestBody
from lark_oapi.api.cardkit.v1 import (
ContentCardElementRequest,
ContentCardElementRequestBody,
)
try:
request = ContentCardElementRequest.builder() \
.card_id(card_id) \
.element_id(_STREAM_ELEMENT_ID) \
request = (
ContentCardElementRequest.builder()
.card_id(card_id)
.element_id(_STREAM_ELEMENT_ID)
.request_body(
ContentCardElementRequestBody.builder()
.content(content).sequence(sequence).build()
).build()
.content(content)
.sequence(sequence)
.build()
)
.build()
)
response = self._client.cardkit.v1.card_element.content(request)
if not response.success():
logger.warning("Failed to stream-update card {}: code={}, msg={}", card_id, response.code, response.msg)
logger.warning(
"Failed to stream-update card {}: code={}, msg={}",
card_id,
response.code,
response.msg,
)
return False
return True
except Exception as e:
@ -1045,22 +1231,28 @@ class FeishuChannel(BaseChannel):
Sequence must strictly exceed the previous card OpenAPI operation on this entity.
"""
from lark_oapi.api.cardkit.v1 import SettingsCardRequest, SettingsCardRequestBody
settings_payload = json.dumps({"config": {"streaming_mode": False}}, ensure_ascii=False)
try:
request = SettingsCardRequest.builder() \
.card_id(card_id) \
request = (
SettingsCardRequest.builder()
.card_id(card_id)
.request_body(
SettingsCardRequestBody.builder()
.settings(settings_payload)
.sequence(sequence)
.uuid(str(uuid.uuid4()))
.build()
).build()
)
.build()
)
response = self._client.cardkit.v1.card.settings(request)
if not response.success():
logger.warning(
"Failed to close streaming on card {}: code={}, msg={}",
card_id, response.code, response.msg,
card_id,
response.code,
response.msg,
)
return False
return True
@ -1068,7 +1260,9 @@ class FeishuChannel(BaseChannel):
logger.warning("Error closing streaming on card {}: {}", card_id, e)
return False
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
async def send_delta(
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
) -> None:
"""Progressive streaming via CardKit: create card on first delta, stream-update on subsequent."""
if not self._client:
return
@ -1087,17 +1281,31 @@ class FeishuChannel(BaseChannel):
if buf.card_id:
buf.sequence += 1
await loop.run_in_executor(
None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence,
None,
self._stream_update_text_sync,
buf.card_id,
buf.text,
buf.sequence,
)
# Required so the chat list preview exits the streaming placeholder (Feishu streaming card docs).
buf.sequence += 1
await loop.run_in_executor(
None, self._close_streaming_mode_sync, buf.card_id, buf.sequence,
None,
self._close_streaming_mode_sync,
buf.card_id,
buf.sequence,
)
else:
for chunk in self._split_elements_by_table_limit(self._build_card_elements(buf.text)):
card = json.dumps({"config": {"wide_screen_mode": True}, "elements": chunk}, ensure_ascii=False)
await loop.run_in_executor(None, self._send_message_sync, rid_type, chat_id, "interactive", card)
for chunk in self._split_elements_by_table_limit(
self._build_card_elements(buf.text)
):
card = json.dumps(
{"config": {"wide_screen_mode": True}, "elements": chunk},
ensure_ascii=False,
)
await loop.run_in_executor(
None, self._send_message_sync, rid_type, chat_id, "interactive", card
)
return
# --- accumulate delta ---
@ -1111,15 +1319,21 @@ class FeishuChannel(BaseChannel):
now = time.monotonic()
if buf.card_id is None:
card_id = await loop.run_in_executor(None, self._create_streaming_card_sync, rid_type, chat_id)
card_id = await loop.run_in_executor(
None, self._create_streaming_card_sync, rid_type, chat_id
)
if card_id:
buf.card_id = card_id
buf.sequence = 1
await loop.run_in_executor(None, self._stream_update_text_sync, card_id, buf.text, 1)
await loop.run_in_executor(
None, self._stream_update_text_sync, card_id, buf.text, 1
)
buf.last_edit = now
elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
buf.sequence += 1
await loop.run_in_executor(None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence)
await loop.run_in_executor(
None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence
)
buf.last_edit = now
async def send(self, msg: OutboundMessage) -> None:
@ -1145,14 +1359,13 @@ class FeishuChannel(BaseChannel):
# Only the very first send (media or text) in this call uses reply; subsequent
# chunks/media fall back to plain create to avoid redundant quote bubbles.
reply_message_id: str | None = None
if (
self.config.reply_to_message
and not msg.metadata.get("_progress", False)
):
if self.config.reply_to_message and not msg.metadata.get("_progress", False):
reply_message_id = msg.metadata.get("message_id") or None
# For topic group messages, always reply to keep context in thread
elif msg.metadata.get("thread_id"):
reply_message_id = msg.metadata.get("root_id") or msg.metadata.get("message_id") or None
reply_message_id = (
msg.metadata.get("root_id") or msg.metadata.get("message_id") or None
)
first_send = True # tracks whether the reply has already been used
@ -1176,8 +1389,10 @@ class FeishuChannel(BaseChannel):
key = await loop.run_in_executor(None, self._upload_image_sync, file_path)
if key:
await loop.run_in_executor(
None, _do_send,
"image", json.dumps({"image_key": key}, ensure_ascii=False),
None,
_do_send,
"image",
json.dumps({"image_key": key}, ensure_ascii=False),
)
else:
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
@ -1192,8 +1407,10 @@ class FeishuChannel(BaseChannel):
else:
media_type = "file"
await loop.run_in_executor(
None, _do_send,
media_type, json.dumps({"file_key": key}, ensure_ascii=False),
None,
_do_send,
media_type,
json.dumps({"file_key": key}, ensure_ascii=False),
)
if msg.content and msg.content.strip():
@ -1215,8 +1432,10 @@ class FeishuChannel(BaseChannel):
for chunk in self._split_elements_by_table_limit(elements):
card = {"config": {"wide_screen_mode": True}, "elements": chunk}
await loop.run_in_executor(
None, _do_send,
"interactive", json.dumps(card, ensure_ascii=False),
None,
_do_send,
"interactive",
json.dumps(card, ensure_ascii=False),
)
except Exception as e:
@ -1231,13 +1450,16 @@ class FeishuChannel(BaseChannel):
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
async def _on_message(self, data: Any) -> None:
async def _on_message(self, data: P2ImMessageReceiveV1) -> None:
"""Handle incoming message from Feishu."""
try:
event = data.event
message = event.message
sender = event.sender
logger.debug("Feishu raw message: {}", message.content)
logger.debug("Feishu mentions: {}", getattr(message, "mentions", None))
# Deduplication check
message_id = message.message_id
if message_id in self._processed_message_ids:
@ -1276,6 +1498,8 @@ class FeishuChannel(BaseChannel):
if msg_type == "text":
text = content_json.get("text", "")
if text:
mentions = getattr(message, "mentions", None)
text = self._resolve_mentions(text, mentions)
content_parts.append(text)
elif msg_type == "post":
@ -1292,7 +1516,9 @@ class FeishuChannel(BaseChannel):
content_parts.append(content_text)
elif msg_type in ("image", "audio", "file", "media"):
file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
file_path, content_text = await self._download_and_save_media(
msg_type, content_json, message_id
)
if file_path:
media_paths.append(file_path)
@ -1303,7 +1529,14 @@ class FeishuChannel(BaseChannel):
content_parts.append(content_text)
elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
elif msg_type in (
"share_chat",
"share_user",
"interactive",
"share_calendar_event",
"system",
"merge_forward",
):
# Handle share cards and interactive messages
text = _extract_share_card_content(content_json, msg_type)
if text:
@ -1346,7 +1579,7 @@ class FeishuChannel(BaseChannel):
"parent_id": parent_id,
"root_id": root_id,
"thread_id": thread_id,
}
},
)
except Exception as e:
@ -1356,6 +1589,10 @@ class FeishuChannel(BaseChannel):
"""Ignore reaction events so they do not generate SDK noise."""
pass
def _on_reaction_deleted(self, data: Any) -> None:
"""Ignore reaction deleted events so they do not generate SDK noise."""
pass
def _on_message_read(self, data: Any) -> None:
"""Ignore read events so they do not generate SDK noise."""
pass
@ -1411,7 +1648,9 @@ class FeishuChannel(BaseChannel):
return "\n".join(part for part in parts if part)
async def _send_tool_hint_card(self, receive_id_type: str, receive_id: str, tool_hint: str) -> None:
async def _send_tool_hint_card(
self, receive_id_type: str, receive_id: str, tool_hint: str
) -> None:
"""Send tool hint as an interactive card with formatted code block.
Args:
@ -1427,15 +1666,15 @@ class FeishuChannel(BaseChannel):
card = {
"config": {"wide_screen_mode": True},
"elements": [
{
"tag": "markdown",
"content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```"
}
]
{"tag": "markdown", "content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```"}
],
}
await loop.run_in_executor(
None, self._send_message_sync,
receive_id_type, receive_id, "interactive",
None,
self._send_message_sync,
receive_id_type,
receive_id,
"interactive",
json.dumps(card, ensure_ascii=False),
)

View File

@ -39,7 +39,8 @@ class ChannelManager:
"""Initialize channels discovered via pkgutil scan + entry_points plugins."""
from nanobot.channels.registry import discover_all
groq_key = self.config.providers.groq.api_key
transcription_provider = self.config.channels.transcription_provider
transcription_key = self._resolve_transcription_key(transcription_provider)
for name, cls in discover_all().items():
section = getattr(self.config.channels, name, None)
@ -54,7 +55,8 @@ class ChannelManager:
continue
try:
channel = cls(section, self.bus)
channel.transcription_api_key = groq_key
channel.transcription_provider = transcription_provider
channel.transcription_api_key = transcription_key
self.channels[name] = channel
logger.info("{} channel enabled", cls.display_name)
except Exception as e:
@ -62,6 +64,15 @@ class ChannelManager:
self._validate_allow_from()
def _resolve_transcription_key(self, provider: str) -> str:
"""Pick the API key for the configured transcription provider."""
try:
if provider == "openai":
return self.config.providers.openai.api_key
return self.config.providers.groq.api_key
except AttributeError:
return ""
def _validate_allow_from(self) -> None:
for name, ch in self.channels.items():
if getattr(ch.config, "allow_from", None) == []:

View File

@ -1,6 +1,7 @@
"""Matrix (Element) channel — inbound sync + outbound message/media delivery."""
import asyncio
import json
import logging
import mimetypes
import time
@ -17,10 +18,10 @@ try:
from nio import (
AsyncClient,
AsyncClientConfig,
ContentRepositoryConfigError,
DownloadError,
InviteEvent,
JoinError,
LoginResponse,
MatrixRoom,
MemoryDownloadResponse,
RoomEncryptedMedia,
@ -203,10 +204,11 @@ class MatrixConfig(Base):
enabled: bool = False
homeserver: str = "https://matrix.org"
access_token: str = ""
user_id: str = ""
password: str = ""
access_token: str = ""
device_id: str = ""
e2ee_enabled: bool = True
e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled")
sync_stop_grace_seconds: int = 2
max_media_bytes: int = 20 * 1024 * 1024
allow_from: list[str] = Field(default_factory=list)
@ -256,17 +258,15 @@ class MatrixChannel(BaseChannel):
self._running = True
_configure_nio_logging_bridge()
store_path = get_data_dir() / "matrix-store"
store_path.mkdir(parents=True, exist_ok=True)
self.store_path = get_data_dir() / "matrix-store"
self.store_path.mkdir(parents=True, exist_ok=True)
self.session_path = self.store_path / "session.json"
self.client = AsyncClient(
homeserver=self.config.homeserver, user=self.config.user_id,
store_path=store_path,
store_path=self.store_path,
config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
)
self.client.user_id = self.config.user_id
self.client.access_token = self.config.access_token
self.client.device_id = self.config.device_id
self._register_event_callbacks()
self._register_response_callbacks()
@ -274,13 +274,49 @@ class MatrixChannel(BaseChannel):
if not self.config.e2ee_enabled:
logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
if self.config.device_id:
if self.config.password:
if self.config.access_token or self.config.device_id:
logger.warning("Password-based Matrix login active; access_token and device_id fields will be ignored.")
create_new_session = True
if self.session_path.exists():
logger.info("Found session.json at {}; attempting to use existing session...", self.session_path)
try:
with open(self.session_path, "r", encoding="utf-8") as f:
session = json.load(f)
self.client.user_id = self.config.user_id
self.client.access_token = session["access_token"]
self.client.device_id = session["device_id"]
self.client.load_store()
logger.info("Successfully loaded from existing session")
create_new_session = False
except Exception as e:
logger.warning("Failed to load from existing session: {}", e)
logger.info("Falling back to password login...")
if create_new_session:
logger.info("Using password login...")
resp = await self.client.login(self.config.password)
if isinstance(resp, LoginResponse):
logger.info("Logged in using a password; saving details to disk")
self._write_session_to_disk(resp)
else:
logger.error("Failed to log in: {}", resp)
return
elif self.config.access_token and self.config.device_id:
try:
self.client.user_id = self.config.user_id
self.client.access_token = self.config.access_token
self.client.device_id = self.config.device_id
self.client.load_store()
except Exception:
logger.exception("Matrix store load failed; restart may replay recent messages.")
logger.info("Successfully loaded from existing session")
except Exception as e:
logger.warning("Failed to load from existing session: {}", e)
else:
logger.warning("Matrix device_id empty; restart may replay recent messages.")
logger.warning("Unable to load a Matrix session due to missing password, access_token, or device_id; encryption may not work")
return
self._sync_task = asyncio.create_task(self._sync_loop())
@ -304,6 +340,19 @@ class MatrixChannel(BaseChannel):
if self.client:
await self.client.close()
def _write_session_to_disk(self, resp: LoginResponse) -> None:
"""Save login session to disk for persistence across restarts."""
session = {
"access_token": resp.access_token,
"device_id": resp.device_id,
}
try:
with open(self.session_path, "w", encoding="utf-8") as f:
json.dump(session, f, indent=2)
logger.info("Session saved to {}", self.session_path)
except Exception as e:
logger.warning("Failed to save session: {}", e)
def _is_workspace_path_allowed(self, path: Path) -> bool:
"""Check path is inside workspace (when restriction enabled)."""
if not self._restrict_to_workspace or not self._workspace:

View File

@ -6,14 +6,14 @@ import asyncio
import re
import time
import unicodedata
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Literal
from loguru import logger
from pydantic import Field
from telegram import BotCommand, ReactionTypeEmoji, ReplyParameters, Update
from telegram.error import BadRequest, NetworkError, TimedOut
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
from telegram.ext import Application, ContextTypes, MessageHandler, filters
from telegram.request import HTTPXRequest
from nanobot.bus.events import OutboundMessage
@ -558,8 +558,10 @@ class TelegramChannel(BaseChannel):
await self._remove_reaction(chat_id, int(reply_to_message_id))
except ValueError:
pass
chunks = split_message(buf.text, TELEGRAM_MAX_MESSAGE_LEN)
primary_text = chunks[0] if chunks else buf.text
try:
html = _markdown_to_telegram_html(buf.text)
html = _markdown_to_telegram_html(primary_text)
await self._call_with_retry(
self._app.bot.edit_message_text,
chat_id=int_chat_id, message_id=buf.message_id,
@ -575,15 +577,18 @@ class TelegramChannel(BaseChannel):
await self._call_with_retry(
self._app.bot.edit_message_text,
chat_id=int_chat_id, message_id=buf.message_id,
text=buf.text,
text=primary_text,
)
except Exception as e2:
if self._is_not_modified_error(e2):
logger.debug("Final stream plain edit already applied for {}", chat_id)
self._stream_bufs.pop(chat_id, None)
return
logger.warning("Final stream edit failed: {}", e2)
raise # Let ChannelManager handle retry
else:
logger.warning("Final stream edit failed: {}", e2)
raise # Let ChannelManager handle retry
# If final content exceeds Telegram limit, keep the first chunk in
# the edited stream message and send the rest as follow-up messages.
for extra_chunk in chunks[1:]:
await self._send_text(int_chat_id, extra_chunk)
self._stream_bufs.pop(chat_id, None)
return
@ -599,11 +604,15 @@ class TelegramChannel(BaseChannel):
return
now = time.monotonic()
thread_kwargs = {}
if message_thread_id := meta.get("message_thread_id"):
thread_kwargs["message_thread_id"] = message_thread_id
if buf.message_id is None:
try:
sent = await self._call_with_retry(
self._app.bot.send_message,
chat_id=int_chat_id, text=buf.text,
**thread_kwargs,
)
buf.message_id = sent.message_id
buf.last_edit = now
@ -651,9 +660,9 @@ class TelegramChannel(BaseChannel):
@staticmethod
def _derive_topic_session_key(message) -> str | None:
"""Derive topic-scoped session key for non-private Telegram chats."""
"""Derive topic-scoped session key for Telegram chats with threads."""
message_thread_id = getattr(message, "message_thread_id", None)
if message.chat.type == "private" or message_thread_id is None:
if message_thread_id is None:
return None
return f"telegram:{message.chat_id}:topic:{message_thread_id}"
@ -815,7 +824,7 @@ class TelegramChannel(BaseChannel):
return bool(bot_id and reply_user and reply_user.id == bot_id)
def _remember_thread_context(self, message) -> None:
"""Cache topic thread id by chat/message id for follow-up replies."""
"""Cache Telegram thread context by chat/message id for follow-up replies."""
message_thread_id = getattr(message, "message_thread_id", None)
if message_thread_id is None:
return

View File

@ -484,7 +484,7 @@ class WeixinChannel(BaseChannel):
except httpx.TimeoutException:
# Normal for long-poll, just retry
continue
except Exception as e:
except Exception:
if not self._running:
break
consecutive_failures += 1

View File

@ -75,6 +75,7 @@ class WhatsAppChannel(BaseChannel):
self._ws = None
self._connected = False
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
self._lid_to_phone: dict[str, str] = {}
self._bridge_token: str | None = None
def _effective_bridge_token(self) -> str:
@ -228,21 +229,45 @@ class WhatsAppChannel(BaseChannel):
if not was_mentioned:
return
user_id = pn if pn else sender
sender_id = user_id.split("@")[0] if "@" in user_id else user_id
logger.info("Sender {}", sender)
# Classify by JID suffix: @s.whatsapp.net = phone, @lid.whatsapp.net = LID
# The bridge's pn/sender fields don't consistently map to phone/LID across versions.
raw_a = pn or ""
raw_b = sender or ""
id_a = raw_a.split("@")[0] if "@" in raw_a else raw_a
id_b = raw_b.split("@")[0] if "@" in raw_b else raw_b
# Handle voice transcription if it's a voice message
if content == "[Voice Message]":
logger.info(
"Voice message received from {}, but direct download from bridge is not yet supported.",
sender_id,
)
content = "[Voice Message: Transcription not available for WhatsApp yet]"
phone_id = ""
lid_id = ""
for raw, extracted in [(raw_a, id_a), (raw_b, id_b)]:
if "@s.whatsapp.net" in raw:
phone_id = extracted
elif "@lid.whatsapp.net" in raw:
lid_id = extracted
elif extracted and not phone_id:
phone_id = extracted # best guess for bare values
if phone_id and lid_id:
self._lid_to_phone[lid_id] = phone_id
sender_id = phone_id or self._lid_to_phone.get(lid_id, "") or lid_id or id_a or id_b
logger.info("Sender phone={} lid={} → sender_id={}", phone_id or "(empty)", lid_id or "(empty)", sender_id)
# Extract media paths (images/documents/videos downloaded by the bridge)
media_paths = data.get("media") or []
# Handle voice transcription if it's a voice message
if content == "[Voice Message]":
if media_paths:
logger.info("Transcribing voice message from {}...", sender_id)
transcription = await self.transcribe_audio(media_paths[0])
if transcription:
content = transcription
logger.info("Transcribed voice from {}: {}...", sender_id, transcription[:50])
else:
content = "[Voice Message: Transcription failed]"
else:
content = "[Voice Message: Audio not available]"
# Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
if media_paths:
for p in media_paths:

View File

@ -1,12 +1,11 @@
"""CLI commands for nanobot."""
import asyncio
from contextlib import contextmanager, nullcontext
import os
import select
import signal
import sys
from contextlib import nullcontext
from pathlib import Path
from typing import Any
@ -34,6 +33,19 @@ from rich.table import Table
from rich.text import Text
from nanobot import __logo__, __version__
class SafeFileHistory(FileHistory):
"""FileHistory subclass that sanitizes surrogate characters on write.
On Windows, special Unicode input (emoji, mixed-script) can produce
surrogate characters that crash prompt_toolkit's file write.
See issue #2846.
"""
def store_string(self, string: str) -> None:
safe = string.encode("utf-8", errors="surrogateescape").decode("utf-8", errors="replace")
super().store_string(safe)
from nanobot.cli.stream import StreamRenderer, ThinkingSpinner
from nanobot.config.paths import get_workspace_path, is_default_workspace
from nanobot.config.schema import Config
@ -73,6 +85,7 @@ def _flush_pending_tty_input() -> None:
try:
import termios
termios.tcflush(fd, termios.TCIFLUSH)
return
except Exception:
@ -95,6 +108,7 @@ def _restore_terminal() -> None:
return
try:
import termios
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, _SAVED_TERM_ATTRS)
except Exception:
pass
@ -107,6 +121,7 @@ def _init_prompt_session() -> None:
# Save terminal state so we can restore it on exit
try:
import termios
_SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno())
except Exception:
pass
@ -117,9 +132,9 @@ def _init_prompt_session() -> None:
history_file.parent.mkdir(parents=True, exist_ok=True)
_PROMPT_SESSION = PromptSession(
history=FileHistory(str(history_file)),
history=SafeFileHistory(str(history_file)),
enable_open_in_editor=False,
multiline=False, # Enter submits (single line mode)
multiline=False, # Enter submits (single line mode)
)
@ -231,7 +246,6 @@ async def _read_interactive_input_async() -> str:
raise KeyboardInterrupt from exc
def version_callback(value: bool):
if value:
console.print(f"{__logo__} nanobot v{__version__}")
@ -281,8 +295,12 @@ def onboard(
config = _apply_workspace_override(load_config(config_path))
else:
console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields")
console.print(
" [bold]y[/bold] = overwrite with defaults (existing values will be lost)"
)
console.print(
" [bold]N[/bold] = refresh config, keeping existing values and adding new fields"
)
if typer.confirm("Overwrite?"):
config = _apply_workspace_override(Config())
save_config(config, config_path)
@ -290,7 +308,9 @@ def onboard(
else:
config = _apply_workspace_override(load_config(config_path))
save_config(config, config_path)
console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)")
console.print(
f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)"
)
else:
config = _apply_workspace_override(Config())
# In wizard mode, don't save yet - the wizard will handle saving if should_save=True
@ -340,7 +360,9 @@ def onboard(
console.print(f" 1. Add your API key to [cyan]{config_path}[/cyan]")
console.print(" Get one at: https://openrouter.ai/keys")
console.print(f" 2. Chat: [cyan]{agent_cmd}[/cyan]")
console.print("\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]")
console.print(
"\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]"
)
def _merge_missing_defaults(existing: Any, defaults: Any) -> Any:
@ -413,9 +435,11 @@ def _make_provider(config: Config):
# --- instantiation by backend ---
if backend == "openai_codex":
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
provider = OpenAICodexProvider(default_model=model)
elif backend == "azure_openai":
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
provider = AzureOpenAIProvider(
api_key=p.api_key,
api_base=p.api_base,
@ -426,6 +450,7 @@ def _make_provider(config: Config):
provider = GitHubCopilotProvider(default_model=model)
elif backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider
provider = AnthropicProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
@ -434,6 +459,7 @@ def _make_provider(config: Config):
)
else:
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
provider = OpenAICompatProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
@ -453,7 +479,7 @@ def _make_provider(config: Config):
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
"""Load config and optionally override the active workspace."""
from nanobot.config.loader import load_config, set_config_path
from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path
config_path = None
if config:
@ -464,7 +490,11 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
set_config_path(config_path)
console.print(f"[dim]Using config: {config_path}[/dim]")
loaded = load_config(config_path)
try:
loaded = resolve_config_env_vars(load_config(config_path))
except ValueError as e:
console.print(f"[red]Error: {e}[/red]")
raise typer.Exit(1)
_warn_deprecated_config_keys(config_path)
if workspace:
loaded.agents.defaults.workspace = workspace
@ -474,6 +504,7 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None
def _warn_deprecated_config_keys(config_path: Path | None) -> None:
"""Hint users to remove obsolete keys from their config file."""
import json
from nanobot.config.loader import get_config_path
path = config_path or get_config_path()
@ -497,6 +528,7 @@ def _migrate_cron_store(config: "Config") -> None:
if legacy_path.is_file() and not new_path.exists():
new_path.parent.mkdir(parents=True, exist_ok=True)
import shutil
shutil.move(str(legacy_path), str(new_path))
@ -610,6 +642,7 @@ def gateway(
if verbose:
import logging
logging.basicConfig(level=logging.DEBUG)
config = _load_runtime_config(config, workspace)
@ -695,7 +728,7 @@ def gateway(
if job.payload.deliver and job.payload.to and response:
should_notify = await evaluate_response(
response, job.payload.message, provider, agent.model,
response, reminder_note, provider, agent.model,
)
if should_notify:
from nanobot.bus.events import OutboundMessage
@ -705,6 +738,7 @@ def gateway(
content=response,
))
return response
cron.on_job = on_cron_job
# Create channel manager
@ -808,6 +842,7 @@ def gateway(
console.print("\nShutting down...")
except Exception:
import traceback
console.print("\n[red]Error: Gateway crashed unexpectedly[/red]")
console.print(traceback.format_exc())
finally:
@ -820,8 +855,6 @@ def gateway(
asyncio.run(run())
# ============================================================================
# Agent Commands
# ============================================================================
@ -1296,6 +1329,7 @@ def _register_login(name: str):
def decorator(fn):
_LOGIN_HANDLERS[name] = fn
return fn
return decorator
@ -1326,6 +1360,7 @@ def provider_login(
def _login_openai_codex() -> None:
try:
from oauth_cli_kit import get_token, login_oauth_interactive
token = None
try:
token = get_token()

View File

@ -60,6 +60,20 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
pass
if ctx_est <= 0:
ctx_est = loop._last_usage.get("prompt_tokens", 0)
# Fetch web search provider usage (best-effort, never blocks the response)
search_usage_text: str | None = None
try:
from nanobot.utils.searchusage import fetch_search_usage
web_cfg = getattr(loop, "web_config", None)
search_cfg = getattr(web_cfg, "search", None) if web_cfg else None
if search_cfg is not None:
provider = getattr(search_cfg, "provider", "duckduckgo")
api_key = getattr(search_cfg, "api_key", "") or None
usage = await fetch_search_usage(provider=provider, api_key=api_key)
search_usage_text = usage.format()
except Exception:
pass # Never let usage fetch break /status
return OutboundMessage(
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
@ -69,6 +83,7 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
context_window_tokens=loop.context_window_tokens,
session_msg_count=len(session.get_history(max_messages=0)),
context_tokens_estimate=ctx_est,
search_usage_text=search_usage_text,
),
metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"},
)
@ -93,14 +108,30 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
async def cmd_dream(ctx: CommandContext) -> OutboundMessage:
"""Manually trigger a Dream consolidation run."""
import time
loop = ctx.loop
try:
did_work = await loop.dream.run()
content = "Dream completed." if did_work else "Dream: nothing to process."
except Exception as e:
content = f"Dream failed: {e}"
msg = ctx.msg
async def _run_dream():
t0 = time.monotonic()
try:
did_work = await loop.dream.run()
elapsed = time.monotonic() - t0
if did_work:
content = f"Dream completed in {elapsed:.1f}s."
else:
content = "Dream: nothing to process."
except Exception as e:
elapsed = time.monotonic() - t0
content = f"Dream failed after {elapsed:.1f}s: {e}"
await loop.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content,
))
asyncio.create_task(_run_dream())
return OutboundMessage(
channel=ctx.msg.channel, chat_id=ctx.msg.chat_id, content=content,
channel=msg.channel, chat_id=msg.chat_id, content="Dreaming...",
)

View File

@ -1,6 +1,8 @@
"""Configuration loading utilities."""
import json
import os
import re
from pathlib import Path
import pydantic
@ -76,6 +78,38 @@ def save_config(config: Config, config_path: Path | None = None) -> None:
json.dump(data, f, indent=2, ensure_ascii=False)
def resolve_config_env_vars(config: Config) -> Config:
"""Return a copy of *config* with ``${VAR}`` env-var references resolved.
Only string values are affected; other types pass through unchanged.
Raises :class:`ValueError` if a referenced variable is not set.
"""
data = config.model_dump(mode="json", by_alias=True)
data = _resolve_env_vars(data)
return Config.model_validate(data)
def _resolve_env_vars(obj: object) -> object:
"""Recursively resolve ``${VAR}`` patterns in string values."""
if isinstance(obj, str):
return re.sub(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", _env_replace, obj)
if isinstance(obj, dict):
return {k: _resolve_env_vars(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_resolve_env_vars(v) for v in obj]
return obj
def _env_replace(match: re.Match[str]) -> str:
name = match.group(1)
value = os.environ.get(name)
if value is None:
raise ValueError(
f"Environment variable '{name}' referenced in config is not set"
)
return value
def _migrate_config(data: dict) -> dict:
"""Migrate old config formats to current."""
# Move tools.exec.restrictToWorkspace → tools.restrictToWorkspace

View File

@ -28,6 +28,7 @@ class ChannelsConfig(Base):
send_progress: bool = True # stream agent's text progress to the channel
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included)
transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai"
class DreamConfig(Base):
@ -155,6 +156,7 @@ class WebSearchConfig(Base):
api_key: str = ""
base_url: str = "" # SearXNG base URL
max_results: int = 5
timeout: int = 30 # Wall-clock timeout (seconds) for search operations
class WebToolsConfig(Base):
@ -173,6 +175,7 @@ class ExecToolConfig(Base):
enable: bool = True
timeout: int = 60
path_append: str = ""
sandbox: str = "" # sandbox backend: "" (none) or "bwrap"
class MCPServerConfig(Base):
"""MCP server connection configuration (stdio or HTTP)."""
@ -191,7 +194,7 @@ class ToolsConfig(Base):
web: WebToolsConfig = Field(default_factory=WebToolsConfig)
exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory
restrict_to_workspace: bool = False # restrict all tool access to workspace directory
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
ssrf_whitelist: list[str] = Field(default_factory=list) # CIDR ranges to exempt from SSRF blocking (e.g. ["100.64.0.0/10"] for Tailscale)

View File

@ -47,7 +47,7 @@ class Nanobot:
``~/.nanobot/config.json``.
workspace: Override the workspace directory from config.
"""
from nanobot.config.loader import load_config
from nanobot.config.loader import load_config, resolve_config_env_vars
from nanobot.config.schema import Config
resolved: Path | None = None
@ -56,7 +56,7 @@ class Nanobot:
if not resolved.exists():
raise FileNotFoundError(f"Config not found: {resolved}")
config: Config = load_config(resolved)
config: Config = resolve_config_env_vars(load_config(resolved))
if workspace is not None:
config.agents.defaults.workspace = str(
Path(workspace).expanduser().resolve()

View File

@ -52,6 +52,62 @@ class AnthropicProvider(LLMProvider):
client_kw["max_retries"] = 0
self._client = AsyncAnthropic(**client_kw)
@classmethod
def _handle_error(cls, e: Exception) -> LLMResponse:
response = getattr(e, "response", None)
headers = getattr(response, "headers", None)
payload = (
getattr(e, "body", None)
or getattr(e, "doc", None)
or getattr(response, "text", None)
)
if payload is None and response is not None:
response_json = getattr(response, "json", None)
if callable(response_json):
try:
payload = response_json()
except Exception:
payload = None
payload_text = payload if isinstance(payload, str) else str(payload) if payload is not None else ""
msg = f"Error: {payload_text.strip()[:500]}" if payload_text.strip() else f"Error calling LLM: {e}"
retry_after = cls._extract_retry_after_from_headers(headers)
if retry_after is None:
retry_after = LLMProvider._extract_retry_after(msg)
status_code = getattr(e, "status_code", None)
if status_code is None and response is not None:
status_code = getattr(response, "status_code", None)
should_retry: bool | None = None
if headers is not None:
raw = headers.get("x-should-retry")
if isinstance(raw, str):
lowered = raw.strip().lower()
if lowered == "true":
should_retry = True
elif lowered == "false":
should_retry = False
error_kind: str | None = None
error_name = e.__class__.__name__.lower()
if "timeout" in error_name:
error_kind = "timeout"
elif "connection" in error_name:
error_kind = "connection"
error_type, error_code = LLMProvider._extract_error_type_code(payload)
return LLMResponse(
content=msg,
finish_reason="error",
retry_after=retry_after,
error_status_code=int(status_code) if status_code is not None else None,
error_kind=error_kind,
error_type=error_type,
error_code=error_code,
error_retry_after_s=retry_after,
error_should_retry=should_retry,
)
@staticmethod
def _strip_prefix(model: str) -> str:
if model.startswith("anthropic/"):
@ -404,15 +460,6 @@ class AnthropicProvider(LLMProvider):
# Public API
# ------------------------------------------------------------------
@staticmethod
def _handle_error(e: Exception) -> LLMResponse:
msg = f"Error calling LLM: {e}"
response = getattr(e, "response", None)
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
if retry_after is None:
retry_after = LLMProvider._extract_retry_after(msg)
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
async def chat(
self,
messages: list[dict[str, Any]],
@ -474,6 +521,7 @@ class AnthropicProvider(LLMProvider):
f"{idle_timeout_s} seconds"
),
finish_reason="error",
error_kind="timeout",
)
except Exception as e:
return self._handle_error(e)

View File

@ -54,6 +54,13 @@ class LLMResponse:
retry_after: float | None = None # Provider supplied retry wait in seconds.
reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
# Structured error metadata used by retry policy when finish_reason == "error".
error_status_code: int | None = None
error_kind: str | None = None # e.g. "timeout", "connection"
error_type: str | None = None # Provider/type semantic, e.g. insufficient_quota.
error_code: str | None = None # Provider/code semantic, e.g. rate_limit_exceeded.
error_retry_after_s: float | None = None
error_should_retry: bool | None = None
@property
def has_tool_calls(self) -> bool:
@ -91,6 +98,52 @@ class LLMProvider(ABC):
"server error",
"temporarily unavailable",
)
_RETRYABLE_STATUS_CODES = frozenset({408, 409, 429})
_TRANSIENT_ERROR_KINDS = frozenset({"timeout", "connection"})
_NON_RETRYABLE_429_ERROR_TOKENS = frozenset({
"insufficient_quota",
"quota_exceeded",
"quota_exhausted",
"billing_hard_limit_reached",
"insufficient_balance",
"credit_balance_too_low",
"billing_not_active",
"payment_required",
})
_RETRYABLE_429_ERROR_TOKENS = frozenset({
"rate_limit_exceeded",
"rate_limit_error",
"too_many_requests",
"request_limit_exceeded",
"requests_limit_exceeded",
"overloaded_error",
})
_NON_RETRYABLE_429_TEXT_MARKERS = (
"insufficient_quota",
"insufficient quota",
"quota exceeded",
"quota exhausted",
"billing hard limit",
"billing_hard_limit_reached",
"billing not active",
"insufficient balance",
"insufficient_balance",
"credit balance too low",
"payment required",
"out of credits",
"out of quota",
"exceeded your current quota",
)
_RETRYABLE_429_TEXT_MARKERS = (
"rate limit",
"rate_limit",
"too many requests",
"retry after",
"try again in",
"temporarily unavailable",
"overloaded",
"concurrency limit",
)
_SENTINEL = object()
@ -226,6 +279,80 @@ class LLMProvider(ABC):
err = (content or "").lower()
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
@classmethod
def _is_transient_response(cls, response: LLMResponse) -> bool:
"""Prefer structured error metadata, fallback to text markers for legacy providers."""
if response.error_should_retry is not None:
return bool(response.error_should_retry)
if response.error_status_code is not None:
status = int(response.error_status_code)
if status == 429:
return cls._is_retryable_429_response(response)
if status in cls._RETRYABLE_STATUS_CODES or status >= 500:
return True
kind = (response.error_kind or "").strip().lower()
if kind in cls._TRANSIENT_ERROR_KINDS:
return True
return cls._is_transient_error(response.content)
@staticmethod
def _normalize_error_token(value: Any) -> str | None:
if value is None:
return None
token = str(value).strip().lower()
return token or None
@classmethod
def _extract_error_type_code(cls, payload: Any) -> tuple[str | None, str | None]:
data: dict[str, Any] | None = None
if isinstance(payload, dict):
data = payload
elif isinstance(payload, str):
text = payload.strip()
if text:
try:
parsed = json.loads(text)
except Exception:
parsed = None
if isinstance(parsed, dict):
data = parsed
if not isinstance(data, dict):
return None, None
error_obj = data.get("error")
type_value = data.get("type")
code_value = data.get("code")
if isinstance(error_obj, dict):
type_value = error_obj.get("type") or type_value
code_value = error_obj.get("code") or code_value
return cls._normalize_error_token(type_value), cls._normalize_error_token(code_value)
@classmethod
def _is_retryable_429_response(cls, response: LLMResponse) -> bool:
type_token = cls._normalize_error_token(response.error_type)
code_token = cls._normalize_error_token(response.error_code)
semantic_tokens = {
token for token in (type_token, code_token)
if token is not None
}
if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in semantic_tokens):
return False
content = (response.content or "").lower()
if any(marker in content for marker in cls._NON_RETRYABLE_429_TEXT_MARKERS):
return False
if any(token in cls._RETRYABLE_429_ERROR_TOKENS for token in semantic_tokens):
return True
if any(marker in content for marker in cls._RETRYABLE_429_TEXT_MARKERS):
return True
# Unknown 429 defaults to WAIT+retry.
return True
@staticmethod
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
@ -397,14 +524,28 @@ class LLMProvider(ABC):
def _extract_retry_after_from_headers(cls, headers: Any) -> float | None:
if not headers:
return None
retry_after: Any = None
if hasattr(headers, "get"):
retry_after = headers.get("retry-after") or headers.get("Retry-After")
if retry_after is None and isinstance(headers, dict):
for key, value in headers.items():
if isinstance(key, str) and key.lower() == "retry-after":
retry_after = value
break
def _header_value(name: str) -> Any:
if hasattr(headers, "get"):
value = headers.get(name) or headers.get(name.title())
if value is not None:
return value
if isinstance(headers, dict):
for key, value in headers.items():
if isinstance(key, str) and key.lower() == name.lower():
return value
return None
try:
retry_ms = _header_value("retry-after-ms")
if retry_ms is not None:
value = float(retry_ms) / 1000.0
if value > 0:
return value
except (TypeError, ValueError):
pass
retry_after = _header_value("retry-after")
if retry_after is None:
return None
retry_after_text = str(retry_after).strip()
@ -421,6 +562,14 @@ class LLMProvider(ABC):
remaining = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds()
return max(0.1, remaining)
@classmethod
def _extract_retry_after_from_response(cls, response: LLMResponse) -> float | None:
if response.error_retry_after_s is not None and response.error_retry_after_s > 0:
return response.error_retry_after_s
if response.retry_after is not None and response.retry_after > 0:
return response.retry_after
return cls._extract_retry_after(response.content)
async def _sleep_with_heartbeat(
self,
delay: float,
@ -469,7 +618,7 @@ class LLMProvider(ABC):
last_error_key = error_key
identical_error_count = 1 if error_key else 0
if not self._is_transient_error(response.content):
if not self._is_transient_response(response):
stripped = self._strip_image_content(original_messages)
if stripped is not None and stripped != kw["messages"]:
logger.warning(
@ -492,7 +641,7 @@ class LLMProvider(ABC):
break
base_delay = delays[min(attempt - 1, len(delays) - 1)]
delay = response.retry_after or self._extract_retry_after(response.content) or base_delay
delay = self._extract_retry_after_from_response(response) or base_delay
if persistent:
delay = min(delay, self._PERSISTENT_MAX_DELAY)

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import hashlib
import importlib.util
import os
import secrets
import string
@ -12,7 +13,17 @@ from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
import json_repair
from openai import AsyncOpenAI
if os.environ.get("LANGFUSE_SECRET_KEY") and importlib.util.find_spec("langfuse"):
from langfuse.openai import AsyncOpenAI
else:
if os.environ.get("LANGFUSE_SECRET_KEY"):
import logging
logging.getLogger(__name__).warning(
"LANGFUSE_SECRET_KEY is set but langfuse is not installed; "
"install with `pip install langfuse` to enable tracing"
)
from openai import AsyncOpenAI
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
@ -286,6 +297,24 @@ class OpenAICompatProvider(LLMProvider):
if reasoning_effort:
kwargs["reasoning_effort"] = reasoning_effort
# Provider-specific thinking parameters.
# Only sent when reasoning_effort is explicitly configured so that
# the provider default is preserved otherwise.
if spec and reasoning_effort is not None:
thinking_enabled = reasoning_effort.lower() != "minimal"
extra: dict[str, Any] | None = None
if spec.name == "dashscope":
extra = {"enable_thinking": thinking_enabled}
elif spec.name in (
"volcengine", "volcengine_coding_plan",
"byteplus", "byteplus_coding_plan",
):
extra = {
"thinking": {"type": "enabled" if thinking_enabled else "disabled"}
}
if extra:
kwargs.setdefault("extra_body", {}).update(extra)
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = tool_choice or "auto"
@ -605,16 +634,73 @@ class OpenAICompatProvider(LLMProvider):
reasoning_content="".join(reasoning_parts) or None,
)
@classmethod
def _extract_error_metadata(cls, e: Exception) -> dict[str, Any]:
response = getattr(e, "response", None)
headers = getattr(response, "headers", None)
payload = (
getattr(e, "body", None)
or getattr(e, "doc", None)
or getattr(response, "text", None)
)
if payload is None and response is not None:
response_json = getattr(response, "json", None)
if callable(response_json):
try:
payload = response_json()
except Exception:
payload = None
error_type, error_code = LLMProvider._extract_error_type_code(payload)
status_code = getattr(e, "status_code", None)
if status_code is None and response is not None:
status_code = getattr(response, "status_code", None)
should_retry: bool | None = None
if headers is not None:
raw = headers.get("x-should-retry")
if isinstance(raw, str):
lowered = raw.strip().lower()
if lowered == "true":
should_retry = True
elif lowered == "false":
should_retry = False
error_kind: str | None = None
error_name = e.__class__.__name__.lower()
if "timeout" in error_name:
error_kind = "timeout"
elif "connection" in error_name:
error_kind = "connection"
return {
"error_status_code": int(status_code) if status_code is not None else None,
"error_kind": error_kind,
"error_type": error_type,
"error_code": error_code,
"error_retry_after_s": cls._extract_retry_after_from_headers(headers),
"error_should_retry": should_retry,
}
@staticmethod
def _handle_error(e: Exception) -> LLMResponse:
body = (
getattr(e, "doc", None)
or getattr(e, "body", None)
or getattr(getattr(e, "response", None), "text", None)
)
body_text = body if isinstance(body, str) else str(body) if body is not None else ""
msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}"
response = getattr(e, "response", None)
body = getattr(e, "doc", None) or getattr(response, "text", None)
body_text = str(body).strip() if body is not None else ""
msg = f"Error: {body_text[:500]}" if body_text else f"Error calling LLM: {e}"
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
if retry_after is None:
retry_after = LLMProvider._extract_retry_after(msg)
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
return LLMResponse(
content=msg,
finish_reason="error",
retry_after=retry_after,
**OpenAICompatProvider._extract_error_metadata(e),
)
# ------------------------------------------------------------------
# Public API
@ -682,6 +768,7 @@ class OpenAICompatProvider(LLMProvider):
f"{idle_timeout_s} seconds"
),
finish_reason="error",
error_kind="timeout",
)
except Exception as e:
return self._handle_error(e)

View File

@ -1,4 +1,4 @@
"""Voice transcription provider using Groq."""
"""Voice transcription providers (Groq and OpenAI Whisper)."""
import os
from pathlib import Path
@ -7,6 +7,36 @@ import httpx
from loguru import logger
class OpenAITranscriptionProvider:
"""Voice transcription provider using OpenAI's Whisper API."""
def __init__(self, api_key: str | None = None):
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
self.api_url = "https://api.openai.com/v1/audio/transcriptions"
async def transcribe(self, file_path: str | Path) -> str:
if not self.api_key:
logger.warning("OpenAI API key not configured for transcription")
return ""
path = Path(file_path)
if not path.exists():
logger.error("Audio file not found: {}", file_path)
return ""
try:
async with httpx.AsyncClient() as client:
with open(path, "rb") as f:
files = {"file": (path.name, f), "model": (None, "whisper-1")}
headers = {"Authorization": f"Bearer {self.api_key}"}
response = await client.post(
self.api_url, headers=headers, files=files, timeout=60.0,
)
response.raise_for_status()
return response.json().get("text", "")
except Exception as e:
logger.error("OpenAI transcription error: {}", e)
return ""
class GroqTranscriptionProvider:
"""
Voice transcription provider using Groq's Whisper API.

View File

@ -1,7 +1,9 @@
{% if part == 'system' %}
You are a notification gate for a background agent. You will be given the original task and the agent's response. Call the evaluate_notification tool to decide whether the user should be notified.
Notify when the response contains actionable information, errors, completed deliverables, or anything the user explicitly asked to be reminded about.
Notify when the response contains actionable information, errors, completed deliverables, scheduled reminder/timer completions, or anything the user explicitly asked to be reminded about.
A user-scheduled reminder should usually notify even when the response is brief or mostly repeats the original reminder.
Suppress when the response is a routine status check with nothing new, a confirmation that everything is normal, or essentially empty.
{% elif part == 'user' %}

View File

@ -1,5 +1,6 @@
"""Utility functions for nanobot."""
from nanobot.utils.helpers import ensure_dir
from nanobot.utils.path import abbreviate_path
__all__ = ["ensure_dir"]
__all__ = ["ensure_dir", "abbreviate_path"]

View File

@ -396,8 +396,15 @@ def build_status_content(
context_window_tokens: int,
session_msg_count: int,
context_tokens_estimate: int,
search_usage_text: str | None = None,
) -> str:
"""Build a human-readable runtime status snapshot."""
"""Build a human-readable runtime status snapshot.
Args:
search_usage_text: Optional pre-formatted web search usage string
(produced by SearchUsageInfo.format()). When provided
it is appended as an extra section.
"""
uptime_s = int(time.time() - start_time)
uptime = (
f"{uptime_s // 3600}h {(uptime_s % 3600) // 60}m"
@ -414,14 +421,17 @@ def build_status_content(
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
if cached and last_in:
token_line += f" ({cached * 100 // last_in}% cached)"
return "\n".join([
lines = [
f"\U0001f408 nanobot v{version}",
f"\U0001f9e0 Model: {model}",
token_line,
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
f"\U0001f4ac Session: {session_msg_count} messages",
f"\u23f1 Uptime: {uptime}",
])
]
if search_usage_text:
lines.append(search_usage_text)
return "\n".join(lines)
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:

107
nanobot/utils/path.py Normal file
View File

@ -0,0 +1,107 @@
"""Path abbreviation utilities for display."""
from __future__ import annotations
import os
import re
from urllib.parse import urlparse
def abbreviate_path(path: str, max_len: int = 40) -> str:
"""Abbreviate a file path or URL, preserving basename and key directories.
Strategy:
1. Return as-is if short enough
2. Replace home directory with ~/
3. From right, keep basename + parent dirs until budget exhausted
4. Prefix with /
"""
if not path:
return path
# Handle URLs: preserve scheme://domain + filename
if re.match(r"https?://", path):
return _abbreviate_url(path, max_len)
# Normalize separators to /
normalized = path.replace("\\", "/")
# Replace home directory
home = os.path.expanduser("~").replace("\\", "/")
if normalized.startswith(home + "/"):
normalized = "~" + normalized[len(home):]
elif normalized == home:
normalized = "~"
# Return early only after normalization and home replacement
if len(normalized) <= max_len:
return normalized
# Split into segments
parts = normalized.rstrip("/").split("/")
if len(parts) <= 1:
return normalized[:max_len - 1] + "\u2026"
# Always keep the basename
basename = parts[-1]
# Budget: max_len minus "…/" prefix (2 chars) minus "/" separator minus basename
budget = max_len - len(basename) - 3 # -3 for "…/" + final "/"
# Walk backwards from parent, collecting segments
kept: list[str] = []
for seg in reversed(parts[:-1]):
needed = len(seg) + 1 # segment + "/"
if not kept and needed <= budget:
kept.append(seg)
budget -= needed
elif kept:
needed_with_sep = len(seg) + 1
if needed_with_sep <= budget:
kept.append(seg)
budget -= needed_with_sep
else:
break
else:
break
kept.reverse()
if kept:
return "\u2026/" + "/".join(kept) + "/" + basename
return "\u2026/" + basename
def _abbreviate_url(url: str, max_len: int = 40) -> str:
"""Abbreviate a URL keeping domain and filename."""
if len(url) <= max_len:
return url
parsed = urlparse(url)
domain = parsed.netloc # e.g. "example.com"
path_part = parsed.path # e.g. "/api/v2/resource.json"
# Extract filename from path
segments = path_part.rstrip("/").split("/")
basename = segments[-1] if segments else ""
if not basename:
# No filename, truncate URL
return url[: max_len - 1] + "\u2026"
budget = max_len - len(domain) - len(basename) - 4 # "…/" + "/"
if budget < 0:
trunc = max_len - len(domain) - 5 # "…/" + "/"
return domain + "/\u2026/" + (basename[:trunc] if trunc > 0 else "")
# Build abbreviated path
kept: list[str] = []
for seg in reversed(segments[:-1]):
if len(seg) + 1 <= budget:
kept.append(seg)
budget -= len(seg) + 1
else:
break
kept.reverse()
if kept:
return domain + "/\u2026/" + "/".join(kept) + "/" + basename
return domain + "/\u2026/" + basename

View File

@ -16,8 +16,7 @@ EMPTY_FINAL_RESPONSE_MESSAGE = (
)
FINALIZATION_RETRY_PROMPT = (
"You have already finished the tool work. Do not call any more tools. "
"Using only the conversation and tool results above, provide the final answer for the user now."
"Please provide your response to the user based on the conversation above."
)

View File

@ -0,0 +1,168 @@
"""Web search provider usage fetchers for /status command."""
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import Any
@dataclass
class SearchUsageInfo:
"""Structured usage info returned by a provider fetcher."""
provider: str
supported: bool = False # True if the provider has a usage API
error: str | None = None # Set when the API call failed
# Usage counters (None = not available for this provider)
used: int | None = None
limit: int | None = None
remaining: int | None = None
reset_date: str | None = None # ISO date string, e.g. "2026-05-01"
# Tavily-specific breakdown
search_used: int | None = None
extract_used: int | None = None
crawl_used: int | None = None
def format(self) -> str:
"""Return a human-readable multi-line string for /status output."""
lines = [f"🔍 Web Search: {self.provider}"]
if not self.supported:
lines.append(" Usage tracking: not available for this provider")
return "\n".join(lines)
if self.error:
lines.append(f" Usage: unavailable ({self.error})")
return "\n".join(lines)
if self.used is not None and self.limit is not None:
lines.append(f" Usage: {self.used} / {self.limit} requests")
elif self.used is not None:
lines.append(f" Usage: {self.used} requests")
# Tavily breakdown
breakdown_parts = []
if self.search_used is not None:
breakdown_parts.append(f"Search: {self.search_used}")
if self.extract_used is not None:
breakdown_parts.append(f"Extract: {self.extract_used}")
if self.crawl_used is not None:
breakdown_parts.append(f"Crawl: {self.crawl_used}")
if breakdown_parts:
lines.append(f" Breakdown: {' | '.join(breakdown_parts)}")
if self.remaining is not None:
lines.append(f" Remaining: {self.remaining} requests")
if self.reset_date:
lines.append(f" Resets: {self.reset_date}")
return "\n".join(lines)
async def fetch_search_usage(
provider: str,
api_key: str | None = None,
) -> SearchUsageInfo:
"""
Fetch usage info for the configured web search provider.
Args:
provider: Provider name (e.g. "tavily", "brave", "duckduckgo").
api_key: API key for the provider (falls back to env vars).
Returns:
SearchUsageInfo with populated fields where available.
"""
p = (provider or "duckduckgo").strip().lower()
if p == "tavily":
return await _fetch_tavily_usage(api_key)
else:
# brave, duckduckgo, searxng, jina, unknown — no usage API
return SearchUsageInfo(provider=p, supported=False)
# ---------------------------------------------------------------------------
# Tavily
# ---------------------------------------------------------------------------
async def _fetch_tavily_usage(api_key: str | None) -> SearchUsageInfo:
"""Fetch usage from GET https://api.tavily.com/usage."""
import httpx
key = api_key or os.environ.get("TAVILY_API_KEY", "")
if not key:
return SearchUsageInfo(
provider="tavily",
supported=True,
error="TAVILY_API_KEY not configured",
)
try:
async with httpx.AsyncClient(timeout=8.0) as client:
r = await client.get(
"https://api.tavily.com/usage",
headers={"Authorization": f"Bearer {key}"},
)
r.raise_for_status()
data: dict[str, Any] = r.json()
return _parse_tavily_usage(data)
except httpx.HTTPStatusError as e:
return SearchUsageInfo(
provider="tavily",
supported=True,
error=f"HTTP {e.response.status_code}",
)
except Exception as e:
return SearchUsageInfo(
provider="tavily",
supported=True,
error=str(e)[:80],
)
def _parse_tavily_usage(data: dict[str, Any]) -> SearchUsageInfo:
"""
Parse Tavily /usage response.
Actual API response shape:
{
"account": {
"current_plan": "Researcher",
"plan_usage": 20,
"plan_limit": 1000,
"search_usage": 20,
"crawl_usage": 0,
"extract_usage": 0,
"map_usage": 0,
"research_usage": 0,
"paygo_usage": 0,
"paygo_limit": null
}
}
"""
account = data.get("account") or {}
used = account.get("plan_usage")
limit = account.get("plan_limit")
# Compute remaining
remaining = None
if used is not None and limit is not None:
remaining = max(0, limit - used)
return SearchUsageInfo(
provider="tavily",
supported=True,
used=used,
limit=limit,
remaining=remaining,
search_used=account.get("search_usage"),
extract_used=account.get("extract_usage"),
crawl_used=account.get("crawl_usage"),
)

119
nanobot/utils/tool_hints.py Normal file
View File

@ -0,0 +1,119 @@
"""Tool hint formatting for concise, human-readable tool call display."""
from __future__ import annotations
from nanobot.utils.path import abbreviate_path
# Registry: tool_name -> (key_args, template, is_path, is_command)
_TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = {
"read_file": (["path", "file_path"], "read {}", True, False),
"write_file": (["path", "file_path"], "write {}", True, False),
"edit": (["file_path", "path"], "edit {}", True, False),
"glob": (["pattern"], 'glob "{}"', False, False),
"grep": (["pattern"], 'grep "{}"', False, False),
"exec": (["command"], "$ {}", False, True),
"web_search": (["query"], 'search "{}"', False, False),
"web_fetch": (["url"], "fetch {}", True, False),
"list_dir": (["path"], "ls {}", True, False),
}
def format_tool_hints(tool_calls: list) -> str:
"""Format tool calls as concise hints with smart abbreviation."""
if not tool_calls:
return ""
hints = []
for name, count, example_tc in _group_consecutive(tool_calls):
fmt = _TOOL_FORMATS.get(name)
if fmt:
hint = _fmt_known(example_tc, fmt)
elif name.startswith("mcp_"):
hint = _fmt_mcp(example_tc)
else:
hint = _fmt_fallback(example_tc)
if count > 1:
hint = f"{hint} \u00d7 {count}"
hints.append(hint)
return ", ".join(hints)
def _get_args(tc) -> dict:
"""Extract args dict from tc.arguments, handling list/dict/None/empty."""
if tc.arguments is None:
return {}
if isinstance(tc.arguments, list):
return tc.arguments[0] if tc.arguments else {}
if isinstance(tc.arguments, dict):
return tc.arguments
return {}
def _group_consecutive(calls: list) -> list[tuple[str, int, object]]:
"""Group consecutive calls to the same tool: [(name, count, first), ...]."""
groups: list[tuple[str, int, object]] = []
for tc in calls:
if groups and groups[-1][0] == tc.name:
groups[-1] = (groups[-1][0], groups[-1][1] + 1, groups[-1][2])
else:
groups.append((tc.name, 1, tc))
return groups
def _extract_arg(tc, key_args: list[str]) -> str | None:
"""Extract the first available value from preferred key names."""
args = _get_args(tc)
if not isinstance(args, dict):
return None
for key in key_args:
val = args.get(key)
if isinstance(val, str) and val:
return val
for val in args.values():
if isinstance(val, str) and val:
return val
return None
def _fmt_known(tc, fmt: tuple) -> str:
"""Format a registered tool using its template."""
val = _extract_arg(tc, fmt[0])
if val is None:
return tc.name
if fmt[2]: # is_path
val = abbreviate_path(val)
elif fmt[3]: # is_command
val = val[:40] + "\u2026" if len(val) > 40 else val
return fmt[1].format(val)
def _fmt_mcp(tc) -> str:
"""Format MCP tool as server::tool."""
name = tc.name
if "__" in name:
parts = name.split("__", 1)
server = parts[0].removeprefix("mcp_")
tool = parts[1]
else:
rest = name.removeprefix("mcp_")
parts = rest.split("_", 1)
server = parts[0] if parts else rest
tool = parts[1] if len(parts) > 1 else ""
if not tool:
return name
args = _get_args(tc)
val = next((v for v in args.values() if isinstance(v, str) and v), None)
if val is None:
return f"{server}::{tool}"
return f'{server}::{tool}("{abbreviate_path(val, 40)}")'
def _fmt_fallback(tc) -> str:
"""Original formatting logic for unregistered tools."""
args = _get_args(tc)
val = next(iter(args.values()), None) if isinstance(args, dict) else None
if not isinstance(val, str):
return tc.name
return f'{tc.name}("{abbreviate_path(val, 40)}")' if len(val) > 40 else f'{tc.name}("{val}")'

View File

@ -1,6 +1,6 @@
[project]
name = "nanobot-ai"
version = "0.1.4.post6"
version = "0.1.5"
description = "A lightweight personal AI assistant framework"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.11"

View File

@ -458,6 +458,7 @@ async def test_runner_uses_raw_messages_when_context_governance_fails():
@pytest.mark.asyncio
async def test_runner_retries_empty_final_response_with_summary_prompt():
"""Empty responses get 2 silent retries before finalization kicks in."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
@ -465,11 +466,11 @@ async def test_runner_retries_empty_final_response_with_summary_prompt():
async def chat_with_retry(*, messages, tools=None, **kwargs):
calls.append({"messages": messages, "tools": tools})
if len(calls) == 1:
if len(calls) <= 2:
return LLMResponse(
content=None,
tool_calls=[],
usage={"prompt_tokens": 10, "completion_tokens": 1},
usage={"prompt_tokens": 5, "completion_tokens": 1},
)
return LLMResponse(
content="final answer",
@ -486,20 +487,23 @@ async def test_runner_retries_empty_final_response_with_summary_prompt():
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=1,
max_iterations=3,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "final answer"
assert len(calls) == 2
assert calls[1]["tools"] is None
assert "Do not call any more tools" in calls[1]["messages"][-1]["content"]
# 2 silent retries (iterations 0,1) + finalization on iteration 1
assert len(calls) == 3
assert calls[0]["tools"] is not None
assert calls[1]["tools"] is not None
assert calls[2]["tools"] is None
assert result.usage["prompt_tokens"] == 13
assert result.usage["completion_tokens"] == 8
assert result.usage["completion_tokens"] == 9
@pytest.mark.asyncio
async def test_runner_uses_specific_message_after_empty_finalization_retry():
"""After silent retries + finalization all return empty, stop_reason is empty_final_response."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
@ -517,7 +521,7 @@ async def test_runner_uses_specific_message_after_empty_finalization_retry():
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=1,
max_iterations=3,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
@ -525,6 +529,66 @@ async def test_runner_uses_specific_message_after_empty_finalization_retry():
assert result.stop_reason == "empty_final_response"
@pytest.mark.asyncio
async def test_runner_empty_response_does_not_break_tool_chain():
"""An empty intermediate response must not kill an ongoing tool chain.
Sequence: tool_call empty tool_call final text.
The runner should recover via silent retry and complete normally.
"""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
call_count = 0
async def chat_with_retry(*, messages, tools=None, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return LLMResponse(
content=None,
tool_calls=[ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a.txt"})],
usage={"prompt_tokens": 10, "completion_tokens": 5},
)
if call_count == 2:
return LLMResponse(content=None, tool_calls=[], usage={"prompt_tokens": 10, "completion_tokens": 1})
if call_count == 3:
return LLMResponse(
content=None,
tool_calls=[ToolCallRequest(id="tc2", name="read_file", arguments={"path": "b.txt"})],
usage={"prompt_tokens": 10, "completion_tokens": 5},
)
return LLMResponse(
content="Here are the results.",
tool_calls=[],
usage={"prompt_tokens": 10, "completion_tokens": 10},
)
provider.chat_with_retry = chat_with_retry
provider.chat_stream_with_retry = chat_with_retry
async def fake_tool(name, args, **kw):
return "file content"
tool_registry = MagicMock()
tool_registry.get_definitions.return_value = [{"type": "function", "function": {"name": "read_file"}}]
tool_registry.execute = AsyncMock(side_effect=fake_tool)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "read both files"}],
tools=tool_registry,
model="test-model",
max_iterations=10,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "Here are the results."
assert result.stop_reason == "completed"
assert call_count == 4
assert "read_file" in result.tools_used
def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch):
from nanobot.agent.runner import AgentRunSpec, AgentRunner

View File

@ -0,0 +1,252 @@
"""Tests for nanobot.agent.skills.SkillsLoader."""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from nanobot.agent.skills import SkillsLoader
def _write_skill(
base: Path,
name: str,
*,
metadata_json: dict | None = None,
body: str = "# Skill\n",
) -> Path:
"""Create ``base / name / SKILL.md`` with optional nanobot metadata JSON."""
skill_dir = base / name
skill_dir.mkdir(parents=True)
lines = ["---"]
if metadata_json is not None:
payload = json.dumps({"nanobot": metadata_json}, separators=(",", ":"))
lines.append(f'metadata: {payload}')
lines.extend(["---", "", body])
path = skill_dir / "SKILL.md"
path.write_text("\n".join(lines), encoding="utf-8")
return path
def test_list_skills_empty_when_skills_dir_missing(tmp_path: Path) -> None:
workspace = tmp_path / "ws"
workspace.mkdir()
builtin = tmp_path / "builtin"
builtin.mkdir()
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
assert loader.list_skills(filter_unavailable=False) == []
def test_list_skills_empty_when_skills_dir_exists_but_empty(tmp_path: Path) -> None:
workspace = tmp_path / "ws"
(workspace / "skills").mkdir(parents=True)
builtin = tmp_path / "builtin"
builtin.mkdir()
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
assert loader.list_skills(filter_unavailable=False) == []
def test_list_skills_workspace_entry_shape_and_source(tmp_path: Path) -> None:
workspace = tmp_path / "ws"
skills_root = workspace / "skills"
skills_root.mkdir(parents=True)
skill_path = _write_skill(skills_root, "alpha", body="# Alpha")
builtin = tmp_path / "builtin"
builtin.mkdir()
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
entries = loader.list_skills(filter_unavailable=False)
assert entries == [
{"name": "alpha", "path": str(skill_path), "source": "workspace"},
]
def test_list_skills_skips_non_directories_and_missing_skill_md(tmp_path: Path) -> None:
workspace = tmp_path / "ws"
skills_root = workspace / "skills"
skills_root.mkdir(parents=True)
(skills_root / "not_a_dir.txt").write_text("x", encoding="utf-8")
(skills_root / "no_skill_md").mkdir()
ok_path = _write_skill(skills_root, "ok", body="# Ok")
builtin = tmp_path / "builtin"
builtin.mkdir()
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
entries = loader.list_skills(filter_unavailable=False)
names = {entry["name"] for entry in entries}
assert names == {"ok"}
assert entries[0]["path"] == str(ok_path)
def test_list_skills_workspace_shadows_builtin_same_name(tmp_path: Path) -> None:
workspace = tmp_path / "ws"
ws_skills = workspace / "skills"
ws_skills.mkdir(parents=True)
ws_path = _write_skill(ws_skills, "dup", body="# Workspace wins")
builtin = tmp_path / "builtin"
_write_skill(builtin, "dup", body="# Builtin")
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
entries = loader.list_skills(filter_unavailable=False)
assert len(entries) == 1
assert entries[0]["source"] == "workspace"
assert entries[0]["path"] == str(ws_path)
def test_list_skills_merges_workspace_and_builtin(tmp_path: Path) -> None:
workspace = tmp_path / "ws"
ws_skills = workspace / "skills"
ws_skills.mkdir(parents=True)
ws_path = _write_skill(ws_skills, "ws_only", body="# W")
builtin = tmp_path / "builtin"
bi_path = _write_skill(builtin, "bi_only", body="# B")
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
entries = sorted(loader.list_skills(filter_unavailable=False), key=lambda item: item["name"])
assert entries == [
{"name": "bi_only", "path": str(bi_path), "source": "builtin"},
{"name": "ws_only", "path": str(ws_path), "source": "workspace"},
]
def test_list_skills_builtin_omitted_when_dir_missing(tmp_path: Path) -> None:
workspace = tmp_path / "ws"
ws_skills = workspace / "skills"
ws_skills.mkdir(parents=True)
ws_path = _write_skill(ws_skills, "solo", body="# S")
missing_builtin = tmp_path / "no_such_builtin"
loader = SkillsLoader(workspace, builtin_skills_dir=missing_builtin)
entries = loader.list_skills(filter_unavailable=False)
assert entries == [{"name": "solo", "path": str(ws_path), "source": "workspace"}]
def test_list_skills_filter_unavailable_excludes_unmet_bin_requirement(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
workspace = tmp_path / "ws"
skills_root = workspace / "skills"
skills_root.mkdir(parents=True)
_write_skill(
skills_root,
"needs_bin",
metadata_json={"requires": {"bins": ["nanobot_test_fake_binary"]}},
)
builtin = tmp_path / "builtin"
builtin.mkdir()
def fake_which(cmd: str) -> str | None:
if cmd == "nanobot_test_fake_binary":
return None
return "/usr/bin/true"
monkeypatch.setattr("nanobot.agent.skills.shutil.which", fake_which)
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
assert loader.list_skills(filter_unavailable=True) == []
def test_list_skills_filter_unavailable_includes_when_bin_requirement_met(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
workspace = tmp_path / "ws"
skills_root = workspace / "skills"
skills_root.mkdir(parents=True)
skill_path = _write_skill(
skills_root,
"has_bin",
metadata_json={"requires": {"bins": ["nanobot_test_fake_binary"]}},
)
builtin = tmp_path / "builtin"
builtin.mkdir()
def fake_which(cmd: str) -> str | None:
if cmd == "nanobot_test_fake_binary":
return "/fake/nanobot_test_fake_binary"
return None
monkeypatch.setattr("nanobot.agent.skills.shutil.which", fake_which)
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
entries = loader.list_skills(filter_unavailable=True)
assert entries == [
{"name": "has_bin", "path": str(skill_path), "source": "workspace"},
]
def test_list_skills_filter_unavailable_false_keeps_unmet_requirements(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
workspace = tmp_path / "ws"
skills_root = workspace / "skills"
skills_root.mkdir(parents=True)
skill_path = _write_skill(
skills_root,
"blocked",
metadata_json={"requires": {"bins": ["nanobot_test_fake_binary"]}},
)
builtin = tmp_path / "builtin"
builtin.mkdir()
monkeypatch.setattr("nanobot.agent.skills.shutil.which", lambda _cmd: None)
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
entries = loader.list_skills(filter_unavailable=False)
assert entries == [
{"name": "blocked", "path": str(skill_path), "source": "workspace"},
]
def test_list_skills_filter_unavailable_excludes_unmet_env_requirement(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
workspace = tmp_path / "ws"
skills_root = workspace / "skills"
skills_root.mkdir(parents=True)
_write_skill(
skills_root,
"needs_env",
metadata_json={"requires": {"env": ["NANOBOT_SKILLS_TEST_ENV_VAR"]}},
)
builtin = tmp_path / "builtin"
builtin.mkdir()
monkeypatch.delenv("NANOBOT_SKILLS_TEST_ENV_VAR", raising=False)
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
assert loader.list_skills(filter_unavailable=True) == []
def test_list_skills_openclaw_metadata_parsed_for_requirements(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
workspace = tmp_path / "ws"
skills_root = workspace / "skills"
skills_root.mkdir(parents=True)
skill_dir = skills_root / "openclaw_skill"
skill_dir.mkdir(parents=True)
skill_path = skill_dir / "SKILL.md"
oc_payload = json.dumps({"openclaw": {"requires": {"bins": ["nanobot_oc_bin"]}}}, separators=(",", ":"))
skill_path.write_text(
"\n".join(["---", f"metadata: {oc_payload}", "---", "", "# OC"]),
encoding="utf-8",
)
builtin = tmp_path / "builtin"
builtin.mkdir()
monkeypatch.setattr("nanobot.agent.skills.shutil.which", lambda _cmd: None)
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
assert loader.list_skills(filter_unavailable=True) == []
monkeypatch.setattr(
"nanobot.agent.skills.shutil.which",
lambda cmd: "/x" if cmd == "nanobot_oc_bin" else None,
)
entries = loader.list_skills(filter_unavailable=True)
assert entries == [
{"name": "openclaw_skill", "path": str(skill_path), "source": "workspace"},
]

View File

@ -0,0 +1,202 @@
"""Tests for tool hint formatting (nanobot.utils.tool_hints)."""
from nanobot.utils.tool_hints import format_tool_hints
from nanobot.providers.base import ToolCallRequest
def _tc(name: str, args) -> ToolCallRequest:
return ToolCallRequest(id="c1", name=name, arguments=args)
def _hint(calls):
"""Shortcut for format_tool_hints."""
return format_tool_hints(calls)
class TestToolHintKnownTools:
"""Test registered tool types produce correct formatted output."""
def test_read_file_short_path(self):
result = _hint([_tc("read_file", {"path": "foo.txt"})])
assert result == 'read foo.txt'
def test_read_file_long_path(self):
result = _hint([_tc("read_file", {"path": "/home/user/.local/share/uv/tools/nanobot/agent/loop.py"})])
assert "loop.py" in result
assert "read " in result
def test_write_file_shows_path_not_content(self):
result = _hint([_tc("write_file", {"path": "docs/api.md", "content": "# API Reference\n\nLong content..."})])
assert result == "write docs/api.md"
def test_edit_shows_path(self):
result = _hint([_tc("edit", {"file_path": "src/main.py", "old_string": "x", "new_string": "y"})])
assert "main.py" in result
assert "edit " in result
def test_glob_shows_pattern(self):
result = _hint([_tc("glob", {"pattern": "**/*.py", "path": "src"})])
assert result == 'glob "**/*.py"'
def test_grep_shows_pattern(self):
result = _hint([_tc("grep", {"pattern": "TODO|FIXME", "path": "src"})])
assert result == 'grep "TODO|FIXME"'
def test_exec_shows_command(self):
result = _hint([_tc("exec", {"command": "npm install typescript"})])
assert result == "$ npm install typescript"
def test_exec_truncates_long_command(self):
cmd = "cd /very/long/path && cat file && echo done && sleep 1 && ls -la"
result = _hint([_tc("exec", {"command": cmd})])
assert result.startswith("$ ")
assert len(result) <= 50 # reasonable limit
def test_web_search(self):
result = _hint([_tc("web_search", {"query": "Claude 4 vs GPT-4"})])
assert result == 'search "Claude 4 vs GPT-4"'
def test_web_fetch(self):
result = _hint([_tc("web_fetch", {"url": "https://example.com/page"})])
assert result == "fetch https://example.com/page"
class TestToolHintMCP:
"""Test MCP tools are abbreviated to server::tool format."""
def test_mcp_standard_format(self):
result = _hint([_tc("mcp_4_5v_mcp__analyze_image", {"imageSource": "https://img.jpg", "prompt": "describe"})])
assert "4_5v" in result
assert "analyze_image" in result
def test_mcp_simple_name(self):
result = _hint([_tc("mcp_github__create_issue", {"title": "Bug fix"})])
assert "github" in result
assert "create_issue" in result
class TestToolHintFallback:
"""Test unknown tools fall back to original behavior."""
def test_unknown_tool_with_string_arg(self):
result = _hint([_tc("custom_tool", {"data": "hello world"})])
assert result == 'custom_tool("hello world")'
def test_unknown_tool_with_long_arg_truncates(self):
long_val = "a" * 60
result = _hint([_tc("custom_tool", {"data": long_val})])
assert len(result) < 80
assert "\u2026" in result
def test_unknown_tool_no_string_arg(self):
result = _hint([_tc("custom_tool", {"count": 42})])
assert result == "custom_tool"
def test_empty_tool_calls(self):
result = _hint([])
assert result == ""
class TestToolHintFolding:
"""Test consecutive same-tool calls are folded."""
def test_single_call_no_fold(self):
calls = [_tc("grep", {"pattern": "*.py"})]
result = _hint(calls)
assert "\u00d7" not in result
def test_two_consecutive_same_folded(self):
calls = [
_tc("grep", {"pattern": "*.py"}),
_tc("grep", {"pattern": "*.ts"}),
]
result = _hint(calls)
assert "\u00d7 2" in result
def test_three_consecutive_same_folded(self):
calls = [
_tc("read_file", {"path": "a.py"}),
_tc("read_file", {"path": "b.py"}),
_tc("read_file", {"path": "c.py"}),
]
result = _hint(calls)
assert "\u00d7 3" in result
def test_different_tools_not_folded(self):
calls = [
_tc("grep", {"pattern": "TODO"}),
_tc("read_file", {"path": "a.py"}),
]
result = _hint(calls)
assert "\u00d7" not in result
def test_interleaved_same_tools_not_folded(self):
calls = [
_tc("grep", {"pattern": "a"}),
_tc("read_file", {"path": "f.py"}),
_tc("grep", {"pattern": "b"}),
]
result = _hint(calls)
assert "\u00d7" not in result
class TestToolHintMultipleCalls:
"""Test multiple different tool calls are comma-separated."""
def test_two_different_tools(self):
calls = [
_tc("grep", {"pattern": "TODO"}),
_tc("read_file", {"path": "main.py"}),
]
result = _hint(calls)
assert 'grep "TODO"' in result
assert "read main.py" in result
assert ", " in result
class TestToolHintEdgeCases:
"""Test edge cases and defensive handling (G1, G2)."""
def test_known_tool_empty_list_args(self):
"""C1/G1: Empty list arguments should not crash."""
result = _hint([_tc("read_file", [])])
assert result == "read_file"
def test_known_tool_none_args(self):
"""G2: None arguments should not crash."""
result = _hint([_tc("read_file", None)])
assert result == "read_file"
def test_fallback_empty_list_args(self):
"""C1: Empty list args in fallback should not crash."""
result = _hint([_tc("custom_tool", [])])
assert result == "custom_tool"
def test_fallback_none_args(self):
"""G2: None args in fallback should not crash."""
result = _hint([_tc("custom_tool", None)])
assert result == "custom_tool"
def test_list_dir_registered(self):
"""S2: list_dir should use 'ls' format."""
result = _hint([_tc("list_dir", {"path": "/tmp"})])
assert result == "ls /tmp"
class TestToolHintMixedFolding:
"""G4: Mixed folding groups with interleaved same-tool segments."""
def test_read_read_grep_grep_read(self):
"""read×2, grep×2, read — should produce two separate groups."""
calls = [
_tc("read_file", {"path": "a.py"}),
_tc("read_file", {"path": "b.py"}),
_tc("grep", {"pattern": "x"}),
_tc("grep", {"pattern": "y"}),
_tc("read_file", {"path": "c.py"}),
]
result = _hint(calls)
assert "\u00d7 2" in result
# Should have 3 groups: read×2, grep×2, read
parts = result.split(", ")
assert len(parts) == 3

View File

@ -1,5 +1,6 @@
from email.message import EmailMessage
from datetime import date
from pathlib import Path
import imaplib
import pytest
@ -650,3 +651,224 @@ def test_check_authentication_results_method() -> None:
spf, dkim = EmailChannel._check_authentication_results(parsed)
assert spf is False
assert dkim is True
# ---------------------------------------------------------------------------
# Attachment extraction tests
# ---------------------------------------------------------------------------
def _make_raw_email_with_attachment(
from_addr: str = "alice@example.com",
subject: str = "With attachment",
body: str = "See attached.",
attachment_name: str = "doc.pdf",
attachment_content: bytes = b"%PDF-1.4 fake pdf content",
attachment_mime: str = "application/pdf",
auth_results: str | None = None,
) -> bytes:
msg = EmailMessage()
msg["From"] = from_addr
msg["To"] = "bot@example.com"
msg["Subject"] = subject
msg["Message-ID"] = "<m1@example.com>"
if auth_results:
msg["Authentication-Results"] = auth_results
msg.set_content(body)
maintype, subtype = attachment_mime.split("/", 1)
msg.add_attachment(
attachment_content,
maintype=maintype,
subtype=subtype,
filename=attachment_name,
)
return msg.as_bytes()
def test_extract_attachments_saves_pdf(tmp_path, monkeypatch) -> None:
"""PDF attachment is saved to media dir and path returned in media list."""
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
raw = _make_raw_email_with_attachment()
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(allowed_attachment_types=["application/pdf"], verify_dkim=False, verify_spf=False)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert len(items[0]["media"]) == 1
saved_path = Path(items[0]["media"][0])
assert saved_path.exists()
assert saved_path.read_bytes() == b"%PDF-1.4 fake pdf content"
assert "500_doc.pdf" in saved_path.name
assert "[attachment:" in items[0]["content"]
def test_extract_attachments_disabled_by_default(monkeypatch) -> None:
"""With no allowed_attachment_types (default), no attachments are extracted."""
raw = _make_raw_email_with_attachment()
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(verify_dkim=False, verify_spf=False)
assert cfg.allowed_attachment_types == []
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert items[0]["media"] == []
assert "[attachment:" not in items[0]["content"]
def test_extract_attachments_mime_type_filter(tmp_path, monkeypatch) -> None:
"""Non-allowed MIME types are skipped."""
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
raw = _make_raw_email_with_attachment(
attachment_name="image.png",
attachment_content=b"\x89PNG fake",
attachment_mime="image/png",
)
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(
allowed_attachment_types=["application/pdf"],
verify_dkim=False,
verify_spf=False,
)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert items[0]["media"] == []
def test_extract_attachments_empty_allowed_types_rejects_all(tmp_path, monkeypatch) -> None:
"""Empty allowed_attachment_types means no types are accepted."""
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
raw = _make_raw_email_with_attachment(
attachment_name="image.png",
attachment_content=b"\x89PNG fake",
attachment_mime="image/png",
)
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(
allowed_attachment_types=[],
verify_dkim=False,
verify_spf=False,
)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert items[0]["media"] == []
def test_extract_attachments_wildcard_pattern(tmp_path, monkeypatch) -> None:
"""Glob patterns like 'image/*' match attachment MIME types."""
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
raw = _make_raw_email_with_attachment(
attachment_name="photo.jpg",
attachment_content=b"\xff\xd8\xff fake jpeg",
attachment_mime="image/jpeg",
)
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(
allowed_attachment_types=["image/*"],
verify_dkim=False,
verify_spf=False,
)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert len(items[0]["media"]) == 1
def test_extract_attachments_size_limit(tmp_path, monkeypatch) -> None:
"""Attachments exceeding max_attachment_size are skipped."""
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
raw = _make_raw_email_with_attachment(
attachment_content=b"x" * 1000,
)
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(
allowed_attachment_types=["*"],
max_attachment_size=500,
verify_dkim=False,
verify_spf=False,
)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert items[0]["media"] == []
def test_extract_attachments_max_count(tmp_path, monkeypatch) -> None:
"""Only max_attachments_per_email are saved."""
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
# Build email with 3 attachments
msg = EmailMessage()
msg["From"] = "alice@example.com"
msg["To"] = "bot@example.com"
msg["Subject"] = "Many attachments"
msg["Message-ID"] = "<m1@example.com>"
msg.set_content("See attached.")
for i in range(3):
msg.add_attachment(
f"content {i}".encode(),
maintype="application",
subtype="pdf",
filename=f"doc{i}.pdf",
)
raw = msg.as_bytes()
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(
allowed_attachment_types=["*"],
max_attachments_per_email=2,
verify_dkim=False,
verify_spf=False,
)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert len(items[0]["media"]) == 2
def test_extract_attachments_sanitizes_filename(tmp_path, monkeypatch) -> None:
"""Path traversal in filenames is neutralized."""
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
raw = _make_raw_email_with_attachment(
attachment_name="../../../etc/passwd",
)
fake = _make_fake_imap(raw)
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
cfg = _make_config(allowed_attachment_types=["*"], verify_dkim=False, verify_spf=False)
channel = EmailChannel(cfg, MessageBus())
items = channel._fetch_new_messages()
assert len(items) == 1
assert len(items[0]["media"]) == 1
saved_path = Path(items[0]["media"][0])
# File must be inside the media dir, not escaped via path traversal
assert saved_path.parent == tmp_path

View File

@ -0,0 +1,62 @@
"""Tests for Feishu _is_bot_mentioned logic."""
from types import SimpleNamespace
import pytest
from nanobot.channels.feishu import FeishuChannel
def _make_channel(bot_open_id: str | None = None) -> FeishuChannel:
config = SimpleNamespace(
app_id="test_id",
app_secret="test_secret",
verification_token="",
event_encrypt_key="",
group_policy="mention",
)
ch = FeishuChannel.__new__(FeishuChannel)
ch.config = config
ch._bot_open_id = bot_open_id
return ch
def _make_message(mentions=None, content="hello"):
return SimpleNamespace(content=content, mentions=mentions)
def _make_mention(open_id: str, user_id: str | None = None):
mid = SimpleNamespace(open_id=open_id, user_id=user_id)
return SimpleNamespace(id=mid)
class TestIsBotMentioned:
def test_exact_match_with_bot_open_id(self):
ch = _make_channel(bot_open_id="ou_bot123")
msg = _make_message(mentions=[_make_mention("ou_bot123")])
assert ch._is_bot_mentioned(msg) is True
def test_no_match_different_bot(self):
ch = _make_channel(bot_open_id="ou_bot123")
msg = _make_message(mentions=[_make_mention("ou_other_bot")])
assert ch._is_bot_mentioned(msg) is False
def test_at_all_always_matches(self):
ch = _make_channel(bot_open_id="ou_bot123")
msg = _make_message(content="@_all hello")
assert ch._is_bot_mentioned(msg) is True
def test_fallback_heuristic_when_no_bot_open_id(self):
ch = _make_channel(bot_open_id=None)
msg = _make_message(mentions=[_make_mention("ou_some_bot", user_id=None)])
assert ch._is_bot_mentioned(msg) is True
def test_fallback_ignores_user_mentions(self):
ch = _make_channel(bot_open_id=None)
msg = _make_message(mentions=[_make_mention("ou_user", user_id="u_12345")])
assert ch._is_bot_mentioned(msg) is False
def test_no_mentions_returns_false(self):
ch = _make_channel(bot_open_id="ou_bot123")
msg = _make_message(mentions=None)
assert ch._is_bot_mentioned(msg) is False

View File

@ -0,0 +1,59 @@
"""Tests for FeishuChannel._resolve_mentions."""
from types import SimpleNamespace
from nanobot.channels.feishu import FeishuChannel
def _mention(key: str, name: str, open_id: str = "", user_id: str = ""):
"""Build a mock MentionEvent-like object."""
id_obj = SimpleNamespace(open_id=open_id, user_id=user_id) if (open_id or user_id) else None
return SimpleNamespace(key=key, name=name, id=id_obj)
class TestResolveMentions:
def test_single_mention_replaced(self):
text = "hello @_user_1 how are you"
mentions = [_mention("@_user_1", "Alice", open_id="ou_abc123")]
result = FeishuChannel._resolve_mentions(text, mentions)
assert "@Alice (ou_abc123)" in result
assert "@_user_1" not in result
def test_mention_with_both_ids(self):
text = "@_user_1 said hi"
mentions = [_mention("@_user_1", "Bob", open_id="ou_abc", user_id="uid_456")]
result = FeishuChannel._resolve_mentions(text, mentions)
assert "@Bob (ou_abc, user id: uid_456)" in result
def test_mention_no_id_skipped(self):
"""When mention has no id object, the placeholder is left unchanged."""
text = "@_user_1 said hi"
mentions = [SimpleNamespace(key="@_user_1", name="Charlie", id=None)]
result = FeishuChannel._resolve_mentions(text, mentions)
assert result == "@_user_1 said hi"
def test_multiple_mentions(self):
text = "@_user_1 and @_user_2 are here"
mentions = [
_mention("@_user_1", "Alice", open_id="ou_a"),
_mention("@_user_2", "Bob", open_id="ou_b"),
]
result = FeishuChannel._resolve_mentions(text, mentions)
assert "@Alice (ou_a)" in result
assert "@Bob (ou_b)" in result
assert "@_user_1" not in result
assert "@_user_2" not in result
def test_no_mentions_returns_text(self):
assert FeishuChannel._resolve_mentions("hello world", None) == "hello world"
assert FeishuChannel._resolve_mentions("hello world", []) == "hello world"
def test_empty_text_returns_empty(self):
mentions = [_mention("@_user_1", "Alice", open_id="ou_a")]
assert FeishuChannel._resolve_mentions("", mentions) == ""
def test_mention_key_not_in_text_skipped(self):
text = "hello world"
mentions = [_mention("@_user_99", "Ghost", open_id="ou_ghost")]
result = FeishuChannel._resolve_mentions(text, mentions)
assert result == "hello world"

View File

@ -127,6 +127,79 @@ async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel):
@mark.asyncio
async def test_tool_hint_new_format_basic(mock_feishu_channel):
"""New format hints (read path, grep "pattern") should parse correctly."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content='read src/main.py, grep "TODO"',
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
content = json.loads(mock_send.call_args[0][3])
md = content["elements"][0]["content"]
assert "read src/main.py" in md
assert 'grep "TODO"' in md
@mark.asyncio
async def test_tool_hint_new_format_with_comma_in_quotes(mock_feishu_channel):
"""Commas inside quoted arguments must not cause incorrect line splits."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content='grep "hello, world", $ echo test',
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
content = json.loads(mock_send.call_args[0][3])
md = content["elements"][0]["content"]
# The comma inside quotes should NOT cause a line break
assert 'grep "hello, world"' in md
assert "$ echo test" in md
@mark.asyncio
async def test_tool_hint_new_format_with_folding(mock_feishu_channel):
"""Folded calls (× N) should display on separate lines."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content='read path × 3, grep "pattern"',
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
content = json.loads(mock_send.call_args[0][3])
md = content["elements"][0]["content"]
assert "\u00d7 3" in md
assert 'grep "pattern"' in md
@mark.asyncio
async def test_tool_hint_new_format_mcp(mock_feishu_channel):
"""MCP tool format (server::tool) should parse correctly."""
msg = OutboundMessage(
channel="feishu",
chat_id="oc_123456",
content='4_5v::analyze_image("photo.jpg")',
metadata={"_tool_hint": True}
)
with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send:
await mock_feishu_channel.send(msg)
content = json.loads(mock_send.call_args[0][3])
md = content["elements"][0]["content"]
assert "4_5v::analyze_image" in md
async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel):
"""Commas inside a single tool argument must not be split onto a new line."""
msg = OutboundMessage(

View File

@ -385,6 +385,32 @@ async def test_send_delta_stream_end_treats_not_modified_as_success() -> None:
assert "123" not in channel._stream_bufs
@pytest.mark.asyncio
async def test_send_delta_stream_end_splits_oversized_reply() -> None:
"""Final streamed reply exceeding Telegram limit is split into chunks."""
from nanobot.channels.telegram import TELEGRAM_MAX_MESSAGE_LEN
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
channel._app.bot.edit_message_text = AsyncMock()
channel._app.bot.send_message = AsyncMock(return_value=SimpleNamespace(message_id=99))
oversized = "x" * (TELEGRAM_MAX_MESSAGE_LEN + 500)
channel._stream_bufs["123"] = _StreamBuf(text=oversized, message_id=7, last_edit=0.0)
await channel.send_delta("123", "", {"_stream_end": True})
channel._app.bot.edit_message_text.assert_called_once()
edit_text = channel._app.bot.edit_message_text.call_args.kwargs.get("text", "")
assert len(edit_text) <= TELEGRAM_MAX_MESSAGE_LEN
channel._app.bot.send_message.assert_called_once()
assert "123" not in channel._stream_bufs
@pytest.mark.asyncio
async def test_send_delta_new_stream_id_replaces_stale_buffer() -> None:
channel = TelegramChannel(
@ -424,6 +450,23 @@ async def test_send_delta_incremental_edit_treats_not_modified_as_success() -> N
assert channel._stream_bufs["123"].last_edit > 0.0
@pytest.mark.asyncio
async def test_send_delta_initial_send_keeps_message_in_thread() -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
await channel.send_delta(
"123",
"hello",
{"_stream_delta": True, "_stream_id": "s:0", "message_thread_id": 42},
)
assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42
def test_derive_topic_session_key_uses_thread_id() -> None:
message = SimpleNamespace(
chat=SimpleNamespace(type="supergroup"),
@ -434,6 +477,27 @@ def test_derive_topic_session_key_uses_thread_id() -> None:
assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42"
def test_derive_topic_session_key_private_dm_thread() -> None:
"""Private DM threads (Telegram Threaded Mode) must get their own session key."""
message = SimpleNamespace(
chat=SimpleNamespace(type="private"),
chat_id=999,
message_thread_id=7,
)
assert TelegramChannel._derive_topic_session_key(message) == "telegram:999:topic:7"
def test_derive_topic_session_key_none_without_thread() -> None:
"""No thread id → no topic session key, regardless of chat type."""
for chat_type in ("private", "supergroup", "group"):
message = SimpleNamespace(
chat=SimpleNamespace(type=chat_type),
chat_id=123,
message_thread_id=None,
)
assert TelegramChannel._derive_topic_session_key(message) is None
def test_get_extension_falls_back_to_original_filename() -> None:
channel = TelegramChannel(TelegramConfig(), MessageBus())

View File

@ -163,6 +163,107 @@ async def test_group_policy_mention_accepts_mentioned_group_message():
assert kwargs["sender_id"] == "user"
@pytest.mark.asyncio
async def test_sender_id_prefers_phone_jid_over_lid():
"""sender_id should resolve to phone number when @s.whatsapp.net JID is present."""
ch = WhatsAppChannel({"enabled": True}, MagicMock())
ch._handle_message = AsyncMock()
await ch._handle_bridge_message(
json.dumps({
"type": "message",
"id": "lid1",
"sender": "ABC123@lid.whatsapp.net",
"pn": "5551234@s.whatsapp.net",
"content": "hi",
"timestamp": 1,
})
)
kwargs = ch._handle_message.await_args.kwargs
assert kwargs["sender_id"] == "5551234"
@pytest.mark.asyncio
async def test_lid_to_phone_cache_resolves_lid_only_messages():
"""When only LID is present, a cached LID→phone mapping should be used."""
ch = WhatsAppChannel({"enabled": True}, MagicMock())
ch._handle_message = AsyncMock()
# First message: both phone and LID → builds cache
await ch._handle_bridge_message(
json.dumps({
"type": "message",
"id": "c1",
"sender": "LID99@lid.whatsapp.net",
"pn": "5559999@s.whatsapp.net",
"content": "first",
"timestamp": 1,
})
)
# Second message: only LID, no phone
await ch._handle_bridge_message(
json.dumps({
"type": "message",
"id": "c2",
"sender": "LID99@lid.whatsapp.net",
"pn": "",
"content": "second",
"timestamp": 2,
})
)
second_kwargs = ch._handle_message.await_args_list[1].kwargs
assert second_kwargs["sender_id"] == "5559999"
@pytest.mark.asyncio
async def test_voice_message_transcription_uses_media_path():
"""Voice messages are transcribed when media path is available."""
ch = WhatsAppChannel({"enabled": True}, MagicMock())
ch.transcription_provider = "openai"
ch.transcription_api_key = "sk-test"
ch._handle_message = AsyncMock()
ch.transcribe_audio = AsyncMock(return_value="Hello world")
await ch._handle_bridge_message(
json.dumps({
"type": "message",
"id": "v1",
"sender": "12345@s.whatsapp.net",
"pn": "",
"content": "[Voice Message]",
"timestamp": 1,
"media": ["/tmp/voice.ogg"],
})
)
ch.transcribe_audio.assert_awaited_once_with("/tmp/voice.ogg")
kwargs = ch._handle_message.await_args.kwargs
assert kwargs["content"].startswith("Hello world")
@pytest.mark.asyncio
async def test_voice_message_no_media_shows_not_available():
"""Voice messages without media produce a fallback placeholder."""
ch = WhatsAppChannel({"enabled": True}, MagicMock())
ch._handle_message = AsyncMock()
await ch._handle_bridge_message(
json.dumps({
"type": "message",
"id": "v2",
"sender": "12345@s.whatsapp.net",
"pn": "",
"content": "[Voice Message]",
"timestamp": 1,
})
)
kwargs = ch._handle_message.await_args.kwargs
assert kwargs["content"] == "[Voice Message: Audio not available]"
def test_load_or_create_bridge_token_persists_generated_secret(tmp_path):
token_path = tmp_path / "whatsapp-auth" / "bridge-token"

View File

@ -1,5 +1,7 @@
import asyncio
import json
import re
import shutil
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
@ -9,6 +11,7 @@ from typer.testing import CliRunner
from nanobot.bus.events import OutboundMessage
from nanobot.cli.commands import _make_provider, app
from nanobot.config.schema import Config
from nanobot.cron.types import CronJob, CronPayload
from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_name
@ -19,11 +22,6 @@ class _StopGatewayError(RuntimeError):
pass
import shutil
import pytest
@pytest.fixture
def mock_paths():
"""Mock config/workspace paths for test isolation."""
@ -31,7 +29,6 @@ def mock_paths():
patch("nanobot.config.loader.save_config") as mock_sc, \
patch("nanobot.config.loader.load_config") as mock_lc, \
patch("nanobot.cli.commands.get_workspace_path") as mock_ws:
base_dir = Path("./test_onboard_data")
if base_dir.exists():
shutil.rmtree(base_dir)
@ -425,13 +422,13 @@ def mock_agent_runtime(tmp_path):
config.agents.defaults.workspace = str(tmp_path / "default-workspace")
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
patch("nanobot.config.loader.resolve_config_env_vars", side_effect=lambda c: c), \
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
patch("nanobot.cli.commands._make_provider", return_value=object()), \
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
patch("nanobot.bus.queue.MessageBus"), \
patch("nanobot.cron.service.CronService"), \
patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls:
agent_loop = MagicMock()
agent_loop.channels_config = None
agent_loop.process_direct = AsyncMock(
@ -656,7 +653,9 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron(
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
monkeypatch.setattr(
"nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None
)
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
@ -739,6 +738,7 @@ def _patch_cli_command_runtime(
set_config_path or (lambda _path: None),
)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.config.loader.resolve_config_env_vars", lambda c: c)
monkeypatch.setattr(
"nanobot.cli.commands.sync_workspace_templates",
sync_templates or (lambda _path: None),
@ -868,6 +868,115 @@ def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path:
assert seen["cron_store"] == config.workspace_path / "cron" / "jobs.json"
def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
monkeypatch, tmp_path: Path
) -> None:
config_file = tmp_path / "instance" / "config.json"
config_file.parent.mkdir(parents=True)
config_file.write_text("{}")
config = Config()
config.agents.defaults.workspace = str(tmp_path / "config-workspace")
provider = object()
bus = MagicMock()
bus.publish_outbound = AsyncMock()
seen: dict[str, object] = {}
monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None)
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider)
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
class _FakeCron:
def __init__(self, _store_path: Path) -> None:
self.on_job = None
seen["cron"] = self
class _FakeAgentLoop:
def __init__(self, *args, **kwargs) -> None:
self.model = "test-model"
self.tools = {}
async def process_direct(self, *_args, **_kwargs):
return OutboundMessage(
channel="telegram",
chat_id="user-1",
content="Time to stretch.",
)
async def close_mcp(self) -> None:
return None
async def run(self) -> None:
return None
def stop(self) -> None:
return None
class _StopAfterCronSetup:
def __init__(self, *_args, **_kwargs) -> None:
raise _StopGatewayError("stop")
async def _capture_evaluate_response(
response: str,
task_context: str,
provider_arg: object,
model: str,
) -> bool:
seen["response"] = response
seen["task_context"] = task_context
seen["provider"] = provider_arg
seen["model"] = model
return True
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron)
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _StopAfterCronSetup)
monkeypatch.setattr(
"nanobot.utils.evaluator.evaluate_response",
_capture_evaluate_response,
)
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
assert isinstance(result.exception, _StopGatewayError)
cron = seen["cron"]
assert isinstance(cron, _FakeCron)
assert cron.on_job is not None
job = CronJob(
id="cron-1",
name="stretch",
payload=CronPayload(
message="Remind me to stretch.",
deliver=True,
channel="telegram",
to="user-1",
),
)
response = asyncio.run(cron.on_job(job))
assert response == "Time to stretch."
assert seen["response"] == "Time to stretch."
assert seen["provider"] is provider
assert seen["model"] == "test-model"
assert seen["task_context"] == (
"[Scheduled Task] Timer finished.\n\n"
"Task 'stretch' has been triggered.\n"
"Scheduled instruction: Remind me to stretch."
)
bus.publish_outbound.assert_awaited_once_with(
OutboundMessage(
channel="telegram",
chat_id="user-1",
content="Time to stretch.",
)
)
def test_gateway_workspace_override_does_not_migrate_legacy_cron(
monkeypatch, tmp_path: Path
) -> None:

View File

@ -0,0 +1,44 @@
"""Regression tests for SafeFileHistory (issue #2846).
Surrogate characters in CLI input must not crash history file writes.
"""
from nanobot.cli.commands import SafeFileHistory
class TestSafeFileHistory:
def test_surrogate_replaced(self, tmp_path):
"""Surrogate pairs are replaced with U+FFFD, not crash."""
hist = SafeFileHistory(str(tmp_path / "history"))
hist.store_string("hello \udce9 world")
entries = list(hist.load_history_strings())
assert len(entries) == 1
assert "\udce9" not in entries[0]
assert "hello" in entries[0]
assert "world" in entries[0]
def test_normal_text_unchanged(self, tmp_path):
hist = SafeFileHistory(str(tmp_path / "history"))
hist.store_string("normal ascii text")
entries = list(hist.load_history_strings())
assert entries[0] == "normal ascii text"
def test_emoji_preserved(self, tmp_path):
hist = SafeFileHistory(str(tmp_path / "history"))
hist.store_string("hello 🐈 nanobot")
entries = list(hist.load_history_strings())
assert entries[0] == "hello 🐈 nanobot"
def test_mixed_unicode_preserved(self, tmp_path):
"""CJK + emoji + latin should all pass through cleanly."""
hist = SafeFileHistory(str(tmp_path / "history"))
hist.store_string("你好 hello こんにちは 🎉")
entries = list(hist.load_history_strings())
assert entries[0] == "你好 hello こんにちは 🎉"
def test_multiple_surrogates(self, tmp_path):
hist = SafeFileHistory(str(tmp_path / "history"))
hist.store_string("\udce9\udcf1\udcff")
entries = list(hist.load_history_strings())
assert len(entries) == 1
assert "\udce9" not in entries[0]

View File

@ -0,0 +1,82 @@
import json
import pytest
from nanobot.config.loader import (
_resolve_env_vars,
load_config,
resolve_config_env_vars,
save_config,
)
class TestResolveEnvVars:
def test_replaces_string_value(self, monkeypatch):
monkeypatch.setenv("MY_SECRET", "hunter2")
assert _resolve_env_vars("${MY_SECRET}") == "hunter2"
def test_partial_replacement(self, monkeypatch):
monkeypatch.setenv("HOST", "example.com")
assert _resolve_env_vars("https://${HOST}/api") == "https://example.com/api"
def test_multiple_vars_in_one_string(self, monkeypatch):
monkeypatch.setenv("USER", "alice")
monkeypatch.setenv("PASS", "secret")
assert _resolve_env_vars("${USER}:${PASS}") == "alice:secret"
def test_nested_dicts(self, monkeypatch):
monkeypatch.setenv("TOKEN", "abc123")
data = {"channels": {"telegram": {"token": "${TOKEN}"}}}
result = _resolve_env_vars(data)
assert result["channels"]["telegram"]["token"] == "abc123"
def test_lists(self, monkeypatch):
monkeypatch.setenv("VAL", "x")
assert _resolve_env_vars(["${VAL}", "plain"]) == ["x", "plain"]
def test_ignores_non_strings(self):
assert _resolve_env_vars(42) == 42
assert _resolve_env_vars(True) is True
assert _resolve_env_vars(None) is None
assert _resolve_env_vars(3.14) == 3.14
def test_plain_strings_unchanged(self):
assert _resolve_env_vars("no vars here") == "no vars here"
def test_missing_var_raises(self):
with pytest.raises(ValueError, match="DOES_NOT_EXIST"):
_resolve_env_vars("${DOES_NOT_EXIST}")
class TestResolveConfig:
def test_resolves_env_vars_in_config(self, tmp_path, monkeypatch):
monkeypatch.setenv("TEST_API_KEY", "resolved-key")
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{"providers": {"groq": {"apiKey": "${TEST_API_KEY}"}}}
),
encoding="utf-8",
)
raw = load_config(config_path)
assert raw.providers.groq.api_key == "${TEST_API_KEY}"
resolved = resolve_config_env_vars(raw)
assert resolved.providers.groq.api_key == "resolved-key"
def test_save_preserves_templates(self, tmp_path, monkeypatch):
monkeypatch.setenv("MY_TOKEN", "real-token")
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{"channels": {"telegram": {"token": "${MY_TOKEN}"}}}
),
encoding="utf-8",
)
raw = load_config(config_path)
save_config(raw, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["channels"]["telegram"]["token"] == "${MY_TOKEN}"

View File

@ -299,7 +299,7 @@ def test_add_cron_job_defaults_to_tool_timezone(tmp_path) -> None:
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
tool.set_context("telegram", "chat-1")
result = tool._add_job("Morning standup", None, "0 8 * * *", None, None)
result = tool._add_job(None, "Morning standup", None, "0 8 * * *", None, None)
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]
@ -310,7 +310,7 @@ def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None:
tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai")
tool.set_context("telegram", "chat-1")
result = tool._add_job("Morning reminder", None, None, None, "2026-03-25T08:00:00")
result = tool._add_job(None, "Morning reminder", None, None, None, "2026-03-25T08:00:00")
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]
@ -322,7 +322,7 @@ def test_add_job_delivers_by_default(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool.set_context("telegram", "chat-1")
result = tool._add_job("Morning standup", 60, None, None, None)
result = tool._add_job(None, "Morning standup", 60, None, None, None)
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]
@ -333,7 +333,7 @@ def test_add_job_can_disable_delivery(tmp_path) -> None:
tool = _make_tool(tmp_path)
tool.set_context("telegram", "chat-1")
result = tool._add_job("Background refresh", 60, None, None, None, deliver=False)
result = tool._add_job(None, "Background refresh", 60, None, None, None, deliver=False)
assert result.startswith("Created job")
job = tool._cron.list_jobs()[0]

View File

@ -307,3 +307,54 @@ async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch)
assert result.finish_reason == "error"
assert result.content is not None
assert "stream stalled" in result.content
# ---------------------------------------------------------------------------
# Provider-specific thinking parameters (extra_body)
# ---------------------------------------------------------------------------
def _build_kwargs_for(provider_name: str, model: str, reasoning_effort=None):
spec = find_by_name(provider_name)
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
p = OpenAICompatProvider(api_key="k", default_model=model, spec=spec)
return p._build_kwargs(
messages=[{"role": "user", "content": "hi"}],
tools=None, model=model, max_tokens=1024, temperature=0.7,
reasoning_effort=reasoning_effort, tool_choice=None,
)
def test_dashscope_thinking_enabled_with_reasoning_effort() -> None:
kw = _build_kwargs_for("dashscope", "qwen3-plus", reasoning_effort="medium")
assert kw["extra_body"] == {"enable_thinking": True}
def test_dashscope_thinking_disabled_for_minimal() -> None:
kw = _build_kwargs_for("dashscope", "qwen3-plus", reasoning_effort="minimal")
assert kw["extra_body"] == {"enable_thinking": False}
def test_dashscope_no_extra_body_when_reasoning_effort_none() -> None:
kw = _build_kwargs_for("dashscope", "qwen-turbo", reasoning_effort=None)
assert "extra_body" not in kw
def test_volcengine_thinking_enabled() -> None:
kw = _build_kwargs_for("volcengine", "doubao-seed-2-0-pro", reasoning_effort="high")
assert kw["extra_body"] == {"thinking": {"type": "enabled"}}
def test_byteplus_thinking_disabled_for_minimal() -> None:
kw = _build_kwargs_for("byteplus", "doubao-seed-2-0-pro", reasoning_effort="minimal")
assert kw["extra_body"] == {"thinking": {"type": "disabled"}}
def test_byteplus_no_extra_body_when_reasoning_effort_none() -> None:
kw = _build_kwargs_for("byteplus", "doubao-seed-2-0-pro", reasoning_effort=None)
assert "extra_body" not in kw
def test_openai_no_thinking_extra_body() -> None:
"""Non-thinking providers should never get extra_body for thinking."""
kw = _build_kwargs_for("openai", "gpt-4o", reasoning_effort="medium")
assert "extra_body" not in kw

View File

@ -0,0 +1,81 @@
from types import SimpleNamespace
from nanobot.providers.anthropic_provider import AnthropicProvider
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
def _fake_response(
*,
status_code: int,
headers: dict[str, str] | None = None,
text: str = "",
) -> SimpleNamespace:
return SimpleNamespace(
status_code=status_code,
headers=headers or {},
text=text,
)
def test_openai_handle_error_extracts_structured_metadata() -> None:
class FakeStatusError(Exception):
pass
err = FakeStatusError("boom")
err.status_code = 409
err.response = _fake_response(
status_code=409,
headers={"retry-after-ms": "250", "x-should-retry": "false"},
text='{"error":{"type":"rate_limit_exceeded","code":"rate_limit_exceeded"}}',
)
err.body = {"error": {"type": "rate_limit_exceeded", "code": "rate_limit_exceeded"}}
response = OpenAICompatProvider._handle_error(err)
assert response.finish_reason == "error"
assert response.error_status_code == 409
assert response.error_type == "rate_limit_exceeded"
assert response.error_code == "rate_limit_exceeded"
assert response.error_retry_after_s == 0.25
assert response.error_should_retry is False
def test_openai_handle_error_marks_timeout_kind() -> None:
class FakeTimeoutError(Exception):
pass
response = OpenAICompatProvider._handle_error(FakeTimeoutError("timeout"))
assert response.finish_reason == "error"
assert response.error_kind == "timeout"
def test_anthropic_handle_error_extracts_structured_metadata() -> None:
class FakeStatusError(Exception):
pass
err = FakeStatusError("boom")
err.status_code = 408
err.response = _fake_response(
status_code=408,
headers={"retry-after": "1.5", "x-should-retry": "true"},
)
err.body = {"type": "error", "error": {"type": "rate_limit_error"}}
response = AnthropicProvider._handle_error(err)
assert response.finish_reason == "error"
assert response.error_status_code == 408
assert response.error_type == "rate_limit_error"
assert response.error_retry_after_s == 1.5
assert response.error_should_retry is True
def test_anthropic_handle_error_marks_connection_kind() -> None:
class FakeConnectionError(Exception):
pass
response = AnthropicProvider._handle_error(FakeConnectionError("connection"))
assert response.finish_reason == "error"
assert response.error_kind == "connection"

View File

@ -254,6 +254,14 @@ def test_extract_retry_after_from_headers_supports_numeric_and_http_date() -> No
) == 0.1
def test_extract_retry_after_from_headers_supports_retry_after_ms() -> None:
assert LLMProvider._extract_retry_after_from_headers({"retry-after-ms": "250"}) == 0.25
assert LLMProvider._extract_retry_after_from_headers({"Retry-After-Ms": "1000"}) == 1.0
assert LLMProvider._extract_retry_after_from_headers(
{"retry-after-ms": "500", "retry-after": "10"},
) == 0.5
@pytest.mark.asyncio
async def test_chat_with_retry_prefers_structured_retry_after_when_present(monkeypatch) -> None:
provider = ScriptedProvider([
@ -273,6 +281,153 @@ async def test_chat_with_retry_prefers_structured_retry_after_when_present(monke
assert delays == [9.0]
@pytest.mark.asyncio
async def test_chat_with_retry_retries_structured_status_code_without_keyword(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(
content="request failed",
finish_reason="error",
error_status_code=409,
),
LLMResponse(content="ok"),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "ok"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_chat_with_retry_stops_on_429_quota_exhausted(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(
content='{"error":{"type":"insufficient_quota","code":"insufficient_quota"}}',
finish_reason="error",
error_status_code=429,
error_type="insufficient_quota",
error_code="insufficient_quota",
),
LLMResponse(content="ok"),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.finish_reason == "error"
assert provider.calls == 1
assert delays == []
@pytest.mark.asyncio
async def test_chat_with_retry_retries_429_transient_rate_limit(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(
content='{"error":{"type":"rate_limit_exceeded","code":"rate_limit_exceeded"}}',
finish_reason="error",
error_status_code=429,
error_type="rate_limit_exceeded",
error_code="rate_limit_exceeded",
error_retry_after_s=0.2,
),
LLMResponse(content="ok"),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "ok"
assert provider.calls == 2
assert delays == [0.2]
@pytest.mark.asyncio
async def test_chat_with_retry_retries_structured_timeout_kind(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(
content="request failed",
finish_reason="error",
error_kind="timeout",
),
LLMResponse(content="ok"),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "ok"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_chat_with_retry_structured_should_retry_false_disables_retry(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(
content="429 rate limit",
finish_reason="error",
error_should_retry=False,
),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.finish_reason == "error"
assert provider.calls == 1
assert delays == []
@pytest.mark.asyncio
async def test_chat_with_retry_prefers_structured_retry_after(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(
content="429 rate limit, retry after 99s",
finish_reason="error",
error_retry_after_s=0.2,
),
LLMResponse(content="ok"),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "ok"
assert delays == [0.2]
@pytest.mark.asyncio
async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None:
provider = ScriptedProvider([
@ -295,4 +450,3 @@ async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monk
assert response.content == "429 rate limit"
assert provider.calls == 10
assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4]

0
tests/tools/__init__.py Normal file
View File

View File

@ -0,0 +1,38 @@
"""Tests for exec tool environment isolation."""
import pytest
from nanobot.agent.tools.shell import ExecTool
@pytest.mark.asyncio
async def test_exec_does_not_leak_parent_env(monkeypatch):
"""Env vars from the parent process must not be visible to commands."""
monkeypatch.setenv("NANOBOT_SECRET_TOKEN", "super-secret-value")
tool = ExecTool()
result = await tool.execute(command="printenv NANOBOT_SECRET_TOKEN")
assert "super-secret-value" not in result
@pytest.mark.asyncio
async def test_exec_has_working_path():
"""Basic commands should be available via the login shell's PATH."""
tool = ExecTool()
result = await tool.execute(command="echo hello")
assert "hello" in result
@pytest.mark.asyncio
async def test_exec_path_append():
"""The pathAppend config should be available in the command's PATH."""
tool = ExecTool(path_append="/opt/custom/bin")
result = await tool.execute(command="echo $PATH")
assert "/opt/custom/bin" in result
@pytest.mark.asyncio
async def test_exec_path_append_preserves_system_path():
"""pathAppend must not clobber standard system paths."""
tool = ExecTool(path_append="/opt/custom/bin")
result = await tool.execute(command="ls /")
assert "Exit code: 0" in result

View File

@ -112,7 +112,7 @@ class TestMessageToolSuppressLogic:
assert final_content == "Done"
assert progress == [
("Visible", False),
('read_file("foo.txt")', True),
('read foo.txt', True),
]

121
tests/tools/test_sandbox.py Normal file
View File

@ -0,0 +1,121 @@
"""Tests for nanobot.agent.tools.sandbox."""
import shlex
import pytest
from nanobot.agent.tools.sandbox import wrap_command
def _parse(cmd: str) -> list[str]:
"""Split a wrapped command back into tokens for assertion."""
return shlex.split(cmd)
class TestBwrapBackend:
def test_basic_structure(self, tmp_path):
ws = str(tmp_path / "project")
result = wrap_command("bwrap", "echo hi", ws, ws)
tokens = _parse(result)
assert tokens[0] == "bwrap"
assert "--new-session" in tokens
assert "--die-with-parent" in tokens
assert "--ro-bind" in tokens
assert "--proc" in tokens
assert "--dev" in tokens
assert "--tmpfs" in tokens
sep = tokens.index("--")
assert tokens[sep + 1:] == ["sh", "-c", "echo hi"]
def test_workspace_bind_mounted_rw(self, tmp_path):
ws = str(tmp_path / "project")
result = wrap_command("bwrap", "ls", ws, ws)
tokens = _parse(result)
bind_idx = [i for i, t in enumerate(tokens) if t == "--bind"]
assert any(tokens[i + 1] == ws and tokens[i + 2] == ws for i in bind_idx)
def test_parent_dir_masked_with_tmpfs(self, tmp_path):
ws = tmp_path / "project"
result = wrap_command("bwrap", "ls", str(ws), str(ws))
tokens = _parse(result)
tmpfs_indices = [i for i, t in enumerate(tokens) if t == "--tmpfs"]
tmpfs_targets = {tokens[i + 1] for i in tmpfs_indices}
assert str(ws.parent) in tmpfs_targets
def test_cwd_inside_workspace(self, tmp_path):
ws = tmp_path / "project"
sub = ws / "src" / "lib"
result = wrap_command("bwrap", "pwd", str(ws), str(sub))
tokens = _parse(result)
chdir_idx = tokens.index("--chdir")
assert tokens[chdir_idx + 1] == str(sub)
def test_cwd_outside_workspace_falls_back(self, tmp_path):
ws = tmp_path / "project"
outside = tmp_path / "other"
result = wrap_command("bwrap", "pwd", str(ws), str(outside))
tokens = _parse(result)
chdir_idx = tokens.index("--chdir")
assert tokens[chdir_idx + 1] == str(ws.resolve())
def test_command_with_special_characters(self, tmp_path):
ws = str(tmp_path / "project")
cmd = "echo 'hello world' && cat \"file with spaces.txt\""
result = wrap_command("bwrap", cmd, ws, ws)
tokens = _parse(result)
sep = tokens.index("--")
assert tokens[sep + 1:] == ["sh", "-c", cmd]
def test_system_dirs_ro_bound(self, tmp_path):
ws = str(tmp_path / "project")
result = wrap_command("bwrap", "ls", ws, ws)
tokens = _parse(result)
ro_bind_indices = [i for i, t in enumerate(tokens) if t == "--ro-bind"]
ro_targets = {tokens[i + 1] for i in ro_bind_indices}
assert "/usr" in ro_targets
def test_optional_dirs_use_ro_bind_try(self, tmp_path):
ws = str(tmp_path / "project")
result = wrap_command("bwrap", "ls", ws, ws)
tokens = _parse(result)
try_indices = [i for i, t in enumerate(tokens) if t == "--ro-bind-try"]
try_targets = {tokens[i + 1] for i in try_indices}
assert "/bin" in try_targets
assert "/etc/ssl/certs" in try_targets
def test_media_dir_ro_bind(self, tmp_path, monkeypatch):
"""Media directory should be read-only mounted inside the sandbox."""
fake_media = tmp_path / "media"
fake_media.mkdir()
monkeypatch.setattr(
"nanobot.agent.tools.sandbox.get_media_dir",
lambda: fake_media,
)
ws = str(tmp_path / "project")
result = wrap_command("bwrap", "ls", ws, ws)
tokens = _parse(result)
try_indices = [i for i, t in enumerate(tokens) if t == "--ro-bind-try"]
try_pairs = {(tokens[i + 1], tokens[i + 2]) for i in try_indices}
assert (str(fake_media), str(fake_media)) in try_pairs
class TestUnknownBackend:
def test_raises_value_error(self, tmp_path):
ws = str(tmp_path / "project")
with pytest.raises(ValueError, match="Unknown sandbox backend"):
wrap_command("nonexistent", "ls", ws, ws)
def test_empty_string_raises(self, tmp_path):
ws = str(tmp_path / "project")
with pytest.raises(ValueError):
wrap_command("", "ls", ws, ws)

View File

@ -1,3 +1,6 @@
import shlex
import subprocess
import sys
from typing import Any
from nanobot.agent.tools import (
@ -546,10 +549,15 @@ async def test_exec_head_tail_truncation() -> None:
"""Long output should preserve both head and tail."""
tool = ExecTool()
# Generate output that exceeds _MAX_OUTPUT (10_000 chars)
# Use python to generate output to avoid command line length limits
result = await tool.execute(
command="python -c \"print('A' * 6000 + '\\n' + 'B' * 6000)\""
)
# Use current interpreter (PATH may not have `python`). ExecTool uses
# create_subprocess_shell: POSIX needs shlex.quote; Windows uses cmd.exe
# rules, so list2cmdline is appropriate there.
script = "print('A' * 6000 + '\\n' + 'B' * 6000)"
if sys.platform == "win32":
command = subprocess.list2cmdline([sys.executable, "-c", script])
else:
command = f"{shlex.quote(sys.executable)} -c {shlex.quote(script)}"
result = await tool.execute(command=command)
assert "chars truncated" in result
# Head portion should start with As
assert result.startswith("A")

View File

@ -1,5 +1,7 @@
"""Tests for multi-provider web search."""
import asyncio
import httpx
import pytest
@ -160,3 +162,70 @@ async def test_searxng_invalid_url():
tool = _tool(provider="searxng", base_url="not-a-url")
result = await tool.execute(query="test")
assert "Error" in result
@pytest.mark.asyncio
async def test_jina_422_falls_back_to_duckduckgo(monkeypatch):
class MockDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}]
async def mock_get(self, url, **kw):
assert "s.jina.ai" in str(url)
raise httpx.HTTPStatusError(
"422 Unprocessable Entity",
request=httpx.Request("GET", str(url)),
response=httpx.Response(422, request=httpx.Request("GET", str(url))),
)
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
tool = _tool(provider="jina", api_key="jina-key")
result = await tool.execute(query="test")
assert "DuckDuckGo fallback" in result
@pytest.mark.asyncio
async def test_jina_search_uses_path_encoded_query(monkeypatch):
calls = {}
async def mock_get(self, url, **kw):
calls["url"] = str(url)
calls["params"] = kw.get("params")
return _response(json={
"data": [{"title": "Jina Result", "url": "https://jina.ai", "content": "AI search"}]
})
monkeypatch.setattr(httpx.AsyncClient, "get", mock_get)
tool = _tool(provider="jina", api_key="jina-key")
await tool.execute(query="hello world")
assert calls["url"].rstrip("/") == "https://s.jina.ai/hello%20world"
assert calls["params"] in (None, {})
@pytest.mark.asyncio
async def test_duckduckgo_timeout_returns_error(monkeypatch):
"""asyncio.wait_for guard should fire when DDG search hangs."""
import threading
gate = threading.Event()
class HangingDDGS:
def __init__(self, **kw):
pass
def text(self, query, max_results=5):
gate.wait(timeout=10)
return []
monkeypatch.setattr("ddgs.DDGS", HangingDDGS)
tool = _tool(provider="duckduckgo")
tool.config.timeout = 0.2
result = await tool.execute(query="test")
gate.set()
assert "Error" in result

View File

@ -0,0 +1,105 @@
"""Tests for abbreviate_path utility."""
import os
from nanobot.utils.path import abbreviate_path
class TestAbbreviatePathShort:
def test_short_path_unchanged(self):
assert abbreviate_path("/home/user/file.py") == "/home/user/file.py"
def test_exact_max_len_unchanged(self):
path = "/a/b/c" # 7 chars
assert abbreviate_path("/a/b/c", max_len=7) == "/a/b/c"
def test_basename_only(self):
assert abbreviate_path("file.py") == "file.py"
def test_empty_string(self):
assert abbreviate_path("") == ""
class TestAbbreviatePathHome:
def test_home_replacement(self):
home = os.path.expanduser("~")
result = abbreviate_path(f"{home}/project/file.py")
assert result.startswith("~/")
assert result.endswith("file.py")
def test_home_preserves_short_path(self):
home = os.path.expanduser("~")
result = abbreviate_path(f"{home}/a.py")
assert result == "~/a.py"
class TestAbbreviatePathLong:
def test_long_path_keeps_basename(self):
path = "/a/b/c/d/e/f/g/h/very_long_filename.py"
result = abbreviate_path(path, max_len=30)
assert result.endswith("very_long_filename.py")
assert "\u2026" in result
def test_long_path_keeps_parent_dir(self):
path = "/a/b/c/d/e/f/g/h/src/loop.py"
result = abbreviate_path(path, max_len=30)
assert "loop.py" in result
assert "src" in result
def test_very_long_path_just_basename(self):
path = "/a/b/c/d/e/f/g/h/i/j/k/l/m/n/o/p/q/r/s/t/u/v/w/x/y/z/file.py"
result = abbreviate_path(path, max_len=20)
assert result.endswith("file.py")
assert len(result) <= 20
class TestAbbreviatePathWindows:
def test_windows_drive_path(self):
path = "D:\\Documents\\GitHub\\nanobot\\src\\utils\\helpers.py"
result = abbreviate_path(path, max_len=40)
assert result.endswith("helpers.py")
assert "nanobot" in result
def test_windows_home(self):
home = os.path.expanduser("~")
path = os.path.join(home, ".nanobot", "workspace", "log.txt")
result = abbreviate_path(path)
assert result.startswith("~/")
assert "log.txt" in result
class TestAbbreviatePathURLs:
def test_url_keeps_domain_and_filename(self):
url = "https://example.com/api/v2/long/path/resource.json"
result = abbreviate_path(url, max_len=40)
assert "resource.json" in result
assert "example.com" in result
def test_short_url_unchanged(self):
url = "https://example.com/api"
assert abbreviate_path(url) == url
def test_url_no_path_just_domain(self):
"""G3: URL with no path should return as-is if short enough."""
url = "https://example.com"
assert abbreviate_path(url) == url
def test_url_with_query_string(self):
"""G3: URL with query params should abbreviate path part."""
url = "https://example.com/api/v2/endpoint?key=value&other=123"
result = abbreviate_path(url, max_len=40)
assert "example.com" in result
assert "\u2026" in result
def test_url_very_long_basename(self):
"""G3: URL with very long basename should truncate basename."""
url = "https://example.com/path/very_long_resource_name_file.json"
result = abbreviate_path(url, max_len=35)
assert "example.com" in result
assert "\u2026" in result
def test_url_negative_budget_consistent_format(self):
"""I3: Negative budget should still produce domain/…/basename format."""
url = "https://a.co/very/deep/path/with/lots/of/segments/and/a/long/basename.txt"
result = abbreviate_path(url, max_len=20)
assert "a.co" in result
assert "/\u2026/" in result

View File

@ -0,0 +1,306 @@
"""Tests for web search provider usage fetching and /status integration."""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from nanobot.utils.searchusage import (
SearchUsageInfo,
_parse_tavily_usage,
fetch_search_usage,
)
from nanobot.utils.helpers import build_status_content
# ---------------------------------------------------------------------------
# SearchUsageInfo.format() tests
# ---------------------------------------------------------------------------
class TestSearchUsageInfoFormat:
def test_unsupported_provider_shows_no_tracking(self):
info = SearchUsageInfo(provider="duckduckgo", supported=False)
text = info.format()
assert "duckduckgo" in text
assert "not available" in text
def test_supported_with_error(self):
info = SearchUsageInfo(provider="tavily", supported=True, error="HTTP 401")
text = info.format()
assert "tavily" in text
assert "HTTP 401" in text
assert "unavailable" in text
def test_full_tavily_usage(self):
info = SearchUsageInfo(
provider="tavily",
supported=True,
used=142,
limit=1000,
remaining=858,
reset_date="2026-05-01",
search_used=120,
extract_used=15,
crawl_used=7,
)
text = info.format()
assert "tavily" in text
assert "142 / 1000" in text
assert "858" in text
assert "2026-05-01" in text
assert "Search: 120" in text
assert "Extract: 15" in text
assert "Crawl: 7" in text
def test_usage_without_limit(self):
info = SearchUsageInfo(provider="tavily", supported=True, used=50)
text = info.format()
assert "50 requests" in text
assert "/" not in text.split("Usage:")[1].split("\n")[0]
def test_no_breakdown_when_none(self):
info = SearchUsageInfo(
provider="tavily", supported=True, used=10, limit=100, remaining=90
)
text = info.format()
assert "Breakdown" not in text
def test_brave_unsupported(self):
info = SearchUsageInfo(provider="brave", supported=False)
text = info.format()
assert "brave" in text
assert "not available" in text
# ---------------------------------------------------------------------------
# _parse_tavily_usage tests
# ---------------------------------------------------------------------------
class TestParseTavilyUsage:
def test_full_response(self):
data = {
"account": {
"current_plan": "Researcher",
"plan_usage": 142,
"plan_limit": 1000,
"search_usage": 120,
"extract_usage": 15,
"crawl_usage": 7,
"map_usage": 0,
"research_usage": 0,
"paygo_usage": 0,
"paygo_limit": None,
},
}
info = _parse_tavily_usage(data)
assert info.provider == "tavily"
assert info.supported is True
assert info.used == 142
assert info.limit == 1000
assert info.remaining == 858
assert info.search_used == 120
assert info.extract_used == 15
assert info.crawl_used == 7
def test_remaining_computed(self):
data = {"account": {"plan_usage": 300, "plan_limit": 1000}}
info = _parse_tavily_usage(data)
assert info.remaining == 700
def test_remaining_not_negative(self):
data = {"account": {"plan_usage": 1100, "plan_limit": 1000}}
info = _parse_tavily_usage(data)
assert info.remaining == 0
def test_empty_response(self):
info = _parse_tavily_usage({})
assert info.provider == "tavily"
assert info.supported is True
assert info.used is None
assert info.limit is None
def test_no_breakdown_fields(self):
data = {"account": {"plan_usage": 5, "plan_limit": 50}}
info = _parse_tavily_usage(data)
assert info.search_used is None
assert info.extract_used is None
assert info.crawl_used is None
# ---------------------------------------------------------------------------
# fetch_search_usage routing tests
# ---------------------------------------------------------------------------
class TestFetchSearchUsageRouting:
@pytest.mark.asyncio
async def test_duckduckgo_returns_unsupported(self):
info = await fetch_search_usage("duckduckgo")
assert info.provider == "duckduckgo"
assert info.supported is False
@pytest.mark.asyncio
async def test_searxng_returns_unsupported(self):
info = await fetch_search_usage("searxng")
assert info.supported is False
@pytest.mark.asyncio
async def test_jina_returns_unsupported(self):
info = await fetch_search_usage("jina")
assert info.supported is False
@pytest.mark.asyncio
async def test_brave_returns_unsupported(self):
info = await fetch_search_usage("brave")
assert info.provider == "brave"
assert info.supported is False
@pytest.mark.asyncio
async def test_unknown_provider_returns_unsupported(self):
info = await fetch_search_usage("some_unknown_provider")
assert info.supported is False
@pytest.mark.asyncio
async def test_tavily_no_api_key_returns_error(self):
with patch.dict("os.environ", {}, clear=True):
# Ensure TAVILY_API_KEY is not set
import os
os.environ.pop("TAVILY_API_KEY", None)
info = await fetch_search_usage("tavily", api_key=None)
assert info.provider == "tavily"
assert info.supported is True
assert info.error is not None
assert "not configured" in info.error
@pytest.mark.asyncio
async def test_tavily_success(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"account": {
"current_plan": "Researcher",
"plan_usage": 142,
"plan_limit": 1000,
"search_usage": 120,
"extract_usage": 15,
"crawl_usage": 7,
},
}
mock_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(return_value=mock_response)
with patch("httpx.AsyncClient", return_value=mock_client):
info = await fetch_search_usage("tavily", api_key="test-key")
assert info.provider == "tavily"
assert info.supported is True
assert info.error is None
assert info.used == 142
assert info.limit == 1000
assert info.remaining == 858
assert info.search_used == 120
@pytest.mark.asyncio
async def test_tavily_http_error(self):
import httpx
mock_response = MagicMock()
mock_response.status_code = 401
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
"401", request=MagicMock(), response=mock_response
)
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(return_value=mock_response)
with patch("httpx.AsyncClient", return_value=mock_client):
info = await fetch_search_usage("tavily", api_key="bad-key")
assert info.supported is True
assert info.error == "HTTP 401"
@pytest.mark.asyncio
async def test_tavily_network_error(self):
import httpx
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client.get = AsyncMock(side_effect=httpx.ConnectError("timeout"))
with patch("httpx.AsyncClient", return_value=mock_client):
info = await fetch_search_usage("tavily", api_key="test-key")
assert info.supported is True
assert info.error is not None
@pytest.mark.asyncio
async def test_provider_name_case_insensitive(self):
info = await fetch_search_usage("Tavily", api_key=None)
assert info.provider == "tavily"
assert info.supported is True
# ---------------------------------------------------------------------------
# build_status_content integration tests
# ---------------------------------------------------------------------------
class TestBuildStatusContentWithSearchUsage:
_BASE_KWARGS = dict(
version="0.1.0",
model="claude-opus-4-5",
start_time=1_000_000.0,
last_usage={"prompt_tokens": 1000, "completion_tokens": 200},
context_window_tokens=65536,
session_msg_count=5,
context_tokens_estimate=3000,
)
def test_no_search_usage_unchanged(self):
"""Omitting search_usage_text keeps existing behaviour."""
content = build_status_content(**self._BASE_KWARGS)
assert "🔍" not in content
assert "Web Search" not in content
def test_search_usage_none_unchanged(self):
content = build_status_content(**self._BASE_KWARGS, search_usage_text=None)
assert "🔍" not in content
def test_search_usage_appended(self):
usage_text = "🔍 Web Search: tavily\n Usage: 142 / 1000 requests"
content = build_status_content(**self._BASE_KWARGS, search_usage_text=usage_text)
assert "🔍 Web Search: tavily" in content
assert "142 / 1000" in content
def test_existing_fields_still_present(self):
usage_text = "🔍 Web Search: duckduckgo\n Usage tracking: not available"
content = build_status_content(**self._BASE_KWARGS, search_usage_text=usage_text)
# Original fields must still be present
assert "nanobot v0.1.0" in content
assert "claude-opus-4-5" in content
assert "1000 in / 200 out" in content
# New field appended
assert "duckduckgo" in content
def test_full_tavily_in_status(self):
info = SearchUsageInfo(
provider="tavily",
supported=True,
used=142,
limit=1000,
remaining=858,
reset_date="2026-05-01",
search_used=120,
extract_used=15,
crawl_used=7,
)
content = build_status_content(**self._BASE_KWARGS, search_usage_text=info.format())
assert "142 / 1000" in content
assert "858" in content
assert "2026-05-01" in content
assert "Search: 120" in content