mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-25 02:52:37 +00:00
Merge branch 'main' into nightly
This commit is contained in:
commit
80bfcf4473
1
.gitignore
vendored
1
.gitignore
vendored
@ -4,6 +4,7 @@
|
||||
.docs
|
||||
.env
|
||||
.web
|
||||
.orion
|
||||
|
||||
# Python bytecode & caches
|
||||
*.pyc
|
||||
|
||||
196
README.md
196
README.md
@ -21,7 +21,20 @@
|
||||
|
||||
## 📢 News
|
||||
|
||||
- **2026-04-14** 🚀 Released **v0.1.5.post1** — Dream skill discovery, mid-turn follow-up injection, WebSocket channel, and deeper channel integrations. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.5.post1) for details.
|
||||
- **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks.
|
||||
- **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened.
|
||||
- **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media.
|
||||
- **2026-04-10** 📓 Notebook editing tool, multiple MCP servers, Feishu streaming & done-emoji.
|
||||
- **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config.
|
||||
- **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback.
|
||||
- **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools.
|
||||
- **2026-04-06** 🛰️ Langfuse observability, unified Whisper transcription, email attachments.
|
||||
- **2026-04-05** 🚀 Released **v0.1.5** — sturdier long-running tasks, Dream two-stage memory, production-ready sandboxing and programming Agent SDK. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.5) for details.
|
||||
|
||||
<details>
|
||||
<summary>Earlier news</summary>
|
||||
|
||||
- **2026-04-04** 🚀 Jinja2 response templates, Dream memory hardened, smarter retry handling.
|
||||
- **2026-04-03** 🧠 Xiaomi MiMo provider, chain-of-thought reasoning visible, Telegram UX polish.
|
||||
- **2026-04-02** 🧱 Long-running tasks run more reliably — core runtime hardening.
|
||||
@ -31,11 +44,6 @@
|
||||
- **2026-03-29** 💬 WeChat voice, typing, QR/media resilience; fixed-session OpenAI-compatible API.
|
||||
- **2026-03-28** 📚 Provider docs refresh; skill template wording fix.
|
||||
- **2026-03-27** 🚀 Released **v0.1.4.post6** — architecture decoupling, litellm removal, end-to-end streaming, WeChat channel, and a security fix. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post6) for details.
|
||||
|
||||
|
||||
<details>
|
||||
<summary>Earlier news</summary>
|
||||
|
||||
- **2026-03-26** 🏗️ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries.
|
||||
- **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures.
|
||||
- **2026-03-24** 🔧 WeChat compatibility, Feishu CardKit streaming, test suite restructured.
|
||||
@ -112,26 +120,57 @@
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [News](#-news)
|
||||
- [Key Features](#key-features-of-nanobot)
|
||||
- [Architecture](#️-architecture)
|
||||
- [Features](#-features)
|
||||
- [Install](#-install)
|
||||
- [Quick Start](#-quick-start)
|
||||
- [Chat Apps](#-chat-apps)
|
||||
- [Agent Social Network](#-agent-social-network)
|
||||
- [Configuration](#️-configuration)
|
||||
- [Multiple Instances](#-multiple-instances)
|
||||
- [Memory](#-memory)
|
||||
- [CLI Reference](#-cli-reference)
|
||||
- [In-Chat Commands](#-in-chat-commands)
|
||||
- [Python SDK](#-python-sdk)
|
||||
- [OpenAI-Compatible API](#-openai-compatible-api)
|
||||
- [Docker](#-docker)
|
||||
- [Linux Service](#-linux-service)
|
||||
- [Project Structure](#-project-structure)
|
||||
- [Contribute & Roadmap](#-contribute--roadmap)
|
||||
- [Star History](#-star-history)
|
||||
- [📢 News](#-news)
|
||||
- [Key Features of nanobot:](#key-features-of-nanobot)
|
||||
- [🏗️ Architecture](#️-architecture)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [✨ Features](#-features)
|
||||
- [📦 Install](#-install)
|
||||
- [Update to latest version](#update-to-latest-version)
|
||||
- [🚀 Quick Start](#-quick-start)
|
||||
- [💬 Chat Apps](#-chat-apps)
|
||||
- [🌐 Agent Social Network](#-agent-social-network)
|
||||
- [⚙️ Configuration](#️-configuration)
|
||||
- [Environment Variables for Secrets](#environment-variables-for-secrets)
|
||||
- [Providers](#providers)
|
||||
- [Channel Settings](#channel-settings)
|
||||
- [Retry Behavior](#retry-behavior)
|
||||
- [Web Search](#web-search)
|
||||
- [`tools.web.search`](#toolswebsearch)
|
||||
- [MCP (Model Context Protocol)](#mcp-model-context-protocol)
|
||||
- [Security](#security)
|
||||
- [Auto Compact](#auto-compact)
|
||||
- [Timezone](#timezone)
|
||||
- [Unified Session](#unified-session)
|
||||
- [Disabled Skills](#disabled-skills)
|
||||
- [🧩 Multiple Instances](#-multiple-instances)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Path Resolution](#path-resolution)
|
||||
- [How It Works](#how-it-works)
|
||||
- [Minimal Setup](#minimal-setup)
|
||||
- [Common Use Cases](#common-use-cases)
|
||||
- [Notes](#notes)
|
||||
- [🧠 Memory](#-memory)
|
||||
- [💻 CLI Reference](#-cli-reference)
|
||||
- [💬 In-Chat Commands](#-in-chat-commands)
|
||||
- [🐍 Python SDK](#-python-sdk)
|
||||
- [🔌 OpenAI-Compatible API](#-openai-compatible-api)
|
||||
- [Behavior](#behavior)
|
||||
- [Endpoints](#endpoints)
|
||||
- [curl](#curl)
|
||||
- [File Upload (JSON base64)](#file-upload-json-base64)
|
||||
- [File Upload (multipart/form-data)](#file-upload-multipartform-data)
|
||||
- [Python (`requests`)](#python-requests)
|
||||
- [Python (`openai`)](#python-openai)
|
||||
- [🐳 Docker](#-docker)
|
||||
- [Docker Compose](#docker-compose)
|
||||
- [Docker](#docker)
|
||||
- [🐧 Linux Service](#-linux-service)
|
||||
- [📁 Project Structure](#-project-structure)
|
||||
- [🤝 Contribute \& Roadmap](#-contribute--roadmap)
|
||||
- [Branching Strategy](#branching-strategy)
|
||||
- [Contributors](#contributors)
|
||||
- [⭐ Star History](#-star-history)
|
||||
|
||||
## ✨ Features
|
||||
|
||||
@ -395,6 +434,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
||||
"enabled": true,
|
||||
"token": "YOUR_BOT_TOKEN",
|
||||
"allowFrom": ["YOUR_USER_ID"],
|
||||
"allowChannels": [],
|
||||
"groupPolicy": "mention",
|
||||
"streaming": true
|
||||
}
|
||||
@ -407,6 +447,7 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
||||
> - `"open"` — Respond to all messages
|
||||
> DMs always respond when the sender is in `allowFrom`.
|
||||
> - If you set group policy to open create new threads as private threads and then @ the bot into it. Otherwise the thread itself and the channel in which you spawned it will spawn a bot session.
|
||||
> `allowChannels` restricts the bot to specific Discord channel IDs. Empty (default) means respond in every channel the bot can see. Example: `["1234567890", "0987654321"]`. The filter applies after `allowFrom`, so both must pass.
|
||||
> `streaming` defaults to `true`. Disable it only if you explicitly want non-streaming replies.
|
||||
|
||||
**5. Invite the bot**
|
||||
@ -902,7 +943,7 @@ Create or reuse a Microsoft Teams / Azure bot app registration. Set the bot mess
|
||||
|
||||
> - `replyInThread: true` replies to the triggering Teams activity when a stored `activity_id` is available.
|
||||
> - `mentionOnlyResponse` controls what Nanobot receives when a user sends only a bot mention (`<at>Nanobot</at>`). Set to `""` to ignore mention-only messages.
|
||||
> - `validateInboundAuth: true` (recommended for production) enables inbound Bot Framework bearer-token validation (signature, issuer, audience, lifetime, `serviceUrl`). **Default is `false`** — set explicitly to `true` for production deployments.
|
||||
> - `validateInboundAuth: true` enables inbound Bot Framework bearer-token validation (signature, issuer, audience, lifetime, `serviceUrl`). This is the safe default for public deployments. Only set it to `false` for local development or tightly controlled testing.
|
||||
|
||||
**4. Run**
|
||||
|
||||
@ -973,6 +1014,7 @@ IMAP_PASSWORD=your-password-here
|
||||
> - **Voice transcription**: Voice messages (Telegram, WhatsApp) are automatically transcribed using Whisper. By default Groq is used (free tier). Set `"transcriptionProvider": "openai"` under `channels` to use OpenAI Whisper instead — the API key is picked from the matching provider config.
|
||||
> - **MiniMax Coding Plan**: Exclusive discount links for the nanobot community: [Overseas](https://platform.minimax.io/subscribe/coding-plan?code=9txpdXw04g&source=link) · [Mainland China](https://platform.minimaxi.com/subscribe/token-plan?code=GILTJpMTqZ&source=link)
|
||||
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
||||
> - **MiniMax thinking mode**: Use `providers.minimaxAnthropic` when you want `reasoningEffort` / thinking mode. MiniMax exposes that capability through its Anthropic-compatible endpoint, so nanobot keeps it as a separate provider instead of guessing MiniMax-specific thinking parameters on the generic OpenAI-compatible `minimax` endpoint. It uses the same `MINIMAX_API_KEY`. Default Anthropic-compatible base URL: `https://api.minimax.io/anthropic`; for mainland China use `https://api.minimaxi.com/anthropic`.
|
||||
> - **VolcEngine / BytePlus Coding Plan**: Use dedicated providers `volcengineCodingPlan` or `byteplusCodingPlan` instead of the pay-per-use `volcengine` / `byteplus` providers.
|
||||
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
||||
> - **Alibaba Cloud BaiLian**: If you're using Alibaba Cloud BaiLian's OpenAI-compatible endpoint, set `"apiBase": "https://dashscope.aliyuncs.com/compatible-mode/v1"` in your dashscope provider config.
|
||||
@ -990,6 +1032,7 @@ IMAP_PASSWORD=your-password-here
|
||||
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||
| `groq` | LLM + Voice transcription (Whisper, default) | [console.groq.com](https://console.groq.com) |
|
||||
| `minimax` | LLM (MiniMax direct) | [platform.minimaxi.com](https://platform.minimaxi.com) |
|
||||
| `minimax_anthropic` | LLM (MiniMax Anthropic-compatible endpoint, thinking mode) | [platform.minimaxi.com](https://platform.minimaxi.com) |
|
||||
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
||||
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
||||
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
|
||||
@ -998,6 +1041,7 @@ IMAP_PASSWORD=your-password-here
|
||||
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
||||
| `mimo` | LLM (MiMo) | [platform.xiaomimimo.com](https://platform.xiaomimimo.com) |
|
||||
| `ollama` | LLM (local, Ollama) | — |
|
||||
| `lm_studio` | LLM (local, LM Studio) | — |
|
||||
| `mistral` | LLM | [docs.mistral.ai](https://docs.mistral.ai/) |
|
||||
| `stepfun` | LLM (Step Fun/阶跃星辰) | [platform.stepfun.com](https://platform.stepfun.com) |
|
||||
| `ovms` | LLM (local, OpenVINO Model Server) | [docs.openvino.ai](https://docs.openvino.ai/2026/model-server/ovms_docs_llm_quickstart.html) |
|
||||
@ -1085,7 +1129,7 @@ nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -
|
||||
<details>
|
||||
<summary><b>Custom Provider (Any OpenAI-compatible API)</b></summary>
|
||||
|
||||
Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Model name is passed as-is.
|
||||
Connects directly to any OpenAI-compatible endpoint — llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Model name is passed as-is.
|
||||
|
||||
```json
|
||||
{
|
||||
@ -1103,7 +1147,7 @@ Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, To
|
||||
}
|
||||
```
|
||||
|
||||
> For local servers that don't require a key, set `apiKey` to any non-empty string (e.g. `"no-key"`).
|
||||
> For local servers that don't require authentication, set `apiKey` to `null`.
|
||||
>
|
||||
> `custom` is the right choice for providers that expose an OpenAI-compatible **chat completions** API. It does **not** force third-party endpoints onto the OpenAI/Azure **Responses API**.
|
||||
>
|
||||
@ -1162,6 +1206,40 @@ ollama run llama3.2
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>LM Studio (local)</b></summary>
|
||||
|
||||
[LM Studio](https://lmstudio.ai/) provides a local OpenAI-compatible server for running LLMs. Download models through the LM Studio UI, then start the local server.
|
||||
|
||||
**1. Start LM Studio server:**
|
||||
- Launch LM Studio
|
||||
- Go to the "Local Server" tab
|
||||
- Load a model (e.g., Llama, Mistral, Qwen)
|
||||
- Click "Start Server" (default port: 1234)
|
||||
|
||||
**2. Add to config** (partial — merge into `~/.nanobot/config.json`):
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"lm_studio": {
|
||||
"apiKey": null,
|
||||
"apiBase": "http://localhost:1234/v1"
|
||||
}
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"provider": "lm_studio",
|
||||
"model": "local-model"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> **Note:** Set `apiKey` to `null` for LM Studio since it runs locally and doesn't require authentication. The model name should match what's shown in the LM Studio UI.
|
||||
> `provider: "auto"` also works when `providers.lm_studio.apiBase` is configured, but setting `"provider": "lm_studio"` is the clearest option.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>OpenVINO Model Server (local / OpenAI-compatible)</b></summary>
|
||||
|
||||
@ -1249,12 +1327,12 @@ vllm serve meta-llama/Llama-3.1-8B-Instruct --port 8000
|
||||
|
||||
**2. Add to config** (partial — merge into `~/.nanobot/config.json`):
|
||||
|
||||
*Provider (key can be any non-empty string for local):*
|
||||
*Provider (set API key to null for local servers):*
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"vllm": {
|
||||
"apiKey": "dummy",
|
||||
"apiKey": null,
|
||||
"apiBase": "http://localhost:8000/v1"
|
||||
}
|
||||
}
|
||||
@ -1621,8 +1699,12 @@ How it works:
|
||||
3. **Summary injection**: When the user returns, the summary is injected as runtime context (one-shot, not persisted) alongside the retained recent suffix.
|
||||
4. **Restart-safe resume**: The summary is also mirrored into session metadata so it can still be recovered after a process restart.
|
||||
|
||||
> [!TIP]
|
||||
> Think of auto compact as "summarize older context, keep the freshest live turns." It is not a hard session reset.
|
||||
> [!NOTE]
|
||||
> Mental model: "summarize older context, keep the freshest live turns, **and overwrite the session file with the compact form.**" It is not a full `session.clear()`, but it is a write — not a soft cursor move.
|
||||
>
|
||||
> Concretely, auto compact rewrites `sessions/<key>.jsonl` in place: older messages (including their structured `tool_calls` / `tool_call_id` / `reasoning_content`) are replaced by just the retained recent suffix (currently 8 messages), while the archived prefix is preserved only as a plain-text summary appended to `memory/history.jsonl` (or a `[RAW] ...` flattened dump if LLM summarization fails). The original structured JSON of those turns is no longer recoverable from the session file.
|
||||
>
|
||||
> This differs from the **token-driven soft consolidation** that fires when a prompt exceeds the context budget: that path only advances an internal `last_consolidated` cursor and leaves the session file untouched, so the raw tool-call trail stays on disk and can still be replayed or audited. If you rely on that trail for debugging or auditing, leave `idleCompactAfterMinutes` at the default `0` and let only the token-driven path run.
|
||||
|
||||
### Timezone
|
||||
|
||||
@ -1778,6 +1860,7 @@ Example config:
|
||||
}
|
||||
},
|
||||
"gateway": {
|
||||
"host": "127.0.0.1",
|
||||
"port": 18790
|
||||
}
|
||||
}
|
||||
@ -1790,6 +1873,14 @@ nanobot gateway --config ~/.nanobot-telegram/config.json
|
||||
nanobot gateway --config ~/.nanobot-discord/config.json
|
||||
```
|
||||
|
||||
Each gateway instance also exposes a lightweight HTTP health endpoint on
|
||||
`gateway.host:gateway.port`. By default, the gateway binds to `127.0.0.1`,
|
||||
so the endpoint stays local unless you explicitly set `gateway.host` to a
|
||||
public or LAN-facing address.
|
||||
|
||||
- `GET /health` returns `{"status":"ok"}`
|
||||
- Other paths return `404`
|
||||
|
||||
Override workspace for one-off runs when needed:
|
||||
|
||||
```bash
|
||||
@ -1932,7 +2023,8 @@ By default, the API binds to `127.0.0.1:8900`. You can change this in `config.js
|
||||
- Session isolation: pass `"session_id"` in the request body to isolate conversations; omit for a shared default session (`api:default`)
|
||||
- Single-message input: each request must contain exactly one `user` message
|
||||
- Fixed model: omit `model`, or pass the same model shown by `/v1/models`
|
||||
- No streaming: `stream=true` is not supported
|
||||
- Streaming: set `stream=true` to receive Server-Sent Events (`text/event-stream`) with OpenAI-compatible delta chunks, terminated by `data: [DONE]`; omit or set `stream=false` for a single JSON response
|
||||
- **File uploads**: supports images, PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) via JSON base64 or `multipart/form-data` (max 10MB per file)
|
||||
- API requests run in the synthetic `api` channel, so the `message` tool does **not** automatically deliver to Telegram/Discord/etc. To proactively send to another chat, call `message` with an explicit `channel` and `chat_id` for an enabled channel.
|
||||
|
||||
Example tool call for cross-channel delivery from an API session:
|
||||
@ -1964,6 +2056,44 @@ curl http://127.0.0.1:8900/v1/chat/completions \
|
||||
}'
|
||||
```
|
||||
|
||||
### File Upload (JSON base64)
|
||||
|
||||
Send images inline using the OpenAI multimodal content format:
|
||||
|
||||
```bash
|
||||
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"messages": [{"role": "user", "content": [
|
||||
{"type": "text", "text": "Describe this image"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,iVBOR..."}}
|
||||
]}]
|
||||
}'
|
||||
```
|
||||
|
||||
### File Upload (multipart/form-data)
|
||||
|
||||
Upload any supported file type (images, PDF, Word, Excel, PPT) via multipart:
|
||||
|
||||
```bash
|
||||
# Single file
|
||||
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||
-F "message=Summarize this report" \
|
||||
-F "files=@report.docx"
|
||||
|
||||
# Multiple files with session isolation
|
||||
curl http://127.0.0.1:8900/v1/chat/completions \
|
||||
-F "message=Compare these files" \
|
||||
-F "files=@chart.png" \
|
||||
-F "files=@data.xlsx" \
|
||||
-F "session_id=my-session"
|
||||
```
|
||||
|
||||
Supported file types:
|
||||
- **Images**: PNG, JPEG, GIF, WebP (sent to AI as base64 for vision analysis)
|
||||
- **Documents**: PDF, Word (.docx), Excel (.xlsx), PowerPoint (.pptx) (text extracted and sent to AI)
|
||||
- **Text**: TXT, Markdown, CSV, JSON, etc. (read directly)
|
||||
|
||||
### Python (`requests`)
|
||||
|
||||
```python
|
||||
|
||||
@ -18,11 +18,15 @@ Enabled by default (read-only mode). The agent can check its state but not set i
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
my_enabled: true # default: true
|
||||
my_set: false # default: false (read-only)
|
||||
my:
|
||||
enable: true # default: true
|
||||
allow_set: false # default: false (read-only)
|
||||
```
|
||||
|
||||
To allow the agent to set its configuration (e.g. switch models, adjust parameters), set `my_set: true`.
|
||||
To allow the agent to set its configuration (e.g. switch models, adjust parameters), set `tools.my.allow_set: true`.
|
||||
|
||||
Legacy `tools.myEnabled` / `tools.mySet` keys are auto-migrated on load, and
|
||||
rewritten in-place the next time `nanobot onboard` refreshes the config.
|
||||
|
||||
All modifications are held in memory only — restart restores defaults.
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ def _resolve_version() -> str:
|
||||
return _pkg_version("nanobot-ai")
|
||||
except PackageNotFoundError:
|
||||
# Source checkouts often import nanobot without installed dist-info.
|
||||
return _read_pyproject_version() or "0.1.5"
|
||||
return _read_pyproject_version() or "0.1.5.post1"
|
||||
|
||||
|
||||
__version__ = _resolve_version()
|
||||
|
||||
@ -3,15 +3,14 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
import platform
|
||||
from importlib.resources import files as pkg_files
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.utils.helpers import current_time_str
|
||||
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.utils.prompt_templates import render_template
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
from nanobot.utils.helpers import build_assistant_message, detect_image_mime
|
||||
from nanobot.utils.helpers import build_assistant_message, current_time_str, detect_image_mime
|
||||
from nanobot.utils.prompt_templates import render_template
|
||||
|
||||
|
||||
class ContextBuilder:
|
||||
@ -41,7 +40,7 @@ class ContextBuilder:
|
||||
parts.append(bootstrap)
|
||||
|
||||
memory = self.memory.get_memory_context()
|
||||
if memory:
|
||||
if memory and not self._is_template_content(self.memory.read_memory(), "memory/MEMORY.md"):
|
||||
parts.append(f"# Memory\n\n{memory}")
|
||||
|
||||
always_skills = self.skills.get_always_skills()
|
||||
@ -50,7 +49,7 @@ class ContextBuilder:
|
||||
if always_content:
|
||||
parts.append(f"# Active Skills\n\n{always_content}")
|
||||
|
||||
skills_summary = self.skills.build_skills_summary()
|
||||
skills_summary = self.skills.build_skills_summary(exclude=set(always_skills))
|
||||
if skills_summary:
|
||||
parts.append(render_template("agent/skills_section.md", skills_summary=skills_summary))
|
||||
|
||||
@ -116,6 +115,17 @@ class ContextBuilder:
|
||||
|
||||
return "\n\n".join(parts) if parts else ""
|
||||
|
||||
@staticmethod
|
||||
def _is_template_content(content: str, template_path: str) -> bool:
|
||||
"""Check if *content* is identical to the bundled template (user hasn't customized it)."""
|
||||
try:
|
||||
tpl = pkg_files("nanobot") / "templates" / template_path
|
||||
if tpl.is_file():
|
||||
return content.strip() == tpl.read_text(encoding="utf-8").strip()
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
history: list[dict[str, Any]],
|
||||
@ -160,7 +170,6 @@ class ContextBuilder:
|
||||
if not p.is_file():
|
||||
continue
|
||||
raw = p.read_bytes()
|
||||
# Detect real MIME type from magic bytes; fallback to filename guess
|
||||
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
|
||||
if not mime or not mime.startswith("image/"):
|
||||
continue
|
||||
|
||||
@ -17,10 +17,10 @@ from nanobot.agent.autocompact import AutoCompact
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.memory import Consolidator, Dream
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
@ -31,12 +31,14 @@ from nanobot.agent.tools.self import MyTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.utils.helpers import image_placeholder_text, truncate_text as truncate_text_fn
|
||||
from nanobot.utils.document import extract_documents
|
||||
from nanobot.utils.helpers import image_placeholder_text
|
||||
from nanobot.utils.helpers import truncate_text as truncate_text_fn
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -230,7 +232,7 @@ class AgentLoop:
|
||||
provider=provider,
|
||||
model=self.model,
|
||||
sessions=self.sessions,
|
||||
context_window_tokens=context_window_tokens,
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
build_messages=self.context.build_messages,
|
||||
get_tool_definitions=self.tools.get_definitions,
|
||||
max_completion_tokens=provider.generation.max_tokens,
|
||||
@ -246,8 +248,8 @@ class AgentLoop:
|
||||
model=self.model,
|
||||
)
|
||||
self._register_default_tools()
|
||||
if _tc.my_enabled:
|
||||
self.tools.register(MyTool(loop=self, modify_allowed=_tc.my_set))
|
||||
if _tc.my.enable:
|
||||
self.tools.register(MyTool(loop=self, modify_allowed=_tc.my.allow_set))
|
||||
self._runtime_vars: dict[str, Any] = {}
|
||||
self._current_iteration: int = 0
|
||||
self.commands = CommandRouter()
|
||||
@ -393,10 +395,12 @@ class AgentLoop:
|
||||
pending_msg = pending_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
user_content = self.context._build_user_content(
|
||||
pending_msg.content,
|
||||
pending_msg.media if pending_msg.media else None,
|
||||
)
|
||||
content = pending_msg.content
|
||||
media = pending_msg.media if pending_msg.media else None
|
||||
if media:
|
||||
content, media = extract_documents(content, media)
|
||||
media = media or None
|
||||
user_content = self.context._build_user_content(content, media)
|
||||
runtime_ctx = self.context._build_runtime_context(
|
||||
pending_msg.channel,
|
||||
pending_msg.chat_id,
|
||||
@ -667,6 +671,12 @@ class AgentLoop:
|
||||
content=final_content or "Background task completed.",
|
||||
)
|
||||
|
||||
# Extract document text from media at the processing boundary so all
|
||||
# channels benefit without format-specific logic in ContextBuilder.
|
||||
if msg.media:
|
||||
new_content, image_only = extract_documents(msg.content, msg.media)
|
||||
msg = dataclasses.replace(msg, content=new_content, media=image_only)
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
@ -967,13 +977,17 @@ class AgentLoop:
|
||||
session_key: str = "cli:direct",
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
media: list[str] | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a message directly and return the outbound payload."""
|
||||
await self._connect_mcp()
|
||||
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
|
||||
msg = InboundMessage(
|
||||
channel=channel, sender_id="user", chat_id=chat_id,
|
||||
content=content, media=media or [],
|
||||
)
|
||||
return await self._process_message(
|
||||
msg,
|
||||
session_key=session_key,
|
||||
|
||||
@ -239,13 +239,13 @@ class MemoryStore:
|
||||
pass
|
||||
# Fallback: read last line's cursor from the JSONL file.
|
||||
last = self._read_last_entry()
|
||||
if last:
|
||||
if last and last.get("cursor"):
|
||||
return last["cursor"] + 1
|
||||
return 1
|
||||
|
||||
def read_unprocessed_history(self, since_cursor: int) -> list[dict[str, Any]]:
|
||||
"""Return history entries with cursor > *since_cursor*."""
|
||||
return [e for e in self._read_entries() if e["cursor"] > since_cursor]
|
||||
return [e for e in self._read_entries() if e.get("cursor", 0) > since_cursor]
|
||||
|
||||
def compact_history(self) -> None:
|
||||
"""Drop oldest entries if the file exceeds *max_history_entries*."""
|
||||
@ -552,6 +552,13 @@ class Consolidator:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Single source of truth for the staleness threshold used in _annotate_with_ages
|
||||
# *and* in the Phase 1 prompt template (passed as `stale_threshold_days`).
|
||||
# Keep code and prompt aligned — if you bump this, the LLM's instruction string
|
||||
# updates automatically.
|
||||
_STALE_THRESHOLD_DAYS = 14
|
||||
|
||||
|
||||
class Dream:
|
||||
"""Two-phase memory processor: analyze history.jsonl, then edit files via AgentRunner.
|
||||
|
||||
@ -568,6 +575,7 @@ class Dream:
|
||||
max_batch_size: int = 20,
|
||||
max_iterations: int = 10,
|
||||
max_tool_result_chars: int = 16_000,
|
||||
annotate_line_ages: bool = True,
|
||||
):
|
||||
self.store = store
|
||||
self.provider = provider
|
||||
@ -575,6 +583,10 @@ class Dream:
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_iterations = max_iterations
|
||||
self.max_tool_result_chars = max_tool_result_chars
|
||||
# Kill switch for the git-blame-based per-line age annotation in Phase 1.
|
||||
# Default True keeps the #3212 behavior; set False to feed MEMORY.md raw
|
||||
# (e.g. if a specific LLM reacts poorly to the `← Nd` suffix).
|
||||
self.annotate_line_ages = annotate_line_ages
|
||||
self._runner = AgentRunner(provider)
|
||||
self._tools = self._build_tools()
|
||||
|
||||
@ -632,6 +644,52 @@ class Dream:
|
||||
|
||||
# -- main entry ----------------------------------------------------------
|
||||
|
||||
def _annotate_with_ages(self, content: str) -> str:
|
||||
"""Append per-line age suffixes to MEMORY.md content.
|
||||
|
||||
Each non-blank line whose age exceeds ``_STALE_THRESHOLD_DAYS`` gets a
|
||||
suffix like ``← 30d`` indicating days since last modification.
|
||||
Returns the original content unchanged if git is unavailable,
|
||||
annotate fails, or the line count doesn't match the age count
|
||||
(which can happen with an uncommitted working-tree edit — better to
|
||||
skip annotation than to tag the wrong line).
|
||||
SOUL.md and USER.md are never annotated.
|
||||
"""
|
||||
file_path = "memory/MEMORY.md"
|
||||
try:
|
||||
ages = self.store.git.line_ages(file_path)
|
||||
except Exception:
|
||||
logger.debug("line_ages failed for {}", file_path)
|
||||
return content
|
||||
if not ages:
|
||||
return content
|
||||
|
||||
had_trailing = content.endswith("\n")
|
||||
lines = content.splitlines()
|
||||
# If HEAD-blob line count disagrees with the working-tree content we
|
||||
# received, ages would be assigned to the wrong lines — skip entirely
|
||||
# and feed the LLM un-annotated content rather than misleading data.
|
||||
if len(lines) != len(ages):
|
||||
logger.debug(
|
||||
"line_ages length mismatch for {} (lines={}, ages={}); skipping annotation",
|
||||
file_path, len(lines), len(ages),
|
||||
)
|
||||
return content
|
||||
|
||||
annotated: list[str] = []
|
||||
for line, age in zip(lines, ages):
|
||||
if not line.strip():
|
||||
annotated.append(line)
|
||||
continue
|
||||
if age.age_days > _STALE_THRESHOLD_DAYS:
|
||||
annotated.append(f"{line} \u2190 {age.age_days}d")
|
||||
else:
|
||||
annotated.append(line)
|
||||
result = "\n".join(annotated)
|
||||
if had_trailing:
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
async def run(self) -> bool:
|
||||
"""Process unprocessed history entries. Returns True if work was done."""
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
@ -652,9 +710,14 @@ class Dream:
|
||||
f"[{e['timestamp']}] {e['content']}" for e in batch
|
||||
)
|
||||
|
||||
# Current file contents
|
||||
# Current file contents + per-line age annotations (MEMORY.md only)
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_memory = self.store.read_memory() or "(empty)"
|
||||
raw_memory = self.store.read_memory() or "(empty)"
|
||||
current_memory = (
|
||||
self._annotate_with_ages(raw_memory)
|
||||
if self.annotate_line_ages
|
||||
else raw_memory
|
||||
)
|
||||
current_soul = self.store.read_soul() or "(empty)"
|
||||
current_user = self.store.read_user() or "(empty)"
|
||||
|
||||
@ -676,7 +739,11 @@ class Dream:
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": render_template("agent/dream_phase1.md", strip=True),
|
||||
"content": render_template(
|
||||
"agent/dream_phase1.md",
|
||||
strip=True,
|
||||
stale_threshold_days=_STALE_THRESHOLD_DAYS,
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": phase1_prompt},
|
||||
],
|
||||
@ -759,7 +826,9 @@ class Dream:
|
||||
# Git auto-commit (only when there are actual changes)
|
||||
if changelog and self.store.git.is_initialized():
|
||||
ts = batch[-1]["timestamp"]
|
||||
sha = self.store.git.auto_commit(f"dream: {ts}, {len(changelog)} change(s)")
|
||||
summary = f"dream: {ts}, {len(changelog)} change(s)"
|
||||
commit_msg = f"{summary}\n\n{analysis.strip()}"
|
||||
sha = self.store.git.auto_commit(commit_msg)
|
||||
if sha:
|
||||
logger.info("Dream commit: {}", sha)
|
||||
|
||||
|
||||
@ -6,6 +6,8 @@ import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
# Default builtin skills directory (relative to this file)
|
||||
BUILTIN_SKILLS_DIR = Path(__file__).parent.parent / "skills"
|
||||
|
||||
@ -16,10 +18,6 @@ _STRIP_SKILL_FRONTMATTER = re.compile(
|
||||
)
|
||||
|
||||
|
||||
def _escape_xml(text: str) -> str:
|
||||
return text.replace("&", "&").replace("<", "<").replace(">", ">")
|
||||
|
||||
|
||||
class SkillsLoader:
|
||||
"""
|
||||
Loader for agent skills.
|
||||
@ -110,39 +108,37 @@ class SkillsLoader:
|
||||
]
|
||||
return "\n\n---\n\n".join(parts)
|
||||
|
||||
def build_skills_summary(self) -> str:
|
||||
def build_skills_summary(self, exclude: set[str] | None = None) -> str:
|
||||
"""
|
||||
Build a summary of all skills (name, description, path, availability).
|
||||
|
||||
This is used for progressive loading - the agent can read the full
|
||||
skill content using read_file when needed.
|
||||
|
||||
Args:
|
||||
exclude: Set of skill names to omit from the summary.
|
||||
|
||||
Returns:
|
||||
XML-formatted skills summary.
|
||||
Markdown-formatted skills summary.
|
||||
"""
|
||||
all_skills = self.list_skills(filter_unavailable=False)
|
||||
if not all_skills:
|
||||
return ""
|
||||
|
||||
lines: list[str] = ["<skills>"]
|
||||
lines: list[str] = []
|
||||
for entry in all_skills:
|
||||
skill_name = entry["name"]
|
||||
if exclude and skill_name in exclude:
|
||||
continue
|
||||
meta = self._get_skill_meta(skill_name)
|
||||
available = self._check_requirements(meta)
|
||||
lines.extend(
|
||||
[
|
||||
f' <skill available="{str(available).lower()}">',
|
||||
f" <name>{_escape_xml(skill_name)}</name>",
|
||||
f" <description>{_escape_xml(self._get_skill_description(skill_name))}</description>",
|
||||
f" <location>{entry['path']}</location>",
|
||||
]
|
||||
)
|
||||
if not available:
|
||||
desc = self._get_skill_description(skill_name)
|
||||
if available:
|
||||
lines.append(f"- **{skill_name}** — {desc} `{entry['path']}`")
|
||||
else:
|
||||
missing = self._get_missing_requirements(meta)
|
||||
if missing:
|
||||
lines.append(f" <requires>{_escape_xml(missing)}</requires>")
|
||||
lines.append(" </skill>")
|
||||
lines.append("</skills>")
|
||||
suffix = f" (unavailable: {missing})" if missing else " (unavailable)"
|
||||
lines.append(f"- **{skill_name}** — {desc}{suffix} `{entry['path']}`")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _get_missing_requirements(self, skill_meta: dict) -> str:
|
||||
@ -171,11 +167,19 @@ class SkillsLoader:
|
||||
return content[match.end():].strip()
|
||||
return content
|
||||
|
||||
def _parse_nanobot_metadata(self, raw: str) -> dict:
|
||||
"""Parse skill metadata JSON from frontmatter (supports nanobot and openclaw keys)."""
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
def _parse_nanobot_metadata(self, raw: object) -> dict:
|
||||
"""Extract nanobot/openclaw metadata from a frontmatter field.
|
||||
|
||||
``raw`` may be a dict (already parsed by yaml.safe_load) or a JSON str.
|
||||
"""
|
||||
if isinstance(raw, dict):
|
||||
data = raw
|
||||
elif isinstance(raw, str):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
else:
|
||||
return {}
|
||||
if not isinstance(data, dict):
|
||||
return {}
|
||||
@ -193,8 +197,8 @@ class SkillsLoader:
|
||||
|
||||
def _get_skill_meta(self, name: str) -> dict:
|
||||
"""Get nanobot metadata for a skill (cached in frontmatter)."""
|
||||
meta = self.get_skill_metadata(name) or {}
|
||||
return self._parse_nanobot_metadata(meta.get("metadata", ""))
|
||||
raw_meta = self.get_skill_metadata(name) or {}
|
||||
return self._parse_nanobot_metadata(raw_meta.get("metadata"))
|
||||
|
||||
def get_always_skills(self) -> list[str]:
|
||||
"""Get skills marked as always=true that meet requirements."""
|
||||
@ -203,7 +207,7 @@ class SkillsLoader:
|
||||
for entry in self.list_skills(filter_unavailable=True)
|
||||
if (meta := self.get_skill_metadata(entry["name"]) or {})
|
||||
and (
|
||||
self._parse_nanobot_metadata(meta.get("metadata", "")).get("always")
|
||||
self._parse_nanobot_metadata(meta.get("metadata")).get("always")
|
||||
or meta.get("always")
|
||||
)
|
||||
]
|
||||
@ -224,10 +228,15 @@ class SkillsLoader:
|
||||
match = _STRIP_SKILL_FRONTMATTER.match(content)
|
||||
if not match:
|
||||
return None
|
||||
metadata: dict[str, str] = {}
|
||||
for line in match.group(1).splitlines():
|
||||
if ":" not in line:
|
||||
continue
|
||||
key, value = line.split(":", 1)
|
||||
metadata[key.strip()] = value.strip().strip('"\'')
|
||||
try:
|
||||
parsed = yaml.safe_load(match.group(1))
|
||||
except yaml.YAMLError:
|
||||
return None
|
||||
if not isinstance(parsed, dict):
|
||||
return None
|
||||
# yaml.safe_load returns native types (int, bool, list, etc.);
|
||||
# keep values as-is so downstream consumers get correct types.
|
||||
metadata: dict[str, object] = {}
|
||||
for key, value in parsed.items():
|
||||
metadata[str(key)] = value
|
||||
return metadata
|
||||
|
||||
@ -301,3 +301,11 @@ class SubagentManager:
|
||||
def get_running_count(self) -> int:
|
||||
"""Return the number of currently running subagents."""
|
||||
return len(self._running_tasks)
|
||||
|
||||
def get_running_count_by_session(self, session_key: str) -> int:
|
||||
"""Return the number of currently running subagents for a session."""
|
||||
tids = self._session_tasks.get(session_key, set())
|
||||
return sum(
|
||||
1 for tid in tids
|
||||
if tid in self._running_tasks and not self._running_tasks[tid].done()
|
||||
)
|
||||
|
||||
@ -284,7 +284,7 @@ class MyTool(Tool):
|
||||
if action in ("inspect", "check"):
|
||||
return self._inspect(key)
|
||||
if not self._modify_allowed:
|
||||
return "Error: set is disabled (my_set is False)"
|
||||
return "Error: set is disabled (tools.my.allow_set is false)"
|
||||
if action in ("modify", "set"):
|
||||
return self._modify(key, value)
|
||||
return f"Unknown action: {action}"
|
||||
|
||||
@ -96,10 +96,37 @@ class WebSearchTool(Tool):
|
||||
self.config = config if config is not None else WebSearchConfig()
|
||||
self.proxy = proxy
|
||||
|
||||
def _effective_provider(self) -> str:
|
||||
"""Resolve the backend that execute() will actually use."""
|
||||
provider = self.config.provider.strip().lower() or "brave"
|
||||
if provider == "duckduckgo":
|
||||
return "duckduckgo"
|
||||
if provider == "brave":
|
||||
api_key = self.config.api_key or os.environ.get("BRAVE_API_KEY", "")
|
||||
return "brave" if api_key else "duckduckgo"
|
||||
if provider == "tavily":
|
||||
api_key = self.config.api_key or os.environ.get("TAVILY_API_KEY", "")
|
||||
return "tavily" if api_key else "duckduckgo"
|
||||
if provider == "searxng":
|
||||
base_url = (self.config.base_url or os.environ.get("SEARXNG_BASE_URL", "")).strip()
|
||||
return "searxng" if base_url else "duckduckgo"
|
||||
if provider == "jina":
|
||||
api_key = self.config.api_key or os.environ.get("JINA_API_KEY", "")
|
||||
return "jina" if api_key else "duckduckgo"
|
||||
if provider == "kagi":
|
||||
api_key = self.config.api_key or os.environ.get("KAGI_API_KEY", "")
|
||||
return "kagi" if api_key else "duckduckgo"
|
||||
return provider
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
"""DuckDuckGo searches are serialized because ddgs is not concurrency-safe."""
|
||||
return self._effective_provider() == "duckduckgo"
|
||||
|
||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||
provider = self.config.provider.strip().lower() or "brave"
|
||||
n = min(max(count or self.config.max_results, 1), 10)
|
||||
|
||||
@ -7,15 +7,30 @@ All requests route to a single persistent API session.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json as _json
|
||||
import mimetypes
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.utils.helpers import safe_filename
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB
|
||||
_DATA_URL_RE = re.compile(r"^data:([^;]+);base64,(.+)$", re.DOTALL)
|
||||
|
||||
|
||||
class _FileSizeExceeded(Exception):
|
||||
"""Raised when an uploaded file exceeds the size limit."""
|
||||
|
||||
|
||||
API_SESSION_KEY = "api:default"
|
||||
API_CHAT_ID = "default"
|
||||
|
||||
@ -24,6 +39,7 @@ API_CHAT_ID = "default"
|
||||
# Response helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _error_json(status: int, message: str, err_type: str = "invalid_request_error") -> web.Response:
|
||||
return web.json_response(
|
||||
{"error": {"message": message, "type": err_type, "code": status}},
|
||||
@ -56,50 +72,231 @@ def _response_text(value: Any) -> str:
|
||||
return str(getattr(value, "content") or "")
|
||||
return str(value)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSE helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sse_chunk(delta: str, model: str, chunk_id: str, finish_reason: str | None = None) -> bytes:
|
||||
"""Format a single OpenAI-compatible SSE chunk."""
|
||||
payload = {
|
||||
"id": chunk_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": delta} if delta else {},
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
return f"data: {_json.dumps(payload)}\n\n".encode()
|
||||
|
||||
|
||||
_SSE_DONE = b"data: [DONE]\n\n"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Upload helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _save_base64_data_url(data_url: str, media_dir: Path) -> str | None:
|
||||
"""Decode a data:...;base64,... URL and save to disk."""
|
||||
m = _DATA_URL_RE.match(data_url)
|
||||
if not m:
|
||||
return None
|
||||
mime_type, b64_payload = m.group(1), m.group(2)
|
||||
try:
|
||||
raw = base64.b64decode(b64_payload)
|
||||
except Exception:
|
||||
return None
|
||||
if len(raw) > MAX_FILE_SIZE:
|
||||
raise _FileSizeExceeded(f"File exceeds {MAX_FILE_SIZE // (1024 * 1024)}MB limit")
|
||||
ext = mimetypes.guess_extension(mime_type) or ".bin"
|
||||
filename = f"{uuid.uuid4().hex[:12]}{ext}"
|
||||
dest = media_dir / safe_filename(filename)
|
||||
dest.write_bytes(raw)
|
||||
return str(dest)
|
||||
|
||||
|
||||
def _parse_json_content(body: dict) -> tuple[str, list[str]]:
|
||||
"""Parse JSON request body. Returns (text, media_paths)."""
|
||||
messages = body.get("messages")
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
raise ValueError("Only a single user message is supported")
|
||||
message = messages[0]
|
||||
if not isinstance(message, dict) or message.get("role") != "user":
|
||||
raise ValueError("Only a single user message is supported")
|
||||
|
||||
user_content = message.get("content", "")
|
||||
media_dir = get_media_dir("api")
|
||||
media_paths: list[str] = []
|
||||
|
||||
if isinstance(user_content, list):
|
||||
text_parts: list[str] = []
|
||||
for part in user_content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
if part.get("type") == "text":
|
||||
text_parts.append(part.get("text", ""))
|
||||
elif part.get("type") == "image_url":
|
||||
url = part.get("image_url", {}).get("url", "")
|
||||
if url.startswith("data:"):
|
||||
saved = _save_base64_data_url(url, media_dir)
|
||||
if saved:
|
||||
media_paths.append(saved)
|
||||
elif url:
|
||||
raise ValueError(
|
||||
"Remote image URLs are not supported. "
|
||||
"Use base64 data URLs or upload files via multipart/form-data."
|
||||
)
|
||||
text = " ".join(text_parts)
|
||||
elif isinstance(user_content, str):
|
||||
text = user_content
|
||||
else:
|
||||
raise ValueError("Invalid content format")
|
||||
|
||||
return text, media_paths
|
||||
|
||||
|
||||
async def _parse_multipart(request: web.Request) -> tuple[str, list[str], str | None, str | None]:
|
||||
"""Parse multipart/form-data. Returns (text, media_paths, session_id, model)."""
|
||||
media_dir = get_media_dir("api")
|
||||
reader = await request.multipart()
|
||||
text = ""
|
||||
session_id = None
|
||||
model = None
|
||||
media_paths: list[str] = []
|
||||
|
||||
while True:
|
||||
part = await reader.next()
|
||||
if part is None:
|
||||
break
|
||||
if part.name == "message":
|
||||
text = (await part.read()).decode("utf-8")
|
||||
elif part.name == "session_id":
|
||||
session_id = (await part.read()).decode("utf-8").strip()
|
||||
elif part.name == "model":
|
||||
model = (await part.read()).decode("utf-8").strip()
|
||||
elif part.name == "files":
|
||||
raw = await part.read()
|
||||
if len(raw) > MAX_FILE_SIZE:
|
||||
raise _FileSizeExceeded(
|
||||
f"File '{part.filename}' exceeds {MAX_FILE_SIZE // (1024 * 1024)}MB limit"
|
||||
)
|
||||
base = safe_filename(part.filename or "upload.bin")
|
||||
filename = f"{uuid.uuid4().hex[:12]}_{base}"
|
||||
dest = media_dir / filename
|
||||
dest.write_bytes(raw)
|
||||
media_paths.append(str(dest))
|
||||
|
||||
if not text:
|
||||
text = "请分析上传的文件"
|
||||
|
||||
return text, media_paths, session_id, model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Route handlers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
"""POST /v1/chat/completions"""
|
||||
|
||||
# --- Parse body ---
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return _error_json(400, "Invalid JSON body")
|
||||
|
||||
messages = body.get("messages")
|
||||
if not isinstance(messages, list) or len(messages) != 1:
|
||||
return _error_json(400, "Only a single user message is supported")
|
||||
|
||||
# Stream not yet supported
|
||||
if body.get("stream", False):
|
||||
return _error_json(400, "stream=true is not supported yet. Set stream=false or omit it.")
|
||||
|
||||
message = messages[0]
|
||||
if not isinstance(message, dict) or message.get("role") != "user":
|
||||
return _error_json(400, "Only a single user message is supported")
|
||||
user_content = message.get("content", "")
|
||||
if isinstance(user_content, list):
|
||||
# Multi-modal content array — extract text parts
|
||||
user_content = " ".join(
|
||||
part.get("text", "") for part in user_content if part.get("type") == "text"
|
||||
)
|
||||
"""POST /v1/chat/completions — supports JSON and multipart/form-data."""
|
||||
content_type = request.content_type or ""
|
||||
if not isinstance(content_type, str):
|
||||
content_type = ""
|
||||
|
||||
agent_loop = request.app["agent_loop"]
|
||||
timeout_s: float = request.app.get("request_timeout", 120.0)
|
||||
model_name: str = request.app.get("model_name", "nanobot")
|
||||
if (requested_model := body.get("model")) and requested_model != model_name:
|
||||
|
||||
stream = False
|
||||
try:
|
||||
if content_type.startswith("multipart/"):
|
||||
text, media_paths, session_id, requested_model = await _parse_multipart(request)
|
||||
else:
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return _error_json(400, "Invalid JSON body")
|
||||
stream = body.get("stream", False)
|
||||
requested_model = body.get("model")
|
||||
text, media_paths = _parse_json_content(body)
|
||||
session_id = body.get("session_id")
|
||||
except ValueError as e:
|
||||
return _error_json(400, str(e))
|
||||
except _FileSizeExceeded as e:
|
||||
return _error_json(413, str(e), err_type="invalid_request_error")
|
||||
except Exception:
|
||||
logger.exception("Error parsing upload")
|
||||
return _error_json(413, "File too large or invalid upload")
|
||||
|
||||
if requested_model and requested_model != model_name:
|
||||
return _error_json(400, f"Only configured model '{model_name}' is available")
|
||||
|
||||
session_key = f"api:{body['session_id']}" if body.get("session_id") else API_SESSION_KEY
|
||||
session_key = f"api:{session_id}" if session_id else API_SESSION_KEY
|
||||
session_locks: dict[str, asyncio.Lock] = request.app["session_locks"]
|
||||
session_lock = session_locks.setdefault(session_key, asyncio.Lock())
|
||||
|
||||
logger.info("API request session_key={} content={}", session_key, user_content[:80])
|
||||
logger.info(
|
||||
"API request session_key={} media={} text={} stream={}",
|
||||
session_key, len(media_paths), text[:80], stream,
|
||||
)
|
||||
# -- streaming path --
|
||||
if stream:
|
||||
resp = web.StreamResponse()
|
||||
resp.content_type = "text/event-stream"
|
||||
resp.headers["Cache-Control"] = "no-cache"
|
||||
resp.headers["Connection"] = "keep-alive"
|
||||
resp.enable_compression()
|
||||
await resp.prepare(request)
|
||||
|
||||
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||
|
||||
async def _on_stream(token: str) -> None:
|
||||
await queue.put(token)
|
||||
|
||||
async def _on_stream_end(*_a: Any, **_kw: Any) -> None:
|
||||
await queue.put(None)
|
||||
|
||||
async def _run() -> None:
|
||||
try:
|
||||
async with session_lock:
|
||||
await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=text,
|
||||
media=media_paths if media_paths else None,
|
||||
session_key=session_key,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
on_stream=_on_stream,
|
||||
on_stream_end=_on_stream_end,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Streaming error for session {}", session_key)
|
||||
await queue.put(None)
|
||||
|
||||
task = asyncio.create_task(_run())
|
||||
try:
|
||||
while True:
|
||||
token = await queue.get()
|
||||
if token is None:
|
||||
break
|
||||
await resp.write(_sse_chunk(token, model_name, chunk_id))
|
||||
finally:
|
||||
task.cancel()
|
||||
|
||||
await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop"))
|
||||
await resp.write(_SSE_DONE)
|
||||
return resp
|
||||
|
||||
# -- non-streaming path (original logic) --
|
||||
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
try:
|
||||
@ -107,7 +304,8 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
content=text,
|
||||
media=media_paths if media_paths else None,
|
||||
session_key=session_key,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
@ -117,13 +315,11 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
response_text = _response_text(response)
|
||||
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning(
|
||||
"Empty response for session {}, retrying",
|
||||
session_key,
|
||||
)
|
||||
logger.warning("Empty response for session {}, retrying", session_key)
|
||||
retry_response = await asyncio.wait_for(
|
||||
agent_loop.process_direct(
|
||||
content=user_content,
|
||||
content=text,
|
||||
media=media_paths if media_paths else None,
|
||||
session_key=session_key,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
@ -132,10 +328,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
)
|
||||
response_text = _response_text(retry_response)
|
||||
if not response_text or not response_text.strip():
|
||||
logger.warning(
|
||||
"Empty response after retry for session {}, using fallback",
|
||||
session_key,
|
||||
)
|
||||
logger.warning("Empty response after retry, using fallback")
|
||||
response_text = _FALLBACK
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
@ -153,17 +346,19 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
||||
async def handle_models(request: web.Request) -> web.Response:
|
||||
"""GET /v1/models"""
|
||||
model_name = request.app.get("model_name", "nanobot")
|
||||
return web.json_response({
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "nanobot",
|
||||
}
|
||||
],
|
||||
})
|
||||
return web.json_response(
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": model_name,
|
||||
"object": "model",
|
||||
"created": 0,
|
||||
"owned_by": "nanobot",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def handle_health(request: web.Request) -> web.Response:
|
||||
@ -175,7 +370,10 @@ async def handle_health(request: web.Request) -> web.Response:
|
||||
# App factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0) -> web.Application:
|
||||
|
||||
def create_app(
|
||||
agent_loop, model_name: str = "nanobot", request_timeout: float = 120.0
|
||||
) -> web.Application:
|
||||
"""Create the aiohttp application.
|
||||
|
||||
Args:
|
||||
@ -183,7 +381,7 @@ def create_app(agent_loop, model_name: str = "nanobot", request_timeout: float =
|
||||
model_name: Model name reported in responses.
|
||||
request_timeout: Per-request timeout in seconds.
|
||||
"""
|
||||
app = web.Application()
|
||||
app = web.Application(client_max_size=20 * 1024 * 1024) # 20MB for base64 images
|
||||
app["agent_loop"] = agent_loop
|
||||
app["model_name"] = model_name
|
||||
app["request_timeout"] = request_timeout
|
||||
|
||||
@ -123,7 +123,13 @@ class BaseChannel(ABC):
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
if isinstance(self.config, dict):
|
||||
if "allow_from" in self.config:
|
||||
allow_list = self.config.get("allow_from")
|
||||
else:
|
||||
allow_list = self.config.get("allowFrom", [])
|
||||
else:
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
if not allow_list:
|
||||
logger.warning("{}: allow_from is empty — all access denied", self.name)
|
||||
return False
|
||||
|
||||
@ -53,6 +53,7 @@ class DiscordConfig(Base):
|
||||
enabled: bool = False
|
||||
token: str = ""
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
allow_channels: list[str] = Field(default_factory=list) # Allowed channel IDs (empty = all)
|
||||
intents: int = 37377
|
||||
group_policy: Literal["mention", "open"] = "mention"
|
||||
read_receipt_emoji: str = "👀"
|
||||
@ -450,7 +451,6 @@ class DiscordChannel(BaseChannel):
|
||||
await self._start_typing(message.channel)
|
||||
|
||||
# Add read receipt reaction immediately, working emoji after delay
|
||||
channel_id = self._channel_key(message.channel)
|
||||
try:
|
||||
await message.add_reaction(self.config.read_receipt_emoji)
|
||||
self._pending_reactions[channel_id] = message
|
||||
@ -534,6 +534,12 @@ class DiscordChannel(BaseChannel):
|
||||
"""Check if inbound Discord message should be processed."""
|
||||
if not self.is_allowed(sender_id):
|
||||
return False
|
||||
# Channel-based filtering: only respond in allowed channels
|
||||
allow_channels = self.config.allow_channels
|
||||
if allow_channels:
|
||||
channel_id = self._channel_key(message.channel)
|
||||
if channel_id not in allow_channels:
|
||||
return False
|
||||
if message.guild is not None and not self._should_respond_in_group(message, content):
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -86,7 +86,15 @@ class ChannelManager:
|
||||
|
||||
def _validate_allow_from(self) -> None:
|
||||
for name, ch in self.channels.items():
|
||||
if getattr(ch.config, "allow_from", None) == []:
|
||||
cfg = ch.config
|
||||
if isinstance(cfg, dict):
|
||||
if "allow_from" in cfg:
|
||||
allow = cfg.get("allow_from")
|
||||
else:
|
||||
allow = cfg.get("allowFrom")
|
||||
else:
|
||||
allow = getattr(cfg, "allow_from", None)
|
||||
if allow == []:
|
||||
raise SystemExit(
|
||||
f'Error: "{name}" has empty allowFrom (denies all). '
|
||||
f'Set ["*"] to allow everyone, or add specific user IDs.'
|
||||
|
||||
@ -57,7 +57,7 @@ class MSTeamsConfig(Base):
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
reply_in_thread: bool = True
|
||||
mention_only_response: str = "Hi — what can I help with?"
|
||||
validate_inbound_auth: bool = False
|
||||
validate_inbound_auth: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -116,9 +116,9 @@ class MSTeamsChannel(BaseChannel):
|
||||
|
||||
if not self.config.validate_inbound_auth:
|
||||
logger.warning(
|
||||
"MSTeams inbound auth validation is DISABLED. "
|
||||
"MSTeams inbound auth validation was explicitly DISABLED in config. "
|
||||
"Anyone who knows the webhook URL can send messages as any user. "
|
||||
"Set validateInboundAuth: true in config for production use."
|
||||
"Only disable this for local development or controlled testing."
|
||||
)
|
||||
|
||||
self._loop = asyncio.get_running_loop()
|
||||
@ -274,6 +274,14 @@ class MSTeamsChannel(BaseChannel):
|
||||
logger.debug("MSTeams ignoring empty message after Teams text sanitization")
|
||||
return
|
||||
|
||||
if not self.is_allowed(sender_id):
|
||||
logger.warning(
|
||||
"Access denied for sender {} on channel {}. "
|
||||
"Add them to allowFrom list in config to grant access.",
|
||||
sender_id, self.name,
|
||||
)
|
||||
return
|
||||
|
||||
self._conversation_refs[conversation_id] = ConversationRef(
|
||||
service_url=service_url,
|
||||
conversation_id=conversation_id,
|
||||
|
||||
@ -5,6 +5,7 @@ import re
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||
@ -13,8 +14,6 @@ from slackify_markdown import slackify_markdown
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from pydantic import Field
|
||||
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
@ -50,6 +49,9 @@ class SlackChannel(BaseChannel):
|
||||
|
||||
name = "slack"
|
||||
display_name = "Slack"
|
||||
_SLACK_ID_RE = re.compile(r"^[CDGUW][A-Z0-9]{2,}$")
|
||||
_SLACK_CHANNEL_REF_RE = re.compile(r"^<#([A-Z0-9]+)(?:\|[^>]+)?>$")
|
||||
_SLACK_USER_REF_RE = re.compile(r"^<@([A-Z0-9]+)(?:\|[^>]+)?>$")
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
@ -63,6 +65,7 @@ class SlackChannel(BaseChannel):
|
||||
self._web_client: AsyncWebClient | None = None
|
||||
self._socket_client: SocketModeClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
self._target_cache: dict[str, str] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Slack Socket Mode client."""
|
||||
@ -113,17 +116,23 @@ class SlackChannel(BaseChannel):
|
||||
logger.warning("Slack client not running")
|
||||
return
|
||||
try:
|
||||
target_chat_id = await self._resolve_target_chat_id(msg.chat_id)
|
||||
slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
|
||||
thread_ts = slack_meta.get("thread_ts")
|
||||
channel_type = slack_meta.get("channel_type")
|
||||
origin_chat_id = str((slack_meta.get("event", {}) or {}).get("channel") or msg.chat_id)
|
||||
# Slack DMs don't use threads; channel/group replies may keep thread_ts.
|
||||
thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None
|
||||
thread_ts_param = (
|
||||
thread_ts
|
||||
if thread_ts and channel_type != "im" and target_chat_id == origin_chat_id
|
||||
else None
|
||||
)
|
||||
|
||||
# Slack rejects empty text payloads. Keep media-only messages media-only,
|
||||
# but send a single blank message when the bot has no text or files to send.
|
||||
if msg.content or not (msg.media or []):
|
||||
await self._web_client.chat_postMessage(
|
||||
channel=msg.chat_id,
|
||||
channel=target_chat_id,
|
||||
text=self._to_mrkdwn(msg.content) if msg.content else " ",
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
@ -131,7 +140,7 @@ class SlackChannel(BaseChannel):
|
||||
for media_path in msg.media or []:
|
||||
try:
|
||||
await self._web_client.files_upload_v2(
|
||||
channel=msg.chat_id,
|
||||
channel=target_chat_id,
|
||||
file=media_path,
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
@ -141,12 +150,123 @@ class SlackChannel(BaseChannel):
|
||||
# Update reaction emoji when the final (non-progress) response is sent
|
||||
if not (msg.metadata or {}).get("_progress"):
|
||||
event = slack_meta.get("event", {})
|
||||
await self._update_react_emoji(msg.chat_id, event.get("ts"))
|
||||
await self._update_react_emoji(origin_chat_id, event.get("ts"))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending Slack message: {}", e)
|
||||
raise
|
||||
|
||||
async def _resolve_target_chat_id(self, target: str) -> str:
|
||||
"""Resolve human-friendly Slack targets to concrete IDs when needed."""
|
||||
if not self._web_client:
|
||||
return target
|
||||
|
||||
target = target.strip()
|
||||
if not target:
|
||||
return target
|
||||
|
||||
if match := self._SLACK_CHANNEL_REF_RE.fullmatch(target):
|
||||
return match.group(1)
|
||||
if match := self._SLACK_USER_REF_RE.fullmatch(target):
|
||||
return await self._open_dm_for_user(match.group(1))
|
||||
if self._SLACK_ID_RE.fullmatch(target):
|
||||
if target.startswith(("U", "W")):
|
||||
return await self._open_dm_for_user(target)
|
||||
return target
|
||||
|
||||
if target.startswith("#"):
|
||||
return await self._resolve_channel_name(target[1:])
|
||||
if target.startswith("@"):
|
||||
return await self._resolve_user_handle(target[1:])
|
||||
|
||||
try:
|
||||
return await self._resolve_channel_name(target)
|
||||
except ValueError:
|
||||
return await self._resolve_user_handle(target)
|
||||
|
||||
async def _resolve_channel_name(self, name: str) -> str:
|
||||
normalized = self._normalize_target_name(name)
|
||||
if not normalized:
|
||||
raise ValueError("Slack target channel name is empty")
|
||||
|
||||
cache_key = f"channel:{normalized}"
|
||||
if cache_key in self._target_cache:
|
||||
return self._target_cache[cache_key]
|
||||
|
||||
cursor: str | None = None
|
||||
while True:
|
||||
response = await self._web_client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=200,
|
||||
cursor=cursor,
|
||||
)
|
||||
for channel in response.get("channels", []):
|
||||
if self._normalize_target_name(str(channel.get("name") or "")) == normalized:
|
||||
channel_id = str(channel.get("id") or "")
|
||||
if channel_id:
|
||||
self._target_cache[cache_key] = channel_id
|
||||
return channel_id
|
||||
cursor = ((response.get("response_metadata") or {}).get("next_cursor") or "").strip()
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
raise ValueError(
|
||||
f"Slack channel '{name}' was not found. Use a joined channel name like "
|
||||
f"'#general' or a concrete channel ID."
|
||||
)
|
||||
|
||||
async def _resolve_user_handle(self, handle: str) -> str:
|
||||
normalized = self._normalize_target_name(handle)
|
||||
if not normalized:
|
||||
raise ValueError("Slack target user handle is empty")
|
||||
|
||||
cache_key = f"user:{normalized}"
|
||||
if cache_key in self._target_cache:
|
||||
return self._target_cache[cache_key]
|
||||
|
||||
cursor: str | None = None
|
||||
while True:
|
||||
response = await self._web_client.users_list(limit=200, cursor=cursor)
|
||||
for member in response.get("members", []):
|
||||
if self._member_matches_handle(member, normalized):
|
||||
user_id = str(member.get("id") or "")
|
||||
if not user_id:
|
||||
continue
|
||||
dm_id = await self._open_dm_for_user(user_id)
|
||||
self._target_cache[cache_key] = dm_id
|
||||
return dm_id
|
||||
cursor = ((response.get("response_metadata") or {}).get("next_cursor") or "").strip()
|
||||
if not cursor:
|
||||
break
|
||||
|
||||
raise ValueError(
|
||||
f"Slack user '{handle}' was not found. Use '@name' or a concrete DM/channel ID."
|
||||
)
|
||||
|
||||
async def _open_dm_for_user(self, user_id: str) -> str:
|
||||
response = await self._web_client.conversations_open(users=user_id)
|
||||
channel_id = str(((response.get("channel") or {}).get("id")) or "")
|
||||
if not channel_id:
|
||||
raise ValueError(f"Slack DM target for user '{user_id}' could not be opened.")
|
||||
return channel_id
|
||||
|
||||
@staticmethod
|
||||
def _normalize_target_name(value: str) -> str:
|
||||
return value.strip().lstrip("#@").lower()
|
||||
|
||||
@classmethod
|
||||
def _member_matches_handle(cls, member: dict[str, Any], normalized: str) -> bool:
|
||||
profile = member.get("profile") or {}
|
||||
candidates = {
|
||||
str(member.get("name") or ""),
|
||||
str(profile.get("display_name") or ""),
|
||||
str(profile.get("display_name_normalized") or ""),
|
||||
str(profile.get("real_name") or ""),
|
||||
str(profile.get("real_name_normalized") or ""),
|
||||
}
|
||||
return normalized in {cls._normalize_target_name(candidate) for candidate in candidates if candidate}
|
||||
|
||||
async def _on_socket_request(
|
||||
self,
|
||||
client: SocketModeClient,
|
||||
|
||||
@ -302,13 +302,22 @@ class WecomChannel(BaseChannel):
|
||||
|
||||
elif msg_type == "mixed":
|
||||
# Mixed content contains multiple message items
|
||||
msg_items = body.get("mixed", {}).get("item", [])
|
||||
msg_items = body.get("mixed", {}).get("msg_item", [])
|
||||
for item in msg_items:
|
||||
item_type = item.get("type", "")
|
||||
item_type = item.get("msgtype", "")
|
||||
if item_type == "text":
|
||||
text = item.get("text", {}).get("content", "")
|
||||
if text:
|
||||
content_parts.append(text)
|
||||
elif item_type == "image":
|
||||
file_url = item.get("image", {}).get("url", "")
|
||||
aes_key = item.get("image", {}).get("aeskey", "")
|
||||
if file_url and aes_key:
|
||||
file_path = await self._download_and_save_media(file_url, aes_key, "image")
|
||||
if file_path:
|
||||
filename = os.path.basename(file_path)
|
||||
content_parts.append(f"[image: {filename}]")
|
||||
media_paths.append(file_path)
|
||||
else:
|
||||
content_parts.append(MSG_TYPE_MAP.get(item_type, f"[{item_type}]"))
|
||||
|
||||
|
||||
@ -823,12 +823,55 @@ def gateway(
|
||||
|
||||
console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
|
||||
|
||||
async def _health_server(host: str, health_port: int):
|
||||
"""Lightweight HTTP health endpoint on the gateway port."""
|
||||
import json as _json
|
||||
|
||||
async def handle(reader, writer):
|
||||
try:
|
||||
data = await asyncio.wait_for(reader.read(4096), timeout=5)
|
||||
except (asyncio.TimeoutError, ConnectionError):
|
||||
writer.close()
|
||||
return
|
||||
|
||||
request_line = data.split(b"\r\n", 1)[0].decode("utf-8", errors="replace")
|
||||
method, path = "", ""
|
||||
parts = request_line.split(" ")
|
||||
if len(parts) >= 2:
|
||||
method, path = parts[0], parts[1]
|
||||
|
||||
if method == "GET" and path == "/health":
|
||||
body = _json.dumps({"status": "ok"})
|
||||
resp = (
|
||||
f"HTTP/1.0 200 OK\r\n"
|
||||
f"Content-Type: application/json\r\n"
|
||||
f"Content-Length: {len(body)}\r\n"
|
||||
f"\r\n{body}"
|
||||
)
|
||||
else:
|
||||
body = "Not Found"
|
||||
resp = (
|
||||
f"HTTP/1.0 404 Not Found\r\n"
|
||||
f"Content-Type: text/plain\r\n"
|
||||
f"Content-Length: {len(body)}\r\n"
|
||||
f"\r\n{body}"
|
||||
)
|
||||
|
||||
writer.write(resp.encode())
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
|
||||
server = await asyncio.start_server(handle, host, health_port)
|
||||
console.print(f"[green]✓[/green] Health endpoint: http://{host}:{health_port}/health")
|
||||
async with server:
|
||||
await server.serve_forever()
|
||||
# Register Dream system job (always-on, idempotent on restart)
|
||||
dream_cfg = config.agents.defaults.dream
|
||||
if dream_cfg.model_override:
|
||||
agent.dream.model = dream_cfg.model_override
|
||||
agent.dream.max_batch_size = dream_cfg.max_batch_size
|
||||
agent.dream.max_iterations = dream_cfg.max_iterations
|
||||
agent.dream.annotate_line_ages = dream_cfg.annotate_line_ages
|
||||
from nanobot.cron.types import CronJob, CronPayload
|
||||
cron.register_system_job(CronJob(
|
||||
id="dream",
|
||||
@ -845,6 +888,7 @@ def gateway(
|
||||
await asyncio.gather(
|
||||
agent.run(),
|
||||
channels.start_all(),
|
||||
_health_server(config.gateway.host, port),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\nShutting down...")
|
||||
@ -967,7 +1011,7 @@ def agent(
|
||||
# Interactive mode — route through bus like other channels
|
||||
from nanobot.bus.events import InboundMessage
|
||||
_init_prompt_session()
|
||||
console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n")
|
||||
console.print(f"{__logo__} Interactive mode [bold blue]({config.agents.defaults.model})[/bold blue] — type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit\n")
|
||||
|
||||
if ":" in session_id:
|
||||
cli_channel, cli_chat_id = session_id.split(":", 1)
|
||||
|
||||
@ -102,7 +102,7 @@ class StreamRenderer:
|
||||
self._live = Live(self._render(), console=c, auto_refresh=False)
|
||||
self._live.start()
|
||||
now = time.monotonic()
|
||||
if "\n" in delta or (now - self._t) > 0.05:
|
||||
if (now - self._t) > 0.15:
|
||||
self._live.update(self._render())
|
||||
self._live.refresh()
|
||||
self._t = now
|
||||
|
||||
@ -74,6 +74,12 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
||||
search_usage_text = usage.format()
|
||||
except Exception:
|
||||
pass # Never let usage fetch break /status
|
||||
active_tasks = loop._active_tasks.get(ctx.key, [])
|
||||
task_count = sum(1 for t in active_tasks if not t.done())
|
||||
try:
|
||||
task_count += loop.subagents.get_running_count_by_session(ctx.key)
|
||||
except Exception:
|
||||
pass
|
||||
return OutboundMessage(
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
@ -84,6 +90,10 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
||||
session_msg_count=len(session.get_history(max_messages=0)),
|
||||
context_tokens_estimate=ctx_est,
|
||||
search_usage_text=search_usage_text,
|
||||
active_task_count=task_count,
|
||||
max_completion_tokens=getattr(
|
||||
getattr(loop.provider, "generation", None), "max_tokens", 8192
|
||||
),
|
||||
),
|
||||
metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"},
|
||||
)
|
||||
|
||||
@ -117,4 +117,19 @@ def _migrate_config(data: dict) -> dict:
|
||||
exec_cfg = tools.get("exec", {})
|
||||
if "restrictToWorkspace" in exec_cfg and "restrictToWorkspace" not in tools:
|
||||
tools["restrictToWorkspace"] = exec_cfg.pop("restrictToWorkspace")
|
||||
|
||||
# Move tools.myEnabled / tools.mySet → tools.my.{enable, allowSet}.
|
||||
# The old flat keys shipped in the initial MyTool landing; wrapping them in a
|
||||
# sub-config keeps `web` / `exec` / `my` symmetric and gives room to grow.
|
||||
if "myEnabled" in tools or "mySet" in tools:
|
||||
my_cfg = tools.setdefault("my", {})
|
||||
if "myEnabled" in tools and "enable" not in my_cfg:
|
||||
my_cfg["enable"] = tools.pop("myEnabled")
|
||||
else:
|
||||
tools.pop("myEnabled", None)
|
||||
if "mySet" in tools and "allowSet" not in my_cfg:
|
||||
my_cfg["allowSet"] = tools.pop("mySet")
|
||||
else:
|
||||
tools.pop("mySet", None)
|
||||
|
||||
return data
|
||||
|
||||
@ -43,7 +43,12 @@ class DreamConfig(Base):
|
||||
validation_alias=AliasChoices("modelOverride", "model", "model_override"),
|
||||
) # Optional Dream-specific model override
|
||||
max_batch_size: int = Field(default=20, ge=1) # Max history entries per run
|
||||
max_iterations: int = Field(default=10, ge=1) # Max tool calls per Phase 2
|
||||
# Bumped from 10 to 15 in #3212 (exp002: +30% dedup, no accuracy loss; >15 plateaus).
|
||||
max_iterations: int = Field(default=15, ge=1) # Max tool calls per Phase 2
|
||||
# Per-line git-blame age annotation in Phase 1 prompt (see #3212). Default
|
||||
# on — set to False to feed MEMORY.md raw if a specific LLM reacts poorly
|
||||
# to the `← Nd` suffix or you want deterministic, git-independent prompts.
|
||||
annotate_line_ages: bool = True
|
||||
|
||||
def build_schedule(self, timezone: str) -> CronSchedule:
|
||||
"""Build the runtime schedule, preferring the legacy cron override if present."""
|
||||
@ -96,7 +101,7 @@ class AgentsConfig(Base):
|
||||
class ProviderConfig(Base):
|
||||
"""LLM provider configuration."""
|
||||
|
||||
api_key: str = ""
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix)
|
||||
|
||||
@ -115,10 +120,12 @@ class ProvidersConfig(Base):
|
||||
dashscope: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
|
||||
lm_studio: ProviderConfig = Field(default_factory=ProviderConfig) # LM Studio local models
|
||||
ovms: ProviderConfig = Field(default_factory=ProviderConfig) # OpenVINO Model Server (OVMS)
|
||||
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
minimax_anthropic: ProviderConfig = Field(default_factory=ProviderConfig) # MiniMax Anthropic endpoint (thinking)
|
||||
mistral: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||
stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (阶跃星辰)
|
||||
xiaomi_mimo: ProviderConfig = Field(default_factory=ProviderConfig) # Xiaomi MIMO (小米)
|
||||
@ -152,7 +159,7 @@ class ApiConfig(Base):
|
||||
class GatewayConfig(Base):
|
||||
"""Gateway/server configuration."""
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
host: str = "127.0.0.1" # Safer default: local-only bind.
|
||||
port: int = 18790
|
||||
heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig)
|
||||
|
||||
@ -198,16 +205,22 @@ class MCPServerConfig(Base):
|
||||
tool_timeout: int = 30 # seconds before a tool call is cancelled
|
||||
enabled_tools: list[str] = Field(default_factory=lambda: ["*"]) # Only register these tools; accepts raw MCP names or wrapped mcp_<server>_<tool> names; ["*"] = all tools; [] = no tools
|
||||
|
||||
class MyToolConfig(Base):
|
||||
"""Self-inspection tool configuration."""
|
||||
|
||||
enable: bool = True # register the `my` tool (agent runtime state inspection)
|
||||
allow_set: bool = False # let `my` modify loop state (read-only if False)
|
||||
|
||||
|
||||
class ToolsConfig(Base):
|
||||
"""Tools configuration."""
|
||||
|
||||
web: WebToolsConfig = Field(default_factory=WebToolsConfig)
|
||||
exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
|
||||
my: MyToolConfig = Field(default_factory=MyToolConfig)
|
||||
restrict_to_workspace: bool = False # restrict all tool access to workspace directory
|
||||
mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
|
||||
ssrf_whitelist: list[str] = Field(default_factory=list) # CIDR ranges to exempt from SSRF blocking (e.g. ["100.64.0.0/10"] for Tailscale)
|
||||
my_enabled: bool = True # enable the my tool (agent runtime state inspection)
|
||||
my_set: bool = False # allow my tool to set state (read-only if False)
|
||||
|
||||
|
||||
class Config(BaseSettings):
|
||||
|
||||
@ -512,9 +512,9 @@ class LLMProvider(ABC):
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call chat_stream() with retry on transient provider failures."""
|
||||
if max_tokens is self._SENTINEL:
|
||||
if max_tokens is self._SENTINEL or max_tokens is None:
|
||||
max_tokens = self.generation.max_tokens
|
||||
if temperature is self._SENTINEL:
|
||||
if temperature is self._SENTINEL or temperature is None:
|
||||
temperature = self.generation.temperature
|
||||
if reasoning_effort is self._SENTINEL:
|
||||
reasoning_effort = self.generation.reasoning_effort
|
||||
@ -549,11 +549,14 @@ class LLMProvider(ABC):
|
||||
|
||||
Parameters default to ``self.generation`` when not explicitly passed,
|
||||
so callers no longer need to thread temperature / max_tokens /
|
||||
reasoning_effort through every layer.
|
||||
reasoning_effort through every layer. Explicit ``None`` is also
|
||||
normalized to the provider's generation defaults so that downstream
|
||||
``_build_kwargs`` never sees ``None`` for ``max_tokens`` / ``temperature``
|
||||
(which would crash ``max(1, max_tokens)``).
|
||||
"""
|
||||
if max_tokens is self._SENTINEL:
|
||||
if max_tokens is self._SENTINEL or max_tokens is None:
|
||||
max_tokens = self.generation.max_tokens
|
||||
if temperature is self._SENTINEL:
|
||||
if temperature is self._SENTINEL or temperature is None:
|
||||
temperature = self.generation.temperature
|
||||
if reasoning_effort is self._SENTINEL:
|
||||
reasoning_effort = self.generation.reasoning_effort
|
||||
@ -718,9 +721,22 @@ class LLMProvider(ABC):
|
||||
identical_error_count,
|
||||
(response.content or "")[:120].lower(),
|
||||
)
|
||||
if on_retry_wait:
|
||||
await on_retry_wait(
|
||||
f"Persistent retry stopped after {identical_error_count} identical errors."
|
||||
)
|
||||
return response
|
||||
|
||||
if not persistent and attempt > len(delays):
|
||||
logger.warning(
|
||||
"LLM request failed after {} retries, giving up: {}",
|
||||
attempt,
|
||||
(response.content or "")[:120].lower(),
|
||||
)
|
||||
if on_retry_wait:
|
||||
await on_retry_wait(
|
||||
f"Model request failed after {attempt} retries, giving up."
|
||||
)
|
||||
break
|
||||
|
||||
base_delay = delays[min(attempt - 1, len(delays) - 1)]
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import hashlib
|
||||
import importlib.util
|
||||
import os
|
||||
@ -49,6 +50,29 @@ _DEFAULT_OPENROUTER_HEADERS = {
|
||||
"X-OpenRouter-Title": "nanobot",
|
||||
"X-OpenRouter-Categories": "cli-agent,personal-agent",
|
||||
}
|
||||
_KIMI_THINKING_MODELS: frozenset[str] = frozenset({
|
||||
"kimi-k2.5",
|
||||
"k2.6-code-preview",
|
||||
})
|
||||
|
||||
|
||||
def _is_kimi_thinking_model(model_name: str) -> bool:
|
||||
"""Return True if model_name refers to a Kimi thinking-capable model.
|
||||
|
||||
Supports two forms:
|
||||
- Exact match: kimi-k2.5 in _KIMI_THINKING_MODELS
|
||||
- Slug match: moonshotai/kimi-k2.5 -> the part after the last "/"
|
||||
is checked against _KIMI_THINKING_MODELS
|
||||
|
||||
This covers both the native Moonshot provider (bare slug) and
|
||||
OpenRouter-style names (``"publisher/slug"``).
|
||||
"""
|
||||
name = model_name.lower()
|
||||
if name in _KIMI_THINKING_MODELS:
|
||||
return True
|
||||
if "/" in name and name.rsplit("/", 1)[1] in _KIMI_THINKING_MODELS:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _short_tool_id() -> str:
|
||||
@ -222,6 +246,24 @@ class OpenAICompatProvider(LLMProvider):
|
||||
return tool_call_id
|
||||
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_call_arguments(arguments: Any) -> str:
|
||||
"""Force function.arguments into a valid JSON object string."""
|
||||
if isinstance(arguments, str):
|
||||
stripped = arguments.strip()
|
||||
if not stripped:
|
||||
return "{}"
|
||||
try:
|
||||
parsed = json_repair.loads(stripped)
|
||||
except Exception:
|
||||
return "{}"
|
||||
if isinstance(parsed, dict):
|
||||
return json.dumps(parsed, ensure_ascii=False)
|
||||
return "{}"
|
||||
if isinstance(arguments, dict):
|
||||
return json.dumps(arguments, ensure_ascii=False)
|
||||
return "{}"
|
||||
|
||||
def _sanitize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Strip non-standard keys, normalize tool_call IDs."""
|
||||
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
||||
@ -241,6 +283,16 @@ class OpenAICompatProvider(LLMProvider):
|
||||
continue
|
||||
tc_clean = dict(tc)
|
||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||
function = tc_clean.get("function")
|
||||
if isinstance(function, dict):
|
||||
function_clean = dict(function)
|
||||
if "arguments" in function_clean:
|
||||
function_clean["arguments"] = self._normalize_tool_call_arguments(
|
||||
function_clean.get("arguments")
|
||||
)
|
||||
else:
|
||||
function_clean["arguments"] = "{}"
|
||||
tc_clean["function"] = function_clean
|
||||
normalized.append(tc_clean)
|
||||
clean["tool_calls"] = normalized
|
||||
if clean.get("role") == "assistant":
|
||||
@ -334,6 +386,16 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if extra:
|
||||
kwargs.setdefault("extra_body", {}).update(extra)
|
||||
|
||||
# Model-level thinking injection for Kimi thinking-capable models.
|
||||
# Strip any provider prefix (e.g. "moonshotai/") before the set lookup
|
||||
# so that OpenRouter-style names like "moonshotai/kimi-k2.5" are handled
|
||||
# identically to bare names like "kimi-k2.5".
|
||||
if reasoning_effort is not None and _is_kimi_thinking_model(model_name):
|
||||
thinking_enabled = reasoning_effort.lower() != "minimal"
|
||||
kwargs.setdefault("extra_body", {}).update(
|
||||
{"thinking": {"type": "enabled" if thinking_enabled else "disabled"}}
|
||||
)
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
@ -280,6 +280,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
backend="openai_compat",
|
||||
default_api_base="https://api.minimax.io/v1",
|
||||
),
|
||||
# MiniMax Anthropic-compatible endpoint: supports thinking mode
|
||||
ProviderSpec(
|
||||
name="minimax_anthropic",
|
||||
keywords=("minimax_anthropic",),
|
||||
env_key="MINIMAX_API_KEY",
|
||||
display_name="MiniMax (Anthropic)",
|
||||
backend="anthropic",
|
||||
default_api_base="https://api.minimax.io/anthropic",
|
||||
),
|
||||
# Mistral AI: OpenAI-compatible API
|
||||
ProviderSpec(
|
||||
name="mistral",
|
||||
@ -328,6 +337,17 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
detect_by_base_keyword="11434",
|
||||
default_api_base="http://localhost:11434/v1",
|
||||
),
|
||||
# LM Studio (local, OpenAI-compatible)
|
||||
ProviderSpec(
|
||||
name="lm_studio",
|
||||
keywords=("lm-studio", "lmstudio", "lm_studio"),
|
||||
env_key="LM_STUDIO_API_KEY",
|
||||
display_name="LM Studio",
|
||||
backend="openai_compat",
|
||||
is_local=True,
|
||||
detect_by_base_keyword="1234",
|
||||
default_api_base="http://localhost:1234/v1",
|
||||
),
|
||||
# === OpenVINO Model Server (direct, local, OpenAI-compatible at /v3) ===
|
||||
ProviderSpec(
|
||||
name="ovms",
|
||||
|
||||
@ -59,7 +59,7 @@ always: true
|
||||
|
||||
- All modifications in-memory only — restart resets everything
|
||||
- Protected params have type/range validation: `max_iterations` (1–100), `context_window_tokens` (4096–1M), `model` (non-empty str)
|
||||
- If `my_set` is false, check only
|
||||
- If `tools.my.allow_set` is false, check only
|
||||
|
||||
## Related tools
|
||||
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
Compare conversation history against current memory files. Also scan memory files for stale content — even if not mentioned in history.
|
||||
You have TWO equally important tasks:
|
||||
1. Extract new facts from conversation history
|
||||
2. Deduplicate existing memory files — find and flag redundant, overlapping, or stale content even if NOT mentioned in history
|
||||
|
||||
Output one line per finding:
|
||||
[FILE] atomic fact (not already in memory)
|
||||
@ -12,12 +14,20 @@ Rules:
|
||||
- Corrections: [USER] location is Tokyo, not Osaka
|
||||
- Capture confirmed approaches the user validated
|
||||
|
||||
Staleness — flag for [FILE-REMOVE]:
|
||||
- Time-sensitive data older than 14 days: weather, daily status, one-time meetings, passed events
|
||||
- Completed one-time tasks: triage, one-time reviews, finished research, resolved incidents
|
||||
- Resolved tracking: merged/closed PRs, fixed issues, completed migrations
|
||||
- Detailed incident info after 14 days — reduce to one-line summary
|
||||
- Superseded: approaches replaced by newer solutions, deprecated dependencies
|
||||
Deduplication — scan ALL memory files for these redundancy patterns:
|
||||
- Same fact stated in multiple places (e.g., "communicates in Chinese" in both USER.md and multiple MEMORY.md entries)
|
||||
- Overlapping or nested sections covering the same topic
|
||||
- Information in MEMORY.md that is already captured in USER.md or SOUL.md (MEMORY.md should not duplicate permanent-file content)
|
||||
- Verbose entries that can be condensed without losing information
|
||||
For each duplicate found, output [FILE-REMOVE] for the less authoritative copy (prefer keeping facts in their canonical location)
|
||||
|
||||
Staleness — MEMORY.md lines may have a ``← Nd`` suffix showing days since last modification:
|
||||
- SOUL.md and USER.md have no age annotations — they are permanent, only update with corrections
|
||||
- Age only indicates when content was last touched, not whether it should be removed
|
||||
- Use content judgment: user habits/preferences/personality traits are permanent regardless of age
|
||||
- Only prune content that is objectively outdated: passed events, resolved tracking, superseded approaches
|
||||
- Lines with ``← Nd`` (N>{{ stale_threshold_days }}) deserve closer review but are NOT automatically removable
|
||||
- When removing: prefer deleting individual items over entire sections
|
||||
|
||||
Skill discovery — flag [SKILL] when ALL of these are true:
|
||||
- A specific, repeatable workflow appeared 2+ times in the conversation history
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Skills
|
||||
|
||||
The following skills extend your capabilities. To use a skill, read its SKILL.md file using the read_file tool.
|
||||
Skills with available="false" need dependencies installed first - you can try installing them with apt/brew.
|
||||
Unavailable skills need dependencies installed first — you can try installing them with apt/brew.
|
||||
|
||||
{{ skills_summary }}
|
||||
|
||||
267
nanobot/utils/document.py
Normal file
267
nanobot/utils/document.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""Document text extraction utilities for nanobot."""
|
||||
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.utils.helpers import detect_image_mime
|
||||
|
||||
try:
|
||||
from pypdf import PdfReader
|
||||
except ImportError:
|
||||
PdfReader = None # type: ignore
|
||||
|
||||
try:
|
||||
from docx import Document as DocxDocument
|
||||
except ImportError:
|
||||
DocxDocument = None # type: ignore
|
||||
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
except ImportError:
|
||||
load_workbook = None # type: ignore
|
||||
|
||||
try:
|
||||
from pptx import Presentation as PptxPresentation
|
||||
except ImportError:
|
||||
PptxPresentation = None # type: ignore
|
||||
|
||||
|
||||
# Supported file extensions for text extraction
|
||||
SUPPORTED_EXTENSIONS: set[str] = {
|
||||
# Document formats
|
||||
".pdf",
|
||||
".docx",
|
||||
".xlsx",
|
||||
".pptx",
|
||||
# Text formats
|
||||
".txt",
|
||||
".md",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".html",
|
||||
".htm",
|
||||
".log",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
# Image formats (for future OCR support)
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
}
|
||||
|
||||
_MAX_TEXT_LENGTH = 200_000
|
||||
|
||||
|
||||
def extract_text(path: Path) -> str | None:
|
||||
"""Extract text from a file.
|
||||
|
||||
Args:
|
||||
path: Path to the file.
|
||||
|
||||
Returns:
|
||||
Extracted text as string, None for unsupported types,
|
||||
or error string for failures.
|
||||
"""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
|
||||
if not path.exists():
|
||||
return f"[error: file not found: {path}]"
|
||||
|
||||
ext = path.suffix.lower()
|
||||
|
||||
# Document formats
|
||||
if ext == ".pdf":
|
||||
if PdfReader is None:
|
||||
return "[error: pypdf not installed]"
|
||||
return _extract_pdf(path)
|
||||
elif ext == ".docx":
|
||||
if DocxDocument is None:
|
||||
return "[error: python-docx not installed]"
|
||||
return _extract_docx(path)
|
||||
elif ext == ".xlsx":
|
||||
if load_workbook is None:
|
||||
return "[error: openpyxl not installed]"
|
||||
return _extract_xlsx(path)
|
||||
elif ext == ".pptx":
|
||||
if PptxPresentation is None:
|
||||
return "[error: python-pptx not installed]"
|
||||
return _extract_pptx(path)
|
||||
elif _is_text_extension(ext):
|
||||
return _extract_text_file(path)
|
||||
elif ext in {".png", ".jpg", ".jpeg", ".gif", ".webp"}:
|
||||
# Image files - for future OCR support
|
||||
return f"[image: {path.name}]"
|
||||
else:
|
||||
# Unsupported extension
|
||||
return None
|
||||
|
||||
|
||||
def _extract_pdf(path: Path) -> str:
|
||||
"""Extract text from PDF using pypdf."""
|
||||
try:
|
||||
reader = PdfReader(path)
|
||||
pages: list[str] = []
|
||||
for i, page in enumerate(reader.pages, 1):
|
||||
text = page.extract_text() or ""
|
||||
pages.append(f"--- Page {i} ---\n{text}")
|
||||
return _truncate("\n\n".join(pages), _MAX_TEXT_LENGTH)
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract PDF {}: {}", path, e)
|
||||
return f"[error: failed to extract PDF: {e!s}]"
|
||||
|
||||
|
||||
def _extract_docx(path: Path) -> str:
|
||||
"""Extract text from DOCX using python-docx."""
|
||||
try:
|
||||
doc = DocxDocument(path)
|
||||
paragraphs: list[str] = [p.text for p in doc.paragraphs if p.text.strip()]
|
||||
return _truncate("\n\n".join(paragraphs), _MAX_TEXT_LENGTH)
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract DOCX {}: {}", path, e)
|
||||
return f"[error: failed to extract DOCX: {e!s}]"
|
||||
|
||||
|
||||
def _extract_xlsx(path: Path) -> str:
|
||||
"""Extract text from XLSX using openpyxl."""
|
||||
try:
|
||||
wb = load_workbook(path, read_only=True, data_only=True)
|
||||
sheets: list[str] = []
|
||||
for sheet_name in wb.sheetnames:
|
||||
ws = wb[sheet_name]
|
||||
rows: list[str] = []
|
||||
for row in ws.iter_rows(values_only=True):
|
||||
row_text = "\t".join(str(cell) if cell is not None else "" for cell in row)
|
||||
if row_text.strip():
|
||||
rows.append(row_text)
|
||||
if rows:
|
||||
sheets.append(f"--- Sheet: {sheet_name} ---\n" + "\n".join(rows))
|
||||
wb.close()
|
||||
return _truncate("\n\n".join(sheets), _MAX_TEXT_LENGTH)
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract XLSX {}: {}", path, e)
|
||||
return f"[error: failed to extract XLSX: {e!s}]"
|
||||
|
||||
|
||||
def _extract_pptx(path: Path) -> str:
|
||||
"""Extract text from PPTX using python-pptx."""
|
||||
try:
|
||||
prs = PptxPresentation(path)
|
||||
slides: list[str] = []
|
||||
for i, slide in enumerate(prs.slides, 1):
|
||||
slide_text: list[str] = []
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text") and shape.text:
|
||||
slide_text.append(shape.text)
|
||||
if slide_text:
|
||||
slides.append(f"--- Slide {i} ---\n" + "\n".join(slide_text))
|
||||
return _truncate("\n\n".join(slides), _MAX_TEXT_LENGTH)
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract PPTX {}: {}", path, e)
|
||||
return f"[error: failed to extract PPTX: {e!s}]"
|
||||
|
||||
|
||||
def _extract_text_file(path: Path) -> str:
|
||||
"""Extract text from a plain text file."""
|
||||
try:
|
||||
# Try UTF-8 first, then latin-1 fallback
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8")
|
||||
except UnicodeDecodeError:
|
||||
content = path.read_text(encoding="latin-1")
|
||||
return _truncate(content, _MAX_TEXT_LENGTH)
|
||||
except Exception as e:
|
||||
logger.error("Failed to read text file {}: {}", path, e)
|
||||
return f"[error: failed to read file: {e!s}]"
|
||||
|
||||
|
||||
def _truncate(text: str, max_length: int) -> str:
|
||||
"""Truncate text with a suffix indicating truncation."""
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
return text[:max_length] + f"... (truncated, {len(text)} chars total)"
|
||||
|
||||
|
||||
def _is_text_extension(ext: str) -> bool:
|
||||
"""Check if extension is a text format."""
|
||||
return ext in {
|
||||
".txt",
|
||||
".md",
|
||||
".csv",
|
||||
".json",
|
||||
".xml",
|
||||
".html",
|
||||
".htm",
|
||||
".log",
|
||||
".yaml",
|
||||
".yml",
|
||||
".toml",
|
||||
".ini",
|
||||
".cfg",
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# High-level helper: split media into images + extracted document text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_MAX_EXTRACT_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
|
||||
|
||||
|
||||
def extract_documents(
|
||||
text: str,
|
||||
media_paths: list[str],
|
||||
*,
|
||||
max_file_size: int = _MAX_EXTRACT_FILE_SIZE,
|
||||
) -> tuple[str, list[str]]:
|
||||
"""Separate images from documents in *media_paths*.
|
||||
|
||||
Documents (PDF, DOCX, XLSX, PPTX, plain-text, …) have their text
|
||||
extracted and appended to *text*. Only image paths are kept in the
|
||||
returned list so that downstream layers only need to handle vision
|
||||
blocks.
|
||||
|
||||
Files larger than *max_file_size* bytes are skipped with a warning
|
||||
to avoid unbounded memory / CPU usage.
|
||||
"""
|
||||
image_paths: list[str] = []
|
||||
doc_texts: list[str] = []
|
||||
|
||||
for path_str in media_paths:
|
||||
p = Path(path_str)
|
||||
if not p.is_file():
|
||||
continue
|
||||
|
||||
try:
|
||||
size = p.stat().st_size
|
||||
except OSError:
|
||||
continue
|
||||
if size > max_file_size:
|
||||
logger.warning(
|
||||
"Skipping oversized file for extraction: {} ({:.1f} MB > {} MB limit)",
|
||||
p.name, size / (1024 * 1024), max_file_size // (1024 * 1024),
|
||||
)
|
||||
continue
|
||||
|
||||
with open(p, "rb") as f:
|
||||
header = f.read(16)
|
||||
mime = detect_image_mime(header) or mimetypes.guess_type(path_str)[0]
|
||||
if mime and mime.startswith("image/"):
|
||||
image_paths.append(path_str)
|
||||
else:
|
||||
extracted = extract_text(p)
|
||||
if extracted and not extracted.startswith("[error:"):
|
||||
doc_texts.append(f"[File: {p.name}]\n{extracted}")
|
||||
|
||||
if doc_texts:
|
||||
text = text + "\n\n" + "\n\n".join(doc_texts)
|
||||
|
||||
return text, image_paths
|
||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import io
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
@ -24,6 +25,23 @@ class CommitInfo:
|
||||
return f"{header}\n(no file changes)"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineAge:
|
||||
"""Age of a single line based on git blame."""
|
||||
|
||||
age_days: int # days since last modification
|
||||
|
||||
|
||||
def _compute_line_ages(annotated) -> list[LineAge]:
|
||||
"""Convert annotate results to per-line ages."""
|
||||
now = datetime.now(tz=timezone.utc).date()
|
||||
ages: list[LineAge] = []
|
||||
for (commit, _tree_entry), _line_bytes in annotated:
|
||||
dt = datetime.fromtimestamp(commit.commit_time, tz=timezone.utc).date()
|
||||
ages.append(LineAge(age_days=(now - dt).days))
|
||||
return ages
|
||||
|
||||
|
||||
class GitStore:
|
||||
"""Git-backed version control for memory files."""
|
||||
|
||||
@ -191,6 +209,34 @@ class GitStore:
|
||||
logger.warning("Git log failed")
|
||||
return []
|
||||
|
||||
def line_ages(self, file_path: str) -> list[LineAge]:
|
||||
"""Compute the age of each line in a tracked file via git blame.
|
||||
|
||||
Returns one LineAge per line, in order.
|
||||
Returns an empty list if the repo is not initialized, the file is
|
||||
empty, or annotation fails.
|
||||
"""
|
||||
|
||||
if not self.is_initialized():
|
||||
return []
|
||||
|
||||
target = self._workspace / file_path
|
||||
if not target.exists() or target.stat().st_size == 0:
|
||||
return []
|
||||
|
||||
try:
|
||||
from dulwich import porcelain
|
||||
|
||||
annotated = porcelain.annotate(str(self._workspace), file_path)
|
||||
except Exception:
|
||||
logger.warning("Git line_ages annotate failed for {}", file_path)
|
||||
return []
|
||||
|
||||
if not annotated:
|
||||
return []
|
||||
|
||||
return _compute_line_ages(annotated)
|
||||
|
||||
def diff_commits(self, sha1: str, sha2: str) -> str:
|
||||
"""Show diff between two commits."""
|
||||
if not self.is_initialized():
|
||||
|
||||
@ -400,6 +400,8 @@ def build_status_content(
|
||||
session_msg_count: int,
|
||||
context_tokens_estimate: int,
|
||||
search_usage_text: str | None = None,
|
||||
active_task_count: int = 0,
|
||||
max_completion_tokens: int = 8192,
|
||||
) -> str:
|
||||
"""Build a human-readable runtime status snapshot.
|
||||
|
||||
@ -418,7 +420,9 @@ def build_status_content(
|
||||
last_out = last_usage.get("completion_tokens", 0)
|
||||
cached = last_usage.get("cached_tokens", 0)
|
||||
ctx_total = max(context_window_tokens, 0)
|
||||
ctx_pct = int((context_tokens_estimate / ctx_total) * 100) if ctx_total > 0 else 0
|
||||
# Budget mirrors Consolidator formula: ctx_window - max_completion - _SAFETY_BUFFER
|
||||
ctx_budget = max(ctx_total - int(max_completion_tokens) - 1024, 1)
|
||||
ctx_pct = min(int((context_tokens_estimate / ctx_budget) * 100), 999) if ctx_budget > 0 else 0
|
||||
ctx_used_str = f"{context_tokens_estimate // 1000}k" if context_tokens_estimate >= 1000 else str(context_tokens_estimate)
|
||||
ctx_total_str = f"{ctx_total // 1000}k" if ctx_total > 0 else "n/a"
|
||||
token_line = f"\U0001f4ca Tokens: {last_in} in / {last_out} out"
|
||||
@ -428,9 +432,10 @@ def build_status_content(
|
||||
f"\U0001f408 nanobot v{version}",
|
||||
f"\U0001f9e0 Model: {model}",
|
||||
token_line,
|
||||
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}%)",
|
||||
f"\U0001f4da Context: {ctx_used_str}/{ctx_total_str} ({ctx_pct}% of input budget)",
|
||||
f"\U0001f4ac Session: {session_msg_count} messages",
|
||||
f"\u23f1 Uptime: {uptime}",
|
||||
f"\u26a1 Tasks: {active_task_count} active",
|
||||
]
|
||||
if search_usage_text:
|
||||
lines.append(search_usage_text)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "nanobot-ai"
|
||||
version = "0.1.5"
|
||||
version = "0.1.5.post1"
|
||||
description = "A lightweight personal AI assistant framework"
|
||||
readme = { file = "README.md", content-type = "text/markdown" }
|
||||
requires-python = ">=3.11"
|
||||
@ -50,6 +50,11 @@ dependencies = [
|
||||
"tiktoken>=0.12.0,<1.0.0",
|
||||
"jinja2>=3.1.0,<4.0.0",
|
||||
"dulwich>=0.22.0,<1.0.0",
|
||||
"pyyaml>=6.0,<7.0.0",
|
||||
"pypdf>=5.0.0,<6.0.0",
|
||||
"python-docx>=1.1.0,<2.0.0",
|
||||
"openpyxl>=3.1.0,<4.0.0",
|
||||
"python-pptx>=1.0.0,<2.0.0",
|
||||
"filelock>=3.25.2",
|
||||
]
|
||||
|
||||
|
||||
@ -219,3 +219,55 @@ def test_subagent_result_does_not_create_consecutive_assistant_messages(tmp_path
|
||||
|
||||
for left, right in zip(messages, messages[1:]):
|
||||
assert not (left.get("role") == right.get("role") == "assistant")
|
||||
|
||||
|
||||
def test_always_skills_excluded_from_skills_index(tmp_path) -> None:
|
||||
"""Always skills should appear in Active Skills but NOT in the skills index."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
# memory skill should be in Active Skills section
|
||||
assert "# Active Skills" in prompt
|
||||
assert "### Skill: memory" in prompt
|
||||
|
||||
# memory skill should NOT appear in the skills index
|
||||
skills_section = prompt.split("# Skills\n", 1)
|
||||
if len(skills_section) > 1:
|
||||
index_text = skills_section[1].split("\n\n---")[0]
|
||||
assert "**memory**" not in index_text
|
||||
|
||||
|
||||
def test_template_memory_md_is_skipped(tmp_path) -> None:
|
||||
"""MEMORY.md matching the bundled template should not inject the Memory section."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
from nanobot.utils.helpers import sync_workspace_templates
|
||||
sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
builder = ContextBuilder(workspace)
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
# The "# Memory\n\n## Long-term Memory" block is produced only by
|
||||
# build_system_prompt() when MEMORY.md is injected. The memory skill
|
||||
# also contains "# Memory" but is followed by "## Structure", not
|
||||
# "## Long-term Memory".
|
||||
assert "# Memory\n\n## Long-term Memory" not in prompt
|
||||
assert "This file is automatically updated by nanobot" not in prompt
|
||||
|
||||
|
||||
def test_customized_memory_md_is_injected(tmp_path) -> None:
|
||||
"""A Dream-populated MEMORY.md should be injected normally."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
from nanobot.utils.helpers import sync_workspace_templates
|
||||
sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
(workspace / "memory" / "MEMORY.md").write_text(
|
||||
"# Long-term Memory\n\nUser prefers dark mode.\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
builder = ContextBuilder(workspace)
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
assert "# Memory\n\n## Long-term Memory" in prompt
|
||||
assert "User prefers dark mode" in prompt
|
||||
|
||||
@ -2,11 +2,12 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from nanobot.agent.memory import Dream, MemoryStore
|
||||
from nanobot.agent.runner import AgentRunResult
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.utils.gitstore import LineAge
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -123,3 +124,135 @@ class TestDreamRun:
|
||||
assert "Successfully wrote" in result
|
||||
assert (store.workspace / "skills" / "test-skill" / "SKILL.md").exists()
|
||||
|
||||
async def test_phase1_prompt_includes_line_age_annotations(self, dream, mock_provider, mock_runner, store):
|
||||
"""Phase 1 prompt should have per-line age suffixes in MEMORY.md when git is available."""
|
||||
store.append_history("some event")
|
||||
mock_provider.chat_with_retry.return_value = MagicMock(content="[SKIP]")
|
||||
mock_runner.run = AsyncMock(return_value=_make_run_result())
|
||||
|
||||
# Init git so line_ages works
|
||||
store.git.init()
|
||||
store.git.auto_commit("initial memory state")
|
||||
|
||||
await dream.run()
|
||||
|
||||
# The MEMORY.md section should not crash and should contain the memory content
|
||||
call_args = mock_provider.chat_with_retry.call_args
|
||||
user_msg = call_args.kwargs.get("messages", call_args[1].get("messages"))[1]["content"]
|
||||
assert "## Current MEMORY.md" in user_msg
|
||||
|
||||
async def test_phase1_annotates_only_memory_not_soul_or_user(self, dream, mock_provider, mock_runner, store):
|
||||
"""SOUL.md and USER.md should never have age annotations — they are permanent."""
|
||||
store.append_history("some event")
|
||||
mock_provider.chat_with_retry.return_value = MagicMock(content="[SKIP]")
|
||||
mock_runner.run = AsyncMock(return_value=_make_run_result())
|
||||
|
||||
store.git.init()
|
||||
store.git.auto_commit("initial state")
|
||||
|
||||
await dream.run()
|
||||
|
||||
call_args = mock_provider.chat_with_retry.call_args
|
||||
user_msg = call_args.kwargs.get("messages", call_args[1].get("messages"))[1]["content"]
|
||||
# The ← suffix should only appear in MEMORY.md section
|
||||
memory_section = user_msg.split("## Current MEMORY.md")[1].split("## Current SOUL.md")[0]
|
||||
soul_section = user_msg.split("## Current SOUL.md")[1].split("## Current USER.md")[0]
|
||||
user_section = user_msg.split("## Current USER.md")[1]
|
||||
# SOUL and USER should not contain age arrows
|
||||
assert "\u2190" not in soul_section
|
||||
assert "\u2190" not in user_section
|
||||
|
||||
async def test_phase1_prompt_works_without_git(self, dream, mock_provider, mock_runner, store):
|
||||
"""Phase 1 should work fine even if git is not initialized (no age annotations)."""
|
||||
store.append_history("some event")
|
||||
mock_provider.chat_with_retry.return_value = MagicMock(content="[SKIP]")
|
||||
mock_runner.run = AsyncMock(return_value=_make_run_result())
|
||||
|
||||
await dream.run()
|
||||
|
||||
# Should still succeed — just without age annotations
|
||||
mock_provider.chat_with_retry.assert_called_once()
|
||||
call_args = mock_provider.chat_with_retry.call_args
|
||||
user_msg = call_args.kwargs.get("messages", call_args[1].get("messages"))[1]["content"]
|
||||
assert "## Current MEMORY.md" in user_msg
|
||||
|
||||
async def test_phase1_prompt_carries_age_suffix_for_stale_lines(
|
||||
self, dream, mock_provider, mock_runner, store,
|
||||
):
|
||||
"""End-to-end: ages >14d must appear verbatim in the LLM prompt, ages ≤14d must not."""
|
||||
# MEMORY.md fixture has 2 non-blank lines ("# Memory" and "- Project X active").
|
||||
# Inject four ages to cover threshold boundaries: >14 suffix, ==14 no suffix, <14 no suffix.
|
||||
store.write_memory("# Memory\n- Project X active\n- fresh item\n- edge case line")
|
||||
store.append_history("some event")
|
||||
mock_provider.chat_with_retry.return_value = MagicMock(content="[SKIP]")
|
||||
mock_runner.run = AsyncMock(return_value=_make_run_result())
|
||||
|
||||
fake_ages = [
|
||||
LineAge(age_days=30), # "# Memory" → should get ← 30d
|
||||
LineAge(age_days=20), # "- Project X..." → should get ← 20d
|
||||
LineAge(age_days=14), # "- fresh item" → ==14, threshold is strictly >14, no suffix
|
||||
LineAge(age_days=5), # "- edge case..." → no suffix
|
||||
]
|
||||
with patch.object(store.git, "line_ages", return_value=fake_ages):
|
||||
await dream.run()
|
||||
|
||||
call_args = mock_provider.chat_with_retry.call_args
|
||||
user_msg = call_args.kwargs.get("messages", call_args[1].get("messages"))[1]["content"]
|
||||
memory_section = user_msg.split("## Current MEMORY.md")[1].split("## Current SOUL.md")[0]
|
||||
assert "\u2190 30d" in memory_section
|
||||
assert "\u2190 20d" in memory_section
|
||||
assert "\u2190 14d" not in memory_section
|
||||
assert "\u2190 5d" not in memory_section
|
||||
|
||||
async def test_phase1_skips_annotation_when_disabled(
|
||||
self, dream, mock_provider, mock_runner, store,
|
||||
):
|
||||
"""`annotate_line_ages=False` must bypass the git lookup entirely and keep MEMORY.md raw."""
|
||||
store.append_history("some event")
|
||||
mock_provider.chat_with_retry.return_value = MagicMock(content="[SKIP]")
|
||||
mock_runner.run = AsyncMock(return_value=_make_run_result())
|
||||
|
||||
dream.annotate_line_ages = False
|
||||
# line_ages must be bypassed entirely — verify with a spy rather than a
|
||||
# raising side_effect, because _annotate_with_ages catches Exception
|
||||
# (which swallows AssertionError) and would hide an accidental call.
|
||||
with patch.object(store.git, "line_ages") as mock_line_ages:
|
||||
await dream.run()
|
||||
mock_line_ages.assert_not_called()
|
||||
|
||||
call_args = mock_provider.chat_with_retry.call_args
|
||||
user_msg = call_args.kwargs.get("messages", call_args[1].get("messages"))[1]["content"]
|
||||
assert "\u2190" not in user_msg
|
||||
|
||||
async def test_phase1_skips_annotation_on_line_ages_length_mismatch(
|
||||
self, dream, mock_provider, mock_runner, store,
|
||||
):
|
||||
"""If ages length != lines length (dirty working tree), skip annotation instead of mis-tagging."""
|
||||
# MEMORY.md has 2 non-blank lines but we hand back only 1 age → mismatch.
|
||||
store.append_history("some event")
|
||||
mock_provider.chat_with_retry.return_value = MagicMock(content="[SKIP]")
|
||||
mock_runner.run = AsyncMock(return_value=_make_run_result())
|
||||
|
||||
with patch.object(store.git, "line_ages", return_value=[LineAge(age_days=999)]):
|
||||
await dream.run()
|
||||
|
||||
call_args = mock_provider.chat_with_retry.call_args
|
||||
user_msg = call_args.kwargs.get("messages", call_args[1].get("messages"))[1]["content"]
|
||||
memory_section = user_msg.split("## Current MEMORY.md")[1].split("## Current SOUL.md")[0]
|
||||
# No age arrow at all — we refused to annotate rather than tag the wrong line.
|
||||
assert "\u2190" not in memory_section
|
||||
|
||||
async def test_phase1_prompt_uses_threshold_from_template_var(
|
||||
self, dream, mock_provider, mock_runner, store,
|
||||
):
|
||||
"""System prompt should reference the stale-threshold constant, not a hardcoded 14."""
|
||||
store.append_history("some event")
|
||||
mock_provider.chat_with_retry.return_value = MagicMock(content="[SKIP]")
|
||||
mock_runner.run = AsyncMock(return_value=_make_run_result())
|
||||
|
||||
await dream.run()
|
||||
|
||||
system_msg = mock_provider.chat_with_retry.call_args.kwargs["messages"][0]["content"]
|
||||
# The template renders with stale_threshold_days=14 → LLM must see "N>14"
|
||||
assert "N>14" in system_msg
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
@ -308,3 +309,111 @@ async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(t
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
]
|
||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_preserves_runtime_checkpoint_for_next_turn(tmp_path: Path) -> None:
|
||||
from nanobot.command.builtin import cmd_stop
|
||||
from nanobot.command.router import CommandContext
|
||||
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
|
||||
checkpoint_saved = asyncio.Event()
|
||||
|
||||
async def interrupted_run_agent_loop(_initial_messages, *, session=None, **_kwargs):
|
||||
assert session is not None
|
||||
loop._set_runtime_checkpoint(
|
||||
session,
|
||||
{
|
||||
"assistant_message": {
|
||||
"role": "assistant",
|
||||
"content": "working",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_done",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
"completed_tool_results": [
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_done",
|
||||
"name": "read_file",
|
||||
"content": "ok",
|
||||
}
|
||||
],
|
||||
"pending_tool_calls": [
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
checkpoint_saved.set()
|
||||
await asyncio.Event().wait()
|
||||
|
||||
loop._run_agent_loop = interrupted_run_agent_loop # type: ignore[method-assign]
|
||||
|
||||
first_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="keep progress")
|
||||
task = asyncio.create_task(loop._process_message(first_msg))
|
||||
loop._active_tasks[first_msg.session_key] = [task]
|
||||
await asyncio.wait_for(checkpoint_saved.wait(), timeout=1.0)
|
||||
|
||||
stop_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="/stop")
|
||||
stop_ctx = CommandContext(msg=stop_msg, session=None, key=stop_msg.session_key, raw="/stop", loop=loop)
|
||||
stop_result = await cmd_stop(stop_ctx)
|
||||
|
||||
assert "Stopped 1 task" in stop_result.content
|
||||
assert task.done()
|
||||
|
||||
loop.sessions.invalidate("feishu:c4")
|
||||
interrupted = loop.sessions.get_or_create("feishu:c4")
|
||||
assert interrupted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True
|
||||
assert interrupted.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is not None
|
||||
|
||||
async def resumed_run_agent_loop(initial_messages, **_kwargs):
|
||||
return (
|
||||
"next answer",
|
||||
None,
|
||||
[*initial_messages, {"role": "assistant", "content": "next answer"}],
|
||||
"stop",
|
||||
False,
|
||||
)
|
||||
|
||||
loop._run_agent_loop = resumed_run_agent_loop # type: ignore[method-assign]
|
||||
result = await loop._process_message(
|
||||
InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="continue here")
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.content == "next answer"
|
||||
|
||||
session = loop.sessions.get_or_create("feishu:c4")
|
||||
assert [
|
||||
{k: v for k, v in m.items() if k in {"role", "content", "tool_call_id", "name"}}
|
||||
for m in session.messages
|
||||
] == [
|
||||
{"role": "user", "content": "keep progress"},
|
||||
{"role": "assistant", "content": "working"},
|
||||
{"role": "tool", "tool_call_id": "call_done", "name": "read_file", "content": "ok"},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_pending",
|
||||
"name": "exec",
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
},
|
||||
{"role": "user", "content": "continue here"},
|
||||
{"role": "assistant", "content": "next answer"},
|
||||
]
|
||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||
assert AgentLoop._RUNTIME_CHECKPOINT_KEY not in session.metadata
|
||||
|
||||
@ -79,6 +79,29 @@ class TestHistoryWithCursor:
|
||||
entries = store.read_unprocessed_history(since_cursor=0)
|
||||
assert len(entries) == 2
|
||||
|
||||
def test_read_unprocessed_skips_entries_without_cursor(self, store):
|
||||
"""Regression: entries missing the cursor key should be silently skipped."""
|
||||
store.history_file.write_text(
|
||||
'{"timestamp": "2026-04-01 10:00", "content": "no cursor"}\n'
|
||||
'{"cursor": 2, "timestamp": "2026-04-01 10:01", "content": "valid"}\n'
|
||||
'{"cursor": 3, "timestamp": "2026-04-01 10:02", "content": "also valid"}\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
entries = store.read_unprocessed_history(since_cursor=0)
|
||||
assert [e["cursor"] for e in entries] == [2, 3]
|
||||
|
||||
def test_next_cursor_falls_back_when_last_entry_has_no_cursor(self, store):
|
||||
"""Regression: _next_cursor should not KeyError on entries without cursor."""
|
||||
store.history_file.write_text(
|
||||
'{"timestamp": "2026-04-01 10:01", "content": "no cursor"}\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
# Delete .cursor file so _next_cursor falls back to reading JSONL
|
||||
store._cursor_file.unlink(missing_ok=True)
|
||||
# Last entry has no cursor — should safely return 1, not KeyError
|
||||
cursor = store.append_history("new event")
|
||||
assert cursor == 1
|
||||
|
||||
def test_compact_history_drops_oldest(self, tmp_path):
|
||||
store = MemoryStore(tmp_path, max_history_entries=2)
|
||||
store.append_history("event 1")
|
||||
|
||||
@ -689,11 +689,20 @@ async def test_runner_keeps_going_when_tool_result_persistence_fails():
|
||||
|
||||
|
||||
class _DelayTool(Tool):
|
||||
def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
delay: float,
|
||||
read_only: bool,
|
||||
shared_events: list[str],
|
||||
exclusive: bool = False,
|
||||
):
|
||||
self._name = name
|
||||
self._delay = delay
|
||||
self._read_only = read_only
|
||||
self._shared_events = shared_events
|
||||
self._exclusive = exclusive
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -711,6 +720,10 @@ class _DelayTool(Tool):
|
||||
def read_only(self) -> bool:
|
||||
return self._read_only
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return self._exclusive
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
self._shared_events.append(f"start:{self._name}")
|
||||
await asyncio.sleep(self._delay)
|
||||
@ -756,6 +769,48 @@ async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_does_not_batch_exclusive_read_only_tools():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
read_a = _DelayTool("read_a", delay=0.03, read_only=True, shared_events=shared_events)
|
||||
read_b = _DelayTool("read_b", delay=0.03, read_only=True, shared_events=shared_events)
|
||||
ddg_like = _DelayTool(
|
||||
"ddg_like",
|
||||
delay=0.01,
|
||||
read_only=True,
|
||||
shared_events=shared_events,
|
||||
exclusive=True,
|
||||
)
|
||||
tools.register(read_a)
|
||||
tools.register(ddg_like)
|
||||
tools.register(read_b)
|
||||
|
||||
runner = AgentRunner(MagicMock())
|
||||
await runner._execute_tools(
|
||||
AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
concurrent_tools=True,
|
||||
),
|
||||
[
|
||||
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||
ToolCallRequest(id="ddg1", name="ddg_like", arguments={}),
|
||||
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||
],
|
||||
{},
|
||||
)
|
||||
|
||||
assert shared_events[0] == "start:read_a"
|
||||
assert shared_events.index("end:read_a") < shared_events.index("start:ddg_like")
|
||||
assert shared_events.index("end:ddg_like") < shared_events.index("start:read_b")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_blocks_repeated_external_fetches():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
@ -310,3 +310,90 @@ def test_disabled_skills_excluded_from_get_always_skills(tmp_path: Path) -> None
|
||||
always = loader.get_always_skills()
|
||||
assert "alpha" not in always
|
||||
assert "beta" in always
|
||||
|
||||
|
||||
# -- multiline description tests (YAML folded > and literal |) -----------------
|
||||
|
||||
|
||||
def test_build_skills_summary_folded_description(tmp_path: Path) -> None:
|
||||
"""description: > (YAML folded scalar) should be parsed correctly."""
|
||||
workspace = tmp_path / "ws"
|
||||
ws_skills = workspace / "skills"
|
||||
ws_skills.mkdir(parents=True)
|
||||
skill_dir = ws_skills / "pdf"
|
||||
skill_dir.mkdir(parents=True)
|
||||
skill_path = skill_dir / "SKILL.md"
|
||||
skill_path.write_text(
|
||||
"---\n"
|
||||
"name: pdf\n"
|
||||
"description: >\n"
|
||||
" Use this skill when visual quality and design identity matter for a PDF.\n"
|
||||
" CREATE (generate from scratch): \"make a PDF\".\n"
|
||||
"---\n\n# PDF Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
summary = loader.build_skills_summary()
|
||||
assert "pdf" in summary
|
||||
assert "visual quality" in summary
|
||||
|
||||
|
||||
def test_build_skills_summary_literal_description(tmp_path: Path) -> None:
|
||||
"""description: | (YAML literal scalar) should be parsed correctly."""
|
||||
workspace = tmp_path / "ws"
|
||||
ws_skills = workspace / "skills"
|
||||
ws_skills.mkdir(parents=True)
|
||||
skill_dir = ws_skills / "multi"
|
||||
skill_dir.mkdir(parents=True)
|
||||
skill_path = skill_dir / "SKILL.md"
|
||||
skill_path.write_text(
|
||||
"---\n"
|
||||
"name: multi\n"
|
||||
"description: |\n"
|
||||
" Line one of description.\n"
|
||||
" Line two of description.\n"
|
||||
"---\n\n# Multi\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
meta = loader.get_skill_metadata("multi")
|
||||
assert meta is not None
|
||||
desc = meta.get("description")
|
||||
assert isinstance(desc, str)
|
||||
assert "Line one" in desc
|
||||
assert "Line two" in desc
|
||||
|
||||
|
||||
def test_get_skill_metadata_handles_yaml_types(tmp_path: Path) -> None:
|
||||
"""yaml.safe_load returns native types; always should be True, not 'true'."""
|
||||
workspace = tmp_path / "ws"
|
||||
ws_skills = workspace / "skills"
|
||||
ws_skills.mkdir(parents=True)
|
||||
skill_dir = ws_skills / "typed"
|
||||
skill_dir.mkdir(parents=True)
|
||||
payload = json.dumps({"nanobot": {"requires": {"bins": ["gh"]}, "always": True}}, separators=(",", ":"))
|
||||
skill_path = skill_dir / "SKILL.md"
|
||||
skill_path.write_text(
|
||||
"---\n"
|
||||
"name: typed\n"
|
||||
f"metadata: {payload}\n"
|
||||
"always: true\n"
|
||||
"---\n\n# Typed\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
builtin = tmp_path / "builtin"
|
||||
builtin.mkdir()
|
||||
|
||||
loader = SkillsLoader(workspace, builtin_skills_dir=builtin)
|
||||
meta = loader.get_skill_metadata("typed")
|
||||
assert meta is not None
|
||||
# YAML parsed 'true' to Python True
|
||||
assert meta.get("always") is True
|
||||
# metadata is a parsed dict, not a JSON string
|
||||
assert isinstance(meta.get("metadata"), dict)
|
||||
|
||||
@ -835,7 +835,7 @@ class TestInspectTaskStatuses:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# read-only mode (my_set=False)
|
||||
# read-only mode (tools.my.allow_set=False)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadOnlyMode:
|
||||
|
||||
@ -23,3 +23,15 @@ def test_is_allowed_requires_exact_match() -> None:
|
||||
|
||||
assert channel.is_allowed("allow@email.com") is True
|
||||
assert channel.is_allowed("attacker|allow@email.com") is False
|
||||
|
||||
|
||||
def test_is_allowed_supports_dict_allow_from_alias() -> None:
|
||||
channel = _DummyChannel({"allowFrom": ["alice"]}, MessageBus())
|
||||
|
||||
assert channel.is_allowed("alice") is True
|
||||
|
||||
|
||||
def test_is_allowed_denies_empty_dict_allow_from() -> None:
|
||||
channel = _DummyChannel({"allow_from": []}, MessageBus())
|
||||
|
||||
assert channel.is_allowed("alice") is False
|
||||
|
||||
@ -753,7 +753,10 @@ class _ChannelWithAllowFrom(BaseChannel):
|
||||
|
||||
def __init__(self, config, bus, allow_from):
|
||||
super().__init__(config, bus)
|
||||
self.config.allow_from = allow_from
|
||||
if isinstance(self.config, dict):
|
||||
self.config["allow_from"] = allow_from
|
||||
else:
|
||||
self.config.allow_from = allow_from
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
@ -821,6 +824,25 @@ async def test_validate_allow_from_passes_with_asterisk():
|
||||
mgr._validate_allow_from()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_allow_from_raises_on_empty_dict_allow_from():
|
||||
"""_validate_allow_from should reject empty dict-backed allow_from lists."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
)
|
||||
|
||||
mgr = ChannelManager.__new__(ChannelManager)
|
||||
mgr.config = fake_config
|
||||
mgr.channels = {"test": _ChannelWithAllowFrom({"enabled": True}, None, [])}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
mgr._validate_allow_from()
|
||||
|
||||
assert "empty allowFrom" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_channel_returns_channel_if_exists():
|
||||
"""get_channel should return the channel if it exists."""
|
||||
|
||||
@ -313,6 +313,45 @@ async def test_on_message_accepts_allowlisted_dm() -> None:
|
||||
assert handled[0]["metadata"] == {"message_id": "789", "guild_id": None, "reply_to": None}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_accepts_when_channel_in_allow_channels() -> None:
|
||||
# When allow_channels is set, messages from listed channels should be forwarded.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], allow_channels=["456"]),
|
||||
MessageBus(),
|
||||
)
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(author_id=123, channel_id=456))
|
||||
|
||||
assert len(handled) == 1
|
||||
assert handled[0]["chat_id"] == "456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_drops_when_channel_not_in_allow_channels() -> None:
|
||||
# When allow_channels is set and incoming channel is not listed, drop silently.
|
||||
channel = DiscordChannel(
|
||||
DiscordConfig(enabled=True, allow_from=["*"], allow_channels=["999"]),
|
||||
MessageBus(),
|
||||
)
|
||||
handled: list[dict] = []
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle # type: ignore[method-assign]
|
||||
|
||||
await channel._on_message(_make_message(author_id=123, channel_id=456))
|
||||
|
||||
assert handled == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_ignores_unmentioned_guild_message() -> None:
|
||||
# With mention-only group policy, guild messages without a bot mention are dropped.
|
||||
|
||||
@ -10,8 +10,7 @@ except ImportError:
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.slack import SlackChannel
|
||||
from nanobot.channels.slack import SlackConfig
|
||||
from nanobot.channels.slack import SlackChannel, SlackConfig
|
||||
|
||||
|
||||
class _FakeAsyncWebClient:
|
||||
@ -20,6 +19,12 @@ class _FakeAsyncWebClient:
|
||||
self.file_upload_calls: list[dict[str, object | None]] = []
|
||||
self.reactions_add_calls: list[dict[str, object | None]] = []
|
||||
self.reactions_remove_calls: list[dict[str, object | None]] = []
|
||||
self.conversations_list_calls: list[dict[str, object | None]] = []
|
||||
self.users_list_calls: list[dict[str, object | None]] = []
|
||||
self.conversations_open_calls: list[dict[str, object | None]] = []
|
||||
self._conversations_pages: list[dict[str, object]] = []
|
||||
self._users_pages: list[dict[str, object]] = []
|
||||
self._open_dm_response: dict[str, object] = {"channel": {"id": "D_OPENED"}}
|
||||
|
||||
async def chat_postMessage(
|
||||
self,
|
||||
@ -81,6 +86,22 @@ class _FakeAsyncWebClient:
|
||||
}
|
||||
)
|
||||
|
||||
async def conversations_list(self, **kwargs):
|
||||
self.conversations_list_calls.append(kwargs)
|
||||
if self._conversations_pages:
|
||||
return self._conversations_pages.pop(0)
|
||||
return {"channels": [], "response_metadata": {"next_cursor": ""}}
|
||||
|
||||
async def users_list(self, **kwargs):
|
||||
self.users_list_calls.append(kwargs)
|
||||
if self._users_pages:
|
||||
return self._users_pages.pop(0)
|
||||
return {"members": [], "response_metadata": {"next_cursor": ""}}
|
||||
|
||||
async def conversations_open(self, **kwargs):
|
||||
self.conversations_open_calls.append(kwargs)
|
||||
return self._open_dm_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_uses_thread_for_channel_messages() -> None:
|
||||
@ -151,3 +172,147 @@ async def test_send_updates_reaction_when_final_response_sent() -> None:
|
||||
assert fake_web.reactions_add_calls == [
|
||||
{"channel": "C123", "name": "white_check_mark", "timestamp": "1700000000.000100"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_resolves_channel_name_to_channel_id() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
fake_web._conversations_pages = [
|
||||
{
|
||||
"channels": [{"id": "C999", "name": "channel_x"}],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
]
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="#channel_x",
|
||||
content="hello",
|
||||
)
|
||||
)
|
||||
|
||||
assert fake_web.chat_post_calls == [
|
||||
{"channel": "C999", "text": "hello\n", "thread_ts": None}
|
||||
]
|
||||
assert len(fake_web.conversations_list_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_resolves_user_handle_to_dm_channel() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
fake_web._users_pages = [
|
||||
{
|
||||
"members": [
|
||||
{
|
||||
"id": "U234",
|
||||
"name": "alice",
|
||||
"profile": {"display_name": "Alice"},
|
||||
}
|
||||
],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
]
|
||||
fake_web._open_dm_response = {"channel": {"id": "D234"}}
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="@alice",
|
||||
content="hello",
|
||||
)
|
||||
)
|
||||
|
||||
assert fake_web.conversations_open_calls == [{"users": "U234"}]
|
||||
assert fake_web.chat_post_calls == [
|
||||
{"channel": "D234", "text": "hello\n", "thread_ts": None}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_updates_reaction_on_origin_channel_for_cross_channel_send() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True, react_emoji="eyes"), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
fake_web._conversations_pages = [
|
||||
{
|
||||
"channels": [{"id": "C999", "name": "channel_x"}],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
]
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="channel_x",
|
||||
content="done",
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": {"ts": "1700000000.000100", "channel": "D_ORIGIN"},
|
||||
"channel_type": "im",
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert fake_web.chat_post_calls == [
|
||||
{"channel": "C999", "text": "done\n", "thread_ts": None}
|
||||
]
|
||||
assert fake_web.reactions_remove_calls == [
|
||||
{"channel": "D_ORIGIN", "name": "eyes", "timestamp": "1700000000.000100"}
|
||||
]
|
||||
assert fake_web.reactions_add_calls == [
|
||||
{"channel": "D_ORIGIN", "name": "white_check_mark", "timestamp": "1700000000.000100"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_does_not_reuse_origin_thread_ts_for_cross_channel_send() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
fake_web._conversations_pages = [
|
||||
{
|
||||
"channels": [{"id": "C999", "name": "channel_x"}],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
}
|
||||
]
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="channel_x",
|
||||
content="done",
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": {"ts": "1700000000.000100", "channel": "C_ORIGIN"},
|
||||
"thread_ts": "1700000000.000200",
|
||||
"channel_type": "channel",
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert fake_web.chat_post_calls == [
|
||||
{"channel": "C999", "text": "done\n", "thread_ts": None}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_raises_when_named_target_cannot_be_resolved() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
with pytest.raises(ValueError, match="was not found"):
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="#missing-channel",
|
||||
content="hello",
|
||||
)
|
||||
)
|
||||
|
||||
@ -541,6 +541,50 @@ async def test_process_voice_message() -> None:
|
||||
assert "[voice]" in msg.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_mixed_message() -> None:
|
||||
"""Mixed message: contains picture and text message types."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
saved = f.name
|
||||
|
||||
client.download_file.return_value = (b"\x89PNG\r\n", "photo.png")
|
||||
channel._client = client
|
||||
|
||||
try:
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_mixed_1",
|
||||
"chatid": "chat1",
|
||||
"msgtype": "mixed",
|
||||
"from": {"userid": "user1"},
|
||||
"mixed": {
|
||||
"msg_item": [
|
||||
{"msgtype": "text", "text": {"content": "hello wecom"}},
|
||||
{"msgtype": "image", "image": {"url": "https://example.com/img.png", "aeskey": "key123"}}
|
||||
]
|
||||
}
|
||||
})
|
||||
await channel._process_message(frame, "mixed")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "chat1"
|
||||
assert msg.content.startswith("hello wecom")
|
||||
assert msg.metadata["msg_type"] == "mixed"
|
||||
assert len(msg.media) == 1
|
||||
assert msg.media[0].endswith("photo.png")
|
||||
assert "[image:" in msg.content
|
||||
finally:
|
||||
# Clean up any photo.png in tempdir
|
||||
p = os.path.join(os.path.dirname(saved), "photo.png")
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_deduplication() -> None:
|
||||
"""Same msg_id is not processed twice."""
|
||||
|
||||
@ -257,6 +257,28 @@ def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan():
|
||||
assert config.get_api_base() == "https://ark.cn-beijing.volces.com/api/coding/v3"
|
||||
|
||||
|
||||
def test_config_accepts_lm_studio_without_api_key_and_uses_default_localhost_api_base():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"provider": "lm_studio",
|
||||
"model": "local-model",
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"lmStudio": {
|
||||
"apiKey": None,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "lm_studio"
|
||||
assert config.get_api_key() is None
|
||||
assert config.get_api_base() == "http://localhost:1234/v1"
|
||||
|
||||
|
||||
def test_find_by_name_accepts_camel_case_and_hyphen_aliases():
|
||||
assert find_by_name("volcengineCodingPlan") is not None
|
||||
assert find_by_name("volcengineCodingPlan").name == "volcengine_coding_plan"
|
||||
@ -1126,6 +1148,153 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
|
||||
assert "port 18792" in result.stdout
|
||||
|
||||
|
||||
def test_gateway_health_endpoint_binds_and_serves_expected_responses(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
config_file = _write_instance_config(tmp_path)
|
||||
config = Config()
|
||||
config.gateway.port = 18791
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class _FakeDream:
|
||||
model = None
|
||||
max_batch_size = 0
|
||||
max_iterations = 0
|
||||
|
||||
async def run(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeAgentLoop:
|
||||
def __init__(self, **_kwargs) -> None:
|
||||
self.model = "test-model"
|
||||
self.dream = _FakeDream()
|
||||
|
||||
async def run(self) -> None:
|
||||
await asyncio.Event().wait()
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeChannelManager:
|
||||
def __init__(self, _config, _bus) -> None:
|
||||
self.enabled_channels = ["telegram", "discord"]
|
||||
|
||||
async def start_all(self) -> None:
|
||||
await asyncio.Event().wait()
|
||||
|
||||
async def stop_all(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeCronService:
|
||||
def __init__(self, _store_path: Path) -> None:
|
||||
self.on_job = None
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
return None
|
||||
|
||||
def status(self) -> dict[str, int]:
|
||||
return {"jobs": 0}
|
||||
|
||||
def register_system_job(self, _job) -> None:
|
||||
return None
|
||||
|
||||
class _FakeHeartbeatService:
|
||||
def __init__(self, **_kwargs) -> None:
|
||||
return None
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
def stop(self) -> None:
|
||||
return None
|
||||
|
||||
class _FakeServer:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
async def serve_forever(self) -> None:
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
async def _fake_start_server(handler, host: str, port: int):
|
||||
captured["handler"] = handler
|
||||
captured["host"] = host
|
||||
captured["port"] = port
|
||||
return _FakeServer()
|
||||
|
||||
class _FakeReader:
|
||||
def __init__(self, payload: bytes) -> None:
|
||||
self.payload = payload
|
||||
|
||||
async def read(self, _size: int) -> bytes:
|
||||
return self.payload
|
||||
|
||||
class _FakeWriter:
|
||||
def __init__(self) -> None:
|
||||
self.output = b""
|
||||
self.closed = False
|
||||
|
||||
def write(self, data: bytes) -> None:
|
||||
self.output += data
|
||||
|
||||
async def drain(self) -> None:
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
_patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config,
|
||||
message_bus=lambda: object(),
|
||||
session_manager=lambda _workspace: object(),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.channels.manager.ChannelManager", _FakeChannelManager)
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _FakeCronService)
|
||||
monkeypatch.setattr("nanobot.heartbeat.service.HeartbeatService", _FakeHeartbeatService)
|
||||
monkeypatch.setattr("asyncio.start_server", _fake_start_server)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert captured["host"] == "127.0.0.1"
|
||||
assert captured["port"] == 18791
|
||||
assert "Health endpoint: http://127.0.0.1:18791/health" in result.stdout
|
||||
|
||||
def _call_handler(path: str) -> tuple[str, _FakeWriter]:
|
||||
request = f"GET {path} HTTP/1.1\r\nHost: localhost\r\n\r\n".encode()
|
||||
writer = _FakeWriter()
|
||||
handler = captured["handler"]
|
||||
assert callable(handler)
|
||||
asyncio.run(handler(_FakeReader(request), writer))
|
||||
return writer.output.decode(), writer
|
||||
|
||||
root_response, root_writer = _call_handler("/")
|
||||
assert root_writer.closed is True
|
||||
assert "HTTP/1.0 404 Not Found" in root_response
|
||||
assert root_response.endswith("\r\n\r\nNot Found")
|
||||
|
||||
health_response, health_writer = _call_handler("/health")
|
||||
assert health_writer.closed is True
|
||||
assert "HTTP/1.0 200 OK" in health_response
|
||||
health_body = json.loads(health_response.split("\r\n\r\n", 1)[1])
|
||||
assert health_body == {"status": "ok"}
|
||||
|
||||
missing_response, missing_writer = _call_handler("/missing")
|
||||
assert missing_writer.closed is True
|
||||
assert "HTTP/1.0 404 Not Found" in missing_response
|
||||
assert missing_response.endswith("\r\n\r\nNot Found")
|
||||
|
||||
|
||||
def test_serve_uses_api_config_defaults_and_workspace_override(
|
||||
monkeypatch, tmp_path: Path
|
||||
) -> None:
|
||||
|
||||
@ -140,6 +140,7 @@ class TestRestartCommand:
|
||||
loop.consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||
return_value=(20500, "tiktoken")
|
||||
)
|
||||
loop.subagents.get_running_count_by_session.return_value = 0
|
||||
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
|
||||
|
||||
@ -148,11 +149,36 @@ class TestRestartCommand:
|
||||
assert response is not None
|
||||
assert "Model: test-model" in response.content
|
||||
assert "Tokens: 0 in / 0 out" in response.content
|
||||
assert "Context: 20k/65k (31%)" in response.content
|
||||
assert "Context: 20k/65k (31% of input budget)" in response.content
|
||||
assert "Session: 3 messages" in response.content
|
||||
assert "Uptime: 2m 5s" in response.content
|
||||
assert "Tasks: 0 active" in response.content
|
||||
assert response.metadata == {"render_as": "text"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_counts_running_dispatch_and_subagent_tasks(self):
|
||||
loop, _bus = _make_loop()
|
||||
session = MagicMock()
|
||||
session.get_history.return_value = [{"role": "user"}]
|
||||
loop.sessions.get_or_create.return_value = session
|
||||
loop.consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||
return_value=(1000, "tiktoken")
|
||||
)
|
||||
|
||||
running_task = MagicMock()
|
||||
running_task.done.return_value = False
|
||||
finished_task = MagicMock()
|
||||
finished_task.done.return_value = True
|
||||
|
||||
msg = InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
|
||||
loop._active_tasks[msg.session_key] = [running_task, finished_task]
|
||||
loop.subagents.get_running_count_by_session.return_value = 2
|
||||
|
||||
response = await loop._process_message(msg)
|
||||
|
||||
assert response is not None
|
||||
assert "Tasks: 3 active" in response.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_loop_resets_usage_when_provider_omits_it(self):
|
||||
loop, _bus = _make_loop()
|
||||
@ -179,6 +205,7 @@ class TestRestartCommand:
|
||||
loop.consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||
return_value=(0, "none")
|
||||
)
|
||||
loop.subagents.get_running_count_by_session.return_value = 0
|
||||
|
||||
response = await loop._process_message(
|
||||
InboundMessage(channel="telegram", sender_id="u1", chat_id="c1", content="/status")
|
||||
@ -186,7 +213,8 @@ class TestRestartCommand:
|
||||
|
||||
assert response is not None
|
||||
assert "Tokens: 1200 in / 34 out" in response.content
|
||||
assert "Context: 1k/65k (1%)" in response.content
|
||||
assert "Context: 1k/65k (1% of input budget)" in response.content
|
||||
assert "Tasks: 0 active" in response.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_direct_preserves_render_metadata(self):
|
||||
@ -195,6 +223,7 @@ class TestRestartCommand:
|
||||
session.get_history.return_value = []
|
||||
loop.sessions.get_or_create.return_value = session
|
||||
loop.subagents.get_running_count.return_value = 0
|
||||
loop.subagents.get_running_count_by_session.return_value = 0
|
||||
|
||||
response = await loop.process_direct("/status", session_key="cli:test")
|
||||
|
||||
|
||||
@ -140,6 +140,71 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
|
||||
assert saved["channels"]["qq"]["msgFormat"] == "plain"
|
||||
|
||||
|
||||
def test_load_config_migrates_legacy_my_tool_keys(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"tools": {
|
||||
"myEnabled": False,
|
||||
"mySet": True,
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
assert config.tools.my.enable is False
|
||||
assert config.tools.my.allow_set is True
|
||||
|
||||
|
||||
def test_save_config_rewrites_legacy_my_tool_keys(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"tools": {
|
||||
"myEnabled": False,
|
||||
"mySet": True,
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
save_config(config, config_path)
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
|
||||
tools = saved["tools"]
|
||||
assert "myEnabled" not in tools
|
||||
assert "mySet" not in tools
|
||||
assert tools["my"] == {"enable": False, "allowSet": True}
|
||||
|
||||
|
||||
def test_new_my_tool_keys_take_precedence_over_legacy(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"tools": {
|
||||
"myEnabled": False,
|
||||
"mySet": False,
|
||||
"my": {"enable": True, "allowSet": True},
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
|
||||
assert config.tools.my.enable is True
|
||||
assert config.tools.my.allow_set is True
|
||||
|
||||
|
||||
def test_load_config_resets_ssrf_whitelist_when_next_config_is_empty(tmp_path) -> None:
|
||||
whitelisted = tmp_path / "whitelisted.json"
|
||||
whitelisted.write_text(
|
||||
|
||||
@ -584,6 +584,78 @@ def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -
|
||||
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
||||
|
||||
|
||||
def test_openai_compat_stringifies_dict_tool_arguments() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": {"cmd": "ls -la"}},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "name": "exec", "content": "ok"},
|
||||
{"role": "user", "content": "done"},
|
||||
])
|
||||
|
||||
assert sanitized[1]["tool_calls"][0]["function"]["arguments"] == '{"cmd": "ls -la"}'
|
||||
|
||||
|
||||
def test_openai_compat_repairs_non_json_tool_arguments_string() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{'cmd': 'pwd'}"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "name": "exec", "content": "ok"},
|
||||
{"role": "user", "content": "done"},
|
||||
])
|
||||
|
||||
assert sanitized[1]["tool_calls"][0]["function"]["arguments"] == '{"cmd": "pwd"}'
|
||||
|
||||
|
||||
def test_openai_compat_defaults_missing_tool_arguments_to_empty_object() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{"role": "user", "content": "hi"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "exec"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "name": "exec", "content": "ok"},
|
||||
{"role": "user", "content": "done"},
|
||||
])
|
||||
|
||||
assert sanitized[1]["tool_calls"][0]["function"]["arguments"] == "{}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
|
||||
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
|
||||
@ -658,3 +730,50 @@ def test_openai_no_thinking_extra_body() -> None:
|
||||
"""Non-thinking providers should never get extra_body for thinking."""
|
||||
kw = _build_kwargs_for("openai", "gpt-4o", reasoning_effort="medium")
|
||||
assert "extra_body" not in kw
|
||||
|
||||
|
||||
def test_kimi_k25_thinking_enabled() -> None:
|
||||
"""kimi-k2.5 with reasoning_effort set should opt in to thinking."""
|
||||
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="medium")
|
||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||
|
||||
|
||||
def test_kimi_k25_thinking_disabled_for_minimal() -> None:
|
||||
"""reasoning_effort='minimal' maps to thinking disabled for kimi-k2.5."""
|
||||
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="minimal")
|
||||
assert kw.get("extra_body") == {"thinking": {"type": "disabled"}}
|
||||
|
||||
|
||||
def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None:
|
||||
"""Without reasoning_effort the thinking param must not be injected."""
|
||||
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort=None)
|
||||
assert "extra_body" not in kw
|
||||
|
||||
|
||||
def test_kimi_k25_thinking_enabled_with_openrouter_prefix() -> None:
|
||||
"""OpenRouter-style model names like moonshotai/kimi-k2.5 must trigger thinking."""
|
||||
kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.5", reasoning_effort="medium")
|
||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||
|
||||
def test_kimi_k25_thinking_disabled_with_openrouter_prefix() -> None:
|
||||
"""OpenRouter names must NOT trigger thinking without reasoning_effort."""
|
||||
kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.5", reasoning_effort=None)
|
||||
assert "extra_body" not in kw
|
||||
|
||||
|
||||
def test_kimi_k26_code_preview_thinking_enabled() -> None:
|
||||
"""k2.6-code-preview also supports thinking; should behave like k2.5."""
|
||||
kw = _build_kwargs_for("moonshot", "k2.6-code-preview", reasoning_effort="high")
|
||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||
|
||||
|
||||
def test_kimi_k2_series_no_thinking_injection() -> None:
|
||||
"""kimi-k2 (non-thinking) models must NOT receive extra_body.thinking."""
|
||||
kw = _build_kwargs_for("moonshot", "kimi-k2", reasoning_effort="high")
|
||||
assert "extra_body" not in kw
|
||||
|
||||
|
||||
def test_kimi_k2_thinking_series_no_thinking_injection() -> None:
|
||||
"""kimi-k2-thinking series models must NOT receive extra_body.thinking."""
|
||||
kw = _build_kwargs_for("moonshot", "kimi-k2-thinking", reasoning_effort="high")
|
||||
assert "extra_body" not in kw
|
||||
|
||||
21
tests/providers/test_minimax_anthropic_provider.py
Normal file
21
tests/providers/test_minimax_anthropic_provider.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""Tests for the MiniMax Anthropic provider registration."""
|
||||
|
||||
from nanobot.config.schema import ProvidersConfig
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
|
||||
def test_minimax_anthropic_config_field_exists():
|
||||
"""ProvidersConfig should expose a minimax_anthropic field."""
|
||||
config = ProvidersConfig()
|
||||
assert hasattr(config, "minimax_anthropic")
|
||||
|
||||
|
||||
def test_minimax_anthropic_provider_in_registry():
|
||||
"""MiniMax Anthropic endpoint should be registered with Anthropic backend."""
|
||||
specs = {s.name: s for s in PROVIDERS}
|
||||
assert "minimax_anthropic" in specs
|
||||
|
||||
minimax_anthropic = specs["minimax_anthropic"]
|
||||
assert minimax_anthropic.env_key == "MINIMAX_API_KEY"
|
||||
assert minimax_anthropic.backend == "anthropic"
|
||||
assert minimax_anthropic.default_api_base == "https://api.minimax.io/anthropic"
|
||||
@ -87,6 +87,33 @@ async def test_chat_with_retry_returns_final_error_after_retries(monkeypatch) ->
|
||||
assert delays == [1, 2, 4]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_emits_terminal_progress_when_standard_retries_exhaust(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="429 rate limit a", finish_reason="error"),
|
||||
LLMResponse(content="429 rate limit b", finish_reason="error"),
|
||||
LLMResponse(content="429 rate limit c", finish_reason="error"),
|
||||
LLMResponse(content="503 final server error", finish_reason="error"),
|
||||
])
|
||||
progress: list[str] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
return None
|
||||
|
||||
async def _progress(msg: str) -> None:
|
||||
progress.append(msg)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
on_retry_wait=_progress,
|
||||
)
|
||||
|
||||
assert response.content == "503 final server error"
|
||||
assert progress[-1] == "Model request failed after 4 retries, giving up."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_preserves_cancelled_error() -> None:
|
||||
provider = ScriptedProvider([asyncio.CancelledError()])
|
||||
@ -469,3 +496,67 @@ async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monk
|
||||
assert response.content == "429 rate limit"
|
||||
assert provider.calls == 10
|
||||
assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistent_retry_emits_terminal_progress_on_identical_error_limit(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
*[LLMResponse(content="429 rate limit", finish_reason="error") for _ in range(10)],
|
||||
])
|
||||
progress: list[str] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
return None
|
||||
|
||||
async def _progress(msg: str) -> None:
|
||||
progress.append(msg)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
retry_mode="persistent",
|
||||
on_retry_wait=_progress,
|
||||
)
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert progress[-1] == "Persistent retry stopped after 10 identical errors."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_normalizes_explicit_none_max_tokens() -> None:
|
||||
"""Explicit max_tokens=None must fall back to generation defaults.
|
||||
|
||||
Regression for #3102: callers that construct AgentRunSpec with
|
||||
max_tokens=None propagate None into chat_with_retry, which used to
|
||||
reach ``_build_kwargs`` and crash on ``max(1, None)``.
|
||||
"""
|
||||
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=None,
|
||||
temperature=None,
|
||||
)
|
||||
|
||||
assert response.content == "ok"
|
||||
# Generation settings default to 4096 / 0.7; explicit None should
|
||||
# have been replaced before reaching chat().
|
||||
assert provider.last_kwargs["max_tokens"] == 4096
|
||||
assert provider.last_kwargs["temperature"] == 0.7
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_with_retry_normalizes_explicit_none_max_tokens() -> None:
|
||||
"""chat_stream_with_retry must apply the same None-guard as chat_with_retry."""
|
||||
provider = ScriptedProvider([LLMResponse(content="ok")])
|
||||
|
||||
response = await provider.chat_stream_with_retry(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=None,
|
||||
temperature=None,
|
||||
)
|
||||
|
||||
assert response.content == "ok"
|
||||
assert provider.last_kwargs["max_tokens"] == 4096
|
||||
assert provider.last_kwargs["temperature"] == 0.7
|
||||
|
||||
496
tests/test_api_attachment.py
Normal file
496
tests/test_api_attachment.py
Normal file
@ -0,0 +1,496 @@
|
||||
"""Tests for API file upload functionality (JSON base64 + multipart)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from nanobot.api.server import (
|
||||
_FileSizeExceeded,
|
||||
_parse_json_content,
|
||||
_save_base64_data_url,
|
||||
create_app,
|
||||
)
|
||||
from nanobot.utils.document import extract_documents
|
||||
|
||||
try:
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
HAS_AIOHTTP = True
|
||||
except ImportError:
|
||||
HAS_AIOHTTP = False
|
||||
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
|
||||
agent = MagicMock()
|
||||
agent.process_direct = AsyncMock(return_value=response_text)
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
return _make_mock_agent()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(mock_agent):
|
||||
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def aiohttp_client():
|
||||
clients: list[TestClient] = []
|
||||
|
||||
async def _make_client(app):
|
||||
client = TestClient(TestServer(app))
|
||||
await client.start_server()
|
||||
clients.append(client)
|
||||
return client
|
||||
|
||||
try:
|
||||
yield _make_client
|
||||
finally:
|
||||
for client in clients:
|
||||
await client.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper function tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_save_base64_data_url_saves_png(tmp_path) -> None:
|
||||
"""Saving a base64 data URL creates a file with correct extension."""
|
||||
b64_data = base64.b64encode(b"fake png data").decode()
|
||||
data_url = f"data:image/png;base64,{b64_data}"
|
||||
result = _save_base64_data_url(data_url, tmp_path)
|
||||
assert result is not None
|
||||
assert result.endswith(".png")
|
||||
assert (tmp_path / result.replace(str(tmp_path) + "/", "")).read_bytes() == b"fake png data"
|
||||
|
||||
|
||||
def test_save_base64_data_url_handles_invalid_b64(tmp_path) -> None:
|
||||
"""Invalid base64 returns None."""
|
||||
result = _save_base64_data_url("data:image/png;base64,not-valid-base64!!!", tmp_path)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_save_base64_data_url_handles_unknown_mime(tmp_path) -> None:
|
||||
"""Unknown MIME type defaults to .bin."""
|
||||
b64_data = base64.b64encode(b"some data").decode()
|
||||
data_url = f"data:unknown/type;base64,{b64_data}"
|
||||
result = _save_base64_data_url(data_url, tmp_path)
|
||||
assert result is not None
|
||||
assert result.endswith(".bin")
|
||||
|
||||
|
||||
def test_save_base64_data_url_rejects_oversized_payload(tmp_path) -> None:
|
||||
"""Base64 uploads should respect the same per-file limit as multipart."""
|
||||
large_payload = base64.b64encode(b"x" * (11 * 1024 * 1024)).decode()
|
||||
data_url = f"data:image/png;base64,{large_payload}"
|
||||
|
||||
with pytest.raises(_FileSizeExceeded, match="10MB limit"):
|
||||
_save_base64_data_url(data_url, tmp_path)
|
||||
|
||||
|
||||
def test_parse_json_content_extracts_text_and_media(tmp_path) -> None:
|
||||
"""Parse JSON with text + base64 image saves image and returns paths."""
|
||||
b64_data = base64.b64encode(b"img").decode()
|
||||
body = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_data}"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
text, media_paths = _parse_json_content(body)
|
||||
assert text == "describe this"
|
||||
assert len(media_paths) == 1
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
def test_parse_json_content_plain_text_only() -> None:
|
||||
"""Plain text string content returns no media."""
|
||||
body = {"messages": [{"role": "user", "content": "hello"}]}
|
||||
text, media_paths = _parse_json_content(body)
|
||||
assert text == "hello"
|
||||
assert media_paths == []
|
||||
|
||||
|
||||
def test_parse_json_content_validates_single_message() -> None:
|
||||
"""Multiple messages raise ValueError."""
|
||||
body = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "first"},
|
||||
{"role": "user", "content": "second"},
|
||||
]
|
||||
}
|
||||
with pytest.raises(ValueError, match="single user message"):
|
||||
_parse_json_content(body)
|
||||
|
||||
|
||||
def test_parse_json_content_validates_user_role() -> None:
|
||||
"""Non-user role raises ValueError."""
|
||||
body = {"messages": [{"role": "system", "content": "you are a bot"}]}
|
||||
with pytest.raises(ValueError, match="single user message"):
|
||||
_parse_json_content(body)
|
||||
|
||||
|
||||
def test_parse_json_content_rejects_oversized_base64_file(tmp_path) -> None:
|
||||
"""Oversized JSON data URLs should fail before writing to disk."""
|
||||
large_payload = base64.b64encode(b"x" * (11 * 1024 * 1024)).decode()
|
||||
body = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe"},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{large_payload}"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
with pytest.raises(_FileSizeExceeded, match="10MB limit"):
|
||||
_parse_json_content(body)
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Multipart upload tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_upload_saves_file(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""Multipart upload saves file to media dir and passes path to process_direct."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
file_data = b"test file content"
|
||||
data = BytesIO(file_data)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
data={"message": "analyze this", "files": data},
|
||||
)
|
||||
assert resp.status == 200
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["content"] == "analyze this"
|
||||
assert len(call_kwargs.get("media") or []) == 1
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_multiple_files(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""Multipart upload with multiple files saves all and passes paths."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Note: aiohttp test client has limited multipart support
|
||||
# This test verifies the basic flow
|
||||
file_data = b"test content"
|
||||
data = BytesIO(file_data)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
data={"message": "analyze", "files": data},
|
||||
)
|
||||
assert resp.status == 200
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_file_size_limit(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""File exceeding MAX_FILE_SIZE returns 413."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Create a file larger than 10MB
|
||||
large_data = b"x" * (11 * 1024 * 1024)
|
||||
data = BytesIO(large_data)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
data={"message": "analyze", "files": data},
|
||||
)
|
||||
assert resp.status == 413
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_defaults_text_when_missing(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""Multipart without message field uses default text."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
file_data = b"content"
|
||||
data = BytesIO(file_data)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
data={"files": data},
|
||||
)
|
||||
assert resp.status == 200
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["content"] == "请分析上传的文件"
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multipart_with_session_id(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""Multipart upload with session_id uses custom session key."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
file_data = b"content"
|
||||
data = BytesIO(file_data)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
data={"message": "hello", "session_id": "my-session", "files": data},
|
||||
)
|
||||
assert resp.status == 200
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["session_key"] == "api:my-session"
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backward compatibility tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_text_backward_compat(aiohttp_client, mock_agent) -> None:
|
||||
"""Plain text JSON request (no media) works as before."""
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello world"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == "mock response"
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["content"] == "hello world"
|
||||
assert call_kwargs.get("media") is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_base64_image_upload(aiohttp_client, mock_agent, tmp_path) -> None:
|
||||
"""JSON request with base64 data URL saves file and passes path."""
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
# Use valid base64 for a tiny PNG (1x1 transparent pixel)
|
||||
tiny_png_b64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "what is this"},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{tiny_png_b64}"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["content"] == "what is this"
|
||||
assert len(call_kwargs.get("media", [])) == 1
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_documents tests (now in nanobot.utils.document)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_extract_documents_separates_images_from_docs(tmp_path) -> None:
|
||||
"""Images stay in media; document text is appended to content."""
|
||||
from docx import Document
|
||||
|
||||
png = tmp_path / "chart.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||
|
||||
doc = Document()
|
||||
doc.add_paragraph("Quarterly revenue is $5M")
|
||||
docx_path = tmp_path / "report.docx"
|
||||
doc.save(docx_path)
|
||||
|
||||
text, image_paths = extract_documents("summarize", [str(png), str(docx_path)])
|
||||
assert len(image_paths) == 1
|
||||
assert image_paths[0] == str(png)
|
||||
assert "Quarterly revenue" in text
|
||||
assert "summarize" in text
|
||||
|
||||
|
||||
def test_extract_documents_skips_extraction_errors(tmp_path, monkeypatch) -> None:
|
||||
"""Document extraction errors should not leak into user text."""
|
||||
bad_file = tmp_path / "broken.docx"
|
||||
bad_file.write_text("not a docx", encoding="utf-8")
|
||||
|
||||
import nanobot.utils.document as _doc
|
||||
monkeypatch.setattr(
|
||||
_doc, "extract_text",
|
||||
lambda _path: "[error: failed to extract DOCX: boom]",
|
||||
)
|
||||
|
||||
text, image_paths = extract_documents("hello", [str(bad_file)])
|
||||
assert text == "hello"
|
||||
assert image_paths == []
|
||||
|
||||
|
||||
def test_extract_documents_images_only(tmp_path) -> None:
|
||||
"""When all files are images, text is unchanged and all paths kept."""
|
||||
png = tmp_path / "a.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||
text, image_paths = extract_documents("describe", [str(png)])
|
||||
assert text == "describe"
|
||||
assert len(image_paths) == 1
|
||||
|
||||
|
||||
def test_extract_documents_skips_oversized_files(tmp_path) -> None:
|
||||
"""Files exceeding the size limit should be silently skipped."""
|
||||
big = tmp_path / "huge.txt"
|
||||
big.write_bytes(b"x" * 200)
|
||||
|
||||
text, image_paths = extract_documents("hello", [str(big)], max_file_size=100)
|
||||
assert text == "hello"
|
||||
assert image_paths == []
|
||||
|
||||
|
||||
def test_extract_documents_does_not_read_full_file_for_mime(tmp_path) -> None:
|
||||
"""MIME detection should only read header bytes, not the entire file."""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
big_txt = tmp_path / "big.txt"
|
||||
big_txt.write_bytes(b"hello world " * 100_000) # ~1.2 MB
|
||||
|
||||
original_read_bytes = _Path.read_bytes
|
||||
read_sizes: list[int] = []
|
||||
|
||||
def _tracking_read_bytes(self):
|
||||
data = original_read_bytes(self)
|
||||
read_sizes.append(len(data))
|
||||
return data
|
||||
|
||||
import unittest.mock
|
||||
with unittest.mock.patch.object(_Path, "read_bytes", _tracking_read_bytes):
|
||||
extract_documents("test", [str(big_txt)])
|
||||
|
||||
# If the full file was read for MIME detection, read_sizes would
|
||||
# contain a >1MB entry. After the fix, only a small header is read.
|
||||
assert all(size <= 4096 for size in read_sizes), (
|
||||
f"extract_documents read full file for MIME detection: sizes={read_sizes}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DOCX upload test — API saves file, loop layer extracts text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_docx_upload_passes_media_path(aiohttp_client, tmp_path) -> None:
|
||||
"""Uploaded DOCX is saved to disk and its path passed as media.
|
||||
(Text extraction happens later in AgentLoop._process_message.)"""
|
||||
agent = _make_mock_agent("report summary")
|
||||
import os
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
|
||||
try:
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
from docx import Document
|
||||
doc = Document()
|
||||
doc.add_paragraph("Total revenue: $5,000,000")
|
||||
buf = BytesIO()
|
||||
doc.save(buf)
|
||||
|
||||
import aiohttp
|
||||
data = aiohttp.FormData()
|
||||
data.add_field("message", "summarize the report")
|
||||
data.add_field("files", buf.getvalue(), filename="report.docx",
|
||||
content_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document")
|
||||
|
||||
resp = await client.post("/v1/chat/completions", data=data)
|
||||
assert resp.status == 200
|
||||
call_kwargs = agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["content"] == "summarize the report"
|
||||
media = call_kwargs.get("media", [])
|
||||
assert len(media) == 1
|
||||
assert "report.docx" in media[0]
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
253
tests/test_api_stream.py
Normal file
253
tests/test_api_stream.py
Normal file
@ -0,0 +1,253 @@
|
||||
"""Tests for SSE streaming support in /v1/chat/completions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from nanobot.api.server import (
|
||||
_sse_chunk,
|
||||
_SSE_DONE,
|
||||
create_app,
|
||||
)
|
||||
|
||||
try:
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
HAS_AIOHTTP = True
|
||||
except ImportError:
|
||||
HAS_AIOHTTP = False
|
||||
|
||||
pytest_plugins = ("pytest_asyncio",)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests for SSE helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sse_chunk_with_delta() -> None:
|
||||
raw = _sse_chunk("hello", "test-model", "chatcmpl-abc123")
|
||||
line = raw.decode()
|
||||
assert line.startswith("data: ")
|
||||
payload = json.loads(line[len("data: "):])
|
||||
assert payload["id"] == "chatcmpl-abc123"
|
||||
assert payload["object"] == "chat.completion.chunk"
|
||||
assert payload["model"] == "test-model"
|
||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||
assert payload["choices"][0]["finish_reason"] is None
|
||||
|
||||
|
||||
def test_sse_chunk_finish_reason() -> None:
|
||||
raw = _sse_chunk("", "m", "id1", finish_reason="stop")
|
||||
payload = json.loads(raw.decode().split("data: ", 1)[1])
|
||||
assert payload["choices"][0]["delta"] == {}
|
||||
assert payload["choices"][0]["finish_reason"] == "stop"
|
||||
|
||||
|
||||
def test_sse_done_format() -> None:
|
||||
assert _SSE_DONE == b"data: [DONE]\n\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests with aiohttp TestClient
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_streaming_agent(tokens: list[str]) -> MagicMock:
|
||||
"""Create a mock agent that streams tokens via on_stream callback."""
|
||||
agent = MagicMock()
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
async def fake_process_direct(*, content="", media=None, session_key="",
|
||||
channel="", chat_id="", on_stream=None,
|
||||
on_stream_end=None, **kwargs):
|
||||
if on_stream:
|
||||
for token in tokens:
|
||||
await on_stream(token)
|
||||
if on_stream_end:
|
||||
await on_stream_end()
|
||||
return " ".join(tokens)
|
||||
|
||||
agent.process_direct = fake_process_direct
|
||||
return agent
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def aiohttp_client():
|
||||
clients: list[TestClient] = []
|
||||
|
||||
async def _make_client(app):
|
||||
client = TestClient(TestServer(app))
|
||||
await client.start_server()
|
||||
clients.append(client)
|
||||
return client
|
||||
|
||||
try:
|
||||
yield _make_client
|
||||
finally:
|
||||
for client in clients:
|
||||
await client.close()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_true_returns_sse(aiohttp_client) -> None:
|
||||
"""stream=true should return text/event-stream with SSE chunks."""
|
||||
agent = _make_streaming_agent(["Hello", " world"])
|
||||
app = create_app(agent, model_name="test-model")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert resp.content_type == "text/event-stream"
|
||||
|
||||
body = await resp.text()
|
||||
lines = [l for l in body.split("\n") if l.startswith("data: ")]
|
||||
|
||||
# Should have: 2 token chunks + 1 finish chunk + [DONE]
|
||||
data_lines = [l[len("data: "):] for l in lines]
|
||||
assert data_lines[-1] == "[DONE]"
|
||||
|
||||
chunks = [json.loads(l) for l in data_lines[:-1]]
|
||||
assert chunks[0]["choices"][0]["delta"]["content"] == "Hello"
|
||||
assert chunks[1]["choices"][0]["delta"]["content"] == " world"
|
||||
# Last chunk before [DONE] should have finish_reason=stop
|
||||
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
||||
assert chunks[-1]["choices"][0]["delta"] == {}
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_false_returns_json(aiohttp_client) -> None:
|
||||
"""stream=false should still return regular JSON response."""
|
||||
agent = MagicMock()
|
||||
agent.process_direct = AsyncMock(return_value="normal reply")
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}], "stream": False},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["object"] == "chat.completion"
|
||||
assert body["choices"][0]["message"]["content"] == "normal reply"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_default_is_false(aiohttp_client) -> None:
|
||||
"""Omitting stream should behave like stream=false."""
|
||||
agent = MagicMock()
|
||||
agent.process_direct = AsyncMock(return_value="default reply")
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}]},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.json()
|
||||
assert body["object"] == "chat.completion"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_sse_chunk_ids_are_consistent(aiohttp_client) -> None:
|
||||
"""All SSE chunks in a single stream should share the same id."""
|
||||
agent = _make_streaming_agent(["A", "B", "C"])
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "go"}], "stream": True},
|
||||
)
|
||||
body = await resp.text()
|
||||
data_lines = [l[len("data: "):] for l in body.split("\n") if l.startswith("data: ") and l != "data: [DONE]"]
|
||||
chunks = [json.loads(l) for l in data_lines]
|
||||
|
||||
chunk_ids = {c["id"] for c in chunks}
|
||||
assert len(chunk_ids) == 1, f"Expected single chunk id, got {chunk_ids}"
|
||||
assert chunk_ids.pop().startswith("chatcmpl-")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_passes_on_stream_callbacks(aiohttp_client) -> None:
|
||||
"""process_direct should be called with on_stream and on_stream_end when streaming."""
|
||||
captured_kwargs: dict = {}
|
||||
|
||||
async def fake_process_direct(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
if kwargs.get("on_stream_end"):
|
||||
await kwargs["on_stream_end"]()
|
||||
return "done"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = fake_process_direct
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert captured_kwargs.get("on_stream") is not None
|
||||
assert captured_kwargs.get("on_stream_end") is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_with_session_id(aiohttp_client) -> None:
|
||||
"""Streaming should respect session_id for session key routing."""
|
||||
captured_key: str = ""
|
||||
|
||||
async def fake_process_direct(*, session_key="", on_stream=None, on_stream_end=None, **kwargs):
|
||||
nonlocal captured_key
|
||||
captured_key = session_key
|
||||
if on_stream:
|
||||
await on_stream("ok")
|
||||
if on_stream_end:
|
||||
await on_stream_end()
|
||||
return "ok"
|
||||
|
||||
agent = MagicMock()
|
||||
agent.process_direct = fake_process_direct
|
||||
agent._connect_mcp = AsyncMock()
|
||||
agent.close_mcp = AsyncMock()
|
||||
|
||||
app = create_app(agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": True,
|
||||
"session_id": "my-session",
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert captured_key == "api:my-session"
|
||||
@ -15,6 +15,7 @@ def test_status_shows_cache_hit_rate():
|
||||
)
|
||||
assert "60% cached" in content
|
||||
assert "2000 in / 300 out" in content
|
||||
assert "Tasks: 0 active" in content
|
||||
|
||||
|
||||
def test_status_no_cache_info():
|
||||
@ -30,6 +31,7 @@ def test_status_no_cache_info():
|
||||
)
|
||||
assert "cached" not in content.lower()
|
||||
assert "2000 in / 300 out" in content
|
||||
assert "Tasks: 0 active" in content
|
||||
|
||||
|
||||
def test_status_zero_cached_tokens():
|
||||
@ -57,3 +59,34 @@ def test_status_100_percent_cached():
|
||||
context_tokens_estimate=3000,
|
||||
)
|
||||
assert "100% cached" in content
|
||||
|
||||
|
||||
def test_status_context_pct_uses_budget_not_total():
|
||||
"""Percentage should be calculated against input budget, not raw context window."""
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="test",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 2000, "completion_tokens": 300},
|
||||
context_window_tokens=128000,
|
||||
session_msg_count=10,
|
||||
context_tokens_estimate=120000,
|
||||
max_completion_tokens=8192,
|
||||
)
|
||||
# budget = 128000 - 8192 - 1024 = 118784; pct = 120000/118784*100 ≈ 101%
|
||||
assert "(101% of input budget)" in content
|
||||
|
||||
|
||||
def test_status_context_pct_capped_at_999():
|
||||
"""Extreme overflow should be capped at 999."""
|
||||
content = build_status_content(
|
||||
version="0.1.0",
|
||||
model="test",
|
||||
start_time=1000000.0,
|
||||
last_usage={"prompt_tokens": 2000, "completion_tokens": 300},
|
||||
context_window_tokens=10000,
|
||||
session_msg_count=10,
|
||||
context_tokens_estimate=100000,
|
||||
max_completion_tokens=4096,
|
||||
)
|
||||
assert "(999% of input budget)" in content
|
||||
|
||||
113
tests/test_context_documents.py
Normal file
113
tests/test_context_documents.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""Tests for context builder media handling.
|
||||
|
||||
The ContextBuilder._build_user_content method should ONLY handle images.
|
||||
Document text extraction is the responsibility of the processing layer
|
||||
(AgentLoop._process_message and _drain_pending).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.utils.document import extract_documents
|
||||
|
||||
|
||||
def _make_builder(tmp_path: Path) -> ContextBuilder:
|
||||
"""Create a minimal ContextBuilder for testing."""
|
||||
return ContextBuilder(workspace=tmp_path, timezone="UTC")
|
||||
|
||||
|
||||
def test_build_user_content_with_no_media_returns_string(tmp_path: Path) -> None:
|
||||
builder = _make_builder(tmp_path)
|
||||
result = builder._build_user_content("hello", None)
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
def test_build_user_content_with_image_returns_list(tmp_path: Path) -> None:
|
||||
"""Image files should produce base64 content blocks."""
|
||||
builder = _make_builder(tmp_path)
|
||||
png = tmp_path / "test.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||
result = builder._build_user_content("describe this", [str(png)])
|
||||
assert isinstance(result, list)
|
||||
types = [b["type"] for b in result]
|
||||
assert "image_url" in types
|
||||
assert "text" in types
|
||||
|
||||
|
||||
def test_build_user_content_ignores_non_image_files(tmp_path: Path) -> None:
|
||||
"""Non-image files should be silently skipped — extraction is not context builder's job."""
|
||||
builder = _make_builder(tmp_path)
|
||||
txt = tmp_path / "notes.txt"
|
||||
txt.write_text("some text", encoding="utf-8")
|
||||
result = builder._build_user_content("summarize", [str(txt)])
|
||||
assert result == "summarize"
|
||||
|
||||
|
||||
def test_build_user_content_mixed_image_and_non_image(tmp_path: Path) -> None:
|
||||
"""Only images should be included; non-image files are skipped."""
|
||||
builder = _make_builder(tmp_path)
|
||||
png = tmp_path / "chart.png"
|
||||
png.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100)
|
||||
txt = tmp_path / "report.txt"
|
||||
txt.write_text("report text", encoding="utf-8")
|
||||
|
||||
result = builder._build_user_content("analyze", [str(png), str(txt)])
|
||||
assert isinstance(result, list)
|
||||
assert any(b["type"] == "image_url" for b in result)
|
||||
text_parts = [b.get("text", "") for b in result if b.get("type") == "text"]
|
||||
assert all("report text" not in t for t in text_parts)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bug detection: extract_documents must be called BEFORE _build_user_content
|
||||
# to prevent document media from being silently dropped.
|
||||
# This simulates the _drain_pending code path.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_drain_pending_path_preserves_document_text(tmp_path: Path) -> None:
|
||||
"""Simulates the _drain_pending path: a pending follow-up message
|
||||
with a document attachment must have its text extracted before being
|
||||
passed to _build_user_content. Without extract_documents, the
|
||||
document is silently dropped."""
|
||||
from docx import Document
|
||||
|
||||
doc = Document()
|
||||
doc.add_paragraph("Quarterly revenue is $5M")
|
||||
docx_path = tmp_path / "report.docx"
|
||||
doc.save(docx_path)
|
||||
|
||||
content = "summarize"
|
||||
media = [str(docx_path)]
|
||||
|
||||
# Step 1: extract_documents separates docs from images
|
||||
new_content, image_only = extract_documents(content, media)
|
||||
|
||||
# Step 2: _build_user_content handles only images (none left here)
|
||||
builder = _make_builder(tmp_path)
|
||||
result = builder._build_user_content(new_content, image_only if image_only else None)
|
||||
|
||||
# The document text should be present in the final content
|
||||
assert "Quarterly revenue" in result
|
||||
assert "summarize" in result
|
||||
|
||||
|
||||
def test_drain_pending_path_without_extract_loses_document(tmp_path: Path) -> None:
|
||||
"""Demonstrates the BUG: if _drain_pending calls _build_user_content
|
||||
directly without extract_documents, document content is lost."""
|
||||
from docx import Document
|
||||
|
||||
doc = Document()
|
||||
doc.add_paragraph("Secret data in document")
|
||||
docx_path = tmp_path / "report.docx"
|
||||
doc.save(docx_path)
|
||||
|
||||
builder = _make_builder(tmp_path)
|
||||
|
||||
# Bug path: call _build_user_content directly with document media
|
||||
result = builder._build_user_content("summarize", [str(docx_path)])
|
||||
|
||||
# The document text is LOST — _build_user_content ignores non-images
|
||||
assert result == "summarize" # only the original text, no doc content
|
||||
assert "Secret data" not in result
|
||||
273
tests/test_document_parsing.py
Normal file
273
tests/test_document_parsing.py
Normal file
@ -0,0 +1,273 @@
|
||||
"""Tests for document text extraction utilities."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.utils.document import (
|
||||
SUPPORTED_EXTENSIONS,
|
||||
_is_text_extension,
|
||||
extract_text,
|
||||
)
|
||||
|
||||
|
||||
class TestSupportedExtensions:
|
||||
"""Test the SUPPORTED_EXTENSIONS constant."""
|
||||
|
||||
def test_supported_extensions_include_common_formats(self):
|
||||
"""Test that common document formats are included."""
|
||||
# Document formats
|
||||
assert ".pdf" in SUPPORTED_EXTENSIONS
|
||||
assert ".docx" in SUPPORTED_EXTENSIONS
|
||||
assert ".xlsx" in SUPPORTED_EXTENSIONS
|
||||
assert ".pptx" in SUPPORTED_EXTENSIONS
|
||||
|
||||
# Text formats
|
||||
assert ".txt" in SUPPORTED_EXTENSIONS
|
||||
assert ".md" in SUPPORTED_EXTENSIONS
|
||||
assert ".csv" in SUPPORTED_EXTENSIONS
|
||||
assert ".json" in SUPPORTED_EXTENSIONS
|
||||
assert ".yaml" in SUPPORTED_EXTENSIONS
|
||||
assert ".yml" in SUPPORTED_EXTENSIONS
|
||||
|
||||
# Image formats
|
||||
assert ".png" in SUPPORTED_EXTENSIONS
|
||||
assert ".jpg" in SUPPORTED_EXTENSIONS
|
||||
assert ".jpeg" in SUPPORTED_EXTENSIONS
|
||||
|
||||
|
||||
class TestExtractText:
|
||||
"""Test the extract_text function."""
|
||||
|
||||
def test_extract_text_unsupported_returns_none(self, tmp_path: Path):
|
||||
"""Test that unsupported file types return None."""
|
||||
unsupported_file = tmp_path / "file.xyz"
|
||||
unsupported_file.write_text("content")
|
||||
|
||||
result = extract_text(unsupported_file)
|
||||
assert result is None
|
||||
|
||||
def test_extract_text_file_not_found(self, tmp_path: Path):
|
||||
"""Test that non-existent files return error string."""
|
||||
missing_file = tmp_path / "nonexistent.txt"
|
||||
|
||||
result = extract_text(missing_file)
|
||||
assert result is not None
|
||||
assert "[error: file not found:" in result
|
||||
|
||||
def test_extract_text_txt_file(self, tmp_path: Path):
|
||||
"""Test extracting text from a .txt file."""
|
||||
txt_file = tmp_path / "test.txt"
|
||||
content = "Hello, world!\nThis is a test."
|
||||
txt_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = extract_text(txt_file)
|
||||
assert result == content
|
||||
|
||||
def test_extract_text_txt_file_with_truncation(self, tmp_path: Path):
|
||||
"""Test that large text files are truncated."""
|
||||
txt_file = tmp_path / "large.txt"
|
||||
# Create content larger than _MAX_TEXT_LENGTH
|
||||
content = "x" * 300_000
|
||||
txt_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = extract_text(txt_file)
|
||||
assert len(result) < 300_000
|
||||
assert "(truncated," in result
|
||||
assert "chars total)" in result
|
||||
|
||||
def test_extract_text_md_file(self, tmp_path: Path):
|
||||
"""Test extracting text from a .md file."""
|
||||
md_file = tmp_path / "test.md"
|
||||
content = "# Header\n\nSome markdown content."
|
||||
md_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = extract_text(md_file)
|
||||
assert result == content
|
||||
|
||||
def test_extract_text_csv_file(self, tmp_path: Path):
|
||||
"""Test extracting text from a .csv file."""
|
||||
csv_file = tmp_path / "test.csv"
|
||||
content = "name,age\nAlice,30\nBob,25"
|
||||
csv_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = extract_text(csv_file)
|
||||
assert result == content
|
||||
|
||||
def test_extract_text_json_file(self, tmp_path: Path):
|
||||
"""Test extracting text from a .json file."""
|
||||
json_file = tmp_path / "test.json"
|
||||
content = '{"key": "value", "number": 42}'
|
||||
json_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = extract_text(json_file)
|
||||
assert result == content
|
||||
|
||||
def test_extract_text_xlsx(self, tmp_path: Path):
|
||||
"""Test extracting text from an .xlsx file."""
|
||||
from openpyxl import Workbook
|
||||
|
||||
xlsx_file = tmp_path / "test.xlsx"
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
ws.title = "Sheet1"
|
||||
ws["A1"] = "Name"
|
||||
ws["B1"] = "Age"
|
||||
ws["A2"] = "Alice"
|
||||
ws["B2"] = 30
|
||||
ws["A3"] = "Bob"
|
||||
ws["B3"] = 25
|
||||
|
||||
# Add a second sheet
|
||||
ws2 = wb.create_sheet("Sheet2")
|
||||
ws2["A1"] = "Product"
|
||||
ws2["B1"] = "Price"
|
||||
ws2["A2"] = "Widget"
|
||||
ws2["B2"] = 9.99
|
||||
|
||||
wb.save(xlsx_file)
|
||||
wb.close()
|
||||
|
||||
result = extract_text(xlsx_file)
|
||||
assert result is not None
|
||||
assert "--- Sheet: Sheet1 ---" in result
|
||||
assert "--- Sheet: Sheet2 ---" in result
|
||||
assert "Alice" in result
|
||||
assert "Bob" in result
|
||||
assert "Widget" in result
|
||||
assert "9.99" in result
|
||||
|
||||
def test_extract_text_xlsx_empty_sheet(self, tmp_path: Path):
|
||||
"""Test extracting text from an .xlsx file with empty sheets."""
|
||||
from openpyxl import Workbook
|
||||
|
||||
xlsx_file = tmp_path / "empty.xlsx"
|
||||
wb = Workbook()
|
||||
# Clear the default sheet
|
||||
wb.remove(wb.active)
|
||||
# Add an empty sheet
|
||||
wb.create_sheet("EmptySheet")
|
||||
wb.save(xlsx_file)
|
||||
wb.close()
|
||||
|
||||
result = extract_text(xlsx_file)
|
||||
# Empty sheets should return empty string or header only
|
||||
assert result == "--- Sheet: EmptySheet ---" or result == ""
|
||||
|
||||
def test_extract_text_docx(self, tmp_path: Path):
|
||||
"""Test extracting text from a .docx file."""
|
||||
from docx import Document
|
||||
|
||||
docx_file = tmp_path / "test.docx"
|
||||
doc = Document()
|
||||
doc.add_heading("Test Document", 0)
|
||||
doc.add_paragraph("This is paragraph one.")
|
||||
doc.add_paragraph("This is paragraph two.")
|
||||
doc.save(docx_file)
|
||||
|
||||
result = extract_text(docx_file)
|
||||
assert result is not None
|
||||
assert "Test Document" in result
|
||||
assert "This is paragraph one." in result
|
||||
assert "This is paragraph two." in result
|
||||
|
||||
def test_extract_text_docx_empty(self, tmp_path: Path):
|
||||
"""Test extracting text from an empty .docx file."""
|
||||
from docx import Document
|
||||
|
||||
docx_file = tmp_path / "empty.docx"
|
||||
doc = Document()
|
||||
doc.save(docx_file)
|
||||
|
||||
result = extract_text(docx_file)
|
||||
assert result == ""
|
||||
|
||||
def test_extract_text_pptx(self, tmp_path: Path):
|
||||
"""Test extracting text from a .pptx file."""
|
||||
from pptx import Presentation
|
||||
|
||||
pptx_file = tmp_path / "test.pptx"
|
||||
prs = Presentation()
|
||||
|
||||
# Slide 1
|
||||
slide1 = prs.slides.add_slide(prs.slide_layouts[0])
|
||||
for shape in slide1.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
shape.text = "First Slide Title"
|
||||
|
||||
# Slide 2
|
||||
slide2 = prs.slides.add_slide(prs.slide_layouts[5])
|
||||
left = top = width = height = 1000000
|
||||
textbox = slide2.shapes.add_textbox(left, top, width, height)
|
||||
text_frame = textbox.text_frame
|
||||
text_frame.text = "Bullet point content"
|
||||
|
||||
prs.save(pptx_file)
|
||||
|
||||
result = extract_text(pptx_file)
|
||||
assert result is not None
|
||||
assert "--- Slide 1 ---" in result
|
||||
assert "--- Slide 2 ---" in result
|
||||
# Text content may vary depending on PowerPoint layout defaults
|
||||
assert len(result) > 0
|
||||
|
||||
def test_extract_text_pdf_not_found(self, tmp_path: Path):
|
||||
"""Test that missing PDF files return error string."""
|
||||
missing_pdf = tmp_path / "nonexistent.pdf"
|
||||
|
||||
result = extract_text(missing_pdf)
|
||||
assert result is not None
|
||||
assert "[error: file not found:" in result
|
||||
|
||||
def test_extract_text_image_files(self, tmp_path: Path):
|
||||
"""Test that image files return placeholder text."""
|
||||
# Create a minimal PNG file (1x1 pixel)
|
||||
png_file = tmp_path / "test.png"
|
||||
# Minimal valid PNG: 8-byte signature + IHDR + IDAT + IEND
|
||||
png_data = (
|
||||
b"\x89PNG\r\n\x1a\n"
|
||||
b"\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01"
|
||||
b"\x08\x02\x00\x00\x00\x90wS\xde"
|
||||
b"\x00\x00\x00\x0cIDATx\x9cc\x00\x01\x00\x00\x05\x00\x01"
|
||||
b"\r\n-\xb4\x00\x00\x00\x00IEND\xaeB`\x82"
|
||||
)
|
||||
png_file.write_bytes(png_data)
|
||||
|
||||
result = extract_text(png_file)
|
||||
assert result is not None
|
||||
assert "[image:" in result
|
||||
assert "test.png" in result
|
||||
|
||||
|
||||
class TestIsTextExtension:
|
||||
"""Test the _is_text_extension helper."""
|
||||
|
||||
def test_text_extensions_return_true(self):
|
||||
"""Test that known text extensions return True."""
|
||||
assert _is_text_extension(".txt") is True
|
||||
assert _is_text_extension(".md") is True
|
||||
assert _is_text_extension(".csv") is True
|
||||
assert _is_text_extension(".json") is True
|
||||
assert _is_text_extension(".yaml") is True
|
||||
assert _is_text_extension(".yml") is True
|
||||
assert _is_text_extension(".xml") is True
|
||||
assert _is_text_extension(".html") is True
|
||||
assert _is_text_extension(".htm") is True
|
||||
|
||||
def test_non_text_extensions_return_false(self):
|
||||
"""Test that non-text extensions return False."""
|
||||
assert _is_text_extension(".pdf") is False
|
||||
assert _is_text_extension(".docx") is False
|
||||
assert _is_text_extension(".xlsx") is False
|
||||
assert _is_text_extension(".pptx") is False
|
||||
assert _is_text_extension(".png") is False
|
||||
assert _is_text_extension(".xyz") is False
|
||||
|
||||
def test_case_sensitivity(self):
|
||||
"""Test that _is_text_extension requires lowercase extension.
|
||||
|
||||
Note: The main extract_text function handles case-insensitivity by
|
||||
converting extensions to lowercase before calling _is_text_extension.
|
||||
"""
|
||||
# _is_text_extension itself is case-sensitive (lowercase only)
|
||||
assert _is_text_extension(".txt") is True
|
||||
assert _is_text_extension(".TXT") is False
|
||||
assert _is_text_extension(".pdf") is False
|
||||
@ -147,6 +147,40 @@ async def test_handle_activity_ignores_group_messages(make_channel):
|
||||
assert ch._conversation_refs == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_activity_denied_sender_does_not_store_ref(make_channel, tmp_path):
|
||||
ch = make_channel(allowFrom=["allowed-user"])
|
||||
|
||||
activity = {
|
||||
"type": "message",
|
||||
"id": "activity-denied",
|
||||
"text": "Hello from denied user",
|
||||
"serviceUrl": "https://smba.trafficmanager.net/amer/",
|
||||
"conversation": {
|
||||
"id": "conv-denied",
|
||||
"conversationType": "personal",
|
||||
},
|
||||
"from": {
|
||||
"id": "29:user-id",
|
||||
"aadObjectId": "aad-user-1",
|
||||
"name": "Bob",
|
||||
},
|
||||
"recipient": {
|
||||
"id": "28:bot-id",
|
||||
"name": "nanobot",
|
||||
},
|
||||
"channelData": {
|
||||
"tenant": {"id": "tenant-id"},
|
||||
},
|
||||
}
|
||||
|
||||
await ch._handle_activity(activity)
|
||||
|
||||
assert ch.bus.inbound == []
|
||||
assert ch._conversation_refs == {}
|
||||
assert not (tmp_path / "state" / "msteams_conversations.json").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_activity_mention_only_uses_default_response(make_channel):
|
||||
ch = make_channel()
|
||||
@ -520,6 +554,7 @@ async def test_start_logs_install_hint_when_pyjwt_missing(make_channel, monkeypa
|
||||
def test_msteams_default_config_includes_restart_notify_fields():
|
||||
cfg = MSTeamsChannel.default_config()
|
||||
|
||||
assert cfg["validateInboundAuth"] is True
|
||||
assert "restartNotifyEnabled" not in cfg
|
||||
assert "restartNotifyPreMessage" not in cfg
|
||||
assert "restartNotifyPostMessage" not in cfg
|
||||
|
||||
@ -101,15 +101,14 @@ async def test_no_user_message_returns_400(aiohttp_client, app) -> None:
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_true_returns_400(aiohttp_client, app) -> None:
|
||||
async def test_stream_true_returns_sse(aiohttp_client, app) -> None:
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={"messages": [{"role": "user", "content": "hello"}], "stream": True},
|
||||
)
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert "stream" in body["error"]["message"].lower()
|
||||
assert resp.status == 200
|
||||
assert resp.content_type == "text/event-stream"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -194,6 +193,7 @@ async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_ag
|
||||
assert body["model"] == "test-model"
|
||||
mock_agent.process_direct.assert_called_once_with(
|
||||
content="hello",
|
||||
media=None,
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
@ -205,7 +205,7 @@ async def test_successful_request_uses_fixed_api_session(aiohttp_client, mock_ag
|
||||
async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
|
||||
call_log: list[str] = []
|
||||
|
||||
async def fake_process(content, session_key="", channel="", chat_id=""):
|
||||
async def fake_process(content, session_key="", channel="", chat_id="", **kwargs):
|
||||
call_log.append(session_key)
|
||||
return f"reply to {content}"
|
||||
|
||||
@ -236,7 +236,7 @@ async def test_followup_requests_share_same_session_key(aiohttp_client) -> None:
|
||||
async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None:
|
||||
order: list[str] = []
|
||||
|
||||
async def slow_process(content, session_key="", channel="", chat_id=""):
|
||||
async def slow_process(content, session_key="", channel="", chat_id="", **kwargs):
|
||||
order.append(f"start:{content}")
|
||||
await asyncio.sleep(0.1)
|
||||
order.append(f"end:{content}")
|
||||
@ -307,20 +307,46 @@ async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent) -> N
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
mock_agent.process_direct.assert_called_once_with(
|
||||
content="describe this",
|
||||
session_key=API_SESSION_KEY,
|
||||
channel="api",
|
||||
chat_id=API_CHAT_ID,
|
||||
call_kwargs = mock_agent.process_direct.call_args.kwargs
|
||||
assert call_kwargs["content"] == "describe this"
|
||||
assert call_kwargs["session_key"] == API_SESSION_KEY
|
||||
assert call_kwargs["channel"] == "api"
|
||||
assert call_kwargs["chat_id"] == API_CHAT_ID
|
||||
assert len(call_kwargs.get("media") or []) >= 0 # base64 images saved to disk
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_multimodal_remote_image_url_returns_400(aiohttp_client, mock_agent) -> None:
|
||||
app = create_app(mock_agent, model_name="m")
|
||||
client = await aiohttp_client(app)
|
||||
resp = await client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "describe this"},
|
||||
{"type": "image_url", "image_url": {"url": "https://example.com/image.png"}},
|
||||
],
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status == 400
|
||||
body = await resp.json()
|
||||
assert "remote image urls are not supported" in body["error"]["message"].lower()
|
||||
mock_agent.process_direct.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_retry_then_success(aiohttp_client) -> None:
|
||||
call_count = 0
|
||||
|
||||
async def sometimes_empty(content, session_key="", channel="", chat_id=""):
|
||||
async def sometimes_empty(content, session_key="", channel="", chat_id="", **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
@ -351,7 +377,7 @@ async def test_empty_response_falls_back(aiohttp_client) -> None:
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def always_empty(content, session_key="", channel="", chat_id=""):
|
||||
async def always_empty(content, session_key="", channel="", chat_id="", **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return ""
|
||||
@ -371,3 +397,31 @@ async def test_empty_response_falls_back(aiohttp_client) -> None:
|
||||
body = await resp.json()
|
||||
assert body["choices"][0]["message"]["content"] == EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_direct_accepts_media() -> None:
|
||||
"""process_direct should forward media paths to _process_message."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop._connect_mcp = AsyncMock()
|
||||
|
||||
captured_msg = None
|
||||
|
||||
async def fake_process(msg, *, session_key="", on_progress=None, on_stream=None, on_stream_end=None):
|
||||
nonlocal captured_msg
|
||||
captured_msg = msg
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process
|
||||
|
||||
await loop.process_direct(
|
||||
content="analyze this",
|
||||
media=["/tmp/image.png", "/tmp/report.pdf"],
|
||||
session_key="test:1",
|
||||
)
|
||||
|
||||
assert captured_msg is not None
|
||||
assert captured_msg.media == ["/tmp/image.png", "/tmp/report.pdf"]
|
||||
assert captured_msg.content == "analyze this"
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
"""Tests for multi-provider web search."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
@ -20,6 +18,25 @@ def _response(status: int = 200, json: dict | None = None) -> httpx.Response:
|
||||
return r
|
||||
|
||||
|
||||
def test_duckduckgo_search_is_exclusive():
|
||||
tool = _tool(provider="duckduckgo")
|
||||
assert tool.exclusive is True
|
||||
assert tool.concurrency_safe is False
|
||||
|
||||
|
||||
def test_brave_with_api_key_remains_concurrency_safe():
|
||||
tool = _tool(provider="brave", api_key="brave-key")
|
||||
assert tool.exclusive is False
|
||||
assert tool.concurrency_safe is True
|
||||
|
||||
|
||||
def test_brave_without_api_key_is_treated_as_duckduckgo_for_concurrency(monkeypatch):
|
||||
monkeypatch.delenv("BRAVE_API_KEY", raising=False)
|
||||
tool = _tool(provider="brave", api_key="")
|
||||
assert tool.exclusive is True
|
||||
assert tool.concurrency_safe is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_brave_search(monkeypatch):
|
||||
async def mock_get(self, url, **kw):
|
||||
@ -79,7 +96,6 @@ async def test_duckduckgo_search(monkeypatch):
|
||||
import nanobot.agent.tools.web as web_mod
|
||||
monkeypatch.setattr(web_mod, "DDGS", MockDDGS, raising=False)
|
||||
|
||||
from ddgs import DDGS
|
||||
monkeypatch.setattr("ddgs.DDGS", MockDDGS)
|
||||
|
||||
tool = _tool(provider="duckduckgo")
|
||||
@ -265,5 +281,3 @@ async def test_duckduckgo_timeout_returns_error(monkeypatch):
|
||||
result = await tool.execute(query="test")
|
||||
gate.set()
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
|
||||
91
tests/utils/test_gitstore.py
Normal file
91
tests/utils/test_gitstore.py
Normal file
@ -0,0 +1,91 @@
|
||||
"""Tests for GitStore — line_ages() and core git operations."""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.utils.gitstore import GitStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def git(tmp_path):
|
||||
"""Create an initialized GitStore with tracked MEMORY.md."""
|
||||
g = GitStore(tmp_path, tracked_files=["MEMORY.md", "SOUL.md"])
|
||||
g.init()
|
||||
return g
|
||||
|
||||
|
||||
class TestLineAges:
|
||||
def test_returns_empty_when_not_initialized(self, tmp_path):
|
||||
"""line_ages should return [] if the git repo is not initialized."""
|
||||
git = GitStore(tmp_path, tracked_files=["MEMORY.md"])
|
||||
assert git.line_ages("MEMORY.md") == []
|
||||
|
||||
def test_returns_empty_for_missing_file(self, git):
|
||||
"""line_ages should return [] for a file that doesn't exist."""
|
||||
assert git.line_ages("SOUL.md") == []
|
||||
|
||||
def test_returns_empty_for_empty_file(self, git, tmp_path):
|
||||
"""line_ages should return [] for an empty tracked file."""
|
||||
(tmp_path / "SOUL.md").write_text("", encoding="utf-8")
|
||||
git.auto_commit("empty soul")
|
||||
assert git.line_ages("SOUL.md") == []
|
||||
|
||||
def test_one_age_per_line(self, git, tmp_path):
|
||||
"""line_ages should return one entry per line in the file."""
|
||||
content = "# Memory\n\n## Section A\n- item 1\n"
|
||||
(tmp_path / "MEMORY.md").write_text(content, encoding="utf-8")
|
||||
git.auto_commit("initial")
|
||||
ages = git.line_ages("MEMORY.md")
|
||||
assert len(ages) == len(content.splitlines())
|
||||
|
||||
def test_fresh_lines_have_age_zero(self, git, tmp_path):
|
||||
"""Lines committed today should have age_days=0."""
|
||||
(tmp_path / "MEMORY.md").write_text("## A\n- x\n", encoding="utf-8")
|
||||
git.auto_commit("initial")
|
||||
ages = git.line_ages("MEMORY.md")
|
||||
assert all(a.age_days == 0 for a in ages)
|
||||
|
||||
def test_age_differentiates_across_days(self, git, tmp_path):
|
||||
"""Lines committed today should show correct age when 'now' is mocked forward."""
|
||||
(tmp_path / "MEMORY.md").write_text("## A\n- x\n", encoding="utf-8")
|
||||
git.auto_commit("initial")
|
||||
|
||||
future_now = datetime.now(tz=timezone.utc) + timedelta(days=30)
|
||||
with patch("nanobot.utils.gitstore.datetime") as mock_dt:
|
||||
mock_dt.now.return_value = future_now
|
||||
mock_dt.fromtimestamp = datetime.fromtimestamp
|
||||
ages = git.line_ages("MEMORY.md")
|
||||
|
||||
assert len(ages) == 2
|
||||
assert all(a.age_days == 30 for a in ages)
|
||||
|
||||
def test_annotate_failure_returns_empty(self, tmp_path):
|
||||
"""If annotate fails, line_ages should return [] gracefully."""
|
||||
git = GitStore(tmp_path, tracked_files=["MEMORY.md"])
|
||||
# Don't init — annotate will fail
|
||||
assert git.line_ages("MEMORY.md") == []
|
||||
|
||||
def test_partial_edit_only_updates_changed_lines(self, git, tmp_path):
|
||||
"""Only modified lines should reflect the new commit's timestamp."""
|
||||
(tmp_path / "MEMORY.md").write_text(
|
||||
"# Memory\n\n## A\n- old\n\n## B\n- keep\n", encoding="utf-8"
|
||||
)
|
||||
git.auto_commit("commit1")
|
||||
time.sleep(1.1)
|
||||
|
||||
# Only modify section A
|
||||
(tmp_path / "MEMORY.md").write_text(
|
||||
"# Memory\n\n## A\n- new\n\n## B\n- keep\n", encoding="utf-8"
|
||||
)
|
||||
git.auto_commit("commit2")
|
||||
|
||||
ages = git.line_ages("MEMORY.md")
|
||||
lines = (tmp_path / "MEMORY.md").read_text(encoding="utf-8").splitlines()
|
||||
# All lines are from today, but verify line-level tracking works
|
||||
assert len(ages) == len(lines)
|
||||
# "- new" line and "- keep" line both age=0 (same day), but
|
||||
# the key point is we get per-line results
|
||||
assert len(ages) == 7
|
||||
Loading…
x
Reference in New Issue
Block a user