diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..9da244d8c --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Ensure shell scripts always use LF line endings (Docker/Linux compat) +*.sh text eol=lf diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e00362d02..fac9be66c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,5 +30,8 @@ jobs: - name: Install all dependencies run: uv sync --all-extras + - name: Lint with ruff + run: uv run ruff check nanobot --select F401,F841 + - name: Run tests run: uv run pytest tests/ diff --git a/Dockerfile b/Dockerfile index 90f0e36a5..3b86d61b6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,7 +2,7 @@ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim # Install Node.js 20 for the WhatsApp bridge RUN apt-get update && \ - apt-get install -y --no-install-recommends curl ca-certificates gnupg git openssh-client && \ + apt-get install -y --no-install-recommends curl ca-certificates gnupg git bubblewrap openssh-client && \ mkdir -p /etc/apt/keyrings && \ curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg && \ echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_20.x nodistro main" > /etc/apt/sources.list.d/nodesource.list && \ @@ -32,11 +32,19 @@ RUN git config --global --add url."https://github.com/".insteadOf ssh://git@gith npm install && npm run build WORKDIR /app -# Create config directory -RUN mkdir -p /root/.nanobot +# Create non-root user and config directory +RUN useradd -m -u 1000 -s /bin/bash nanobot && \ + mkdir -p /home/nanobot/.nanobot && \ + chown -R nanobot:nanobot /home/nanobot /app + +COPY entrypoint.sh /usr/local/bin/entrypoint.sh +RUN sed -i 's/\r$//' /usr/local/bin/entrypoint.sh && chmod +x /usr/local/bin/entrypoint.sh + +USER nanobot +ENV HOME=/home/nanobot # Gateway default port EXPOSE 18790 -ENTRYPOINT ["nanobot"] +ENTRYPOINT ["entrypoint.sh"] CMD ["status"] diff --git a/README.md b/README.md index b62079351..a2ea20f8c 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,44 @@
nanobot -

nanobot: Ultra-Lightweight Personal AI Assistant

+

nanobot: Ultra-Lightweight Personal AI Agent

PyPI Downloads Python License + Docs Feishu WeChat Discord

-🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw). +🐈 **nanobot** is an **ultra-lightweight** personal AI agent inspired by [OpenClaw](https://github.com/openclaw/openclaw). -⚑️ Delivers core agent functionality with **99% fewer lines of code** than OpenClaw. +⚑️ Delivers core agent functionality with **99% fewer lines of code**. πŸ“ Real-time line count: run `bash core_agent_lines.sh` to verify anytime. ## πŸ“’ News -- **2026-04-02** 🧱 **Long-running tasks** run more reliably β€” core runtime hardening. +- **2026-04-05** πŸš€ Released **v0.1.5** β€” sturdier long-running tasks, Dream two-stage memory, production-ready sandboxing and programming Agent SDK. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.5) for details. +- **2026-04-04** πŸš€ Jinja2 response templates, Dream memory hardened, smarter retry handling. +- **2026-04-03** 🧠 Xiaomi MiMo provider, chain-of-thought reasoning visible, Telegram UX polish. +- **2026-04-02** 🧱 Long-running tasks run more reliably β€” core runtime hardening. - **2026-04-01** πŸ”‘ GitHub Copilot auth restored; stricter workspace paths; OpenRouter Claude caching fix. - **2026-03-31** πŸ›°οΈ WeChat multimodal alignment, Discord/Matrix polish, Python SDK facade, MCP and tool fixes. - **2026-03-30** 🧩 OpenAI-compatible API tightened; composable agent lifecycle hooks. - **2026-03-29** πŸ’¬ WeChat voice, typing, QR/media resilience; fixed-session OpenAI-compatible API. - **2026-03-28** πŸ“š Provider docs refresh; skill template wording fix. - **2026-03-27** πŸš€ Released **v0.1.4.post6** β€” architecture decoupling, litellm removal, end-to-end streaming, WeChat channel, and a security fix. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post6) for details. -- **2026-03-26** πŸ—οΈ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries. -- **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures. -- **2026-03-24** πŸ”§ WeChat compatibility, Feishu CardKit streaming, test suite restructured. +
Earlier news +- **2026-03-26** πŸ—οΈ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries. +- **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures. +- **2026-03-24** πŸ”§ WeChat compatibility, Feishu CardKit streaming, test suite restructured. - **2026-03-23** πŸ”§ Command routing refactored for plugins, WhatsApp/WeChat media, unified channel login CLI. - **2026-03-22** ⚑ End-to-end streaming, WeChat channel, Anthropic cache optimization, `/status` command. - **2026-03-21** πŸ”’ Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7). @@ -91,7 +96,7 @@ ## Key Features of nanobot: -πŸͺΆ **Ultra-Lightweight**: A super lightweight implementation of OpenClaw β€” 99% smaller, significantly faster. +πŸͺΆ **Ultra-Lightweight**: A lightweight implementation built for stable, long-running AI agents. πŸ”¬ **Research-Ready**: Clean, readable code that's easy to understand, modify, and extend for research. @@ -140,7 +145,7 @@

-

+

@@ -252,7 +257,7 @@ Configure these **two parts** in your config (other options have defaults). nanobot agent ``` -That's it! You have a working AI assistant in 2 minutes. +That's it! You have a working AI agent in 2 minutes. ## πŸ’¬ Chat Apps @@ -433,9 +438,11 @@ pip install nanobot-ai[matrix] - You need: - `userId` (example: `@nanobot:matrix.org`) - - `accessToken` - - `deviceId` (recommended so sync tokens can be restored across restarts) -- You can obtain these from your homeserver login API (`/_matrix/client/v3/login`) or from your client's advanced session settings. + - `password` + +(Note: `accessToken` and `deviceId` are still supported for legacy reasons, but +for reliable encryption, password login is recommended instead. If the +`password` is provided, `accessToken` and `deviceId` will be ignored.) **3. Configure** @@ -446,8 +453,7 @@ pip install nanobot-ai[matrix] "enabled": true, "homeserver": "https://matrix.org", "userId": "@nanobot:matrix.org", - "accessToken": "syt_xxx", - "deviceId": "NANOBOT01", + "password": "mypasswordhere", "e2eeEnabled": true, "allowFrom": ["@your_user:matrix.org"], "groupPolicy": "open", @@ -459,7 +465,7 @@ pip install nanobot-ai[matrix] } ``` -> Keep a persistent `matrix-store` and stable `deviceId` β€” encrypted session state is lost if these change across restarts. +> Keep a persistent `matrix-store` β€” encrypted session state is lost if these change across restarts. | Option | Description | |--------|-------------| @@ -720,6 +726,9 @@ Give nanobot its own email account. It polls **IMAP** for incoming mail and repl > - `allowFrom`: Add your email address. Use `["*"]` to accept emails from anyone. > - `smtpUseTls` and `smtpUseSsl` default to `true` / `false` respectively, which is correct for Gmail (port 587 + STARTTLS). No need to set them explicitly. > - Set `"autoReplyEnabled": false` if you only want to read/analyze emails without sending automatic replies. +> - `allowedAttachmentTypes`: Save inbound attachments matching these MIME types β€” `["*"]` for all, e.g. `["application/pdf", "image/*"]` (default `[]` = disabled). +> - `maxAttachmentSize`: Max size per attachment in bytes (default `2000000` / 2MB). +> - `maxAttachmentsPerEmail`: Max attachments to save per email (default `5`). ```json { @@ -736,7 +745,8 @@ Give nanobot its own email account. It polls **IMAP** for incoming mail and repl "smtpUsername": "my-nanobot@gmail.com", "smtpPassword": "your-app-password", "fromAddress": "my-nanobot@gmail.com", - "allowFrom": ["your-real-email@gmail.com"] + "allowFrom": ["your-real-email@gmail.com"], + "allowedAttachmentTypes": ["application/pdf", "image/*"] } } } @@ -861,10 +871,45 @@ Config file: `~/.nanobot/config.json` > run `nanobot onboard`, then answer `N` when asked whether to overwrite the config. > nanobot will merge in missing default fields and keep your current settings. +### Environment Variables for Secrets + +Instead of storing secrets directly in `config.json`, you can use `${VAR_NAME}` references that are resolved from environment variables at startup: + +```json +{ + "channels": { + "telegram": { "token": "${TELEGRAM_TOKEN}" }, + "email": { + "imapPassword": "${IMAP_PASSWORD}", + "smtpPassword": "${SMTP_PASSWORD}" + } + }, + "providers": { + "groq": { "apiKey": "${GROQ_API_KEY}" } + } +} +``` + +For **systemd** deployments, use `EnvironmentFile=` in the service unit to load variables from a file that only the deploying user can read: + +```ini +# /etc/systemd/system/nanobot.service (excerpt) +[Service] +EnvironmentFile=/home/youruser/nanobot_secrets.env +User=nanobot +ExecStart=... +``` + +```bash +# /home/youruser/nanobot_secrets.env (mode 600, owned by youruser) +TELEGRAM_TOKEN=your-token-here +IMAP_PASSWORD=your-password-here +``` + ### Providers > [!TIP] -> - **Groq** provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed. +> - **Voice transcription**: Voice messages (Telegram, WhatsApp) are automatically transcribed using Whisper. By default Groq is used (free tier). Set `"transcriptionProvider": "openai"` under `channels` to use OpenAI Whisper instead β€” the API key is picked from the matching provider config. > - **MiniMax Coding Plan**: Exclusive discount links for the nanobot community: [Overseas](https://platform.minimax.io/subscribe/coding-plan?code=9txpdXw04g&source=link) Β· [Mainland China](https://platform.minimaxi.com/subscribe/token-plan?code=GILTJpMTqZ&source=link) > - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config. > - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers. @@ -880,9 +925,9 @@ Config file: `~/.nanobot/config.json` | `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) Β· [byteplus.com](https://www.byteplus.com) | | `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | | `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) | -| `openai` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | +| `openai` | LLM + Voice transcription (Whisper) | [platform.openai.com](https://platform.openai.com) | | `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | -| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | +| `groq` | LLM + Voice transcription (Whisper, default) | [console.groq.com](https://console.groq.com) | | `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) | | `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | | `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) | @@ -1197,6 +1242,7 @@ Global settings that apply to all channels. Configure under the `channels` secti "sendProgress": true, "sendToolHints": false, "sendMaxRetries": 3, + "transcriptionProvider": "groq", "telegram": { ... } } } @@ -1207,6 +1253,7 @@ Global settings that apply to all channels. Configure under the `channels` secti | `sendProgress` | `true` | Stream agent's text progress to the channel | | `sendToolHints` | `false` | Stream tool-call hints (e.g. `read_file("…")`) | | `sendMaxRetries` | `3` | Max delivery attempts per outbound message, including the initial send (0-10 configured, minimum 1 actual attempt) | +| `transcriptionProvider` | `"groq"` | Voice transcription backend: `"groq"` (free tier, default) or `"openai"`. API key is auto-resolved from the matching provider config. | #### Retry Behavior @@ -1434,16 +1481,19 @@ MCP tools are automatically discovered and registered on startup. The LLM can us ### Security > [!TIP] -> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent. +> For production deployments, set `"restrictToWorkspace": true` and `"tools.exec.sandbox": "bwrap"` in your config to sandbox the agent. > In `v0.1.4.post3` and earlier, an empty `allowFrom` allowed all senders. Since `v0.1.4.post4`, empty `allowFrom` denies all access by default. To allow all senders, set `"allowFrom": ["*"]`. | Option | Default | Description | |--------|---------|-------------| | `tools.restrictToWorkspace` | `false` | When `true`, restricts **all** agent tools (shell, file read/write/edit, list) to the workspace directory. Prevents path traversal and out-of-scope access. | +| `tools.exec.sandbox` | `""` | Sandbox backend for shell commands. Set to `"bwrap"` to wrap exec calls in a [bubblewrap](https://github.com/containers/bubblewrap) sandbox β€” the process can only see the workspace (read-write) and media directory (read-only); config files and API keys are hidden. Automatically enables `restrictToWorkspace` for file tools. **Linux only** β€” requires `bwrap` installed (`apt install bubblewrap`; pre-installed in the Docker image). Not available on macOS or Windows (bwrap depends on Linux kernel namespaces). | | `tools.exec.enable` | `true` | When `false`, the shell `exec` tool is not registered at all. Use this to completely disable shell command execution. | | `tools.exec.pathAppend` | `""` | Extra directories to append to `PATH` when running shell commands (e.g. `/usr/sbin` for `ufw`). | | `channels.*.allowFrom` | `[]` (deny all) | Whitelist of user IDs. Empty denies all; use `["*"]` to allow everyone. | +**Docker security**: The official Docker image runs as a non-root user (`nanobot`, UID 1000) with bubblewrap pre-installed. When using `docker-compose.yml`, the container drops all Linux capabilities except `SYS_ADMIN` (required for bwrap's namespace isolation). + ### Timezone @@ -1763,7 +1813,8 @@ print(resp.choices[0].message.content) ## 🐳 Docker > [!TIP] -> The `-v ~/.nanobot:/root/.nanobot` flag mounts your local config directory into the container, so your config and workspace persist across container restarts. +> The `-v ~/.nanobot:/home/nanobot/.nanobot` flag mounts your local config directory into the container, so your config and workspace persist across container restarts. +> The container runs as user `nanobot` (UID 1000). If you get **Permission denied**, fix ownership on the host first: `sudo chown -R 1000:1000 ~/.nanobot`, or pass `--user $(id -u):$(id -g)` to match your host UID. Podman users can use `--userns=keep-id` instead. ### Docker Compose @@ -1786,17 +1837,17 @@ docker compose down # stop docker build -t nanobot . # Initialize config (first time only) -docker run -v ~/.nanobot:/root/.nanobot --rm nanobot onboard +docker run -v ~/.nanobot:/home/nanobot/.nanobot --rm nanobot onboard # Edit config on host to add API keys vim ~/.nanobot/config.json # Run gateway (connects to enabled channels, e.g. Telegram/Discord/Mochat) -docker run -v ~/.nanobot:/root/.nanobot -p 18790:18790 nanobot gateway +docker run -v ~/.nanobot:/home/nanobot/.nanobot -p 18790:18790 nanobot gateway # Or run a single command -docker run -v ~/.nanobot:/root/.nanobot --rm nanobot agent -m "Hello!" -docker run -v ~/.nanobot:/root/.nanobot --rm nanobot status +docker run -v ~/.nanobot:/home/nanobot/.nanobot --rm nanobot agent -m "Hello!" +docker run -v ~/.nanobot:/home/nanobot/.nanobot --rm nanobot status ``` ## 🐧 Linux Service diff --git a/SECURITY.md b/SECURITY.md index d98adb6e9..8e65d4042 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -64,6 +64,7 @@ chmod 600 ~/.nanobot/config.json The `exec` tool can execute shell commands. While dangerous command patterns are blocked, you should: +- βœ… **Enable the bwrap sandbox** (`"tools.exec.sandbox": "bwrap"`) for kernel-level isolation (Linux only) - βœ… Review all tool usage in agent logs - βœ… Understand what commands the agent is running - βœ… Use a dedicated user account with limited privileges @@ -71,6 +72,19 @@ The `exec` tool can execute shell commands. While dangerous command patterns are - ❌ Don't disable security checks - ❌ Don't run on systems with sensitive data without careful review +**Exec sandbox (bwrap):** + +On Linux, set `"tools.exec.sandbox": "bwrap"` to wrap every shell command in a [bubblewrap](https://github.com/containers/bubblewrap) sandbox. This uses Linux kernel namespaces to restrict what the process can see: + +- Workspace directory β†’ **read-write** (agent works normally) +- Media directory β†’ **read-only** (can read uploaded attachments) +- System directories (`/usr`, `/bin`, `/lib`) β†’ **read-only** (commands still work) +- Config files and API keys (`~/.nanobot/config.json`) β†’ **hidden** (masked by tmpfs) + +Requires `bwrap` installed (`apt install bubblewrap`). Pre-installed in the official Docker image. **Not available on macOS or Windows** β€” bubblewrap depends on Linux kernel namespaces. + +Enabling the sandbox also automatically activates `restrictToWorkspace` for file tools. + **Blocked patterns:** - `rm -rf /` - Root filesystem deletion - Fork bombs @@ -82,6 +96,7 @@ The `exec` tool can execute shell commands. While dangerous command patterns are File operations have path traversal protection, but: +- βœ… Enable `restrictToWorkspace` or the bwrap sandbox to confine file access - βœ… Run nanobot with a dedicated user account - βœ… Use filesystem permissions to protect sensitive directories - βœ… Regularly audit file operations in logs @@ -232,7 +247,7 @@ If you suspect a security breach: 1. **No Rate Limiting** - Users can send unlimited messages (add your own if needed) 2. **Plain Text Config** - API keys stored in plain text (use keyring for production) 3. **No Session Management** - No automatic session expiry -4. **Limited Command Filtering** - Only blocks obvious dangerous patterns +4. **Limited Command Filtering** - Only blocks obvious dangerous patterns (enable the bwrap sandbox for kernel-level isolation on Linux) 5. **No Audit Trail** - Limited security event logging (enhance as needed) ## Security Checklist @@ -243,6 +258,7 @@ Before deploying nanobot: - [ ] Config file permissions set to 0600 - [ ] `allowFrom` lists configured for all channels - [ ] Running as non-root user +- [ ] Exec sandbox enabled (`"tools.exec.sandbox": "bwrap"`) on Linux deployments - [ ] File system permissions properly restricted - [ ] Dependencies updated to latest secure versions - [ ] Logs monitored for security events @@ -252,7 +268,7 @@ Before deploying nanobot: ## Updates -**Last Updated**: 2026-02-03 +**Last Updated**: 2026-04-05 For the latest security updates and announcements, check: - GitHub Security Advisories: https://github.com/HKUDS/nanobot/security/advisories diff --git a/case/scedule.gif b/case/schedule.gif similarity index 100% rename from case/scedule.gif rename to case/schedule.gif diff --git a/docker-compose.yml b/docker-compose.yml index 5c27f81a0..21beb1c6f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,7 +3,14 @@ x-common-config: &common-config context: . dockerfile: Dockerfile volumes: - - ~/.nanobot:/root/.nanobot + - ~/.nanobot:/home/nanobot/.nanobot + cap_drop: + - ALL + cap_add: + - SYS_ADMIN + security_opt: + - apparmor=unconfined + - seccomp=unconfined services: nanobot-gateway: @@ -16,12 +23,29 @@ services: deploy: resources: limits: - cpus: '1' + cpus: "1" memory: 1G reservations: - cpus: '0.25' + cpus: "0.25" memory: 256M - + + nanobot-api: + container_name: nanobot-api + <<: *common-config + command: + ["serve", "--host", "0.0.0.0", "-w", "/home/nanobot/.nanobot/api-workspace"] + restart: unless-stopped + ports: + - 127.0.0.1:8900:8900 + deploy: + resources: + limits: + cpus: "1" + memory: 1G + reservations: + cpus: "0.25" + memory: 256M + nanobot-cli: <<: *common-config profiles: diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 000000000..ab780dc96 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,15 @@ +#!/bin/sh +dir="$HOME/.nanobot" +if [ -d "$dir" ] && [ ! -w "$dir" ]; then + owner_uid=$(stat -c %u "$dir" 2>/dev/null || stat -f %u "$dir" 2>/dev/null) + cat >&2 < bool: return any(h.wants_streaming() for h in self._hooks) - async def before_iteration(self, context: AgentHookContext) -> None: + async def _for_each_hook_safe(self, method_name: str, *args: Any, **kwargs: Any) -> None: for h in self._hooks: try: - await h.before_iteration(context) + await getattr(h, method_name)(*args, **kwargs) except Exception: - logger.exception("AgentHook.before_iteration error in {}", type(h).__name__) + logger.exception("AgentHook.{} error in {}", method_name, type(h).__name__) + + async def before_iteration(self, context: AgentHookContext) -> None: + await self._for_each_hook_safe("before_iteration", context) async def on_stream(self, context: AgentHookContext, delta: str) -> None: - for h in self._hooks: - try: - await h.on_stream(context, delta) - except Exception: - logger.exception("AgentHook.on_stream error in {}", type(h).__name__) + await self._for_each_hook_safe("on_stream", context, delta) async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: - for h in self._hooks: - try: - await h.on_stream_end(context, resuming=resuming) - except Exception: - logger.exception("AgentHook.on_stream_end error in {}", type(h).__name__) + await self._for_each_hook_safe("on_stream_end", context, resuming=resuming) async def before_execute_tools(self, context: AgentHookContext) -> None: - for h in self._hooks: - try: - await h.before_execute_tools(context) - except Exception: - logger.exception("AgentHook.before_execute_tools error in {}", type(h).__name__) + await self._for_each_hook_safe("before_execute_tools", context) async def after_iteration(self, context: AgentHookContext) -> None: - for h in self._hooks: - try: - await h.after_iteration(context) - except Exception: - logger.exception("AgentHook.after_iteration error in {}", type(h).__name__) + await self._for_each_hook_safe("after_iteration", context) def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: for h in self._hooks: diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 25a81b62b..66d765d00 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -4,7 +4,6 @@ from __future__ import annotations import asyncio import json -import re import os import time from contextlib import AsyncExitStack, nullcontext @@ -262,7 +261,7 @@ class AgentLoop: def _register_default_tools(self) -> None: """Register the default set of tools.""" - allowed_dir = self.workspace if self.restrict_to_workspace else None + allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read)) for cls in (WriteFileTool, EditFileTool, ListDirTool): @@ -274,6 +273,7 @@ class AgentLoop: working_dir=str(self.workspace), timeout=self.exec_config.timeout, restrict_to_workspace=self.restrict_to_workspace, + sandbox=self.exec_config.sandbox, path_append=self.exec_config.path_append, )) if self.web_config.enable: @@ -325,14 +325,10 @@ class AgentLoop: @staticmethod def _tool_hint(tool_calls: list) -> str: - """Format tool calls as concise hint, e.g. 'web_search("query")'.""" - def _fmt(tc): - args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {} - val = next(iter(args.values()), None) if isinstance(args, dict) else None - if not isinstance(val, str): - return tc.name - return f'{tc.name}("{val[:40]}…")' if len(val) > 40 else f'{tc.name}("{val}")' - return ", ".join(_fmt(tc) for tc in tool_calls) + """Format tool calls as concise hints with smart abbreviation.""" + from nanobot.utils.tool_hints import format_tool_hints + + return format_tool_hints(tool_calls) async def _run_agent_loop( self, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 12dd2287b..fbc2a4788 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -30,6 +30,7 @@ from nanobot.utils.runtime import ( ) _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model." +_MAX_EMPTY_RETRIES = 2 _SNIP_SAFETY_BUFFER = 1024 @dataclass(slots=True) class AgentRunSpec: @@ -86,6 +87,7 @@ class AgentRunner: stop_reason = "completed" tool_events: list[dict[str, str]] = [] external_lookup_counts: dict[str, int] = {} + empty_content_retries = 0 for iteration in range(spec.max_iterations): try: @@ -178,15 +180,30 @@ class AgentRunner: "pending_tool_calls": [], }, ) + empty_content_retries = 0 await hook.after_iteration(context) continue clean = hook.finalize_content(context, response.content) if response.finish_reason != "error" and is_blank_text(clean): + empty_content_retries += 1 + if empty_content_retries < _MAX_EMPTY_RETRIES: + logger.warning( + "Empty response on turn {} for {} ({}/{}); retrying", + iteration, + spec.session_key or "default", + empty_content_retries, + _MAX_EMPTY_RETRIES, + ) + if hook.wants_streaming(): + await hook.on_stream_end(context, resuming=False) + await hook.after_iteration(context) + continue logger.warning( - "Empty final response on turn {} for {}; retrying with explicit finalization prompt", + "Empty response on turn {} for {} after {} retries; attempting finalization", iteration, spec.session_key or "default", + empty_content_retries, ) if hook.wants_streaming(): await hook.on_stream_end(context, resuming=False) diff --git a/nanobot/agent/skills.py b/nanobot/agent/skills.py index 9afee82f0..ca215cc96 100644 --- a/nanobot/agent/skills.py +++ b/nanobot/agent/skills.py @@ -9,6 +9,16 @@ from pathlib import Path # Default builtin skills directory (relative to this file) BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills" +# Opening ---, YAML body (group 1), closing --- on its own line; supports CRLF. +_STRIP_SKILL_FRONTMATTER = re.compile( + r"^---\s*\r?\n(.*?)\r?\n---\s*\r?\n?", + re.DOTALL, +) + + +def _escape_xml(text: str) -> str: + return text.replace("&", "&").replace("<", "<").replace(">", ">") + class SkillsLoader: """ @@ -23,6 +33,22 @@ class SkillsLoader: self.workspace_skills = workspace / "skills" self.builtin_skills = builtin_skills_dir or BUILTIN_SKILLS_DIR + def _skill_entries_from_dir(self, base: Path, source: str, *, skip_names: set[str] | None = None) -> list[dict[str, str]]: + if not base.exists(): + return [] + entries: list[dict[str, str]] = [] + for skill_dir in base.iterdir(): + if not skill_dir.is_dir(): + continue + skill_file = skill_dir / "SKILL.md" + if not skill_file.exists(): + continue + name = skill_dir.name + if skip_names is not None and name in skip_names: + continue + entries.append({"name": name, "path": str(skill_file), "source": source}) + return entries + def list_skills(self, filter_unavailable: bool = True) -> list[dict[str, str]]: """ List all available skills. @@ -33,27 +59,15 @@ class SkillsLoader: Returns: List of skill info dicts with 'name', 'path', 'source'. """ - skills = [] - - # Workspace skills (highest priority) - if self.workspace_skills.exists(): - for skill_dir in self.workspace_skills.iterdir(): - if skill_dir.is_dir(): - skill_file = skill_dir / "SKILL.md" - if skill_file.exists(): - skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "workspace"}) - - # Built-in skills + skills = self._skill_entries_from_dir(self.workspace_skills, "workspace") + workspace_names = {entry["name"] for entry in skills} if self.builtin_skills and self.builtin_skills.exists(): - for skill_dir in self.builtin_skills.iterdir(): - if skill_dir.is_dir(): - skill_file = skill_dir / "SKILL.md" - if skill_file.exists() and not any(s["name"] == skill_dir.name for s in skills): - skills.append({"name": skill_dir.name, "path": str(skill_file), "source": "builtin"}) + skills.extend( + self._skill_entries_from_dir(self.builtin_skills, "builtin", skip_names=workspace_names) + ) - # Filter by requirements if filter_unavailable: - return [s for s in skills if self._check_requirements(self._get_skill_meta(s["name"]))] + return [skill for skill in skills if self._check_requirements(self._get_skill_meta(skill["name"]))] return skills def load_skill(self, name: str) -> str | None: @@ -66,17 +80,13 @@ class SkillsLoader: Returns: Skill content or None if not found. """ - # Check workspace first - workspace_skill = self.workspace_skills / name / "SKILL.md" - if workspace_skill.exists(): - return workspace_skill.read_text(encoding="utf-8") - - # Check built-in + roots = [self.workspace_skills] if self.builtin_skills: - builtin_skill = self.builtin_skills / name / "SKILL.md" - if builtin_skill.exists(): - return builtin_skill.read_text(encoding="utf-8") - + roots.append(self.builtin_skills) + for root in roots: + path = root / name / "SKILL.md" + if path.exists(): + return path.read_text(encoding="utf-8") return None def load_skills_for_context(self, skill_names: list[str]) -> str: @@ -89,14 +99,12 @@ class SkillsLoader: Returns: Formatted skills content. """ - parts = [] - for name in skill_names: - content = self.load_skill(name) - if content: - content = self._strip_frontmatter(content) - parts.append(f"### Skill: {name}\n\n{content}") - - return "\n\n---\n\n".join(parts) if parts else "" + parts = [ + f"### Skill: {name}\n\n{self._strip_frontmatter(markdown)}" + for name in skill_names + if (markdown := self.load_skill(name)) + ] + return "\n\n---\n\n".join(parts) def build_skills_summary(self) -> str: """ @@ -112,44 +120,36 @@ class SkillsLoader: if not all_skills: return "" - def escape_xml(s: str) -> str: - return s.replace("&", "&").replace("<", "<").replace(">", ">") - - lines = [""] - for s in all_skills: - name = escape_xml(s["name"]) - path = s["path"] - desc = escape_xml(self._get_skill_description(s["name"])) - skill_meta = self._get_skill_meta(s["name"]) - available = self._check_requirements(skill_meta) - - lines.append(f" ") - lines.append(f" {name}") - lines.append(f" {desc}") - lines.append(f" {path}") - - # Show missing requirements for unavailable skills + lines: list[str] = [""] + for entry in all_skills: + skill_name = entry["name"] + meta = self._get_skill_meta(skill_name) + available = self._check_requirements(meta) + lines.extend( + [ + f' ', + f" {_escape_xml(skill_name)}", + f" {_escape_xml(self._get_skill_description(skill_name))}", + f" {entry['path']}", + ] + ) if not available: - missing = self._get_missing_requirements(skill_meta) + missing = self._get_missing_requirements(meta) if missing: - lines.append(f" {escape_xml(missing)}") - + lines.append(f" {_escape_xml(missing)}") lines.append(" ") lines.append("") - return "\n".join(lines) def _get_missing_requirements(self, skill_meta: dict) -> str: """Get a description of missing requirements.""" - missing = [] requires = skill_meta.get("requires", {}) - for b in requires.get("bins", []): - if not shutil.which(b): - missing.append(f"CLI: {b}") - for env in requires.get("env", []): - if not os.environ.get(env): - missing.append(f"ENV: {env}") - return ", ".join(missing) + required_bins = requires.get("bins", []) + required_env_vars = requires.get("env", []) + return ", ".join( + [f"CLI: {command_name}" for command_name in required_bins if not shutil.which(command_name)] + + [f"ENV: {env_name}" for env_name in required_env_vars if not os.environ.get(env_name)] + ) def _get_skill_description(self, name: str) -> str: """Get the description of a skill from its frontmatter.""" @@ -160,30 +160,32 @@ class SkillsLoader: def _strip_frontmatter(self, content: str) -> str: """Remove YAML frontmatter from markdown content.""" - if content.startswith("---"): - match = re.match(r"^---\n.*?\n---\n", content, re.DOTALL) - if match: - return content[match.end():].strip() + if not content.startswith("---"): + return content + match = _STRIP_SKILL_FRONTMATTER.match(content) + if match: + return content[match.end():].strip() return content def _parse_nanobot_metadata(self, raw: str) -> dict: """Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys).""" try: data = json.loads(raw) - return data.get("nanobot", data.get("openclaw", {})) if isinstance(data, dict) else {} except (json.JSONDecodeError, TypeError): return {} + if not isinstance(data, dict): + return {} + payload = data.get("nanobot", data.get("openclaw", {})) + return payload if isinstance(payload, dict) else {} def _check_requirements(self, skill_meta: dict) -> bool: """Check if skill requirements are met (bins, env vars).""" requires = skill_meta.get("requires", {}) - for b in requires.get("bins", []): - if not shutil.which(b): - return False - for env in requires.get("env", []): - if not os.environ.get(env): - return False - return True + required_bins = requires.get("bins", []) + required_env_vars = requires.get("env", []) + return all(shutil.which(cmd) for cmd in required_bins) and all( + os.environ.get(var) for var in required_env_vars + ) def _get_skill_meta(self, name: str) -> dict: """Get nanobot metadata for a skill (cached in frontmatter).""" @@ -192,13 +194,15 @@ class SkillsLoader: def get_always_skills(self) -> list[str]: """Get skills marked as always=true that meet requirements.""" - result = [] - for s in self.list_skills(filter_unavailable=True): - meta = self.get_skill_metadata(s["name"]) or {} - skill_meta = self._parse_nanobot_metadata(meta.get("metadata", "")) - if skill_meta.get("always") or meta.get("always"): - result.append(s["name"]) - return result + return [ + entry["name"] + for entry in self.list_skills(filter_unavailable=True) + if (meta := self.get_skill_metadata(entry["name"]) or {}) + and ( + self._parse_nanobot_metadata(meta.get("metadata", "")).get("always") + or meta.get("always") + ) + ] def get_skill_metadata(self, name: str) -> dict | None: """ @@ -211,18 +215,15 @@ class SkillsLoader: Metadata dict or None. """ content = self.load_skill(name) - if not content: + if not content or not content.startswith("---"): return None - - if content.startswith("---"): - match = re.match(r"^---\n(.*?)\n---", content, re.DOTALL) - if match: - # Simple YAML parsing - metadata = {} - for line in match.group(1).split("\n"): - if ":" in line: - key, value = line.split(":", 1) - metadata[key.strip()] = value.strip().strip('"\'') - return metadata - - return None + match = _STRIP_SKILL_FRONTMATTER.match(content) + if not match: + return None + metadata: dict[str, str] = {} + for line in match.group(1).splitlines(): + if ":" not in line: + continue + key, value = line.split(":", 1) + metadata[key.strip()] = value.strip().strip('"\'') + return metadata diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index 8ea4b6c06..585139972 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -111,7 +111,7 @@ class SubagentManager: try: # Build subagent tools (no message tool, no spawn tool) tools = ToolRegistry() - allowed_dir = self.workspace if self.restrict_to_workspace else None + allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read)) tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir)) @@ -124,6 +124,7 @@ class SubagentManager: working_dir=str(self.workspace), timeout=self.exec_config.timeout, restrict_to_workspace=self.restrict_to_workspace, + sandbox=self.exec_config.sandbox, path_append=self.exec_config.path_append, )) if self.web_config.enable: diff --git a/nanobot/agent/tools/cron.py b/nanobot/agent/tools/cron.py index 064b6e4c9..f0d3ddab9 100644 --- a/nanobot/agent/tools/cron.py +++ b/nanobot/agent/tools/cron.py @@ -13,6 +13,10 @@ from nanobot.cron.types import CronJob, CronJobState, CronSchedule @tool_parameters( tool_parameters_schema( action=StringSchema("Action to perform", enum=["add", "list", "remove"]), + name=StringSchema( + "Optional short human-readable label for the job " + "(e.g., 'weather-monitor', 'daily-standup'). Defaults to first 30 chars of message." + ), message=StringSchema( "Instruction for the agent to execute when the job triggers " "(e.g., 'Send a reminder to WeChat: xxx' or 'Check system status and report')" @@ -93,6 +97,7 @@ class CronTool(Tool): async def execute( self, action: str, + name: str | None = None, message: str = "", every_seconds: int | None = None, cron_expr: str | None = None, @@ -105,7 +110,7 @@ class CronTool(Tool): if action == "add": if self._in_cron_context.get(): return "Error: cannot schedule new jobs from within a cron job execution" - return self._add_job(message, every_seconds, cron_expr, tz, at, deliver) + return self._add_job(name, message, every_seconds, cron_expr, tz, at, deliver) elif action == "list": return self._list_jobs() elif action == "remove": @@ -114,6 +119,7 @@ class CronTool(Tool): def _add_job( self, + name: str | None, message: str, every_seconds: int | None, cron_expr: str | None, @@ -158,7 +164,7 @@ class CronTool(Tool): return "Error: either every_seconds, cron_expr, or at is required" job = self._cron.add_job( - name=message[:30], + name=name or message[:30], schedule=schedule, message=message, deliver=deliver, diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 11f05c557..27ae3ccd9 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -186,7 +186,7 @@ class WriteFileTool(_FsTool): fp = self._resolve(path) fp.parent.mkdir(parents=True, exist_ok=True) fp.write_text(content, encoding="utf-8") - return f"Successfully wrote {len(content)} bytes to {fp}" + return f"Successfully wrote {len(content)} characters to {fp}" except PermissionError as e: return f"Error: {e}" except Exception as e: diff --git a/nanobot/agent/tools/sandbox.py b/nanobot/agent/tools/sandbox.py new file mode 100644 index 000000000..459ce16a3 --- /dev/null +++ b/nanobot/agent/tools/sandbox.py @@ -0,0 +1,55 @@ +"""Sandbox backends for shell command execution. + +To add a new backend, implement a function with the signature: + _wrap_(command: str, workspace: str, cwd: str) -> str +and register it in _BACKENDS below. +""" + +import shlex +from pathlib import Path + +from nanobot.config.paths import get_media_dir + + +def _bwrap(command: str, workspace: str, cwd: str) -> str: + """Wrap command in a bubblewrap sandbox (requires bwrap in container). + + Only the workspace is bind-mounted read-write; its parent dir (which holds + config.json) is hidden behind a fresh tmpfs. The media directory is + bind-mounted read-only so exec commands can read uploaded attachments. + """ + ws = Path(workspace).resolve() + media = get_media_dir().resolve() + + try: + sandbox_cwd = str(ws / Path(cwd).resolve().relative_to(ws)) + except ValueError: + sandbox_cwd = str(ws) + + required = ["/usr"] + optional = ["/bin", "/lib", "/lib64", "/etc/alternatives", + "/etc/ssl/certs", "/etc/resolv.conf", "/etc/ld.so.cache"] + + args = ["bwrap", "--new-session", "--die-with-parent"] + for p in required: args += ["--ro-bind", p, p] + for p in optional: args += ["--ro-bind-try", p, p] + args += [ + "--proc", "/proc", "--dev", "/dev", "--tmpfs", "/tmp", + "--tmpfs", str(ws.parent), # mask config dir + "--dir", str(ws), # recreate workspace mount point + "--bind", str(ws), str(ws), + "--ro-bind-try", str(media), str(media), # read-only access to media + "--chdir", sandbox_cwd, + "--", "sh", "-c", command, + ] + return shlex.join(args) + + +_BACKENDS = {"bwrap": _bwrap} + + +def wrap_command(sandbox: str, command: str, workspace: str, cwd: str) -> str: + """Wrap *command* using the named sandbox backend.""" + if backend := _BACKENDS.get(sandbox): + return backend(command, workspace, cwd) + raise ValueError(f"Unknown sandbox backend {sandbox!r}. Available: {list(_BACKENDS)}") diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index c8876827c..e5c04eb72 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -3,6 +3,7 @@ import asyncio import os import re +import shutil import sys from pathlib import Path from typing import Any @@ -10,6 +11,7 @@ from typing import Any from loguru import logger from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.sandbox import wrap_command from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema from nanobot.config.paths import get_media_dir @@ -40,10 +42,12 @@ class ExecTool(Tool): deny_patterns: list[str] | None = None, allow_patterns: list[str] | None = None, restrict_to_workspace: bool = False, + sandbox: str = "", path_append: str = "", ): self.timeout = timeout self.working_dir = working_dir + self.sandbox = sandbox self.deny_patterns = deny_patterns or [ r"\brm\s+-[rf]{1,2}\b", # rm -r, rm -rf, rm -fr r"\bdel\s+/[fq]\b", # del /f, del /q @@ -83,15 +87,23 @@ class ExecTool(Tool): if guard_error: return guard_error + if self.sandbox: + workspace = self.working_dir or cwd + command = wrap_command(self.sandbox, command, workspace, cwd) + cwd = str(Path(workspace).resolve()) + effective_timeout = min(timeout or self.timeout, self._MAX_TIMEOUT) - env = os.environ.copy() + env = self._build_env() + if self.path_append: - env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append + command = f'export PATH="$PATH:{self.path_append}"; {command}' + + bash = shutil.which("bash") or "/bin/bash" try: - process = await asyncio.create_subprocess_shell( - command, + process = await asyncio.create_subprocess_exec( + bash, "-l", "-c", command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, cwd=cwd, @@ -104,18 +116,11 @@ class ExecTool(Tool): timeout=effective_timeout, ) except asyncio.TimeoutError: - process.kill() - try: - await asyncio.wait_for(process.wait(), timeout=5.0) - except asyncio.TimeoutError: - pass - finally: - if sys.platform != "win32": - try: - os.waitpid(process.pid, os.WNOHANG) - except (ProcessLookupError, ChildProcessError) as e: - logger.debug("Process already reaped or not found: {}", e) + await self._kill_process(process) return f"Error: Command timed out after {effective_timeout} seconds" + except asyncio.CancelledError: + await self._kill_process(process) + raise output_parts = [] @@ -146,6 +151,36 @@ class ExecTool(Tool): except Exception as e: return f"Error executing command: {str(e)}" + @staticmethod + async def _kill_process(process: asyncio.subprocess.Process) -> None: + """Kill a subprocess and reap it to prevent zombies.""" + process.kill() + try: + await asyncio.wait_for(process.wait(), timeout=5.0) + except asyncio.TimeoutError: + pass + finally: + if sys.platform != "win32": + try: + os.waitpid(process.pid, os.WNOHANG) + except (ProcessLookupError, ChildProcessError) as e: + logger.debug("Process already reaped or not found: {}", e) + + def _build_env(self) -> dict[str, str]: + """Build a minimal environment for subprocess execution. + + Uses HOME so that ``bash -l`` sources the user's profile (which sets + PATH and other essentials). Only PATH is extended with *path_append*; + the parent process's environment is **not** inherited, preventing + secrets in env vars from leaking to LLM-generated commands. + """ + home = os.environ.get("HOME", "/tmp") + return { + "HOME": home, + "LANG": os.environ.get("LANG", "C.UTF-8"), + "TERM": os.environ.get("TERM", "dumb"), + } + def _guard_command(self, command: str, cwd: str) -> str | None: """Best-effort safety guard for potentially destructive commands.""" cmd = command.strip() diff --git a/nanobot/agent/tools/web.py b/nanobot/agent/tools/web.py index 9ac923050..a6d7be983 100644 --- a/nanobot/agent/tools/web.py +++ b/nanobot/agent/tools/web.py @@ -8,7 +8,7 @@ import json import os import re from typing import TYPE_CHECKING, Any -from urllib.parse import urlparse +from urllib.parse import quote, urlparse import httpx from loguru import logger @@ -182,10 +182,10 @@ class WebSearchTool(Tool): return await self._search_duckduckgo(query, n) try: headers = {"Accept": "application/json", "Authorization": f"Bearer {api_key}"} + encoded_query = quote(query, safe="") async with httpx.AsyncClient(proxy=self.proxy) as client: r = await client.get( - f"https://s.jina.ai/", - params={"q": query}, + f"https://s.jina.ai/{encoded_query}", headers=headers, timeout=15.0, ) @@ -197,7 +197,8 @@ class WebSearchTool(Tool): ] return _format_results(query, items, n) except Exception as e: - return f"Error: {e}" + logger.warning("Jina search failed ({}), falling back to DuckDuckGo", e) + return await self._search_duckduckgo(query, n) async def _search_duckduckgo(self, query: str, n: int) -> str: try: @@ -206,7 +207,10 @@ class WebSearchTool(Tool): from ddgs import DDGS ddgs = DDGS(timeout=10) - raw = await asyncio.to_thread(ddgs.text, query, max_results=n) + raw = await asyncio.wait_for( + asyncio.to_thread(ddgs.text, query, max_results=n), + timeout=self.config.timeout, + ) if not raw: return f"No results for: {query}" items = [ diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 86e991344..dd29c0851 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -22,6 +22,7 @@ class BaseChannel(ABC): name: str = "base" display_name: str = "Base" + transcription_provider: str = "groq" transcription_api_key: str = "" def __init__(self, config: Any, bus: MessageBus): @@ -37,13 +38,16 @@ class BaseChannel(ABC): self._running = False async def transcribe_audio(self, file_path: str | Path) -> str: - """Transcribe an audio file via Groq Whisper. Returns empty string on failure.""" + """Transcribe an audio file via Whisper (OpenAI or Groq). Returns empty string on failure.""" if not self.transcription_api_key: return "" try: - from nanobot.providers.transcription import GroqTranscriptionProvider - - provider = GroqTranscriptionProvider(api_key=self.transcription_api_key) + if self.transcription_provider == "openai": + from nanobot.providers.transcription import OpenAITranscriptionProvider + provider = OpenAITranscriptionProvider(api_key=self.transcription_api_key) + else: + from nanobot.providers.transcription import GroqTranscriptionProvider + provider = GroqTranscriptionProvider(api_key=self.transcription_api_key) return await provider.transcribe(file_path) except Exception as e: logger.warning("{}: audio transcription failed: {}", self.name, e) diff --git a/nanobot/channels/email.py b/nanobot/channels/email.py index bee2ceccd..f0fcdf9a9 100644 --- a/nanobot/channels/email.py +++ b/nanobot/channels/email.py @@ -12,6 +12,8 @@ from email.header import decode_header, make_header from email.message import EmailMessage from email.parser import BytesParser from email.utils import parseaddr +from fnmatch import fnmatch +from pathlib import Path from typing import Any from loguru import logger @@ -20,7 +22,9 @@ from pydantic import Field from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel +from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base +from nanobot.utils.helpers import safe_filename class EmailConfig(Base): @@ -55,6 +59,11 @@ class EmailConfig(Base): verify_dkim: bool = True # Require Authentication-Results with dkim=pass verify_spf: bool = True # Require Authentication-Results with spf=pass + # Attachment handling β€” set allowed types to enable (e.g. ["application/pdf", "image/*"], or ["*"] for all) + allowed_attachment_types: list[str] = Field(default_factory=list) + max_attachment_size: int = 2_000_000 # 2MB per attachment + max_attachments_per_email: int = 5 + class EmailChannel(BaseChannel): """ @@ -153,6 +162,7 @@ class EmailChannel(BaseChannel): sender_id=sender, chat_id=sender, content=item["content"], + media=item.get("media") or None, metadata=item.get("metadata", {}), ) except Exception as e: @@ -404,6 +414,20 @@ class EmailChannel(BaseChannel): f"{body}" ) + # --- Attachment extraction --- + attachment_paths: list[str] = [] + if self.config.allowed_attachment_types: + saved = self._extract_attachments( + parsed, + uid or "noid", + allowed_types=self.config.allowed_attachment_types, + max_size=self.config.max_attachment_size, + max_count=self.config.max_attachments_per_email, + ) + for p in saved: + attachment_paths.append(str(p)) + content += f"\n[attachment: {p.name} β€” saved to {p}]" + metadata = { "message_id": message_id, "subject": subject, @@ -418,6 +442,7 @@ class EmailChannel(BaseChannel): "message_id": message_id, "content": content, "metadata": metadata, + "media": attachment_paths, } ) @@ -537,6 +562,61 @@ class EmailChannel(BaseChannel): dkim_pass = True return spf_pass, dkim_pass + @classmethod + def _extract_attachments( + cls, + msg: Any, + uid: str, + *, + allowed_types: list[str], + max_size: int, + max_count: int, + ) -> list[Path]: + """Extract and save email attachments to the media directory. + + Returns list of saved file paths. + """ + if not msg.is_multipart(): + return [] + + saved: list[Path] = [] + media_dir = get_media_dir("email") + + for part in msg.walk(): + if len(saved) >= max_count: + break + if part.get_content_disposition() != "attachment": + continue + + content_type = part.get_content_type() + if not any(fnmatch(content_type, pat) for pat in allowed_types): + logger.debug("Email attachment skipped (type {}): not in allowed list", content_type) + continue + + payload = part.get_payload(decode=True) + if payload is None: + continue + if len(payload) > max_size: + logger.warning( + "Email attachment skipped: size {} exceeds limit {}", + len(payload), + max_size, + ) + continue + + raw_name = part.get_filename() or "attachment" + sanitized = safe_filename(raw_name) or "attachment" + dest = media_dir / f"{uid}_{sanitized}" + + try: + dest.write_bytes(payload) + saved.append(dest) + logger.info("Email attachment saved: {}", dest) + except Exception as exc: + logger.warning("Failed to save email attachment {}: {}", dest, exc) + + return saved + @staticmethod def _html_to_text(raw_html: str) -> str: text = re.sub(r"<\s*br\s*/?>", "\n", raw_html, flags=re.IGNORECASE) diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 1128c0e16..bac14cb84 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -1,6 +1,7 @@ """Feishu/Lark channel implementation using lark-oapi SDK with WebSocket long connection.""" import asyncio +import importlib.util import json import os import re @@ -9,19 +10,17 @@ import time import uuid from collections import OrderedDict from dataclasses import dataclass -from pathlib import Path from typing import Any, Literal +from lark_oapi.api.im.v1.model import MentionEvent, P2ImMessageReceiveV1 from loguru import logger +from pydantic import Field from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base -from pydantic import Field - -import importlib.util FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None @@ -76,7 +75,9 @@ def _extract_interactive_content(content: dict) -> list[str]: elif isinstance(title, str): parts.append(f"title: {title}") - for elements in content.get("elements", []) if isinstance(content.get("elements"), list) else []: + for elements in ( + content.get("elements", []) if isinstance(content.get("elements"), list) else [] + ): for element in elements: parts.extend(_extract_element_content(element)) @@ -260,6 +261,7 @@ _STREAM_ELEMENT_ID = "streaming_md" @dataclass class _FeishuStreamBuf: """Per-chat streaming accumulator using CardKit streaming API.""" + text: str = "" card_id: str | None = None sequence: int = 0 @@ -288,16 +290,19 @@ class FeishuChannel(BaseChannel): return FeishuConfig().model_dump(by_alias=True) def __init__(self, config: Any, bus: MessageBus): + import lark_oapi as lark + if isinstance(config, dict): config = FeishuConfig.model_validate(config) super().__init__(config, bus) self.config: FeishuConfig = config - self._client: Any = None + self._client: lark.Client = None self._ws_client: Any = None self._ws_thread: threading.Thread | None = None self._processed_message_ids: OrderedDict[str, None] = OrderedDict() # Ordered dedup cache self._loop: asyncio.AbstractEventLoop | None = None self._stream_bufs: dict[str, _FeishuStreamBuf] = {} + self._bot_open_id: str | None = None @staticmethod def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any: @@ -316,24 +321,28 @@ class FeishuChannel(BaseChannel): return import lark_oapi as lark + self._running = True self._loop = asyncio.get_running_loop() # Create Lark client for sending messages - self._client = lark.Client.builder() \ - .app_id(self.config.app_id) \ - .app_secret(self.config.app_secret) \ - .log_level(lark.LogLevel.INFO) \ + self._client = ( + lark.Client.builder() + .app_id(self.config.app_id) + .app_secret(self.config.app_secret) + .log_level(lark.LogLevel.INFO) .build() + ) builder = lark.EventDispatcherHandler.builder( self.config.encrypt_key or "", self.config.verification_token or "", - ).register_p2_im_message_receive_v1( - self._on_message_sync - ) + ).register_p2_im_message_receive_v1(self._on_message_sync) builder = self._register_optional_event( builder, "register_p2_im_message_reaction_created_v1", self._on_reaction_created ) + builder = self._register_optional_event( + builder, "register_p2_im_message_reaction_deleted_v1", self._on_reaction_deleted + ) builder = self._register_optional_event( builder, "register_p2_im_message_message_read_v1", self._on_message_read ) @@ -349,7 +358,7 @@ class FeishuChannel(BaseChannel): self.config.app_id, self.config.app_secret, event_handler=event_handler, - log_level=lark.LogLevel.INFO + log_level=lark.LogLevel.INFO, ) # Start WebSocket client in a separate thread with reconnect loop. @@ -359,7 +368,9 @@ class FeishuChannel(BaseChannel): # "This event loop is already running" errors. def run_ws(): import time + import lark_oapi.ws.client as _lark_ws_client + ws_loop = asyncio.new_event_loop() asyncio.set_event_loop(ws_loop) # Patch the module-level loop used by lark's ws Client.start() @@ -378,6 +389,15 @@ class FeishuChannel(BaseChannel): self._ws_thread = threading.Thread(target=run_ws, daemon=True) self._ws_thread.start() + # Fetch bot's own open_id for accurate @mention matching + self._bot_open_id = await asyncio.get_running_loop().run_in_executor( + None, self._fetch_bot_open_id + ) + if self._bot_open_id: + logger.info("Feishu bot open_id: {}", self._bot_open_id) + else: + logger.warning("Could not fetch bot open_id; @mention matching may be inaccurate") + logger.info("Feishu bot started with WebSocket long connection") logger.info("No public IP required - using WebSocket to receive events") @@ -396,6 +416,70 @@ class FeishuChannel(BaseChannel): self._running = False logger.info("Feishu bot stopped") + def _fetch_bot_open_id(self) -> str | None: + """Fetch the bot's own open_id via GET /open-apis/bot/v3/info.""" + try: + import lark_oapi as lark + + request = ( + lark.BaseRequest.builder() + .http_method(lark.HttpMethod.GET) + .uri("/open-apis/bot/v3/info") + .token_types({lark.AccessTokenType.APP}) + .build() + ) + response = self._client.request(request) + if response.success(): + import json + + data = json.loads(response.raw.content) + bot = (data.get("data") or data).get("bot") or data.get("bot") or {} + return bot.get("open_id") + logger.warning("Failed to get bot info: code={}, msg={}", response.code, response.msg) + return None + except Exception as e: + logger.warning("Error fetching bot info: {}", e) + return None + + @staticmethod + def _resolve_mentions(text: str, mentions: list[MentionEvent] | None) -> str: + """Replace @_user_n placeholders with actual user info from mentions. + + Args: + text: The message text containing @_user_n placeholders + mentions: List of mention objects from Feishu message + + Returns: + Text with placeholders replaced by @姓名 (open_id) + """ + if not mentions or not text: + return text + + for mention in mentions: + key = mention.key or None + if not key or key not in text: + continue + + user_id_obj = mention.id or None + if not user_id_obj: + continue + + open_id = user_id_obj.open_id + user_id = user_id_obj.user_id + name = mention.name or key + + # Format: @姓名 (open_id, user_id: xxx) + if open_id and user_id: + replacement = f"@{name} ({open_id}, user id: {user_id})" + elif open_id: + replacement = f"@{name} ({open_id})" + else: + replacement = f"@{name}" + + text = text.replace(key, replacement) + + return text + def _is_bot_mentioned(self, message: Any) -> bool: """Check if the bot is @mentioned in the message.""" raw_content = message.content or "" @@ -406,9 +490,14 @@ class FeishuChannel(BaseChannel): mid = getattr(mention, "id", None) if not mid: continue - # Bot mentions have no user_id (None or "") but a valid open_id - if not getattr(mid, "user_id", None) and (getattr(mid, "open_id", None) or "").startswith("ou_"): - return True + mention_open_id = getattr(mid, "open_id", None) or "" + if self._bot_open_id: + if mention_open_id == self._bot_open_id: + return True + else: + # Fallback heuristic when bot open_id is unavailable + if not getattr(mid, "user_id", None) and mention_open_id.startswith("ou_"): + return True return False def _is_group_message_for_bot(self, message: Any) -> bool: @@ -419,20 +508,30 @@ class FeishuChannel(BaseChannel): def _add_reaction_sync(self, message_id: str, emoji_type: str) -> str | None: """Sync helper for adding reaction (runs in thread pool).""" - from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji + from lark_oapi.api.im.v1 import ( + CreateMessageReactionRequest, + CreateMessageReactionRequestBody, + Emoji, + ) + try: - request = CreateMessageReactionRequest.builder() \ - .message_id(message_id) \ + request = ( + CreateMessageReactionRequest.builder() + .message_id(message_id) .request_body( CreateMessageReactionRequestBody.builder() .reaction_type(Emoji.builder().emoji_type(emoji_type).build()) .build() - ).build() + ) + .build() + ) response = self._client.im.v1.message_reaction.create(request) if not response.success(): - logger.warning("Failed to add reaction: code={}, msg={}", response.code, response.msg) + logger.warning( + "Failed to add reaction: code={}, msg={}", response.code, response.msg + ) return None else: logger.debug("Added {} reaction to message {}", emoji_type, message_id) @@ -456,17 +555,22 @@ class FeishuChannel(BaseChannel): def _remove_reaction_sync(self, message_id: str, reaction_id: str) -> None: """Sync helper for removing reaction (runs in thread pool).""" from lark_oapi.api.im.v1 import DeleteMessageReactionRequest + try: - request = DeleteMessageReactionRequest.builder() \ - .message_id(message_id) \ - .reaction_id(reaction_id) \ + request = ( + DeleteMessageReactionRequest.builder() + .message_id(message_id) + .reaction_id(reaction_id) .build() + ) response = self._client.im.v1.message_reaction.delete(request) if response.success(): logger.debug("Removed reaction {} from message {}", reaction_id, message_id) else: - logger.debug("Failed to remove reaction: code={}, msg={}", response.code, response.msg) + logger.debug( + "Failed to remove reaction: code={}, msg={}", response.code, response.msg + ) except Exception as e: logger.debug("Error removing reaction: {}", e) @@ -521,27 +625,35 @@ class FeishuChannel(BaseChannel): lines = [_line.strip() for _line in table_text.strip().split("\n") if _line.strip()] if len(lines) < 3: return None + def split(_line: str) -> list[str]: return [c.strip() for c in _line.strip("|").split("|")] + headers = [cls._strip_md_formatting(h) for h in split(lines[0])] rows = [[cls._strip_md_formatting(c) for c in split(_line)] for _line in lines[2:]] - columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"} - for i, h in enumerate(headers)] + columns = [ + {"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"} + for i, h in enumerate(headers) + ] return { "tag": "table", "page_size": len(rows) + 1, "columns": columns, - "rows": [{f"c{i}": r[i] if i < len(r) else "" for i in range(len(headers))} for r in rows], + "rows": [ + {f"c{i}": r[i] if i < len(r) else "" for i in range(len(headers))} for r in rows + ], } def _build_card_elements(self, content: str) -> list[dict]: """Split content into div/markdown + table elements for Feishu card.""" elements, last_end = [], 0 for m in self._TABLE_RE.finditer(content): - before = content[last_end:m.start()] + before = content[last_end : m.start()] if before.strip(): elements.extend(self._split_headings(before)) - elements.append(self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)}) + elements.append( + self._parse_md_table(m.group(1)) or {"tag": "markdown", "content": m.group(1)} + ) last_end = m.end() remaining = content[last_end:] if remaining.strip(): @@ -549,7 +661,9 @@ class FeishuChannel(BaseChannel): return elements or [{"tag": "markdown", "content": content}] @staticmethod - def _split_elements_by_table_limit(elements: list[dict], max_tables: int = 1) -> list[list[dict]]: + def _split_elements_by_table_limit( + elements: list[dict], max_tables: int = 1 + ) -> list[list[dict]]: """Split card elements into groups with at most *max_tables* table elements each. Feishu cards have a hard limit of one table per card (API error 11310). @@ -582,23 +696,25 @@ class FeishuChannel(BaseChannel): code_blocks = [] for m in self._CODE_BLOCK_RE.finditer(content): code_blocks.append(m.group(1)) - protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks)-1}\x00", 1) + protected = protected.replace(m.group(1), f"\x00CODE{len(code_blocks) - 1}\x00", 1) elements = [] last_end = 0 for m in self._HEADING_RE.finditer(protected): - before = protected[last_end:m.start()].strip() + before = protected[last_end : m.start()].strip() if before: elements.append({"tag": "markdown", "content": before}) text = self._strip_md_formatting(m.group(2).strip()) display_text = f"**{text}**" if text else "" - elements.append({ - "tag": "div", - "text": { - "tag": "lark_md", - "content": display_text, - }, - }) + elements.append( + { + "tag": "div", + "text": { + "tag": "lark_md", + "content": display_text, + }, + } + ) last_end = m.end() remaining = protected[last_end:].strip() if remaining: @@ -614,19 +730,19 @@ class FeishuChannel(BaseChannel): # ── Smart format detection ────────────────────────────────────────── # Patterns that indicate "complex" markdown needing card rendering _COMPLEX_MD_RE = re.compile( - r"```" # fenced code block + r"```" # fenced code block r"|^\|.+\|.*\n\s*\|[-:\s|]+\|" # markdown table (header + separator) - r"|^#{1,6}\s+" # headings - , re.MULTILINE, + r"|^#{1,6}\s+", # headings + re.MULTILINE, ) # Simple markdown patterns (bold, italic, strikethrough) _SIMPLE_MD_RE = re.compile( - r"\*\*.+?\*\*" # **bold** - r"|__.+?__" # __bold__ + r"\*\*.+?\*\*" # **bold** + r"|__.+?__" # __bold__ r"|(? str | None: """Upload an image to Feishu and return the image_key.""" from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody + try: with open(file_path, "rb") as f: - request = CreateImageRequest.builder() \ + request = ( + CreateImageRequest.builder() .request_body( - CreateImageRequestBody.builder() - .image_type("message") - .image(f) - .build() - ).build() + CreateImageRequestBody.builder().image_type("message").image(f).build() + ) + .build() + ) response = self._client.im.v1.image.create(request) if response.success(): image_key = response.data.image_key logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key) return image_key else: - logger.error("Failed to upload image: code={}, msg={}", response.code, response.msg) + logger.error( + "Failed to upload image: code={}, msg={}", response.code, response.msg + ) return None except Exception as e: logger.error("Error uploading image {}: {}", file_path, e) @@ -761,49 +889,62 @@ class FeishuChannel(BaseChannel): def _upload_file_sync(self, file_path: str) -> str | None: """Upload a file to Feishu and return the file_key.""" from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody + ext = os.path.splitext(file_path)[1].lower() file_type = self._FILE_TYPE_MAP.get(ext, "stream") file_name = os.path.basename(file_path) try: with open(file_path, "rb") as f: - request = CreateFileRequest.builder() \ + request = ( + CreateFileRequest.builder() .request_body( CreateFileRequestBody.builder() .file_type(file_type) .file_name(file_name) .file(f) .build() - ).build() + ) + .build() + ) response = self._client.im.v1.file.create(request) if response.success(): file_key = response.data.file_key logger.debug("Uploaded file {}: {}", file_name, file_key) return file_key else: - logger.error("Failed to upload file: code={}, msg={}", response.code, response.msg) + logger.error( + "Failed to upload file: code={}, msg={}", response.code, response.msg + ) return None except Exception as e: logger.error("Error uploading file {}: {}", file_path, e) return None - def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]: + def _download_image_sync( + self, message_id: str, image_key: str + ) -> tuple[bytes | None, str | None]: """Download an image from Feishu message by message_id and image_key.""" from lark_oapi.api.im.v1 import GetMessageResourceRequest + try: - request = GetMessageResourceRequest.builder() \ - .message_id(message_id) \ - .file_key(image_key) \ - .type("image") \ + request = ( + GetMessageResourceRequest.builder() + .message_id(message_id) + .file_key(image_key) + .type("image") .build() + ) response = self._client.im.v1.message_resource.get(request) if response.success(): file_data = response.file # GetMessageResourceRequest returns BytesIO, need to read bytes - if hasattr(file_data, 'read'): + if hasattr(file_data, "read"): file_data = file_data.read() return file_data, response.file_name else: - logger.error("Failed to download image: code={}, msg={}", response.code, response.msg) + logger.error( + "Failed to download image: code={}, msg={}", response.code, response.msg + ) return None, None except Exception as e: logger.error("Error downloading image {}: {}", image_key, e) @@ -835,17 +976,19 @@ class FeishuChannel(BaseChannel): file_data = file_data.read() return file_data, response.file_name else: - logger.error("Failed to download {}: code={}, msg={}", resource_type, response.code, response.msg) + logger.error( + "Failed to download {}: code={}, msg={}", + resource_type, + response.code, + response.msg, + ) return None, None except Exception: logger.exception("Error downloading {} {}", resource_type, file_key) return None, None async def _download_and_save_media( - self, - msg_type: str, - content_json: dict, - message_id: str | None = None + self, msg_type: str, content_json: dict, message_id: str | None = None ) -> tuple[str | None, str]: """ Download media from Feishu and save to local disk. @@ -894,13 +1037,16 @@ class FeishuChannel(BaseChannel): Returns a "[Reply to: ...]" context string, or None on failure. """ from lark_oapi.api.im.v1 import GetMessageRequest + try: request = GetMessageRequest.builder().message_id(message_id).build() response = self._client.im.v1.message.get(request) if not response.success(): logger.debug( "Feishu: could not fetch parent message {}: code={}, msg={}", - message_id, response.code, response.msg, + message_id, + response.code, + response.msg, ) return None items = getattr(response.data, "items", None) @@ -935,20 +1081,24 @@ class FeishuChannel(BaseChannel): def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str) -> bool: """Reply to an existing Feishu message using the Reply API (synchronous).""" from lark_oapi.api.im.v1 import ReplyMessageRequest, ReplyMessageRequestBody + try: - request = ReplyMessageRequest.builder() \ - .message_id(parent_message_id) \ + request = ( + ReplyMessageRequest.builder() + .message_id(parent_message_id) .request_body( - ReplyMessageRequestBody.builder() - .msg_type(msg_type) - .content(content) - .build() - ).build() + ReplyMessageRequestBody.builder().msg_type(msg_type).content(content).build() + ) + .build() + ) response = self._client.im.v1.message.reply(request) if not response.success(): logger.error( "Failed to reply to Feishu message {}: code={}, msg={}, log_id={}", - parent_message_id, response.code, response.msg, response.get_log_id() + parent_message_id, + response.code, + response.msg, + response.get_log_id(), ) return False logger.debug("Feishu reply sent to message {}", parent_message_id) @@ -957,24 +1107,33 @@ class FeishuChannel(BaseChannel): logger.error("Error replying to Feishu message {}: {}", parent_message_id, e) return False - def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> str | None: + def _send_message_sync( + self, receive_id_type: str, receive_id: str, msg_type: str, content: str + ) -> str | None: """Send a single message and return the message_id on success.""" from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody + try: - request = CreateMessageRequest.builder() \ - .receive_id_type(receive_id_type) \ + request = ( + CreateMessageRequest.builder() + .receive_id_type(receive_id_type) .request_body( CreateMessageRequestBody.builder() .receive_id(receive_id) .msg_type(msg_type) .content(content) .build() - ).build() + ) + .build() + ) response = self._client.im.v1.message.create(request) if not response.success(): logger.error( "Failed to send Feishu {} message: code={}, msg={}, log_id={}", - msg_type, response.code, response.msg, response.get_log_id() + msg_type, + response.code, + response.msg, + response.get_log_id(), ) return None msg_id = getattr(response.data, "message_id", None) @@ -987,31 +1146,44 @@ class FeishuChannel(BaseChannel): def _create_streaming_card_sync(self, receive_id_type: str, chat_id: str) -> str | None: """Create a CardKit streaming card, send it to chat, return card_id.""" from lark_oapi.api.cardkit.v1 import CreateCardRequest, CreateCardRequestBody + card_json = { "schema": "2.0", "config": {"wide_screen_mode": True, "update_multi": True, "streaming_mode": True}, - "body": {"elements": [{"tag": "markdown", "content": "", "element_id": _STREAM_ELEMENT_ID}]}, + "body": { + "elements": [{"tag": "markdown", "content": "", "element_id": _STREAM_ELEMENT_ID}] + }, } try: - request = CreateCardRequest.builder().request_body( - CreateCardRequestBody.builder() - .type("card_json") - .data(json.dumps(card_json, ensure_ascii=False)) + request = ( + CreateCardRequest.builder() + .request_body( + CreateCardRequestBody.builder() + .type("card_json") + .data(json.dumps(card_json, ensure_ascii=False)) + .build() + ) .build() - ).build() + ) response = self._client.cardkit.v1.card.create(request) if not response.success(): - logger.warning("Failed to create streaming card: code={}, msg={}", response.code, response.msg) + logger.warning( + "Failed to create streaming card: code={}, msg={}", response.code, response.msg + ) return None card_id = getattr(response.data, "card_id", None) if card_id: message_id = self._send_message_sync( - receive_id_type, chat_id, "interactive", + receive_id_type, + chat_id, + "interactive", json.dumps({"type": "card", "data": {"card_id": card_id}}), ) if message_id: return card_id - logger.warning("Created streaming card {} but failed to send it to {}", card_id, chat_id) + logger.warning( + "Created streaming card {} but failed to send it to {}", card_id, chat_id + ) return None except Exception as e: logger.warning("Error creating streaming card: {}", e) @@ -1019,18 +1191,32 @@ class FeishuChannel(BaseChannel): def _stream_update_text_sync(self, card_id: str, content: str, sequence: int) -> bool: """Stream-update the markdown element on a CardKit card (typewriter effect).""" - from lark_oapi.api.cardkit.v1 import ContentCardElementRequest, ContentCardElementRequestBody + from lark_oapi.api.cardkit.v1 import ( + ContentCardElementRequest, + ContentCardElementRequestBody, + ) + try: - request = ContentCardElementRequest.builder() \ - .card_id(card_id) \ - .element_id(_STREAM_ELEMENT_ID) \ + request = ( + ContentCardElementRequest.builder() + .card_id(card_id) + .element_id(_STREAM_ELEMENT_ID) .request_body( ContentCardElementRequestBody.builder() - .content(content).sequence(sequence).build() - ).build() + .content(content) + .sequence(sequence) + .build() + ) + .build() + ) response = self._client.cardkit.v1.card_element.content(request) if not response.success(): - logger.warning("Failed to stream-update card {}: code={}, msg={}", card_id, response.code, response.msg) + logger.warning( + "Failed to stream-update card {}: code={}, msg={}", + card_id, + response.code, + response.msg, + ) return False return True except Exception as e: @@ -1045,22 +1231,28 @@ class FeishuChannel(BaseChannel): Sequence must strictly exceed the previous card OpenAPI operation on this entity. """ from lark_oapi.api.cardkit.v1 import SettingsCardRequest, SettingsCardRequestBody + settings_payload = json.dumps({"config": {"streaming_mode": False}}, ensure_ascii=False) try: - request = SettingsCardRequest.builder() \ - .card_id(card_id) \ + request = ( + SettingsCardRequest.builder() + .card_id(card_id) .request_body( SettingsCardRequestBody.builder() .settings(settings_payload) .sequence(sequence) .uuid(str(uuid.uuid4())) .build() - ).build() + ) + .build() + ) response = self._client.cardkit.v1.card.settings(request) if not response.success(): logger.warning( "Failed to close streaming on card {}: code={}, msg={}", - card_id, response.code, response.msg, + card_id, + response.code, + response.msg, ) return False return True @@ -1068,7 +1260,9 @@ class FeishuChannel(BaseChannel): logger.warning("Error closing streaming on card {}: {}", card_id, e) return False - async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: + async def send_delta( + self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None + ) -> None: """Progressive streaming via CardKit: create card on first delta, stream-update on subsequent.""" if not self._client: return @@ -1087,17 +1281,31 @@ class FeishuChannel(BaseChannel): if buf.card_id: buf.sequence += 1 await loop.run_in_executor( - None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence, + None, + self._stream_update_text_sync, + buf.card_id, + buf.text, + buf.sequence, ) # Required so the chat list preview exits the streaming placeholder (Feishu streaming card docs). buf.sequence += 1 await loop.run_in_executor( - None, self._close_streaming_mode_sync, buf.card_id, buf.sequence, + None, + self._close_streaming_mode_sync, + buf.card_id, + buf.sequence, ) else: - for chunk in self._split_elements_by_table_limit(self._build_card_elements(buf.text)): - card = json.dumps({"config": {"wide_screen_mode": True}, "elements": chunk}, ensure_ascii=False) - await loop.run_in_executor(None, self._send_message_sync, rid_type, chat_id, "interactive", card) + for chunk in self._split_elements_by_table_limit( + self._build_card_elements(buf.text) + ): + card = json.dumps( + {"config": {"wide_screen_mode": True}, "elements": chunk}, + ensure_ascii=False, + ) + await loop.run_in_executor( + None, self._send_message_sync, rid_type, chat_id, "interactive", card + ) return # --- accumulate delta --- @@ -1111,15 +1319,21 @@ class FeishuChannel(BaseChannel): now = time.monotonic() if buf.card_id is None: - card_id = await loop.run_in_executor(None, self._create_streaming_card_sync, rid_type, chat_id) + card_id = await loop.run_in_executor( + None, self._create_streaming_card_sync, rid_type, chat_id + ) if card_id: buf.card_id = card_id buf.sequence = 1 - await loop.run_in_executor(None, self._stream_update_text_sync, card_id, buf.text, 1) + await loop.run_in_executor( + None, self._stream_update_text_sync, card_id, buf.text, 1 + ) buf.last_edit = now elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL: buf.sequence += 1 - await loop.run_in_executor(None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence) + await loop.run_in_executor( + None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence + ) buf.last_edit = now async def send(self, msg: OutboundMessage) -> None: @@ -1145,14 +1359,13 @@ class FeishuChannel(BaseChannel): # Only the very first send (media or text) in this call uses reply; subsequent # chunks/media fall back to plain create to avoid redundant quote bubbles. reply_message_id: str | None = None - if ( - self.config.reply_to_message - and not msg.metadata.get("_progress", False) - ): + if self.config.reply_to_message and not msg.metadata.get("_progress", False): reply_message_id = msg.metadata.get("message_id") or None # For topic group messages, always reply to keep context in thread elif msg.metadata.get("thread_id"): - reply_message_id = msg.metadata.get("root_id") or msg.metadata.get("message_id") or None + reply_message_id = ( + msg.metadata.get("root_id") or msg.metadata.get("message_id") or None + ) first_send = True # tracks whether the reply has already been used @@ -1176,8 +1389,10 @@ class FeishuChannel(BaseChannel): key = await loop.run_in_executor(None, self._upload_image_sync, file_path) if key: await loop.run_in_executor( - None, _do_send, - "image", json.dumps({"image_key": key}, ensure_ascii=False), + None, + _do_send, + "image", + json.dumps({"image_key": key}, ensure_ascii=False), ) else: key = await loop.run_in_executor(None, self._upload_file_sync, file_path) @@ -1192,8 +1407,10 @@ class FeishuChannel(BaseChannel): else: media_type = "file" await loop.run_in_executor( - None, _do_send, - media_type, json.dumps({"file_key": key}, ensure_ascii=False), + None, + _do_send, + media_type, + json.dumps({"file_key": key}, ensure_ascii=False), ) if msg.content and msg.content.strip(): @@ -1215,8 +1432,10 @@ class FeishuChannel(BaseChannel): for chunk in self._split_elements_by_table_limit(elements): card = {"config": {"wide_screen_mode": True}, "elements": chunk} await loop.run_in_executor( - None, _do_send, - "interactive", json.dumps(card, ensure_ascii=False), + None, + _do_send, + "interactive", + json.dumps(card, ensure_ascii=False), ) except Exception as e: @@ -1231,13 +1450,16 @@ class FeishuChannel(BaseChannel): if self._loop and self._loop.is_running(): asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop) - async def _on_message(self, data: Any) -> None: + async def _on_message(self, data: P2ImMessageReceiveV1) -> None: """Handle incoming message from Feishu.""" try: event = data.event message = event.message sender = event.sender - + + logger.debug("Feishu raw message: {}", message.content) + logger.debug("Feishu mentions: {}", getattr(message, "mentions", None)) + # Deduplication check message_id = message.message_id if message_id in self._processed_message_ids: @@ -1276,6 +1498,8 @@ class FeishuChannel(BaseChannel): if msg_type == "text": text = content_json.get("text", "") if text: + mentions = getattr(message, "mentions", None) + text = self._resolve_mentions(text, mentions) content_parts.append(text) elif msg_type == "post": @@ -1292,7 +1516,9 @@ class FeishuChannel(BaseChannel): content_parts.append(content_text) elif msg_type in ("image", "audio", "file", "media"): - file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id) + file_path, content_text = await self._download_and_save_media( + msg_type, content_json, message_id + ) if file_path: media_paths.append(file_path) @@ -1303,7 +1529,14 @@ class FeishuChannel(BaseChannel): content_parts.append(content_text) - elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"): + elif msg_type in ( + "share_chat", + "share_user", + "interactive", + "share_calendar_event", + "system", + "merge_forward", + ): # Handle share cards and interactive messages text = _extract_share_card_content(content_json, msg_type) if text: @@ -1346,7 +1579,7 @@ class FeishuChannel(BaseChannel): "parent_id": parent_id, "root_id": root_id, "thread_id": thread_id, - } + }, ) except Exception as e: @@ -1356,6 +1589,10 @@ class FeishuChannel(BaseChannel): """Ignore reaction events so they do not generate SDK noise.""" pass + def _on_reaction_deleted(self, data: Any) -> None: + """Ignore reaction deleted events so they do not generate SDK noise.""" + pass + def _on_message_read(self, data: Any) -> None: """Ignore read events so they do not generate SDK noise.""" pass @@ -1411,7 +1648,9 @@ class FeishuChannel(BaseChannel): return "\n".join(part for part in parts if part) - async def _send_tool_hint_card(self, receive_id_type: str, receive_id: str, tool_hint: str) -> None: + async def _send_tool_hint_card( + self, receive_id_type: str, receive_id: str, tool_hint: str + ) -> None: """Send tool hint as an interactive card with formatted code block. Args: @@ -1427,15 +1666,15 @@ class FeishuChannel(BaseChannel): card = { "config": {"wide_screen_mode": True}, "elements": [ - { - "tag": "markdown", - "content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```" - } - ] + {"tag": "markdown", "content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```"} + ], } await loop.run_in_executor( - None, self._send_message_sync, - receive_id_type, receive_id, "interactive", + None, + self._send_message_sync, + receive_id_type, + receive_id, + "interactive", json.dumps(card, ensure_ascii=False), ) diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 1f26f4d7a..aaec5e335 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -39,7 +39,8 @@ class ChannelManager: """Initialize channels discovered via pkgutil scan + entry_points plugins.""" from nanobot.channels.registry import discover_all - groq_key = self.config.providers.groq.api_key + transcription_provider = self.config.channels.transcription_provider + transcription_key = self._resolve_transcription_key(transcription_provider) for name, cls in discover_all().items(): section = getattr(self.config.channels, name, None) @@ -54,7 +55,8 @@ class ChannelManager: continue try: channel = cls(section, self.bus) - channel.transcription_api_key = groq_key + channel.transcription_provider = transcription_provider + channel.transcription_api_key = transcription_key self.channels[name] = channel logger.info("{} channel enabled", cls.display_name) except Exception as e: @@ -62,6 +64,15 @@ class ChannelManager: self._validate_allow_from() + def _resolve_transcription_key(self, provider: str) -> str: + """Pick the API key for the configured transcription provider.""" + try: + if provider == "openai": + return self.config.providers.openai.api_key + return self.config.providers.groq.api_key + except AttributeError: + return "" + def _validate_allow_from(self) -> None: for name, ch in self.channels.items(): if getattr(ch.config, "allow_from", None) == []: diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index bc6d9398a..85a167a3a 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -1,6 +1,7 @@ """Matrix (Element) channel β€” inbound sync + outbound message/media delivery.""" import asyncio +import json import logging import mimetypes import time @@ -17,10 +18,10 @@ try: from nio import ( AsyncClient, AsyncClientConfig, - ContentRepositoryConfigError, DownloadError, InviteEvent, JoinError, + LoginResponse, MatrixRoom, MemoryDownloadResponse, RoomEncryptedMedia, @@ -203,10 +204,11 @@ class MatrixConfig(Base): enabled: bool = False homeserver: str = "https://matrix.org" - access_token: str = "" user_id: str = "" + password: str = "" + access_token: str = "" device_id: str = "" - e2ee_enabled: bool = True + e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled") sync_stop_grace_seconds: int = 2 max_media_bytes: int = 20 * 1024 * 1024 allow_from: list[str] = Field(default_factory=list) @@ -256,17 +258,15 @@ class MatrixChannel(BaseChannel): self._running = True _configure_nio_logging_bridge() - store_path = get_data_dir() / "matrix-store" - store_path.mkdir(parents=True, exist_ok=True) + self.store_path = get_data_dir() / "matrix-store" + self.store_path.mkdir(parents=True, exist_ok=True) + self.session_path = self.store_path / "session.json" self.client = AsyncClient( homeserver=self.config.homeserver, user=self.config.user_id, - store_path=store_path, + store_path=self.store_path, config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled), ) - self.client.user_id = self.config.user_id - self.client.access_token = self.config.access_token - self.client.device_id = self.config.device_id self._register_event_callbacks() self._register_response_callbacks() @@ -274,13 +274,49 @@ class MatrixChannel(BaseChannel): if not self.config.e2ee_enabled: logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.") - if self.config.device_id: + if self.config.password: + if self.config.access_token or self.config.device_id: + logger.warning("Password-based Matrix login active; access_token and device_id fields will be ignored.") + + create_new_session = True + if self.session_path.exists(): + logger.info("Found session.json at {}; attempting to use existing session...", self.session_path) + try: + with open(self.session_path, "r", encoding="utf-8") as f: + session = json.load(f) + self.client.user_id = self.config.user_id + self.client.access_token = session["access_token"] + self.client.device_id = session["device_id"] + self.client.load_store() + logger.info("Successfully loaded from existing session") + create_new_session = False + except Exception as e: + logger.warning("Failed to load from existing session: {}", e) + logger.info("Falling back to password login...") + + if create_new_session: + logger.info("Using password login...") + resp = await self.client.login(self.config.password) + if isinstance(resp, LoginResponse): + logger.info("Logged in using a password; saving details to disk") + self._write_session_to_disk(resp) + else: + logger.error("Failed to log in: {}", resp) + return + + elif self.config.access_token and self.config.device_id: try: + self.client.user_id = self.config.user_id + self.client.access_token = self.config.access_token + self.client.device_id = self.config.device_id self.client.load_store() - except Exception: - logger.exception("Matrix store load failed; restart may replay recent messages.") + logger.info("Successfully loaded from existing session") + except Exception as e: + logger.warning("Failed to load from existing session: {}", e) + else: - logger.warning("Matrix device_id empty; restart may replay recent messages.") + logger.warning("Unable to load a Matrix session due to missing password, access_token, or device_id; encryption may not work") + return self._sync_task = asyncio.create_task(self._sync_loop()) @@ -304,6 +340,19 @@ class MatrixChannel(BaseChannel): if self.client: await self.client.close() + def _write_session_to_disk(self, resp: LoginResponse) -> None: + """Save login session to disk for persistence across restarts.""" + session = { + "access_token": resp.access_token, + "device_id": resp.device_id, + } + try: + with open(self.session_path, "w", encoding="utf-8") as f: + json.dump(session, f, indent=2) + logger.info("Session saved to {}", self.session_path) + except Exception as e: + logger.warning("Failed to save session: {}", e) + def _is_workspace_path_allowed(self, path: Path) -> bool: """Check path is inside workspace (when restriction enabled).""" if not self._restrict_to_workspace or not self._workspace: diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 1aa0568c6..aeb36d8e4 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -6,14 +6,14 @@ import asyncio import re import time import unicodedata -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Literal from loguru import logger from pydantic import Field from telegram import BotCommand, ReactionTypeEmoji, ReplyParameters, Update from telegram.error import BadRequest, NetworkError, TimedOut -from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters +from telegram.ext import Application, ContextTypes, MessageHandler, filters from telegram.request import HTTPXRequest from nanobot.bus.events import OutboundMessage @@ -558,8 +558,10 @@ class TelegramChannel(BaseChannel): await self._remove_reaction(chat_id, int(reply_to_message_id)) except ValueError: pass + chunks = split_message(buf.text, TELEGRAM_MAX_MESSAGE_LEN) + primary_text = chunks[0] if chunks else buf.text try: - html = _markdown_to_telegram_html(buf.text) + html = _markdown_to_telegram_html(primary_text) await self._call_with_retry( self._app.bot.edit_message_text, chat_id=int_chat_id, message_id=buf.message_id, @@ -575,15 +577,18 @@ class TelegramChannel(BaseChannel): await self._call_with_retry( self._app.bot.edit_message_text, chat_id=int_chat_id, message_id=buf.message_id, - text=buf.text, + text=primary_text, ) except Exception as e2: if self._is_not_modified_error(e2): logger.debug("Final stream plain edit already applied for {}", chat_id) - self._stream_bufs.pop(chat_id, None) - return - logger.warning("Final stream edit failed: {}", e2) - raise # Let ChannelManager handle retry + else: + logger.warning("Final stream edit failed: {}", e2) + raise # Let ChannelManager handle retry + # If final content exceeds Telegram limit, keep the first chunk in + # the edited stream message and send the rest as follow-up messages. + for extra_chunk in chunks[1:]: + await self._send_text(int_chat_id, extra_chunk) self._stream_bufs.pop(chat_id, None) return @@ -599,11 +604,15 @@ class TelegramChannel(BaseChannel): return now = time.monotonic() + thread_kwargs = {} + if message_thread_id := meta.get("message_thread_id"): + thread_kwargs["message_thread_id"] = message_thread_id if buf.message_id is None: try: sent = await self._call_with_retry( self._app.bot.send_message, chat_id=int_chat_id, text=buf.text, + **thread_kwargs, ) buf.message_id = sent.message_id buf.last_edit = now @@ -651,9 +660,9 @@ class TelegramChannel(BaseChannel): @staticmethod def _derive_topic_session_key(message) -> str | None: - """Derive topic-scoped session key for non-private Telegram chats.""" + """Derive topic-scoped session key for Telegram chats with threads.""" message_thread_id = getattr(message, "message_thread_id", None) - if message.chat.type == "private" or message_thread_id is None: + if message_thread_id is None: return None return f"telegram:{message.chat_id}:topic:{message_thread_id}" @@ -815,7 +824,7 @@ class TelegramChannel(BaseChannel): return bool(bot_id and reply_user and reply_user.id == bot_id) def _remember_thread_context(self, message) -> None: - """Cache topic thread id by chat/message id for follow-up replies.""" + """Cache Telegram thread context by chat/message id for follow-up replies.""" message_thread_id = getattr(message, "message_thread_id", None) if message_thread_id is None: return diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 2266bc9f0..3f87e2203 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -484,7 +484,7 @@ class WeixinChannel(BaseChannel): except httpx.TimeoutException: # Normal for long-poll, just retry continue - except Exception as e: + except Exception: if not self._running: break consecutive_failures += 1 diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py index a788dd727..a7fd82654 100644 --- a/nanobot/channels/whatsapp.py +++ b/nanobot/channels/whatsapp.py @@ -75,6 +75,7 @@ class WhatsAppChannel(BaseChannel): self._ws = None self._connected = False self._processed_message_ids: OrderedDict[str, None] = OrderedDict() + self._lid_to_phone: dict[str, str] = {} self._bridge_token: str | None = None def _effective_bridge_token(self) -> str: @@ -228,21 +229,45 @@ class WhatsAppChannel(BaseChannel): if not was_mentioned: return - user_id = pn if pn else sender - sender_id = user_id.split("@")[0] if "@" in user_id else user_id - logger.info("Sender {}", sender) + # Classify by JID suffix: @s.whatsapp.net = phone, @lid.whatsapp.net = LID + # The bridge's pn/sender fields don't consistently map to phone/LID across versions. + raw_a = pn or "" + raw_b = sender or "" + id_a = raw_a.split("@")[0] if "@" in raw_a else raw_a + id_b = raw_b.split("@")[0] if "@" in raw_b else raw_b - # Handle voice transcription if it's a voice message - if content == "[Voice Message]": - logger.info( - "Voice message received from {}, but direct download from bridge is not yet supported.", - sender_id, - ) - content = "[Voice Message: Transcription not available for WhatsApp yet]" + phone_id = "" + lid_id = "" + for raw, extracted in [(raw_a, id_a), (raw_b, id_b)]: + if "@s.whatsapp.net" in raw: + phone_id = extracted + elif "@lid.whatsapp.net" in raw: + lid_id = extracted + elif extracted and not phone_id: + phone_id = extracted # best guess for bare values + + if phone_id and lid_id: + self._lid_to_phone[lid_id] = phone_id + sender_id = phone_id or self._lid_to_phone.get(lid_id, "") or lid_id or id_a or id_b + + logger.info("Sender phone={} lid={} β†’ sender_id={}", phone_id or "(empty)", lid_id or "(empty)", sender_id) # Extract media paths (images/documents/videos downloaded by the bridge) media_paths = data.get("media") or [] + # Handle voice transcription if it's a voice message + if content == "[Voice Message]": + if media_paths: + logger.info("Transcribing voice message from {}...", sender_id) + transcription = await self.transcribe_audio(media_paths[0]) + if transcription: + content = transcription + logger.info("Transcribed voice from {}: {}...", sender_id, transcription[:50]) + else: + content = "[Voice Message: Transcription failed]" + else: + content = "[Voice Message: Audio not available]" + # Build content tags matching Telegram's pattern: [image: /path] or [file: /path] if media_paths: for p in media_paths: diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index dfb13ba97..a1fb7c0e0 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -1,12 +1,11 @@ """CLI commands for nanobot.""" import asyncio -from contextlib import contextmanager, nullcontext - import os import select import signal import sys +from contextlib import nullcontext from pathlib import Path from typing import Any @@ -34,6 +33,19 @@ from rich.table import Table from rich.text import Text from nanobot import __logo__, __version__ + + +class SafeFileHistory(FileHistory): + """FileHistory subclass that sanitizes surrogate characters on write. + + On Windows, special Unicode input (emoji, mixed-script) can produce + surrogate characters that crash prompt_toolkit's file write. + See issue #2846. + """ + + def store_string(self, string: str) -> None: + safe = string.encode("utf-8", errors="surrogateescape").decode("utf-8", errors="replace") + super().store_string(safe) from nanobot.cli.stream import StreamRenderer, ThinkingSpinner from nanobot.config.paths import get_workspace_path, is_default_workspace from nanobot.config.schema import Config @@ -73,6 +85,7 @@ def _flush_pending_tty_input() -> None: try: import termios + termios.tcflush(fd, termios.TCIFLUSH) return except Exception: @@ -95,6 +108,7 @@ def _restore_terminal() -> None: return try: import termios + termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, _SAVED_TERM_ATTRS) except Exception: pass @@ -107,6 +121,7 @@ def _init_prompt_session() -> None: # Save terminal state so we can restore it on exit try: import termios + _SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno()) except Exception: pass @@ -117,9 +132,9 @@ def _init_prompt_session() -> None: history_file.parent.mkdir(parents=True, exist_ok=True) _PROMPT_SESSION = PromptSession( - history=FileHistory(str(history_file)), + history=SafeFileHistory(str(history_file)), enable_open_in_editor=False, - multiline=False, # Enter submits (single line mode) + multiline=False, # Enter submits (single line mode) ) @@ -231,7 +246,6 @@ async def _read_interactive_input_async() -> str: raise KeyboardInterrupt from exc - def version_callback(value: bool): if value: console.print(f"{__logo__} nanobot v{__version__}") @@ -281,8 +295,12 @@ def onboard( config = _apply_workspace_override(load_config(config_path)) else: console.print(f"[yellow]Config already exists at {config_path}[/yellow]") - console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)") - console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields") + console.print( + " [bold]y[/bold] = overwrite with defaults (existing values will be lost)" + ) + console.print( + " [bold]N[/bold] = refresh config, keeping existing values and adding new fields" + ) if typer.confirm("Overwrite?"): config = _apply_workspace_override(Config()) save_config(config, config_path) @@ -290,7 +308,9 @@ def onboard( else: config = _apply_workspace_override(load_config(config_path)) save_config(config, config_path) - console.print(f"[green]βœ“[/green] Config refreshed at {config_path} (existing values preserved)") + console.print( + f"[green]βœ“[/green] Config refreshed at {config_path} (existing values preserved)" + ) else: config = _apply_workspace_override(Config()) # In wizard mode, don't save yet - the wizard will handle saving if should_save=True @@ -340,7 +360,9 @@ def onboard( console.print(f" 1. Add your API key to [cyan]{config_path}[/cyan]") console.print(" Get one at: https://openrouter.ai/keys") console.print(f" 2. Chat: [cyan]{agent_cmd}[/cyan]") - console.print("\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]") + console.print( + "\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]" + ) def _merge_missing_defaults(existing: Any, defaults: Any) -> Any: @@ -413,9 +435,11 @@ def _make_provider(config: Config): # --- instantiation by backend --- if backend == "openai_codex": from nanobot.providers.openai_codex_provider import OpenAICodexProvider + provider = OpenAICodexProvider(default_model=model) elif backend == "azure_openai": from nanobot.providers.azure_openai_provider import AzureOpenAIProvider + provider = AzureOpenAIProvider( api_key=p.api_key, api_base=p.api_base, @@ -426,6 +450,7 @@ def _make_provider(config: Config): provider = GitHubCopilotProvider(default_model=model) elif backend == "anthropic": from nanobot.providers.anthropic_provider import AnthropicProvider + provider = AnthropicProvider( api_key=p.api_key if p else None, api_base=config.get_api_base(model), @@ -434,6 +459,7 @@ def _make_provider(config: Config): ) else: from nanobot.providers.openai_compat_provider import OpenAICompatProvider + provider = OpenAICompatProvider( api_key=p.api_key if p else None, api_base=config.get_api_base(model), @@ -453,7 +479,7 @@ def _make_provider(config: Config): def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config: """Load config and optionally override the active workspace.""" - from nanobot.config.loader import load_config, set_config_path + from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path config_path = None if config: @@ -464,7 +490,11 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None set_config_path(config_path) console.print(f"[dim]Using config: {config_path}[/dim]") - loaded = load_config(config_path) + try: + loaded = resolve_config_env_vars(load_config(config_path)) + except ValueError as e: + console.print(f"[red]Error: {e}[/red]") + raise typer.Exit(1) _warn_deprecated_config_keys(config_path) if workspace: loaded.agents.defaults.workspace = workspace @@ -474,6 +504,7 @@ def _load_runtime_config(config: str | None = None, workspace: str | None = None def _warn_deprecated_config_keys(config_path: Path | None) -> None: """Hint users to remove obsolete keys from their config file.""" import json + from nanobot.config.loader import get_config_path path = config_path or get_config_path() @@ -497,6 +528,7 @@ def _migrate_cron_store(config: "Config") -> None: if legacy_path.is_file() and not new_path.exists(): new_path.parent.mkdir(parents=True, exist_ok=True) import shutil + shutil.move(str(legacy_path), str(new_path)) @@ -610,6 +642,7 @@ def gateway( if verbose: import logging + logging.basicConfig(level=logging.DEBUG) config = _load_runtime_config(config, workspace) @@ -695,7 +728,7 @@ def gateway( if job.payload.deliver and job.payload.to and response: should_notify = await evaluate_response( - response, job.payload.message, provider, agent.model, + response, reminder_note, provider, agent.model, ) if should_notify: from nanobot.bus.events import OutboundMessage @@ -705,6 +738,7 @@ def gateway( content=response, )) return response + cron.on_job = on_cron_job # Create channel manager @@ -808,6 +842,7 @@ def gateway( console.print("\nShutting down...") except Exception: import traceback + console.print("\n[red]Error: Gateway crashed unexpectedly[/red]") console.print(traceback.format_exc()) finally: @@ -820,8 +855,6 @@ def gateway( asyncio.run(run()) - - # ============================================================================ # Agent Commands # ============================================================================ @@ -1296,6 +1329,7 @@ def _register_login(name: str): def decorator(fn): _LOGIN_HANDLERS[name] = fn return fn + return decorator @@ -1326,6 +1360,7 @@ def provider_login( def _login_openai_codex() -> None: try: from oauth_cli_kit import get_token, login_oauth_interactive + token = None try: token = get_token() diff --git a/nanobot/command/builtin.py b/nanobot/command/builtin.py index a5629f66e..94e46320b 100644 --- a/nanobot/command/builtin.py +++ b/nanobot/command/builtin.py @@ -60,6 +60,20 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage: pass if ctx_est <= 0: ctx_est = loop._last_usage.get("prompt_tokens", 0) + + # Fetch web search provider usage (best-effort, never blocks the response) + search_usage_text: str | None = None + try: + from nanobot.utils.searchusage import fetch_search_usage + web_cfg = getattr(loop, "web_config", None) + search_cfg = getattr(web_cfg, "search", None) if web_cfg else None + if search_cfg is not None: + provider = getattr(search_cfg, "provider", "duckduckgo") + api_key = getattr(search_cfg, "api_key", "") or None + usage = await fetch_search_usage(provider=provider, api_key=api_key) + search_usage_text = usage.format() + except Exception: + pass # Never let usage fetch break /status return OutboundMessage( channel=ctx.msg.channel, chat_id=ctx.msg.chat_id, @@ -69,6 +83,7 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage: context_window_tokens=loop.context_window_tokens, session_msg_count=len(session.get_history(max_messages=0)), context_tokens_estimate=ctx_est, + search_usage_text=search_usage_text, ), metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"}, ) @@ -93,14 +108,30 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage: async def cmd_dream(ctx: CommandContext) -> OutboundMessage: """Manually trigger a Dream consolidation run.""" + import time + loop = ctx.loop - try: - did_work = await loop.dream.run() - content = "Dream completed." if did_work else "Dream: nothing to process." - except Exception as e: - content = f"Dream failed: {e}" + msg = ctx.msg + + async def _run_dream(): + t0 = time.monotonic() + try: + did_work = await loop.dream.run() + elapsed = time.monotonic() - t0 + if did_work: + content = f"Dream completed in {elapsed:.1f}s." + else: + content = "Dream: nothing to process." + except Exception as e: + elapsed = time.monotonic() - t0 + content = f"Dream failed after {elapsed:.1f}s: {e}" + await loop.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content=content, + )) + + asyncio.create_task(_run_dream()) return OutboundMessage( - channel=ctx.msg.channel, chat_id=ctx.msg.chat_id, content=content, + channel=msg.channel, chat_id=msg.chat_id, content="Dreaming...", ) diff --git a/nanobot/config/loader.py b/nanobot/config/loader.py index f5b2f33b8..618334c1c 100644 --- a/nanobot/config/loader.py +++ b/nanobot/config/loader.py @@ -1,6 +1,8 @@ """Configuration loading utilities.""" import json +import os +import re from pathlib import Path import pydantic @@ -76,6 +78,38 @@ def save_config(config: Config, config_path: Path | None = None) -> None: json.dump(data, f, indent=2, ensure_ascii=False) +def resolve_config_env_vars(config: Config) -> Config: + """Return a copy of *config* with ``${VAR}`` env-var references resolved. + + Only string values are affected; other types pass through unchanged. + Raises :class:`ValueError` if a referenced variable is not set. + """ + data = config.model_dump(mode="json", by_alias=True) + data = _resolve_env_vars(data) + return Config.model_validate(data) + + +def _resolve_env_vars(obj: object) -> object: + """Recursively resolve ``${VAR}`` patterns in string values.""" + if isinstance(obj, str): + return re.sub(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}", _env_replace, obj) + if isinstance(obj, dict): + return {k: _resolve_env_vars(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_resolve_env_vars(v) for v in obj] + return obj + + +def _env_replace(match: re.Match[str]) -> str: + name = match.group(1) + value = os.environ.get(name) + if value is None: + raise ValueError( + f"Environment variable '{name}' referenced in config is not set" + ) + return value + + def _migrate_config(data: dict) -> dict: """Migrate old config formats to current.""" # Move tools.exec.restrictToWorkspace β†’ tools.restrictToWorkspace diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 0b5d6a817..f147434e7 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -28,6 +28,7 @@ class ChannelsConfig(Base): send_progress: bool = True # stream agent's text progress to the channel send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…")) send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included) + transcription_provider: str = "groq" # Voice transcription backend: "groq" or "openai" class DreamConfig(Base): @@ -155,6 +156,7 @@ class WebSearchConfig(Base): api_key: str = "" base_url: str = "" # SearXNG base URL max_results: int = 5 + timeout: int = 30 # Wall-clock timeout (seconds) for search operations class WebToolsConfig(Base): @@ -173,6 +175,7 @@ class ExecToolConfig(Base): enable: bool = True timeout: int = 60 path_append: str = "" + sandbox: str = "" # sandbox backend: "" (none) or "bwrap" class MCPServerConfig(Base): """MCP server connection configuration (stdio or HTTP).""" @@ -191,7 +194,7 @@ class ToolsConfig(Base): web: WebToolsConfig = Field(default_factory=WebToolsConfig) exec: ExecToolConfig = Field(default_factory=ExecToolConfig) - restrict_to_workspace: bool = False # If true, restrict all tool access to workspace directory + restrict_to_workspace: bool = False # restrict all tool access to workspace directory mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict) ssrf_whitelist: list[str] = Field(default_factory=list) # CIDR ranges to exempt from SSRF blocking (e.g. ["100.64.0.0/10"] for Tailscale) diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index 4860fa312..85e9e1ddb 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -47,7 +47,7 @@ class Nanobot: ``~/.nanobot/config.json``. workspace: Override the workspace directory from config. """ - from nanobot.config.loader import load_config + from nanobot.config.loader import load_config, resolve_config_env_vars from nanobot.config.schema import Config resolved: Path | None = None @@ -56,7 +56,7 @@ class Nanobot: if not resolved.exists(): raise FileNotFoundError(f"Config not found: {resolved}") - config: Config = load_config(resolved) + config: Config = resolve_config_env_vars(load_config(resolved)) if workspace is not None: config.agents.defaults.workspace = str( Path(workspace).expanduser().resolve() diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index 1cade5fb5..e389b51ed 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -52,6 +52,62 @@ class AnthropicProvider(LLMProvider): client_kw["max_retries"] = 0 self._client = AsyncAnthropic(**client_kw) + @classmethod + def _handle_error(cls, e: Exception) -> LLMResponse: + response = getattr(e, "response", None) + headers = getattr(response, "headers", None) + payload = ( + getattr(e, "body", None) + or getattr(e, "doc", None) + or getattr(response, "text", None) + ) + if payload is None and response is not None: + response_json = getattr(response, "json", None) + if callable(response_json): + try: + payload = response_json() + except Exception: + payload = None + payload_text = payload if isinstance(payload, str) else str(payload) if payload is not None else "" + msg = f"Error: {payload_text.strip()[:500]}" if payload_text.strip() else f"Error calling LLM: {e}" + retry_after = cls._extract_retry_after_from_headers(headers) + if retry_after is None: + retry_after = LLMProvider._extract_retry_after(msg) + + status_code = getattr(e, "status_code", None) + if status_code is None and response is not None: + status_code = getattr(response, "status_code", None) + + should_retry: bool | None = None + if headers is not None: + raw = headers.get("x-should-retry") + if isinstance(raw, str): + lowered = raw.strip().lower() + if lowered == "true": + should_retry = True + elif lowered == "false": + should_retry = False + + error_kind: str | None = None + error_name = e.__class__.__name__.lower() + if "timeout" in error_name: + error_kind = "timeout" + elif "connection" in error_name: + error_kind = "connection" + error_type, error_code = LLMProvider._extract_error_type_code(payload) + + return LLMResponse( + content=msg, + finish_reason="error", + retry_after=retry_after, + error_status_code=int(status_code) if status_code is not None else None, + error_kind=error_kind, + error_type=error_type, + error_code=error_code, + error_retry_after_s=retry_after, + error_should_retry=should_retry, + ) + @staticmethod def _strip_prefix(model: str) -> str: if model.startswith("anthropic/"): @@ -404,15 +460,6 @@ class AnthropicProvider(LLMProvider): # Public API # ------------------------------------------------------------------ - @staticmethod - def _handle_error(e: Exception) -> LLMResponse: - msg = f"Error calling LLM: {e}" - response = getattr(e, "response", None) - retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None)) - if retry_after is None: - retry_after = LLMProvider._extract_retry_after(msg) - return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after) - async def chat( self, messages: list[dict[str, Any]], @@ -474,6 +521,7 @@ class AnthropicProvider(LLMProvider): f"{idle_timeout_s} seconds" ), finish_reason="error", + error_kind="timeout", ) except Exception as e: return self._handle_error(e) diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 118eb80ca..d5833c9ae 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -54,6 +54,13 @@ class LLMResponse: retry_after: float | None = None # Provider supplied retry wait in seconds. reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc. thinking_blocks: list[dict] | None = None # Anthropic extended thinking + # Structured error metadata used by retry policy when finish_reason == "error". + error_status_code: int | None = None + error_kind: str | None = None # e.g. "timeout", "connection" + error_type: str | None = None # Provider/type semantic, e.g. insufficient_quota. + error_code: str | None = None # Provider/code semantic, e.g. rate_limit_exceeded. + error_retry_after_s: float | None = None + error_should_retry: bool | None = None @property def has_tool_calls(self) -> bool: @@ -91,6 +98,52 @@ class LLMProvider(ABC): "server error", "temporarily unavailable", ) + _RETRYABLE_STATUS_CODES = frozenset({408, 409, 429}) + _TRANSIENT_ERROR_KINDS = frozenset({"timeout", "connection"}) + _NON_RETRYABLE_429_ERROR_TOKENS = frozenset({ + "insufficient_quota", + "quota_exceeded", + "quota_exhausted", + "billing_hard_limit_reached", + "insufficient_balance", + "credit_balance_too_low", + "billing_not_active", + "payment_required", + }) + _RETRYABLE_429_ERROR_TOKENS = frozenset({ + "rate_limit_exceeded", + "rate_limit_error", + "too_many_requests", + "request_limit_exceeded", + "requests_limit_exceeded", + "overloaded_error", + }) + _NON_RETRYABLE_429_TEXT_MARKERS = ( + "insufficient_quota", + "insufficient quota", + "quota exceeded", + "quota exhausted", + "billing hard limit", + "billing_hard_limit_reached", + "billing not active", + "insufficient balance", + "insufficient_balance", + "credit balance too low", + "payment required", + "out of credits", + "out of quota", + "exceeded your current quota", + ) + _RETRYABLE_429_TEXT_MARKERS = ( + "rate limit", + "rate_limit", + "too many requests", + "retry after", + "try again in", + "temporarily unavailable", + "overloaded", + "concurrency limit", + ) _SENTINEL = object() @@ -226,6 +279,80 @@ class LLMProvider(ABC): err = (content or "").lower() return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) + @classmethod + def _is_transient_response(cls, response: LLMResponse) -> bool: + """Prefer structured error metadata, fallback to text markers for legacy providers.""" + if response.error_should_retry is not None: + return bool(response.error_should_retry) + + if response.error_status_code is not None: + status = int(response.error_status_code) + if status == 429: + return cls._is_retryable_429_response(response) + if status in cls._RETRYABLE_STATUS_CODES or status >= 500: + return True + + kind = (response.error_kind or "").strip().lower() + if kind in cls._TRANSIENT_ERROR_KINDS: + return True + + return cls._is_transient_error(response.content) + + @staticmethod + def _normalize_error_token(value: Any) -> str | None: + if value is None: + return None + token = str(value).strip().lower() + return token or None + + @classmethod + def _extract_error_type_code(cls, payload: Any) -> tuple[str | None, str | None]: + data: dict[str, Any] | None = None + if isinstance(payload, dict): + data = payload + elif isinstance(payload, str): + text = payload.strip() + if text: + try: + parsed = json.loads(text) + except Exception: + parsed = None + if isinstance(parsed, dict): + data = parsed + if not isinstance(data, dict): + return None, None + + error_obj = data.get("error") + type_value = data.get("type") + code_value = data.get("code") + if isinstance(error_obj, dict): + type_value = error_obj.get("type") or type_value + code_value = error_obj.get("code") or code_value + + return cls._normalize_error_token(type_value), cls._normalize_error_token(code_value) + + @classmethod + def _is_retryable_429_response(cls, response: LLMResponse) -> bool: + type_token = cls._normalize_error_token(response.error_type) + code_token = cls._normalize_error_token(response.error_code) + semantic_tokens = { + token for token in (type_token, code_token) + if token is not None + } + if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in semantic_tokens): + return False + + content = (response.content or "").lower() + if any(marker in content for marker in cls._NON_RETRYABLE_429_TEXT_MARKERS): + return False + + if any(token in cls._RETRYABLE_429_ERROR_TOKENS for token in semantic_tokens): + return True + if any(marker in content for marker in cls._RETRYABLE_429_TEXT_MARKERS): + return True + # Unknown 429 defaults to WAIT+retry. + return True + @staticmethod def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None: """Replace image_url blocks with text placeholder. Returns None if no images found.""" @@ -397,14 +524,28 @@ class LLMProvider(ABC): def _extract_retry_after_from_headers(cls, headers: Any) -> float | None: if not headers: return None - retry_after: Any = None - if hasattr(headers, "get"): - retry_after = headers.get("retry-after") or headers.get("Retry-After") - if retry_after is None and isinstance(headers, dict): - for key, value in headers.items(): - if isinstance(key, str) and key.lower() == "retry-after": - retry_after = value - break + + def _header_value(name: str) -> Any: + if hasattr(headers, "get"): + value = headers.get(name) or headers.get(name.title()) + if value is not None: + return value + if isinstance(headers, dict): + for key, value in headers.items(): + if isinstance(key, str) and key.lower() == name.lower(): + return value + return None + + try: + retry_ms = _header_value("retry-after-ms") + if retry_ms is not None: + value = float(retry_ms) / 1000.0 + if value > 0: + return value + except (TypeError, ValueError): + pass + + retry_after = _header_value("retry-after") if retry_after is None: return None retry_after_text = str(retry_after).strip() @@ -421,6 +562,14 @@ class LLMProvider(ABC): remaining = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds() return max(0.1, remaining) + @classmethod + def _extract_retry_after_from_response(cls, response: LLMResponse) -> float | None: + if response.error_retry_after_s is not None and response.error_retry_after_s > 0: + return response.error_retry_after_s + if response.retry_after is not None and response.retry_after > 0: + return response.retry_after + return cls._extract_retry_after(response.content) + async def _sleep_with_heartbeat( self, delay: float, @@ -469,7 +618,7 @@ class LLMProvider(ABC): last_error_key = error_key identical_error_count = 1 if error_key else 0 - if not self._is_transient_error(response.content): + if not self._is_transient_response(response): stripped = self._strip_image_content(original_messages) if stripped is not None and stripped != kw["messages"]: logger.warning( @@ -492,7 +641,7 @@ class LLMProvider(ABC): break base_delay = delays[min(attempt - 1, len(delays) - 1)] - delay = response.retry_after or self._extract_retry_after(response.content) or base_delay + delay = self._extract_retry_after_from_response(response) or base_delay if persistent: delay = min(delay, self._PERSISTENT_MAX_DELAY) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index a216e9046..706268585 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio import hashlib +import importlib.util import os import secrets import string @@ -12,7 +13,17 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any import json_repair -from openai import AsyncOpenAI + +if os.environ.get("LANGFUSE_SECRET_KEY") and importlib.util.find_spec("langfuse"): + from langfuse.openai import AsyncOpenAI +else: + if os.environ.get("LANGFUSE_SECRET_KEY"): + import logging + logging.getLogger(__name__).warning( + "LANGFUSE_SECRET_KEY is set but langfuse is not installed; " + "install with `pip install langfuse` to enable tracing" + ) + from openai import AsyncOpenAI from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest @@ -286,6 +297,24 @@ class OpenAICompatProvider(LLMProvider): if reasoning_effort: kwargs["reasoning_effort"] = reasoning_effort + # Provider-specific thinking parameters. + # Only sent when reasoning_effort is explicitly configured so that + # the provider default is preserved otherwise. + if spec and reasoning_effort is not None: + thinking_enabled = reasoning_effort.lower() != "minimal" + extra: dict[str, Any] | None = None + if spec.name == "dashscope": + extra = {"enable_thinking": thinking_enabled} + elif spec.name in ( + "volcengine", "volcengine_coding_plan", + "byteplus", "byteplus_coding_plan", + ): + extra = { + "thinking": {"type": "enabled" if thinking_enabled else "disabled"} + } + if extra: + kwargs.setdefault("extra_body", {}).update(extra) + if tools: kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" @@ -605,16 +634,73 @@ class OpenAICompatProvider(LLMProvider): reasoning_content="".join(reasoning_parts) or None, ) + @classmethod + def _extract_error_metadata(cls, e: Exception) -> dict[str, Any]: + response = getattr(e, "response", None) + headers = getattr(response, "headers", None) + payload = ( + getattr(e, "body", None) + or getattr(e, "doc", None) + or getattr(response, "text", None) + ) + if payload is None and response is not None: + response_json = getattr(response, "json", None) + if callable(response_json): + try: + payload = response_json() + except Exception: + payload = None + error_type, error_code = LLMProvider._extract_error_type_code(payload) + + status_code = getattr(e, "status_code", None) + if status_code is None and response is not None: + status_code = getattr(response, "status_code", None) + + should_retry: bool | None = None + if headers is not None: + raw = headers.get("x-should-retry") + if isinstance(raw, str): + lowered = raw.strip().lower() + if lowered == "true": + should_retry = True + elif lowered == "false": + should_retry = False + + error_kind: str | None = None + error_name = e.__class__.__name__.lower() + if "timeout" in error_name: + error_kind = "timeout" + elif "connection" in error_name: + error_kind = "connection" + + return { + "error_status_code": int(status_code) if status_code is not None else None, + "error_kind": error_kind, + "error_type": error_type, + "error_code": error_code, + "error_retry_after_s": cls._extract_retry_after_from_headers(headers), + "error_should_retry": should_retry, + } + @staticmethod def _handle_error(e: Exception) -> LLMResponse: + body = ( + getattr(e, "doc", None) + or getattr(e, "body", None) + or getattr(getattr(e, "response", None), "text", None) + ) + body_text = body if isinstance(body, str) else str(body) if body is not None else "" + msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}" response = getattr(e, "response", None) - body = getattr(e, "doc", None) or getattr(response, "text", None) - body_text = str(body).strip() if body is not None else "" - msg = f"Error: {body_text[:500]}" if body_text else f"Error calling LLM: {e}" retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None)) if retry_after is None: retry_after = LLMProvider._extract_retry_after(msg) - return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after) + return LLMResponse( + content=msg, + finish_reason="error", + retry_after=retry_after, + **OpenAICompatProvider._extract_error_metadata(e), + ) # ------------------------------------------------------------------ # Public API @@ -682,6 +768,7 @@ class OpenAICompatProvider(LLMProvider): f"{idle_timeout_s} seconds" ), finish_reason="error", + error_kind="timeout", ) except Exception as e: return self._handle_error(e) diff --git a/nanobot/providers/transcription.py b/nanobot/providers/transcription.py index 1c8cb6a3f..aca9693ee 100644 --- a/nanobot/providers/transcription.py +++ b/nanobot/providers/transcription.py @@ -1,4 +1,4 @@ -"""Voice transcription provider using Groq.""" +"""Voice transcription providers (Groq and OpenAI Whisper).""" import os from pathlib import Path @@ -7,6 +7,36 @@ import httpx from loguru import logger +class OpenAITranscriptionProvider: + """Voice transcription provider using OpenAI's Whisper API.""" + + def __init__(self, api_key: str | None = None): + self.api_key = api_key or os.environ.get("OPENAI_API_KEY") + self.api_url = "https://api.openai.com/v1/audio/transcriptions" + + async def transcribe(self, file_path: str | Path) -> str: + if not self.api_key: + logger.warning("OpenAI API key not configured for transcription") + return "" + path = Path(file_path) + if not path.exists(): + logger.error("Audio file not found: {}", file_path) + return "" + try: + async with httpx.AsyncClient() as client: + with open(path, "rb") as f: + files = {"file": (path.name, f), "model": (None, "whisper-1")} + headers = {"Authorization": f"Bearer {self.api_key}"} + response = await client.post( + self.api_url, headers=headers, files=files, timeout=60.0, + ) + response.raise_for_status() + return response.json().get("text", "") + except Exception as e: + logger.error("OpenAI transcription error: {}", e) + return "" + + class GroqTranscriptionProvider: """ Voice transcription provider using Groq's Whisper API. diff --git a/nanobot/templates/agent/evaluator.md b/nanobot/templates/agent/evaluator.md index 305e4f8d0..51cf7a4e4 100644 --- a/nanobot/templates/agent/evaluator.md +++ b/nanobot/templates/agent/evaluator.md @@ -1,7 +1,9 @@ {% if part == 'system' %} You are a notification gate for a background agent. You will be given the original task and the agent's response. Call the evaluate_notification tool to decide whether the user should be notified. -Notify when the response contains actionable information, errors, completed deliverables, or anything the user explicitly asked to be reminded about. +Notify when the response contains actionable information, errors, completed deliverables, scheduled reminder/timer completions, or anything the user explicitly asked to be reminded about. + +A user-scheduled reminder should usually notify even when the response is brief or mostly repeats the original reminder. Suppress when the response is a routine status check with nothing new, a confirmation that everything is normal, or essentially empty. {% elif part == 'user' %} diff --git a/nanobot/utils/__init__.py b/nanobot/utils/__init__.py index 46f02acbd..9ad157c2e 100644 --- a/nanobot/utils/__init__.py +++ b/nanobot/utils/__init__.py @@ -1,5 +1,6 @@ """Utility functions for nanobot.""" from nanobot.utils.helpers import ensure_dir +from nanobot.utils.path import abbreviate_path -__all__ = ["ensure_dir"] +__all__ = ["ensure_dir", "abbreviate_path"] diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 93293c9e0..7267bac2a 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -396,8 +396,15 @@ def build_status_content( context_window_tokens: int, session_msg_count: int, context_tokens_estimate: int, + search_usage_text: str | None = None, ) -> str: - """Build a human-readable runtime status snapshot.""" + """Build a human-readable runtime status snapshot. + + Args: + search_usage_text: Optional pre-formatted web search usage string + (produced by SearchUsageInfo.format()). When provided + it is appended as an extra section. + """ uptime_s = int(time.time() - start_time) uptime = ( f"{uptime_s // 3600}h {(uptime_s % 3600) // 60}m" @@ -414,14 +421,17 @@ def build_status_content( token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out" if cached and last_in: token_line += f" ({cached * 100 // last_in}% cached)" - return "\n".join([ + lines = [ f"\U0001f408 nanobot v{version}", f"\U0001f9e0 Model: {model}", token_line, f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)", f"\U0001f4ac Session: {session_msg_count} messages", f"\u23f1 Uptime: {uptime}", - ]) + ] + if search_usage_text: + lines.append(search_usage_text) + return "\n".join(lines) def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]: diff --git a/nanobot/utils/path.py b/nanobot/utils/path.py new file mode 100644 index 000000000..32591a471 --- /dev/null +++ b/nanobot/utils/path.py @@ -0,0 +1,107 @@ +"""Path abbreviation utilities for display.""" + +from __future__ import annotations + +import os +import re +from urllib.parse import urlparse + + +def abbreviate_path(path: str, max_len: int = 40) -> str: + """Abbreviate a file path or URL, preserving basename and key directories. + + Strategy: + 1. Return as-is if short enough + 2. Replace home directory with ~/ + 3. From right, keep basename + parent dirs until budget exhausted + 4. Prefix with …/ + """ + if not path: + return path + + # Handle URLs: preserve scheme://domain + filename + if re.match(r"https?://", path): + return _abbreviate_url(path, max_len) + + # Normalize separators to / + normalized = path.replace("\\", "/") + + # Replace home directory + home = os.path.expanduser("~").replace("\\", "/") + if normalized.startswith(home + "/"): + normalized = "~" + normalized[len(home):] + elif normalized == home: + normalized = "~" + + # Return early only after normalization and home replacement + if len(normalized) <= max_len: + return normalized + + # Split into segments + parts = normalized.rstrip("/").split("/") + if len(parts) <= 1: + return normalized[:max_len - 1] + "\u2026" + + # Always keep the basename + basename = parts[-1] + # Budget: max_len minus "…/" prefix (2 chars) minus "/" separator minus basename + budget = max_len - len(basename) - 3 # -3 for "…/" + final "/" + + # Walk backwards from parent, collecting segments + kept: list[str] = [] + for seg in reversed(parts[:-1]): + needed = len(seg) + 1 # segment + "/" + if not kept and needed <= budget: + kept.append(seg) + budget -= needed + elif kept: + needed_with_sep = len(seg) + 1 + if needed_with_sep <= budget: + kept.append(seg) + budget -= needed_with_sep + else: + break + else: + break + + kept.reverse() + if kept: + return "\u2026/" + "/".join(kept) + "/" + basename + return "\u2026/" + basename + + +def _abbreviate_url(url: str, max_len: int = 40) -> str: + """Abbreviate a URL keeping domain and filename.""" + if len(url) <= max_len: + return url + + parsed = urlparse(url) + domain = parsed.netloc # e.g. "example.com" + path_part = parsed.path # e.g. "/api/v2/resource.json" + + # Extract filename from path + segments = path_part.rstrip("/").split("/") + basename = segments[-1] if segments else "" + + if not basename: + # No filename, truncate URL + return url[: max_len - 1] + "\u2026" + + budget = max_len - len(domain) - len(basename) - 4 # "…/" + "/" + if budget < 0: + trunc = max_len - len(domain) - 5 # "…/" + "/" + return domain + "/\u2026/" + (basename[:trunc] if trunc > 0 else "") + + # Build abbreviated path + kept: list[str] = [] + for seg in reversed(segments[:-1]): + if len(seg) + 1 <= budget: + kept.append(seg) + budget -= len(seg) + 1 + else: + break + + kept.reverse() + if kept: + return domain + "/\u2026/" + "/".join(kept) + "/" + basename + return domain + "/\u2026/" + basename diff --git a/nanobot/utils/runtime.py b/nanobot/utils/runtime.py index 7164629c5..25d456955 100644 --- a/nanobot/utils/runtime.py +++ b/nanobot/utils/runtime.py @@ -16,8 +16,7 @@ EMPTY_FINAL_RESPONSE_MESSAGE = ( ) FINALIZATION_RETRY_PROMPT = ( - "You have already finished the tool work. Do not call any more tools. " - "Using only the conversation and tool results above, provide the final answer for the user now." + "Please provide your response to the user based on the conversation above." ) diff --git a/nanobot/utils/searchusage.py b/nanobot/utils/searchusage.py new file mode 100644 index 000000000..ac490aadd --- /dev/null +++ b/nanobot/utils/searchusage.py @@ -0,0 +1,168 @@ +"""Web search provider usage fetchers for /status command.""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Any + + +@dataclass +class SearchUsageInfo: + """Structured usage info returned by a provider fetcher.""" + + provider: str + supported: bool = False # True if the provider has a usage API + error: str | None = None # Set when the API call failed + + # Usage counters (None = not available for this provider) + used: int | None = None + limit: int | None = None + remaining: int | None = None + reset_date: str | None = None # ISO date string, e.g. "2026-05-01" + + # Tavily-specific breakdown + search_used: int | None = None + extract_used: int | None = None + crawl_used: int | None = None + + def format(self) -> str: + """Return a human-readable multi-line string for /status output.""" + lines = [f"πŸ” Web Search: {self.provider}"] + + if not self.supported: + lines.append(" Usage tracking: not available for this provider") + return "\n".join(lines) + + if self.error: + lines.append(f" Usage: unavailable ({self.error})") + return "\n".join(lines) + + if self.used is not None and self.limit is not None: + lines.append(f" Usage: {self.used} / {self.limit} requests") + elif self.used is not None: + lines.append(f" Usage: {self.used} requests") + + # Tavily breakdown + breakdown_parts = [] + if self.search_used is not None: + breakdown_parts.append(f"Search: {self.search_used}") + if self.extract_used is not None: + breakdown_parts.append(f"Extract: {self.extract_used}") + if self.crawl_used is not None: + breakdown_parts.append(f"Crawl: {self.crawl_used}") + if breakdown_parts: + lines.append(f" Breakdown: {' | '.join(breakdown_parts)}") + + if self.remaining is not None: + lines.append(f" Remaining: {self.remaining} requests") + + if self.reset_date: + lines.append(f" Resets: {self.reset_date}") + + return "\n".join(lines) + + +async def fetch_search_usage( + provider: str, + api_key: str | None = None, +) -> SearchUsageInfo: + """ + Fetch usage info for the configured web search provider. + + Args: + provider: Provider name (e.g. "tavily", "brave", "duckduckgo"). + api_key: API key for the provider (falls back to env vars). + + Returns: + SearchUsageInfo with populated fields where available. + """ + p = (provider or "duckduckgo").strip().lower() + + if p == "tavily": + return await _fetch_tavily_usage(api_key) + else: + # brave, duckduckgo, searxng, jina, unknown β€” no usage API + return SearchUsageInfo(provider=p, supported=False) + + +# --------------------------------------------------------------------------- +# Tavily +# --------------------------------------------------------------------------- + +async def _fetch_tavily_usage(api_key: str | None) -> SearchUsageInfo: + """Fetch usage from GET https://api.tavily.com/usage.""" + import httpx + + key = api_key or os.environ.get("TAVILY_API_KEY", "") + if not key: + return SearchUsageInfo( + provider="tavily", + supported=True, + error="TAVILY_API_KEY not configured", + ) + + try: + async with httpx.AsyncClient(timeout=8.0) as client: + r = await client.get( + "https://api.tavily.com/usage", + headers={"Authorization": f"Bearer {key}"}, + ) + r.raise_for_status() + data: dict[str, Any] = r.json() + return _parse_tavily_usage(data) + except httpx.HTTPStatusError as e: + return SearchUsageInfo( + provider="tavily", + supported=True, + error=f"HTTP {e.response.status_code}", + ) + except Exception as e: + return SearchUsageInfo( + provider="tavily", + supported=True, + error=str(e)[:80], + ) + + +def _parse_tavily_usage(data: dict[str, Any]) -> SearchUsageInfo: + """ + Parse Tavily /usage response. + + Actual API response shape: + { + "account": { + "current_plan": "Researcher", + "plan_usage": 20, + "plan_limit": 1000, + "search_usage": 20, + "crawl_usage": 0, + "extract_usage": 0, + "map_usage": 0, + "research_usage": 0, + "paygo_usage": 0, + "paygo_limit": null + } + } + """ + account = data.get("account") or {} + used = account.get("plan_usage") + limit = account.get("plan_limit") + + # Compute remaining + remaining = None + if used is not None and limit is not None: + remaining = max(0, limit - used) + + return SearchUsageInfo( + provider="tavily", + supported=True, + used=used, + limit=limit, + remaining=remaining, + search_used=account.get("search_usage"), + extract_used=account.get("extract_usage"), + crawl_used=account.get("crawl_usage"), + ) + + diff --git a/nanobot/utils/tool_hints.py b/nanobot/utils/tool_hints.py new file mode 100644 index 000000000..a907a2700 --- /dev/null +++ b/nanobot/utils/tool_hints.py @@ -0,0 +1,119 @@ +"""Tool hint formatting for concise, human-readable tool call display.""" + +from __future__ import annotations + +from nanobot.utils.path import abbreviate_path + +# Registry: tool_name -> (key_args, template, is_path, is_command) +_TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = { + "read_file": (["path", "file_path"], "read {}", True, False), + "write_file": (["path", "file_path"], "write {}", True, False), + "edit": (["file_path", "path"], "edit {}", True, False), + "glob": (["pattern"], 'glob "{}"', False, False), + "grep": (["pattern"], 'grep "{}"', False, False), + "exec": (["command"], "$ {}", False, True), + "web_search": (["query"], 'search "{}"', False, False), + "web_fetch": (["url"], "fetch {}", True, False), + "list_dir": (["path"], "ls {}", True, False), +} + + +def format_tool_hints(tool_calls: list) -> str: + """Format tool calls as concise hints with smart abbreviation.""" + if not tool_calls: + return "" + + hints = [] + for name, count, example_tc in _group_consecutive(tool_calls): + fmt = _TOOL_FORMATS.get(name) + if fmt: + hint = _fmt_known(example_tc, fmt) + elif name.startswith("mcp_"): + hint = _fmt_mcp(example_tc) + else: + hint = _fmt_fallback(example_tc) + + if count > 1: + hint = f"{hint} \u00d7 {count}" + hints.append(hint) + + return ", ".join(hints) + + +def _get_args(tc) -> dict: + """Extract args dict from tc.arguments, handling list/dict/None/empty.""" + if tc.arguments is None: + return {} + if isinstance(tc.arguments, list): + return tc.arguments[0] if tc.arguments else {} + if isinstance(tc.arguments, dict): + return tc.arguments + return {} + + +def _group_consecutive(calls: list) -> list[tuple[str, int, object]]: + """Group consecutive calls to the same tool: [(name, count, first), ...].""" + groups: list[tuple[str, int, object]] = [] + for tc in calls: + if groups and groups[-1][0] == tc.name: + groups[-1] = (groups[-1][0], groups[-1][1] + 1, groups[-1][2]) + else: + groups.append((tc.name, 1, tc)) + return groups + + +def _extract_arg(tc, key_args: list[str]) -> str | None: + """Extract the first available value from preferred key names.""" + args = _get_args(tc) + if not isinstance(args, dict): + return None + for key in key_args: + val = args.get(key) + if isinstance(val, str) and val: + return val + for val in args.values(): + if isinstance(val, str) and val: + return val + return None + + +def _fmt_known(tc, fmt: tuple) -> str: + """Format a registered tool using its template.""" + val = _extract_arg(tc, fmt[0]) + if val is None: + return tc.name + if fmt[2]: # is_path + val = abbreviate_path(val) + elif fmt[3]: # is_command + val = val[:40] + "\u2026" if len(val) > 40 else val + return fmt[1].format(val) + + +def _fmt_mcp(tc) -> str: + """Format MCP tool as server::tool.""" + name = tc.name + if "__" in name: + parts = name.split("__", 1) + server = parts[0].removeprefix("mcp_") + tool = parts[1] + else: + rest = name.removeprefix("mcp_") + parts = rest.split("_", 1) + server = parts[0] if parts else rest + tool = parts[1] if len(parts) > 1 else "" + if not tool: + return name + args = _get_args(tc) + val = next((v for v in args.values() if isinstance(v, str) and v), None) + if val is None: + return f"{server}::{tool}" + return f'{server}::{tool}("{abbreviate_path(val, 40)}")' + + +def _fmt_fallback(tc) -> str: + """Original formatting logic for unregistered tools.""" + args = _get_args(tc) + val = next(iter(args.values()), None) if isinstance(args, dict) else None + if not isinstance(val, str): + return tc.name + return f'{tc.name}("{abbreviate_path(val, 40)}")' if len(val) > 40 else f'{tc.name}("{val}")' diff --git a/pyproject.toml b/pyproject.toml index e9aa6198d..ab40d3f90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "nanobot-ai" -version = "0.1.4.post6" +version = "0.1.5" description = "A lightweight personal AI assistant framework" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index dcdd15031..a700f495b 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -458,6 +458,7 @@ async def test_runner_uses_raw_messages_when_context_governance_fails(): @pytest.mark.asyncio async def test_runner_retries_empty_final_response_with_summary_prompt(): + """Empty responses get 2 silent retries before finalization kicks in.""" from nanobot.agent.runner import AgentRunSpec, AgentRunner provider = MagicMock() @@ -465,11 +466,11 @@ async def test_runner_retries_empty_final_response_with_summary_prompt(): async def chat_with_retry(*, messages, tools=None, **kwargs): calls.append({"messages": messages, "tools": tools}) - if len(calls) == 1: + if len(calls) <= 2: return LLMResponse( content=None, tool_calls=[], - usage={"prompt_tokens": 10, "completion_tokens": 1}, + usage={"prompt_tokens": 5, "completion_tokens": 1}, ) return LLMResponse( content="final answer", @@ -486,20 +487,23 @@ async def test_runner_retries_empty_final_response_with_summary_prompt(): initial_messages=[{"role": "user", "content": "do task"}], tools=tools, model="test-model", - max_iterations=1, + max_iterations=3, max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, )) assert result.final_content == "final answer" - assert len(calls) == 2 - assert calls[1]["tools"] is None - assert "Do not call any more tools" in calls[1]["messages"][-1]["content"] + # 2 silent retries (iterations 0,1) + finalization on iteration 1 + assert len(calls) == 3 + assert calls[0]["tools"] is not None + assert calls[1]["tools"] is not None + assert calls[2]["tools"] is None assert result.usage["prompt_tokens"] == 13 - assert result.usage["completion_tokens"] == 8 + assert result.usage["completion_tokens"] == 9 @pytest.mark.asyncio async def test_runner_uses_specific_message_after_empty_finalization_retry(): + """After silent retries + finalization all return empty, stop_reason is empty_final_response.""" from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE @@ -517,7 +521,7 @@ async def test_runner_uses_specific_message_after_empty_finalization_retry(): initial_messages=[{"role": "user", "content": "do task"}], tools=tools, model="test-model", - max_iterations=1, + max_iterations=3, max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, )) @@ -525,6 +529,66 @@ async def test_runner_uses_specific_message_after_empty_finalization_retry(): assert result.stop_reason == "empty_final_response" +@pytest.mark.asyncio +async def test_runner_empty_response_does_not_break_tool_chain(): + """An empty intermediate response must not kill an ongoing tool chain. + + Sequence: tool_call β†’ empty β†’ tool_call β†’ final text. + The runner should recover via silent retry and complete normally. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + call_count = 0 + + async def chat_with_retry(*, messages, tools=None, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return LLMResponse( + content=None, + tool_calls=[ToolCallRequest(id="tc1", name="read_file", arguments={"path": "a.txt"})], + usage={"prompt_tokens": 10, "completion_tokens": 5}, + ) + if call_count == 2: + return LLMResponse(content=None, tool_calls=[], usage={"prompt_tokens": 10, "completion_tokens": 1}) + if call_count == 3: + return LLMResponse( + content=None, + tool_calls=[ToolCallRequest(id="tc2", name="read_file", arguments={"path": "b.txt"})], + usage={"prompt_tokens": 10, "completion_tokens": 5}, + ) + return LLMResponse( + content="Here are the results.", + tool_calls=[], + usage={"prompt_tokens": 10, "completion_tokens": 10}, + ) + + provider.chat_with_retry = chat_with_retry + provider.chat_stream_with_retry = chat_with_retry + + async def fake_tool(name, args, **kw): + return "file content" + + tool_registry = MagicMock() + tool_registry.get_definitions.return_value = [{"type": "function", "function": {"name": "read_file"}}] + tool_registry.execute = AsyncMock(side_effect=fake_tool) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "read both files"}], + tools=tool_registry, + model="test-model", + max_iterations=10, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "Here are the results." + assert result.stop_reason == "completed" + assert call_count == 4 + assert "read_file" in result.tools_used + + def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch): from nanobot.agent.runner import AgentRunSpec, AgentRunner diff --git a/tests/agent/test_skills_loader.py b/tests/agent/test_skills_loader.py new file mode 100644 index 000000000..46923c806 --- /dev/null +++ b/tests/agent/test_skills_loader.py @@ -0,0 +1,252 @@ +"""Tests for nanobot.agent.skills.SkillsLoader.""" + +from __future__ import annotations + +import json +from pathlib import Path + +import pytest + +from nanobot.agent.skills import SkillsLoader + + +def _write_skill( + base: Path, + name: str, + *, + metadata_json: dict | None = None, + body: str = "# Skill\n", +) -> Path: + """Create ``base / name / SKILL.md`` with optional nanobot metadata JSON.""" + skill_dir = base / name + skill_dir.mkdir(parents=True) + lines = ["---"] + if metadata_json is not None: + payload = json.dumps({"nanobot": metadata_json}, separators=(",", ":")) + lines.append(f'metadata: {payload}') + lines.extend(["---", "", body]) + path = skill_dir / "SKILL.md" + path.write_text("\n".join(lines), encoding="utf-8") + return path + + +def test_list_skills_empty_when_skills_dir_missing(tmp_path: Path) -> None: + workspace = tmp_path / "ws" + workspace.mkdir() + builtin = tmp_path / "builtin" + builtin.mkdir() + loader = SkillsLoader(workspace, builtin_skills_dir=builtin) + assert loader.list_skills(filter_unavailable=False) == [] + + +def test_list_skills_empty_when_skills_dir_exists_but_empty(tmp_path: Path) -> None: + workspace = tmp_path / "ws" + (workspace / "skills").mkdir(parents=True) + builtin = tmp_path / "builtin" + builtin.mkdir() + loader = SkillsLoader(workspace, builtin_skills_dir=builtin) + assert loader.list_skills(filter_unavailable=False) == [] + + +def test_list_skills_workspace_entry_shape_and_source(tmp_path: Path) -> None: + workspace = tmp_path / "ws" + skills_root = workspace / "skills" + skills_root.mkdir(parents=True) + skill_path = _write_skill(skills_root, "alpha", body="# Alpha") + builtin = tmp_path / "builtin" + builtin.mkdir() + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin) + entries = loader.list_skills(filter_unavailable=False) + assert entries == [ + {"name": "alpha", "path": str(skill_path), "source": "workspace"}, + ] + + +def test_list_skills_skips_non_directories_and_missing_skill_md(tmp_path: Path) -> None: + workspace = tmp_path / "ws" + skills_root = workspace / "skills" + skills_root.mkdir(parents=True) + (skills_root / "not_a_dir.txt").write_text("x", encoding="utf-8") + (skills_root / "no_skill_md").mkdir() + ok_path = _write_skill(skills_root, "ok", body="# Ok") + builtin = tmp_path / "builtin" + builtin.mkdir() + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin) + entries = loader.list_skills(filter_unavailable=False) + names = {entry["name"] for entry in entries} + assert names == {"ok"} + assert entries[0]["path"] == str(ok_path) + + +def test_list_skills_workspace_shadows_builtin_same_name(tmp_path: Path) -> None: + workspace = tmp_path / "ws" + ws_skills = workspace / "skills" + ws_skills.mkdir(parents=True) + ws_path = _write_skill(ws_skills, "dup", body="# Workspace wins") + + builtin = tmp_path / "builtin" + _write_skill(builtin, "dup", body="# Builtin") + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin) + entries = loader.list_skills(filter_unavailable=False) + assert len(entries) == 1 + assert entries[0]["source"] == "workspace" + assert entries[0]["path"] == str(ws_path) + + +def test_list_skills_merges_workspace_and_builtin(tmp_path: Path) -> None: + workspace = tmp_path / "ws" + ws_skills = workspace / "skills" + ws_skills.mkdir(parents=True) + ws_path = _write_skill(ws_skills, "ws_only", body="# W") + builtin = tmp_path / "builtin" + bi_path = _write_skill(builtin, "bi_only", body="# B") + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin) + entries = sorted(loader.list_skills(filter_unavailable=False), key=lambda item: item["name"]) + assert entries == [ + {"name": "bi_only", "path": str(bi_path), "source": "builtin"}, + {"name": "ws_only", "path": str(ws_path), "source": "workspace"}, + ] + + +def test_list_skills_builtin_omitted_when_dir_missing(tmp_path: Path) -> None: + workspace = tmp_path / "ws" + ws_skills = workspace / "skills" + ws_skills.mkdir(parents=True) + ws_path = _write_skill(ws_skills, "solo", body="# S") + missing_builtin = tmp_path / "no_such_builtin" + + loader = SkillsLoader(workspace, builtin_skills_dir=missing_builtin) + entries = loader.list_skills(filter_unavailable=False) + assert entries == [{"name": "solo", "path": str(ws_path), "source": "workspace"}] + + +def test_list_skills_filter_unavailable_excludes_unmet_bin_requirement( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + workspace = tmp_path / "ws" + skills_root = workspace / "skills" + skills_root.mkdir(parents=True) + _write_skill( + skills_root, + "needs_bin", + metadata_json={"requires": {"bins": ["nanobot_test_fake_binary"]}}, + ) + builtin = tmp_path / "builtin" + builtin.mkdir() + + def fake_which(cmd: str) -> str | None: + if cmd == "nanobot_test_fake_binary": + return None + return "/usr/bin/true" + + monkeypatch.setattr("nanobot.agent.skills.shutil.which", fake_which) + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin) + assert loader.list_skills(filter_unavailable=True) == [] + + +def test_list_skills_filter_unavailable_includes_when_bin_requirement_met( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + workspace = tmp_path / "ws" + skills_root = workspace / "skills" + skills_root.mkdir(parents=True) + skill_path = _write_skill( + skills_root, + "has_bin", + metadata_json={"requires": {"bins": ["nanobot_test_fake_binary"]}}, + ) + builtin = tmp_path / "builtin" + builtin.mkdir() + + def fake_which(cmd: str) -> str | None: + if cmd == "nanobot_test_fake_binary": + return "/fake/nanobot_test_fake_binary" + return None + + monkeypatch.setattr("nanobot.agent.skills.shutil.which", fake_which) + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin) + entries = loader.list_skills(filter_unavailable=True) + assert entries == [ + {"name": "has_bin", "path": str(skill_path), "source": "workspace"}, + ] + + +def test_list_skills_filter_unavailable_false_keeps_unmet_requirements( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + workspace = tmp_path / "ws" + skills_root = workspace / "skills" + skills_root.mkdir(parents=True) + skill_path = _write_skill( + skills_root, + "blocked", + metadata_json={"requires": {"bins": ["nanobot_test_fake_binary"]}}, + ) + builtin = tmp_path / "builtin" + builtin.mkdir() + + monkeypatch.setattr("nanobot.agent.skills.shutil.which", lambda _cmd: None) + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin) + entries = loader.list_skills(filter_unavailable=False) + assert entries == [ + {"name": "blocked", "path": str(skill_path), "source": "workspace"}, + ] + + +def test_list_skills_filter_unavailable_excludes_unmet_env_requirement( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + workspace = tmp_path / "ws" + skills_root = workspace / "skills" + skills_root.mkdir(parents=True) + _write_skill( + skills_root, + "needs_env", + metadata_json={"requires": {"env": ["NANOBOT_SKILLS_TEST_ENV_VAR"]}}, + ) + builtin = tmp_path / "builtin" + builtin.mkdir() + + monkeypatch.delenv("NANOBOT_SKILLS_TEST_ENV_VAR", raising=False) + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin) + assert loader.list_skills(filter_unavailable=True) == [] + + +def test_list_skills_openclaw_metadata_parsed_for_requirements( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + workspace = tmp_path / "ws" + skills_root = workspace / "skills" + skills_root.mkdir(parents=True) + skill_dir = skills_root / "openclaw_skill" + skill_dir.mkdir(parents=True) + skill_path = skill_dir / "SKILL.md" + oc_payload = json.dumps({"openclaw": {"requires": {"bins": ["nanobot_oc_bin"]}}}, separators=(",", ":")) + skill_path.write_text( + "\n".join(["---", f"metadata: {oc_payload}", "---", "", "# OC"]), + encoding="utf-8", + ) + builtin = tmp_path / "builtin" + builtin.mkdir() + + monkeypatch.setattr("nanobot.agent.skills.shutil.which", lambda _cmd: None) + + loader = SkillsLoader(workspace, builtin_skills_dir=builtin) + assert loader.list_skills(filter_unavailable=True) == [] + + monkeypatch.setattr( + "nanobot.agent.skills.shutil.which", + lambda cmd: "/x" if cmd == "nanobot_oc_bin" else None, + ) + entries = loader.list_skills(filter_unavailable=True) + assert entries == [ + {"name": "openclaw_skill", "path": str(skill_path), "source": "workspace"}, + ] diff --git a/tests/agent/test_tool_hint.py b/tests/agent/test_tool_hint.py new file mode 100644 index 000000000..2384cfbb2 --- /dev/null +++ b/tests/agent/test_tool_hint.py @@ -0,0 +1,202 @@ +"""Tests for tool hint formatting (nanobot.utils.tool_hints).""" + +from nanobot.utils.tool_hints import format_tool_hints +from nanobot.providers.base import ToolCallRequest + + +def _tc(name: str, args) -> ToolCallRequest: + return ToolCallRequest(id="c1", name=name, arguments=args) + + +def _hint(calls): + """Shortcut for format_tool_hints.""" + return format_tool_hints(calls) + + +class TestToolHintKnownTools: + """Test registered tool types produce correct formatted output.""" + + def test_read_file_short_path(self): + result = _hint([_tc("read_file", {"path": "foo.txt"})]) + assert result == 'read foo.txt' + + def test_read_file_long_path(self): + result = _hint([_tc("read_file", {"path": "/home/user/.local/share/uv/tools/nanobot/agent/loop.py"})]) + assert "loop.py" in result + assert "read " in result + + def test_write_file_shows_path_not_content(self): + result = _hint([_tc("write_file", {"path": "docs/api.md", "content": "# API Reference\n\nLong content..."})]) + assert result == "write docs/api.md" + + def test_edit_shows_path(self): + result = _hint([_tc("edit", {"file_path": "src/main.py", "old_string": "x", "new_string": "y"})]) + assert "main.py" in result + assert "edit " in result + + def test_glob_shows_pattern(self): + result = _hint([_tc("glob", {"pattern": "**/*.py", "path": "src"})]) + assert result == 'glob "**/*.py"' + + def test_grep_shows_pattern(self): + result = _hint([_tc("grep", {"pattern": "TODO|FIXME", "path": "src"})]) + assert result == 'grep "TODO|FIXME"' + + def test_exec_shows_command(self): + result = _hint([_tc("exec", {"command": "npm install typescript"})]) + assert result == "$ npm install typescript" + + def test_exec_truncates_long_command(self): + cmd = "cd /very/long/path && cat file && echo done && sleep 1 && ls -la" + result = _hint([_tc("exec", {"command": cmd})]) + assert result.startswith("$ ") + assert len(result) <= 50 # reasonable limit + + def test_web_search(self): + result = _hint([_tc("web_search", {"query": "Claude 4 vs GPT-4"})]) + assert result == 'search "Claude 4 vs GPT-4"' + + def test_web_fetch(self): + result = _hint([_tc("web_fetch", {"url": "https://example.com/page"})]) + assert result == "fetch https://example.com/page" + + +class TestToolHintMCP: + """Test MCP tools are abbreviated to server::tool format.""" + + def test_mcp_standard_format(self): + result = _hint([_tc("mcp_4_5v_mcp__analyze_image", {"imageSource": "https://img.jpg", "prompt": "describe"})]) + assert "4_5v" in result + assert "analyze_image" in result + + def test_mcp_simple_name(self): + result = _hint([_tc("mcp_github__create_issue", {"title": "Bug fix"})]) + assert "github" in result + assert "create_issue" in result + + +class TestToolHintFallback: + """Test unknown tools fall back to original behavior.""" + + def test_unknown_tool_with_string_arg(self): + result = _hint([_tc("custom_tool", {"data": "hello world"})]) + assert result == 'custom_tool("hello world")' + + def test_unknown_tool_with_long_arg_truncates(self): + long_val = "a" * 60 + result = _hint([_tc("custom_tool", {"data": long_val})]) + assert len(result) < 80 + assert "\u2026" in result + + def test_unknown_tool_no_string_arg(self): + result = _hint([_tc("custom_tool", {"count": 42})]) + assert result == "custom_tool" + + def test_empty_tool_calls(self): + result = _hint([]) + assert result == "" + + +class TestToolHintFolding: + """Test consecutive same-tool calls are folded.""" + + def test_single_call_no_fold(self): + calls = [_tc("grep", {"pattern": "*.py"})] + result = _hint(calls) + assert "\u00d7" not in result + + def test_two_consecutive_same_folded(self): + calls = [ + _tc("grep", {"pattern": "*.py"}), + _tc("grep", {"pattern": "*.ts"}), + ] + result = _hint(calls) + assert "\u00d7 2" in result + + def test_three_consecutive_same_folded(self): + calls = [ + _tc("read_file", {"path": "a.py"}), + _tc("read_file", {"path": "b.py"}), + _tc("read_file", {"path": "c.py"}), + ] + result = _hint(calls) + assert "\u00d7 3" in result + + def test_different_tools_not_folded(self): + calls = [ + _tc("grep", {"pattern": "TODO"}), + _tc("read_file", {"path": "a.py"}), + ] + result = _hint(calls) + assert "\u00d7" not in result + + def test_interleaved_same_tools_not_folded(self): + calls = [ + _tc("grep", {"pattern": "a"}), + _tc("read_file", {"path": "f.py"}), + _tc("grep", {"pattern": "b"}), + ] + result = _hint(calls) + assert "\u00d7" not in result + + +class TestToolHintMultipleCalls: + """Test multiple different tool calls are comma-separated.""" + + def test_two_different_tools(self): + calls = [ + _tc("grep", {"pattern": "TODO"}), + _tc("read_file", {"path": "main.py"}), + ] + result = _hint(calls) + assert 'grep "TODO"' in result + assert "read main.py" in result + assert ", " in result + + +class TestToolHintEdgeCases: + """Test edge cases and defensive handling (G1, G2).""" + + def test_known_tool_empty_list_args(self): + """C1/G1: Empty list arguments should not crash.""" + result = _hint([_tc("read_file", [])]) + assert result == "read_file" + + def test_known_tool_none_args(self): + """G2: None arguments should not crash.""" + result = _hint([_tc("read_file", None)]) + assert result == "read_file" + + def test_fallback_empty_list_args(self): + """C1: Empty list args in fallback should not crash.""" + result = _hint([_tc("custom_tool", [])]) + assert result == "custom_tool" + + def test_fallback_none_args(self): + """G2: None args in fallback should not crash.""" + result = _hint([_tc("custom_tool", None)]) + assert result == "custom_tool" + + def test_list_dir_registered(self): + """S2: list_dir should use 'ls' format.""" + result = _hint([_tc("list_dir", {"path": "/tmp"})]) + assert result == "ls /tmp" + + +class TestToolHintMixedFolding: + """G4: Mixed folding groups with interleaved same-tool segments.""" + + def test_read_read_grep_grep_read(self): + """readΓ—2, grepΓ—2, read β€” should produce two separate groups.""" + calls = [ + _tc("read_file", {"path": "a.py"}), + _tc("read_file", {"path": "b.py"}), + _tc("grep", {"pattern": "x"}), + _tc("grep", {"pattern": "y"}), + _tc("read_file", {"path": "c.py"}), + ] + result = _hint(calls) + assert "\u00d7 2" in result + # Should have 3 groups: readΓ—2, grepΓ—2, read + parts = result.split(", ") + assert len(parts) == 3 diff --git a/tests/channels/test_email_channel.py b/tests/channels/test_email_channel.py index 2d0e33ce3..6d6d2f74f 100644 --- a/tests/channels/test_email_channel.py +++ b/tests/channels/test_email_channel.py @@ -1,5 +1,6 @@ from email.message import EmailMessage from datetime import date +from pathlib import Path import imaplib import pytest @@ -650,3 +651,224 @@ def test_check_authentication_results_method() -> None: spf, dkim = EmailChannel._check_authentication_results(parsed) assert spf is False assert dkim is True + + +# --------------------------------------------------------------------------- +# Attachment extraction tests +# --------------------------------------------------------------------------- + + +def _make_raw_email_with_attachment( + from_addr: str = "alice@example.com", + subject: str = "With attachment", + body: str = "See attached.", + attachment_name: str = "doc.pdf", + attachment_content: bytes = b"%PDF-1.4 fake pdf content", + attachment_mime: str = "application/pdf", + auth_results: str | None = None, +) -> bytes: + msg = EmailMessage() + msg["From"] = from_addr + msg["To"] = "bot@example.com" + msg["Subject"] = subject + msg["Message-ID"] = "" + if auth_results: + msg["Authentication-Results"] = auth_results + msg.set_content(body) + maintype, subtype = attachment_mime.split("/", 1) + msg.add_attachment( + attachment_content, + maintype=maintype, + subtype=subtype, + filename=attachment_name, + ) + return msg.as_bytes() + + +def test_extract_attachments_saves_pdf(tmp_path, monkeypatch) -> None: + """PDF attachment is saved to media dir and path returned in media list.""" + monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path) + + raw = _make_raw_email_with_attachment() + fake = _make_fake_imap(raw) + monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake) + + cfg = _make_config(allowed_attachment_types=["application/pdf"], verify_dkim=False, verify_spf=False) + channel = EmailChannel(cfg, MessageBus()) + items = channel._fetch_new_messages() + + assert len(items) == 1 + assert len(items[0]["media"]) == 1 + saved_path = Path(items[0]["media"][0]) + assert saved_path.exists() + assert saved_path.read_bytes() == b"%PDF-1.4 fake pdf content" + assert "500_doc.pdf" in saved_path.name + assert "[attachment:" in items[0]["content"] + + +def test_extract_attachments_disabled_by_default(monkeypatch) -> None: + """With no allowed_attachment_types (default), no attachments are extracted.""" + raw = _make_raw_email_with_attachment() + fake = _make_fake_imap(raw) + monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake) + + cfg = _make_config(verify_dkim=False, verify_spf=False) + assert cfg.allowed_attachment_types == [] + channel = EmailChannel(cfg, MessageBus()) + items = channel._fetch_new_messages() + + assert len(items) == 1 + assert items[0]["media"] == [] + assert "[attachment:" not in items[0]["content"] + + +def test_extract_attachments_mime_type_filter(tmp_path, monkeypatch) -> None: + """Non-allowed MIME types are skipped.""" + monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path) + + raw = _make_raw_email_with_attachment( + attachment_name="image.png", + attachment_content=b"\x89PNG fake", + attachment_mime="image/png", + ) + fake = _make_fake_imap(raw) + monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake) + + cfg = _make_config( + allowed_attachment_types=["application/pdf"], + verify_dkim=False, + verify_spf=False, + ) + channel = EmailChannel(cfg, MessageBus()) + items = channel._fetch_new_messages() + + assert len(items) == 1 + assert items[0]["media"] == [] + + +def test_extract_attachments_empty_allowed_types_rejects_all(tmp_path, monkeypatch) -> None: + """Empty allowed_attachment_types means no types are accepted.""" + monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path) + + raw = _make_raw_email_with_attachment( + attachment_name="image.png", + attachment_content=b"\x89PNG fake", + attachment_mime="image/png", + ) + fake = _make_fake_imap(raw) + monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake) + + cfg = _make_config( + allowed_attachment_types=[], + verify_dkim=False, + verify_spf=False, + ) + channel = EmailChannel(cfg, MessageBus()) + items = channel._fetch_new_messages() + + assert len(items) == 1 + assert items[0]["media"] == [] + + +def test_extract_attachments_wildcard_pattern(tmp_path, monkeypatch) -> None: + """Glob patterns like 'image/*' match attachment MIME types.""" + monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path) + + raw = _make_raw_email_with_attachment( + attachment_name="photo.jpg", + attachment_content=b"\xff\xd8\xff fake jpeg", + attachment_mime="image/jpeg", + ) + fake = _make_fake_imap(raw) + monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake) + + cfg = _make_config( + allowed_attachment_types=["image/*"], + verify_dkim=False, + verify_spf=False, + ) + channel = EmailChannel(cfg, MessageBus()) + items = channel._fetch_new_messages() + + assert len(items) == 1 + assert len(items[0]["media"]) == 1 + + +def test_extract_attachments_size_limit(tmp_path, monkeypatch) -> None: + """Attachments exceeding max_attachment_size are skipped.""" + monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path) + + raw = _make_raw_email_with_attachment( + attachment_content=b"x" * 1000, + ) + fake = _make_fake_imap(raw) + monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake) + + cfg = _make_config( + allowed_attachment_types=["*"], + max_attachment_size=500, + verify_dkim=False, + verify_spf=False, + ) + channel = EmailChannel(cfg, MessageBus()) + items = channel._fetch_new_messages() + + assert len(items) == 1 + assert items[0]["media"] == [] + + +def test_extract_attachments_max_count(tmp_path, monkeypatch) -> None: + """Only max_attachments_per_email are saved.""" + monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path) + + # Build email with 3 attachments + msg = EmailMessage() + msg["From"] = "alice@example.com" + msg["To"] = "bot@example.com" + msg["Subject"] = "Many attachments" + msg["Message-ID"] = "" + msg.set_content("See attached.") + for i in range(3): + msg.add_attachment( + f"content {i}".encode(), + maintype="application", + subtype="pdf", + filename=f"doc{i}.pdf", + ) + raw = msg.as_bytes() + + fake = _make_fake_imap(raw) + monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake) + + cfg = _make_config( + allowed_attachment_types=["*"], + max_attachments_per_email=2, + verify_dkim=False, + verify_spf=False, + ) + channel = EmailChannel(cfg, MessageBus()) + items = channel._fetch_new_messages() + + assert len(items) == 1 + assert len(items[0]["media"]) == 2 + + +def test_extract_attachments_sanitizes_filename(tmp_path, monkeypatch) -> None: + """Path traversal in filenames is neutralized.""" + monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path) + + raw = _make_raw_email_with_attachment( + attachment_name="../../../etc/passwd", + ) + fake = _make_fake_imap(raw) + monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake) + + cfg = _make_config(allowed_attachment_types=["*"], verify_dkim=False, verify_spf=False) + channel = EmailChannel(cfg, MessageBus()) + items = channel._fetch_new_messages() + + assert len(items) == 1 + assert len(items[0]["media"]) == 1 + saved_path = Path(items[0]["media"][0]) + # File must be inside the media dir, not escaped via path traversal + assert saved_path.parent == tmp_path diff --git a/tests/channels/test_feishu_mention.py b/tests/channels/test_feishu_mention.py new file mode 100644 index 000000000..fb81f2294 --- /dev/null +++ b/tests/channels/test_feishu_mention.py @@ -0,0 +1,62 @@ +"""Tests for Feishu _is_bot_mentioned logic.""" + +from types import SimpleNamespace + +import pytest + +from nanobot.channels.feishu import FeishuChannel + + +def _make_channel(bot_open_id: str | None = None) -> FeishuChannel: + config = SimpleNamespace( + app_id="test_id", + app_secret="test_secret", + verification_token="", + event_encrypt_key="", + group_policy="mention", + ) + ch = FeishuChannel.__new__(FeishuChannel) + ch.config = config + ch._bot_open_id = bot_open_id + return ch + + +def _make_message(mentions=None, content="hello"): + return SimpleNamespace(content=content, mentions=mentions) + + +def _make_mention(open_id: str, user_id: str | None = None): + mid = SimpleNamespace(open_id=open_id, user_id=user_id) + return SimpleNamespace(id=mid) + + +class TestIsBotMentioned: + def test_exact_match_with_bot_open_id(self): + ch = _make_channel(bot_open_id="ou_bot123") + msg = _make_message(mentions=[_make_mention("ou_bot123")]) + assert ch._is_bot_mentioned(msg) is True + + def test_no_match_different_bot(self): + ch = _make_channel(bot_open_id="ou_bot123") + msg = _make_message(mentions=[_make_mention("ou_other_bot")]) + assert ch._is_bot_mentioned(msg) is False + + def test_at_all_always_matches(self): + ch = _make_channel(bot_open_id="ou_bot123") + msg = _make_message(content="@_all hello") + assert ch._is_bot_mentioned(msg) is True + + def test_fallback_heuristic_when_no_bot_open_id(self): + ch = _make_channel(bot_open_id=None) + msg = _make_message(mentions=[_make_mention("ou_some_bot", user_id=None)]) + assert ch._is_bot_mentioned(msg) is True + + def test_fallback_ignores_user_mentions(self): + ch = _make_channel(bot_open_id=None) + msg = _make_message(mentions=[_make_mention("ou_user", user_id="u_12345")]) + assert ch._is_bot_mentioned(msg) is False + + def test_no_mentions_returns_false(self): + ch = _make_channel(bot_open_id="ou_bot123") + msg = _make_message(mentions=None) + assert ch._is_bot_mentioned(msg) is False diff --git a/tests/channels/test_feishu_mentions.py b/tests/channels/test_feishu_mentions.py new file mode 100644 index 000000000..a49e76ee6 --- /dev/null +++ b/tests/channels/test_feishu_mentions.py @@ -0,0 +1,59 @@ +"""Tests for FeishuChannel._resolve_mentions.""" + +from types import SimpleNamespace + +from nanobot.channels.feishu import FeishuChannel + + +def _mention(key: str, name: str, open_id: str = "", user_id: str = ""): + """Build a mock MentionEvent-like object.""" + id_obj = SimpleNamespace(open_id=open_id, user_id=user_id) if (open_id or user_id) else None + return SimpleNamespace(key=key, name=name, id=id_obj) + + +class TestResolveMentions: + def test_single_mention_replaced(self): + text = "hello @_user_1 how are you" + mentions = [_mention("@_user_1", "Alice", open_id="ou_abc123")] + result = FeishuChannel._resolve_mentions(text, mentions) + assert "@Alice (ou_abc123)" in result + assert "@_user_1" not in result + + def test_mention_with_both_ids(self): + text = "@_user_1 said hi" + mentions = [_mention("@_user_1", "Bob", open_id="ou_abc", user_id="uid_456")] + result = FeishuChannel._resolve_mentions(text, mentions) + assert "@Bob (ou_abc, user id: uid_456)" in result + + def test_mention_no_id_skipped(self): + """When mention has no id object, the placeholder is left unchanged.""" + text = "@_user_1 said hi" + mentions = [SimpleNamespace(key="@_user_1", name="Charlie", id=None)] + result = FeishuChannel._resolve_mentions(text, mentions) + assert result == "@_user_1 said hi" + + def test_multiple_mentions(self): + text = "@_user_1 and @_user_2 are here" + mentions = [ + _mention("@_user_1", "Alice", open_id="ou_a"), + _mention("@_user_2", "Bob", open_id="ou_b"), + ] + result = FeishuChannel._resolve_mentions(text, mentions) + assert "@Alice (ou_a)" in result + assert "@Bob (ou_b)" in result + assert "@_user_1" not in result + assert "@_user_2" not in result + + def test_no_mentions_returns_text(self): + assert FeishuChannel._resolve_mentions("hello world", None) == "hello world" + assert FeishuChannel._resolve_mentions("hello world", []) == "hello world" + + def test_empty_text_returns_empty(self): + mentions = [_mention("@_user_1", "Alice", open_id="ou_a")] + assert FeishuChannel._resolve_mentions("", mentions) == "" + + def test_mention_key_not_in_text_skipped(self): + text = "hello world" + mentions = [_mention("@_user_99", "Ghost", open_id="ou_ghost")] + result = FeishuChannel._resolve_mentions(text, mentions) + assert result == "hello world" diff --git a/tests/channels/test_feishu_tool_hint_code_block.py b/tests/channels/test_feishu_tool_hint_code_block.py index a65f1d988..a5db5ad69 100644 --- a/tests/channels/test_feishu_tool_hint_code_block.py +++ b/tests/channels/test_feishu_tool_hint_code_block.py @@ -127,6 +127,79 @@ async def test_tool_hint_multiple_tools_in_one_message(mock_feishu_channel): @mark.asyncio +async def test_tool_hint_new_format_basic(mock_feishu_channel): + """New format hints (read path, grep "pattern") should parse correctly.""" + msg = OutboundMessage( + channel="feishu", + chat_id="oc_123456", + content='read src/main.py, grep "TODO"', + metadata={"_tool_hint": True} + ) + + with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: + await mock_feishu_channel.send(msg) + + content = json.loads(mock_send.call_args[0][3]) + md = content["elements"][0]["content"] + assert "read src/main.py" in md + assert 'grep "TODO"' in md + + +@mark.asyncio +async def test_tool_hint_new_format_with_comma_in_quotes(mock_feishu_channel): + """Commas inside quoted arguments must not cause incorrect line splits.""" + msg = OutboundMessage( + channel="feishu", + chat_id="oc_123456", + content='grep "hello, world", $ echo test', + metadata={"_tool_hint": True} + ) + + with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: + await mock_feishu_channel.send(msg) + + content = json.loads(mock_send.call_args[0][3]) + md = content["elements"][0]["content"] + # The comma inside quotes should NOT cause a line break + assert 'grep "hello, world"' in md + assert "$ echo test" in md + + +@mark.asyncio +async def test_tool_hint_new_format_with_folding(mock_feishu_channel): + """Folded calls (Γ— N) should display on separate lines.""" + msg = OutboundMessage( + channel="feishu", + chat_id="oc_123456", + content='read path Γ— 3, grep "pattern"', + metadata={"_tool_hint": True} + ) + + with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: + await mock_feishu_channel.send(msg) + + content = json.loads(mock_send.call_args[0][3]) + md = content["elements"][0]["content"] + assert "\u00d7 3" in md + assert 'grep "pattern"' in md + + +@mark.asyncio +async def test_tool_hint_new_format_mcp(mock_feishu_channel): + """MCP tool format (server::tool) should parse correctly.""" + msg = OutboundMessage( + channel="feishu", + chat_id="oc_123456", + content='4_5v::analyze_image("photo.jpg")', + metadata={"_tool_hint": True} + ) + + with patch.object(mock_feishu_channel, '_send_message_sync') as mock_send: + await mock_feishu_channel.send(msg) + + content = json.loads(mock_send.call_args[0][3]) + md = content["elements"][0]["content"] + assert "4_5v::analyze_image" in md async def test_tool_hint_keeps_commas_inside_arguments(mock_feishu_channel): """Commas inside a single tool argument must not be split onto a new line.""" msg = OutboundMessage( diff --git a/tests/channels/test_telegram_channel.py b/tests/channels/test_telegram_channel.py index 9584ad547..7bc212804 100644 --- a/tests/channels/test_telegram_channel.py +++ b/tests/channels/test_telegram_channel.py @@ -385,6 +385,32 @@ async def test_send_delta_stream_end_treats_not_modified_as_success() -> None: assert "123" not in channel._stream_bufs +@pytest.mark.asyncio +async def test_send_delta_stream_end_splits_oversized_reply() -> None: + """Final streamed reply exceeding Telegram limit is split into chunks.""" + from nanobot.channels.telegram import TELEGRAM_MAX_MESSAGE_LEN + + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + channel._app.bot.edit_message_text = AsyncMock() + channel._app.bot.send_message = AsyncMock(return_value=SimpleNamespace(message_id=99)) + + oversized = "x" * (TELEGRAM_MAX_MESSAGE_LEN + 500) + channel._stream_bufs["123"] = _StreamBuf(text=oversized, message_id=7, last_edit=0.0) + + await channel.send_delta("123", "", {"_stream_end": True}) + + channel._app.bot.edit_message_text.assert_called_once() + edit_text = channel._app.bot.edit_message_text.call_args.kwargs.get("text", "") + assert len(edit_text) <= TELEGRAM_MAX_MESSAGE_LEN + + channel._app.bot.send_message.assert_called_once() + assert "123" not in channel._stream_bufs + + @pytest.mark.asyncio async def test_send_delta_new_stream_id_replaces_stale_buffer() -> None: channel = TelegramChannel( @@ -424,6 +450,23 @@ async def test_send_delta_incremental_edit_treats_not_modified_as_success() -> N assert channel._stream_bufs["123"].last_edit > 0.0 +@pytest.mark.asyncio +async def test_send_delta_initial_send_keeps_message_in_thread() -> None: + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["*"]), + MessageBus(), + ) + channel._app = _FakeApp(lambda: None) + + await channel.send_delta( + "123", + "hello", + {"_stream_delta": True, "_stream_id": "s:0", "message_thread_id": 42}, + ) + + assert channel._app.bot.sent_messages[0]["message_thread_id"] == 42 + + def test_derive_topic_session_key_uses_thread_id() -> None: message = SimpleNamespace( chat=SimpleNamespace(type="supergroup"), @@ -434,6 +477,27 @@ def test_derive_topic_session_key_uses_thread_id() -> None: assert TelegramChannel._derive_topic_session_key(message) == "telegram:-100123:topic:42" +def test_derive_topic_session_key_private_dm_thread() -> None: + """Private DM threads (Telegram Threaded Mode) must get their own session key.""" + message = SimpleNamespace( + chat=SimpleNamespace(type="private"), + chat_id=999, + message_thread_id=7, + ) + assert TelegramChannel._derive_topic_session_key(message) == "telegram:999:topic:7" + + +def test_derive_topic_session_key_none_without_thread() -> None: + """No thread id β†’ no topic session key, regardless of chat type.""" + for chat_type in ("private", "supergroup", "group"): + message = SimpleNamespace( + chat=SimpleNamespace(type=chat_type), + chat_id=123, + message_thread_id=None, + ) + assert TelegramChannel._derive_topic_session_key(message) is None + + def test_get_extension_falls_back_to_original_filename() -> None: channel = TelegramChannel(TelegramConfig(), MessageBus()) diff --git a/tests/channels/test_whatsapp_channel.py b/tests/channels/test_whatsapp_channel.py index 8223fdff3..b61033677 100644 --- a/tests/channels/test_whatsapp_channel.py +++ b/tests/channels/test_whatsapp_channel.py @@ -163,6 +163,107 @@ async def test_group_policy_mention_accepts_mentioned_group_message(): assert kwargs["sender_id"] == "user" +@pytest.mark.asyncio +async def test_sender_id_prefers_phone_jid_over_lid(): + """sender_id should resolve to phone number when @s.whatsapp.net JID is present.""" + ch = WhatsAppChannel({"enabled": True}, MagicMock()) + ch._handle_message = AsyncMock() + + await ch._handle_bridge_message( + json.dumps({ + "type": "message", + "id": "lid1", + "sender": "ABC123@lid.whatsapp.net", + "pn": "5551234@s.whatsapp.net", + "content": "hi", + "timestamp": 1, + }) + ) + + kwargs = ch._handle_message.await_args.kwargs + assert kwargs["sender_id"] == "5551234" + + +@pytest.mark.asyncio +async def test_lid_to_phone_cache_resolves_lid_only_messages(): + """When only LID is present, a cached LIDβ†’phone mapping should be used.""" + ch = WhatsAppChannel({"enabled": True}, MagicMock()) + ch._handle_message = AsyncMock() + + # First message: both phone and LID β†’ builds cache + await ch._handle_bridge_message( + json.dumps({ + "type": "message", + "id": "c1", + "sender": "LID99@lid.whatsapp.net", + "pn": "5559999@s.whatsapp.net", + "content": "first", + "timestamp": 1, + }) + ) + # Second message: only LID, no phone + await ch._handle_bridge_message( + json.dumps({ + "type": "message", + "id": "c2", + "sender": "LID99@lid.whatsapp.net", + "pn": "", + "content": "second", + "timestamp": 2, + }) + ) + + second_kwargs = ch._handle_message.await_args_list[1].kwargs + assert second_kwargs["sender_id"] == "5559999" + + +@pytest.mark.asyncio +async def test_voice_message_transcription_uses_media_path(): + """Voice messages are transcribed when media path is available.""" + ch = WhatsAppChannel({"enabled": True}, MagicMock()) + ch.transcription_provider = "openai" + ch.transcription_api_key = "sk-test" + ch._handle_message = AsyncMock() + ch.transcribe_audio = AsyncMock(return_value="Hello world") + + await ch._handle_bridge_message( + json.dumps({ + "type": "message", + "id": "v1", + "sender": "12345@s.whatsapp.net", + "pn": "", + "content": "[Voice Message]", + "timestamp": 1, + "media": ["/tmp/voice.ogg"], + }) + ) + + ch.transcribe_audio.assert_awaited_once_with("/tmp/voice.ogg") + kwargs = ch._handle_message.await_args.kwargs + assert kwargs["content"].startswith("Hello world") + + +@pytest.mark.asyncio +async def test_voice_message_no_media_shows_not_available(): + """Voice messages without media produce a fallback placeholder.""" + ch = WhatsAppChannel({"enabled": True}, MagicMock()) + ch._handle_message = AsyncMock() + + await ch._handle_bridge_message( + json.dumps({ + "type": "message", + "id": "v2", + "sender": "12345@s.whatsapp.net", + "pn": "", + "content": "[Voice Message]", + "timestamp": 1, + }) + ) + + kwargs = ch._handle_message.await_args.kwargs + assert kwargs["content"] == "[Voice Message: Audio not available]" + + def test_load_or_create_bridge_token_persists_generated_secret(tmp_path): token_path = tmp_path / "whatsapp-auth" / "bridge-token" diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 0f6ff8177..3a1e7145a 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -1,5 +1,7 @@ +import asyncio import json import re +import shutil from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch @@ -9,6 +11,7 @@ from typer.testing import CliRunner from nanobot.bus.events import OutboundMessage from nanobot.cli.commands import _make_provider, app from nanobot.config.schema import Config +from nanobot.cron.types import CronJob, CronPayload from nanobot.providers.openai_codex_provider import _strip_model_prefix from nanobot.providers.registry import find_by_name @@ -19,11 +22,6 @@ class _StopGatewayError(RuntimeError): pass -import shutil - -import pytest - - @pytest.fixture def mock_paths(): """Mock config/workspace paths for test isolation.""" @@ -31,7 +29,6 @@ def mock_paths(): patch("nanobot.config.loader.save_config") as mock_sc, \ patch("nanobot.config.loader.load_config") as mock_lc, \ patch("nanobot.cli.commands.get_workspace_path") as mock_ws: - base_dir = Path("./test_onboard_data") if base_dir.exists(): shutil.rmtree(base_dir) @@ -425,13 +422,13 @@ def mock_agent_runtime(tmp_path): config.agents.defaults.workspace = str(tmp_path / "default-workspace") with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \ + patch("nanobot.config.loader.resolve_config_env_vars", side_effect=lambda c: c), \ patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \ patch("nanobot.cli.commands._make_provider", return_value=object()), \ patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \ patch("nanobot.bus.queue.MessageBus"), \ patch("nanobot.cron.service.CronService"), \ patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls: - agent_loop = MagicMock() agent_loop.channels_config = None agent_loop.process_direct = AsyncMock( @@ -656,7 +653,9 @@ def test_agent_custom_config_workspace_does_not_migrate_legacy_cron( monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron) monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) - monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None) + monkeypatch.setattr( + "nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None + ) result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)]) @@ -739,6 +738,7 @@ def _patch_cli_command_runtime( set_config_path or (lambda _path: None), ) monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.config.loader.resolve_config_env_vars", lambda c: c) monkeypatch.setattr( "nanobot.cli.commands.sync_workspace_templates", sync_templates or (lambda _path: None), @@ -868,6 +868,115 @@ def test_gateway_uses_workspace_directory_for_cron_store(monkeypatch, tmp_path: assert seen["cron_store"] == config.workspace_path / "cron" / "jobs.json" +def test_gateway_cron_evaluator_receives_scheduled_reminder_context( + monkeypatch, tmp_path: Path +) -> None: + config_file = tmp_path / "instance" / "config.json" + config_file.parent.mkdir(parents=True) + config_file.write_text("{}") + + config = Config() + config.agents.defaults.workspace = str(tmp_path / "config-workspace") + provider = object() + bus = MagicMock() + bus.publish_outbound = AsyncMock() + seen: dict[str, object] = {} + + monkeypatch.setattr("nanobot.config.loader.set_config_path", lambda _path: None) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) + monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) + monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider) + monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus) + monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) + + class _FakeCron: + def __init__(self, _store_path: Path) -> None: + self.on_job = None + seen["cron"] = self + + class _FakeAgentLoop: + def __init__(self, *args, **kwargs) -> None: + self.model = "test-model" + self.tools = {} + + async def process_direct(self, *_args, **_kwargs): + return OutboundMessage( + channel="telegram", + chat_id="user-1", + content="Time to stretch.", + ) + + async def close_mcp(self) -> None: + return None + + async def run(self) -> None: + return None + + def stop(self) -> None: + return None + + class _StopAfterCronSetup: + def __init__(self, *_args, **_kwargs) -> None: + raise _StopGatewayError("stop") + + async def _capture_evaluate_response( + response: str, + task_context: str, + provider_arg: object, + model: str, + ) -> bool: + seen["response"] = response + seen["task_context"] = task_context + seen["provider"] = provider_arg + seen["model"] = model + return True + + monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCron) + monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) + monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _StopAfterCronSetup) + monkeypatch.setattr( + "nanobot.utils.evaluator.evaluate_response", + _capture_evaluate_response, + ) + + result = runner.invoke(app, ["gateway", "--config", str(config_file)]) + + assert isinstance(result.exception, _StopGatewayError) + cron = seen["cron"] + assert isinstance(cron, _FakeCron) + assert cron.on_job is not None + + job = CronJob( + id="cron-1", + name="stretch", + payload=CronPayload( + message="Remind me to stretch.", + deliver=True, + channel="telegram", + to="user-1", + ), + ) + + response = asyncio.run(cron.on_job(job)) + + assert response == "Time to stretch." + assert seen["response"] == "Time to stretch." + assert seen["provider"] is provider + assert seen["model"] == "test-model" + assert seen["task_context"] == ( + "[Scheduled Task] Timer finished.\n\n" + "Task 'stretch' has been triggered.\n" + "Scheduled instruction: Remind me to stretch." + ) + bus.publish_outbound.assert_awaited_once_with( + OutboundMessage( + channel="telegram", + chat_id="user-1", + content="Time to stretch.", + ) + ) + + def test_gateway_workspace_override_does_not_migrate_legacy_cron( monkeypatch, tmp_path: Path ) -> None: diff --git a/tests/cli/test_safe_file_history.py b/tests/cli/test_safe_file_history.py new file mode 100644 index 000000000..78b5e2339 --- /dev/null +++ b/tests/cli/test_safe_file_history.py @@ -0,0 +1,44 @@ +"""Regression tests for SafeFileHistory (issue #2846). + +Surrogate characters in CLI input must not crash history file writes. +""" + +from nanobot.cli.commands import SafeFileHistory + + +class TestSafeFileHistory: + def test_surrogate_replaced(self, tmp_path): + """Surrogate pairs are replaced with U+FFFD, not crash.""" + hist = SafeFileHistory(str(tmp_path / "history")) + hist.store_string("hello \udce9 world") + entries = list(hist.load_history_strings()) + assert len(entries) == 1 + assert "\udce9" not in entries[0] + assert "hello" in entries[0] + assert "world" in entries[0] + + def test_normal_text_unchanged(self, tmp_path): + hist = SafeFileHistory(str(tmp_path / "history")) + hist.store_string("normal ascii text") + entries = list(hist.load_history_strings()) + assert entries[0] == "normal ascii text" + + def test_emoji_preserved(self, tmp_path): + hist = SafeFileHistory(str(tmp_path / "history")) + hist.store_string("hello 🐈 nanobot") + entries = list(hist.load_history_strings()) + assert entries[0] == "hello 🐈 nanobot" + + def test_mixed_unicode_preserved(self, tmp_path): + """CJK + emoji + latin should all pass through cleanly.""" + hist = SafeFileHistory(str(tmp_path / "history")) + hist.store_string("δ½ ε₯½ hello こんにけは πŸŽ‰") + entries = list(hist.load_history_strings()) + assert entries[0] == "δ½ ε₯½ hello こんにけは πŸŽ‰" + + def test_multiple_surrogates(self, tmp_path): + hist = SafeFileHistory(str(tmp_path / "history")) + hist.store_string("\udce9\udcf1\udcff") + entries = list(hist.load_history_strings()) + assert len(entries) == 1 + assert "\udce9" not in entries[0] diff --git a/tests/config/test_env_interpolation.py b/tests/config/test_env_interpolation.py new file mode 100644 index 000000000..aefcc3e40 --- /dev/null +++ b/tests/config/test_env_interpolation.py @@ -0,0 +1,82 @@ +import json + +import pytest + +from nanobot.config.loader import ( + _resolve_env_vars, + load_config, + resolve_config_env_vars, + save_config, +) + + +class TestResolveEnvVars: + def test_replaces_string_value(self, monkeypatch): + monkeypatch.setenv("MY_SECRET", "hunter2") + assert _resolve_env_vars("${MY_SECRET}") == "hunter2" + + def test_partial_replacement(self, monkeypatch): + monkeypatch.setenv("HOST", "example.com") + assert _resolve_env_vars("https://${HOST}/api") == "https://example.com/api" + + def test_multiple_vars_in_one_string(self, monkeypatch): + monkeypatch.setenv("USER", "alice") + monkeypatch.setenv("PASS", "secret") + assert _resolve_env_vars("${USER}:${PASS}") == "alice:secret" + + def test_nested_dicts(self, monkeypatch): + monkeypatch.setenv("TOKEN", "abc123") + data = {"channels": {"telegram": {"token": "${TOKEN}"}}} + result = _resolve_env_vars(data) + assert result["channels"]["telegram"]["token"] == "abc123" + + def test_lists(self, monkeypatch): + monkeypatch.setenv("VAL", "x") + assert _resolve_env_vars(["${VAL}", "plain"]) == ["x", "plain"] + + def test_ignores_non_strings(self): + assert _resolve_env_vars(42) == 42 + assert _resolve_env_vars(True) is True + assert _resolve_env_vars(None) is None + assert _resolve_env_vars(3.14) == 3.14 + + def test_plain_strings_unchanged(self): + assert _resolve_env_vars("no vars here") == "no vars here" + + def test_missing_var_raises(self): + with pytest.raises(ValueError, match="DOES_NOT_EXIST"): + _resolve_env_vars("${DOES_NOT_EXIST}") + + +class TestResolveConfig: + def test_resolves_env_vars_in_config(self, tmp_path, monkeypatch): + monkeypatch.setenv("TEST_API_KEY", "resolved-key") + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + {"providers": {"groq": {"apiKey": "${TEST_API_KEY}"}}} + ), + encoding="utf-8", + ) + + raw = load_config(config_path) + assert raw.providers.groq.api_key == "${TEST_API_KEY}" + + resolved = resolve_config_env_vars(raw) + assert resolved.providers.groq.api_key == "resolved-key" + + def test_save_preserves_templates(self, tmp_path, monkeypatch): + monkeypatch.setenv("MY_TOKEN", "real-token") + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + {"channels": {"telegram": {"token": "${MY_TOKEN}"}}} + ), + encoding="utf-8", + ) + + raw = load_config(config_path) + save_config(raw, config_path) + + saved = json.loads(config_path.read_text(encoding="utf-8")) + assert saved["channels"]["telegram"]["token"] == "${MY_TOKEN}" diff --git a/tests/cron/test_cron_tool_list.py b/tests/cron/test_cron_tool_list.py index 5da3f4891..e57ab26bd 100644 --- a/tests/cron/test_cron_tool_list.py +++ b/tests/cron/test_cron_tool_list.py @@ -299,7 +299,7 @@ def test_add_cron_job_defaults_to_tool_timezone(tmp_path) -> None: tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai") tool.set_context("telegram", "chat-1") - result = tool._add_job("Morning standup", None, "0 8 * * *", None, None) + result = tool._add_job(None, "Morning standup", None, "0 8 * * *", None, None) assert result.startswith("Created job") job = tool._cron.list_jobs()[0] @@ -310,7 +310,7 @@ def test_add_at_job_uses_default_timezone_for_naive_datetime(tmp_path) -> None: tool = _make_tool_with_tz(tmp_path, "Asia/Shanghai") tool.set_context("telegram", "chat-1") - result = tool._add_job("Morning reminder", None, None, None, "2026-03-25T08:00:00") + result = tool._add_job(None, "Morning reminder", None, None, None, "2026-03-25T08:00:00") assert result.startswith("Created job") job = tool._cron.list_jobs()[0] @@ -322,7 +322,7 @@ def test_add_job_delivers_by_default(tmp_path) -> None: tool = _make_tool(tmp_path) tool.set_context("telegram", "chat-1") - result = tool._add_job("Morning standup", 60, None, None, None) + result = tool._add_job(None, "Morning standup", 60, None, None, None) assert result.startswith("Created job") job = tool._cron.list_jobs()[0] @@ -333,7 +333,7 @@ def test_add_job_can_disable_delivery(tmp_path) -> None: tool = _make_tool(tmp_path) tool.set_context("telegram", "chat-1") - result = tool._add_job("Background refresh", 60, None, None, None, deliver=False) + result = tool._add_job(None, "Background refresh", 60, None, None, None, deliver=False) assert result.startswith("Created job") job = tool._cron.list_jobs()[0] diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 1be505872..2e885e165 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -307,3 +307,54 @@ async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) assert result.finish_reason == "error" assert result.content is not None assert "stream stalled" in result.content + + +# --------------------------------------------------------------------------- +# Provider-specific thinking parameters (extra_body) +# --------------------------------------------------------------------------- + +def _build_kwargs_for(provider_name: str, model: str, reasoning_effort=None): + spec = find_by_name(provider_name) + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + p = OpenAICompatProvider(api_key="k", default_model=model, spec=spec) + return p._build_kwargs( + messages=[{"role": "user", "content": "hi"}], + tools=None, model=model, max_tokens=1024, temperature=0.7, + reasoning_effort=reasoning_effort, tool_choice=None, + ) + + +def test_dashscope_thinking_enabled_with_reasoning_effort() -> None: + kw = _build_kwargs_for("dashscope", "qwen3-plus", reasoning_effort="medium") + assert kw["extra_body"] == {"enable_thinking": True} + + +def test_dashscope_thinking_disabled_for_minimal() -> None: + kw = _build_kwargs_for("dashscope", "qwen3-plus", reasoning_effort="minimal") + assert kw["extra_body"] == {"enable_thinking": False} + + +def test_dashscope_no_extra_body_when_reasoning_effort_none() -> None: + kw = _build_kwargs_for("dashscope", "qwen-turbo", reasoning_effort=None) + assert "extra_body" not in kw + + +def test_volcengine_thinking_enabled() -> None: + kw = _build_kwargs_for("volcengine", "doubao-seed-2-0-pro", reasoning_effort="high") + assert kw["extra_body"] == {"thinking": {"type": "enabled"}} + + +def test_byteplus_thinking_disabled_for_minimal() -> None: + kw = _build_kwargs_for("byteplus", "doubao-seed-2-0-pro", reasoning_effort="minimal") + assert kw["extra_body"] == {"thinking": {"type": "disabled"}} + + +def test_byteplus_no_extra_body_when_reasoning_effort_none() -> None: + kw = _build_kwargs_for("byteplus", "doubao-seed-2-0-pro", reasoning_effort=None) + assert "extra_body" not in kw + + +def test_openai_no_thinking_extra_body() -> None: + """Non-thinking providers should never get extra_body for thinking.""" + kw = _build_kwargs_for("openai", "gpt-4o", reasoning_effort="medium") + assert "extra_body" not in kw diff --git a/tests/providers/test_provider_error_metadata.py b/tests/providers/test_provider_error_metadata.py new file mode 100644 index 000000000..ea2532acf --- /dev/null +++ b/tests/providers/test_provider_error_metadata.py @@ -0,0 +1,81 @@ +from types import SimpleNamespace + +from nanobot.providers.anthropic_provider import AnthropicProvider +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + +def _fake_response( + *, + status_code: int, + headers: dict[str, str] | None = None, + text: str = "", +) -> SimpleNamespace: + return SimpleNamespace( + status_code=status_code, + headers=headers or {}, + text=text, + ) + + +def test_openai_handle_error_extracts_structured_metadata() -> None: + class FakeStatusError(Exception): + pass + + err = FakeStatusError("boom") + err.status_code = 409 + err.response = _fake_response( + status_code=409, + headers={"retry-after-ms": "250", "x-should-retry": "false"}, + text='{"error":{"type":"rate_limit_exceeded","code":"rate_limit_exceeded"}}', + ) + err.body = {"error": {"type": "rate_limit_exceeded", "code": "rate_limit_exceeded"}} + + response = OpenAICompatProvider._handle_error(err) + + assert response.finish_reason == "error" + assert response.error_status_code == 409 + assert response.error_type == "rate_limit_exceeded" + assert response.error_code == "rate_limit_exceeded" + assert response.error_retry_after_s == 0.25 + assert response.error_should_retry is False + + +def test_openai_handle_error_marks_timeout_kind() -> None: + class FakeTimeoutError(Exception): + pass + + response = OpenAICompatProvider._handle_error(FakeTimeoutError("timeout")) + + assert response.finish_reason == "error" + assert response.error_kind == "timeout" + + +def test_anthropic_handle_error_extracts_structured_metadata() -> None: + class FakeStatusError(Exception): + pass + + err = FakeStatusError("boom") + err.status_code = 408 + err.response = _fake_response( + status_code=408, + headers={"retry-after": "1.5", "x-should-retry": "true"}, + ) + err.body = {"type": "error", "error": {"type": "rate_limit_error"}} + + response = AnthropicProvider._handle_error(err) + + assert response.finish_reason == "error" + assert response.error_status_code == 408 + assert response.error_type == "rate_limit_error" + assert response.error_retry_after_s == 1.5 + assert response.error_should_retry is True + + +def test_anthropic_handle_error_marks_connection_kind() -> None: + class FakeConnectionError(Exception): + pass + + response = AnthropicProvider._handle_error(FakeConnectionError("connection")) + + assert response.finish_reason == "error" + assert response.error_kind == "connection" diff --git a/tests/providers/test_provider_retry.py b/tests/providers/test_provider_retry.py index 61e58e22a..78c2a791e 100644 --- a/tests/providers/test_provider_retry.py +++ b/tests/providers/test_provider_retry.py @@ -254,6 +254,14 @@ def test_extract_retry_after_from_headers_supports_numeric_and_http_date() -> No ) == 0.1 +def test_extract_retry_after_from_headers_supports_retry_after_ms() -> None: + assert LLMProvider._extract_retry_after_from_headers({"retry-after-ms": "250"}) == 0.25 + assert LLMProvider._extract_retry_after_from_headers({"Retry-After-Ms": "1000"}) == 1.0 + assert LLMProvider._extract_retry_after_from_headers( + {"retry-after-ms": "500", "retry-after": "10"}, + ) == 0.5 + + @pytest.mark.asyncio async def test_chat_with_retry_prefers_structured_retry_after_when_present(monkeypatch) -> None: provider = ScriptedProvider([ @@ -273,6 +281,153 @@ async def test_chat_with_retry_prefers_structured_retry_after_when_present(monke assert delays == [9.0] +@pytest.mark.asyncio +async def test_chat_with_retry_retries_structured_status_code_without_keyword(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse( + content="request failed", + finish_reason="error", + error_status_code=409, + ), + LLMResponse(content="ok"), + ]) + delays: list[float] = [] + + async def _fake_sleep(delay: float) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.content == "ok" + assert provider.calls == 2 + assert delays == [1] + + +@pytest.mark.asyncio +async def test_chat_with_retry_stops_on_429_quota_exhausted(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse( + content='{"error":{"type":"insufficient_quota","code":"insufficient_quota"}}', + finish_reason="error", + error_status_code=429, + error_type="insufficient_quota", + error_code="insufficient_quota", + ), + LLMResponse(content="ok"), + ]) + delays: list[float] = [] + + async def _fake_sleep(delay: float) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.finish_reason == "error" + assert provider.calls == 1 + assert delays == [] + + +@pytest.mark.asyncio +async def test_chat_with_retry_retries_429_transient_rate_limit(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse( + content='{"error":{"type":"rate_limit_exceeded","code":"rate_limit_exceeded"}}', + finish_reason="error", + error_status_code=429, + error_type="rate_limit_exceeded", + error_code="rate_limit_exceeded", + error_retry_after_s=0.2, + ), + LLMResponse(content="ok"), + ]) + delays: list[float] = [] + + async def _fake_sleep(delay: float) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.content == "ok" + assert provider.calls == 2 + assert delays == [0.2] + + +@pytest.mark.asyncio +async def test_chat_with_retry_retries_structured_timeout_kind(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse( + content="request failed", + finish_reason="error", + error_kind="timeout", + ), + LLMResponse(content="ok"), + ]) + delays: list[float] = [] + + async def _fake_sleep(delay: float) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.content == "ok" + assert provider.calls == 2 + assert delays == [1] + + +@pytest.mark.asyncio +async def test_chat_with_retry_structured_should_retry_false_disables_retry(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse( + content="429 rate limit", + finish_reason="error", + error_should_retry=False, + ), + ]) + delays: list[float] = [] + + async def _fake_sleep(delay: float) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.finish_reason == "error" + assert provider.calls == 1 + assert delays == [] + + +@pytest.mark.asyncio +async def test_chat_with_retry_prefers_structured_retry_after(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse( + content="429 rate limit, retry after 99s", + finish_reason="error", + error_retry_after_s=0.2, + ), + LLMResponse(content="ok"), + ]) + delays: list[float] = [] + + async def _fake_sleep(delay: float) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.content == "ok" + assert delays == [0.2] + + @pytest.mark.asyncio async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None: provider = ScriptedProvider([ @@ -295,4 +450,3 @@ async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monk assert response.content == "429 rate limit" assert provider.calls == 10 assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4] - diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/tools/test_exec_env.py b/tests/tools/test_exec_env.py new file mode 100644 index 000000000..e5c0f48bb --- /dev/null +++ b/tests/tools/test_exec_env.py @@ -0,0 +1,38 @@ +"""Tests for exec tool environment isolation.""" + +import pytest + +from nanobot.agent.tools.shell import ExecTool + + +@pytest.mark.asyncio +async def test_exec_does_not_leak_parent_env(monkeypatch): + """Env vars from the parent process must not be visible to commands.""" + monkeypatch.setenv("NANOBOT_SECRET_TOKEN", "super-secret-value") + tool = ExecTool() + result = await tool.execute(command="printenv NANOBOT_SECRET_TOKEN") + assert "super-secret-value" not in result + + +@pytest.mark.asyncio +async def test_exec_has_working_path(): + """Basic commands should be available via the login shell's PATH.""" + tool = ExecTool() + result = await tool.execute(command="echo hello") + assert "hello" in result + + +@pytest.mark.asyncio +async def test_exec_path_append(): + """The pathAppend config should be available in the command's PATH.""" + tool = ExecTool(path_append="/opt/custom/bin") + result = await tool.execute(command="echo $PATH") + assert "/opt/custom/bin" in result + + +@pytest.mark.asyncio +async def test_exec_path_append_preserves_system_path(): + """pathAppend must not clobber standard system paths.""" + tool = ExecTool(path_append="/opt/custom/bin") + result = await tool.execute(command="ls /") + assert "Exit code: 0" in result diff --git a/tests/tools/test_message_tool_suppress.py b/tests/tools/test_message_tool_suppress.py index 1091de4c7..26d12085f 100644 --- a/tests/tools/test_message_tool_suppress.py +++ b/tests/tools/test_message_tool_suppress.py @@ -112,7 +112,7 @@ class TestMessageToolSuppressLogic: assert final_content == "Done" assert progress == [ ("Visible", False), - ('read_file("foo.txt")', True), + ('read foo.txt', True), ] diff --git a/tests/tools/test_sandbox.py b/tests/tools/test_sandbox.py new file mode 100644 index 000000000..82232d83e --- /dev/null +++ b/tests/tools/test_sandbox.py @@ -0,0 +1,121 @@ +"""Tests for nanobot.agent.tools.sandbox.""" + +import shlex + +import pytest + +from nanobot.agent.tools.sandbox import wrap_command + + +def _parse(cmd: str) -> list[str]: + """Split a wrapped command back into tokens for assertion.""" + return shlex.split(cmd) + + +class TestBwrapBackend: + def test_basic_structure(self, tmp_path): + ws = str(tmp_path / "project") + result = wrap_command("bwrap", "echo hi", ws, ws) + tokens = _parse(result) + + assert tokens[0] == "bwrap" + assert "--new-session" in tokens + assert "--die-with-parent" in tokens + assert "--ro-bind" in tokens + assert "--proc" in tokens + assert "--dev" in tokens + assert "--tmpfs" in tokens + + sep = tokens.index("--") + assert tokens[sep + 1:] == ["sh", "-c", "echo hi"] + + def test_workspace_bind_mounted_rw(self, tmp_path): + ws = str(tmp_path / "project") + result = wrap_command("bwrap", "ls", ws, ws) + tokens = _parse(result) + + bind_idx = [i for i, t in enumerate(tokens) if t == "--bind"] + assert any(tokens[i + 1] == ws and tokens[i + 2] == ws for i in bind_idx) + + def test_parent_dir_masked_with_tmpfs(self, tmp_path): + ws = tmp_path / "project" + result = wrap_command("bwrap", "ls", str(ws), str(ws)) + tokens = _parse(result) + + tmpfs_indices = [i for i, t in enumerate(tokens) if t == "--tmpfs"] + tmpfs_targets = {tokens[i + 1] for i in tmpfs_indices} + assert str(ws.parent) in tmpfs_targets + + def test_cwd_inside_workspace(self, tmp_path): + ws = tmp_path / "project" + sub = ws / "src" / "lib" + result = wrap_command("bwrap", "pwd", str(ws), str(sub)) + tokens = _parse(result) + + chdir_idx = tokens.index("--chdir") + assert tokens[chdir_idx + 1] == str(sub) + + def test_cwd_outside_workspace_falls_back(self, tmp_path): + ws = tmp_path / "project" + outside = tmp_path / "other" + result = wrap_command("bwrap", "pwd", str(ws), str(outside)) + tokens = _parse(result) + + chdir_idx = tokens.index("--chdir") + assert tokens[chdir_idx + 1] == str(ws.resolve()) + + def test_command_with_special_characters(self, tmp_path): + ws = str(tmp_path / "project") + cmd = "echo 'hello world' && cat \"file with spaces.txt\"" + result = wrap_command("bwrap", cmd, ws, ws) + tokens = _parse(result) + + sep = tokens.index("--") + assert tokens[sep + 1:] == ["sh", "-c", cmd] + + def test_system_dirs_ro_bound(self, tmp_path): + ws = str(tmp_path / "project") + result = wrap_command("bwrap", "ls", ws, ws) + tokens = _parse(result) + + ro_bind_indices = [i for i, t in enumerate(tokens) if t == "--ro-bind"] + ro_targets = {tokens[i + 1] for i in ro_bind_indices} + assert "/usr" in ro_targets + + def test_optional_dirs_use_ro_bind_try(self, tmp_path): + ws = str(tmp_path / "project") + result = wrap_command("bwrap", "ls", ws, ws) + tokens = _parse(result) + + try_indices = [i for i, t in enumerate(tokens) if t == "--ro-bind-try"] + try_targets = {tokens[i + 1] for i in try_indices} + assert "/bin" in try_targets + assert "/etc/ssl/certs" in try_targets + + def test_media_dir_ro_bind(self, tmp_path, monkeypatch): + """Media directory should be read-only mounted inside the sandbox.""" + fake_media = tmp_path / "media" + fake_media.mkdir() + monkeypatch.setattr( + "nanobot.agent.tools.sandbox.get_media_dir", + lambda: fake_media, + ) + ws = str(tmp_path / "project") + result = wrap_command("bwrap", "ls", ws, ws) + tokens = _parse(result) + + try_indices = [i for i, t in enumerate(tokens) if t == "--ro-bind-try"] + try_pairs = {(tokens[i + 1], tokens[i + 2]) for i in try_indices} + assert (str(fake_media), str(fake_media)) in try_pairs + + +class TestUnknownBackend: + def test_raises_value_error(self, tmp_path): + ws = str(tmp_path / "project") + with pytest.raises(ValueError, match="Unknown sandbox backend"): + wrap_command("nonexistent", "ls", ws, ws) + + def test_empty_string_raises(self, tmp_path): + ws = str(tmp_path / "project") + with pytest.raises(ValueError): + wrap_command("", "ls", ws, ws) diff --git a/tests/tools/test_tool_validation.py b/tests/tools/test_tool_validation.py index e56f93185..072623db8 100644 --- a/tests/tools/test_tool_validation.py +++ b/tests/tools/test_tool_validation.py @@ -1,3 +1,6 @@ +import shlex +import subprocess +import sys from typing import Any from nanobot.agent.tools import ( @@ -546,10 +549,15 @@ async def test_exec_head_tail_truncation() -> None: """Long output should preserve both head and tail.""" tool = ExecTool() # Generate output that exceeds _MAX_OUTPUT (10_000 chars) - # Use python to generate output to avoid command line length limits - result = await tool.execute( - command="python -c \"print('A' * 6000 + '\\n' + 'B' * 6000)\"" - ) + # Use current interpreter (PATH may not have `python`). ExecTool uses + # create_subprocess_shell: POSIX needs shlex.quote; Windows uses cmd.exe + # rules, so list2cmdline is appropriate there. + script = "print('A' * 6000 + '\\n' + 'B' * 6000)" + if sys.platform == "win32": + command = subprocess.list2cmdline([sys.executable, "-c", script]) + else: + command = f"{shlex.quote(sys.executable)} -c {shlex.quote(script)}" + result = await tool.execute(command=command) assert "chars truncated" in result # Head portion should start with As assert result.startswith("A") diff --git a/tests/tools/test_web_search_tool.py b/tests/tools/test_web_search_tool.py index 02bf44395..e33dd7e6c 100644 --- a/tests/tools/test_web_search_tool.py +++ b/tests/tools/test_web_search_tool.py @@ -1,5 +1,7 @@ """Tests for multi-provider web search.""" +import asyncio + import httpx import pytest @@ -160,3 +162,70 @@ async def test_searxng_invalid_url(): tool = _tool(provider="searxng", base_url="not-a-url") result = await tool.execute(query="test") assert "Error" in result + + +@pytest.mark.asyncio +async def test_jina_422_falls_back_to_duckduckgo(monkeypatch): + class MockDDGS: + def __init__(self, **kw): + pass + + def text(self, query, max_results=5): + return [{"title": "Fallback", "href": "https://ddg.example", "body": "DuckDuckGo fallback"}] + + async def mock_get(self, url, **kw): + assert "s.jina.ai" in str(url) + raise httpx.HTTPStatusError( + "422 Unprocessable Entity", + request=httpx.Request("GET", str(url)), + response=httpx.Response(422, request=httpx.Request("GET", str(url))), + ) + + monkeypatch.setattr(httpx.AsyncClient, "get", mock_get) + monkeypatch.setattr("ddgs.DDGS", MockDDGS) + + tool = _tool(provider="jina", api_key="jina-key") + result = await tool.execute(query="test") + assert "DuckDuckGo fallback" in result + + +@pytest.mark.asyncio +async def test_jina_search_uses_path_encoded_query(monkeypatch): + calls = {} + + async def mock_get(self, url, **kw): + calls["url"] = str(url) + calls["params"] = kw.get("params") + return _response(json={ + "data": [{"title": "Jina Result", "url": "https://jina.ai", "content": "AI search"}] + }) + + monkeypatch.setattr(httpx.AsyncClient, "get", mock_get) + tool = _tool(provider="jina", api_key="jina-key") + await tool.execute(query="hello world") + assert calls["url"].rstrip("/") == "https://s.jina.ai/hello%20world" + assert calls["params"] in (None, {}) + + +@pytest.mark.asyncio +async def test_duckduckgo_timeout_returns_error(monkeypatch): + """asyncio.wait_for guard should fire when DDG search hangs.""" + import threading + gate = threading.Event() + + class HangingDDGS: + def __init__(self, **kw): + pass + + def text(self, query, max_results=5): + gate.wait(timeout=10) + return [] + + monkeypatch.setattr("ddgs.DDGS", HangingDDGS) + tool = _tool(provider="duckduckgo") + tool.config.timeout = 0.2 + result = await tool.execute(query="test") + gate.set() + assert "Error" in result + + diff --git a/tests/utils/test_abbreviate_path.py b/tests/utils/test_abbreviate_path.py new file mode 100644 index 000000000..573ca0a92 --- /dev/null +++ b/tests/utils/test_abbreviate_path.py @@ -0,0 +1,105 @@ +"""Tests for abbreviate_path utility.""" + +import os +from nanobot.utils.path import abbreviate_path + + +class TestAbbreviatePathShort: + def test_short_path_unchanged(self): + assert abbreviate_path("/home/user/file.py") == "/home/user/file.py" + + def test_exact_max_len_unchanged(self): + path = "/a/b/c" # 7 chars + assert abbreviate_path("/a/b/c", max_len=7) == "/a/b/c" + + def test_basename_only(self): + assert abbreviate_path("file.py") == "file.py" + + def test_empty_string(self): + assert abbreviate_path("") == "" + + +class TestAbbreviatePathHome: + def test_home_replacement(self): + home = os.path.expanduser("~") + result = abbreviate_path(f"{home}/project/file.py") + assert result.startswith("~/") + assert result.endswith("file.py") + + def test_home_preserves_short_path(self): + home = os.path.expanduser("~") + result = abbreviate_path(f"{home}/a.py") + assert result == "~/a.py" + + +class TestAbbreviatePathLong: + def test_long_path_keeps_basename(self): + path = "/a/b/c/d/e/f/g/h/very_long_filename.py" + result = abbreviate_path(path, max_len=30) + assert result.endswith("very_long_filename.py") + assert "\u2026" in result + + def test_long_path_keeps_parent_dir(self): + path = "/a/b/c/d/e/f/g/h/src/loop.py" + result = abbreviate_path(path, max_len=30) + assert "loop.py" in result + assert "src" in result + + def test_very_long_path_just_basename(self): + path = "/a/b/c/d/e/f/g/h/i/j/k/l/m/n/o/p/q/r/s/t/u/v/w/x/y/z/file.py" + result = abbreviate_path(path, max_len=20) + assert result.endswith("file.py") + assert len(result) <= 20 + + +class TestAbbreviatePathWindows: + def test_windows_drive_path(self): + path = "D:\\Documents\\GitHub\\nanobot\\src\\utils\\helpers.py" + result = abbreviate_path(path, max_len=40) + assert result.endswith("helpers.py") + assert "nanobot" in result + + def test_windows_home(self): + home = os.path.expanduser("~") + path = os.path.join(home, ".nanobot", "workspace", "log.txt") + result = abbreviate_path(path) + assert result.startswith("~/") + assert "log.txt" in result + + +class TestAbbreviatePathURLs: + def test_url_keeps_domain_and_filename(self): + url = "https://example.com/api/v2/long/path/resource.json" + result = abbreviate_path(url, max_len=40) + assert "resource.json" in result + assert "example.com" in result + + def test_short_url_unchanged(self): + url = "https://example.com/api" + assert abbreviate_path(url) == url + + def test_url_no_path_just_domain(self): + """G3: URL with no path should return as-is if short enough.""" + url = "https://example.com" + assert abbreviate_path(url) == url + + def test_url_with_query_string(self): + """G3: URL with query params should abbreviate path part.""" + url = "https://example.com/api/v2/endpoint?key=value&other=123" + result = abbreviate_path(url, max_len=40) + assert "example.com" in result + assert "\u2026" in result + + def test_url_very_long_basename(self): + """G3: URL with very long basename should truncate basename.""" + url = "https://example.com/path/very_long_resource_name_file.json" + result = abbreviate_path(url, max_len=35) + assert "example.com" in result + assert "\u2026" in result + + def test_url_negative_budget_consistent_format(self): + """I3: Negative budget should still produce domain/…/basename format.""" + url = "https://a.co/very/deep/path/with/lots/of/segments/and/a/long/basename.txt" + result = abbreviate_path(url, max_len=20) + assert "a.co" in result + assert "/\u2026/" in result diff --git a/tests/utils/test_searchusage.py b/tests/utils/test_searchusage.py new file mode 100644 index 000000000..205ccd917 --- /dev/null +++ b/tests/utils/test_searchusage.py @@ -0,0 +1,306 @@ +"""Tests for web search provider usage fetching and /status integration.""" + +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from nanobot.utils.searchusage import ( + SearchUsageInfo, + _parse_tavily_usage, + fetch_search_usage, +) +from nanobot.utils.helpers import build_status_content + + +# --------------------------------------------------------------------------- +# SearchUsageInfo.format() tests +# --------------------------------------------------------------------------- + +class TestSearchUsageInfoFormat: + def test_unsupported_provider_shows_no_tracking(self): + info = SearchUsageInfo(provider="duckduckgo", supported=False) + text = info.format() + assert "duckduckgo" in text + assert "not available" in text + + def test_supported_with_error(self): + info = SearchUsageInfo(provider="tavily", supported=True, error="HTTP 401") + text = info.format() + assert "tavily" in text + assert "HTTP 401" in text + assert "unavailable" in text + + def test_full_tavily_usage(self): + info = SearchUsageInfo( + provider="tavily", + supported=True, + used=142, + limit=1000, + remaining=858, + reset_date="2026-05-01", + search_used=120, + extract_used=15, + crawl_used=7, + ) + text = info.format() + assert "tavily" in text + assert "142 / 1000" in text + assert "858" in text + assert "2026-05-01" in text + assert "Search: 120" in text + assert "Extract: 15" in text + assert "Crawl: 7" in text + + def test_usage_without_limit(self): + info = SearchUsageInfo(provider="tavily", supported=True, used=50) + text = info.format() + assert "50 requests" in text + assert "/" not in text.split("Usage:")[1].split("\n")[0] + + def test_no_breakdown_when_none(self): + info = SearchUsageInfo( + provider="tavily", supported=True, used=10, limit=100, remaining=90 + ) + text = info.format() + assert "Breakdown" not in text + + def test_brave_unsupported(self): + info = SearchUsageInfo(provider="brave", supported=False) + text = info.format() + assert "brave" in text + assert "not available" in text + + +# --------------------------------------------------------------------------- +# _parse_tavily_usage tests +# --------------------------------------------------------------------------- + +class TestParseTavilyUsage: + def test_full_response(self): + data = { + "account": { + "current_plan": "Researcher", + "plan_usage": 142, + "plan_limit": 1000, + "search_usage": 120, + "extract_usage": 15, + "crawl_usage": 7, + "map_usage": 0, + "research_usage": 0, + "paygo_usage": 0, + "paygo_limit": None, + }, + } + info = _parse_tavily_usage(data) + assert info.provider == "tavily" + assert info.supported is True + assert info.used == 142 + assert info.limit == 1000 + assert info.remaining == 858 + assert info.search_used == 120 + assert info.extract_used == 15 + assert info.crawl_used == 7 + + def test_remaining_computed(self): + data = {"account": {"plan_usage": 300, "plan_limit": 1000}} + info = _parse_tavily_usage(data) + assert info.remaining == 700 + + def test_remaining_not_negative(self): + data = {"account": {"plan_usage": 1100, "plan_limit": 1000}} + info = _parse_tavily_usage(data) + assert info.remaining == 0 + + def test_empty_response(self): + info = _parse_tavily_usage({}) + assert info.provider == "tavily" + assert info.supported is True + assert info.used is None + assert info.limit is None + + def test_no_breakdown_fields(self): + data = {"account": {"plan_usage": 5, "plan_limit": 50}} + info = _parse_tavily_usage(data) + assert info.search_used is None + assert info.extract_used is None + assert info.crawl_used is None + + +# --------------------------------------------------------------------------- +# fetch_search_usage routing tests +# --------------------------------------------------------------------------- + +class TestFetchSearchUsageRouting: + @pytest.mark.asyncio + async def test_duckduckgo_returns_unsupported(self): + info = await fetch_search_usage("duckduckgo") + assert info.provider == "duckduckgo" + assert info.supported is False + + @pytest.mark.asyncio + async def test_searxng_returns_unsupported(self): + info = await fetch_search_usage("searxng") + assert info.supported is False + + @pytest.mark.asyncio + async def test_jina_returns_unsupported(self): + info = await fetch_search_usage("jina") + assert info.supported is False + + @pytest.mark.asyncio + async def test_brave_returns_unsupported(self): + info = await fetch_search_usage("brave") + assert info.provider == "brave" + assert info.supported is False + + @pytest.mark.asyncio + async def test_unknown_provider_returns_unsupported(self): + info = await fetch_search_usage("some_unknown_provider") + assert info.supported is False + + @pytest.mark.asyncio + async def test_tavily_no_api_key_returns_error(self): + with patch.dict("os.environ", {}, clear=True): + # Ensure TAVILY_API_KEY is not set + import os + os.environ.pop("TAVILY_API_KEY", None) + info = await fetch_search_usage("tavily", api_key=None) + assert info.provider == "tavily" + assert info.supported is True + assert info.error is not None + assert "not configured" in info.error + + @pytest.mark.asyncio + async def test_tavily_success(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "account": { + "current_plan": "Researcher", + "plan_usage": 142, + "plan_limit": 1000, + "search_usage": 120, + "extract_usage": 15, + "crawl_usage": 7, + }, + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=mock_response) + + with patch("httpx.AsyncClient", return_value=mock_client): + info = await fetch_search_usage("tavily", api_key="test-key") + + assert info.provider == "tavily" + assert info.supported is True + assert info.error is None + assert info.used == 142 + assert info.limit == 1000 + assert info.remaining == 858 + assert info.search_used == 120 + + @pytest.mark.asyncio + async def test_tavily_http_error(self): + import httpx + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "401", request=MagicMock(), response=mock_response + ) + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(return_value=mock_response) + + with patch("httpx.AsyncClient", return_value=mock_client): + info = await fetch_search_usage("tavily", api_key="bad-key") + + assert info.supported is True + assert info.error == "HTTP 401" + + @pytest.mark.asyncio + async def test_tavily_network_error(self): + import httpx + + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client.get = AsyncMock(side_effect=httpx.ConnectError("timeout")) + + with patch("httpx.AsyncClient", return_value=mock_client): + info = await fetch_search_usage("tavily", api_key="test-key") + + assert info.supported is True + assert info.error is not None + + @pytest.mark.asyncio + async def test_provider_name_case_insensitive(self): + info = await fetch_search_usage("Tavily", api_key=None) + assert info.provider == "tavily" + assert info.supported is True + + +# --------------------------------------------------------------------------- +# build_status_content integration tests +# --------------------------------------------------------------------------- + +class TestBuildStatusContentWithSearchUsage: + _BASE_KWARGS = dict( + version="0.1.0", + model="claude-opus-4-5", + start_time=1_000_000.0, + last_usage={"prompt_tokens": 1000, "completion_tokens": 200}, + context_window_tokens=65536, + session_msg_count=5, + context_tokens_estimate=3000, + ) + + def test_no_search_usage_unchanged(self): + """Omitting search_usage_text keeps existing behaviour.""" + content = build_status_content(**self._BASE_KWARGS) + assert "πŸ”" not in content + assert "Web Search" not in content + + def test_search_usage_none_unchanged(self): + content = build_status_content(**self._BASE_KWARGS, search_usage_text=None) + assert "πŸ”" not in content + + def test_search_usage_appended(self): + usage_text = "πŸ” Web Search: tavily\n Usage: 142 / 1000 requests" + content = build_status_content(**self._BASE_KWARGS, search_usage_text=usage_text) + assert "πŸ” Web Search: tavily" in content + assert "142 / 1000" in content + + def test_existing_fields_still_present(self): + usage_text = "πŸ” Web Search: duckduckgo\n Usage tracking: not available" + content = build_status_content(**self._BASE_KWARGS, search_usage_text=usage_text) + # Original fields must still be present + assert "nanobot v0.1.0" in content + assert "claude-opus-4-5" in content + assert "1000 in / 200 out" in content + # New field appended + assert "duckduckgo" in content + + def test_full_tavily_in_status(self): + info = SearchUsageInfo( + provider="tavily", + supported=True, + used=142, + limit=1000, + remaining=858, + reset_date="2026-05-01", + search_used=120, + extract_used=15, + crawl_used=7, + ) + content = build_status_content(**self._BASE_KWARGS, search_usage_text=info.format()) + assert "142 / 1000" in content + assert "858" in content + assert "2026-05-01" in content + assert "Search: 120" in content