mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-01 23:35:52 +00:00
Merge remote-tracking branch 'origin/main' into feat/openai-compatible-session-isolation
# Conflicts: # nanobot/agent/context.py # tests/test_consolidate_offset.py
This commit is contained in:
commit
f958eb4cc9
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
|
.worktrees/
|
||||||
.assets
|
.assets
|
||||||
.env
|
.env
|
||||||
*.pyc
|
*.pyc
|
||||||
@ -19,4 +20,4 @@ __pycache__/
|
|||||||
poetry.lock
|
poetry.lock
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
botpy.log
|
botpy.log
|
||||||
tests/
|
|
||||||
|
|||||||
89
README.md
89
README.md
@ -12,20 +12,28 @@
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw)
|
🐈 **nanobot** is an **ultra-lightweight** personal AI assistant inspired by [OpenClaw](https://github.com/openclaw/openclaw).
|
||||||
|
|
||||||
⚡️ Delivers core agent functionality in just **~4,000** lines of code — **99% smaller** than Clawdbot's 430k+ lines.
|
⚡️ Delivers core agent functionality with **99% fewer lines of code** than OpenClaw.
|
||||||
|
|
||||||
📏 Real-time line count: **3,922 lines** (run `bash core_agent_lines.sh` to verify anytime)
|
📏 Real-time line count: run `bash core_agent_lines.sh` to verify anytime.
|
||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
|
- **2026-02-28** 🚀 Released **v0.1.4.post3** — cleaner context, hardened session history, and smarter agent. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post3) for details.
|
||||||
|
- **2026-02-27** 🧠 Experimental thinking mode support, DingTalk media messages, Feishu and QQ channel fixes.
|
||||||
|
- **2026-02-26** 🛡️ Session poisoning fix, WhatsApp dedup, Windows path guard, Mistral compatibility.
|
||||||
|
- **2026-02-25** 🧹 New Matrix channel, cleaner session context, auto workspace template sync.
|
||||||
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
|
- **2026-02-24** 🚀 Released **v0.1.4.post2** — a reliability-focused release with a redesigned heartbeat, prompt cache optimization, and hardened provider & channel stability. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post2) for details.
|
||||||
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
|
- **2026-02-23** 🔧 Virtual tool-call heartbeat, prompt cache optimization, Slack mrkdwn fixes.
|
||||||
- **2026-02-22** 🛡️ Slack thread isolation, Discord typing fix, agent reliability improvements.
|
- **2026-02-22** 🛡️ Slack thread isolation, Discord typing fix, agent reliability improvements.
|
||||||
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
|
- **2026-02-21** 🎉 Released **v0.1.4.post1** — new providers, media support across channels, and major stability improvements. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post1) for details.
|
||||||
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
|
- **2026-02-20** 🐦 Feishu now receives multimodal files from users. More reliable memory under the hood.
|
||||||
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
|
- **2026-02-19** ✨ Slack now sends files, Discord splits long messages, and subagents work in CLI mode.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Earlier news</summary>
|
||||||
|
|
||||||
- **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching.
|
- **2026-02-18** ⚡️ nanobot now supports VolcEngine, MCP custom auth headers, and Anthropic prompt caching.
|
||||||
- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details.
|
- **2026-02-17** 🎉 Released **v0.1.4** — MCP support, progress streaming, new providers, and multiple channel improvements. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4) for details.
|
||||||
- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills.
|
- **2026-02-16** 🦞 nanobot now integrates a [ClawHub](https://clawhub.ai) skill — search and install public agent skills.
|
||||||
@ -34,10 +42,6 @@
|
|||||||
- **2026-02-13** 🎉 Released **v0.1.3.post7** — includes security hardening and multiple improvements. **Please upgrade to the latest version to address security issues**. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details.
|
- **2026-02-13** 🎉 Released **v0.1.3.post7** — includes security hardening and multiple improvements. **Please upgrade to the latest version to address security issues**. See [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post7) for more details.
|
||||||
- **2026-02-12** 🧠 Redesigned memory system — Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it!
|
- **2026-02-12** 🧠 Redesigned memory system — Less code, more reliable. Join the [discussion](https://github.com/HKUDS/nanobot/discussions/566) about it!
|
||||||
- **2026-02-11** ✨ Enhanced CLI experience and added MiniMax support!
|
- **2026-02-11** ✨ Enhanced CLI experience and added MiniMax support!
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary>Earlier news</summary>
|
|
||||||
|
|
||||||
- **2026-02-10** 🎉 Released **v0.1.3.post6** with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431).
|
- **2026-02-10** 🎉 Released **v0.1.3.post6** with improvements! Check the updates [notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.3.post6) and our [roadmap](https://github.com/HKUDS/nanobot/discussions/431).
|
||||||
- **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms!
|
- **2026-02-09** 💬 Added Slack, Email, and QQ support — nanobot now supports multiple chat platforms!
|
||||||
- **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers).
|
- **2026-02-08** 🔧 Refactored Providers—adding a new LLM provider now takes just 2 simple steps! Check [here](#providers).
|
||||||
@ -289,12 +293,18 @@ If you prefer to configure manually, add the following to `~/.nanobot/config.jso
|
|||||||
"discord": {
|
"discord": {
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"token": "YOUR_BOT_TOKEN",
|
"token": "YOUR_BOT_TOKEN",
|
||||||
"allowFrom": ["YOUR_USER_ID"]
|
"allowFrom": ["YOUR_USER_ID"],
|
||||||
|
"groupPolicy": "mention"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> `groupPolicy` controls how the bot responds in group channels:
|
||||||
|
> - `"mention"` (default) — Only respond when @mentioned
|
||||||
|
> - `"open"` — Respond to all messages
|
||||||
|
> DMs always respond when the sender is in `allowFrom`.
|
||||||
|
|
||||||
**5. Invite the bot**
|
**5. Invite the bot**
|
||||||
- OAuth2 → URL Generator
|
- OAuth2 → URL Generator
|
||||||
- Scopes: `bot`
|
- Scopes: `bot`
|
||||||
@ -343,7 +353,7 @@ pip install nanobot-ai[matrix]
|
|||||||
"accessToken": "syt_xxx",
|
"accessToken": "syt_xxx",
|
||||||
"deviceId": "NANOBOT01",
|
"deviceId": "NANOBOT01",
|
||||||
"e2eeEnabled": true,
|
"e2eeEnabled": true,
|
||||||
"allowFrom": [],
|
"allowFrom": ["@your_user:matrix.org"],
|
||||||
"groupPolicy": "open",
|
"groupPolicy": "open",
|
||||||
"groupAllowFrom": [],
|
"groupAllowFrom": [],
|
||||||
"allowRoomMentions": false,
|
"allowRoomMentions": false,
|
||||||
@ -437,14 +447,14 @@ Uses **WebSocket** long connection — no public IP required.
|
|||||||
"appSecret": "xxx",
|
"appSecret": "xxx",
|
||||||
"encryptKey": "",
|
"encryptKey": "",
|
||||||
"verificationToken": "",
|
"verificationToken": "",
|
||||||
"allowFrom": []
|
"allowFrom": ["ou_YOUR_OPEN_ID"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
|
> `encryptKey` and `verificationToken` are optional for Long Connection mode.
|
||||||
> `allowFrom`: Leave empty to allow all users, or add `["ou_xxx"]` to restrict access.
|
> `allowFrom`: Add your open_id (find it in nanobot logs when you message the bot). Use `["*"]` to allow all users.
|
||||||
|
|
||||||
**3. Run**
|
**3. Run**
|
||||||
|
|
||||||
@ -474,7 +484,7 @@ Uses **botpy SDK** with WebSocket — no public IP required. Currently supports
|
|||||||
|
|
||||||
**3. Configure**
|
**3. Configure**
|
||||||
|
|
||||||
> - `allowFrom`: Leave empty for public access, or add user openids to restrict. You can find openids in the nanobot logs when a user messages the bot.
|
> - `allowFrom`: Add your openid (find it in nanobot logs when you message the bot). Use `["*"]` for public access.
|
||||||
> - For production: submit a review in the bot console and publish. See [QQ Bot Docs](https://bot.q.qq.com/wiki/) for the full publishing flow.
|
> - For production: submit a review in the bot console and publish. See [QQ Bot Docs](https://bot.q.qq.com/wiki/) for the full publishing flow.
|
||||||
|
|
||||||
```json
|
```json
|
||||||
@ -484,7 +494,7 @@ Uses **botpy SDK** with WebSocket — no public IP required. Currently supports
|
|||||||
"enabled": true,
|
"enabled": true,
|
||||||
"appId": "YOUR_APP_ID",
|
"appId": "YOUR_APP_ID",
|
||||||
"secret": "YOUR_APP_SECRET",
|
"secret": "YOUR_APP_SECRET",
|
||||||
"allowFrom": []
|
"allowFrom": ["YOUR_OPENID"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -523,13 +533,13 @@ Uses **Stream Mode** — no public IP required.
|
|||||||
"enabled": true,
|
"enabled": true,
|
||||||
"clientId": "YOUR_APP_KEY",
|
"clientId": "YOUR_APP_KEY",
|
||||||
"clientSecret": "YOUR_APP_SECRET",
|
"clientSecret": "YOUR_APP_SECRET",
|
||||||
"allowFrom": []
|
"allowFrom": ["YOUR_STAFF_ID"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
> `allowFrom`: Leave empty to allow all users, or add `["staffId"]` to restrict access.
|
> `allowFrom`: Add your staff ID. Use `["*"]` to allow all users.
|
||||||
|
|
||||||
**3. Run**
|
**3. Run**
|
||||||
|
|
||||||
@ -564,6 +574,7 @@ Uses **Socket Mode** — no public URL required.
|
|||||||
"enabled": true,
|
"enabled": true,
|
||||||
"botToken": "xoxb-...",
|
"botToken": "xoxb-...",
|
||||||
"appToken": "xapp-...",
|
"appToken": "xapp-...",
|
||||||
|
"allowFrom": ["YOUR_SLACK_USER_ID"],
|
||||||
"groupPolicy": "mention"
|
"groupPolicy": "mention"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -597,7 +608,7 @@ Give nanobot its own email account. It polls **IMAP** for incoming mail and repl
|
|||||||
**2. Configure**
|
**2. Configure**
|
||||||
|
|
||||||
> - `consentGranted` must be `true` to allow mailbox access. This is a safety gate — set `false` to fully disable.
|
> - `consentGranted` must be `true` to allow mailbox access. This is a safety gate — set `false` to fully disable.
|
||||||
> - `allowFrom`: Leave empty to accept emails from anyone, or restrict to specific senders.
|
> - `allowFrom`: Add your email address. Use `["*"]` to accept emails from anyone.
|
||||||
> - `smtpUseTls` and `smtpUseSsl` default to `true` / `false` respectively, which is correct for Gmail (port 587 + STARTTLS). No need to set them explicitly.
|
> - `smtpUseTls` and `smtpUseSsl` default to `true` / `false` respectively, which is correct for Gmail (port 587 + STARTTLS). No need to set them explicitly.
|
||||||
> - Set `"autoReplyEnabled": false` if you only want to read/analyze emails without sending automatic replies.
|
> - Set `"autoReplyEnabled": false` if you only want to read/analyze emails without sending automatic replies.
|
||||||
|
|
||||||
@ -653,6 +664,7 @@ Config file: `~/.nanobot/config.json`
|
|||||||
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
> - **Zhipu Coding Plan**: If you're on Zhipu's coding plan, set `"apiBase": "https://open.bigmodel.cn/api/coding/paas/v4"` in your zhipu provider config.
|
||||||
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
> - **MiniMax (Mainland China)**: If your API key is from MiniMax's mainland China platform (minimaxi.com), set `"apiBase": "https://api.minimaxi.com/v1"` in your minimax provider config.
|
||||||
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
|
> - **VolcEngine Coding Plan**: If you're on VolcEngine's coding plan, set `"apiBase": "https://ark.cn-beijing.volces.com/api/coding/v3"` in your volcengine provider config.
|
||||||
|
> - **Alibaba Cloud Coding Plan**: If you're on the Alibaba Cloud Coding Plan (BaiLian), set `"apiBase": "https://coding.dashscope.aliyuncs.com/v1"` in your dashscope provider config.
|
||||||
|
|
||||||
| Provider | Purpose | Get API Key |
|
| Provider | Purpose | Get API Key |
|
||||||
|----------|---------|-------------|
|
|----------|---------|-------------|
|
||||||
@ -870,6 +882,7 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
|
|||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent.
|
> For production deployments, set `"restrictToWorkspace": true` in your config to sandbox the agent.
|
||||||
|
> **Change in source / post-`v0.1.4.post3`:** In `v0.1.4.post3` and earlier, an empty `allowFrom` means "allow all senders". In newer versions (including building from source), **empty `allowFrom` denies all access by default**. To allow all senders, set `"allowFrom": ["*"]`.
|
||||||
|
|
||||||
| Option | Default | Description |
|
| Option | Default | Description |
|
||||||
|--------|---------|-------------|
|
|--------|---------|-------------|
|
||||||
@ -878,6 +891,33 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
|
|||||||
| `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. |
|
| `channels.*.allowFrom` | `[]` (allow all) | Whitelist of user IDs. Empty = allow everyone; non-empty = only listed users can interact. |
|
||||||
|
|
||||||
|
|
||||||
|
## Multiple Instances
|
||||||
|
|
||||||
|
Run multiple nanobot instances simultaneously, each with its own workspace and configuration.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Instance A - Telegram bot
|
||||||
|
nanobot gateway -w ~/.nanobot/botA -p 18791
|
||||||
|
|
||||||
|
# Instance B - Discord bot
|
||||||
|
nanobot gateway -w ~/.nanobot/botB -p 18792
|
||||||
|
|
||||||
|
# Instance C - Using custom config file
|
||||||
|
nanobot gateway -w ~/.nanobot/botC -c ~/.nanobot/botC/config.json -p 18793
|
||||||
|
```
|
||||||
|
|
||||||
|
| Option | Short | Description |
|
||||||
|
|--------|-------|-------------|
|
||||||
|
| `--workspace` | `-w` | Workspace directory (default: `~/.nanobot/workspace`) |
|
||||||
|
| `--config` | `-c` | Config file path (default: `~/.nanobot/config.json`) |
|
||||||
|
| `--port` | `-p` | Gateway port (default: `18790`) |
|
||||||
|
|
||||||
|
Each instance has its own:
|
||||||
|
- Workspace directory (MEMORY.md, HEARTBEAT.md, session files)
|
||||||
|
- Cron jobs storage (`workspace/cron/jobs.json`)
|
||||||
|
- Configuration (if using `--config`)
|
||||||
|
|
||||||
|
|
||||||
## CLI Reference
|
## CLI Reference
|
||||||
|
|
||||||
| Command | Description |
|
| Command | Description |
|
||||||
@ -895,23 +935,6 @@ MCP tools are automatically discovered and registered on startup. The LLM can us
|
|||||||
|
|
||||||
Interactive mode exits: `exit`, `quit`, `/exit`, `/quit`, `:q`, or `Ctrl+D`.
|
Interactive mode exits: `exit`, `quit`, `/exit`, `/quit`, `:q`, or `Ctrl+D`.
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary><b>Scheduled Tasks (Cron)</b></summary>
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Add a job
|
|
||||||
nanobot cron add --name "daily" --message "Good morning!" --cron "0 9 * * *"
|
|
||||||
nanobot cron add --name "hourly" --message "Check status" --every 3600
|
|
||||||
|
|
||||||
# List jobs
|
|
||||||
nanobot cron list
|
|
||||||
|
|
||||||
# Remove a job
|
|
||||||
nanobot cron remove <job_id>
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Heartbeat (Periodic Tasks)</b></summary>
|
<summary><b>Heartbeat (Periodic Tasks)</b></summary>
|
||||||
|
|
||||||
|
|||||||
@ -55,7 +55,7 @@ chmod 600 ~/.nanobot/config.json
|
|||||||
```
|
```
|
||||||
|
|
||||||
**Security Notes:**
|
**Security Notes:**
|
||||||
- Empty `allowFrom` list will **ALLOW ALL** users (open by default for personal use)
|
- In `v0.1.4.post3` and earlier, an empty `allowFrom` allows all users. In newer versions (including source builds), **empty `allowFrom` denies all access** — set `["*"]` to explicitly allow everyone.
|
||||||
- Get your Telegram user ID from `@userinfobot`
|
- Get your Telegram user ID from `@userinfobot`
|
||||||
- Use full phone numbers with country code for WhatsApp
|
- Use full phone numbers with country code for WhatsApp
|
||||||
- Review access logs regularly for unauthorized access attempts
|
- Review access logs regularly for unauthorized access attempts
|
||||||
@ -212,9 +212,8 @@ If you suspect a security breach:
|
|||||||
- Input length limits on HTTP requests
|
- Input length limits on HTTP requests
|
||||||
|
|
||||||
✅ **Authentication**
|
✅ **Authentication**
|
||||||
- Allow-list based access control
|
- Allow-list based access control — in `v0.1.4.post3` and earlier empty means allow all; in newer versions empty means deny all (`["*"]` to explicitly allow all)
|
||||||
- Failed authentication attempt logging
|
- Failed authentication attempt logging
|
||||||
- Open by default (configure allowFrom for production use)
|
|
||||||
|
|
||||||
✅ **Resource Protection**
|
✅ **Resource Protection**
|
||||||
- Command execution timeouts (60s default)
|
- Command execution timeouts (60s default)
|
||||||
|
|||||||
@ -2,5 +2,5 @@
|
|||||||
nanobot - A lightweight AI agent framework
|
nanobot - A lightweight AI agent framework
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.1.4.post2"
|
__version__ = "0.1.4.post3"
|
||||||
__logo__ = "🐈"
|
__logo__ = "🐈"
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""Agent core module."""
|
"""Agent core module."""
|
||||||
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryStore
|
||||||
from nanobot.agent.skills import SkillsLoader
|
from nanobot.agent.skills import SkillsLoader
|
||||||
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from typing import Any
|
|||||||
|
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryStore
|
||||||
from nanobot.agent.skills import SkillsLoader
|
from nanobot.agent.skills import SkillsLoader
|
||||||
|
from nanobot.utils.helpers import detect_image_mime
|
||||||
|
|
||||||
|
|
||||||
class ContextBuilder:
|
class ContextBuilder:
|
||||||
@ -130,11 +131,20 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
memory_store: "MemoryStore | None" = None,
|
memory_store: "MemoryStore | None" = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Build the complete message list for an LLM call."""
|
"""Build the complete message list for an LLM call."""
|
||||||
|
runtime_ctx = self._build_runtime_context(channel, chat_id)
|
||||||
|
user_content = self._build_user_content(current_message, media)
|
||||||
|
|
||||||
|
# Merge runtime context and user content into a single user message
|
||||||
|
# to avoid consecutive same-role messages that some providers reject.
|
||||||
|
if isinstance(user_content, str):
|
||||||
|
merged = f"{runtime_ctx}\n\n{user_content}"
|
||||||
|
else:
|
||||||
|
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||||
|
|
||||||
return [
|
return [
|
||||||
{"role": "system", "content": self.build_system_prompt(skill_names, memory_store=memory_store)},
|
{"role": "system", "content": self.build_system_prompt(skill_names, memory_store=memory_store)},
|
||||||
*history,
|
*history,
|
||||||
{"role": "user", "content": self._build_runtime_context(channel, chat_id)},
|
{"role": "user", "content": merged},
|
||||||
{"role": "user", "content": self._build_user_content(current_message, media)},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
||||||
@ -145,10 +155,14 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
images = []
|
images = []
|
||||||
for path in media:
|
for path in media:
|
||||||
p = Path(path)
|
p = Path(path)
|
||||||
mime, _ = mimetypes.guess_type(path)
|
if not p.is_file():
|
||||||
if not p.is_file() or not mime or not mime.startswith("image/"):
|
|
||||||
continue
|
continue
|
||||||
b64 = base64.b64encode(p.read_bytes()).decode()
|
raw = p.read_bytes()
|
||||||
|
# Detect real MIME type from magic bytes; fallback to filename guess
|
||||||
|
mime = detect_image_mime(raw) or mimetypes.guess_type(path)[0]
|
||||||
|
if not mime or not mime.startswith("image/"):
|
||||||
|
continue
|
||||||
|
b64 = base64.b64encode(raw).decode()
|
||||||
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
|
images.append({"type": "image_url", "image_url": {"url": f"data:{mime};base64,{b64}"}})
|
||||||
|
|
||||||
if not images:
|
if not images:
|
||||||
@ -168,6 +182,7 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
content: str | None,
|
content: str | None,
|
||||||
tool_calls: list[dict[str, Any]] | None = None,
|
tool_calls: list[dict[str, Any]] | None = None,
|
||||||
reasoning_content: str | None = None,
|
reasoning_content: str | None = None,
|
||||||
|
thinking_blocks: list[dict] | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Add an assistant message to the message list."""
|
"""Add an assistant message to the message list."""
|
||||||
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
msg: dict[str, Any] = {"role": "assistant", "content": content}
|
||||||
@ -175,5 +190,7 @@ Reply directly with text for conversations. Only use the 'message' tool to send
|
|||||||
msg["tool_calls"] = tool_calls
|
msg["tool_calls"] = tool_calls
|
||||||
if reasoning_content is not None:
|
if reasoning_content is not None:
|
||||||
msg["reasoning_content"] = reasoning_content
|
msg["reasoning_content"] = reasoning_content
|
||||||
|
if thinking_blocks:
|
||||||
|
msg["thinking_blocks"] = thinking_blocks
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
return messages
|
return messages
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import weakref
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||||
@ -55,7 +56,9 @@ class AgentLoop:
|
|||||||
temperature: float = 0.1,
|
temperature: float = 0.1,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
memory_window: int = 100,
|
memory_window: int = 100,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
|
web_proxy: str | None = None,
|
||||||
exec_config: ExecToolConfig | None = None,
|
exec_config: ExecToolConfig | None = None,
|
||||||
cron_service: CronService | None = None,
|
cron_service: CronService | None = None,
|
||||||
restrict_to_workspace: bool = False,
|
restrict_to_workspace: bool = False,
|
||||||
@ -73,7 +76,9 @@ class AgentLoop:
|
|||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.memory_window = memory_window
|
self.memory_window = memory_window
|
||||||
|
self.reasoning_effort = reasoning_effort
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.cron_service = cron_service
|
self.cron_service = cron_service
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
@ -88,7 +93,9 @@ class AgentLoop:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
reasoning_effort=reasoning_effort,
|
||||||
brave_api_key=brave_api_key,
|
brave_api_key=brave_api_key,
|
||||||
|
web_proxy=web_proxy,
|
||||||
exec_config=self.exec_config,
|
exec_config=self.exec_config,
|
||||||
restrict_to_workspace=restrict_to_workspace,
|
restrict_to_workspace=restrict_to_workspace,
|
||||||
)
|
)
|
||||||
@ -100,7 +107,7 @@ class AgentLoop:
|
|||||||
self._mcp_connecting = False
|
self._mcp_connecting = False
|
||||||
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
self._consolidating: set[str] = set() # Session keys with consolidation in progress
|
||||||
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
self._consolidation_tasks: set[asyncio.Task] = set() # Strong refs to in-flight tasks
|
||||||
self._consolidation_locks: dict[str, asyncio.Lock] = {}
|
self._consolidation_locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
|
||||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||||
self._processing_lock = asyncio.Lock()
|
self._processing_lock = asyncio.Lock()
|
||||||
self._register_default_tools()
|
self._register_default_tools()
|
||||||
@ -116,8 +123,8 @@ class AgentLoop:
|
|||||||
restrict_to_workspace=self.restrict_to_workspace,
|
restrict_to_workspace=self.restrict_to_workspace,
|
||||||
path_append=self.exec_config.path_append,
|
path_append=self.exec_config.path_append,
|
||||||
))
|
))
|
||||||
self.tools.register(WebSearchTool(api_key=self.brave_api_key))
|
self.tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy))
|
||||||
self.tools.register(WebFetchTool())
|
self.tools.register(WebFetchTool(proxy=self.web_proxy))
|
||||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||||
self.tools.register(SpawnTool(manager=self.subagents))
|
self.tools.register(SpawnTool(manager=self.subagents))
|
||||||
if self.cron_service:
|
if self.cron_service:
|
||||||
@ -198,13 +205,23 @@ class AgentLoop:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
reasoning_effort=self.reasoning_effort,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
if on_progress:
|
if on_progress:
|
||||||
clean = self._strip_think(response.content)
|
thoughts = [
|
||||||
if clean:
|
self._strip_think(response.content),
|
||||||
await on_progress(clean)
|
response.reasoning_content,
|
||||||
|
*(
|
||||||
|
f"Thinking [{b.get('signature', '...')}]:\n{b.get('thought', '...')}"
|
||||||
|
for b in (response.thinking_blocks or [])
|
||||||
|
if isinstance(b, dict) and "signature" in b
|
||||||
|
),
|
||||||
|
]
|
||||||
|
combined_thoughts = "\n\n".join(filter(None, thoughts))
|
||||||
|
if combined_thoughts:
|
||||||
|
await on_progress(combined_thoughts)
|
||||||
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
await on_progress(self._tool_hint(response.tool_calls), tool_hint=True)
|
||||||
|
|
||||||
tool_call_dicts = [
|
tool_call_dicts = [
|
||||||
@ -221,6 +238,7 @@ class AgentLoop:
|
|||||||
messages = self.context.add_assistant_message(
|
messages = self.context.add_assistant_message(
|
||||||
messages, response.content, tool_call_dicts,
|
messages, response.content, tool_call_dicts,
|
||||||
reasoning_content=response.reasoning_content,
|
reasoning_content=response.reasoning_content,
|
||||||
|
thinking_blocks=response.thinking_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
for tool_call in response.tool_calls:
|
for tool_call in response.tool_calls:
|
||||||
@ -244,6 +262,7 @@ class AgentLoop:
|
|||||||
break
|
break
|
||||||
messages = self.context.add_assistant_message(
|
messages = self.context.add_assistant_message(
|
||||||
messages, clean, reasoning_content=response.reasoning_content,
|
messages, clean, reasoning_content=response.reasoning_content,
|
||||||
|
thinking_blocks=response.thinking_blocks,
|
||||||
)
|
)
|
||||||
final_content = clean
|
final_content = clean
|
||||||
break
|
break
|
||||||
@ -391,8 +410,6 @@ class AgentLoop:
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self._consolidating.discard(session.key)
|
self._consolidating.discard(session.key)
|
||||||
if not lock.locked():
|
|
||||||
self._consolidation_locks.pop(session.key, None)
|
|
||||||
|
|
||||||
session.clear()
|
session.clear()
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
@ -416,8 +433,6 @@ class AgentLoop:
|
|||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
self._consolidating.discard(session.key)
|
self._consolidating.discard(session.key)
|
||||||
if not lock.locked():
|
|
||||||
self._consolidation_locks.pop(session.key, None)
|
|
||||||
_task = asyncio.current_task()
|
_task = asyncio.current_task()
|
||||||
if _task is not None:
|
if _task is not None:
|
||||||
self._consolidation_tasks.discard(_task)
|
self._consolidation_tasks.discard(_task)
|
||||||
@ -472,7 +487,7 @@ class AgentLoop:
|
|||||||
"""Save new-turn messages into session, truncating large tool results."""
|
"""Save new-turn messages into session, truncating large tool results."""
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
for m in messages[skip:]:
|
for m in messages[skip:]:
|
||||||
entry = {k: v for k, v in m.items() if k != "reasoning_content"}
|
entry = dict(m)
|
||||||
role, content = entry.get("role"), entry.get("content")
|
role, content = entry.get("role"), entry.get("content")
|
||||||
if role == "assistant" and not content and not entry.get("tool_calls"):
|
if role == "assistant" and not content and not entry.get("tool_calls"):
|
||||||
continue # skip empty assistant messages — they poison session context
|
continue # skip empty assistant messages — they poison session context
|
||||||
@ -480,14 +495,25 @@ class AgentLoop:
|
|||||||
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||||
elif role == "user":
|
elif role == "user":
|
||||||
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||||
continue
|
# Strip the runtime-context prefix, keep only the user text.
|
||||||
|
parts = content.split("\n\n", 1)
|
||||||
|
if len(parts) > 1 and parts[1].strip():
|
||||||
|
entry["content"] = parts[1]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
entry["content"] = [
|
filtered = []
|
||||||
{"type": "text", "text": "[image]"} if (
|
for c in content:
|
||||||
c.get("type") == "image_url"
|
if c.get("type") == "text" and isinstance(c.get("text"), str) and c["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
|
||||||
and c.get("image_url", {}).get("url", "").startswith("data:image/")
|
continue # Strip runtime context from multimodal messages
|
||||||
) else c for c in content
|
if (c.get("type") == "image_url"
|
||||||
]
|
and c.get("image_url", {}).get("url", "").startswith("data:image/")):
|
||||||
|
filtered.append({"type": "text", "text": "[image]"})
|
||||||
|
else:
|
||||||
|
filtered.append(c)
|
||||||
|
if not filtered:
|
||||||
|
continue
|
||||||
|
entry["content"] = filtered
|
||||||
entry.setdefault("timestamp", datetime.now().isoformat())
|
entry.setdefault("timestamp", datetime.now().isoformat())
|
||||||
session.messages.append(entry)
|
session.messages.append(entry)
|
||||||
session.updated_at = datetime.now()
|
session.updated_at = datetime.now()
|
||||||
|
|||||||
@ -128,6 +128,13 @@ class MemoryStore:
|
|||||||
# Some providers return arguments as a JSON string instead of dict
|
# Some providers return arguments as a JSON string instead of dict
|
||||||
if isinstance(args, str):
|
if isinstance(args, str):
|
||||||
args = json.loads(args)
|
args = json.loads(args)
|
||||||
|
# Some providers return arguments as a list (handle edge case)
|
||||||
|
if isinstance(args, list):
|
||||||
|
if args and isinstance(args[0], dict):
|
||||||
|
args = args[0]
|
||||||
|
else:
|
||||||
|
logger.warning("Memory consolidation: unexpected arguments as empty or non-dict list")
|
||||||
|
return False
|
||||||
if not isinstance(args, dict):
|
if not isinstance(args, dict):
|
||||||
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
logger.warning("Memory consolidation: unexpected arguments type {}", type(args).__name__)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -134,7 +134,7 @@ class SkillsLoader:
|
|||||||
if missing:
|
if missing:
|
||||||
lines.append(f" <requires>{escape_xml(missing)}</requires>")
|
lines.append(f" <requires>{escape_xml(missing)}</requires>")
|
||||||
|
|
||||||
lines.append(f" </skill>")
|
lines.append(" </skill>")
|
||||||
lines.append("</skills>")
|
lines.append("</skills>")
|
||||||
|
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|||||||
@ -8,13 +8,14 @@ from typing import Any
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
|
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.config.schema import ExecToolConfig
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
|
||||||
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool, EditFileTool, ListDirTool
|
|
||||||
from nanobot.agent.tools.shell import ExecTool
|
|
||||||
from nanobot.agent.tools.web import WebSearchTool, WebFetchTool
|
|
||||||
|
|
||||||
|
|
||||||
class SubagentManager:
|
class SubagentManager:
|
||||||
@ -28,7 +29,9 @@ class SubagentManager:
|
|||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
brave_api_key: str | None = None,
|
brave_api_key: str | None = None,
|
||||||
|
web_proxy: str | None = None,
|
||||||
exec_config: "ExecToolConfig | None" = None,
|
exec_config: "ExecToolConfig | None" = None,
|
||||||
restrict_to_workspace: bool = False,
|
restrict_to_workspace: bool = False,
|
||||||
):
|
):
|
||||||
@ -39,7 +42,9 @@ class SubagentManager:
|
|||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.reasoning_effort = reasoning_effort
|
||||||
self.brave_api_key = brave_api_key
|
self.brave_api_key = brave_api_key
|
||||||
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
@ -101,11 +106,10 @@ class SubagentManager:
|
|||||||
restrict_to_workspace=self.restrict_to_workspace,
|
restrict_to_workspace=self.restrict_to_workspace,
|
||||||
path_append=self.exec_config.path_append,
|
path_append=self.exec_config.path_append,
|
||||||
))
|
))
|
||||||
tools.register(WebSearchTool(api_key=self.brave_api_key))
|
tools.register(WebSearchTool(api_key=self.brave_api_key, proxy=self.web_proxy))
|
||||||
tools.register(WebFetchTool())
|
tools.register(WebFetchTool(proxy=self.web_proxy))
|
||||||
|
|
||||||
# Build messages with subagent-specific prompt
|
system_prompt = self._build_subagent_prompt()
|
||||||
system_prompt = self._build_subagent_prompt(task)
|
|
||||||
messages: list[dict[str, Any]] = [
|
messages: list[dict[str, Any]] = [
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": task},
|
{"role": "user", "content": task},
|
||||||
@ -125,6 +129,7 @@ class SubagentManager:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
temperature=self.temperature,
|
temperature=self.temperature,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
|
reasoning_effort=self.reasoning_effort,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.has_tool_calls:
|
if response.has_tool_calls:
|
||||||
@ -204,42 +209,27 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
|||||||
await self.bus.publish_inbound(msg)
|
await self.bus.publish_inbound(msg)
|
||||||
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
||||||
|
|
||||||
def _build_subagent_prompt(self, task: str) -> str:
|
def _build_subagent_prompt(self) -> str:
|
||||||
"""Build a focused system prompt for the subagent."""
|
"""Build a focused system prompt for the subagent."""
|
||||||
from datetime import datetime
|
from nanobot.agent.context import ContextBuilder
|
||||||
import time as _time
|
from nanobot.agent.skills import SkillsLoader
|
||||||
now = datetime.now().strftime("%Y-%m-%d %H:%M (%A)")
|
|
||||||
tz = _time.strftime("%Z") or "UTC"
|
|
||||||
|
|
||||||
return f"""# Subagent
|
time_ctx = ContextBuilder._build_runtime_context(None, None)
|
||||||
|
parts = [f"""# Subagent
|
||||||
|
|
||||||
## Current Time
|
{time_ctx}
|
||||||
{now} ({tz})
|
|
||||||
|
|
||||||
You are a subagent spawned by the main agent to complete a specific task.
|
You are a subagent spawned by the main agent to complete a specific task.
|
||||||
|
Stay focused on the assigned task. Your final response will be reported back to the main agent.
|
||||||
## Rules
|
|
||||||
1. Stay focused - complete only the assigned task, nothing else
|
|
||||||
2. Your final response will be reported back to the main agent
|
|
||||||
3. Do not initiate conversations or take on side tasks
|
|
||||||
4. Be concise but informative in your findings
|
|
||||||
|
|
||||||
## What You Can Do
|
|
||||||
- Read and write files in the workspace
|
|
||||||
- Execute shell commands
|
|
||||||
- Search the web and fetch web pages
|
|
||||||
- Complete the task thoroughly
|
|
||||||
|
|
||||||
## What You Cannot Do
|
|
||||||
- Send messages directly to users (no message tool available)
|
|
||||||
- Spawn other subagents
|
|
||||||
- Access the main agent's conversation history
|
|
||||||
|
|
||||||
## Workspace
|
## Workspace
|
||||||
Your workspace is at: {self.workspace}
|
{self.workspace}"""]
|
||||||
Skills are available at: {self.workspace}/skills/ (read SKILL.md files as needed)
|
|
||||||
|
|
||||||
When you have completed the task, provide a clear summary of your findings or actions."""
|
skills_summary = SkillsLoader(self.workspace).build_skills_summary()
|
||||||
|
if skills_summary:
|
||||||
|
parts.append(f"## Skills\n\nRead SKILL.md with read_file to use a skill.\n\n{skills_summary}")
|
||||||
|
|
||||||
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
async def cancel_by_session(self, session_key: str) -> int:
|
async def cancel_by_session(self, session_key: str) -> int:
|
||||||
"""Cancel all subagents for the given session. Returns count cancelled."""
|
"""Cancel all subagents for the given session. Returns count cancelled."""
|
||||||
|
|||||||
@ -54,6 +54,8 @@ class Tool(ABC):
|
|||||||
|
|
||||||
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
def validate_params(self, params: dict[str, Any]) -> list[str]:
|
||||||
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
"""Validate tool parameters against JSON schema. Returns error list (empty if valid)."""
|
||||||
|
if not isinstance(params, dict):
|
||||||
|
return [f"parameters must be an object, got {type(params).__name__}"]
|
||||||
schema = self.parameters or {}
|
schema = self.parameters or {}
|
||||||
if schema.get("type", "object") != "object":
|
if schema.get("type", "object") != "object":
|
||||||
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
|
raise ValueError(f"Schema must be object type, got {schema.get('type')!r}")
|
||||||
@ -84,10 +86,12 @@ class Tool(ABC):
|
|||||||
errors.append(f"missing required {path + '.' + k if path else k}")
|
errors.append(f"missing required {path + '.' + k if path else k}")
|
||||||
for k, v in val.items():
|
for k, v in val.items():
|
||||||
if k in props:
|
if k in props:
|
||||||
errors.extend(self._validate(v, props[k], path + '.' + k if path else k))
|
errors.extend(self._validate(v, props[k], path + "." + k if path else k))
|
||||||
if t == "array" and "items" in schema:
|
if t == "array" and "items" in schema:
|
||||||
for i, item in enumerate(val):
|
for i, item in enumerate(val):
|
||||||
errors.extend(self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]"))
|
errors.extend(
|
||||||
|
self._validate(item, schema["items"], f"{path}[{i}]" if path else f"[{i}]")
|
||||||
|
)
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
def to_schema(self) -> dict[str, Any]:
|
def to_schema(self) -> dict[str, Any]:
|
||||||
@ -98,5 +102,5 @@ class Tool(ABC):
|
|||||||
"name": self.name,
|
"name": self.name,
|
||||||
"description": self.description,
|
"description": self.description,
|
||||||
"parameters": self.parameters,
|
"parameters": self.parameters,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
"""Cron tool for scheduling reminders and tasks."""
|
"""Cron tool for scheduling reminders and tasks."""
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
@ -14,12 +15,21 @@ class CronTool(Tool):
|
|||||||
self._cron = cron_service
|
self._cron = cron_service
|
||||||
self._channel = ""
|
self._channel = ""
|
||||||
self._chat_id = ""
|
self._chat_id = ""
|
||||||
|
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str) -> None:
|
def set_context(self, channel: str, chat_id: str) -> None:
|
||||||
"""Set the current session context for delivery."""
|
"""Set the current session context for delivery."""
|
||||||
self._channel = channel
|
self._channel = channel
|
||||||
self._chat_id = chat_id
|
self._chat_id = chat_id
|
||||||
|
|
||||||
|
def set_cron_context(self, active: bool):
|
||||||
|
"""Mark whether the tool is executing inside a cron job callback."""
|
||||||
|
return self._in_cron_context.set(active)
|
||||||
|
|
||||||
|
def reset_cron_context(self, token) -> None:
|
||||||
|
"""Restore previous cron context."""
|
||||||
|
self._in_cron_context.reset(token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return "cron"
|
return "cron"
|
||||||
@ -36,34 +46,28 @@ class CronTool(Tool):
|
|||||||
"action": {
|
"action": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["add", "list", "remove"],
|
"enum": ["add", "list", "remove"],
|
||||||
"description": "Action to perform"
|
"description": "Action to perform",
|
||||||
},
|
|
||||||
"message": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Reminder message (for add)"
|
|
||||||
},
|
},
|
||||||
|
"message": {"type": "string", "description": "Reminder message (for add)"},
|
||||||
"every_seconds": {
|
"every_seconds": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"description": "Interval in seconds (for recurring tasks)"
|
"description": "Interval in seconds (for recurring tasks)",
|
||||||
},
|
},
|
||||||
"cron_expr": {
|
"cron_expr": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)"
|
"description": "Cron expression like '0 9 * * *' (for scheduled tasks)",
|
||||||
},
|
},
|
||||||
"tz": {
|
"tz": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')"
|
"description": "IANA timezone for cron expressions (e.g. 'America/Vancouver')",
|
||||||
},
|
},
|
||||||
"at": {
|
"at": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')"
|
"description": "ISO datetime for one-time execution (e.g. '2026-02-12T10:30:00')",
|
||||||
},
|
},
|
||||||
"job_id": {
|
"job_id": {"type": "string", "description": "Job ID (for remove)"},
|
||||||
"type": "string",
|
|
||||||
"description": "Job ID (for remove)"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["action"]
|
"required": ["action"],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
@ -75,9 +79,11 @@ class CronTool(Tool):
|
|||||||
tz: str | None = None,
|
tz: str | None = None,
|
||||||
at: str | None = None,
|
at: str | None = None,
|
||||||
job_id: str | None = None,
|
job_id: str | None = None,
|
||||||
**kwargs: Any
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
if action == "add":
|
if action == "add":
|
||||||
|
if self._in_cron_context.get():
|
||||||
|
return "Error: cannot schedule new jobs from within a cron job execution"
|
||||||
return self._add_job(message, every_seconds, cron_expr, tz, at)
|
return self._add_job(message, every_seconds, cron_expr, tz, at)
|
||||||
elif action == "list":
|
elif action == "list":
|
||||||
return self._list_jobs()
|
return self._list_jobs()
|
||||||
@ -101,6 +107,7 @@ class CronTool(Tool):
|
|||||||
return "Error: tz can only be used with cron_expr"
|
return "Error: tz can only be used with cron_expr"
|
||||||
if tz:
|
if tz:
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ZoneInfo(tz)
|
ZoneInfo(tz)
|
||||||
except (KeyError, Exception):
|
except (KeyError, Exception):
|
||||||
@ -114,7 +121,11 @@ class CronTool(Tool):
|
|||||||
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
|
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
|
||||||
elif at:
|
elif at:
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
dt = datetime.fromisoformat(at)
|
|
||||||
|
try:
|
||||||
|
dt = datetime.fromisoformat(at)
|
||||||
|
except ValueError:
|
||||||
|
return f"Error: invalid ISO datetime format '{at}'. Expected format: YYYY-MM-DDTHH:MM:SS"
|
||||||
at_ms = int(dt.timestamp() * 1000)
|
at_ms = int(dt.timestamp() * 1000)
|
||||||
schedule = CronSchedule(kind="at", at_ms=at_ms)
|
schedule = CronSchedule(kind="at", at_ms=at_ms)
|
||||||
delete_after = True
|
delete_after = True
|
||||||
|
|||||||
@ -7,7 +7,9 @@ from typing import Any
|
|||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(path: str, workspace: Path | None = None, allowed_dir: Path | None = None) -> Path:
|
def _resolve_path(
|
||||||
|
path: str, workspace: Path | None = None, allowed_dir: Path | None = None
|
||||||
|
) -> Path:
|
||||||
"""Resolve path against workspace (if relative) and enforce directory restriction."""
|
"""Resolve path against workspace (if relative) and enforce directory restriction."""
|
||||||
p = Path(path).expanduser()
|
p = Path(path).expanduser()
|
||||||
if not p.is_absolute() and workspace:
|
if not p.is_absolute() and workspace:
|
||||||
@ -24,6 +26,8 @@ def _resolve_path(path: str, workspace: Path | None = None, allowed_dir: Path |
|
|||||||
class ReadFileTool(Tool):
|
class ReadFileTool(Tool):
|
||||||
"""Tool to read file contents."""
|
"""Tool to read file contents."""
|
||||||
|
|
||||||
|
_MAX_CHARS = 128_000 # ~128 KB — prevents OOM from reading huge files into LLM context
|
||||||
|
|
||||||
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
def __init__(self, workspace: Path | None = None, allowed_dir: Path | None = None):
|
||||||
self._workspace = workspace
|
self._workspace = workspace
|
||||||
self._allowed_dir = allowed_dir
|
self._allowed_dir = allowed_dir
|
||||||
@ -40,13 +44,8 @@ class ReadFileTool(Tool):
|
|||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"path": {"type": "string", "description": "The file path to read"}},
|
||||||
"path": {
|
"required": ["path"],
|
||||||
"type": "string",
|
|
||||||
"description": "The file path to read"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["path"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||||
@ -57,7 +56,16 @@ class ReadFileTool(Tool):
|
|||||||
if not file_path.is_file():
|
if not file_path.is_file():
|
||||||
return f"Error: Not a file: {path}"
|
return f"Error: Not a file: {path}"
|
||||||
|
|
||||||
|
size = file_path.stat().st_size
|
||||||
|
if size > self._MAX_CHARS * 4: # rough upper bound (UTF-8 chars ≤ 4 bytes)
|
||||||
|
return (
|
||||||
|
f"Error: File too large ({size:,} bytes). "
|
||||||
|
f"Use exec tool with head/tail/grep to read portions."
|
||||||
|
)
|
||||||
|
|
||||||
content = file_path.read_text(encoding="utf-8")
|
content = file_path.read_text(encoding="utf-8")
|
||||||
|
if len(content) > self._MAX_CHARS:
|
||||||
|
return content[: self._MAX_CHARS] + f"\n\n... (truncated — file is {len(content):,} chars, limit {self._MAX_CHARS:,})"
|
||||||
return content
|
return content
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
@ -85,16 +93,10 @@ class WriteFileTool(Tool):
|
|||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"path": {
|
"path": {"type": "string", "description": "The file path to write to"},
|
||||||
"type": "string",
|
"content": {"type": "string", "description": "The content to write"},
|
||||||
"description": "The file path to write to"
|
|
||||||
},
|
|
||||||
"content": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The content to write"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["path", "content"]
|
"required": ["path", "content"],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, content: str, **kwargs: Any) -> str:
|
||||||
@ -129,20 +131,11 @@ class EditFileTool(Tool):
|
|||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"path": {
|
"path": {"type": "string", "description": "The file path to edit"},
|
||||||
"type": "string",
|
"old_text": {"type": "string", "description": "The exact text to find and replace"},
|
||||||
"description": "The file path to edit"
|
"new_text": {"type": "string", "description": "The text to replace with"},
|
||||||
},
|
|
||||||
"old_text": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The exact text to find and replace"
|
|
||||||
},
|
|
||||||
"new_text": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The text to replace with"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required": ["path", "old_text", "new_text"]
|
"required": ["path", "old_text", "new_text"],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, old_text: str, new_text: str, **kwargs: Any) -> str:
|
||||||
@ -184,13 +177,19 @@ class EditFileTool(Tool):
|
|||||||
best_ratio, best_start = ratio, i
|
best_ratio, best_start = ratio, i
|
||||||
|
|
||||||
if best_ratio > 0.5:
|
if best_ratio > 0.5:
|
||||||
diff = "\n".join(difflib.unified_diff(
|
diff = "\n".join(
|
||||||
old_lines, lines[best_start : best_start + window],
|
difflib.unified_diff(
|
||||||
fromfile="old_text (provided)", tofile=f"{path} (actual, line {best_start + 1})",
|
old_lines,
|
||||||
lineterm="",
|
lines[best_start : best_start + window],
|
||||||
))
|
fromfile="old_text (provided)",
|
||||||
|
tofile=f"{path} (actual, line {best_start + 1})",
|
||||||
|
lineterm="",
|
||||||
|
)
|
||||||
|
)
|
||||||
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||||
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
return (
|
||||||
|
f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ListDirTool(Tool):
|
class ListDirTool(Tool):
|
||||||
@ -212,13 +211,8 @@ class ListDirTool(Tool):
|
|||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {"path": {"type": "string", "description": "The directory path to list"}},
|
||||||
"path": {
|
"required": ["path"],
|
||||||
"type": "string",
|
|
||||||
"description": "The directory path to list"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["path"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def execute(self, path: str, **kwargs: Any) -> str:
|
async def execute(self, path: str, **kwargs: Any) -> str:
|
||||||
|
|||||||
@ -58,17 +58,48 @@ async def connect_mcp_servers(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Connect to configured MCP servers and register their tools."""
|
"""Connect to configured MCP servers and register their tools."""
|
||||||
from mcp import ClientSession, StdioServerParameters
|
from mcp import ClientSession, StdioServerParameters
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
from mcp.client.stdio import stdio_client
|
from mcp.client.stdio import stdio_client
|
||||||
|
from mcp.client.streamable_http import streamable_http_client
|
||||||
|
|
||||||
for name, cfg in mcp_servers.items():
|
for name, cfg in mcp_servers.items():
|
||||||
try:
|
try:
|
||||||
if cfg.command:
|
transport_type = cfg.type
|
||||||
|
if not transport_type:
|
||||||
|
if cfg.command:
|
||||||
|
transport_type = "stdio"
|
||||||
|
elif cfg.url:
|
||||||
|
# Convention: URLs ending with /sse use SSE transport; others use streamableHttp
|
||||||
|
transport_type = (
|
||||||
|
"sse" if cfg.url.rstrip("/").endswith("/sse") else "streamableHttp"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if transport_type == "stdio":
|
||||||
params = StdioServerParameters(
|
params = StdioServerParameters(
|
||||||
command=cfg.command, args=cfg.args, env=cfg.env or None
|
command=cfg.command, args=cfg.args, env=cfg.env or None
|
||||||
)
|
)
|
||||||
read, write = await stack.enter_async_context(stdio_client(params))
|
read, write = await stack.enter_async_context(stdio_client(params))
|
||||||
elif cfg.url:
|
elif transport_type == "sse":
|
||||||
from mcp.client.streamable_http import streamable_http_client
|
def httpx_client_factory(
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
timeout: httpx.Timeout | None = None,
|
||||||
|
auth: httpx.Auth | None = None,
|
||||||
|
) -> httpx.AsyncClient:
|
||||||
|
merged_headers = {**(cfg.headers or {}), **(headers or {})}
|
||||||
|
return httpx.AsyncClient(
|
||||||
|
headers=merged_headers or None,
|
||||||
|
follow_redirects=True,
|
||||||
|
timeout=timeout,
|
||||||
|
auth=auth,
|
||||||
|
)
|
||||||
|
|
||||||
|
read, write = await stack.enter_async_context(
|
||||||
|
sse_client(cfg.url, httpx_client_factory=httpx_client_factory)
|
||||||
|
)
|
||||||
|
elif transport_type == "streamableHttp":
|
||||||
# Always provide an explicit httpx client so MCP HTTP transport does not
|
# Always provide an explicit httpx client so MCP HTTP transport does not
|
||||||
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
# inherit httpx's default 5s timeout and preempt the higher-level tool timeout.
|
||||||
http_client = await stack.enter_async_context(
|
http_client = await stack.enter_async_context(
|
||||||
@ -82,7 +113,7 @@ async def connect_mcp_servers(
|
|||||||
streamable_http_client(cfg.url, http_client=http_client)
|
streamable_http_client(cfg.url, http_client=http_client)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("MCP server '{}': no command or url configured, skipping", name)
|
logger.warning("MCP server '{}': unknown transport type '{}'", name, transport_type)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
session = await stack.enter_async_context(ClientSession(read, write))
|
session = await stack.enter_async_context(ClientSession(read, write))
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Spawn tool for creating background subagents."""
|
"""Spawn tool for creating background subagents."""
|
||||||
|
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from typing import Any
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
|
|
||||||
@ -57,9 +58,10 @@ class WebSearchTool(Tool):
|
|||||||
"required": ["query"]
|
"required": ["query"]
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, api_key: str | None = None, max_results: int = 5):
|
def __init__(self, api_key: str | None = None, max_results: int = 5, proxy: str | None = None):
|
||||||
self._init_api_key = api_key
|
self._init_api_key = api_key
|
||||||
self.max_results = max_results
|
self.max_results = max_results
|
||||||
|
self.proxy = proxy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def api_key(self) -> str:
|
def api_key(self) -> str:
|
||||||
@ -69,14 +71,15 @@ class WebSearchTool(Tool):
|
|||||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
return (
|
return (
|
||||||
"Error: Brave Search API key not configured. "
|
"Error: Brave Search API key not configured. Set it in "
|
||||||
"Set it in ~/.nanobot/config.json under tools.web.search.apiKey "
|
"~/.nanobot/config.json under tools.web.search.apiKey "
|
||||||
"(or export BRAVE_API_KEY), then restart the gateway."
|
"(or export BRAVE_API_KEY), then restart the gateway."
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
n = min(max(count or self.max_results, 1), 10)
|
n = min(max(count or self.max_results, 1), 10)
|
||||||
async with httpx.AsyncClient() as client:
|
logger.debug("WebSearch: {}", "proxy enabled" if self.proxy else "direct connection")
|
||||||
|
async with httpx.AsyncClient(proxy=self.proxy) as client:
|
||||||
r = await client.get(
|
r = await client.get(
|
||||||
"https://api.search.brave.com/res/v1/web/search",
|
"https://api.search.brave.com/res/v1/web/search",
|
||||||
params={"q": query, "count": n},
|
params={"q": query, "count": n},
|
||||||
@ -85,17 +88,21 @@ class WebSearchTool(Tool):
|
|||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
results = r.json().get("web", {}).get("results", [])
|
results = r.json().get("web", {}).get("results", [])[:n]
|
||||||
if not results:
|
if not results:
|
||||||
return f"No results for: {query}"
|
return f"No results for: {query}"
|
||||||
|
|
||||||
lines = [f"Results for: {query}\n"]
|
lines = [f"Results for: {query}\n"]
|
||||||
for i, item in enumerate(results[:n], 1):
|
for i, item in enumerate(results, 1):
|
||||||
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
|
lines.append(f"{i}. {item.get('title', '')}\n {item.get('url', '')}")
|
||||||
if desc := item.get("description"):
|
if desc := item.get("description"):
|
||||||
lines.append(f" {desc}")
|
lines.append(f" {desc}")
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
except httpx.ProxyError as e:
|
||||||
|
logger.error("WebSearch proxy error: {}", e)
|
||||||
|
return f"Proxy error: {e}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error("WebSearch error: {}", e)
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|
||||||
|
|
||||||
@ -114,34 +121,33 @@ class WebFetchTool(Tool):
|
|||||||
"required": ["url"]
|
"required": ["url"]
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, max_chars: int = 50000):
|
def __init__(self, max_chars: int = 50000, proxy: str | None = None):
|
||||||
self.max_chars = max_chars
|
self.max_chars = max_chars
|
||||||
|
self.proxy = proxy
|
||||||
|
|
||||||
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
|
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> str:
|
||||||
from readability import Document
|
from readability import Document
|
||||||
|
|
||||||
max_chars = maxChars or self.max_chars
|
max_chars = maxChars or self.max_chars
|
||||||
|
|
||||||
# Validate URL before fetching
|
|
||||||
is_valid, error_msg = _validate_url(url)
|
is_valid, error_msg = _validate_url(url)
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}, ensure_ascii=False)
|
return json.dumps({"error": f"URL validation failed: {error_msg}", "url": url}, ensure_ascii=False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.debug("WebFetch: {}", "proxy enabled" if self.proxy else "direct connection")
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
max_redirects=MAX_REDIRECTS,
|
max_redirects=MAX_REDIRECTS,
|
||||||
timeout=30.0
|
timeout=30.0,
|
||||||
|
proxy=self.proxy,
|
||||||
) as client:
|
) as client:
|
||||||
r = await client.get(url, headers={"User-Agent": USER_AGENT})
|
r = await client.get(url, headers={"User-Agent": USER_AGENT})
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
ctype = r.headers.get("content-type", "")
|
ctype = r.headers.get("content-type", "")
|
||||||
|
|
||||||
# JSON
|
|
||||||
if "application/json" in ctype:
|
if "application/json" in ctype:
|
||||||
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
|
text, extractor = json.dumps(r.json(), indent=2, ensure_ascii=False), "json"
|
||||||
# HTML
|
|
||||||
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
|
elif "text/html" in ctype or r.text[:256].lower().startswith(("<!doctype", "<html")):
|
||||||
doc = Document(r.text)
|
doc = Document(r.text)
|
||||||
content = self._to_markdown(doc.summary()) if extractMode == "markdown" else _strip_tags(doc.summary())
|
content = self._to_markdown(doc.summary()) if extractMode == "markdown" else _strip_tags(doc.summary())
|
||||||
@ -151,12 +157,15 @@ class WebFetchTool(Tool):
|
|||||||
text, extractor = r.text, "raw"
|
text, extractor = r.text, "raw"
|
||||||
|
|
||||||
truncated = len(text) > max_chars
|
truncated = len(text) > max_chars
|
||||||
if truncated:
|
if truncated: text = text[:max_chars]
|
||||||
text = text[:max_chars]
|
|
||||||
|
|
||||||
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
|
return json.dumps({"url": url, "finalUrl": str(r.url), "status": r.status_code,
|
||||||
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
|
"extractor": extractor, "truncated": truncated, "length": len(text), "text": text}, ensure_ascii=False)
|
||||||
|
except httpx.ProxyError as e:
|
||||||
|
logger.error("WebFetch proxy error for {}: {}", url, e)
|
||||||
|
return json.dumps({"error": f"Proxy error: {e}", "url": url}, ensure_ascii=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error("WebFetch error for {}: {}", url, e)
|
||||||
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
|
return json.dumps({"error": str(e), "url": url}, ensure_ascii=False)
|
||||||
|
|
||||||
def _to_markdown(self, html: str) -> str:
|
def _to_markdown(self, html: str) -> str:
|
||||||
|
|||||||
@ -59,29 +59,17 @@ class BaseChannel(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def is_allowed(self, sender_id: str) -> bool:
|
def is_allowed(self, sender_id: str) -> bool:
|
||||||
"""
|
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
|
||||||
Check if a sender is allowed to use this bot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sender_id: The sender's identifier.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if allowed, False otherwise.
|
|
||||||
"""
|
|
||||||
allow_list = getattr(self.config, "allow_from", [])
|
allow_list = getattr(self.config, "allow_from", [])
|
||||||
|
|
||||||
# If no allow list, allow everyone
|
|
||||||
if not allow_list:
|
if not allow_list:
|
||||||
|
logger.warning("{}: allow_from is empty — all access denied", self.name)
|
||||||
|
return False
|
||||||
|
if "*" in allow_list:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
sender_str = str(sender_id)
|
sender_str = str(sender_id)
|
||||||
if sender_str in allow_list:
|
return sender_str in allow_list or any(
|
||||||
return True
|
p in allow_list for p in sender_str.split("|") if p
|
||||||
if "|" in sender_str:
|
)
|
||||||
for part in sender_str.split("|"):
|
|
||||||
if part and part in allow_list:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _handle_message(
|
async def _handle_message(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -2,11 +2,15 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import mimetypes
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from urllib.parse import unquote, urlparse
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@ -15,11 +19,11 @@ from nanobot.config.schema import DingTalkConfig
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from dingtalk_stream import (
|
from dingtalk_stream import (
|
||||||
DingTalkStreamClient,
|
AckMessage,
|
||||||
Credential,
|
|
||||||
CallbackHandler,
|
CallbackHandler,
|
||||||
CallbackMessage,
|
CallbackMessage,
|
||||||
AckMessage,
|
Credential,
|
||||||
|
DingTalkStreamClient,
|
||||||
)
|
)
|
||||||
from dingtalk_stream.chatbot import ChatbotMessage
|
from dingtalk_stream.chatbot import ChatbotMessage
|
||||||
|
|
||||||
@ -96,6 +100,9 @@ class DingTalkChannel(BaseChannel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name = "dingtalk"
|
name = "dingtalk"
|
||||||
|
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"}
|
||||||
|
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg", ".m4a", ".aac"}
|
||||||
|
_VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm"}
|
||||||
|
|
||||||
def __init__(self, config: DingTalkConfig, bus: MessageBus):
|
def __init__(self, config: DingTalkConfig, bus: MessageBus):
|
||||||
super().__init__(config, bus)
|
super().__init__(config, bus)
|
||||||
@ -191,40 +198,224 @@ class DingTalkChannel(BaseChannel):
|
|||||||
logger.error("Failed to get DingTalk access token: {}", e)
|
logger.error("Failed to get DingTalk access token: {}", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_http_url(value: str) -> bool:
|
||||||
|
return urlparse(value).scheme in ("http", "https")
|
||||||
|
|
||||||
|
def _guess_upload_type(self, media_ref: str) -> str:
|
||||||
|
ext = Path(urlparse(media_ref).path).suffix.lower()
|
||||||
|
if ext in self._IMAGE_EXTS: return "image"
|
||||||
|
if ext in self._AUDIO_EXTS: return "voice"
|
||||||
|
if ext in self._VIDEO_EXTS: return "video"
|
||||||
|
return "file"
|
||||||
|
|
||||||
|
def _guess_filename(self, media_ref: str, upload_type: str) -> str:
|
||||||
|
name = os.path.basename(urlparse(media_ref).path)
|
||||||
|
return name or {"image": "image.jpg", "voice": "audio.amr", "video": "video.mp4"}.get(upload_type, "file.bin")
|
||||||
|
|
||||||
|
async def _read_media_bytes(
|
||||||
|
self,
|
||||||
|
media_ref: str,
|
||||||
|
) -> tuple[bytes | None, str | None, str | None]:
|
||||||
|
if not media_ref:
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
if self._is_http_url(media_ref):
|
||||||
|
if not self._http:
|
||||||
|
return None, None, None
|
||||||
|
try:
|
||||||
|
resp = await self._http.get(media_ref, follow_redirects=True)
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
logger.warning(
|
||||||
|
"DingTalk media download failed status={} ref={}",
|
||||||
|
resp.status_code,
|
||||||
|
media_ref,
|
||||||
|
)
|
||||||
|
return None, None, None
|
||||||
|
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
|
||||||
|
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||||
|
return resp.content, filename, content_type or None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if media_ref.startswith("file://"):
|
||||||
|
parsed = urlparse(media_ref)
|
||||||
|
local_path = Path(unquote(parsed.path))
|
||||||
|
else:
|
||||||
|
local_path = Path(os.path.expanduser(media_ref))
|
||||||
|
if not local_path.is_file():
|
||||||
|
logger.warning("DingTalk media file not found: {}", local_path)
|
||||||
|
return None, None, None
|
||||||
|
data = await asyncio.to_thread(local_path.read_bytes)
|
||||||
|
content_type = mimetypes.guess_type(local_path.name)[0]
|
||||||
|
return data, local_path.name, content_type
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("DingTalk media read error ref={} err={}", media_ref, e)
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
async def _upload_media(
|
||||||
|
self,
|
||||||
|
token: str,
|
||||||
|
data: bytes,
|
||||||
|
media_type: str,
|
||||||
|
filename: str,
|
||||||
|
content_type: str | None,
|
||||||
|
) -> str | None:
|
||||||
|
if not self._http:
|
||||||
|
return None
|
||||||
|
url = f"https://oapi.dingtalk.com/media/upload?access_token={token}&type={media_type}"
|
||||||
|
mime = content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||||
|
files = {"media": (filename, data, mime)}
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = await self._http.post(url, files=files)
|
||||||
|
text = resp.text
|
||||||
|
result = resp.json() if resp.headers.get("content-type", "").startswith("application/json") else {}
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
logger.error("DingTalk media upload failed status={} type={} body={}", resp.status_code, media_type, text[:500])
|
||||||
|
return None
|
||||||
|
errcode = result.get("errcode", 0)
|
||||||
|
if errcode != 0:
|
||||||
|
logger.error("DingTalk media upload api error type={} errcode={} body={}", media_type, errcode, text[:500])
|
||||||
|
return None
|
||||||
|
sub = result.get("result") or {}
|
||||||
|
media_id = result.get("media_id") or result.get("mediaId") or sub.get("media_id") or sub.get("mediaId")
|
||||||
|
if not media_id:
|
||||||
|
logger.error("DingTalk media upload missing media_id body={}", text[:500])
|
||||||
|
return None
|
||||||
|
return str(media_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("DingTalk media upload error type={} err={}", media_type, e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _send_batch_message(
|
||||||
|
self,
|
||||||
|
token: str,
|
||||||
|
chat_id: str,
|
||||||
|
msg_key: str,
|
||||||
|
msg_param: dict[str, Any],
|
||||||
|
) -> bool:
|
||||||
|
if not self._http:
|
||||||
|
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
||||||
|
return False
|
||||||
|
|
||||||
|
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||||
|
headers = {"x-acs-dingtalk-access-token": token}
|
||||||
|
payload = {
|
||||||
|
"robotCode": self.config.client_id,
|
||||||
|
"userIds": [chat_id],
|
||||||
|
"msgKey": msg_key,
|
||||||
|
"msgParam": json.dumps(msg_param, ensure_ascii=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = await self._http.post(url, json=payload, headers=headers)
|
||||||
|
body = resp.text
|
||||||
|
if resp.status_code != 200:
|
||||||
|
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
|
||||||
|
return False
|
||||||
|
try: result = resp.json()
|
||||||
|
except Exception: result = {}
|
||||||
|
errcode = result.get("errcode")
|
||||||
|
if errcode not in (None, 0):
|
||||||
|
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
|
||||||
|
return False
|
||||||
|
logger.debug("DingTalk message sent to {} with msgKey={}", chat_id, msg_key)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error sending DingTalk message msgKey={} err={}", msg_key, e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _send_markdown_text(self, token: str, chat_id: str, content: str) -> bool:
|
||||||
|
return await self._send_batch_message(
|
||||||
|
token,
|
||||||
|
chat_id,
|
||||||
|
"sampleMarkdown",
|
||||||
|
{"text": content, "title": "Nanobot Reply"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _send_media_ref(self, token: str, chat_id: str, media_ref: str) -> bool:
|
||||||
|
media_ref = (media_ref or "").strip()
|
||||||
|
if not media_ref:
|
||||||
|
return True
|
||||||
|
|
||||||
|
upload_type = self._guess_upload_type(media_ref)
|
||||||
|
if upload_type == "image" and self._is_http_url(media_ref):
|
||||||
|
ok = await self._send_batch_message(
|
||||||
|
token,
|
||||||
|
chat_id,
|
||||||
|
"sampleImageMsg",
|
||||||
|
{"photoURL": media_ref},
|
||||||
|
)
|
||||||
|
if ok:
|
||||||
|
return True
|
||||||
|
logger.warning("DingTalk image url send failed, trying upload fallback: {}", media_ref)
|
||||||
|
|
||||||
|
data, filename, content_type = await self._read_media_bytes(media_ref)
|
||||||
|
if not data:
|
||||||
|
logger.error("DingTalk media read failed: {}", media_ref)
|
||||||
|
return False
|
||||||
|
|
||||||
|
filename = filename or self._guess_filename(media_ref, upload_type)
|
||||||
|
file_type = Path(filename).suffix.lower().lstrip(".")
|
||||||
|
if not file_type:
|
||||||
|
guessed = mimetypes.guess_extension(content_type or "")
|
||||||
|
file_type = (guessed or ".bin").lstrip(".")
|
||||||
|
if file_type == "jpeg":
|
||||||
|
file_type = "jpg"
|
||||||
|
|
||||||
|
media_id = await self._upload_media(
|
||||||
|
token=token,
|
||||||
|
data=data,
|
||||||
|
media_type=upload_type,
|
||||||
|
filename=filename,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
if not media_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if upload_type == "image":
|
||||||
|
# Verified in production: sampleImageMsg accepts media_id in photoURL.
|
||||||
|
ok = await self._send_batch_message(
|
||||||
|
token,
|
||||||
|
chat_id,
|
||||||
|
"sampleImageMsg",
|
||||||
|
{"photoURL": media_id},
|
||||||
|
)
|
||||||
|
if ok:
|
||||||
|
return True
|
||||||
|
logger.warning("DingTalk image media_id send failed, falling back to file: {}", media_ref)
|
||||||
|
|
||||||
|
return await self._send_batch_message(
|
||||||
|
token,
|
||||||
|
chat_id,
|
||||||
|
"sampleFile",
|
||||||
|
{"mediaId": media_id, "fileName": filename, "fileType": file_type},
|
||||||
|
)
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
"""Send a message through DingTalk."""
|
"""Send a message through DingTalk."""
|
||||||
token = await self._get_access_token()
|
token = await self._get_access_token()
|
||||||
if not token:
|
if not token:
|
||||||
return
|
return
|
||||||
|
|
||||||
# oToMessages/batchSend: sends to individual users (private chat)
|
if msg.content and msg.content.strip():
|
||||||
# https://open.dingtalk.com/document/orgapp/robot-batch-send-messages
|
await self._send_markdown_text(token, msg.chat_id, msg.content.strip())
|
||||||
url = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
|
||||||
|
|
||||||
headers = {"x-acs-dingtalk-access-token": token}
|
for media_ref in msg.media or []:
|
||||||
|
ok = await self._send_media_ref(token, msg.chat_id, media_ref)
|
||||||
data = {
|
if ok:
|
||||||
"robotCode": self.config.client_id,
|
continue
|
||||||
"userIds": [msg.chat_id], # chat_id is the user's staffId
|
logger.error("DingTalk media send failed for {}", media_ref)
|
||||||
"msgKey": "sampleMarkdown",
|
# Send visible fallback so failures are observable by the user.
|
||||||
"msgParam": json.dumps({
|
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||||
"text": msg.content,
|
await self._send_markdown_text(
|
||||||
"title": "Nanobot Reply",
|
token,
|
||||||
}, ensure_ascii=False),
|
msg.chat_id,
|
||||||
}
|
f"[Attachment send failed: {filename}]",
|
||||||
|
)
|
||||||
if not self._http:
|
|
||||||
logger.warning("DingTalk HTTP client not initialized, cannot send")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
resp = await self._http.post(url, json=data, headers=headers)
|
|
||||||
if resp.status_code != 200:
|
|
||||||
logger.error("DingTalk send failed: {}", resp.text)
|
|
||||||
else:
|
|
||||||
logger.debug("DingTalk message sent to {}", msg.chat_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error sending DingTalk message: {}", e)
|
|
||||||
|
|
||||||
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
|
async def _on_message(self, content: str, sender_id: str, sender_name: str) -> None:
|
||||||
"""Handle incoming message (called by NanobotDingTalkHandler).
|
"""Handle incoming message (called by NanobotDingTalkHandler).
|
||||||
|
|||||||
@ -13,35 +13,13 @@ from nanobot.bus.events import OutboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import DiscordConfig
|
from nanobot.config.schema import DiscordConfig
|
||||||
|
from nanobot.utils.helpers import split_message
|
||||||
|
|
||||||
DISCORD_API_BASE = "https://discord.com/api/v10"
|
DISCORD_API_BASE = "https://discord.com/api/v10"
|
||||||
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
MAX_ATTACHMENT_BYTES = 20 * 1024 * 1024 # 20MB
|
||||||
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
MAX_MESSAGE_LEN = 2000 # Discord message character limit
|
||||||
|
|
||||||
|
|
||||||
def _split_message(content: str, max_len: int = MAX_MESSAGE_LEN) -> list[str]:
|
|
||||||
"""Split content into chunks within max_len, preferring line breaks."""
|
|
||||||
if not content:
|
|
||||||
return []
|
|
||||||
if len(content) <= max_len:
|
|
||||||
return [content]
|
|
||||||
chunks: list[str] = []
|
|
||||||
while content:
|
|
||||||
if len(content) <= max_len:
|
|
||||||
chunks.append(content)
|
|
||||||
break
|
|
||||||
cut = content[:max_len]
|
|
||||||
pos = cut.rfind('\n')
|
|
||||||
if pos <= 0:
|
|
||||||
pos = cut.rfind(' ')
|
|
||||||
if pos <= 0:
|
|
||||||
pos = max_len
|
|
||||||
chunks.append(content[:pos])
|
|
||||||
content = content[pos:].lstrip()
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordChannel(BaseChannel):
|
class DiscordChannel(BaseChannel):
|
||||||
"""Discord channel using Gateway websocket."""
|
"""Discord channel using Gateway websocket."""
|
||||||
|
|
||||||
@ -55,6 +33,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
self._heartbeat_task: asyncio.Task | None = None
|
self._heartbeat_task: asyncio.Task | None = None
|
||||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||||
self._http: httpx.AsyncClient | None = None
|
self._http: httpx.AsyncClient | None = None
|
||||||
|
self._bot_user_id: str | None = None
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the Discord gateway connection."""
|
"""Start the Discord gateway connection."""
|
||||||
@ -105,7 +84,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
headers = {"Authorization": f"Bot {self.config.token}"}
|
headers = {"Authorization": f"Bot {self.config.token}"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chunks = _split_message(msg.content or "")
|
chunks = split_message(msg.content or "", MAX_MESSAGE_LEN)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -171,6 +150,10 @@ class DiscordChannel(BaseChannel):
|
|||||||
await self._identify()
|
await self._identify()
|
||||||
elif op == 0 and event_type == "READY":
|
elif op == 0 and event_type == "READY":
|
||||||
logger.info("Discord gateway READY")
|
logger.info("Discord gateway READY")
|
||||||
|
# Capture bot user ID for mention detection
|
||||||
|
user_data = payload.get("user") or {}
|
||||||
|
self._bot_user_id = user_data.get("id")
|
||||||
|
logger.info("Discord bot connected as user {}", self._bot_user_id)
|
||||||
elif op == 0 and event_type == "MESSAGE_CREATE":
|
elif op == 0 and event_type == "MESSAGE_CREATE":
|
||||||
await self._handle_message_create(payload)
|
await self._handle_message_create(payload)
|
||||||
elif op == 7:
|
elif op == 7:
|
||||||
@ -227,6 +210,7 @@ class DiscordChannel(BaseChannel):
|
|||||||
sender_id = str(author.get("id", ""))
|
sender_id = str(author.get("id", ""))
|
||||||
channel_id = str(payload.get("channel_id", ""))
|
channel_id = str(payload.get("channel_id", ""))
|
||||||
content = payload.get("content") or ""
|
content = payload.get("content") or ""
|
||||||
|
guild_id = payload.get("guild_id")
|
||||||
|
|
||||||
if not sender_id or not channel_id:
|
if not sender_id or not channel_id:
|
||||||
return
|
return
|
||||||
@ -234,6 +218,11 @@ class DiscordChannel(BaseChannel):
|
|||||||
if not self.is_allowed(sender_id):
|
if not self.is_allowed(sender_id):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Check group channel policy (DMs always respond if is_allowed passes)
|
||||||
|
if guild_id is not None:
|
||||||
|
if not self._should_respond_in_group(payload, content):
|
||||||
|
return
|
||||||
|
|
||||||
content_parts = [content] if content else []
|
content_parts = [content] if content else []
|
||||||
media_paths: list[str] = []
|
media_paths: list[str] = []
|
||||||
media_dir = Path.home() / ".nanobot" / "media"
|
media_dir = Path.home() / ".nanobot" / "media"
|
||||||
@ -270,11 +259,32 @@ class DiscordChannel(BaseChannel):
|
|||||||
media=media_paths,
|
media=media_paths,
|
||||||
metadata={
|
metadata={
|
||||||
"message_id": str(payload.get("id", "")),
|
"message_id": str(payload.get("id", "")),
|
||||||
"guild_id": payload.get("guild_id"),
|
"guild_id": guild_id,
|
||||||
"reply_to": reply_to,
|
"reply_to": reply_to,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _should_respond_in_group(self, payload: dict[str, Any], content: str) -> bool:
|
||||||
|
"""Check if bot should respond in a group channel based on policy."""
|
||||||
|
if self.config.group_policy == "open":
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self.config.group_policy == "mention":
|
||||||
|
# Check if bot was mentioned in the message
|
||||||
|
if self._bot_user_id:
|
||||||
|
# Check mentions array
|
||||||
|
mentions = payload.get("mentions") or []
|
||||||
|
for mention in mentions:
|
||||||
|
if str(mention.get("id")) == self._bot_user_id:
|
||||||
|
return True
|
||||||
|
# Also check content for mention format <@USER_ID>
|
||||||
|
if f"<@{self._bot_user_id}>" in content or f"<@!{self._bot_user_id}>" in content:
|
||||||
|
return True
|
||||||
|
logger.debug("Discord message in {} ignored (bot not mentioned)", payload.get("channel_id"))
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
async def _start_typing(self, channel_id: str) -> None:
|
async def _start_typing(self, channel_id: str) -> None:
|
||||||
"""Start periodic typing indicator for a channel."""
|
"""Start periodic typing indicator for a channel."""
|
||||||
await self._stop_typing(channel_id)
|
await self._stop_typing(channel_id)
|
||||||
|
|||||||
@ -16,27 +16,9 @@ from nanobot.bus.queue import MessageBus
|
|||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import FeishuConfig
|
from nanobot.config.schema import FeishuConfig
|
||||||
|
|
||||||
try:
|
import importlib.util
|
||||||
import lark_oapi as lark
|
|
||||||
from lark_oapi.api.im.v1 import (
|
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||||
CreateFileRequest,
|
|
||||||
CreateFileRequestBody,
|
|
||||||
CreateImageRequest,
|
|
||||||
CreateImageRequestBody,
|
|
||||||
CreateMessageRequest,
|
|
||||||
CreateMessageRequestBody,
|
|
||||||
CreateMessageReactionRequest,
|
|
||||||
CreateMessageReactionRequestBody,
|
|
||||||
Emoji,
|
|
||||||
GetFileRequest,
|
|
||||||
GetMessageResourceRequest,
|
|
||||||
P2ImMessageReceiveV1,
|
|
||||||
)
|
|
||||||
FEISHU_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
FEISHU_AVAILABLE = False
|
|
||||||
lark = None
|
|
||||||
Emoji = None
|
|
||||||
|
|
||||||
# Message type display mapping
|
# Message type display mapping
|
||||||
MSG_TYPE_MAP = {
|
MSG_TYPE_MAP = {
|
||||||
@ -182,57 +164,59 @@ def _extract_element_content(element: dict) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
def _extract_post_content(content_json: dict) -> tuple[str, list[str]]:
|
||||||
"""Extract text and image keys from Feishu post (rich text) message content.
|
"""Extract text and image keys from Feishu post (rich text) message.
|
||||||
|
|
||||||
Supports two formats:
|
Handles three payload shapes:
|
||||||
1. Direct format: {"title": "...", "content": [...]}
|
- Direct: {"title": "...", "content": [[...]]}
|
||||||
2. Localized format: {"zh_cn": {"title": "...", "content": [...]}}
|
- Localized: {"zh_cn": {"title": "...", "content": [...]}}
|
||||||
|
- Wrapped: {"post": {"zh_cn": {"title": "...", "content": [...]}}}
|
||||||
Returns:
|
|
||||||
(text, image_keys) - extracted text and list of image keys
|
|
||||||
"""
|
"""
|
||||||
def extract_from_lang(lang_content: dict) -> tuple[str | None, list[str]]:
|
|
||||||
if not isinstance(lang_content, dict):
|
def _parse_block(block: dict) -> tuple[str | None, list[str]]:
|
||||||
|
if not isinstance(block, dict) or not isinstance(block.get("content"), list):
|
||||||
return None, []
|
return None, []
|
||||||
title = lang_content.get("title", "")
|
texts, images = [], []
|
||||||
content_blocks = lang_content.get("content", [])
|
if title := block.get("title"):
|
||||||
if not isinstance(content_blocks, list):
|
texts.append(title)
|
||||||
return None, []
|
for row in block["content"]:
|
||||||
text_parts = []
|
if not isinstance(row, list):
|
||||||
image_keys = []
|
|
||||||
if title:
|
|
||||||
text_parts.append(title)
|
|
||||||
for block in content_blocks:
|
|
||||||
if not isinstance(block, list):
|
|
||||||
continue
|
continue
|
||||||
for element in block:
|
for el in row:
|
||||||
if isinstance(element, dict):
|
if not isinstance(el, dict):
|
||||||
tag = element.get("tag")
|
continue
|
||||||
if tag == "text":
|
tag = el.get("tag")
|
||||||
text_parts.append(element.get("text", ""))
|
if tag in ("text", "a"):
|
||||||
elif tag == "a":
|
texts.append(el.get("text", ""))
|
||||||
text_parts.append(element.get("text", ""))
|
elif tag == "at":
|
||||||
elif tag == "at":
|
texts.append(f"@{el.get('user_name', 'user')}")
|
||||||
text_parts.append(f"@{element.get('user_name', 'user')}")
|
elif tag == "img" and (key := el.get("image_key")):
|
||||||
elif tag == "img":
|
images.append(key)
|
||||||
img_key = element.get("image_key")
|
return (" ".join(texts).strip() or None), images
|
||||||
if img_key:
|
|
||||||
image_keys.append(img_key)
|
|
||||||
text = " ".join(text_parts).strip() if text_parts else None
|
|
||||||
return text, image_keys
|
|
||||||
|
|
||||||
# Try direct format first
|
# Unwrap optional {"post": ...} envelope
|
||||||
if "content" in content_json:
|
root = content_json
|
||||||
text, images = extract_from_lang(content_json)
|
if isinstance(root, dict) and isinstance(root.get("post"), dict):
|
||||||
if text or images:
|
root = root["post"]
|
||||||
return text or "", images
|
if not isinstance(root, dict):
|
||||||
|
return "", []
|
||||||
|
|
||||||
# Try localized format
|
# Direct format
|
||||||
for lang_key in ("zh_cn", "en_us", "ja_jp"):
|
if "content" in root:
|
||||||
lang_content = content_json.get(lang_key)
|
text, imgs = _parse_block(root)
|
||||||
text, images = extract_from_lang(lang_content)
|
if text or imgs:
|
||||||
if text or images:
|
return text or "", imgs
|
||||||
return text or "", images
|
|
||||||
|
# Localized: prefer known locales, then fall back to any dict child
|
||||||
|
for key in ("zh_cn", "en_us", "ja_jp"):
|
||||||
|
if key in root:
|
||||||
|
text, imgs = _parse_block(root[key])
|
||||||
|
if text or imgs:
|
||||||
|
return text or "", imgs
|
||||||
|
for val in root.values():
|
||||||
|
if isinstance(val, dict):
|
||||||
|
text, imgs = _parse_block(val)
|
||||||
|
if text or imgs:
|
||||||
|
return text or "", imgs
|
||||||
|
|
||||||
return "", []
|
return "", []
|
||||||
|
|
||||||
@ -279,6 +263,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
logger.error("Feishu app_id and app_secret not configured")
|
logger.error("Feishu app_id and app_secret not configured")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
import lark_oapi as lark
|
||||||
self._running = True
|
self._running = True
|
||||||
self._loop = asyncio.get_running_loop()
|
self._loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
@ -305,15 +290,28 @@ class FeishuChannel(BaseChannel):
|
|||||||
log_level=lark.LogLevel.INFO
|
log_level=lark.LogLevel.INFO
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start WebSocket client in a separate thread with reconnect loop
|
# Start WebSocket client in a separate thread with reconnect loop.
|
||||||
|
# A dedicated event loop is created for this thread so that lark_oapi's
|
||||||
|
# module-level `loop = asyncio.get_event_loop()` picks up an idle loop
|
||||||
|
# instead of the already-running main asyncio loop, which would cause
|
||||||
|
# "This event loop is already running" errors.
|
||||||
def run_ws():
|
def run_ws():
|
||||||
while self._running:
|
import time
|
||||||
try:
|
import lark_oapi.ws.client as _lark_ws_client
|
||||||
self._ws_client.start()
|
ws_loop = asyncio.new_event_loop()
|
||||||
except Exception as e:
|
asyncio.set_event_loop(ws_loop)
|
||||||
logger.warning("Feishu WebSocket error: {}", e)
|
# Patch the module-level loop used by lark's ws Client.start()
|
||||||
if self._running:
|
_lark_ws_client.loop = ws_loop
|
||||||
import time; time.sleep(5)
|
try:
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
self._ws_client.start()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Feishu WebSocket error: {}", e)
|
||||||
|
if self._running:
|
||||||
|
time.sleep(5)
|
||||||
|
finally:
|
||||||
|
ws_loop.close()
|
||||||
|
|
||||||
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
self._ws_thread = threading.Thread(target=run_ws, daemon=True)
|
||||||
self._ws_thread.start()
|
self._ws_thread.start()
|
||||||
@ -326,17 +324,19 @@ class FeishuChannel(BaseChannel):
|
|||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
"""Stop the Feishu bot."""
|
"""
|
||||||
|
Stop the Feishu bot.
|
||||||
|
|
||||||
|
Notice: lark.ws.Client does not expose stop method, simply exiting the program will close the client.
|
||||||
|
|
||||||
|
Reference: https://github.com/larksuite/oapi-sdk-python/blob/v2_main/lark_oapi/ws/client.py#L86
|
||||||
|
"""
|
||||||
self._running = False
|
self._running = False
|
||||||
if self._ws_client:
|
|
||||||
try:
|
|
||||||
self._ws_client.stop()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("Error stopping WebSocket client: {}", e)
|
|
||||||
logger.info("Feishu bot stopped")
|
logger.info("Feishu bot stopped")
|
||||||
|
|
||||||
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
def _add_reaction_sync(self, message_id: str, emoji_type: str) -> None:
|
||||||
"""Sync helper for adding reaction (runs in thread pool)."""
|
"""Sync helper for adding reaction (runs in thread pool)."""
|
||||||
|
from lark_oapi.api.im.v1 import CreateMessageReactionRequest, CreateMessageReactionRequestBody, Emoji
|
||||||
try:
|
try:
|
||||||
request = CreateMessageReactionRequest.builder() \
|
request = CreateMessageReactionRequest.builder() \
|
||||||
.message_id(message_id) \
|
.message_id(message_id) \
|
||||||
@ -361,7 +361,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
||||||
"""
|
"""
|
||||||
if not self._client or not Emoji:
|
if not self._client:
|
||||||
return
|
return
|
||||||
|
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
@ -380,12 +380,13 @@ class FeishuChannel(BaseChannel):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_md_table(table_text: str) -> dict | None:
|
def _parse_md_table(table_text: str) -> dict | None:
|
||||||
"""Parse a markdown table into a Feishu table element."""
|
"""Parse a markdown table into a Feishu table element."""
|
||||||
lines = [l.strip() for l in table_text.strip().split("\n") if l.strip()]
|
lines = [_line.strip() for _line in table_text.strip().split("\n") if _line.strip()]
|
||||||
if len(lines) < 3:
|
if len(lines) < 3:
|
||||||
return None
|
return None
|
||||||
split = lambda l: [c.strip() for c in l.strip("|").split("|")]
|
def split(_line: str) -> list[str]:
|
||||||
|
return [c.strip() for c in _line.strip("|").split("|")]
|
||||||
headers = split(lines[0])
|
headers = split(lines[0])
|
||||||
rows = [split(l) for l in lines[2:]]
|
rows = [split(_line) for _line in lines[2:]]
|
||||||
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
|
columns = [{"tag": "column", "name": f"c{i}", "display_name": h, "width": "auto"}
|
||||||
for i, h in enumerate(headers)]
|
for i, h in enumerate(headers)]
|
||||||
return {
|
return {
|
||||||
@ -409,6 +410,34 @@ class FeishuChannel(BaseChannel):
|
|||||||
elements.extend(self._split_headings(remaining))
|
elements.extend(self._split_headings(remaining))
|
||||||
return elements or [{"tag": "markdown", "content": content}]
|
return elements or [{"tag": "markdown", "content": content}]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_elements_by_table_limit(elements: list[dict], max_tables: int = 1) -> list[list[dict]]:
|
||||||
|
"""Split card elements into groups with at most *max_tables* table elements each.
|
||||||
|
|
||||||
|
Feishu cards have a hard limit of one table per card (API error 11310).
|
||||||
|
When the rendered content contains multiple markdown tables each table is
|
||||||
|
placed in a separate card message so every table reaches the user.
|
||||||
|
"""
|
||||||
|
if not elements:
|
||||||
|
return [[]]
|
||||||
|
groups: list[list[dict]] = []
|
||||||
|
current: list[dict] = []
|
||||||
|
table_count = 0
|
||||||
|
for el in elements:
|
||||||
|
if el.get("tag") == "table":
|
||||||
|
if table_count >= max_tables:
|
||||||
|
if current:
|
||||||
|
groups.append(current)
|
||||||
|
current = []
|
||||||
|
table_count = 0
|
||||||
|
current.append(el)
|
||||||
|
table_count += 1
|
||||||
|
else:
|
||||||
|
current.append(el)
|
||||||
|
if current:
|
||||||
|
groups.append(current)
|
||||||
|
return groups or [[]]
|
||||||
|
|
||||||
def _split_headings(self, content: str) -> list[dict]:
|
def _split_headings(self, content: str) -> list[dict]:
|
||||||
"""Split content by headings, converting headings to div elements."""
|
"""Split content by headings, converting headings to div elements."""
|
||||||
protected = content
|
protected = content
|
||||||
@ -443,8 +472,124 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
return elements or [{"tag": "markdown", "content": content}]
|
return elements or [{"tag": "markdown", "content": content}]
|
||||||
|
|
||||||
|
# ── Smart format detection ──────────────────────────────────────────
|
||||||
|
# Patterns that indicate "complex" markdown needing card rendering
|
||||||
|
_COMPLEX_MD_RE = re.compile(
|
||||||
|
r"```" # fenced code block
|
||||||
|
r"|^\|.+\|.*\n\s*\|[-:\s|]+\|" # markdown table (header + separator)
|
||||||
|
r"|^#{1,6}\s+" # headings
|
||||||
|
, re.MULTILINE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simple markdown patterns (bold, italic, strikethrough)
|
||||||
|
_SIMPLE_MD_RE = re.compile(
|
||||||
|
r"\*\*.+?\*\*" # **bold**
|
||||||
|
r"|__.+?__" # __bold__
|
||||||
|
r"|(?<!\*)\*(?!\*)(.+?)(?<!\*)\*(?!\*)" # *italic* (single *)
|
||||||
|
r"|~~.+?~~" # ~~strikethrough~~
|
||||||
|
, re.DOTALL,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Markdown link: [text](url)
|
||||||
|
_MD_LINK_RE = re.compile(r"\[([^\]]+)\]\((https?://[^\)]+)\)")
|
||||||
|
|
||||||
|
# Unordered list items
|
||||||
|
_LIST_RE = re.compile(r"^[\s]*[-*+]\s+", re.MULTILINE)
|
||||||
|
|
||||||
|
# Ordered list items
|
||||||
|
_OLIST_RE = re.compile(r"^[\s]*\d+\.\s+", re.MULTILINE)
|
||||||
|
|
||||||
|
# Max length for plain text format
|
||||||
|
_TEXT_MAX_LEN = 200
|
||||||
|
|
||||||
|
# Max length for post (rich text) format; beyond this, use card
|
||||||
|
_POST_MAX_LEN = 2000
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _detect_msg_format(cls, content: str) -> str:
|
||||||
|
"""Determine the optimal Feishu message format for *content*.
|
||||||
|
|
||||||
|
Returns one of:
|
||||||
|
- ``"text"`` – plain text, short and no markdown
|
||||||
|
- ``"post"`` – rich text (links only, moderate length)
|
||||||
|
- ``"interactive"`` – card with full markdown rendering
|
||||||
|
"""
|
||||||
|
stripped = content.strip()
|
||||||
|
|
||||||
|
# Complex markdown (code blocks, tables, headings) → always card
|
||||||
|
if cls._COMPLEX_MD_RE.search(stripped):
|
||||||
|
return "interactive"
|
||||||
|
|
||||||
|
# Long content → card (better readability with card layout)
|
||||||
|
if len(stripped) > cls._POST_MAX_LEN:
|
||||||
|
return "interactive"
|
||||||
|
|
||||||
|
# Has bold/italic/strikethrough → card (post format can't render these)
|
||||||
|
if cls._SIMPLE_MD_RE.search(stripped):
|
||||||
|
return "interactive"
|
||||||
|
|
||||||
|
# Has list items → card (post format can't render list bullets well)
|
||||||
|
if cls._LIST_RE.search(stripped) or cls._OLIST_RE.search(stripped):
|
||||||
|
return "interactive"
|
||||||
|
|
||||||
|
# Has links → post format (supports <a> tags)
|
||||||
|
if cls._MD_LINK_RE.search(stripped):
|
||||||
|
return "post"
|
||||||
|
|
||||||
|
# Short plain text → text format
|
||||||
|
if len(stripped) <= cls._TEXT_MAX_LEN:
|
||||||
|
return "text"
|
||||||
|
|
||||||
|
# Medium plain text without any formatting → post format
|
||||||
|
return "post"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _markdown_to_post(cls, content: str) -> str:
|
||||||
|
"""Convert markdown content to Feishu post message JSON.
|
||||||
|
|
||||||
|
Handles links ``[text](url)`` as ``a`` tags; everything else as ``text`` tags.
|
||||||
|
Each line becomes a paragraph (row) in the post body.
|
||||||
|
"""
|
||||||
|
lines = content.strip().split("\n")
|
||||||
|
paragraphs: list[list[dict]] = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
elements: list[dict] = []
|
||||||
|
last_end = 0
|
||||||
|
|
||||||
|
for m in cls._MD_LINK_RE.finditer(line):
|
||||||
|
# Text before this link
|
||||||
|
before = line[last_end:m.start()]
|
||||||
|
if before:
|
||||||
|
elements.append({"tag": "text", "text": before})
|
||||||
|
elements.append({
|
||||||
|
"tag": "a",
|
||||||
|
"text": m.group(1),
|
||||||
|
"href": m.group(2),
|
||||||
|
})
|
||||||
|
last_end = m.end()
|
||||||
|
|
||||||
|
# Remaining text after last link
|
||||||
|
remaining = line[last_end:]
|
||||||
|
if remaining:
|
||||||
|
elements.append({"tag": "text", "text": remaining})
|
||||||
|
|
||||||
|
# Empty line → empty paragraph for spacing
|
||||||
|
if not elements:
|
||||||
|
elements.append({"tag": "text", "text": ""})
|
||||||
|
|
||||||
|
paragraphs.append(elements)
|
||||||
|
|
||||||
|
post_body = {
|
||||||
|
"zh_cn": {
|
||||||
|
"content": paragraphs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return json.dumps(post_body, ensure_ascii=False)
|
||||||
|
|
||||||
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
|
_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
|
||||||
_AUDIO_EXTS = {".opus"}
|
_AUDIO_EXTS = {".opus"}
|
||||||
|
_VIDEO_EXTS = {".mp4", ".mov", ".avi"}
|
||||||
_FILE_TYPE_MAP = {
|
_FILE_TYPE_MAP = {
|
||||||
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
|
".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
|
||||||
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
|
".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
|
||||||
@ -452,6 +597,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
def _upload_image_sync(self, file_path: str) -> str | None:
|
def _upload_image_sync(self, file_path: str) -> str | None:
|
||||||
"""Upload an image to Feishu and return the image_key."""
|
"""Upload an image to Feishu and return the image_key."""
|
||||||
|
from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody
|
||||||
try:
|
try:
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
request = CreateImageRequest.builder() \
|
request = CreateImageRequest.builder() \
|
||||||
@ -475,6 +621,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
def _upload_file_sync(self, file_path: str) -> str | None:
|
def _upload_file_sync(self, file_path: str) -> str | None:
|
||||||
"""Upload a file to Feishu and return the file_key."""
|
"""Upload a file to Feishu and return the file_key."""
|
||||||
|
from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody
|
||||||
ext = os.path.splitext(file_path)[1].lower()
|
ext = os.path.splitext(file_path)[1].lower()
|
||||||
file_type = self._FILE_TYPE_MAP.get(ext, "stream")
|
file_type = self._FILE_TYPE_MAP.get(ext, "stream")
|
||||||
file_name = os.path.basename(file_path)
|
file_name = os.path.basename(file_path)
|
||||||
@ -502,6 +649,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
|
def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
|
||||||
"""Download an image from Feishu message by message_id and image_key."""
|
"""Download an image from Feishu message by message_id and image_key."""
|
||||||
|
from lark_oapi.api.im.v1 import GetMessageResourceRequest
|
||||||
try:
|
try:
|
||||||
request = GetMessageResourceRequest.builder() \
|
request = GetMessageResourceRequest.builder() \
|
||||||
.message_id(message_id) \
|
.message_id(message_id) \
|
||||||
@ -526,6 +674,13 @@ class FeishuChannel(BaseChannel):
|
|||||||
self, message_id: str, file_key: str, resource_type: str = "file"
|
self, message_id: str, file_key: str, resource_type: str = "file"
|
||||||
) -> tuple[bytes | None, str | None]:
|
) -> tuple[bytes | None, str | None]:
|
||||||
"""Download a file/audio/media from a Feishu message by message_id and file_key."""
|
"""Download a file/audio/media from a Feishu message by message_id and file_key."""
|
||||||
|
from lark_oapi.api.im.v1 import GetMessageResourceRequest
|
||||||
|
|
||||||
|
# Feishu API only accepts 'image' or 'file' as type parameter
|
||||||
|
# Convert 'audio' to 'file' for API compatibility
|
||||||
|
if resource_type == "audio":
|
||||||
|
resource_type = "file"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
request = (
|
request = (
|
||||||
GetMessageResourceRequest.builder()
|
GetMessageResourceRequest.builder()
|
||||||
@ -594,6 +749,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
|
|
||||||
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
|
def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> bool:
|
||||||
"""Send a single message (text/image/file/interactive) synchronously."""
|
"""Send a single message (text/image/file/interactive) synchronously."""
|
||||||
|
from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody
|
||||||
try:
|
try:
|
||||||
request = CreateMessageRequest.builder() \
|
request = CreateMessageRequest.builder() \
|
||||||
.receive_id_type(receive_id_type) \
|
.receive_id_type(receive_id_type) \
|
||||||
@ -642,18 +798,45 @@ class FeishuChannel(BaseChannel):
|
|||||||
else:
|
else:
|
||||||
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
|
key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
|
||||||
if key:
|
if key:
|
||||||
media_type = "audio" if ext in self._AUDIO_EXTS else "file"
|
# Use msg_type "media" for audio/video so users can play inline;
|
||||||
|
# "file" for everything else (documents, archives, etc.)
|
||||||
|
if ext in self._AUDIO_EXTS or ext in self._VIDEO_EXTS:
|
||||||
|
media_type = "media"
|
||||||
|
else:
|
||||||
|
media_type = "file"
|
||||||
await loop.run_in_executor(
|
await loop.run_in_executor(
|
||||||
None, self._send_message_sync,
|
None, self._send_message_sync,
|
||||||
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
|
receive_id_type, msg.chat_id, media_type, json.dumps({"file_key": key}, ensure_ascii=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
if msg.content and msg.content.strip():
|
if msg.content and msg.content.strip():
|
||||||
card = {"config": {"wide_screen_mode": True}, "elements": self._build_card_elements(msg.content)}
|
fmt = self._detect_msg_format(msg.content)
|
||||||
await loop.run_in_executor(
|
|
||||||
None, self._send_message_sync,
|
if fmt == "text":
|
||||||
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
|
# Short plain text – send as simple text message
|
||||||
)
|
text_body = json.dumps({"text": msg.content.strip()}, ensure_ascii=False)
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, self._send_message_sync,
|
||||||
|
receive_id_type, msg.chat_id, "text", text_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif fmt == "post":
|
||||||
|
# Medium content with links – send as rich-text post
|
||||||
|
post_body = self._markdown_to_post(msg.content)
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, self._send_message_sync,
|
||||||
|
receive_id_type, msg.chat_id, "post", post_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Complex / long content – send as interactive card
|
||||||
|
elements = self._build_card_elements(msg.content)
|
||||||
|
for chunk in self._split_elements_by_table_limit(elements):
|
||||||
|
card = {"config": {"wide_screen_mode": True}, "elements": chunk}
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, self._send_message_sync,
|
||||||
|
receive_id_type, msg.chat_id, "interactive", json.dumps(card, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending Feishu message: {}", e)
|
logger.error("Error sending Feishu message: {}", e)
|
||||||
|
|||||||
@ -149,6 +149,16 @@ class ChannelManager:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Matrix channel not available: {}", e)
|
logger.warning("Matrix channel not available: {}", e)
|
||||||
|
|
||||||
|
self._validate_allow_from()
|
||||||
|
|
||||||
|
def _validate_allow_from(self) -> None:
|
||||||
|
for name, ch in self.channels.items():
|
||||||
|
if getattr(ch.config, "allow_from", None) == []:
|
||||||
|
raise SystemExit(
|
||||||
|
f'Error: "{name}" has empty allowFrom (denies all). '
|
||||||
|
f'Set ["*"] to allow everyone, or add specific user IDs.'
|
||||||
|
)
|
||||||
|
|
||||||
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
async def _start_channel(self, name: str, channel: BaseChannel) -> None:
|
||||||
"""Start a channel and log any exceptions."""
|
"""Start a channel and log any exceptions."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -12,10 +12,22 @@ try:
|
|||||||
import nh3
|
import nh3
|
||||||
from mistune import create_markdown
|
from mistune import create_markdown
|
||||||
from nio import (
|
from nio import (
|
||||||
AsyncClient, AsyncClientConfig, ContentRepositoryConfigError,
|
AsyncClient,
|
||||||
DownloadError, InviteEvent, JoinError, MatrixRoom, MemoryDownloadResponse,
|
AsyncClientConfig,
|
||||||
RoomEncryptedMedia, RoomMessage, RoomMessageMedia, RoomMessageText,
|
ContentRepositoryConfigError,
|
||||||
RoomSendError, RoomTypingError, SyncError, UploadError,
|
DownloadError,
|
||||||
|
InviteEvent,
|
||||||
|
JoinError,
|
||||||
|
MatrixRoom,
|
||||||
|
MemoryDownloadResponse,
|
||||||
|
RoomEncryptedMedia,
|
||||||
|
RoomMessage,
|
||||||
|
RoomMessageMedia,
|
||||||
|
RoomMessageText,
|
||||||
|
RoomSendError,
|
||||||
|
RoomTypingError,
|
||||||
|
SyncError,
|
||||||
|
UploadError,
|
||||||
)
|
)
|
||||||
from nio.crypto.attachments import decrypt_attachment
|
from nio.crypto.attachments import decrypt_attachment
|
||||||
from nio.exceptions import EncryptionError
|
from nio.exceptions import EncryptionError
|
||||||
@ -350,7 +362,11 @@ class MatrixChannel(BaseChannel):
|
|||||||
limit_bytes = await self._effective_media_limit_bytes()
|
limit_bytes = await self._effective_media_limit_bytes()
|
||||||
for path in candidates:
|
for path in candidates:
|
||||||
if fail := await self._upload_and_send_attachment(
|
if fail := await self._upload_and_send_attachment(
|
||||||
msg.chat_id, path, limit_bytes, relates_to):
|
room_id=msg.chat_id,
|
||||||
|
path=path,
|
||||||
|
limit_bytes=limit_bytes,
|
||||||
|
relates_to=relates_to,
|
||||||
|
):
|
||||||
failures.append(fail)
|
failures.append(fail)
|
||||||
if failures:
|
if failures:
|
||||||
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
|
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
|
||||||
@ -438,8 +454,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
|
async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
|
||||||
allow_from = self.config.allow_from or []
|
if self.is_allowed(event.sender):
|
||||||
if not allow_from or event.sender in allow_from:
|
|
||||||
await self.client.join(room.room_id)
|
await self.client.join(room.room_id)
|
||||||
|
|
||||||
def _is_direct_room(self, room: MatrixRoom) -> bool:
|
def _is_direct_room(self, room: MatrixRoom) -> bool:
|
||||||
@ -664,11 +679,13 @@ class MatrixChannel(BaseChannel):
|
|||||||
parts: list[str] = []
|
parts: list[str] = []
|
||||||
if isinstance(body := getattr(event, "body", None), str) and body.strip():
|
if isinstance(body := getattr(event, "body", None), str) and body.strip():
|
||||||
parts.append(body.strip())
|
parts.append(body.strip())
|
||||||
parts.append(marker)
|
if marker:
|
||||||
|
parts.append(marker)
|
||||||
|
|
||||||
await self._start_typing_keepalive(room.room_id)
|
await self._start_typing_keepalive(room.room_id)
|
||||||
try:
|
try:
|
||||||
meta = self._base_metadata(room, event)
|
meta = self._base_metadata(room, event)
|
||||||
|
meta["attachments"] = []
|
||||||
if attachment:
|
if attachment:
|
||||||
meta["attachments"] = [attachment]
|
meta["attachments"] = [attachment]
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
|
|||||||
@ -31,7 +31,8 @@ def _make_bot_class(channel: "QQChannel") -> "type[botpy.Client]":
|
|||||||
|
|
||||||
class _Bot(botpy.Client):
|
class _Bot(botpy.Client):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(intents=intents)
|
# Disable botpy's file log — nanobot uses loguru; default "botpy.log" fails on read-only fs
|
||||||
|
super().__init__(intents=intents, ext_handlers=False)
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
logger.info("QQ bot ready: {}", self.robot.name)
|
logger.info("QQ bot ready: {}", self.robot.name)
|
||||||
@ -55,6 +56,7 @@ class QQChannel(BaseChannel):
|
|||||||
self.config: QQConfig = config
|
self.config: QQConfig = config
|
||||||
self._client: "botpy.Client | None" = None
|
self._client: "botpy.Client | None" = None
|
||||||
self._processed_ids: deque = deque(maxlen=1000)
|
self._processed_ids: deque = deque(maxlen=1000)
|
||||||
|
self._msg_seq: int = 1 # 消息序列号,避免被 QQ API 去重
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the QQ bot."""
|
"""Start the QQ bot."""
|
||||||
@ -101,11 +103,13 @@ class QQChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
msg_id = msg.metadata.get("message_id")
|
msg_id = msg.metadata.get("message_id")
|
||||||
|
self._msg_seq += 1 # 递增序列号
|
||||||
await self._client.api.post_c2c_message(
|
await self._client.api.post_c2c_message(
|
||||||
openid=msg.chat_id,
|
openid=msg.chat_id,
|
||||||
msg_type=0,
|
msg_type=0,
|
||||||
content=msg.content,
|
content=msg.content,
|
||||||
msg_id=msg_id,
|
msg_id=msg_id,
|
||||||
|
msg_seq=self._msg_seq, # 添加序列号避免去重
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error sending QQ message: {}", e)
|
logger.error("Error sending QQ message: {}", e)
|
||||||
@ -132,3 +136,4 @@ class QQChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error handling QQ message")
|
logger.exception("Error handling QQ message")
|
||||||
|
|
||||||
|
|||||||
@ -5,11 +5,10 @@ import re
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from slack_sdk.socket_mode.websockets import SocketModeClient
|
|
||||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||||
|
from slack_sdk.socket_mode.websockets import SocketModeClient
|
||||||
from slack_sdk.web.async_client import AsyncWebClient
|
from slack_sdk.web.async_client import AsyncWebClient
|
||||||
|
|
||||||
from slackify_markdown import slackify_markdown
|
from slackify_markdown import slackify_markdown
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
|||||||
@ -4,15 +4,19 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from telegram import BotCommand, Update, ReplyParameters
|
from telegram import BotCommand, ReplyParameters, Update
|
||||||
from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
|
from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
|
||||||
from telegram.request import HTTPXRequest
|
from telegram.request import HTTPXRequest
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import TelegramConfig
|
from nanobot.config.schema import TelegramConfig
|
||||||
|
from nanobot.utils.helpers import split_message
|
||||||
|
|
||||||
|
TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
|
||||||
|
|
||||||
|
|
||||||
def _markdown_to_telegram_html(text: str) -> str:
|
def _markdown_to_telegram_html(text: str) -> str:
|
||||||
@ -78,26 +82,6 @@ def _markdown_to_telegram_html(text: str) -> str:
|
|||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
def _split_message(content: str, max_len: int = 4000) -> list[str]:
|
|
||||||
"""Split content into chunks within max_len, preferring line breaks."""
|
|
||||||
if len(content) <= max_len:
|
|
||||||
return [content]
|
|
||||||
chunks: list[str] = []
|
|
||||||
while content:
|
|
||||||
if len(content) <= max_len:
|
|
||||||
chunks.append(content)
|
|
||||||
break
|
|
||||||
cut = content[:max_len]
|
|
||||||
pos = cut.rfind('\n')
|
|
||||||
if pos == -1:
|
|
||||||
pos = cut.rfind(' ')
|
|
||||||
if pos == -1:
|
|
||||||
pos = max_len
|
|
||||||
chunks.append(content[:pos])
|
|
||||||
content = content[pos:].lstrip()
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
class TelegramChannel(BaseChannel):
|
class TelegramChannel(BaseChannel):
|
||||||
"""
|
"""
|
||||||
Telegram channel using long polling.
|
Telegram channel using long polling.
|
||||||
@ -224,7 +208,9 @@ class TelegramChannel(BaseChannel):
|
|||||||
logger.warning("Telegram bot not running")
|
logger.warning("Telegram bot not running")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._stop_typing(msg.chat_id)
|
# Only stop typing indicator for final responses
|
||||||
|
if not msg.metadata.get("_progress", False):
|
||||||
|
self._stop_typing(msg.chat_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chat_id = int(msg.chat_id)
|
chat_id = int(msg.chat_id)
|
||||||
@ -268,23 +254,41 @@ class TelegramChannel(BaseChannel):
|
|||||||
|
|
||||||
# Send text content
|
# Send text content
|
||||||
if msg.content and msg.content != "[empty message]":
|
if msg.content and msg.content != "[empty message]":
|
||||||
for chunk in _split_message(msg.content):
|
is_progress = msg.metadata.get("_progress", False)
|
||||||
|
draft_id = msg.metadata.get("message_id")
|
||||||
|
|
||||||
|
for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
|
||||||
try:
|
try:
|
||||||
html = _markdown_to_telegram_html(chunk)
|
html = _markdown_to_telegram_html(chunk)
|
||||||
await self._app.bot.send_message(
|
if is_progress and draft_id:
|
||||||
chat_id=chat_id,
|
await self._app.bot.send_message_draft(
|
||||||
text=html,
|
chat_id=chat_id,
|
||||||
parse_mode="HTML",
|
draft_id=draft_id,
|
||||||
reply_parameters=reply_params
|
text=html,
|
||||||
)
|
parse_mode="HTML"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._app.bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=html,
|
||||||
|
parse_mode="HTML",
|
||||||
|
reply_parameters=reply_params
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
logger.warning("HTML parse failed, falling back to plain text: {}", e)
|
||||||
try:
|
try:
|
||||||
await self._app.bot.send_message(
|
if is_progress and draft_id:
|
||||||
chat_id=chat_id,
|
await self._app.bot.send_message_draft(
|
||||||
text=chunk,
|
chat_id=chat_id,
|
||||||
reply_parameters=reply_params
|
draft_id=draft_id,
|
||||||
)
|
text=chunk
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self._app.bot.send_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
text=chunk,
|
||||||
|
reply_parameters=reply_params
|
||||||
|
)
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
logger.error("Error sending Telegram message: {}", e2)
|
logger.error("Error sending Telegram message: {}", e2)
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from collections import OrderedDict
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@ -27,6 +27,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
self.config: WhatsAppConfig = config
|
self.config: WhatsAppConfig = config
|
||||||
self._ws = None
|
self._ws = None
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the WhatsApp channel by connecting to the bridge."""
|
"""Start the WhatsApp channel by connecting to the bridge."""
|
||||||
@ -108,6 +109,14 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
# New LID sytle typically:
|
# New LID sytle typically:
|
||||||
sender = data.get("sender", "")
|
sender = data.get("sender", "")
|
||||||
content = data.get("content", "")
|
content = data.get("content", "")
|
||||||
|
message_id = data.get("id", "")
|
||||||
|
|
||||||
|
if message_id:
|
||||||
|
if message_id in self._processed_message_ids:
|
||||||
|
return
|
||||||
|
self._processed_message_ids[message_id] = None
|
||||||
|
while len(self._processed_message_ids) > 1000:
|
||||||
|
self._processed_message_ids.popitem(last=False)
|
||||||
|
|
||||||
# Extract just the phone number or lid as chat_id
|
# Extract just the phone number or lid as chat_id
|
||||||
user_id = pn if pn else sender
|
user_id = pn if pn else sender
|
||||||
@ -124,7 +133,7 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
chat_id=sender, # Use full LID for replies
|
chat_id=sender, # Use full LID for replies
|
||||||
content=content,
|
content=content,
|
||||||
metadata={
|
metadata={
|
||||||
"message_id": data.get("id"),
|
"message_id": message_id,
|
||||||
"timestamp": data.get("timestamp"),
|
"timestamp": data.get("timestamp"),
|
||||||
"is_group": data.get("isGroup", False)
|
"is_group": data.get("isGroup", False)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,23 +2,34 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import signal
|
|
||||||
from pathlib import Path
|
|
||||||
import select
|
import select
|
||||||
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Force UTF-8 encoding for Windows console
|
||||||
|
if sys.platform == "win32":
|
||||||
|
import locale
|
||||||
|
if sys.stdout.encoding != "utf-8":
|
||||||
|
os.environ["PYTHONIOENCODING"] = "utf-8"
|
||||||
|
# Re-open stdout/stderr with UTF-8 encoding
|
||||||
|
try:
|
||||||
|
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
||||||
|
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
from prompt_toolkit import PromptSession
|
||||||
|
from prompt_toolkit.formatted_text import HTML
|
||||||
|
from prompt_toolkit.history import FileHistory
|
||||||
|
from prompt_toolkit.patch_stdout import patch_stdout
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.markdown import Markdown
|
from rich.markdown import Markdown
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
from rich.text import Text
|
from rich.text import Text
|
||||||
|
|
||||||
from prompt_toolkit import PromptSession
|
from nanobot import __logo__, __version__
|
||||||
from prompt_toolkit.formatted_text import HTML
|
|
||||||
from prompt_toolkit.history import FileHistory
|
|
||||||
from prompt_toolkit.patch_stdout import patch_stdout
|
|
||||||
|
|
||||||
from nanobot import __version__, __logo__
|
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.utils.helpers import sync_workspace_templates
|
from nanobot.utils.helpers import sync_workspace_templates
|
||||||
|
|
||||||
@ -201,9 +212,7 @@ def onboard():
|
|||||||
|
|
||||||
def _make_provider(config: Config):
|
def _make_provider(config: Config):
|
||||||
"""Create the appropriate LLM provider from config."""
|
"""Create the appropriate LLM provider from config."""
|
||||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
from nanobot.providers.custom_provider import CustomProvider
|
|
||||||
|
|
||||||
model = config.agents.defaults.model
|
model = config.agents.defaults.model
|
||||||
provider_name = config.get_provider_name(model)
|
provider_name = config.get_provider_name(model)
|
||||||
@ -214,6 +223,7 @@ def _make_provider(config: Config):
|
|||||||
return OpenAICodexProvider(default_model=model)
|
return OpenAICodexProvider(default_model=model)
|
||||||
|
|
||||||
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
||||||
|
from nanobot.providers.custom_provider import CustomProvider
|
||||||
if provider_name == "custom":
|
if provider_name == "custom":
|
||||||
return CustomProvider(
|
return CustomProvider(
|
||||||
api_key=p.api_key if p else "no-key",
|
api_key=p.api_key if p else "no-key",
|
||||||
@ -221,6 +231,7 @@ def _make_provider(config: Config):
|
|||||||
default_model=model,
|
default_model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
spec = find_by_name(provider_name)
|
spec = find_by_name(provider_name)
|
||||||
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
|
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and spec.is_oauth):
|
||||||
@ -284,7 +295,9 @@ def serve(
|
|||||||
max_tokens=config.agents.defaults.max_tokens,
|
max_tokens=config.agents.defaults.max_tokens,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
memory_window=config.agents.defaults.memory_window,
|
memory_window=config.agents.defaults.memory_window,
|
||||||
|
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
session_manager=session_manager,
|
session_manager=session_manager,
|
||||||
@ -322,32 +335,38 @@ def serve(
|
|||||||
@app.command()
|
@app.command()
|
||||||
def gateway(
|
def gateway(
|
||||||
port: int = typer.Option(18790, "--port", "-p", help="Gateway port"),
|
port: int = typer.Option(18790, "--port", "-p", help="Gateway port"),
|
||||||
|
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||||
|
config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
|
||||||
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
|
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
|
||||||
):
|
):
|
||||||
"""Start the nanobot gateway."""
|
"""Start the nanobot gateway."""
|
||||||
from nanobot.config.loader import load_config, get_data_dir
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.manager import ChannelManager
|
from nanobot.channels.manager import ChannelManager
|
||||||
from nanobot.session.manager import SessionManager
|
from nanobot.config.loader import load_config
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
from nanobot.cron.types import CronJob
|
from nanobot.cron.types import CronJob
|
||||||
from nanobot.heartbeat.service import HeartbeatService
|
from nanobot.heartbeat.service import HeartbeatService
|
||||||
|
from nanobot.session.manager import SessionManager
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
config_path = Path(config) if config else None
|
||||||
|
config = load_config(config_path)
|
||||||
|
if workspace:
|
||||||
|
config.agents.defaults.workspace = workspace
|
||||||
|
|
||||||
config = load_config()
|
console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
|
||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = _make_provider(config)
|
provider = _make_provider(config)
|
||||||
session_manager = SessionManager(config.workspace_path)
|
session_manager = SessionManager(config.workspace_path)
|
||||||
|
|
||||||
# Create cron service first (callback set after agent creation)
|
# Create cron service first (callback set after agent creation)
|
||||||
cron_store_path = get_data_dir() / "cron" / "jobs.json"
|
# Use workspace path for per-instance cron store
|
||||||
|
cron_store_path = config.workspace_path / "cron" / "jobs.json"
|
||||||
cron = CronService(cron_store_path)
|
cron = CronService(cron_store_path)
|
||||||
|
|
||||||
# Create agent with cron service
|
# Create agent with cron service
|
||||||
@ -360,7 +379,9 @@ def gateway(
|
|||||||
max_tokens=config.agents.defaults.max_tokens,
|
max_tokens=config.agents.defaults.max_tokens,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
memory_window=config.agents.defaults.memory_window,
|
memory_window=config.agents.defaults.memory_window,
|
||||||
|
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
@ -372,18 +393,40 @@ def gateway(
|
|||||||
# Set cron callback (needs agent)
|
# Set cron callback (needs agent)
|
||||||
async def on_cron_job(job: CronJob) -> str | None:
|
async def on_cron_job(job: CronJob) -> str | None:
|
||||||
"""Execute a cron job through the agent."""
|
"""Execute a cron job through the agent."""
|
||||||
response = await agent.process_direct(
|
from nanobot.agent.tools.cron import CronTool
|
||||||
job.payload.message,
|
from nanobot.agent.tools.message import MessageTool
|
||||||
session_key=f"cron:{job.id}",
|
reminder_note = (
|
||||||
channel=job.payload.channel or "cli",
|
"[Scheduled Task] Timer finished.\n\n"
|
||||||
chat_id=job.payload.to or "direct",
|
f"Task '{job.name}' has been triggered.\n"
|
||||||
|
f"Scheduled instruction: {job.payload.message}"
|
||||||
)
|
)
|
||||||
if job.payload.deliver and job.payload.to:
|
|
||||||
|
# Prevent the agent from scheduling new cron jobs during execution
|
||||||
|
cron_tool = agent.tools.get("cron")
|
||||||
|
cron_token = None
|
||||||
|
if isinstance(cron_tool, CronTool):
|
||||||
|
cron_token = cron_tool.set_cron_context(True)
|
||||||
|
try:
|
||||||
|
response = await agent.process_direct(
|
||||||
|
reminder_note,
|
||||||
|
session_key=f"cron:{job.id}",
|
||||||
|
channel=job.payload.channel or "cli",
|
||||||
|
chat_id=job.payload.to or "direct",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
||||||
|
cron_tool.reset_cron_context(cron_token)
|
||||||
|
|
||||||
|
message_tool = agent.tools.get("message")
|
||||||
|
if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||||
|
return response
|
||||||
|
|
||||||
|
if job.payload.deliver and job.payload.to and response:
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
await bus.publish_outbound(OutboundMessage(
|
await bus.publish_outbound(OutboundMessage(
|
||||||
channel=job.payload.channel or "cli",
|
channel=job.payload.channel or "cli",
|
||||||
chat_id=job.payload.to,
|
chat_id=job.payload.to,
|
||||||
content=response or ""
|
content=response
|
||||||
))
|
))
|
||||||
return response
|
return response
|
||||||
cron.on_job = on_cron_job
|
cron.on_job = on_cron_job
|
||||||
@ -488,12 +531,13 @@ def agent(
|
|||||||
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
|
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
|
||||||
):
|
):
|
||||||
"""Interact with the agent directly."""
|
"""Interact with the agent directly."""
|
||||||
from nanobot.config.loader import load_config, get_data_dir
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
from nanobot.cron.service import CronService
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.config.loader import get_data_dir, load_config
|
||||||
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
config = load_config()
|
config = load_config()
|
||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
|
|
||||||
@ -518,7 +562,9 @@ def agent(
|
|||||||
max_tokens=config.agents.defaults.max_tokens,
|
max_tokens=config.agents.defaults.max_tokens,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
memory_window=config.agents.defaults.memory_window,
|
memory_window=config.agents.defaults.memory_window,
|
||||||
|
reasoning_effort=config.agents.defaults.reasoning_effort,
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
brave_api_key=config.tools.web.search.api_key or None,
|
||||||
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
cron_service=cron,
|
cron_service=cron,
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||||
@ -562,12 +608,21 @@ def agent(
|
|||||||
else:
|
else:
|
||||||
cli_channel, cli_chat_id = "cli", session_id
|
cli_channel, cli_chat_id = "cli", session_id
|
||||||
|
|
||||||
def _exit_on_sigint(signum, frame):
|
def _handle_signal(signum, frame):
|
||||||
|
sig_name = signal.Signals(signum).name
|
||||||
_restore_terminal()
|
_restore_terminal()
|
||||||
console.print("\nGoodbye!")
|
console.print(f"\nReceived {sig_name}, goodbye!")
|
||||||
os._exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, _exit_on_sigint)
|
signal.signal(signal.SIGINT, _handle_signal)
|
||||||
|
signal.signal(signal.SIGTERM, _handle_signal)
|
||||||
|
# SIGHUP is not available on Windows
|
||||||
|
if hasattr(signal, 'SIGHUP'):
|
||||||
|
signal.signal(signal.SIGHUP, _handle_signal)
|
||||||
|
# Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
|
||||||
|
# SIGPIPE is not available on Windows
|
||||||
|
if hasattr(signal, 'SIGPIPE'):
|
||||||
|
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||||
|
|
||||||
async def run_interactive():
|
async def run_interactive():
|
||||||
bus_task = asyncio.create_task(agent_loop.run())
|
bus_task = asyncio.create_task(agent_loop.run())
|
||||||
@ -812,6 +867,7 @@ def _get_bridge_dir() -> Path:
|
|||||||
def channels_login():
|
def channels_login():
|
||||||
"""Link device via QR code."""
|
"""Link device via QR code."""
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
from nanobot.config.loader import load_config
|
from nanobot.config.loader import load_config
|
||||||
|
|
||||||
config = load_config()
|
config = load_config()
|
||||||
@ -832,218 +888,6 @@ def channels_login():
|
|||||||
console.print("[red]npm not found. Please install Node.js.[/red]")
|
console.print("[red]npm not found. Please install Node.js.[/red]")
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Cron Commands
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
cron_app = typer.Typer(help="Manage scheduled tasks")
|
|
||||||
app.add_typer(cron_app, name="cron")
|
|
||||||
|
|
||||||
|
|
||||||
@cron_app.command("list")
|
|
||||||
def cron_list(
|
|
||||||
all: bool = typer.Option(False, "--all", "-a", help="Include disabled jobs"),
|
|
||||||
):
|
|
||||||
"""List scheduled jobs."""
|
|
||||||
from nanobot.config.loader import get_data_dir
|
|
||||||
from nanobot.cron.service import CronService
|
|
||||||
|
|
||||||
store_path = get_data_dir() / "cron" / "jobs.json"
|
|
||||||
service = CronService(store_path)
|
|
||||||
|
|
||||||
jobs = service.list_jobs(include_disabled=all)
|
|
||||||
|
|
||||||
if not jobs:
|
|
||||||
console.print("No scheduled jobs.")
|
|
||||||
return
|
|
||||||
|
|
||||||
table = Table(title="Scheduled Jobs")
|
|
||||||
table.add_column("ID", style="cyan")
|
|
||||||
table.add_column("Name")
|
|
||||||
table.add_column("Schedule")
|
|
||||||
table.add_column("Status")
|
|
||||||
table.add_column("Next Run")
|
|
||||||
|
|
||||||
import time
|
|
||||||
from datetime import datetime as _dt
|
|
||||||
from zoneinfo import ZoneInfo
|
|
||||||
for job in jobs:
|
|
||||||
# Format schedule
|
|
||||||
if job.schedule.kind == "every":
|
|
||||||
sched = f"every {(job.schedule.every_ms or 0) // 1000}s"
|
|
||||||
elif job.schedule.kind == "cron":
|
|
||||||
sched = f"{job.schedule.expr or ''} ({job.schedule.tz})" if job.schedule.tz else (job.schedule.expr or "")
|
|
||||||
else:
|
|
||||||
sched = "one-time"
|
|
||||||
|
|
||||||
# Format next run
|
|
||||||
next_run = ""
|
|
||||||
if job.state.next_run_at_ms:
|
|
||||||
ts = job.state.next_run_at_ms / 1000
|
|
||||||
try:
|
|
||||||
tz = ZoneInfo(job.schedule.tz) if job.schedule.tz else None
|
|
||||||
next_run = _dt.fromtimestamp(ts, tz).strftime("%Y-%m-%d %H:%M")
|
|
||||||
except Exception:
|
|
||||||
next_run = time.strftime("%Y-%m-%d %H:%M", time.localtime(ts))
|
|
||||||
|
|
||||||
status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]"
|
|
||||||
|
|
||||||
table.add_row(job.id, job.name, sched, status, next_run)
|
|
||||||
|
|
||||||
console.print(table)
|
|
||||||
|
|
||||||
|
|
||||||
@cron_app.command("add")
|
|
||||||
def cron_add(
|
|
||||||
name: str = typer.Option(..., "--name", "-n", help="Job name"),
|
|
||||||
message: str = typer.Option(..., "--message", "-m", help="Message for agent"),
|
|
||||||
every: int = typer.Option(None, "--every", "-e", help="Run every N seconds"),
|
|
||||||
cron_expr: str = typer.Option(None, "--cron", "-c", help="Cron expression (e.g. '0 9 * * *')"),
|
|
||||||
tz: str | None = typer.Option(None, "--tz", help="IANA timezone for cron (e.g. 'America/Vancouver')"),
|
|
||||||
at: str = typer.Option(None, "--at", help="Run once at time (ISO format)"),
|
|
||||||
deliver: bool = typer.Option(False, "--deliver", "-d", help="Deliver response to channel"),
|
|
||||||
to: str = typer.Option(None, "--to", help="Recipient for delivery"),
|
|
||||||
channel: str = typer.Option(None, "--channel", help="Channel for delivery (e.g. 'telegram', 'whatsapp')"),
|
|
||||||
):
|
|
||||||
"""Add a scheduled job."""
|
|
||||||
from nanobot.config.loader import get_data_dir
|
|
||||||
from nanobot.cron.service import CronService
|
|
||||||
from nanobot.cron.types import CronSchedule
|
|
||||||
|
|
||||||
if tz and not cron_expr:
|
|
||||||
console.print("[red]Error: --tz can only be used with --cron[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# Determine schedule type
|
|
||||||
if every:
|
|
||||||
schedule = CronSchedule(kind="every", every_ms=every * 1000)
|
|
||||||
elif cron_expr:
|
|
||||||
schedule = CronSchedule(kind="cron", expr=cron_expr, tz=tz)
|
|
||||||
elif at:
|
|
||||||
import datetime
|
|
||||||
dt = datetime.datetime.fromisoformat(at)
|
|
||||||
schedule = CronSchedule(kind="at", at_ms=int(dt.timestamp() * 1000))
|
|
||||||
else:
|
|
||||||
console.print("[red]Error: Must specify --every, --cron, or --at[/red]")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
store_path = get_data_dir() / "cron" / "jobs.json"
|
|
||||||
service = CronService(store_path)
|
|
||||||
|
|
||||||
try:
|
|
||||||
job = service.add_job(
|
|
||||||
name=name,
|
|
||||||
schedule=schedule,
|
|
||||||
message=message,
|
|
||||||
deliver=deliver,
|
|
||||||
to=to,
|
|
||||||
channel=channel,
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
console.print(f"[red]Error: {e}[/red]")
|
|
||||||
raise typer.Exit(1) from e
|
|
||||||
|
|
||||||
console.print(f"[green]✓[/green] Added job '{job.name}' ({job.id})")
|
|
||||||
|
|
||||||
|
|
||||||
@cron_app.command("remove")
|
|
||||||
def cron_remove(
|
|
||||||
job_id: str = typer.Argument(..., help="Job ID to remove"),
|
|
||||||
):
|
|
||||||
"""Remove a scheduled job."""
|
|
||||||
from nanobot.config.loader import get_data_dir
|
|
||||||
from nanobot.cron.service import CronService
|
|
||||||
|
|
||||||
store_path = get_data_dir() / "cron" / "jobs.json"
|
|
||||||
service = CronService(store_path)
|
|
||||||
|
|
||||||
if service.remove_job(job_id):
|
|
||||||
console.print(f"[green]✓[/green] Removed job {job_id}")
|
|
||||||
else:
|
|
||||||
console.print(f"[red]Job {job_id} not found[/red]")
|
|
||||||
|
|
||||||
|
|
||||||
@cron_app.command("enable")
|
|
||||||
def cron_enable(
|
|
||||||
job_id: str = typer.Argument(..., help="Job ID"),
|
|
||||||
disable: bool = typer.Option(False, "--disable", help="Disable instead of enable"),
|
|
||||||
):
|
|
||||||
"""Enable or disable a job."""
|
|
||||||
from nanobot.config.loader import get_data_dir
|
|
||||||
from nanobot.cron.service import CronService
|
|
||||||
|
|
||||||
store_path = get_data_dir() / "cron" / "jobs.json"
|
|
||||||
service = CronService(store_path)
|
|
||||||
|
|
||||||
job = service.enable_job(job_id, enabled=not disable)
|
|
||||||
if job:
|
|
||||||
status = "disabled" if disable else "enabled"
|
|
||||||
console.print(f"[green]✓[/green] Job '{job.name}' {status}")
|
|
||||||
else:
|
|
||||||
console.print(f"[red]Job {job_id} not found[/red]")
|
|
||||||
|
|
||||||
|
|
||||||
@cron_app.command("run")
|
|
||||||
def cron_run(
|
|
||||||
job_id: str = typer.Argument(..., help="Job ID to run"),
|
|
||||||
force: bool = typer.Option(False, "--force", "-f", help="Run even if disabled"),
|
|
||||||
):
|
|
||||||
"""Manually run a job."""
|
|
||||||
from loguru import logger
|
|
||||||
from nanobot.config.loader import load_config, get_data_dir
|
|
||||||
from nanobot.cron.service import CronService
|
|
||||||
from nanobot.cron.types import CronJob
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
|
||||||
logger.disable("nanobot")
|
|
||||||
|
|
||||||
config = load_config()
|
|
||||||
provider = _make_provider(config)
|
|
||||||
bus = MessageBus()
|
|
||||||
agent_loop = AgentLoop(
|
|
||||||
bus=bus,
|
|
||||||
provider=provider,
|
|
||||||
workspace=config.workspace_path,
|
|
||||||
model=config.agents.defaults.model,
|
|
||||||
temperature=config.agents.defaults.temperature,
|
|
||||||
max_tokens=config.agents.defaults.max_tokens,
|
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
|
||||||
memory_window=config.agents.defaults.memory_window,
|
|
||||||
brave_api_key=config.tools.web.search.api_key or None,
|
|
||||||
exec_config=config.tools.exec,
|
|
||||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
|
||||||
mcp_servers=config.tools.mcp_servers,
|
|
||||||
channels_config=config.channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
store_path = get_data_dir() / "cron" / "jobs.json"
|
|
||||||
service = CronService(store_path)
|
|
||||||
|
|
||||||
result_holder = []
|
|
||||||
|
|
||||||
async def on_job(job: CronJob) -> str | None:
|
|
||||||
response = await agent_loop.process_direct(
|
|
||||||
job.payload.message,
|
|
||||||
session_key=f"cron:{job.id}",
|
|
||||||
channel=job.payload.channel or "cli",
|
|
||||||
chat_id=job.payload.to or "direct",
|
|
||||||
)
|
|
||||||
result_holder.append(response)
|
|
||||||
return response
|
|
||||||
|
|
||||||
service.on_job = on_job
|
|
||||||
|
|
||||||
async def run():
|
|
||||||
return await service.run_job(job_id, force=force)
|
|
||||||
|
|
||||||
if asyncio.run(run()):
|
|
||||||
console.print("[green]✓[/green] Job executed")
|
|
||||||
if result_holder:
|
|
||||||
_print_agent_response(result_holder[0], render_markdown=True)
|
|
||||||
else:
|
|
||||||
console.print(f"[red]Failed to run job {job_id}[/red]")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Status Commands
|
# Status Commands
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@ -1052,7 +896,7 @@ def cron_run(
|
|||||||
@app.command()
|
@app.command()
|
||||||
def status():
|
def status():
|
||||||
"""Show nanobot status."""
|
"""Show nanobot status."""
|
||||||
from nanobot.config.loader import load_config, get_config_path
|
from nanobot.config.loader import get_config_path, load_config
|
||||||
|
|
||||||
config_path = get_config_path()
|
config_path = get_config_path()
|
||||||
config = load_config()
|
config = load_config()
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Configuration module for nanobot."""
|
"""Configuration module for nanobot."""
|
||||||
|
|
||||||
from nanobot.config.loader import load_config, get_config_path
|
from nanobot.config.loader import get_config_path, load_config
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
|
|
||||||
__all__ = ["Config", "load_config", "get_config_path"]
|
__all__ = ["Config", "load_config", "get_config_path"]
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from pydantic.alias_generators import to_camel
|
from pydantic.alias_generators import to_camel
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
@ -29,7 +29,9 @@ class TelegramConfig(Base):
|
|||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
token: str = "" # Bot token from @BotFather
|
token: str = "" # Bot token from @BotFather
|
||||||
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
|
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
|
||||||
proxy: str | None = None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
|
proxy: str | None = (
|
||||||
|
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
|
||||||
|
)
|
||||||
reply_to_message: bool = False # If true, bot replies quote the original message
|
reply_to_message: bool = False # If true, bot replies quote the original message
|
||||||
|
|
||||||
|
|
||||||
@ -42,7 +44,9 @@ class FeishuConfig(Base):
|
|||||||
encrypt_key: str = "" # Encrypt Key for event subscription (optional)
|
encrypt_key: str = "" # Encrypt Key for event subscription (optional)
|
||||||
verification_token: str = "" # Verification Token for event subscription (optional)
|
verification_token: str = "" # Verification Token for event subscription (optional)
|
||||||
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
|
allow_from: list[str] = Field(default_factory=list) # Allowed user open_ids
|
||||||
react_emoji: str = "THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
|
react_emoji: str = (
|
||||||
|
"THUMBSUP" # Emoji type for message reactions (e.g. THUMBSUP, OK, DONE, SMILE)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DingTalkConfig(Base):
|
class DingTalkConfig(Base):
|
||||||
@ -62,6 +66,7 @@ class DiscordConfig(Base):
|
|||||||
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
|
allow_from: list[str] = Field(default_factory=list) # Allowed user IDs
|
||||||
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
|
gateway_url: str = "wss://gateway.discord.gg/?v=10&encoding=json"
|
||||||
intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
|
intents: int = 37377 # GUILDS + GUILD_MESSAGES + DIRECT_MESSAGES + MESSAGE_CONTENT
|
||||||
|
group_policy: Literal["mention", "open"] = "mention"
|
||||||
|
|
||||||
|
|
||||||
class MatrixConfig(Base):
|
class MatrixConfig(Base):
|
||||||
@ -72,9 +77,13 @@ class MatrixConfig(Base):
|
|||||||
access_token: str = ""
|
access_token: str = ""
|
||||||
user_id: str = "" # @bot:matrix.org
|
user_id: str = "" # @bot:matrix.org
|
||||||
device_id: str = ""
|
device_id: str = ""
|
||||||
e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
|
e2ee_enabled: bool = True # Enable Matrix E2EE support (encryption + encrypted room handling).
|
||||||
sync_stop_grace_seconds: int = 2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
|
sync_stop_grace_seconds: int = (
|
||||||
max_media_bytes: int = 20 * 1024 * 1024 # Max attachment size accepted for Matrix media handling (inbound + outbound).
|
2 # Max seconds to wait for sync_forever to stop gracefully before cancellation fallback.
|
||||||
|
)
|
||||||
|
max_media_bytes: int = (
|
||||||
|
20 * 1024 * 1024
|
||||||
|
) # Max attachment size accepted for Matrix media handling (inbound + outbound).
|
||||||
allow_from: list[str] = Field(default_factory=list)
|
allow_from: list[str] = Field(default_factory=list)
|
||||||
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
||||||
group_allow_from: list[str] = Field(default_factory=list)
|
group_allow_from: list[str] = Field(default_factory=list)
|
||||||
@ -105,7 +114,9 @@ class EmailConfig(Base):
|
|||||||
from_address: str = ""
|
from_address: str = ""
|
||||||
|
|
||||||
# Behavior
|
# Behavior
|
||||||
auto_reply_enabled: bool = True # If false, inbound email is read but no automatic reply is sent
|
auto_reply_enabled: bool = (
|
||||||
|
True # If false, inbound email is read but no automatic reply is sent
|
||||||
|
)
|
||||||
poll_interval_seconds: int = 30
|
poll_interval_seconds: int = 30
|
||||||
mark_seen: bool = True
|
mark_seen: bool = True
|
||||||
max_body_chars: int = 12000
|
max_body_chars: int = 12000
|
||||||
@ -171,6 +182,7 @@ class SlackConfig(Base):
|
|||||||
user_token_read_only: bool = True
|
user_token_read_only: bool = True
|
||||||
reply_in_thread: bool = True
|
reply_in_thread: bool = True
|
||||||
react_emoji: str = "eyes"
|
react_emoji: str = "eyes"
|
||||||
|
allow_from: list[str] = Field(default_factory=list) # Allowed Slack user IDs (sender-level)
|
||||||
group_policy: str = "mention" # "mention", "open", "allowlist"
|
group_policy: str = "mention" # "mention", "open", "allowlist"
|
||||||
group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist
|
group_allow_from: list[str] = Field(default_factory=list) # Allowed channel IDs if allowlist
|
||||||
dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
|
dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
|
||||||
@ -182,27 +194,17 @@ class QQConfig(Base):
|
|||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
app_id: str = "" # 机器人 ID (AppID) from q.qq.com
|
app_id: str = "" # 机器人 ID (AppID) from q.qq.com
|
||||||
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
|
secret: str = "" # 机器人密钥 (AppSecret) from q.qq.com
|
||||||
allow_from: list[str] = Field(default_factory=list) # Allowed user openids (empty = public access)
|
allow_from: list[str] = Field(
|
||||||
|
default_factory=list
|
||||||
|
) # Allowed user openids (empty = public access)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class MatrixConfig(Base):
|
|
||||||
"""Matrix (Element) channel configuration."""
|
|
||||||
enabled: bool = False
|
|
||||||
homeserver: str = "https://matrix.org"
|
|
||||||
access_token: str = ""
|
|
||||||
user_id: str = "" # e.g. @bot:matrix.org
|
|
||||||
device_id: str = ""
|
|
||||||
e2ee_enabled: bool = True # end-to-end encryption support
|
|
||||||
sync_stop_grace_seconds: int = 2 # graceful sync_forever shutdown timeout
|
|
||||||
max_media_bytes: int = 20 * 1024 * 1024 # inbound + outbound attachment limit
|
|
||||||
allow_from: list[str] = Field(default_factory=list)
|
|
||||||
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
|
||||||
group_allow_from: list[str] = Field(default_factory=list)
|
|
||||||
allow_room_mentions: bool = False
|
|
||||||
|
|
||||||
class ChannelsConfig(Base):
|
class ChannelsConfig(Base):
|
||||||
"""Configuration for chat channels."""
|
"""Configuration for chat channels."""
|
||||||
|
|
||||||
send_progress: bool = True # stream agent's text progress to the channel
|
send_progress: bool = True # stream agent's text progress to the channel
|
||||||
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
|
||||||
whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
|
whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
|
||||||
telegram: TelegramConfig = Field(default_factory=TelegramConfig)
|
telegram: TelegramConfig = Field(default_factory=TelegramConfig)
|
||||||
@ -221,11 +223,14 @@ class AgentDefaults(Base):
|
|||||||
|
|
||||||
workspace: str = "~/.nanobot/workspace"
|
workspace: str = "~/.nanobot/workspace"
|
||||||
model: str = "anthropic/claude-opus-4-5"
|
model: str = "anthropic/claude-opus-4-5"
|
||||||
provider: str = "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
provider: str = (
|
||||||
|
"auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
|
||||||
|
)
|
||||||
max_tokens: int = 8192
|
max_tokens: int = 8192
|
||||||
temperature: float = 0.1
|
temperature: float = 0.1
|
||||||
max_tool_iterations: int = 40
|
max_tool_iterations: int = 40
|
||||||
memory_window: int = 100
|
memory_window: int = 100
|
||||||
|
reasoning_effort: str | None = None # low / medium / high — enables LLM thinking mode
|
||||||
|
|
||||||
|
|
||||||
class AgentsConfig(Base):
|
class AgentsConfig(Base):
|
||||||
@ -258,8 +263,8 @@ class ProvidersConfig(Base):
|
|||||||
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
minimax: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||||
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) API gateway
|
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
||||||
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) API gateway
|
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
||||||
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
openai_codex: ProviderConfig = Field(default_factory=ProviderConfig) # OpenAI Codex (OAuth)
|
||||||
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
github_copilot: ProviderConfig = Field(default_factory=ProviderConfig) # Github Copilot (OAuth)
|
||||||
|
|
||||||
@ -289,6 +294,9 @@ class WebSearchConfig(Base):
|
|||||||
class WebToolsConfig(Base):
|
class WebToolsConfig(Base):
|
||||||
"""Web tools configuration."""
|
"""Web tools configuration."""
|
||||||
|
|
||||||
|
proxy: str | None = (
|
||||||
|
None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
|
||||||
|
)
|
||||||
search: WebSearchConfig = Field(default_factory=WebSearchConfig)
|
search: WebSearchConfig = Field(default_factory=WebSearchConfig)
|
||||||
|
|
||||||
|
|
||||||
@ -302,12 +310,13 @@ class ExecToolConfig(Base):
|
|||||||
class MCPServerConfig(Base):
|
class MCPServerConfig(Base):
|
||||||
"""MCP server connection configuration (stdio or HTTP)."""
|
"""MCP server connection configuration (stdio or HTTP)."""
|
||||||
|
|
||||||
|
type: Literal["stdio", "sse", "streamableHttp"] | None = None # auto-detected if omitted
|
||||||
command: str = "" # Stdio: command to run (e.g. "npx")
|
command: str = "" # Stdio: command to run (e.g. "npx")
|
||||||
args: list[str] = Field(default_factory=list) # Stdio: command arguments
|
args: list[str] = Field(default_factory=list) # Stdio: command arguments
|
||||||
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
|
env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
|
||||||
url: str = "" # HTTP: streamable HTTP endpoint URL
|
url: str = "" # HTTP/SSE: endpoint URL
|
||||||
headers: dict[str, str] = Field(default_factory=dict) # HTTP: Custom HTTP Headers
|
headers: dict[str, str] = Field(default_factory=dict) # HTTP/SSE: custom headers
|
||||||
tool_timeout: int = 30 # Seconds before a tool call is cancelled
|
tool_timeout: int = 30 # seconds before a tool call is cancelled
|
||||||
|
|
||||||
|
|
||||||
class ToolsConfig(Base):
|
class ToolsConfig(Base):
|
||||||
@ -333,7 +342,9 @@ class Config(BaseSettings):
|
|||||||
"""Get expanded workspace path."""
|
"""Get expanded workspace path."""
|
||||||
return Path(self.agents.defaults.workspace).expanduser()
|
return Path(self.agents.defaults.workspace).expanduser()
|
||||||
|
|
||||||
def _match_provider(self, model: str | None = None) -> tuple["ProviderConfig | None", str | None]:
|
def _match_provider(
|
||||||
|
self, model: str | None = None
|
||||||
|
) -> tuple["ProviderConfig | None", str | None]:
|
||||||
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
"""Match provider config and its registry name. Returns (config, spec_name)."""
|
||||||
from nanobot.providers.registry import PROVIDERS
|
from nanobot.providers.registry import PROVIDERS
|
||||||
|
|
||||||
|
|||||||
@ -30,8 +30,9 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
|
|||||||
|
|
||||||
if schedule.kind == "cron" and schedule.expr:
|
if schedule.kind == "cron" and schedule.expr:
|
||||||
try:
|
try:
|
||||||
from croniter import croniter
|
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
from croniter import croniter
|
||||||
# Use caller-provided reference time for deterministic scheduling
|
# Use caller-provided reference time for deterministic scheduling
|
||||||
base_time = now_ms / 1000
|
base_time = now_ms / 1000
|
||||||
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
|
tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
|
||||||
@ -68,13 +69,19 @@ class CronService:
|
|||||||
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None
|
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None
|
||||||
):
|
):
|
||||||
self.store_path = store_path
|
self.store_path = store_path
|
||||||
self.on_job = on_job # Callback to execute job, returns response text
|
self.on_job = on_job
|
||||||
self._store: CronStore | None = None
|
self._store: CronStore | None = None
|
||||||
|
self._last_mtime: float = 0.0
|
||||||
self._timer_task: asyncio.Task | None = None
|
self._timer_task: asyncio.Task | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
def _load_store(self) -> CronStore:
|
def _load_store(self) -> CronStore:
|
||||||
"""Load jobs from disk."""
|
"""Load jobs from disk. Reloads automatically if file was modified externally."""
|
||||||
|
if self._store and self.store_path.exists():
|
||||||
|
mtime = self.store_path.stat().st_mtime
|
||||||
|
if mtime != self._last_mtime:
|
||||||
|
logger.info("Cron: jobs.json modified externally, reloading")
|
||||||
|
self._store = None
|
||||||
if self._store:
|
if self._store:
|
||||||
return self._store
|
return self._store
|
||||||
|
|
||||||
@ -163,6 +170,7 @@ class CronService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
|
self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
|
||||||
|
self._last_mtime = self.store_path.stat().st_mtime
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the cron service."""
|
"""Start the cron service."""
|
||||||
@ -218,6 +226,7 @@ class CronService:
|
|||||||
|
|
||||||
async def _on_timer(self) -> None:
|
async def _on_timer(self) -> None:
|
||||||
"""Handle timer tick - run due jobs."""
|
"""Handle timer tick - run due jobs."""
|
||||||
|
self._load_store()
|
||||||
if not self._store:
|
if not self._store:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,7 @@ class LLMResponse:
|
|||||||
finish_reason: str = "stop"
|
finish_reason: str = "stop"
|
||||||
usage: dict[str, int] = field(default_factory=dict)
|
usage: dict[str, int] = field(default_factory=dict)
|
||||||
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
|
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
|
||||||
|
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_tool_calls(self) -> bool:
|
def has_tool_calls(self) -> bool:
|
||||||
@ -77,6 +78,12 @@ class LLMProvider(ABC):
|
|||||||
result.append(clean)
|
result.append(clean)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if isinstance(content, dict):
|
||||||
|
clean = dict(msg)
|
||||||
|
clean["content"] = [content]
|
||||||
|
result.append(clean)
|
||||||
|
continue
|
||||||
|
|
||||||
result.append(msg)
|
result.append(msg)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -88,6 +95,7 @@ class LLMProvider(ABC):
|
|||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
Send a chat completion request.
|
Send a chat completion request.
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
@ -15,16 +16,24 @@ class CustomProvider(LLMProvider):
|
|||||||
def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"):
|
def __init__(self, api_key: str = "no-key", api_base: str = "http://localhost:8000/v1", default_model: str = "default"):
|
||||||
super().__init__(api_key, api_base)
|
super().__init__(api_key, api_base)
|
||||||
self.default_model = default_model
|
self.default_model = default_model
|
||||||
self._client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
# Keep affinity stable for this provider instance to improve backend cache locality.
|
||||||
|
self._client = AsyncOpenAI(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=api_base,
|
||||||
|
default_headers={"x-session-affinity": uuid.uuid4().hex},
|
||||||
|
)
|
||||||
|
|
||||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7) -> LLMResponse:
|
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None) -> LLMResponse:
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model or self.default_model,
|
"model": model or self.default_model,
|
||||||
"messages": self._sanitize_empty_content(messages),
|
"messages": self._sanitize_empty_content(messages),
|
||||||
"max_tokens": max(1, max_tokens),
|
"max_tokens": max(1, max_tokens),
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
|
if reasoning_effort:
|
||||||
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
if tools:
|
if tools:
|
||||||
kwargs.update(tools=tools, tool_choice="auto")
|
kwargs.update(tools=tools, tool_choice="auto")
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,22 +1,21 @@
|
|||||||
"""LiteLLM provider implementation for multi-provider support."""
|
"""LiteLLM provider implementation for multi-provider support."""
|
||||||
|
|
||||||
import json
|
|
||||||
import json_repair
|
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import json_repair
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import acompletion
|
from litellm import acompletion
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
from nanobot.providers.registry import find_by_model, find_gateway
|
from nanobot.providers.registry import find_by_model, find_gateway
|
||||||
|
|
||||||
|
# Standard chat-completion message keys.
|
||||||
# Standard OpenAI chat-completion message keys plus reasoning_content for
|
|
||||||
# thinking-enabled models (Kimi k2.5, DeepSeek-R1, etc.).
|
|
||||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
||||||
|
_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"})
|
||||||
_ALNUM = string.ascii_letters + string.digits
|
_ALNUM = string.ascii_letters + string.digits
|
||||||
|
|
||||||
def _short_tool_id() -> str:
|
def _short_tool_id() -> str:
|
||||||
@ -160,11 +159,20 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
return
|
return
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sanitize_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]:
|
||||||
|
"""Return provider-specific extra keys to preserve in request messages."""
|
||||||
|
spec = find_by_model(original_model) or find_by_model(resolved_model)
|
||||||
|
if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"):
|
||||||
|
return _ANTHROPIC_EXTRA_KEYS
|
||||||
|
return frozenset()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
|
||||||
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
||||||
|
allowed = _ALLOWED_MSG_KEYS | extra_keys
|
||||||
sanitized = []
|
sanitized = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
clean = {k: v for k, v in msg.items() if k in _ALLOWED_MSG_KEYS}
|
clean = {k: v for k, v in msg.items() if k in allowed}
|
||||||
# Strict providers require "content" even when assistant only has tool_calls
|
# Strict providers require "content" even when assistant only has tool_calls
|
||||||
if clean.get("role") == "assistant" and "content" not in clean:
|
if clean.get("role") == "assistant" and "content" not in clean:
|
||||||
clean["content"] = None
|
clean["content"] = None
|
||||||
@ -178,6 +186,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
Send a chat completion request via LiteLLM.
|
Send a chat completion request via LiteLLM.
|
||||||
@ -194,6 +203,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
"""
|
"""
|
||||||
original_model = model or self.default_model
|
original_model = model or self.default_model
|
||||||
model = self._resolve_model(original_model)
|
model = self._resolve_model(original_model)
|
||||||
|
extra_msg_keys = self._extra_msg_keys(original_model, model)
|
||||||
|
|
||||||
if self._supports_cache_control(original_model):
|
if self._supports_cache_control(original_model):
|
||||||
messages, tools = self._apply_cache_control(messages, tools)
|
messages, tools = self._apply_cache_control(messages, tools)
|
||||||
@ -204,7 +214,7 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
"messages": self._sanitize_messages(self._sanitize_empty_content(messages), extra_keys=extra_msg_keys),
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
}
|
}
|
||||||
@ -224,6 +234,10 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
if self.extra_headers:
|
if self.extra_headers:
|
||||||
kwargs["extra_headers"] = self.extra_headers
|
kwargs["extra_headers"] = self.extra_headers
|
||||||
|
|
||||||
|
if reasoning_effort:
|
||||||
|
kwargs["reasoning_effort"] = reasoning_effort
|
||||||
|
kwargs["drop_params"] = True
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
kwargs["tool_choice"] = "auto"
|
kwargs["tool_choice"] = "auto"
|
||||||
@ -242,20 +256,37 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
"""Parse LiteLLM response into our standard format."""
|
"""Parse LiteLLM response into our standard format."""
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
message = choice.message
|
message = choice.message
|
||||||
|
content = message.content
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
|
||||||
|
# Some providers (e.g. GitHub Copilot) split content and tool_calls
|
||||||
|
# across multiple choices. Merge them so tool_calls are not lost.
|
||||||
|
raw_tool_calls = []
|
||||||
|
for ch in response.choices:
|
||||||
|
msg = ch.message
|
||||||
|
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||||
|
raw_tool_calls.extend(msg.tool_calls)
|
||||||
|
if ch.finish_reason in ("tool_calls", "stop"):
|
||||||
|
finish_reason = ch.finish_reason
|
||||||
|
if not content and msg.content:
|
||||||
|
content = msg.content
|
||||||
|
|
||||||
|
if len(response.choices) > 1:
|
||||||
|
logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
|
||||||
|
len(response.choices), len(raw_tool_calls))
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
for tc in raw_tool_calls:
|
||||||
for tc in message.tool_calls:
|
# Parse arguments from JSON string if needed
|
||||||
# Parse arguments from JSON string if needed
|
args = tc.function.arguments
|
||||||
args = tc.function.arguments
|
if isinstance(args, str):
|
||||||
if isinstance(args, str):
|
args = json_repair.loads(args)
|
||||||
args = json_repair.loads(args)
|
|
||||||
|
|
||||||
tool_calls.append(ToolCallRequest(
|
tool_calls.append(ToolCallRequest(
|
||||||
id=_short_tool_id(),
|
id=_short_tool_id(),
|
||||||
name=tc.function.name,
|
name=tc.function.name,
|
||||||
arguments=args,
|
arguments=args,
|
||||||
))
|
))
|
||||||
|
|
||||||
usage = {}
|
usage = {}
|
||||||
if hasattr(response, "usage") and response.usage:
|
if hasattr(response, "usage") and response.usage:
|
||||||
@ -266,13 +297,15 @@ class LiteLLMProvider(LLMProvider):
|
|||||||
}
|
}
|
||||||
|
|
||||||
reasoning_content = getattr(message, "reasoning_content", None) or None
|
reasoning_content = getattr(message, "reasoning_content", None) or None
|
||||||
|
thinking_blocks = getattr(message, "thinking_blocks", None) or None
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=message.content,
|
content=content,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
finish_reason=choice.finish_reason or "stop",
|
finish_reason=finish_reason or "stop",
|
||||||
usage=usage,
|
usage=usage,
|
||||||
reasoning_content=reasoning_content,
|
reasoning_content=reasoning_content,
|
||||||
|
thinking_blocks=thinking_blocks,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
|
|||||||
@ -9,8 +9,8 @@ from typing import Any, AsyncGenerator
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from oauth_cli_kit import get_token as get_codex_token
|
from oauth_cli_kit import get_token as get_codex_token
|
||||||
|
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses"
|
||||||
@ -31,6 +31,7 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_tokens: int = 4096,
|
max_tokens: int = 4096,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
model = model or self.default_model
|
model = model or self.default_model
|
||||||
system_prompt, input_items = _convert_messages(messages)
|
system_prompt, input_items = _convert_messages(messages)
|
||||||
@ -51,6 +52,9 @@ class OpenAICodexProvider(LLMProvider):
|
|||||||
"parallel_tool_calls": True,
|
"parallel_tool_calls": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if reasoning_effort:
|
||||||
|
body["reasoning"] = {"effort": reasoning_effort}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
body["tools"] = _convert_tools(tools)
|
body["tools"] = _convert_tools(tools)
|
||||||
|
|
||||||
|
|||||||
@ -26,33 +26,33 @@ class ProviderSpec:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# identity
|
# identity
|
||||||
name: str # config field name, e.g. "dashscope"
|
name: str # config field name, e.g. "dashscope"
|
||||||
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
||||||
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
|
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
|
||||||
display_name: str = "" # shown in `nanobot status`
|
display_name: str = "" # shown in `nanobot status`
|
||||||
|
|
||||||
# model prefixing
|
# model prefixing
|
||||||
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
|
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
|
||||||
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
|
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
|
||||||
|
|
||||||
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
||||||
env_extras: tuple[tuple[str, str], ...] = ()
|
env_extras: tuple[tuple[str, str], ...] = ()
|
||||||
|
|
||||||
# gateway / local detection
|
# gateway / local detection
|
||||||
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
|
is_gateway: bool = False # routes any model (OpenRouter, AiHubMix)
|
||||||
is_local: bool = False # local deployment (vLLM, Ollama)
|
is_local: bool = False # local deployment (vLLM, Ollama)
|
||||||
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
||||||
detect_by_base_keyword: str = "" # match substring in api_base URL
|
detect_by_base_keyword: str = "" # match substring in api_base URL
|
||||||
default_api_base: str = "" # fallback base URL
|
default_api_base: str = "" # fallback base URL
|
||||||
|
|
||||||
# gateway behavior
|
# gateway behavior
|
||||||
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
|
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
|
||||||
|
|
||||||
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
|
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
|
||||||
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
||||||
|
|
||||||
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
|
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
|
||||||
is_oauth: bool = False # if True, uses OAuth flow instead of API key
|
is_oauth: bool = False # if True, uses OAuth flow instead of API key
|
||||||
|
|
||||||
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
|
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
|
||||||
is_direct: bool = False
|
is_direct: bool = False
|
||||||
@ -70,7 +70,6 @@ class ProviderSpec:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||||
|
|
||||||
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
|
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="custom",
|
name="custom",
|
||||||
@ -80,17 +79,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
litellm_prefix="",
|
litellm_prefix="",
|
||||||
is_direct=True,
|
is_direct=True,
|
||||||
),
|
),
|
||||||
|
|
||||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||||
# Gateways can route any model, so they win in fallback.
|
# Gateways can route any model, so they win in fallback.
|
||||||
|
|
||||||
# OpenRouter: global gateway, keys start with "sk-or-"
|
# OpenRouter: global gateway, keys start with "sk-or-"
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="openrouter",
|
name="openrouter",
|
||||||
keywords=("openrouter",),
|
keywords=("openrouter",),
|
||||||
env_key="OPENROUTER_API_KEY",
|
env_key="OPENROUTER_API_KEY",
|
||||||
display_name="OpenRouter",
|
display_name="OpenRouter",
|
||||||
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
|
litellm_prefix="openrouter", # claude-3 → openrouter/claude-3
|
||||||
skip_prefixes=(),
|
skip_prefixes=(),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
@ -102,16 +99,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
supports_prompt_caching=True,
|
supports_prompt_caching=True,
|
||||||
),
|
),
|
||||||
|
|
||||||
# AiHubMix: global gateway, OpenAI-compatible interface.
|
# AiHubMix: global gateway, OpenAI-compatible interface.
|
||||||
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
||||||
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="aihubmix",
|
name="aihubmix",
|
||||||
keywords=("aihubmix",),
|
keywords=("aihubmix",),
|
||||||
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
||||||
display_name="AiHubMix",
|
display_name="AiHubMix",
|
||||||
litellm_prefix="openai", # → openai/{model}
|
litellm_prefix="openai", # → openai/{model}
|
||||||
skip_prefixes=(),
|
skip_prefixes=(),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=True,
|
is_gateway=True,
|
||||||
@ -119,10 +115,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
detect_by_key_prefix="",
|
detect_by_key_prefix="",
|
||||||
detect_by_base_keyword="aihubmix",
|
detect_by_base_keyword="aihubmix",
|
||||||
default_api_base="https://aihubmix.com/v1",
|
default_api_base="https://aihubmix.com/v1",
|
||||||
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
|
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="siliconflow",
|
name="siliconflow",
|
||||||
@ -140,7 +135,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# VolcEngine (火山引擎): OpenAI-compatible gateway
|
# VolcEngine (火山引擎): OpenAI-compatible gateway
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="volcengine",
|
name="volcengine",
|
||||||
@ -158,9 +152,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# === Standard providers (matched by model-name keywords) ===============
|
# === Standard providers (matched by model-name keywords) ===============
|
||||||
|
|
||||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="anthropic",
|
name="anthropic",
|
||||||
@ -179,7 +171,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
supports_prompt_caching=True,
|
supports_prompt_caching=True,
|
||||||
),
|
),
|
||||||
|
|
||||||
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="openai",
|
name="openai",
|
||||||
@ -197,14 +188,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# OpenAI Codex: uses OAuth, not API key.
|
# OpenAI Codex: uses OAuth, not API key.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="openai_codex",
|
name="openai_codex",
|
||||||
keywords=("openai-codex",),
|
keywords=("openai-codex",),
|
||||||
env_key="", # OAuth-based, no API key
|
env_key="", # OAuth-based, no API key
|
||||||
display_name="OpenAI Codex",
|
display_name="OpenAI Codex",
|
||||||
litellm_prefix="", # Not routed through LiteLLM
|
litellm_prefix="", # Not routed through LiteLLM
|
||||||
skip_prefixes=(),
|
skip_prefixes=(),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
@ -214,16 +204,15 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
default_api_base="https://chatgpt.com/backend-api",
|
default_api_base="https://chatgpt.com/backend-api",
|
||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
is_oauth=True, # OAuth-based authentication
|
is_oauth=True, # OAuth-based authentication
|
||||||
),
|
),
|
||||||
|
|
||||||
# Github Copilot: uses OAuth, not API key.
|
# Github Copilot: uses OAuth, not API key.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="github_copilot",
|
name="github_copilot",
|
||||||
keywords=("github_copilot", "copilot"),
|
keywords=("github_copilot", "copilot"),
|
||||||
env_key="", # OAuth-based, no API key
|
env_key="", # OAuth-based, no API key
|
||||||
display_name="Github Copilot",
|
display_name="Github Copilot",
|
||||||
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
|
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
|
||||||
skip_prefixes=("github_copilot/",),
|
skip_prefixes=("github_copilot/",),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
@ -233,17 +222,16 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
default_api_base="",
|
default_api_base="",
|
||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
is_oauth=True, # OAuth-based authentication
|
is_oauth=True, # OAuth-based authentication
|
||||||
),
|
),
|
||||||
|
|
||||||
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="deepseek",
|
name="deepseek",
|
||||||
keywords=("deepseek",),
|
keywords=("deepseek",),
|
||||||
env_key="DEEPSEEK_API_KEY",
|
env_key="DEEPSEEK_API_KEY",
|
||||||
display_name="DeepSeek",
|
display_name="DeepSeek",
|
||||||
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
|
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
|
||||||
skip_prefixes=("deepseek/",), # avoid double-prefix
|
skip_prefixes=("deepseek/",), # avoid double-prefix
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=False,
|
is_local=False,
|
||||||
@ -253,15 +241,14 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# Gemini: needs "gemini/" prefix for LiteLLM.
|
# Gemini: needs "gemini/" prefix for LiteLLM.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="gemini",
|
name="gemini",
|
||||||
keywords=("gemini",),
|
keywords=("gemini",),
|
||||||
env_key="GEMINI_API_KEY",
|
env_key="GEMINI_API_KEY",
|
||||||
display_name="Gemini",
|
display_name="Gemini",
|
||||||
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
||||||
skip_prefixes=("gemini/",), # avoid double-prefix
|
skip_prefixes=("gemini/",), # avoid double-prefix
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=False,
|
is_local=False,
|
||||||
@ -271,7 +258,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# Zhipu: LiteLLM uses "zai/" prefix.
|
# Zhipu: LiteLLM uses "zai/" prefix.
|
||||||
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
||||||
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
||||||
@ -280,11 +266,9 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("zhipu", "glm", "zai"),
|
keywords=("zhipu", "glm", "zai"),
|
||||||
env_key="ZAI_API_KEY",
|
env_key="ZAI_API_KEY",
|
||||||
display_name="Zhipu AI",
|
display_name="Zhipu AI",
|
||||||
litellm_prefix="zai", # glm-4 → zai/glm-4
|
litellm_prefix="zai", # glm-4 → zai/glm-4
|
||||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||||
env_extras=(
|
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
|
||||||
("ZHIPUAI_API_KEY", "{api_key}"),
|
|
||||||
),
|
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=False,
|
is_local=False,
|
||||||
detect_by_key_prefix="",
|
detect_by_key_prefix="",
|
||||||
@ -293,14 +277,13 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# DashScope: Qwen models, needs "dashscope/" prefix.
|
# DashScope: Qwen models, needs "dashscope/" prefix.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="dashscope",
|
name="dashscope",
|
||||||
keywords=("qwen", "dashscope"),
|
keywords=("qwen", "dashscope"),
|
||||||
env_key="DASHSCOPE_API_KEY",
|
env_key="DASHSCOPE_API_KEY",
|
||||||
display_name="DashScope",
|
display_name="DashScope",
|
||||||
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
||||||
skip_prefixes=("dashscope/", "openrouter/"),
|
skip_prefixes=("dashscope/", "openrouter/"),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
@ -311,7 +294,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
||||||
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
||||||
# Kimi K2.5 API enforces temperature >= 1.0.
|
# Kimi K2.5 API enforces temperature >= 1.0.
|
||||||
@ -320,22 +302,17 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("moonshot", "kimi"),
|
keywords=("moonshot", "kimi"),
|
||||||
env_key="MOONSHOT_API_KEY",
|
env_key="MOONSHOT_API_KEY",
|
||||||
display_name="Moonshot",
|
display_name="Moonshot",
|
||||||
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
||||||
skip_prefixes=("moonshot/", "openrouter/"),
|
skip_prefixes=("moonshot/", "openrouter/"),
|
||||||
env_extras=(
|
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
|
||||||
("MOONSHOT_API_BASE", "{api_base}"),
|
|
||||||
),
|
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=False,
|
is_local=False,
|
||||||
detect_by_key_prefix="",
|
detect_by_key_prefix="",
|
||||||
detect_by_base_keyword="",
|
detect_by_base_keyword="",
|
||||||
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(
|
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
|
||||||
("kimi-k2.5", {"temperature": 1.0}),
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
|
|
||||||
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
||||||
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
@ -343,7 +320,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("minimax",),
|
keywords=("minimax",),
|
||||||
env_key="MINIMAX_API_KEY",
|
env_key="MINIMAX_API_KEY",
|
||||||
display_name="MiniMax",
|
display_name="MiniMax",
|
||||||
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
||||||
skip_prefixes=("minimax/", "openrouter/"),
|
skip_prefixes=("minimax/", "openrouter/"),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
@ -354,9 +331,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||||
|
|
||||||
# vLLM / any OpenAI-compatible local server.
|
# vLLM / any OpenAI-compatible local server.
|
||||||
# Detected when config key is "vllm" (provider_name="vllm").
|
# Detected when config key is "vllm" (provider_name="vllm").
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
@ -364,20 +339,18 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("vllm",),
|
keywords=("vllm",),
|
||||||
env_key="HOSTED_VLLM_API_KEY",
|
env_key="HOSTED_VLLM_API_KEY",
|
||||||
display_name="vLLM/Local",
|
display_name="vLLM/Local",
|
||||||
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
|
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
|
||||||
skip_prefixes=(),
|
skip_prefixes=(),
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=True,
|
is_local=True,
|
||||||
detect_by_key_prefix="",
|
detect_by_key_prefix="",
|
||||||
detect_by_base_keyword="",
|
detect_by_base_keyword="",
|
||||||
default_api_base="", # user must provide in config
|
default_api_base="", # user must provide in config
|
||||||
strip_model_prefix=False,
|
strip_model_prefix=False,
|
||||||
model_overrides=(),
|
model_overrides=(),
|
||||||
),
|
),
|
||||||
|
|
||||||
# === Auxiliary (not a primary LLM provider) ============================
|
# === Auxiliary (not a primary LLM provider) ============================
|
||||||
|
|
||||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
||||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
@ -385,8 +358,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
keywords=("groq",),
|
keywords=("groq",),
|
||||||
env_key="GROQ_API_KEY",
|
env_key="GROQ_API_KEY",
|
||||||
display_name="Groq",
|
display_name="Groq",
|
||||||
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
||||||
skip_prefixes=("groq/",), # avoid double-prefix
|
skip_prefixes=("groq/",), # avoid double-prefix
|
||||||
env_extras=(),
|
env_extras=(),
|
||||||
is_gateway=False,
|
is_gateway=False,
|
||||||
is_local=False,
|
is_local=False,
|
||||||
@ -403,6 +376,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
# Lookup helpers
|
# Lookup helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def find_by_model(model: str) -> ProviderSpec | None:
|
def find_by_model(model: str) -> ProviderSpec | None:
|
||||||
"""Match a standard provider by model-name keyword (case-insensitive).
|
"""Match a standard provider by model-name keyword (case-insensitive).
|
||||||
Skips gateways/local — those are matched by api_key/api_base instead."""
|
Skips gateways/local — those are matched by api_key/api_base instead."""
|
||||||
@ -418,7 +392,9 @@ def find_by_model(model: str) -> ProviderSpec | None:
|
|||||||
return spec
|
return spec
|
||||||
|
|
||||||
for spec in std_specs:
|
for spec in std_specs:
|
||||||
if any(kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords):
|
if any(
|
||||||
|
kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
|
||||||
|
):
|
||||||
return spec
|
return spec
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""Session management module."""
|
"""Session management module."""
|
||||||
|
|
||||||
from nanobot.session.manager import SessionManager, Session
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
|
||||||
__all__ = ["SessionManager", "Session"]
|
__all__ = ["SessionManager", "Session"]
|
||||||
|
|||||||
@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|||||||
@ -4,17 +4,15 @@ You are a helpful AI assistant. Be concise, accurate, and friendly.
|
|||||||
|
|
||||||
## Scheduled Reminders
|
## Scheduled Reminders
|
||||||
|
|
||||||
When user asks for a reminder at a specific time, use `exec` to run:
|
Before scheduling reminders, check available skills and follow skill guidance first.
|
||||||
```
|
Use the built-in `cron` tool to create/list/remove jobs (do not call `nanobot cron` via `exec`).
|
||||||
nanobot cron add --name "reminder" --message "Your message" --at "YYYY-MM-DDTHH:MM:SS" --deliver --to "USER_ID" --channel "CHANNEL"
|
|
||||||
```
|
|
||||||
Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`).
|
Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegram` from `telegram:8281248569`).
|
||||||
|
|
||||||
**Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications.
|
**Do NOT just write reminders to MEMORY.md** — that won't trigger actual notifications.
|
||||||
|
|
||||||
## Heartbeat Tasks
|
## Heartbeat Tasks
|
||||||
|
|
||||||
`HEARTBEAT.md` is checked every 30 minutes. Use file tools to manage periodic tasks:
|
`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks:
|
||||||
|
|
||||||
- **Add**: `edit_file` to append new tasks
|
- **Add**: `edit_file` to append new tasks
|
||||||
- **Remove**: `edit_file` to delete completed tasks
|
- **Remove**: `edit_file` to delete completed tasks
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""Utility functions for nanobot."""
|
"""Utility functions for nanobot."""
|
||||||
|
|
||||||
from nanobot.utils.helpers import ensure_dir, get_workspace_path, get_data_path
|
from nanobot.utils.helpers import ensure_dir, get_data_path, get_workspace_path
|
||||||
|
|
||||||
__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"]
|
__all__ = ["ensure_dir", "get_workspace_path", "get_data_path"]
|
||||||
|
|||||||
@ -1,8 +1,21 @@
|
|||||||
"""Utility functions for nanobot."""
|
"""Utility functions for nanobot."""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def detect_image_mime(data: bytes) -> str | None:
|
||||||
|
"""Detect image MIME type from magic bytes, ignoring file extension."""
|
||||||
|
if data[:8] == b"\x89PNG\r\n\x1a\n":
|
||||||
|
return "image/png"
|
||||||
|
if data[:3] == b"\xff\xd8\xff":
|
||||||
|
return "image/jpeg"
|
||||||
|
if data[:6] in (b"GIF87a", b"GIF89a"):
|
||||||
|
return "image/gif"
|
||||||
|
if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
|
||||||
|
return "image/webp"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def ensure_dir(path: Path) -> Path:
|
def ensure_dir(path: Path) -> Path:
|
||||||
@ -34,6 +47,38 @@ def safe_filename(name: str) -> str:
|
|||||||
return _UNSAFE_CHARS.sub("_", name).strip()
|
return _UNSAFE_CHARS.sub("_", name).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def split_message(content: str, max_len: int = 2000) -> list[str]:
|
||||||
|
"""
|
||||||
|
Split content into chunks within max_len, preferring line breaks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: The text content to split.
|
||||||
|
max_len: Maximum length per chunk (default 2000 for Discord compatibility).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of message chunks, each within max_len.
|
||||||
|
"""
|
||||||
|
if not content:
|
||||||
|
return []
|
||||||
|
if len(content) <= max_len:
|
||||||
|
return [content]
|
||||||
|
chunks: list[str] = []
|
||||||
|
while content:
|
||||||
|
if len(content) <= max_len:
|
||||||
|
chunks.append(content)
|
||||||
|
break
|
||||||
|
cut = content[:max_len]
|
||||||
|
# Try to break at newline first, then space, then hard break
|
||||||
|
pos = cut.rfind('\n')
|
||||||
|
if pos <= 0:
|
||||||
|
pos = cut.rfind(' ')
|
||||||
|
if pos <= 0:
|
||||||
|
pos = max_len
|
||||||
|
chunks.append(content[:pos])
|
||||||
|
content = content[pos:].lstrip()
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||||
"""Sync bundled templates to workspace. Only creates missing files."""
|
"""Sync bundled templates to workspace. Only creates missing files."""
|
||||||
from importlib.resources import files as pkg_files
|
from importlib.resources import files as pkg_files
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "nanobot-ai"
|
name = "nanobot-ai"
|
||||||
version = "0.1.4.post2"
|
version = "0.1.4.post3"
|
||||||
description = "A lightweight personal AI assistant framework"
|
description = "A lightweight personal AI assistant framework"
|
||||||
requires-python = ">=3.11"
|
requires-python = ">=3.11"
|
||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
@ -30,7 +30,7 @@ dependencies = [
|
|||||||
"rich>=14.0.0,<15.0.0",
|
"rich>=14.0.0,<15.0.0",
|
||||||
"croniter>=6.0.0,<7.0.0",
|
"croniter>=6.0.0,<7.0.0",
|
||||||
"dingtalk-stream>=0.24.0,<1.0.0",
|
"dingtalk-stream>=0.24.0,<1.0.0",
|
||||||
"python-telegram-bot[socks]>=22.0,<23.0",
|
"python-telegram-bot[socks]>=22.6,<23.0",
|
||||||
"lark-oapi>=1.5.0,<2.0.0",
|
"lark-oapi>=1.5.0,<2.0.0",
|
||||||
"socksio>=1.0.0,<2.0.0",
|
"socksio>=1.0.0,<2.0.0",
|
||||||
"python-socketio>=5.16.0,<6.0.0",
|
"python-socketio>=5.16.0,<6.0.0",
|
||||||
@ -42,6 +42,8 @@ dependencies = [
|
|||||||
"prompt-toolkit>=3.0.50,<4.0.0",
|
"prompt-toolkit>=3.0.50,<4.0.0",
|
||||||
"mcp>=1.26.0,<2.0.0",
|
"mcp>=1.26.0,<2.0.0",
|
||||||
"json-repair>=0.57.0,<1.0.0",
|
"json-repair>=0.57.0,<1.0.0",
|
||||||
|
"chardet>=3.0.2,<6.0.0",
|
||||||
|
"openai>=2.8.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@ -58,6 +60,9 @@ dev = [
|
|||||||
"pytest-asyncio>=1.3.0,<2.0.0",
|
"pytest-asyncio>=1.3.0,<2.0.0",
|
||||||
"aiohttp>=3.9.0,<4.0.0",
|
"aiohttp>=3.9.0,<4.0.0",
|
||||||
"ruff>=0.1.0",
|
"ruff>=0.1.0",
|
||||||
|
"matrix-nio[e2e]>=0.25.2",
|
||||||
|
"mistune>=3.0.0,<4.0.0",
|
||||||
|
"nh3>=0.2.17,<1.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
@ -786,10 +786,8 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_cleans_up_consolidation_lock_for_invalidated_session(
|
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
|
||||||
self, tmp_path: Path
|
"""/new clears session and returns confirmation."""
|
||||||
) -> None:
|
|
||||||
"""/new should remove lock entry for fully invalidated session key."""
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@ -801,7 +799,6 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
loop = AgentLoop(
|
loop = AgentLoop(
|
||||||
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10
|
||||||
)
|
)
|
||||||
|
|
||||||
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
loop.provider.chat = AsyncMock(return_value=LLMResponse(content="ok", tool_calls=[]))
|
||||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
@ -811,10 +808,6 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
session.add_message("assistant", f"resp{i}")
|
session.add_message("assistant", f"resp{i}")
|
||||||
loop.sessions.save(session)
|
loop.sessions.save(session)
|
||||||
|
|
||||||
# Ensure lock exists before /new.
|
|
||||||
loop._consolidation_locks.setdefault(session.key, asyncio.Lock())
|
|
||||||
assert session.key in loop._consolidation_locks
|
|
||||||
|
|
||||||
async def _ok_consolidate(sess, archive_all: bool = False, **kw) -> bool:
|
async def _ok_consolidate(sess, archive_all: bool = False, **kw) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -825,4 +818,4 @@ class TestConsolidationDeduplicationGuard:
|
|||||||
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert "new session started" in response.content.lower()
|
assert "new session started" in response.content.lower()
|
||||||
assert session.key not in loop._consolidation_locks
|
assert loop.sessions.get_or_create("cli:test").messages == []
|
||||||
|
|||||||
@ -40,7 +40,7 @@ def test_system_prompt_stays_stable_when_clock_changes(tmp_path, monkeypatch) ->
|
|||||||
|
|
||||||
|
|
||||||
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
||||||
"""Runtime metadata should be a separate user message before the actual user message."""
|
"""Runtime metadata should be merged with the user message."""
|
||||||
workspace = _make_workspace(tmp_path)
|
workspace = _make_workspace(tmp_path)
|
||||||
builder = ContextBuilder(workspace)
|
builder = ContextBuilder(workspace)
|
||||||
|
|
||||||
@ -54,13 +54,12 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
|||||||
assert messages[0]["role"] == "system"
|
assert messages[0]["role"] == "system"
|
||||||
assert "## Current Session" not in messages[0]["content"]
|
assert "## Current Session" not in messages[0]["content"]
|
||||||
|
|
||||||
assert messages[-2]["role"] == "user"
|
# Runtime context is now merged with user message into a single message
|
||||||
runtime_content = messages[-2]["content"]
|
|
||||||
assert isinstance(runtime_content, str)
|
|
||||||
assert ContextBuilder._RUNTIME_CONTEXT_TAG in runtime_content
|
|
||||||
assert "Current Time:" in runtime_content
|
|
||||||
assert "Channel: cli" in runtime_content
|
|
||||||
assert "Chat ID: direct" in runtime_content
|
|
||||||
|
|
||||||
assert messages[-1]["role"] == "user"
|
assert messages[-1]["role"] == "user"
|
||||||
assert messages[-1]["content"] == "Return exactly: OK"
|
user_content = messages[-1]["content"]
|
||||||
|
assert isinstance(user_content, str)
|
||||||
|
assert ContextBuilder._RUNTIME_CONTEXT_TAG in user_content
|
||||||
|
assert "Current Time:" in user_content
|
||||||
|
assert "Channel: cli" in user_content
|
||||||
|
assert "Chat ID: direct" in user_content
|
||||||
|
assert "Return exactly: OK" in user_content
|
||||||
|
|||||||
@ -1,29 +0,0 @@
|
|||||||
from typer.testing import CliRunner
|
|
||||||
|
|
||||||
from nanobot.cli.commands import app
|
|
||||||
|
|
||||||
runner = CliRunner()
|
|
||||||
|
|
||||||
|
|
||||||
def test_cron_add_rejects_invalid_timezone(monkeypatch, tmp_path) -> None:
|
|
||||||
monkeypatch.setattr("nanobot.config.loader.get_data_dir", lambda: tmp_path)
|
|
||||||
|
|
||||||
result = runner.invoke(
|
|
||||||
app,
|
|
||||||
[
|
|
||||||
"cron",
|
|
||||||
"add",
|
|
||||||
"--name",
|
|
||||||
"demo",
|
|
||||||
"--message",
|
|
||||||
"hello",
|
|
||||||
"--cron",
|
|
||||||
"0 9 * * *",
|
|
||||||
"--tz",
|
|
||||||
"America/Vancovuer",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result.exit_code == 1
|
|
||||||
assert "Error: unknown timezone 'America/Vancovuer'" in result.stdout
|
|
||||||
assert not (tmp_path / "cron" / "jobs.json").exists()
|
|
||||||
@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
@ -28,3 +30,32 @@ def test_add_job_accepts_valid_timezone(tmp_path) -> None:
|
|||||||
|
|
||||||
assert job.schedule.tz == "America/Vancouver"
|
assert job.schedule.tz == "America/Vancouver"
|
||||||
assert job.state.next_run_at_ms is not None
|
assert job.state.next_run_at_ms is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_running_service_honors_external_disable(tmp_path) -> None:
|
||||||
|
store_path = tmp_path / "cron" / "jobs.json"
|
||||||
|
called: list[str] = []
|
||||||
|
|
||||||
|
async def on_job(job) -> None:
|
||||||
|
called.append(job.id)
|
||||||
|
|
||||||
|
service = CronService(store_path, on_job=on_job)
|
||||||
|
job = service.add_job(
|
||||||
|
name="external-disable",
|
||||||
|
schedule=CronSchedule(kind="every", every_ms=200),
|
||||||
|
message="hello",
|
||||||
|
)
|
||||||
|
await service.start()
|
||||||
|
try:
|
||||||
|
# Wait slightly to ensure file mtime is definitively different
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
external = CronService(store_path)
|
||||||
|
updated = external.enable_job(job.id, enabled=False)
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.enabled is False
|
||||||
|
|
||||||
|
await asyncio.sleep(0.35)
|
||||||
|
assert called == []
|
||||||
|
finally:
|
||||||
|
service.stop()
|
||||||
|
|||||||
40
tests/test_feishu_post_content.py
Normal file
40
tests/test_feishu_post_content.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from nanobot.channels.feishu import _extract_post_content
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_post_content_supports_post_wrapper_shape() -> None:
|
||||||
|
payload = {
|
||||||
|
"post": {
|
||||||
|
"zh_cn": {
|
||||||
|
"title": "日报",
|
||||||
|
"content": [
|
||||||
|
[
|
||||||
|
{"tag": "text", "text": "完成"},
|
||||||
|
{"tag": "img", "image_key": "img_1"},
|
||||||
|
]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
text, image_keys = _extract_post_content(payload)
|
||||||
|
|
||||||
|
assert text == "日报 完成"
|
||||||
|
assert image_keys == ["img_1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_post_content_keeps_direct_shape_behavior() -> None:
|
||||||
|
payload = {
|
||||||
|
"title": "Daily",
|
||||||
|
"content": [
|
||||||
|
[
|
||||||
|
{"tag": "text", "text": "report"},
|
||||||
|
{"tag": "img", "image_key": "img_a"},
|
||||||
|
{"tag": "img", "image_key": "img_b"},
|
||||||
|
]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
text, image_keys = _extract_post_content(payload)
|
||||||
|
|
||||||
|
assert text == "Daily report"
|
||||||
|
assert image_keys == ["img_a", "img_b"]
|
||||||
104
tests/test_feishu_table_split.py
Normal file
104
tests/test_feishu_table_split.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
"""Tests for FeishuChannel._split_elements_by_table_limit.
|
||||||
|
|
||||||
|
Feishu cards reject messages that contain more than one table element
|
||||||
|
(API error 11310: card table number over limit). The helper splits a flat
|
||||||
|
list of card elements into groups so that each group contains at most one
|
||||||
|
table, allowing nanobot to send multiple cards instead of failing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from nanobot.channels.feishu import FeishuChannel
|
||||||
|
|
||||||
|
|
||||||
|
def _md(text: str) -> dict:
|
||||||
|
return {"tag": "markdown", "content": text}
|
||||||
|
|
||||||
|
|
||||||
|
def _table() -> dict:
|
||||||
|
return {
|
||||||
|
"tag": "table",
|
||||||
|
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||||
|
"rows": [{"c0": "v"}],
|
||||||
|
"page_size": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
split = FeishuChannel._split_elements_by_table_limit
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_list_returns_single_empty_group() -> None:
|
||||||
|
assert split([]) == [[]]
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_tables_returns_single_group() -> None:
|
||||||
|
els = [_md("hello"), _md("world")]
|
||||||
|
result = split(els)
|
||||||
|
assert result == [els]
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_table_stays_in_one_group() -> None:
|
||||||
|
els = [_md("intro"), _table(), _md("outro")]
|
||||||
|
result = split(els)
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0] == els
|
||||||
|
|
||||||
|
|
||||||
|
def test_two_tables_split_into_two_groups() -> None:
|
||||||
|
# Use different row values so the two tables are not equal
|
||||||
|
t1 = {
|
||||||
|
"tag": "table",
|
||||||
|
"columns": [{"tag": "column", "name": "c0", "display_name": "A", "width": "auto"}],
|
||||||
|
"rows": [{"c0": "table-one"}],
|
||||||
|
"page_size": 2,
|
||||||
|
}
|
||||||
|
t2 = {
|
||||||
|
"tag": "table",
|
||||||
|
"columns": [{"tag": "column", "name": "c0", "display_name": "B", "width": "auto"}],
|
||||||
|
"rows": [{"c0": "table-two"}],
|
||||||
|
"page_size": 2,
|
||||||
|
}
|
||||||
|
els = [_md("before"), t1, _md("between"), t2, _md("after")]
|
||||||
|
result = split(els)
|
||||||
|
assert len(result) == 2
|
||||||
|
# First group: text before table-1 + table-1
|
||||||
|
assert t1 in result[0]
|
||||||
|
assert t2 not in result[0]
|
||||||
|
# Second group: text between tables + table-2 + text after
|
||||||
|
assert t2 in result[1]
|
||||||
|
assert t1 not in result[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_three_tables_split_into_three_groups() -> None:
|
||||||
|
tables = [
|
||||||
|
{"tag": "table", "columns": [], "rows": [{"c0": f"t{i}"}], "page_size": 1}
|
||||||
|
for i in range(3)
|
||||||
|
]
|
||||||
|
els = tables[:]
|
||||||
|
result = split(els)
|
||||||
|
assert len(result) == 3
|
||||||
|
for i, group in enumerate(result):
|
||||||
|
assert tables[i] in group
|
||||||
|
|
||||||
|
|
||||||
|
def test_leading_markdown_stays_with_first_table() -> None:
|
||||||
|
intro = _md("intro")
|
||||||
|
t = _table()
|
||||||
|
result = split([intro, t])
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0] == [intro, t]
|
||||||
|
|
||||||
|
|
||||||
|
def test_trailing_markdown_after_second_table() -> None:
|
||||||
|
t1, t2 = _table(), _table()
|
||||||
|
tail = _md("end")
|
||||||
|
result = split([t1, t2, tail])
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[1] == [t2, tail]
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_table_elements_before_first_table_kept_in_first_group() -> None:
|
||||||
|
head = _md("head")
|
||||||
|
t1, t2 = _table(), _table()
|
||||||
|
result = split([head, t1, t2])
|
||||||
|
# head + t1 in group 0; t2 in group 1
|
||||||
|
assert result[0] == [head, t1]
|
||||||
|
assert result[1] == [t2]
|
||||||
41
tests/test_loop_save_turn.py
Normal file
41
tests/test_loop_save_turn.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from nanobot.agent.context import ContextBuilder
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.session.manager import Session
|
||||||
|
|
||||||
|
|
||||||
|
def _mk_loop() -> AgentLoop:
|
||||||
|
loop = AgentLoop.__new__(AgentLoop)
|
||||||
|
loop._TOOL_RESULT_MAX_CHARS = 500
|
||||||
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
||||||
|
loop = _mk_loop()
|
||||||
|
session = Session(key="test:runtime-only")
|
||||||
|
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
||||||
|
|
||||||
|
loop._save_turn(
|
||||||
|
session,
|
||||||
|
[{"role": "user", "content": [{"type": "text", "text": runtime}]}],
|
||||||
|
skip=0,
|
||||||
|
)
|
||||||
|
assert session.messages == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_turn_keeps_image_placeholder_after_runtime_strip() -> None:
|
||||||
|
loop = _mk_loop()
|
||||||
|
session = Session(key="test:image")
|
||||||
|
runtime = ContextBuilder._RUNTIME_CONTEXT_TAG + "\nCurrent Time: now (UTC)"
|
||||||
|
|
||||||
|
loop._save_turn(
|
||||||
|
session,
|
||||||
|
[{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": runtime},
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||||
|
],
|
||||||
|
}],
|
||||||
|
skip=0,
|
||||||
|
)
|
||||||
|
assert session.messages[0]["content"] == [{"type": "text", "text": "[image]"}]
|
||||||
@ -159,6 +159,7 @@ class _FakeAsyncClient:
|
|||||||
|
|
||||||
|
|
||||||
def _make_config(**kwargs) -> MatrixConfig:
|
def _make_config(**kwargs) -> MatrixConfig:
|
||||||
|
kwargs.setdefault("allow_from", ["*"])
|
||||||
return MatrixConfig(
|
return MatrixConfig(
|
||||||
enabled=True,
|
enabled=True,
|
||||||
homeserver="https://matrix.org",
|
homeserver="https://matrix.org",
|
||||||
@ -274,7 +275,7 @@ async def test_stop_stops_sync_forever_before_close(monkeypatch) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_room_invite_joins_when_allow_list_is_empty() -> None:
|
async def test_room_invite_ignores_when_allow_list_is_empty() -> None:
|
||||||
channel = MatrixChannel(_make_config(allow_from=[]), MessageBus())
|
channel = MatrixChannel(_make_config(allow_from=[]), MessageBus())
|
||||||
client = _FakeAsyncClient("", "", "", None)
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
channel.client = client
|
channel.client = client
|
||||||
@ -284,9 +285,22 @@ async def test_room_invite_joins_when_allow_list_is_empty() -> None:
|
|||||||
|
|
||||||
await channel._on_room_invite(room, event)
|
await channel._on_room_invite(room, event)
|
||||||
|
|
||||||
assert client.join_calls == ["!room:matrix.org"]
|
assert client.join_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_room_invite_joins_when_sender_allowed() -> None:
|
||||||
|
channel = MatrixChannel(_make_config(allow_from=["@alice:matrix.org"]), MessageBus())
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
|
||||||
|
room = SimpleNamespace(room_id="!room:matrix.org")
|
||||||
|
event = SimpleNamespace(sender="@alice:matrix.org")
|
||||||
|
|
||||||
|
await channel._on_room_invite(room, event)
|
||||||
|
|
||||||
|
assert client.join_calls == ["!room:matrix.org"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_room_invite_respects_allow_list_when_configured() -> None:
|
async def test_room_invite_respects_allow_list_when_configured() -> None:
|
||||||
channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
|
channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
|
||||||
@ -1163,6 +1177,8 @@ async def test_send_progress_keeps_typing_keepalive_running() -> None:
|
|||||||
assert "!room:matrix.org" in channel._typing_tasks
|
assert "!room:matrix.org" in channel._typing_tasks
|
||||||
assert client.typing_calls[-1] == ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)
|
assert client.typing_calls[-1] == ("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS)
|
||||||
|
|
||||||
|
await channel.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_clears_typing_when_send_fails() -> None:
|
async def test_send_clears_typing_when_send_fails() -> None:
|
||||||
|
|||||||
@ -145,3 +145,78 @@ class TestMemoryConsolidationTypeHandling:
|
|||||||
|
|
||||||
assert result is True
|
assert result is True
|
||||||
provider.chat.assert_not_called()
|
provider.chat.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None:
|
||||||
|
"""Some providers return arguments as a list - extract first element if it's a dict."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
|
||||||
|
# Simulate arguments being a list containing a dict
|
||||||
|
response = LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments=[{
|
||||||
|
"history_entry": "[2026-01-01] User discussed testing.",
|
||||||
|
"memory_update": "# Memory\nUser likes testing.",
|
||||||
|
}],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
provider.chat = AsyncMock(return_value=response)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert "User discussed testing." in store.history_file.read_text()
|
||||||
|
assert "User likes testing." in store.memory_file.read_text()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None:
|
||||||
|
"""Empty list arguments should return False."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
|
||||||
|
response = LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments=[],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
provider.chat = AsyncMock(return_value=response)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None:
|
||||||
|
"""List with non-dict content should return False."""
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = AsyncMock()
|
||||||
|
|
||||||
|
response = LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_1",
|
||||||
|
name="save_memory",
|
||||||
|
arguments=["string", "content"],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
provider.chat = AsyncMock(return_value=response)
|
||||||
|
session = _make_session(message_count=60)
|
||||||
|
|
||||||
|
result = await store.consolidate(session, provider, "test-model", memory_window=50)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user