mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-10 19:56:00 +00:00
fix(agent): scope subagent reply dedupe to origin message
Made-with: Cursor
This commit is contained in:
commit
e157392250
@ -43,6 +43,26 @@ We use a two-branch model to balance stability and exploration:
|
|||||||
**When in doubt, target `nightly`.** It is easier to move a stable idea from `nightly`
|
**When in doubt, target `nightly`.** It is easier to move a stable idea from `nightly`
|
||||||
to `main` than to undo a risky change after it lands in the stable branch.
|
to `main` than to undo a risky change after it lands in the stable branch.
|
||||||
|
|
||||||
|
### Starting Work
|
||||||
|
|
||||||
|
Before making changes, sync the target branch and create a topic branch from it.
|
||||||
|
For stable bug fixes and documentation-only changes, start from the latest `main`.
|
||||||
|
For experimental work, start from the latest `nightly`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git fetch upstream
|
||||||
|
git switch main
|
||||||
|
git pull --ff-only upstream main
|
||||||
|
git switch -c your-topic-branch
|
||||||
|
```
|
||||||
|
|
||||||
|
Use your primary HKUDS/nanobot remote in place of `upstream` if your checkout
|
||||||
|
uses a different remote name.
|
||||||
|
|
||||||
|
Keep unrelated local changes out of the topic branch. If your checkout already has
|
||||||
|
work in progress, use a separate worktree or finish that work before starting a
|
||||||
|
new branch.
|
||||||
|
|
||||||
### How Does Nightly Get Merged to Main?
|
### How Does Nightly Get Merged to Main?
|
||||||
|
|
||||||
We don't merge the entire `nightly` branch. Instead, stable features are **cherry-picked** from `nightly` into individual PRs targeting `main`:
|
We don't merge the entire `nightly` branch. Instead, stable features are **cherry-picked** from `nightly` into individual PRs targeting `main`:
|
||||||
|
|||||||
@ -63,6 +63,7 @@ IMAP_PASSWORD=your-password-here
|
|||||||
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
|
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
|
||||||
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
| `anthropic` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
|
||||||
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
|
| `azure_openai` | LLM (Azure OpenAI) | [portal.azure.com](https://portal.azure.com) |
|
||||||
|
| `bedrock` | LLM (AWS Bedrock Converse, Claude/Nova/Llama/etc.) | [aws.amazon.com/bedrock](https://aws.amazon.com/bedrock/) |
|
||||||
| `openai` | LLM + Voice transcription (Whisper) | [platform.openai.com](https://platform.openai.com) |
|
| `openai` | LLM + Voice transcription (Whisper) | [platform.openai.com](https://platform.openai.com) |
|
||||||
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
| `deepseek` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||||
| `groq` | LLM + Voice transcription (Whisper, default) | [console.groq.com](https://console.groq.com) |
|
| `groq` | LLM + Voice transcription (Whisper, default) | [console.groq.com](https://console.groq.com) |
|
||||||
@ -85,6 +86,183 @@ IMAP_PASSWORD=your-password-here
|
|||||||
| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` |
|
| `github_copilot` | LLM (GitHub Copilot, OAuth) | `nanobot provider login github-copilot` |
|
||||||
| `qianfan` | LLM (Baidu Qianfan) | [cloud.baidu.com](https://cloud.baidu.com/doc/qianfan/s/Hmh4suq26) |
|
| `qianfan` | LLM (Baidu Qianfan) | [cloud.baidu.com](https://cloud.baidu.com/doc/qianfan/s/Hmh4suq26) |
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>AWS Bedrock (Converse API)</b></summary>
|
||||||
|
|
||||||
|
Bedrock uses the native `bedrock-runtime` Converse API, so it can call Bedrock model IDs such as Claude Opus 4.7, Claude Sonnet, Amazon Nova, Meta Llama, Mistral, Qwen, and other models that support Converse. It supports normal chat, streaming, tool calling, tool results, token usage, and Bedrock error metadata.
|
||||||
|
|
||||||
|
This provider is for Bedrock's native Converse API, not Bedrock's OpenAI-compatible `/openai/v1` endpoint. For OpenAI-compatible Bedrock models, you can still use `custom` if you specifically want that API surface.
|
||||||
|
|
||||||
|
**1. Configure credentials**
|
||||||
|
|
||||||
|
Use the normal AWS credential chain (`AWS_ACCESS_KEY_ID` / `AWS_SECRET_ACCESS_KEY`, an AWS profile, or an IAM role). The IAM identity needs:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"Effect": "Allow",
|
||||||
|
"Action": [
|
||||||
|
"bedrock:InvokeModel",
|
||||||
|
"bedrock:InvokeModelWithResponseStream"
|
||||||
|
],
|
||||||
|
"Resource": "*"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also set `providers.bedrock.apiKey` to a Bedrock API key; nanobot exports it as `AWS_BEARER_TOKEN_BEDROCK` for the AWS SDK.
|
||||||
|
|
||||||
|
Credential options:
|
||||||
|
|
||||||
|
- **AWS CLI/default profile**: leave `apiKey` and `profile` empty, then run `aws configure` or provide `AWS_ACCESS_KEY_ID` / `AWS_SECRET_ACCESS_KEY`.
|
||||||
|
- **Named AWS profile**: set `profile` to a profile from `~/.aws/config` or `~/.aws/credentials`.
|
||||||
|
- **IAM role**: on EC2/ECS/Lambda, leave `apiKey` and `profile` empty and attach a role with Bedrock permissions.
|
||||||
|
- **Bedrock API key**: set `apiKey` or `AWS_BEARER_TOKEN_BEDROCK`; `profile` can stay `null`.
|
||||||
|
|
||||||
|
**2. Minimal config**
|
||||||
|
|
||||||
|
For a non-Anthropic model such as Amazon Nova:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"bedrock": {
|
||||||
|
"region": "us-east-1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"provider": "bedrock",
|
||||||
|
"model": "bedrock/amazon.nova-lite-v1:0",
|
||||||
|
"reasoningEffort": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
With a Bedrock API key:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"bedrock": {
|
||||||
|
"region": "us-east-1",
|
||||||
|
"apiKey": "${AWS_BEARER_TOKEN_BEDROCK}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"provider": "bedrock",
|
||||||
|
"model": "bedrock/amazon.nova-lite-v1:0",
|
||||||
|
"reasoningEffort": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
With a named AWS profile:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"bedrock": {
|
||||||
|
"region": "us-east-1",
|
||||||
|
"profile": "my-bedrock-profile"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"provider": "bedrock",
|
||||||
|
"model": "bedrock/amazon.nova-lite-v1:0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Claude Opus 4.7 example**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"bedrock": {
|
||||||
|
"region": "us-east-1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"provider": "bedrock",
|
||||||
|
"model": "bedrock/global.anthropic.claude-opus-4-7",
|
||||||
|
"reasoningEffort": "medium",
|
||||||
|
"maxTokens": 8192
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
For regional routing, use one of Bedrock's inference IDs, for example `bedrock/us.anthropic.claude-opus-4-7`, `bedrock/eu.anthropic.claude-opus-4-7`, or `bedrock/jp.anthropic.claude-opus-4-7`.
|
||||||
|
|
||||||
|
Claude Opus 4.7 does not accept `temperature`, `top_p`, or `top_k`; nanobot omits `temperature` automatically for this model. If `reasoningEffort` is set to `low`, `medium`, `high`, `max`, or `adaptive`, nanobot sends Bedrock's adaptive thinking parameter.
|
||||||
|
|
||||||
|
Anthropic models on Bedrock can also require Anthropic use-case registration and are subject to Anthropic-supported country/region restrictions. If Claude fails with a `ValidationException` about unsupported countries or regions, try a non-Anthropic Bedrock model such as Amazon Nova to verify the provider setup.
|
||||||
|
|
||||||
|
**4. Model IDs**
|
||||||
|
|
||||||
|
Use Bedrock model IDs or inference profile IDs with a `bedrock/` prefix in nanobot config. nanobot removes the prefix before calling AWS.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
- `bedrock/amazon.nova-micro-v1:0`
|
||||||
|
- `bedrock/amazon.nova-lite-v1:0`
|
||||||
|
- `bedrock/global.anthropic.claude-opus-4-7`
|
||||||
|
- `bedrock/us.anthropic.claude-opus-4-7`
|
||||||
|
- `bedrock/openai.gpt-oss-20b-1:0`
|
||||||
|
- `bedrock/meta.llama...`
|
||||||
|
- `bedrock/mistral...`
|
||||||
|
|
||||||
|
Check the Bedrock console for the exact model ID and region availability. Some models require cross-region inference profile IDs such as `us.*`, `eu.*`, or `global.*`.
|
||||||
|
|
||||||
|
**5. Advanced model fields**
|
||||||
|
|
||||||
|
Model-specific fields can be supplied with `extraBody`; nanobot merges it into Converse `additionalModelRequestFields`:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"providers": {
|
||||||
|
"bedrock": {
|
||||||
|
"region": "us-east-1",
|
||||||
|
"extraBody": {
|
||||||
|
"thinking": {
|
||||||
|
"type": "adaptive",
|
||||||
|
"effort": "medium",
|
||||||
|
"display": "summarized"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `apiBase` only for a custom Bedrock Runtime endpoint URL, such as a VPC endpoint or proxy. It is not needed for normal AWS regions.
|
||||||
|
|
||||||
|
Current scope: nanobot passes `messages`, `system`, `inferenceConfig`, `toolConfig`, and `additionalModelRequestFields`. Bedrock Prompt Management, Guardrails, `serviceTier`, and other top-level Converse options are not first-class config fields yet.
|
||||||
|
|
||||||
|
**6. Quick checks**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# For AWS credential-chain usage:
|
||||||
|
aws sts get-caller-identity
|
||||||
|
|
||||||
|
# For API-key usage:
|
||||||
|
export AWS_BEARER_TOKEN_BEDROCK="your-bedrock-api-key"
|
||||||
|
export AWS_REGION="us-east-1"
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nanobot agent -m "Reply with one short sentence."
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>OpenAI Codex (OAuth)</b></summary>
|
<summary><b>OpenAI Codex (OAuth)</b></summary>
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import platform
|
import platform
|
||||||
|
from contextlib import suppress
|
||||||
from importlib.resources import files as pkg_files
|
from importlib.resources import files as pkg_files
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -82,12 +83,14 @@ class ContextBuilder:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_runtime_context(
|
def _build_runtime_context(
|
||||||
channel: str | None, chat_id: str | None, timezone: str | None = None,
|
channel: str | None, chat_id: str | None, timezone: str | None = None,
|
||||||
session_summary: str | None = None,
|
session_summary: str | None = None, sender_id: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build untrusted runtime metadata block for injection before the user message."""
|
"""Build untrusted runtime metadata block for injection before the user message."""
|
||||||
lines = [f"Current Time: {current_time_str(timezone)}"]
|
lines = [f"Current Time: {current_time_str(timezone)}"]
|
||||||
if channel and chat_id:
|
if channel and chat_id:
|
||||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||||
|
if sender_id:
|
||||||
|
lines += [f"Sender ID: {sender_id}"]
|
||||||
if session_summary:
|
if session_summary:
|
||||||
lines += ["", "[Resumed Session]", session_summary]
|
lines += ["", "[Resumed Session]", session_summary]
|
||||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + "\n" + ContextBuilder._RUNTIME_CONTEXT_END
|
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + "\n" + ContextBuilder._RUNTIME_CONTEXT_END
|
||||||
@ -121,12 +124,10 @@ class ContextBuilder:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_template_content(content: str, template_path: str) -> bool:
|
def _is_template_content(content: str, template_path: str) -> bool:
|
||||||
"""Check if *content* is identical to the bundled template (user hasn't customized it)."""
|
"""Check if *content* is identical to the bundled template (user hasn't customized it)."""
|
||||||
try:
|
with suppress(Exception):
|
||||||
tpl = pkg_files("nanobot") / "templates" / template_path
|
tpl = pkg_files("nanobot") / "templates" / template_path
|
||||||
if tpl.is_file():
|
if tpl.is_file():
|
||||||
return content.strip() == tpl.read_text(encoding="utf-8").strip()
|
return content.strip() == tpl.read_text(encoding="utf-8").strip()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def build_messages(
|
def build_messages(
|
||||||
@ -139,9 +140,10 @@ class ContextBuilder:
|
|||||||
chat_id: str | None = None,
|
chat_id: str | None = None,
|
||||||
current_role: str = "user",
|
current_role: str = "user",
|
||||||
session_summary: str | None = None,
|
session_summary: str | None = None,
|
||||||
|
sender_id: str | None = None,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
"""Build the complete message list for an LLM call."""
|
"""Build the complete message list for an LLM call."""
|
||||||
runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone, session_summary=session_summary)
|
runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone, session_summary=session_summary, sender_id=sender_id)
|
||||||
user_content = self._build_user_content(current_message, media)
|
user_content = self._build_user_content(current_message, media)
|
||||||
|
|
||||||
# Merge runtime context and user content into a single user message
|
# Merge runtime context and user content into a single user message
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import dataclasses
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from contextlib import AsyncExitStack, nullcontext
|
from contextlib import AsyncExitStack, nullcontext, suppress
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||||
|
|
||||||
@ -28,6 +28,7 @@ from nanobot.agent.tools.ask import (
|
|||||||
pending_ask_user_id,
|
pending_ask_user_id,
|
||||||
)
|
)
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
|
from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, reset_file_states
|
||||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
from nanobot.agent.tools.message import MessageTool
|
from nanobot.agent.tools.message import MessageTool
|
||||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||||
@ -247,6 +248,9 @@ class AgentLoop:
|
|||||||
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
|
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
|
||||||
self.sessions = session_manager or SessionManager(workspace)
|
self.sessions = session_manager or SessionManager(workspace)
|
||||||
self.tools = ToolRegistry()
|
self.tools = ToolRegistry()
|
||||||
|
# One file-read/write tracker per logical session. The tool registry is
|
||||||
|
# shared by this loop, so tools resolve the active state via contextvars.
|
||||||
|
self._file_state_store = FileStateStore()
|
||||||
self.runner = AgentRunner(provider)
|
self.runner = AgentRunner(provider)
|
||||||
self.subagents = SubagentManager(
|
self.subagents = SubagentManager(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
@ -351,7 +355,9 @@ class AgentLoop:
|
|||||||
self.tools.register(AskUserTool())
|
self.tools.register(AskUserTool())
|
||||||
self.tools.register(
|
self.tools.register(
|
||||||
ReadFileTool(
|
ReadFileTool(
|
||||||
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
workspace=self.workspace,
|
||||||
|
allowed_dir=allowed_dir,
|
||||||
|
extra_allowed_dirs=extra_read,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
for cls in (WriteFileTool, EditFileTool, ListDirTool):
|
||||||
@ -435,6 +441,8 @@ class AgentLoop:
|
|||||||
if hasattr(tool, "set_context"):
|
if hasattr(tool, "set_context"):
|
||||||
if name == "spawn":
|
if name == "spawn":
|
||||||
tool.set_context(channel, chat_id, effective_key=effective_key)
|
tool.set_context(channel, chat_id, effective_key=effective_key)
|
||||||
|
if hasattr(tool, "set_origin_message_id"):
|
||||||
|
tool.set_origin_message_id(message_id)
|
||||||
elif name == "cron":
|
elif name == "cron":
|
||||||
tool.set_context(channel, chat_id, metadata=metadata, session_key=session_key)
|
tool.set_context(channel, chat_id, metadata=metadata, session_key=session_key)
|
||||||
elif name == "message":
|
elif name == "message":
|
||||||
@ -486,10 +494,8 @@ class AgentLoop:
|
|||||||
tasks = self._active_tasks.pop(key, [])
|
tasks = self._active_tasks.pop(key, [])
|
||||||
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
try:
|
with suppress(asyncio.CancelledError, Exception):
|
||||||
await t
|
await t
|
||||||
except (asyncio.CancelledError, Exception):
|
|
||||||
pass
|
|
||||||
sub_cancelled = await self.subagents.cancel_by_session(key)
|
sub_cancelled = await self.subagents.cancel_by_session(key)
|
||||||
return cancelled + sub_cancelled
|
return cancelled + sub_cancelled
|
||||||
|
|
||||||
@ -618,6 +624,9 @@ class AgentLoop:
|
|||||||
|
|
||||||
return items
|
return items
|
||||||
|
|
||||||
|
active_session_key = session.key if session else session_key
|
||||||
|
file_state_token = bind_file_states(self._file_state_store.for_session(active_session_key))
|
||||||
|
try:
|
||||||
result = await self.runner.run(AgentRunSpec(
|
result = await self.runner.run(AgentRunSpec(
|
||||||
initial_messages=initial_messages,
|
initial_messages=initial_messages,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
@ -637,6 +646,8 @@ class AgentLoop:
|
|||||||
checkpoint_callback=_checkpoint,
|
checkpoint_callback=_checkpoint,
|
||||||
injection_callback=_drain_pending,
|
injection_callback=_drain_pending,
|
||||||
))
|
))
|
||||||
|
finally:
|
||||||
|
reset_file_states(file_state_token)
|
||||||
self._last_usage = result.usage
|
self._last_usage = result.usage
|
||||||
if result.stop_reason == "max_iterations":
|
if result.stop_reason == "max_iterations":
|
||||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||||
@ -920,6 +931,7 @@ class AgentLoop:
|
|||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
session_summary=pending,
|
session_summary=pending,
|
||||||
current_role=current_role,
|
current_role=current_role,
|
||||||
|
sender_id=msg.sender_id,
|
||||||
)
|
)
|
||||||
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
|
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
|
||||||
messages, session=session, channel=channel, chat_id=chat_id,
|
messages, session=session, channel=channel, chat_id=chat_id,
|
||||||
@ -947,6 +959,8 @@ class AgentLoop:
|
|||||||
outbound_metadata: dict[str, Any] = {}
|
outbound_metadata: dict[str, Any] = {}
|
||||||
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
|
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
|
||||||
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
|
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
|
||||||
|
if origin_message_id := msg.metadata.get("origin_message_id"):
|
||||||
|
outbound_metadata["origin_message_id"] = origin_message_id
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=channel,
|
channel=channel,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
@ -1015,6 +1029,7 @@ class AgentLoop:
|
|||||||
media=msg.media if msg.media else None,
|
media=msg.media if msg.media else None,
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
chat_id=self._runtime_chat_id(msg),
|
chat_id=self._runtime_chat_id(msg),
|
||||||
|
sender_id=msg.sender_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _bus_progress(
|
async def _bus_progress(
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import weakref
|
import weakref
|
||||||
|
from contextlib import suppress
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -296,10 +297,8 @@ class MemoryStore:
|
|||||||
def _next_cursor(self) -> int:
|
def _next_cursor(self) -> int:
|
||||||
"""Read the current cursor counter and return the next value."""
|
"""Read the current cursor counter and return the next value."""
|
||||||
if self._cursor_file.exists():
|
if self._cursor_file.exists():
|
||||||
try:
|
with suppress(ValueError, OSError):
|
||||||
return int(self._cursor_file.read_text(encoding="utf-8").strip()) + 1
|
return int(self._cursor_file.read_text(encoding="utf-8").strip()) + 1
|
||||||
except (ValueError, OSError):
|
|
||||||
pass
|
|
||||||
# Fast path: trust the tail when intact. Otherwise scan the whole
|
# Fast path: trust the tail when intact. Otherwise scan the whole
|
||||||
# file and take ``max`` — that stays correct even if the monotonic
|
# file and take ``max`` — that stays correct even if the monotonic
|
||||||
# invariant was broken by external writes.
|
# invariant was broken by external writes.
|
||||||
@ -328,7 +327,7 @@ class MemoryStore:
|
|||||||
def _read_entries(self) -> list[dict[str, Any]]:
|
def _read_entries(self) -> list[dict[str, Any]]:
|
||||||
"""Read all entries from history.jsonl."""
|
"""Read all entries from history.jsonl."""
|
||||||
entries: list[dict[str, Any]] = []
|
entries: list[dict[str, Any]] = []
|
||||||
try:
|
with suppress(FileNotFoundError):
|
||||||
with open(self.history_file, "r", encoding="utf-8") as f:
|
with open(self.history_file, "r", encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
@ -337,8 +336,7 @@ class MemoryStore:
|
|||||||
entries.append(json.loads(line))
|
entries.append(json.loads(line))
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
continue
|
continue
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
def _read_last_entry(self) -> dict[str, Any] | None:
|
def _read_last_entry(self) -> dict[str, Any] | None:
|
||||||
@ -374,14 +372,12 @@ class MemoryStore:
|
|||||||
# On Windows, opening a directory with O_RDONLY raises
|
# On Windows, opening a directory with O_RDONLY raises
|
||||||
# PermissionError — skip the dir sync there (NTFS
|
# PermissionError — skip the dir sync there (NTFS
|
||||||
# journals metadata synchronously).
|
# journals metadata synchronously).
|
||||||
try:
|
with suppress(PermissionError):
|
||||||
fd = os.open(str(self.history_file.parent), os.O_RDONLY)
|
fd = os.open(str(self.history_file.parent), os.O_RDONLY)
|
||||||
try:
|
try:
|
||||||
os.fsync(fd)
|
os.fsync(fd)
|
||||||
finally:
|
finally:
|
||||||
os.close(fd)
|
os.close(fd)
|
||||||
except PermissionError:
|
|
||||||
pass # Windows — directory fsync not supported
|
|
||||||
except BaseException:
|
except BaseException:
|
||||||
tmp_path.unlink(missing_ok=True)
|
tmp_path.unlink(missing_ok=True)
|
||||||
raise
|
raise
|
||||||
@ -390,10 +386,8 @@ class MemoryStore:
|
|||||||
|
|
||||||
def get_last_dream_cursor(self) -> int:
|
def get_last_dream_cursor(self) -> int:
|
||||||
if self._dream_cursor_file.exists():
|
if self._dream_cursor_file.exists():
|
||||||
try:
|
with suppress(ValueError, OSError):
|
||||||
return int(self._dream_cursor_file.read_text(encoding="utf-8").strip())
|
return int(self._dream_cursor_file.read_text(encoding="utf-8").strip())
|
||||||
except (ValueError, OSError):
|
|
||||||
pass
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def set_last_dream_cursor(self, cursor: int) -> None:
|
def set_last_dream_cursor(self, cursor: int) -> None:
|
||||||
@ -524,6 +518,7 @@ class Consolidator:
|
|||||||
channel=channel,
|
channel=channel,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
session_summary=session_summary,
|
session_summary=session_summary,
|
||||||
|
sender_id=None,
|
||||||
)
|
)
|
||||||
return estimate_prompt_tokens_chain(
|
return estimate_prompt_tokens_chain(
|
||||||
self.provider,
|
self.provider,
|
||||||
@ -753,23 +748,28 @@ class Dream:
|
|||||||
def _build_tools(self) -> ToolRegistry:
|
def _build_tools(self) -> ToolRegistry:
|
||||||
"""Build a minimal tool registry for the Dream agent."""
|
"""Build a minimal tool registry for the Dream agent."""
|
||||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||||
|
from nanobot.agent.tools.file_state import FileStates
|
||||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
|
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
|
||||||
|
|
||||||
tools = ToolRegistry()
|
tools = ToolRegistry()
|
||||||
workspace = self.store.workspace
|
workspace = self.store.workspace
|
||||||
# Allow reading builtin skills for reference during skill creation
|
# Allow reading builtin skills for reference during skill creation
|
||||||
extra_read = [BUILTIN_SKILLS_DIR] if BUILTIN_SKILLS_DIR.exists() else None
|
extra_read = [BUILTIN_SKILLS_DIR] if BUILTIN_SKILLS_DIR.exists() else None
|
||||||
|
# Dream gets its own FileStates so its caches stay isolated from the
|
||||||
|
# main loop's sessions (issue #3571).
|
||||||
|
file_states = FileStates()
|
||||||
tools.register(ReadFileTool(
|
tools.register(ReadFileTool(
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
allowed_dir=workspace,
|
allowed_dir=workspace,
|
||||||
extra_allowed_dirs=extra_read,
|
extra_allowed_dirs=extra_read,
|
||||||
|
file_states=file_states,
|
||||||
))
|
))
|
||||||
tools.register(EditFileTool(workspace=workspace, allowed_dir=workspace))
|
tools.register(EditFileTool(workspace=workspace, allowed_dir=workspace, file_states=file_states))
|
||||||
# write_file resolves relative paths from workspace root, but can only
|
# write_file resolves relative paths from workspace root, but can only
|
||||||
# write under skills/ so the prompt can safely use skills/<name>/SKILL.md.
|
# write under skills/ so the prompt can safely use skills/<name>/SKILL.md.
|
||||||
skills_dir = workspace / "skills"
|
skills_dir = workspace / "skills"
|
||||||
skills_dir.mkdir(parents=True, exist_ok=True)
|
skills_dir.mkdir(parents=True, exist_ok=True)
|
||||||
tools.register(WriteFileTool(workspace=workspace, allowed_dir=skills_dir))
|
tools.register(WriteFileTool(workspace=workspace, allowed_dir=skills_dir, file_states=file_states))
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
# -- skill listing --------------------------------------------------------
|
# -- skill listing --------------------------------------------------------
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -752,12 +753,10 @@ class AgentRunner:
|
|||||||
prepare_call = getattr(spec.tools, "prepare_call", None)
|
prepare_call = getattr(spec.tools, "prepare_call", None)
|
||||||
tool, params, prep_error = None, tool_call.arguments, None
|
tool, params, prep_error = None, tool_call.arguments, None
|
||||||
if callable(prepare_call):
|
if callable(prepare_call):
|
||||||
try:
|
with suppress(Exception):
|
||||||
prepared = prepare_call(tool_call.name, tool_call.arguments)
|
prepared = prepare_call(tool_call.name, tool_call.arguments)
|
||||||
if isinstance(prepared, tuple) and len(prepared) == 3:
|
if isinstance(prepared, tuple) and len(prepared) == 3:
|
||||||
tool, params, prep_error = prepared
|
tool, params, prep_error = prepared
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if prep_error:
|
if prep_error:
|
||||||
event = {
|
event = {
|
||||||
"name": tool_call.name,
|
"name": tool_call.name,
|
||||||
|
|||||||
@ -170,12 +170,16 @@ class SubagentManager:
|
|||||||
tools = ToolRegistry()
|
tools = ToolRegistry()
|
||||||
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
allowed_dir = self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||||
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
|
# Subagent gets its own FileStates so its read-dedup cache is
|
||||||
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
# isolated from the parent loop's sessions (issue #3571).
|
||||||
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
from nanobot.agent.tools.file_state import FileStates
|
||||||
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
file_states = FileStates()
|
||||||
tools.register(GlobTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read, file_states=file_states))
|
||||||
tools.register(GrepTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
tools.register(WriteFileTool(workspace=self.workspace, allowed_dir=allowed_dir, file_states=file_states))
|
||||||
|
tools.register(EditFileTool(workspace=self.workspace, allowed_dir=allowed_dir, file_states=file_states))
|
||||||
|
tools.register(ListDirTool(workspace=self.workspace, allowed_dir=allowed_dir, file_states=file_states))
|
||||||
|
tools.register(GlobTool(workspace=self.workspace, allowed_dir=allowed_dir, file_states=file_states))
|
||||||
|
tools.register(GrepTool(workspace=self.workspace, allowed_dir=allowed_dir, file_states=file_states))
|
||||||
if self.exec_config.enable:
|
if self.exec_config.enable:
|
||||||
tools.register(ExecTool(
|
tools.register(ExecTool(
|
||||||
working_dir=str(self.workspace),
|
working_dir=str(self.workspace),
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
from contextvars import ContextVar, Token
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -17,9 +18,6 @@ class ReadState:
|
|||||||
can_dedup: bool
|
can_dedup: bool
|
||||||
|
|
||||||
|
|
||||||
_state: dict[str, ReadState] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def _hash_file(p: str) -> str | None:
|
def _hash_file(p: str) -> str | None:
|
||||||
try:
|
try:
|
||||||
return hashlib.sha256(Path(p).read_bytes()).hexdigest()
|
return hashlib.sha256(Path(p).read_bytes()).hexdigest()
|
||||||
@ -27,14 +25,27 @@ def _hash_file(p: str) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def record_read(path: str | Path, offset: int = 1, limit: int | None = None) -> None:
|
class FileStates:
|
||||||
|
"""Per-session read/write tracker.
|
||||||
|
|
||||||
|
Owns its own state dict so read-dedup ("File unchanged since last read")
|
||||||
|
and read-before-edit warnings stay scoped to one agent session and do
|
||||||
|
not leak across sessions sharing this process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("_state",)
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._state: dict[str, ReadState] = {}
|
||||||
|
|
||||||
|
def record_read(self, path: str | Path, offset: int = 1, limit: int | None = None) -> None:
|
||||||
"""Record that a file was read (called after successful read)."""
|
"""Record that a file was read (called after successful read)."""
|
||||||
p = str(Path(path).resolve())
|
p = str(Path(path).resolve())
|
||||||
try:
|
try:
|
||||||
mtime = os.path.getmtime(p)
|
mtime = os.path.getmtime(p)
|
||||||
except OSError:
|
except OSError:
|
||||||
return
|
return
|
||||||
_state[p] = ReadState(
|
self._state[p] = ReadState(
|
||||||
mtime=mtime,
|
mtime=mtime,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@ -42,16 +53,15 @@ def record_read(path: str | Path, offset: int = 1, limit: int | None = None) ->
|
|||||||
can_dedup=True,
|
can_dedup=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def record_write(self, path: str | Path) -> None:
|
||||||
def record_write(path: str | Path) -> None:
|
|
||||||
"""Record that a file was written (updates mtime in state)."""
|
"""Record that a file was written (updates mtime in state)."""
|
||||||
p = str(Path(path).resolve())
|
p = str(Path(path).resolve())
|
||||||
try:
|
try:
|
||||||
mtime = os.path.getmtime(p)
|
mtime = os.path.getmtime(p)
|
||||||
except OSError:
|
except OSError:
|
||||||
_state.pop(p, None)
|
self._state.pop(p, None)
|
||||||
return
|
return
|
||||||
_state[p] = ReadState(
|
self._state[p] = ReadState(
|
||||||
mtime=mtime,
|
mtime=mtime,
|
||||||
offset=1,
|
offset=1,
|
||||||
limit=None,
|
limit=None,
|
||||||
@ -59,8 +69,7 @@ def record_write(path: str | Path) -> None:
|
|||||||
can_dedup=False,
|
can_dedup=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_read(self, path: str | Path) -> str | None:
|
||||||
def check_read(path: str | Path) -> str | None:
|
|
||||||
"""Check if a file has been read and is fresh.
|
"""Check if a file has been read and is fresh.
|
||||||
|
|
||||||
Returns None if OK, or a warning string.
|
Returns None if OK, or a warning string.
|
||||||
@ -68,7 +77,7 @@ def check_read(path: str | Path) -> str | None:
|
|||||||
the check passes to avoid false-positive staleness warnings.
|
the check passes to avoid false-positive staleness warnings.
|
||||||
"""
|
"""
|
||||||
p = str(Path(path).resolve())
|
p = str(Path(path).resolve())
|
||||||
entry = _state.get(p)
|
entry = self._state.get(p)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
return "Warning: file has not been read yet. Read it first to verify content before editing."
|
return "Warning: file has not been read yet. Read it first to verify content before editing."
|
||||||
try:
|
try:
|
||||||
@ -85,11 +94,10 @@ def check_read(path: str | Path) -> str | None:
|
|||||||
return "Warning: file has been modified since last read. Re-read to verify content before editing."
|
return "Warning: file has been modified since last read. Re-read to verify content before editing."
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def is_unchanged(self, path: str | Path, offset: int = 1, limit: int | None = None) -> bool:
|
||||||
def is_unchanged(path: str | Path, offset: int = 1, limit: int | None = None) -> bool:
|
|
||||||
"""Return True if file was previously read with same params and content is unchanged."""
|
"""Return True if file was previously read with same params and content is unchanged."""
|
||||||
p = str(Path(path).resolve())
|
p = str(Path(path).resolve())
|
||||||
entry = _state.get(p)
|
entry = self._state.get(p)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
return False
|
return False
|
||||||
if not entry.can_dedup:
|
if not entry.can_dedup:
|
||||||
@ -113,7 +121,85 @@ def is_unchanged(path: str | Path, offset: int = 1, limit: int | None = None) ->
|
|||||||
# mtime unchanged - content must be identical
|
# mtime unchanged - content must be identical
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def get(self, path: str | Path) -> ReadState | None:
|
||||||
|
"""Return the raw ReadState entry for a path, or None."""
|
||||||
|
return self._state.get(str(Path(path).resolve()))
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all tracked state (useful for testing)."""
|
||||||
|
self._state.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class FileStateStore:
|
||||||
|
"""Lookup table for per-session file read/write state."""
|
||||||
|
|
||||||
|
__slots__ = ("_states_by_key",)
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._states_by_key: dict[str, FileStates] = {}
|
||||||
|
|
||||||
|
def for_session(self, session_key: str | None) -> FileStates:
|
||||||
|
key = session_key or "__default__"
|
||||||
|
states = self._states_by_key.get(key)
|
||||||
|
if states is None:
|
||||||
|
states = FileStates()
|
||||||
|
self._states_by_key[key] = states
|
||||||
|
return states
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._states_by_key.clear()
|
||||||
|
|
||||||
|
|
||||||
|
_current_file_states: ContextVar[FileStates | None] = ContextVar(
|
||||||
|
"nanobot_file_states",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def current_file_states(default: FileStates) -> FileStates:
|
||||||
|
"""Return the FileStates bound to the current agent task, or a fallback."""
|
||||||
|
return _current_file_states.get() or default
|
||||||
|
|
||||||
|
|
||||||
|
def bind_file_states(file_states: FileStates) -> Token[FileStates | None]:
|
||||||
|
"""Bind file read/write state for the current async task."""
|
||||||
|
return _current_file_states.set(file_states)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_file_states(token: Token[FileStates | None]) -> None:
|
||||||
|
_current_file_states.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level default instance, retained for backward compatibility with
|
||||||
|
# tests and callers that reach in directly. Per-session callers should hold
|
||||||
|
# their own FileStates instance instead of touching this one.
|
||||||
|
_default = FileStates()
|
||||||
|
|
||||||
|
|
||||||
|
def record_read(path: str | Path, offset: int = 1, limit: int | None = None) -> None:
|
||||||
|
_default.record_read(path, offset=offset, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
|
def record_write(path: str | Path) -> None:
|
||||||
|
_default.record_write(path)
|
||||||
|
|
||||||
|
|
||||||
|
def check_read(path: str | Path) -> str | None:
|
||||||
|
return _default.check_read(path)
|
||||||
|
|
||||||
|
|
||||||
|
def is_unchanged(path: str | Path, offset: int = 1, limit: int | None = None) -> bool:
|
||||||
|
return _default.is_unchanged(path, offset=offset, limit=limit)
|
||||||
|
|
||||||
|
|
||||||
def clear() -> None:
|
def clear() -> None:
|
||||||
"""Clear all tracked state (useful for testing)."""
|
_default.clear()
|
||||||
_state.clear()
|
|
||||||
|
|
||||||
|
# Legacy attribute for callers that reached into the module-level dict
|
||||||
|
# directly (filesystem.py used to do this). Kept as a property-like accessor
|
||||||
|
# so existing imports keep working.
|
||||||
|
def __getattr__(name: str):
|
||||||
|
if name == "_state":
|
||||||
|
return _default._state
|
||||||
|
raise AttributeError(name)
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from typing import Any
|
|||||||
|
|
||||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||||
from nanobot.agent.tools import file_state
|
from nanobot.agent.tools.file_state import FileStates, _hash_file, current_file_states
|
||||||
from nanobot.utils.helpers import build_image_content_blocks, detect_image_mime
|
from nanobot.utils.helpers import build_image_content_blocks, detect_image_mime
|
||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
|
|
||||||
@ -49,10 +49,22 @@ class _FsTool(Tool):
|
|||||||
workspace: Path | None = None,
|
workspace: Path | None = None,
|
||||||
allowed_dir: Path | None = None,
|
allowed_dir: Path | None = None,
|
||||||
extra_allowed_dirs: list[Path] | None = None,
|
extra_allowed_dirs: list[Path] | None = None,
|
||||||
|
file_states: FileStates | None = None,
|
||||||
):
|
):
|
||||||
self._workspace = workspace
|
self._workspace = workspace
|
||||||
self._allowed_dir = allowed_dir
|
self._allowed_dir = allowed_dir
|
||||||
self._extra_allowed_dirs = extra_allowed_dirs
|
self._extra_allowed_dirs = extra_allowed_dirs
|
||||||
|
# Explicit state is used by isolated runners like Dream/subagents.
|
||||||
|
# Main AgentLoop tools leave this unset and resolve state from the
|
||||||
|
# current async task, which keeps shared tool instances session-safe.
|
||||||
|
self._explicit_file_states = file_states
|
||||||
|
self._fallback_file_states = FileStates()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _file_states(self) -> FileStates:
|
||||||
|
if self._explicit_file_states is not None:
|
||||||
|
return self._explicit_file_states
|
||||||
|
return current_file_states(self._fallback_file_states)
|
||||||
|
|
||||||
def _resolve(self, path: str) -> Path:
|
def _resolve(self, path: str) -> Path:
|
||||||
return _resolve_path(path, self._workspace, self._allowed_dir, self._extra_allowed_dirs)
|
return _resolve_path(path, self._workspace, self._allowed_dir, self._extra_allowed_dirs)
|
||||||
@ -184,7 +196,7 @@ class ReadFileTool(_FsTool):
|
|||||||
|
|
||||||
# Read dedup: same path + offset + limit + unchanged mtime → stub
|
# Read dedup: same path + offset + limit + unchanged mtime → stub
|
||||||
# Always check for external modifications before dedup
|
# Always check for external modifications before dedup
|
||||||
entry = file_state._state.get(str(fp.resolve()))
|
entry = self._file_states.get(fp)
|
||||||
try:
|
try:
|
||||||
current_mtime = os.path.getmtime(fp)
|
current_mtime = os.path.getmtime(fp)
|
||||||
except OSError:
|
except OSError:
|
||||||
@ -193,21 +205,21 @@ class ReadFileTool(_FsTool):
|
|||||||
if current_mtime != entry.mtime:
|
if current_mtime != entry.mtime:
|
||||||
# File was modified externally - force full read and mark as not dedupable
|
# File was modified externally - force full read and mark as not dedupable
|
||||||
entry.can_dedup = False
|
entry.can_dedup = False
|
||||||
file_state.record_read(fp, offset=offset, limit=limit) # Update state with new mtime
|
self._file_states.record_read(fp, offset=offset, limit=limit) # Update state with new mtime
|
||||||
# Continue to read full content (don't return dedup message)
|
# Continue to read full content (don't return dedup message)
|
||||||
else:
|
else:
|
||||||
# File unchanged - return dedup message
|
# File unchanged - return dedup message
|
||||||
# But only if content is actually unchanged (not just mtime)
|
# But only if content is actually unchanged (not just mtime)
|
||||||
current_hash = file_state._hash_file(str(fp))
|
current_hash = _hash_file(str(fp))
|
||||||
if current_hash == entry.content_hash:
|
if current_hash == entry.content_hash:
|
||||||
return f"[File unchanged since last read: {path}]"
|
return f"[File unchanged since last read: {path}]"
|
||||||
else:
|
else:
|
||||||
# Content changed despite same mtime - force full read
|
# Content changed despite same mtime - force full read
|
||||||
entry.can_dedup = False
|
entry.can_dedup = False
|
||||||
file_state.record_read(fp, offset=offset, limit=limit)
|
self._file_states.record_read(fp, offset=offset, limit=limit)
|
||||||
else:
|
else:
|
||||||
# No previous state or marked as not dedupable - read full content
|
# No previous state or marked as not dedupable - read full content
|
||||||
file_state.record_read(fp, offset=offset, limit=limit)
|
self._file_states.record_read(fp, offset=offset, limit=limit)
|
||||||
# Force full read by setting can_dedup to False for this read
|
# Force full read by setting can_dedup to False for this read
|
||||||
if entry:
|
if entry:
|
||||||
entry.can_dedup = False
|
entry.can_dedup = False
|
||||||
@ -256,7 +268,7 @@ class ReadFileTool(_FsTool):
|
|||||||
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
|
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
|
||||||
else:
|
else:
|
||||||
result += f"\n\n(End of file — {total} lines total)"
|
result += f"\n\n(End of file — {total} lines total)"
|
||||||
file_state.record_read(fp, offset=offset, limit=limit)
|
self._file_states.record_read(fp, offset=offset, limit=limit)
|
||||||
return result
|
return result
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
@ -365,7 +377,7 @@ class WriteFileTool(_FsTool):
|
|||||||
fp = self._resolve(path)
|
fp = self._resolve(path)
|
||||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||||
fp.write_text(content, encoding="utf-8")
|
fp.write_text(content, encoding="utf-8")
|
||||||
file_state.record_write(fp)
|
self._file_states.record_write(fp)
|
||||||
return f"Successfully wrote {len(content)} characters to {fp}"
|
return f"Successfully wrote {len(content)} characters to {fp}"
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
@ -699,7 +711,7 @@ class EditFileTool(_FsTool):
|
|||||||
if old_text == "":
|
if old_text == "":
|
||||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||||
fp.write_text(new_text, encoding="utf-8")
|
fp.write_text(new_text, encoding="utf-8")
|
||||||
file_state.record_write(fp)
|
self._file_states.record_write(fp)
|
||||||
return f"Successfully created {fp}"
|
return f"Successfully created {fp}"
|
||||||
return self._file_not_found_msg(path, fp)
|
return self._file_not_found_msg(path, fp)
|
||||||
|
|
||||||
@ -718,11 +730,11 @@ class EditFileTool(_FsTool):
|
|||||||
if content.strip():
|
if content.strip():
|
||||||
return f"Error: Cannot create file — {path} already exists and is not empty."
|
return f"Error: Cannot create file — {path} already exists and is not empty."
|
||||||
fp.write_text(new_text, encoding="utf-8")
|
fp.write_text(new_text, encoding="utf-8")
|
||||||
file_state.record_write(fp)
|
self._file_states.record_write(fp)
|
||||||
return f"Successfully edited {fp}"
|
return f"Successfully edited {fp}"
|
||||||
|
|
||||||
# Read-before-edit check
|
# Read-before-edit check
|
||||||
warning = file_state.check_read(fp)
|
warning = self._file_states.check_read(fp)
|
||||||
|
|
||||||
raw = fp.read_bytes()
|
raw = fp.read_bytes()
|
||||||
uses_crlf = b"\r\n" in raw
|
uses_crlf = b"\r\n" in raw
|
||||||
@ -767,7 +779,7 @@ class EditFileTool(_FsTool):
|
|||||||
new_content = new_content.replace("\n", "\r\n")
|
new_content = new_content.replace("\n", "\r\n")
|
||||||
|
|
||||||
fp.write_bytes(new_content.encode("utf-8"))
|
fp.write_bytes(new_content.encode("utf-8"))
|
||||||
file_state.record_write(fp)
|
self._file_states.record_write(fp)
|
||||||
msg = f"Successfully edited {fp}"
|
msg = f"Successfully edited {fp}"
|
||||||
if warning:
|
if warning:
|
||||||
msg = f"{warning}\n{msg}"
|
msg = f"{warning}\n{msg}"
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from contextlib import AsyncExitStack
|
from contextlib import AsyncExitStack, suppress
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@ -609,10 +609,8 @@ async def connect_mcp_servers(
|
|||||||
"only JSON-RPC to stdout and sends logs/debug output to stderr instead."
|
"only JSON-RPC to stdout and sends logs/debug output to stderr instead."
|
||||||
)
|
)
|
||||||
logger.error("MCP server '{}': failed to connect: {}{}", name, e, hint)
|
logger.error("MCP server '{}': failed to connect: {}{}", name, e, hint)
|
||||||
try:
|
with suppress(Exception):
|
||||||
await server_stack.aclose()
|
await server_stack.aclose()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return name, None
|
return name, None
|
||||||
|
|
||||||
server_stacks: dict[str, AsyncExitStack] = {}
|
server_stacks: dict[str, AsyncExitStack] = {}
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import fnmatch
|
import fnmatch
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from contextlib import suppress
|
||||||
from pathlib import Path, PurePosixPath
|
from pathlib import Path, PurePosixPath
|
||||||
from typing import Any, Iterable, TypeVar
|
from typing import Any, Iterable, TypeVar
|
||||||
|
|
||||||
@ -92,10 +93,8 @@ class _SearchTool(_FsTool):
|
|||||||
|
|
||||||
def _display_path(self, target: Path, root: Path) -> str:
|
def _display_path(self, target: Path, root: Path) -> str:
|
||||||
if self._workspace:
|
if self._workspace:
|
||||||
try:
|
with suppress(ValueError):
|
||||||
return target.relative_to(self._workspace).as_posix()
|
return target.relative_to(self._workspace).as_posix()
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
return target.relative_to(root).as_posix()
|
return target.relative_to(root).as_posix()
|
||||||
|
|
||||||
def _iter_files(self, root: Path) -> Iterable[Path]:
|
def _iter_files(self, root: Path) -> Iterable[Path]:
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import suppress
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -212,9 +213,8 @@ class ExecTool(Tool):
|
|||||||
"""Kill a subprocess and reap it to prevent zombies."""
|
"""Kill a subprocess and reap it to prevent zombies."""
|
||||||
process.kill()
|
process.kill()
|
||||||
try:
|
try:
|
||||||
|
with suppress(asyncio.TimeoutError):
|
||||||
await asyncio.wait_for(process.wait(), timeout=5.0)
|
await asyncio.wait_for(process.wait(), timeout=5.0)
|
||||||
except asyncio.TimeoutError:
|
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
if not _IS_WINDOWS:
|
if not _IS_WINDOWS:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -25,7 +25,10 @@ class SpawnTool(Tool):
|
|||||||
self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli")
|
self._origin_channel: ContextVar[str] = ContextVar("spawn_origin_channel", default="cli")
|
||||||
self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct")
|
self._origin_chat_id: ContextVar[str] = ContextVar("spawn_origin_chat_id", default="direct")
|
||||||
self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct")
|
self._session_key: ContextVar[str] = ContextVar("spawn_session_key", default="cli:direct")
|
||||||
self._origin_message_id: str | None = None
|
self._origin_message_id: ContextVar[str | None] = ContextVar(
|
||||||
|
"spawn_origin_message_id",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None:
|
def set_context(self, channel: str, chat_id: str, effective_key: str | None = None) -> None:
|
||||||
"""Set the origin context for subagent announcements."""
|
"""Set the origin context for subagent announcements."""
|
||||||
@ -35,7 +38,7 @@ class SpawnTool(Tool):
|
|||||||
|
|
||||||
def set_origin_message_id(self, message_id: str | None) -> None:
|
def set_origin_message_id(self, message_id: str | None) -> None:
|
||||||
"""Set the source message id for downstream deduplication."""
|
"""Set the source message id for downstream deduplication."""
|
||||||
self._origin_message_id = message_id
|
self._origin_message_id.set(message_id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@ -59,5 +62,5 @@ class SpawnTool(Tool):
|
|||||||
origin_channel=self._origin_channel.get(),
|
origin_channel=self._origin_channel.get(),
|
||||||
origin_chat_id=self._origin_chat_id.get(),
|
origin_chat_id=self._origin_chat_id.get(),
|
||||||
session_key=self._session_key.get(),
|
session_key=self._session_key.get(),
|
||||||
origin_message_id=self._origin_message_id,
|
origin_message_id=self._origin_message_id.get(),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ All requests route to a single persistent API session.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import contextlib
|
||||||
import json as _json
|
import json as _json
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@ -18,8 +19,12 @@ from loguru import logger
|
|||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
from nanobot.utils.helpers import safe_filename
|
from nanobot.utils.helpers import safe_filename
|
||||||
from nanobot.utils.media_decode import (
|
from nanobot.utils.media_decode import (
|
||||||
FileSizeExceeded as _FileSizeExceeded,
|
|
||||||
MAX_FILE_SIZE,
|
MAX_FILE_SIZE,
|
||||||
|
)
|
||||||
|
from nanobot.utils.media_decode import (
|
||||||
|
FileSizeExceeded as _FileSizeExceeded,
|
||||||
|
)
|
||||||
|
from nanobot.utils.media_decode import (
|
||||||
save_base64_data_url as _save_base64_data_url,
|
save_base64_data_url as _save_base64_data_url,
|
||||||
)
|
)
|
||||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
@ -240,18 +245,25 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
queue: asyncio.Queue[str | None] = asyncio.Queue()
|
||||||
stream_failed = False
|
stream_failed = False
|
||||||
|
emitted_content = False
|
||||||
|
|
||||||
async def _on_stream(token: str) -> None:
|
async def _on_stream(token: str) -> None:
|
||||||
|
nonlocal emitted_content
|
||||||
|
if token:
|
||||||
|
emitted_content = True
|
||||||
await queue.put(token)
|
await queue.put(token)
|
||||||
|
|
||||||
async def _on_stream_end(*_a: Any, **_kw: Any) -> None:
|
async def _on_stream_end(*_a: Any, **_kw: Any) -> None:
|
||||||
await queue.put(None)
|
# Agent stream-end callbacks mark generation segment boundaries.
|
||||||
|
# Tool-backed requests may continue after a segment ends, so the
|
||||||
|
# HTTP SSE stream is closed only when process_direct returns.
|
||||||
|
return None
|
||||||
|
|
||||||
async def _run() -> None:
|
async def _run() -> None:
|
||||||
nonlocal stream_failed
|
nonlocal stream_failed
|
||||||
try:
|
try:
|
||||||
async with session_lock:
|
async with session_lock:
|
||||||
await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
agent_loop.process_direct(
|
agent_loop.process_direct(
|
||||||
content=text,
|
content=text,
|
||||||
media=media_paths if media_paths else None,
|
media=media_paths if media_paths else None,
|
||||||
@ -263,9 +275,14 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
),
|
),
|
||||||
timeout=timeout_s,
|
timeout=timeout_s,
|
||||||
)
|
)
|
||||||
|
if not emitted_content:
|
||||||
|
response_text = _response_text(response)
|
||||||
|
if response_text.strip():
|
||||||
|
await queue.put(response_text)
|
||||||
except Exception:
|
except Exception:
|
||||||
stream_failed = True
|
stream_failed = True
|
||||||
logger.exception("Streaming error for session {}", session_key)
|
logger.exception("Streaming error for session {}", session_key)
|
||||||
|
finally:
|
||||||
await queue.put(None)
|
await queue.put(None)
|
||||||
|
|
||||||
task = asyncio.create_task(_run())
|
task = asyncio.create_task(_run())
|
||||||
@ -276,7 +293,10 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
break
|
break
|
||||||
await resp.write(_sse_chunk(token, model_name, chunk_id))
|
await resp.write(_sse_chunk(token, model_name, chunk_id))
|
||||||
finally:
|
finally:
|
||||||
|
if not task.done():
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
with contextlib.suppress(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
|
||||||
if not stream_failed:
|
if not stream_failed:
|
||||||
await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop"))
|
await resp.write(_sse_chunk("", model_name, chunk_id, finish_reason="stop"))
|
||||||
@ -284,7 +304,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
# -- non-streaming path (original logic) --
|
# -- non-streaming path (original logic) --
|
||||||
_FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE
|
fallback = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with session_lock:
|
async with session_lock:
|
||||||
@ -316,7 +336,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response:
|
|||||||
response_text = _response_text(retry_response)
|
response_text = _response_text(retry_response)
|
||||||
if not response_text or not response_text.strip():
|
if not response_text or not response_text.strip():
|
||||||
logger.warning("Empty response after retry, using fallback")
|
logger.warning("Empty response after retry, using fallback")
|
||||||
response_text = _FALLBACK
|
response_text = fallback
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
return _error_json(504, f"Request timed out after {timeout_s}s")
|
return _error_json(504, f"Request timed out after {timeout_s}s")
|
||||||
|
|||||||
@ -9,7 +9,7 @@ import zipfile
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import unquote, urlparse
|
from urllib.parse import unquote, urljoin, urlparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -19,6 +19,10 @@ from nanobot.bus.events import OutboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
|
from nanobot.security.network import validate_resolved_url, validate_url_target
|
||||||
|
|
||||||
|
DINGTALK_MAX_REMOTE_MEDIA_BYTES = 20 * 1024 * 1024
|
||||||
|
DINGTALK_MAX_REMOTE_MEDIA_REDIRECTS = 3
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from dingtalk_stream import (
|
from dingtalk_stream import (
|
||||||
@ -155,6 +159,8 @@ class DingTalkConfig(Base):
|
|||||||
client_id: str = ""
|
client_id: str = ""
|
||||||
client_secret: str = ""
|
client_secret: str = ""
|
||||||
allow_from: list[str] = Field(default_factory=list)
|
allow_from: list[str] = Field(default_factory=list)
|
||||||
|
allow_remote_media_redirects: bool = False
|
||||||
|
remote_media_redirect_allowed_hosts: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DingTalkChannel(BaseChannel):
|
class DingTalkChannel(BaseChannel):
|
||||||
@ -281,9 +287,12 @@ class DingTalkChannel(BaseChannel):
|
|||||||
|
|
||||||
def _guess_upload_type(self, media_ref: str) -> str:
|
def _guess_upload_type(self, media_ref: str) -> str:
|
||||||
ext = Path(urlparse(media_ref).path).suffix.lower()
|
ext = Path(urlparse(media_ref).path).suffix.lower()
|
||||||
if ext in self._IMAGE_EXTS: return "image"
|
if ext in self._IMAGE_EXTS:
|
||||||
if ext in self._AUDIO_EXTS: return "voice"
|
return "image"
|
||||||
if ext in self._VIDEO_EXTS: return "video"
|
if ext in self._AUDIO_EXTS:
|
||||||
|
return "voice"
|
||||||
|
if ext in self._VIDEO_EXTS:
|
||||||
|
return "video"
|
||||||
return "file"
|
return "file"
|
||||||
|
|
||||||
def _guess_filename(self, media_ref: str, upload_type: str) -> str:
|
def _guess_filename(self, media_ref: str, upload_type: str) -> str:
|
||||||
@ -315,6 +324,146 @@ class DingTalkChannel(BaseChannel):
|
|||||||
return self._zip_bytes(filename, data)
|
return self._zip_bytes(filename, data)
|
||||||
return data, filename, content_type
|
return data, filename, content_type
|
||||||
|
|
||||||
|
def _validate_remote_media_url(self, media_ref: str) -> bool:
|
||||||
|
ok, err = validate_url_target(media_ref)
|
||||||
|
if not ok:
|
||||||
|
logger.warning("DingTalk remote media URL blocked ref={} reason={}", media_ref, err)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _redirect_host_allowed(self, current_url: str, next_url: str) -> bool:
|
||||||
|
current_host = (urlparse(current_url).hostname or "").lower()
|
||||||
|
next_host = (urlparse(next_url).hostname or "").lower()
|
||||||
|
if not next_host:
|
||||||
|
return False
|
||||||
|
if next_host == current_host:
|
||||||
|
return True
|
||||||
|
allowed_hosts = {host.lower() for host in self.config.remote_media_redirect_allowed_hosts}
|
||||||
|
return next_host in allowed_hosts
|
||||||
|
|
||||||
|
def _next_remote_media_url(self, current_url: str, location: str | None) -> str | None:
|
||||||
|
if not self.config.allow_remote_media_redirects:
|
||||||
|
logger.warning("DingTalk media download redirect refused ref={}", current_url)
|
||||||
|
return None
|
||||||
|
if not location:
|
||||||
|
logger.warning("DingTalk media download redirect without Location ref={}", current_url)
|
||||||
|
return None
|
||||||
|
next_url = urljoin(current_url, location)
|
||||||
|
if not self._redirect_host_allowed(current_url, next_url):
|
||||||
|
logger.warning(
|
||||||
|
"DingTalk media download cross-host redirect refused ref={} next={}",
|
||||||
|
current_url,
|
||||||
|
next_url,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
if not self._validate_remote_media_url(next_url):
|
||||||
|
return None
|
||||||
|
return next_url
|
||||||
|
|
||||||
|
async def _fetch_remote_media_bytes(
|
||||||
|
self,
|
||||||
|
media_ref: str,
|
||||||
|
) -> tuple[bytes | None, str | None]:
|
||||||
|
"""Fetch a remote media URL with SSRF, redirect, and size checks."""
|
||||||
|
if not self._http:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not self._validate_remote_media_url(media_ref):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prefer streaming with a running byte cap so large responses are not
|
||||||
|
# materialized before the limit is enforced. Test fakes may only
|
||||||
|
# implement get(), so keep a small compatibility fallback below.
|
||||||
|
stream = getattr(self._http, "stream", None)
|
||||||
|
if stream is not None:
|
||||||
|
current_url = media_ref
|
||||||
|
for _ in range(DINGTALK_MAX_REMOTE_MEDIA_REDIRECTS + 1):
|
||||||
|
async with stream("GET", current_url, follow_redirects=False) as resp:
|
||||||
|
final_ok, final_err = validate_resolved_url(str(resp.url))
|
||||||
|
if not final_ok:
|
||||||
|
logger.warning(
|
||||||
|
"DingTalk remote media redirect blocked ref={} final={} reason={}",
|
||||||
|
media_ref,
|
||||||
|
resp.url,
|
||||||
|
final_err,
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
if 300 <= resp.status_code < 400:
|
||||||
|
next_url = self._next_remote_media_url(
|
||||||
|
str(resp.url), resp.headers.get("location")
|
||||||
|
)
|
||||||
|
if not next_url:
|
||||||
|
return None, None
|
||||||
|
current_url = next_url
|
||||||
|
continue
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
logger.warning(
|
||||||
|
"DingTalk media download failed status={} ref={}",
|
||||||
|
resp.status_code,
|
||||||
|
current_url,
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
chunks: list[bytes] = []
|
||||||
|
total = 0
|
||||||
|
async for chunk in resp.aiter_bytes():
|
||||||
|
total += len(chunk)
|
||||||
|
if total > DINGTALK_MAX_REMOTE_MEDIA_BYTES:
|
||||||
|
logger.warning(
|
||||||
|
"DingTalk media download too large ref={} bytes>{}",
|
||||||
|
current_url,
|
||||||
|
DINGTALK_MAX_REMOTE_MEDIA_BYTES,
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
chunks.append(chunk)
|
||||||
|
return b"".join(chunks), (resp.headers.get("content-type") or "")
|
||||||
|
logger.warning("DingTalk media download exceeded redirect limit ref={}", media_ref)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
current_url = media_ref
|
||||||
|
for _ in range(DINGTALK_MAX_REMOTE_MEDIA_REDIRECTS + 1):
|
||||||
|
resp = await self._http.get(current_url, follow_redirects=False)
|
||||||
|
final_ok, final_err = validate_resolved_url(str(getattr(resp, "url", current_url)))
|
||||||
|
if not final_ok:
|
||||||
|
logger.warning(
|
||||||
|
"DingTalk remote media redirect blocked ref={} final={} reason={}",
|
||||||
|
media_ref,
|
||||||
|
getattr(resp, "url", current_url),
|
||||||
|
final_err,
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
if 300 <= resp.status_code < 400:
|
||||||
|
next_url = self._next_remote_media_url(
|
||||||
|
str(getattr(resp, "url", current_url)), resp.headers.get("location")
|
||||||
|
)
|
||||||
|
if not next_url:
|
||||||
|
return None, None
|
||||||
|
current_url = next_url
|
||||||
|
continue
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
logger.warning(
|
||||||
|
"DingTalk media download failed status={} ref={}",
|
||||||
|
resp.status_code,
|
||||||
|
current_url,
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
if len(resp.content) > DINGTALK_MAX_REMOTE_MEDIA_BYTES:
|
||||||
|
logger.warning(
|
||||||
|
"DingTalk media download too large ref={} bytes>{}",
|
||||||
|
current_url,
|
||||||
|
DINGTALK_MAX_REMOTE_MEDIA_BYTES,
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
return resp.content, (resp.headers.get("content-type") or "")
|
||||||
|
logger.warning("DingTalk media download exceeded redirect limit ref={}", media_ref)
|
||||||
|
return None, None
|
||||||
|
except httpx.TransportError as e:
|
||||||
|
logger.error("DingTalk media download network error ref={} err={}", media_ref, e)
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
||||||
|
return None, None
|
||||||
|
|
||||||
async def _read_media_bytes(
|
async def _read_media_bytes(
|
||||||
self,
|
self,
|
||||||
media_ref: str,
|
media_ref: str,
|
||||||
@ -323,26 +472,12 @@ class DingTalkChannel(BaseChannel):
|
|||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
||||||
if self._is_http_url(media_ref):
|
if self._is_http_url(media_ref):
|
||||||
if not self._http:
|
data, raw_content_type = await self._fetch_remote_media_bytes(media_ref)
|
||||||
|
if data is None:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
try:
|
content_type = (raw_content_type or "").split(";")[0].strip()
|
||||||
resp = await self._http.get(media_ref, follow_redirects=True)
|
|
||||||
if resp.status_code >= 400:
|
|
||||||
logger.warning(
|
|
||||||
"DingTalk media download failed status={} ref={}",
|
|
||||||
resp.status_code,
|
|
||||||
media_ref,
|
|
||||||
)
|
|
||||||
return None, None, None
|
|
||||||
content_type = (resp.headers.get("content-type") or "").split(";")[0].strip()
|
|
||||||
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
filename = self._guess_filename(media_ref, self._guess_upload_type(media_ref))
|
||||||
return resp.content, filename, content_type or None
|
return data, filename, content_type or None
|
||||||
except httpx.TransportError as e:
|
|
||||||
logger.error("DingTalk media download network error ref={} err={}", media_ref, e)
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("DingTalk media download error ref={} err={}", media_ref, e)
|
|
||||||
return None, None, None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if media_ref.startswith("file://"):
|
if media_ref.startswith("file://"):
|
||||||
@ -435,8 +570,10 @@ class DingTalkChannel(BaseChannel):
|
|||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
|
logger.error("DingTalk send failed msgKey={} status={} body={}", msg_key, resp.status_code, body[:500])
|
||||||
return False
|
return False
|
||||||
try: result = resp.json()
|
try:
|
||||||
except Exception: result = {}
|
result = resp.json()
|
||||||
|
except Exception:
|
||||||
|
result = {}
|
||||||
errcode = result.get("errcode")
|
errcode = result.get("errcode")
|
||||||
if errcode not in (None, 0):
|
if errcode not in (None, 0):
|
||||||
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
|
logger.error("DingTalk send api error msgKey={} errcode={} body={}", msg_key, errcode, body[:500])
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import time
|
import time
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
@ -564,10 +565,8 @@ class DiscordChannel(BaseChannel):
|
|||||||
# Delayed working indicator (cosmetic — not tied to subagent lifecycle)
|
# Delayed working indicator (cosmetic — not tied to subagent lifecycle)
|
||||||
async def _delayed_working_emoji() -> None:
|
async def _delayed_working_emoji() -> None:
|
||||||
await asyncio.sleep(self.config.working_emoji_delay)
|
await asyncio.sleep(self.config.working_emoji_delay)
|
||||||
try:
|
with suppress(Exception):
|
||||||
await message.add_reaction(self.config.working_emoji)
|
await message.add_reaction(self.config.working_emoji)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self._working_emoji_tasks[channel_id] = asyncio.create_task(_delayed_working_emoji())
|
self._working_emoji_tasks[channel_id] = asyncio.create_task(_delayed_working_emoji())
|
||||||
|
|
||||||
@ -771,10 +770,8 @@ class DiscordChannel(BaseChannel):
|
|||||||
if task is None:
|
if task is None:
|
||||||
return
|
return
|
||||||
task.cancel()
|
task.cancel()
|
||||||
try:
|
with suppress(asyncio.CancelledError):
|
||||||
await task
|
await task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _clear_reactions(self, chat_id: str) -> None:
|
async def _clear_reactions(self, chat_id: str) -> None:
|
||||||
"""Remove all pending reactions after bot replies."""
|
"""Remove all pending reactions after bot replies."""
|
||||||
@ -788,10 +785,8 @@ class DiscordChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
bot_user = self._client.user if self._client else None
|
bot_user = self._client.user if self._client else None
|
||||||
for emoji in (self.config.read_receipt_emoji, self.config.working_emoji):
|
for emoji in (self.config.read_receipt_emoji, self.config.working_emoji):
|
||||||
try:
|
with suppress(Exception):
|
||||||
await msg_obj.remove_reaction(emoji, bot_user)
|
await msg_obj.remove_reaction(emoji, bot_user)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _cancel_all_typing(self) -> None:
|
async def _cancel_all_typing(self) -> None:
|
||||||
"""Stop all typing tasks."""
|
"""Stop all typing tasks."""
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import imaplib
|
|||||||
import re
|
import re
|
||||||
import smtplib
|
import smtplib
|
||||||
import ssl
|
import ssl
|
||||||
|
from contextlib import suppress
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from email import policy
|
from email import policy
|
||||||
from email.header import decode_header, make_header
|
from email.header import decode_header, make_header
|
||||||
@ -460,10 +461,8 @@ class EmailChannel(BaseChannel):
|
|||||||
if mark_seen:
|
if mark_seen:
|
||||||
client.store(imap_id, "+FLAGS", "\\Seen")
|
client.store(imap_id, "+FLAGS", "\\Seen")
|
||||||
finally:
|
finally:
|
||||||
try:
|
with suppress(Exception):
|
||||||
client.logout()
|
client.logout()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _collect_self_addresses(self) -> set[str]:
|
def _collect_self_addresses(self) -> set[str]:
|
||||||
"""Return normalized email addresses owned by this channel instance."""
|
"""Return normalized email addresses owned by this channel instance."""
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@ -612,12 +613,11 @@ class FeishuChannel(BaseChannel):
|
|||||||
"""Callback: store reaction_id after background add-reaction completes."""
|
"""Callback: store reaction_id after background add-reaction completes."""
|
||||||
if task.cancelled():
|
if task.cancelled():
|
||||||
return
|
return
|
||||||
try:
|
# Failures already logged by _on_background_task_done.
|
||||||
|
with suppress(Exception):
|
||||||
reaction_id = task.result()
|
reaction_id = task.result()
|
||||||
if reaction_id:
|
if reaction_id:
|
||||||
self._reaction_ids[message_id] = reaction_id
|
self._reaction_ids[message_id] = reaction_id
|
||||||
except Exception:
|
|
||||||
pass # already logged by _on_background_task_done
|
|
||||||
# Trim cache to prevent unbounded growth
|
# Trim cache to prevent unbounded growth
|
||||||
if len(self._reaction_ids) > 500:
|
if len(self._reaction_ids) > 500:
|
||||||
self._reaction_ids.pop(next(iter(self._reaction_ids)))
|
self._reaction_ids.pop(next(iter(self._reaction_ids)))
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from contextlib import suppress
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
@ -37,13 +38,6 @@ _BOOL_CAMEL_ALIASES: dict[str, str] = {
|
|||||||
"send_tool_hints": "sendToolHints",
|
"send_tool_hints": "sendToolHints",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class _RecentOutbound:
|
|
||||||
fingerprint: str
|
|
||||||
ts: float
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelManager:
|
class ChannelManager:
|
||||||
"""
|
"""
|
||||||
Manages chat channels and coordinates message routing.
|
Manages chat channels and coordinates message routing.
|
||||||
@ -66,7 +60,7 @@ class ChannelManager:
|
|||||||
self._session_manager = session_manager
|
self._session_manager = session_manager
|
||||||
self.channels: dict[str, BaseChannel] = {}
|
self.channels: dict[str, BaseChannel] = {}
|
||||||
self._dispatch_task: asyncio.Task | None = None
|
self._dispatch_task: asyncio.Task | None = None
|
||||||
self._recent_outbound: dict[tuple[str, str], _RecentOutbound] = {}
|
self._origin_reply_fingerprints: dict[tuple[str, str, str], str] = {}
|
||||||
|
|
||||||
self._init_channels()
|
self._init_channels()
|
||||||
|
|
||||||
@ -228,10 +222,8 @@ class ChannelManager:
|
|||||||
# Stop dispatcher
|
# Stop dispatcher
|
||||||
if self._dispatch_task:
|
if self._dispatch_task:
|
||||||
self._dispatch_task.cancel()
|
self._dispatch_task.cancel()
|
||||||
try:
|
with suppress(asyncio.CancelledError):
|
||||||
await self._dispatch_task
|
await self._dispatch_task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Stop all channels
|
# Stop all channels
|
||||||
for name, channel in self.channels.items():
|
for name, channel in self.channels.items():
|
||||||
@ -247,17 +239,25 @@ class ChannelManager:
|
|||||||
return hashlib.sha1(normalized.encode("utf-8")).hexdigest() if normalized else ""
|
return hashlib.sha1(normalized.encode("utf-8")).hexdigest() if normalized else ""
|
||||||
|
|
||||||
def _should_suppress_outbound(self, msg: OutboundMessage) -> bool:
|
def _should_suppress_outbound(self, msg: OutboundMessage) -> bool:
|
||||||
if msg.metadata.get("_progress"):
|
metadata = msg.metadata or {}
|
||||||
|
if metadata.get("_progress"):
|
||||||
return False
|
return False
|
||||||
fingerprint = self._fingerprint_content(msg.content)
|
fingerprint = self._fingerprint_content(msg.content)
|
||||||
if not fingerprint:
|
if not fingerprint:
|
||||||
return False
|
return False
|
||||||
key = (msg.channel, msg.chat_id)
|
|
||||||
recent = self._recent_outbound.get(key)
|
origin_message_id = metadata.get("origin_message_id")
|
||||||
now = asyncio.get_running_loop().time()
|
if isinstance(origin_message_id, str) and origin_message_id:
|
||||||
if recent and recent.fingerprint == fingerprint and now - recent.ts <= 8.0:
|
key = (msg.channel, msg.chat_id, origin_message_id)
|
||||||
|
if self._origin_reply_fingerprints.get(key) == fingerprint:
|
||||||
return True
|
return True
|
||||||
self._recent_outbound[key] = _RecentOutbound(fingerprint=fingerprint, ts=now)
|
self._origin_reply_fingerprints[key] = fingerprint
|
||||||
|
|
||||||
|
message_id = metadata.get("message_id")
|
||||||
|
if isinstance(message_id, str) and message_id:
|
||||||
|
key = (msg.channel, msg.chat_id, message_id)
|
||||||
|
self._origin_reply_fingerprints[key] = fingerprint
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _dispatch_outbound(self) -> None:
|
async def _dispatch_outbound(self) -> None:
|
||||||
@ -300,8 +300,13 @@ class ChannelManager:
|
|||||||
|
|
||||||
channel = self.channels.get(msg.channel)
|
channel = self.channels.get(msg.channel)
|
||||||
if channel:
|
if channel:
|
||||||
# Duplicate suppression (non-streaming only)
|
# Duplicate suppression is scoped to a known source message
|
||||||
if not msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end") and not msg.metadata.get("_streamed"):
|
# so repeated content from separate turns is still delivered.
|
||||||
|
if (
|
||||||
|
not msg.metadata.get("_stream_delta")
|
||||||
|
and not msg.metadata.get("_stream_end")
|
||||||
|
and not msg.metadata.get("_streamed")
|
||||||
|
):
|
||||||
if self._should_suppress_outbound(msg):
|
if self._should_suppress_outbound(msg):
|
||||||
logger.info("Suppressing duplicate outbound message to {}:{}", msg.channel, msg.chat_id)
|
logger.info("Suppressing duplicate outbound message to {}:{}", msg.channel, msg.chat_id)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import time
|
import time
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, TypeAlias
|
from typing import Any, Literal, TypeAlias
|
||||||
@ -214,7 +215,7 @@ class MatrixConfig(Base):
|
|||||||
allow_from: list[str] = Field(default_factory=list)
|
allow_from: list[str] = Field(default_factory=list)
|
||||||
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
||||||
group_allow_from: list[str] = Field(default_factory=list)
|
group_allow_from: list[str] = Field(default_factory=list)
|
||||||
allow_room_mentions: bool = False,
|
allow_room_mentions: bool = False
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -251,11 +252,13 @@ class MatrixChannel(BaseChannel):
|
|||||||
self._server_upload_limit_bytes: int | None = None
|
self._server_upload_limit_bytes: int | None = None
|
||||||
self._server_upload_limit_checked = False
|
self._server_upload_limit_checked = False
|
||||||
self._stream_bufs: dict[str, _StreamBuf] = {}
|
self._stream_bufs: dict[str, _StreamBuf] = {}
|
||||||
|
self._started_at_ms: int = 0
|
||||||
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start Matrix client and begin sync loop."""
|
"""Start Matrix client and begin sync loop."""
|
||||||
self._running = True
|
self._running = True
|
||||||
|
self._started_at_ms = int(time.time() * 1000)
|
||||||
_configure_nio_logging_bridge()
|
_configure_nio_logging_bridge()
|
||||||
|
|
||||||
self.store_path = get_data_dir() / "matrix-store"
|
self.store_path = get_data_dir() / "matrix-store"
|
||||||
@ -341,10 +344,8 @@ class MatrixChannel(BaseChannel):
|
|||||||
timeout=self.config.sync_stop_grace_seconds)
|
timeout=self.config.sync_stop_grace_seconds)
|
||||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||||
self._sync_task.cancel()
|
self._sync_task.cancel()
|
||||||
try:
|
with suppress(asyncio.CancelledError):
|
||||||
await self._sync_task
|
await self._sync_task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
if self.client:
|
if self.client:
|
||||||
await self.client.close()
|
await self.client.close()
|
||||||
|
|
||||||
@ -523,7 +524,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
failures.append(fail)
|
failures.append(fail)
|
||||||
if failures:
|
if failures:
|
||||||
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
|
text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
|
||||||
if text or not candidates:
|
if text.strip():
|
||||||
content = _build_matrix_text_content(text)
|
content = _build_matrix_text_content(text)
|
||||||
if relates_to:
|
if relates_to:
|
||||||
content["m.relates_to"] = relates_to
|
content["m.relates_to"] = relates_to
|
||||||
@ -609,13 +610,11 @@ class MatrixChannel(BaseChannel):
|
|||||||
"""Best-effort typing indicator update."""
|
"""Best-effort typing indicator update."""
|
||||||
if not self.client:
|
if not self.client:
|
||||||
return
|
return
|
||||||
try:
|
with suppress(Exception):
|
||||||
response = await self.client.room_typing(room_id=room_id, typing_state=typing,
|
response = await self.client.room_typing(room_id=room_id, typing_state=typing,
|
||||||
timeout=TYPING_NOTICE_TIMEOUT_MS)
|
timeout=TYPING_NOTICE_TIMEOUT_MS)
|
||||||
if isinstance(response, RoomTypingError):
|
if isinstance(response, RoomTypingError):
|
||||||
logger.debug("Matrix typing failed for {}: {}", room_id, response)
|
logger.debug("Matrix typing failed for {}: {}", room_id, response)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _start_typing_keepalive(self, room_id: str) -> None:
|
async def _start_typing_keepalive(self, room_id: str) -> None:
|
||||||
"""Start periodic typing refresh (spec-recommended keepalive)."""
|
"""Start periodic typing refresh (spec-recommended keepalive)."""
|
||||||
@ -625,22 +624,18 @@ class MatrixChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
async def loop() -> None:
|
async def loop() -> None:
|
||||||
try:
|
with suppress(asyncio.CancelledError):
|
||||||
while self._running:
|
while self._running:
|
||||||
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
|
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
|
||||||
await self._set_typing(room_id, True)
|
await self._set_typing(room_id, True)
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
self._typing_tasks[room_id] = asyncio.create_task(loop())
|
self._typing_tasks[room_id] = asyncio.create_task(loop())
|
||||||
|
|
||||||
async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
|
async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
|
||||||
if task := self._typing_tasks.pop(room_id, None):
|
if task := self._typing_tasks.pop(room_id, None):
|
||||||
task.cancel()
|
task.cancel()
|
||||||
try:
|
with suppress(asyncio.CancelledError):
|
||||||
await task
|
await task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
if clear_typing:
|
if clear_typing:
|
||||||
await self._set_typing(room_id, False)
|
await self._set_typing(room_id, False)
|
||||||
|
|
||||||
@ -674,6 +669,16 @@ class MatrixChannel(BaseChannel):
|
|||||||
return True
|
return True
|
||||||
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
|
return bool(self.config.allow_room_mentions and mentions.get("room") is True)
|
||||||
|
|
||||||
|
def _is_pre_startup_event(self, event: RoomMessage) -> bool:
|
||||||
|
"""Skip events that landed in the timeline before this process started.
|
||||||
|
|
||||||
|
Matrix sync replays the room timeline on each startup/restart; without
|
||||||
|
this filter old messages would be re-handled as if they were fresh
|
||||||
|
(#3553).
|
||||||
|
"""
|
||||||
|
ts = getattr(event, "server_timestamp", None)
|
||||||
|
return isinstance(ts, int) and ts < self._started_at_ms
|
||||||
|
|
||||||
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
|
def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
|
||||||
"""Apply sender and room policy checks."""
|
"""Apply sender and room policy checks."""
|
||||||
if not self.is_allowed(event.sender):
|
if not self.is_allowed(event.sender):
|
||||||
@ -858,7 +863,11 @@ class MatrixChannel(BaseChannel):
|
|||||||
return meta
|
return meta
|
||||||
|
|
||||||
async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
|
||||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
if (
|
||||||
|
event.sender == self.config.user_id
|
||||||
|
or self._is_pre_startup_event(event)
|
||||||
|
or not self._should_process_message(room, event)
|
||||||
|
):
|
||||||
return
|
return
|
||||||
await self._start_typing_keepalive(room.room_id)
|
await self._start_typing_keepalive(room.room_id)
|
||||||
try:
|
try:
|
||||||
@ -871,7 +880,11 @@ class MatrixChannel(BaseChannel):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
|
async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
|
||||||
if event.sender == self.config.user_id or not self._should_process_message(room, event):
|
if (
|
||||||
|
event.sender == self.config.user_id
|
||||||
|
or self._is_pre_startup_event(event)
|
||||||
|
or not self._should_process_message(room, event)
|
||||||
|
):
|
||||||
return
|
return
|
||||||
attachment, marker = await self._fetch_media_attachment(room, event)
|
attachment, marker = await self._fetch_media_attachment(room, event)
|
||||||
parts: list[str] = []
|
parts: list[str] = []
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -330,10 +331,8 @@ class MochatChannel(BaseChannel):
|
|||||||
await self._cancel_delay_timers()
|
await self._cancel_delay_timers()
|
||||||
|
|
||||||
if self._socket:
|
if self._socket:
|
||||||
try:
|
with suppress(Exception):
|
||||||
await self._socket.disconnect()
|
await self._socket.disconnect()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._socket = None
|
self._socket = None
|
||||||
|
|
||||||
if self._cursor_save_task:
|
if self._cursor_save_task:
|
||||||
@ -460,10 +459,8 @@ class MochatChannel(BaseChannel):
|
|||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to connect Mochat websocket: {}", e)
|
logger.error("Failed to connect Mochat websocket: {}", e)
|
||||||
try:
|
with suppress(Exception):
|
||||||
await client.disconnect()
|
await client.disconnect()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._socket = None
|
self._socket = None
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ import re
|
|||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@ -712,10 +712,8 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
os.replace(tmp_path, path)
|
os.replace(tmp_path, path)
|
||||||
finally:
|
finally:
|
||||||
if tmp_path and os.path.exists(tmp_path):
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
try:
|
with suppress(OSError):
|
||||||
os.unlink(tmp_path)
|
os.unlink(tmp_path)
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _save_refs_locked(self, *, prune: bool = True) -> None:
|
def _save_refs_locked(self, *, prune: bool = True) -> None:
|
||||||
"""Persist conversation references (caller must hold _refs_guard)."""
|
"""Persist conversation references (caller must hold _refs_guard)."""
|
||||||
|
|||||||
@ -25,6 +25,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from contextlib import suppress
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
from urllib.parse import unquote, urlparse
|
from urllib.parse import unquote, urlparse
|
||||||
@ -221,17 +222,13 @@ class QQChannel(BaseChannel):
|
|||||||
"""Stop bot and cleanup resources."""
|
"""Stop bot and cleanup resources."""
|
||||||
self._running = False
|
self._running = False
|
||||||
if self._client:
|
if self._client:
|
||||||
try:
|
with suppress(Exception):
|
||||||
await self._client.close()
|
await self._client.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
if self._http:
|
if self._http:
|
||||||
try:
|
with suppress(Exception):
|
||||||
await self._http.close()
|
await self._http.close()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self._http = None
|
self._http = None
|
||||||
|
|
||||||
logger.info("QQ bot stopped")
|
logger.info("QQ bot stopped")
|
||||||
@ -683,7 +680,5 @@ class QQChannel(BaseChannel):
|
|||||||
finally:
|
finally:
|
||||||
# Cleanup partial file
|
# Cleanup partial file
|
||||||
if tmp_path is not None:
|
if tmp_path is not None:
|
||||||
try:
|
with suppress(Exception):
|
||||||
tmp_path.unlink(missing_ok=True)
|
tmp_path.unlink(missing_ok=True)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import asyncio
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
@ -462,10 +463,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
if not msg.metadata.get("_progress", False):
|
if not msg.metadata.get("_progress", False):
|
||||||
self._stop_typing(msg.chat_id)
|
self._stop_typing(msg.chat_id)
|
||||||
if reply_to_message_id := msg.metadata.get("message_id"):
|
if reply_to_message_id := msg.metadata.get("message_id"):
|
||||||
try:
|
with suppress(ValueError):
|
||||||
await self._remove_reaction(msg.chat_id, int(reply_to_message_id))
|
await self._remove_reaction(msg.chat_id, int(reply_to_message_id))
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chat_id = int(msg.chat_id)
|
chat_id = int(msg.chat_id)
|
||||||
@ -642,10 +641,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
self._stop_typing(chat_id)
|
self._stop_typing(chat_id)
|
||||||
if reply_to_message_id := meta.get("message_id"):
|
if reply_to_message_id := meta.get("message_id"):
|
||||||
try:
|
with suppress(ValueError):
|
||||||
await self._remove_reaction(chat_id, int(reply_to_message_id))
|
await self._remove_reaction(chat_id, int(reply_to_message_id))
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
thread_kwargs = {}
|
thread_kwargs = {}
|
||||||
if message_thread_id := meta.get("message_thread_id"):
|
if message_thread_id := meta.get("message_thread_id"):
|
||||||
thread_kwargs["message_thread_id"] = message_thread_id
|
thread_kwargs["message_thread_id"] = message_thread_id
|
||||||
@ -1162,11 +1159,10 @@ class TelegramChannel(BaseChannel):
|
|||||||
async def _typing_loop(self, chat_id: str) -> None:
|
async def _typing_loop(self, chat_id: str) -> None:
|
||||||
"""Repeatedly send 'typing' action until cancelled."""
|
"""Repeatedly send 'typing' action until cancelled."""
|
||||||
try:
|
try:
|
||||||
|
with suppress(asyncio.CancelledError):
|
||||||
while self._app:
|
while self._app:
|
||||||
await self._app.bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
await self._app.bot.send_chat_action(chat_id=int(chat_id), action="typing")
|
||||||
await asyncio.sleep(4)
|
await asyncio.sleep(4)
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
|
||||||
|
|
||||||
@ -1265,10 +1261,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
button_label = query.data or ""
|
button_label = query.data or ""
|
||||||
await query.answer()
|
await query.answer()
|
||||||
if query.message:
|
if query.message:
|
||||||
try:
|
with suppress(Exception):
|
||||||
await query.message.edit_reply_markup(reply_markup=None)
|
await query.message.edit_reply_markup(reply_markup=None)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
logger.debug("Inline button tap from {}: {}", sender_id, button_label)
|
logger.debug("Inline button tap from {}: {}", sender_id, button_label)
|
||||||
self._start_typing(str(chat_id))
|
self._start_typing(str(chat_id))
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
|
|||||||
@ -19,6 +19,7 @@ import re
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from contextlib import suppress
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
@ -211,7 +212,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
|
|
||||||
def _save_state(self) -> None:
|
def _save_state(self) -> None:
|
||||||
state_file = self._get_state_dir() / "account.json"
|
state_file = self._get_state_dir() / "account.json"
|
||||||
try:
|
with suppress(Exception):
|
||||||
data = {
|
data = {
|
||||||
"token": self._token,
|
"token": self._token,
|
||||||
"get_updates_buf": self._get_updates_buf,
|
"get_updates_buf": self._get_updates_buf,
|
||||||
@ -220,8 +221,6 @@ class WeixinChannel(BaseChannel):
|
|||||||
"base_url": self.config.base_url,
|
"base_url": self.config.base_url,
|
||||||
}
|
}
|
||||||
state_file.write_text(json.dumps(data, ensure_ascii=False))
|
state_file.write_text(json.dumps(data, ensure_ascii=False))
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# HTTP helpers (matches api.ts buildHeaders / apiFetch)
|
# HTTP helpers (matches api.ts buildHeaders / apiFetch)
|
||||||
@ -576,10 +575,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
# Process messages (WeixinMessage[] from types.ts)
|
# Process messages (WeixinMessage[] from types.ts)
|
||||||
msgs: list[dict] = data.get("msgs", []) or []
|
msgs: list[dict] = data.get("msgs", []) or []
|
||||||
for msg in msgs:
|
for msg in msgs:
|
||||||
try:
|
with suppress(Exception):
|
||||||
await self._process_message(msg)
|
await self._process_message(msg)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Inbound message processing (matches inbound.ts + process-message.ts)
|
# Inbound message processing (matches inbound.ts + process-message.ts)
|
||||||
@ -932,10 +929,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
|
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
|
||||||
if stop_event.is_set():
|
if stop_event.is_set():
|
||||||
break
|
break
|
||||||
try:
|
with suppress(Exception):
|
||||||
await self._send_typing(user_id, typing_ticket, TYPING_STATUS_TYPING)
|
await self._send_typing(user_id, typing_ticket, TYPING_STATUS_TYPING)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -962,16 +957,12 @@ class WeixinChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
typing_ticket = ""
|
typing_ticket = ""
|
||||||
try:
|
with suppress(Exception):
|
||||||
typing_ticket = await self._get_typing_ticket(msg.chat_id, ctx_token)
|
typing_ticket = await self._get_typing_ticket(msg.chat_id, ctx_token)
|
||||||
except Exception:
|
|
||||||
typing_ticket = ""
|
|
||||||
|
|
||||||
if typing_ticket:
|
if typing_ticket:
|
||||||
try:
|
with suppress(Exception):
|
||||||
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_TYPING)
|
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_TYPING)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
typing_keepalive_stop = asyncio.Event()
|
typing_keepalive_stop = asyncio.Event()
|
||||||
typing_keepalive_task: asyncio.Task | None = None
|
typing_keepalive_task: asyncio.Task | None = None
|
||||||
@ -1043,16 +1034,12 @@ class WeixinChannel(BaseChannel):
|
|||||||
if typing_keepalive_task:
|
if typing_keepalive_task:
|
||||||
typing_keepalive_stop.set()
|
typing_keepalive_stop.set()
|
||||||
typing_keepalive_task.cancel()
|
typing_keepalive_task.cancel()
|
||||||
try:
|
with suppress(asyncio.CancelledError):
|
||||||
await typing_keepalive_task
|
await typing_keepalive_task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if typing_ticket and not is_progress:
|
if typing_ticket and not is_progress:
|
||||||
try:
|
with suppress(Exception):
|
||||||
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
|
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _start_typing(self, chat_id: str, context_token: str = "") -> None:
|
async def _start_typing(self, chat_id: str, context_token: str = "") -> None:
|
||||||
"""Start typing indicator immediately when a message is received."""
|
"""Start typing indicator immediately when a message is received."""
|
||||||
@ -1076,10 +1063,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
|
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
|
||||||
if stop_event.is_set():
|
if stop_event.is_set():
|
||||||
break
|
break
|
||||||
try:
|
with suppress(Exception):
|
||||||
await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING)
|
await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -1095,10 +1080,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
if stop_event:
|
if stop_event:
|
||||||
stop_event.set()
|
stop_event.set()
|
||||||
task.cancel()
|
task.cancel()
|
||||||
try:
|
with suppress(asyncio.CancelledError):
|
||||||
await task
|
await task
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
if not clear_remote:
|
if not clear_remote:
|
||||||
return
|
return
|
||||||
entry = self._typing_tickets.get(chat_id)
|
entry = self._typing_tickets.get(chat_id)
|
||||||
@ -1339,13 +1322,11 @@ def _encrypt_aes_ecb(data: bytes, aes_key_b64: str) -> bytes:
|
|||||||
pad_len = 16 - len(data) % 16
|
pad_len = 16 - len(data) % 16
|
||||||
padded = data + bytes([pad_len] * pad_len)
|
padded = data + bytes([pad_len] * pad_len)
|
||||||
|
|
||||||
try:
|
with suppress(ImportError):
|
||||||
from Crypto.Cipher import AES
|
from Crypto.Cipher import AES
|
||||||
|
|
||||||
cipher = AES.new(key, AES.MODE_ECB)
|
cipher = AES.new(key, AES.MODE_ECB)
|
||||||
return cipher.encrypt(padded)
|
return cipher.encrypt(padded)
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
@ -1371,13 +1352,11 @@ def _decrypt_aes_ecb(data: bytes, aes_key_b64: str) -> bytes:
|
|||||||
|
|
||||||
decrypted: bytes | None = None
|
decrypted: bytes | None = None
|
||||||
|
|
||||||
try:
|
with suppress(ImportError):
|
||||||
from Crypto.Cipher import AES
|
from Crypto.Cipher import AES
|
||||||
|
|
||||||
cipher = AES.new(key, AES.MODE_ECB)
|
cipher = AES.new(key, AES.MODE_ECB)
|
||||||
decrypted = cipher.decrypt(data)
|
decrypted = cipher.decrypt(data)
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if decrypted is None:
|
if decrypted is None:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import os
|
|||||||
import secrets
|
import secrets
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from contextlib import suppress
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
@ -47,10 +48,8 @@ def _load_or_create_bridge_token(path: Path) -> str:
|
|||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
token = secrets.token_urlsafe(32)
|
token = secrets.token_urlsafe(32)
|
||||||
path.write_text(token, encoding="utf-8")
|
path.write_text(token, encoding="utf-8")
|
||||||
try:
|
with suppress(OSError):
|
||||||
path.chmod(0o600)
|
path.chmod(0o600)
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import os
|
|||||||
import select
|
import select
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext, suppress
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -14,11 +14,9 @@ if sys.platform == "win32":
|
|||||||
if sys.stdout.encoding != "utf-8":
|
if sys.stdout.encoding != "utf-8":
|
||||||
os.environ["PYTHONIOENCODING"] = "utf-8"
|
os.environ["PYTHONIOENCODING"] = "utf-8"
|
||||||
# Re-open stdout/stderr with UTF-8 encoding
|
# Re-open stdout/stderr with UTF-8 encoding
|
||||||
try:
|
with suppress(Exception):
|
||||||
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
||||||
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -83,35 +81,29 @@ def _flush_pending_tty_input() -> None:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
with suppress(Exception):
|
||||||
import termios
|
import termios
|
||||||
|
|
||||||
termios.tcflush(fd, termios.TCIFLUSH)
|
termios.tcflush(fd, termios.TCIFLUSH)
|
||||||
return
|
return
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
with suppress(Exception):
|
||||||
while True:
|
while True:
|
||||||
ready, _, _ = select.select([fd], [], [], 0)
|
ready, _, _ = select.select([fd], [], [], 0)
|
||||||
if not ready:
|
if not ready:
|
||||||
break
|
break
|
||||||
if not os.read(fd, 4096):
|
if not os.read(fd, 4096):
|
||||||
break
|
break
|
||||||
except Exception:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def _restore_terminal() -> None:
|
def _restore_terminal() -> None:
|
||||||
"""Restore terminal to its original state (echo, line buffering, etc.)."""
|
"""Restore terminal to its original state (echo, line buffering, etc.)."""
|
||||||
if _SAVED_TERM_ATTRS is None:
|
if _SAVED_TERM_ATTRS is None:
|
||||||
return
|
return
|
||||||
try:
|
with suppress(Exception):
|
||||||
import termios
|
import termios
|
||||||
|
|
||||||
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, _SAVED_TERM_ATTRS)
|
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, _SAVED_TERM_ATTRS)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _init_prompt_session() -> None:
|
def _init_prompt_session() -> None:
|
||||||
@ -119,12 +111,10 @@ def _init_prompt_session() -> None:
|
|||||||
global _PROMPT_SESSION, _SAVED_TERM_ATTRS
|
global _PROMPT_SESSION, _SAVED_TERM_ATTRS
|
||||||
|
|
||||||
# Save terminal state so we can restore it on exit
|
# Save terminal state so we can restore it on exit
|
||||||
try:
|
with suppress(Exception):
|
||||||
import termios
|
import termios
|
||||||
|
|
||||||
_SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno())
|
_SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno())
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from nanobot.config.paths import get_cli_history_path
|
from nanobot.config.paths import get_cli_history_path
|
||||||
|
|
||||||
@ -936,10 +926,8 @@ def _run_gateway(
|
|||||||
config.gateway.host or "127.0.0.1", port
|
config.gateway.host or "127.0.0.1", port
|
||||||
)
|
)
|
||||||
writer.close()
|
writer.close()
|
||||||
try:
|
with suppress(Exception):
|
||||||
await writer.wait_closed()
|
await writer.wait_closed()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
break
|
break
|
||||||
except OSError:
|
except OSError:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
@ -1520,10 +1508,8 @@ def _login_openai_codex() -> None:
|
|||||||
from oauth_cli_kit import get_token, login_oauth_interactive
|
from oauth_cli_kit import get_token, login_oauth_interactive
|
||||||
|
|
||||||
token = None
|
token = None
|
||||||
try:
|
with suppress(Exception):
|
||||||
token = get_token()
|
token = get_token()
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if not (token and token.access):
|
if not (token and token.access):
|
||||||
console.print("[cyan]Starting interactive OAuth login...[/cyan]\n")
|
console.print("[cyan]Starting interactive OAuth login...[/cyan]\n")
|
||||||
token = login_oauth_interactive(
|
token = login_oauth_interactive(
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
from nanobot import __version__
|
from nanobot import __version__
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
@ -50,16 +51,15 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
|||||||
loop = ctx.loop
|
loop = ctx.loop
|
||||||
session = ctx.session or loop.sessions.get_or_create(ctx.key)
|
session = ctx.session or loop.sessions.get_or_create(ctx.key)
|
||||||
ctx_est = 0
|
ctx_est = 0
|
||||||
try:
|
with suppress(Exception):
|
||||||
ctx_est, _ = loop.consolidator.estimate_session_prompt_tokens(session)
|
ctx_est, _ = loop.consolidator.estimate_session_prompt_tokens(session)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if ctx_est <= 0:
|
if ctx_est <= 0:
|
||||||
ctx_est = loop._last_usage.get("prompt_tokens", 0)
|
ctx_est = loop._last_usage.get("prompt_tokens", 0)
|
||||||
|
|
||||||
# Fetch web search provider usage (best-effort, never blocks the response)
|
# Fetch web search provider usage (best-effort, never blocks the response)
|
||||||
search_usage_text: str | None = None
|
search_usage_text: str | None = None
|
||||||
try:
|
# Never let usage fetch break /status
|
||||||
|
with suppress(Exception):
|
||||||
from nanobot.utils.searchusage import fetch_search_usage
|
from nanobot.utils.searchusage import fetch_search_usage
|
||||||
web_cfg = getattr(loop, "web_config", None)
|
web_cfg = getattr(loop, "web_config", None)
|
||||||
search_cfg = getattr(web_cfg, "search", None) if web_cfg else None
|
search_cfg = getattr(web_cfg, "search", None) if web_cfg else None
|
||||||
@ -68,14 +68,10 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage:
|
|||||||
api_key = getattr(search_cfg, "api_key", "") or None
|
api_key = getattr(search_cfg, "api_key", "") or None
|
||||||
usage = await fetch_search_usage(provider=provider, api_key=api_key)
|
usage = await fetch_search_usage(provider=provider, api_key=api_key)
|
||||||
search_usage_text = usage.format()
|
search_usage_text = usage.format()
|
||||||
except Exception:
|
|
||||||
pass # Never let usage fetch break /status
|
|
||||||
active_tasks = loop._active_tasks.get(ctx.key, [])
|
active_tasks = loop._active_tasks.get(ctx.key, [])
|
||||||
task_count = sum(1 for t in active_tasks if not t.done())
|
task_count = sum(1 for t in active_tasks if not t.done())
|
||||||
try:
|
with suppress(Exception):
|
||||||
task_count += loop.subagents.get_running_count_by_session(ctx.key)
|
task_count += loop.subagents.get_running_count_by_session(ctx.key)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=ctx.msg.channel,
|
channel=ctx.msg.channel,
|
||||||
chat_id=ctx.msg.chat_id,
|
chat_id=ctx.msg.chat_id,
|
||||||
|
|||||||
@ -119,11 +119,19 @@ class ProviderConfig(Base):
|
|||||||
extra_body: dict[str, Any] | None = None # Extra fields merged into every request body
|
extra_body: dict[str, Any] | None = None # Extra fields merged into every request body
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockProviderConfig(ProviderConfig):
|
||||||
|
"""AWS Bedrock Runtime provider configuration."""
|
||||||
|
|
||||||
|
region: str | None = None # AWS region, falls back to AWS_REGION/AWS_DEFAULT_REGION/profile
|
||||||
|
profile: str | None = None # Optional AWS shared config profile
|
||||||
|
|
||||||
|
|
||||||
class ProvidersConfig(Base):
|
class ProvidersConfig(Base):
|
||||||
"""Configuration for LLM providers."""
|
"""Configuration for LLM providers."""
|
||||||
|
|
||||||
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
|
custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
|
||||||
azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
|
azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
|
||||||
|
bedrock: BedrockProviderConfig = Field(default_factory=BedrockProviderConfig) # AWS Bedrock Converse
|
||||||
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
openai: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
|
||||||
@ -287,14 +295,14 @@ class Config(BaseSettings):
|
|||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
p = getattr(self.providers, spec.name, None)
|
p = getattr(self.providers, spec.name, None)
|
||||||
if p and model_prefix and normalized_prefix == spec.name:
|
if p and model_prefix and normalized_prefix == spec.name:
|
||||||
if spec.is_oauth or spec.is_local or p.api_key:
|
if spec.is_oauth or spec.is_local or spec.is_direct or p.api_key:
|
||||||
return p, spec.name
|
return p, spec.name
|
||||||
|
|
||||||
# Match by keyword (order follows PROVIDERS registry)
|
# Match by keyword (order follows PROVIDERS registry)
|
||||||
for spec in PROVIDERS:
|
for spec in PROVIDERS:
|
||||||
p = getattr(self.providers, spec.name, None)
|
p = getattr(self.providers, spec.name, None)
|
||||||
if p and any(_kw_matches(kw) for kw in spec.keywords):
|
if p and any(_kw_matches(kw) for kw in spec.keywords):
|
||||||
if spec.is_oauth or spec.is_local or p.api_key:
|
if spec.is_oauth or spec.is_local or spec.is_direct or p.api_key:
|
||||||
return p, spec.name
|
return p, spec.name
|
||||||
|
|
||||||
# Fallback: configured local providers can route models without
|
# Fallback: configured local providers can route models without
|
||||||
|
|||||||
@ -15,6 +15,7 @@ __all__ = [
|
|||||||
"OpenAICodexProvider",
|
"OpenAICodexProvider",
|
||||||
"GitHubCopilotProvider",
|
"GitHubCopilotProvider",
|
||||||
"AzureOpenAIProvider",
|
"AzureOpenAIProvider",
|
||||||
|
"BedrockProvider",
|
||||||
]
|
]
|
||||||
|
|
||||||
_LAZY_IMPORTS = {
|
_LAZY_IMPORTS = {
|
||||||
@ -23,11 +24,13 @@ _LAZY_IMPORTS = {
|
|||||||
"OpenAICodexProvider": ".openai_codex_provider",
|
"OpenAICodexProvider": ".openai_codex_provider",
|
||||||
"GitHubCopilotProvider": ".github_copilot_provider",
|
"GitHubCopilotProvider": ".github_copilot_provider",
|
||||||
"AzureOpenAIProvider": ".azure_openai_provider",
|
"AzureOpenAIProvider": ".azure_openai_provider",
|
||||||
|
"BedrockProvider": ".bedrock_provider",
|
||||||
}
|
}
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
|
from nanobot.providers.bedrock_provider import BedrockProvider
|
||||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import suppress
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
@ -643,14 +644,12 @@ class LLMProvider(ABC):
|
|||||||
return value
|
return value
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
with suppress(TypeError, ValueError):
|
||||||
retry_ms = _header_value("retry-after-ms")
|
retry_ms = _header_value("retry-after-ms")
|
||||||
if retry_ms is not None:
|
if retry_ms is not None:
|
||||||
value = float(retry_ms) / 1000.0
|
value = float(retry_ms) / 1000.0
|
||||||
if value > 0:
|
if value > 0:
|
||||||
return value
|
return value
|
||||||
except (TypeError, ValueError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
retry_after = _header_value("retry-after")
|
retry_after = _header_value("retry-after")
|
||||||
if retry_after is None:
|
if retry_after is None:
|
||||||
|
|||||||
730
nanobot/providers/bedrock_provider.py
Normal file
730
nanobot/providers/bedrock_provider.py
Normal file
@ -0,0 +1,730 @@
|
|||||||
|
"""AWS Bedrock Converse provider."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from collections.abc import Awaitable, Callable, Iterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import json_repair
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_IMAGE_DATA_URL = re.compile(r"^data:image/([a-zA-Z0-9.+-]+);base64,(.*)$", re.DOTALL)
|
||||||
|
_TEXT_BLOCK_TYPES = {"text", "input_text", "output_text"}
|
||||||
|
_TEMPERATURE_UNSUPPORTED_MODEL_TOKENS = ("claude-opus-4-7",)
|
||||||
|
_ADAPTIVE_THINKING_ONLY_MODEL_TOKENS = ("claude-opus-4-7",)
|
||||||
|
|
||||||
|
|
||||||
|
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
merged = dict(base)
|
||||||
|
for key, value in override.items():
|
||||||
|
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
|
||||||
|
merged[key] = _deep_merge(merged[key], value)
|
||||||
|
else:
|
||||||
|
merged[key] = value
|
||||||
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def _next_or_none(iterator: Iterator[dict[str, Any]]) -> dict[str, Any] | None:
|
||||||
|
try:
|
||||||
|
return next(iterator)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockProvider(LLMProvider):
|
||||||
|
"""LLM provider using AWS Bedrock Runtime's Converse APIs."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str | None = None,
|
||||||
|
api_base: str | None = None,
|
||||||
|
default_model: str = "bedrock/global.anthropic.claude-opus-4-7",
|
||||||
|
*,
|
||||||
|
region: str | None = None,
|
||||||
|
profile: str | None = None,
|
||||||
|
extra_body: dict[str, Any] | None = None,
|
||||||
|
client: Any | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(api_key, api_base)
|
||||||
|
self.default_model = default_model
|
||||||
|
self.region = region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION")
|
||||||
|
self.profile = profile
|
||||||
|
self._extra_body = extra_body or {}
|
||||||
|
self._client = client if client is not None else self._make_client()
|
||||||
|
|
||||||
|
def _make_client(self) -> Any:
|
||||||
|
if self.api_key:
|
||||||
|
os.environ["AWS_BEARER_TOKEN_BEDROCK"] = self.api_key
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
except ImportError as exc: # pragma: no cover - exercised only without boto3 installed
|
||||||
|
raise RuntimeError(
|
||||||
|
"AWS Bedrock provider requires boto3. Install it with `pip install boto3`."
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
session_kwargs: dict[str, Any] = {}
|
||||||
|
if self.profile:
|
||||||
|
session_kwargs["profile_name"] = self.profile
|
||||||
|
session = boto3.Session(**session_kwargs)
|
||||||
|
|
||||||
|
client_kwargs: dict[str, Any] = {}
|
||||||
|
if self.region:
|
||||||
|
client_kwargs["region_name"] = self.region
|
||||||
|
if self.api_base:
|
||||||
|
client_kwargs["endpoint_url"] = self.api_base
|
||||||
|
return session.client("bedrock-runtime", **client_kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_prefix(model: str) -> str:
|
||||||
|
if model.startswith("bedrock/"):
|
||||||
|
return model[len("bedrock/"):]
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _matches_model_token(model: str, tokens: tuple[str, ...]) -> bool:
|
||||||
|
model_lower = model.lower()
|
||||||
|
return any(token in model_lower for token in tokens)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _supports_temperature(cls, model: str) -> bool:
|
||||||
|
return not cls._matches_model_token(model, _TEMPERATURE_UNSUPPORTED_MODEL_TOKENS)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _uses_adaptive_thinking_only(cls, model: str) -> bool:
|
||||||
|
return cls._matches_model_token(model, _ADAPTIVE_THINKING_ONLY_MODEL_TOKENS)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _image_url_block(block: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
url = (block.get("image_url") or {}).get("url", "")
|
||||||
|
if not isinstance(url, str) or not url:
|
||||||
|
return None
|
||||||
|
match = _IMAGE_DATA_URL.match(url)
|
||||||
|
if not match:
|
||||||
|
return {"text": f"(image URL: {url})"}
|
||||||
|
fmt = match.group(1).lower()
|
||||||
|
if fmt == "jpg":
|
||||||
|
fmt = "jpeg"
|
||||||
|
try:
|
||||||
|
data = base64.b64decode(match.group(2), validate=False)
|
||||||
|
except Exception:
|
||||||
|
return {"text": "(invalid image data)"}
|
||||||
|
return {"image": {"format": fmt, "source": {"bytes": data}}}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _content_blocks(cls, content: Any, *, for_tool_result: bool = False) -> list[dict[str, Any]]:
|
||||||
|
if isinstance(content, str) or content is None:
|
||||||
|
return [{"text": content or "(empty)"}]
|
||||||
|
if not isinstance(content, list):
|
||||||
|
if for_tool_result and isinstance(content, dict):
|
||||||
|
return [{"json": content}]
|
||||||
|
return [{"text": str(content)}]
|
||||||
|
|
||||||
|
blocks: list[dict[str, Any]] = []
|
||||||
|
for item in content:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
blocks.append({"text": str(item)})
|
||||||
|
continue
|
||||||
|
|
||||||
|
item_type = item.get("type")
|
||||||
|
if item_type in _TEXT_BLOCK_TYPES or "text" in item:
|
||||||
|
text = item.get("text")
|
||||||
|
if text:
|
||||||
|
blocks.append({"text": str(text)})
|
||||||
|
continue
|
||||||
|
if item_type == "image_url":
|
||||||
|
converted = cls._image_url_block(item)
|
||||||
|
if converted:
|
||||||
|
blocks.append(converted)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Preserve already-Bedrock-shaped content where possible.
|
||||||
|
for key in ("text", "image", "document", "video", "json", "searchResult"):
|
||||||
|
if key in item:
|
||||||
|
blocks.append({key: item[key]})
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
blocks.append({"json": item} if for_tool_result else {"text": json.dumps(item)})
|
||||||
|
|
||||||
|
return blocks or [{"text": "(empty)"}]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _system_blocks(cls, content: Any) -> list[dict[str, Any]]:
|
||||||
|
return [
|
||||||
|
block for block in cls._content_blocks(content)
|
||||||
|
if "text" in block or "cachePoint" in block or "guardContent" in block
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _tool_result_block(cls, msg: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"toolResult": {
|
||||||
|
"toolUseId": str(msg.get("tool_call_id") or ""),
|
||||||
|
"content": cls._content_blocks(msg.get("content"), for_tool_result=True),
|
||||||
|
"status": "success",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _tool_use_block(tool_call: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
function = tool_call.get("function")
|
||||||
|
if not isinstance(function, dict):
|
||||||
|
return None
|
||||||
|
args = function.get("arguments", {})
|
||||||
|
if isinstance(args, str):
|
||||||
|
try:
|
||||||
|
args = json_repair.loads(args) if args.strip() else {}
|
||||||
|
except Exception:
|
||||||
|
args = {}
|
||||||
|
if not isinstance(args, dict):
|
||||||
|
args = {}
|
||||||
|
return {
|
||||||
|
"toolUse": {
|
||||||
|
"toolUseId": str(tool_call.get("id") or ""),
|
||||||
|
"name": str(function.get("name") or ""),
|
||||||
|
"input": args,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _reasoning_block(block: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
if block.get("type") not in {"thinking", "reasoning", "redacted_thinking"}:
|
||||||
|
return None
|
||||||
|
text = block.get("thinking") or block.get("text")
|
||||||
|
signature = block.get("signature")
|
||||||
|
if text and signature:
|
||||||
|
return {
|
||||||
|
"reasoningContent": {
|
||||||
|
"reasoningText": {"text": str(text), "signature": str(signature)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
redacted = block.get("redactedContent")
|
||||||
|
if redacted is None and isinstance(block.get("redactedContentBase64"), str):
|
||||||
|
try:
|
||||||
|
redacted = base64.b64decode(block["redactedContentBase64"])
|
||||||
|
except Exception:
|
||||||
|
redacted = None
|
||||||
|
if redacted is not None:
|
||||||
|
return {"reasoningContent": {"redactedContent": redacted}}
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _assistant_blocks(cls, msg: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
|
blocks: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for thinking in msg.get("thinking_blocks") or []:
|
||||||
|
if isinstance(thinking, dict):
|
||||||
|
reasoning = cls._reasoning_block(thinking)
|
||||||
|
if reasoning:
|
||||||
|
blocks.append(reasoning)
|
||||||
|
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, str) and content:
|
||||||
|
blocks.append({"text": content})
|
||||||
|
elif isinstance(content, list):
|
||||||
|
blocks.extend(block for block in cls._content_blocks(content) if "text" in block)
|
||||||
|
|
||||||
|
for tool_call in msg.get("tool_calls") or []:
|
||||||
|
if isinstance(tool_call, dict):
|
||||||
|
block = cls._tool_use_block(tool_call)
|
||||||
|
if block:
|
||||||
|
blocks.append(block)
|
||||||
|
|
||||||
|
return blocks or [{"text": ""}]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _has_tool_use(msg: dict[str, Any]) -> bool:
|
||||||
|
content = msg.get("content")
|
||||||
|
return isinstance(content, list) and any(
|
||||||
|
isinstance(block, dict) and "toolUse" in block for block in content
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_consecutive(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
|
merged: list[dict[str, Any]] = []
|
||||||
|
for msg in messages:
|
||||||
|
if merged and merged[-1].get("role") == msg.get("role"):
|
||||||
|
prev = merged[-1].setdefault("content", [])
|
||||||
|
cur = msg.get("content") or []
|
||||||
|
if not isinstance(prev, list):
|
||||||
|
prev = [{"text": str(prev)}]
|
||||||
|
merged[-1]["content"] = prev
|
||||||
|
if isinstance(cur, list):
|
||||||
|
prev.extend(cur)
|
||||||
|
else:
|
||||||
|
prev.append({"text": str(cur)})
|
||||||
|
else:
|
||||||
|
merged.append(msg)
|
||||||
|
|
||||||
|
last_popped: dict[str, Any] | None = None
|
||||||
|
while merged and merged[-1].get("role") == "assistant":
|
||||||
|
last_popped = merged.pop()
|
||||||
|
if not merged and last_popped is not None and not BedrockProvider._has_tool_use(last_popped):
|
||||||
|
merged.append({"role": "user", "content": last_popped.get("content") or [{"text": "(empty)"}]})
|
||||||
|
if merged and merged[0].get("role") == "assistant" and not BedrockProvider._has_tool_use(merged[0]):
|
||||||
|
merged.insert(0, {"role": "user", "content": [{"text": "(conversation continued)"}]})
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def _convert_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||||
|
system: list[dict[str, Any]] = []
|
||||||
|
converted: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role")
|
||||||
|
content = msg.get("content")
|
||||||
|
if role == "system":
|
||||||
|
system.extend(self._system_blocks(content))
|
||||||
|
continue
|
||||||
|
if role == "tool":
|
||||||
|
block = self._tool_result_block(msg)
|
||||||
|
if converted and converted[-1].get("role") == "user":
|
||||||
|
converted[-1].setdefault("content", []).append(block)
|
||||||
|
else:
|
||||||
|
converted.append({"role": "user", "content": [block]})
|
||||||
|
continue
|
||||||
|
if role == "assistant":
|
||||||
|
converted.append({"role": "assistant", "content": self._assistant_blocks(msg)})
|
||||||
|
continue
|
||||||
|
if role == "user":
|
||||||
|
converted.append({"role": "user", "content": self._content_blocks(content)})
|
||||||
|
|
||||||
|
return system, self._merge_consecutive(converted)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_tools(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None:
|
||||||
|
if not tools:
|
||||||
|
return None
|
||||||
|
result: list[dict[str, Any]] = []
|
||||||
|
for tool in tools:
|
||||||
|
func = tool.get("function") if isinstance(tool.get("function"), dict) else tool
|
||||||
|
if not isinstance(func, dict):
|
||||||
|
continue
|
||||||
|
name = str(func.get("name") or "")
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
spec: dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"inputSchema": {
|
||||||
|
"json": func.get("parameters") or {"type": "object", "properties": {}}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
description = func.get("description")
|
||||||
|
if description:
|
||||||
|
spec["description"] = str(description)
|
||||||
|
strict = func.get("strict", tool.get("strict"))
|
||||||
|
if isinstance(strict, bool):
|
||||||
|
spec["strict"] = strict
|
||||||
|
result.append({"toolSpec": spec})
|
||||||
|
return result or None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_tool_choice(
|
||||||
|
tool_choice: str | dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
if tool_choice is None or tool_choice == "auto":
|
||||||
|
return {"auto": {}}
|
||||||
|
if tool_choice == "required":
|
||||||
|
return {"any": {}}
|
||||||
|
if tool_choice == "none":
|
||||||
|
return None
|
||||||
|
if isinstance(tool_choice, dict):
|
||||||
|
name = tool_choice.get("function", {}).get("name")
|
||||||
|
if name:
|
||||||
|
return {"tool": {"name": str(name)}}
|
||||||
|
return {"auto": {}}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _adaptive_thinking(reasoning_effort: str | None) -> dict[str, Any] | None:
|
||||||
|
if not reasoning_effort:
|
||||||
|
return None
|
||||||
|
effort = reasoning_effort.lower()
|
||||||
|
if effort == "none":
|
||||||
|
return None
|
||||||
|
thinking: dict[str, Any] = {"type": "adaptive"}
|
||||||
|
if effort != "adaptive":
|
||||||
|
thinking["effort"] = effort
|
||||||
|
return thinking
|
||||||
|
|
||||||
|
def _build_kwargs(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None,
|
||||||
|
model: str | None,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
reasoning_effort: str | None,
|
||||||
|
tool_choice: str | dict[str, Any] | None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
model_id = self._strip_prefix(model or self.default_model)
|
||||||
|
system, bedrock_messages = self._convert_messages(self._sanitize_empty_content(messages))
|
||||||
|
if not bedrock_messages:
|
||||||
|
bedrock_messages = [{"role": "user", "content": [{"text": "(empty)"}]}]
|
||||||
|
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"modelId": model_id,
|
||||||
|
"messages": bedrock_messages,
|
||||||
|
"inferenceConfig": {"maxTokens": max(1, max_tokens)},
|
||||||
|
}
|
||||||
|
if system:
|
||||||
|
kwargs["system"] = system
|
||||||
|
if self._supports_temperature(model_id):
|
||||||
|
kwargs["inferenceConfig"]["temperature"] = temperature
|
||||||
|
|
||||||
|
additional: dict[str, Any] = {}
|
||||||
|
if self._uses_adaptive_thinking_only(model_id):
|
||||||
|
thinking = self._adaptive_thinking(reasoning_effort)
|
||||||
|
if thinking:
|
||||||
|
additional["thinking"] = thinking
|
||||||
|
if self._extra_body:
|
||||||
|
additional = _deep_merge(additional, self._extra_body)
|
||||||
|
if additional:
|
||||||
|
kwargs["additionalModelRequestFields"] = additional
|
||||||
|
|
||||||
|
bedrock_tools = self._convert_tools(tools)
|
||||||
|
if bedrock_tools:
|
||||||
|
tool_config: dict[str, Any] = {"tools": bedrock_tools}
|
||||||
|
choice = self._convert_tool_choice(tool_choice)
|
||||||
|
if choice:
|
||||||
|
tool_config["toolChoice"] = choice
|
||||||
|
kwargs["toolConfig"] = tool_config
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _finish_reason(stop_reason: str | None) -> str:
|
||||||
|
return {
|
||||||
|
"end_turn": "stop",
|
||||||
|
"tool_use": "tool_calls",
|
||||||
|
"max_tokens": "length",
|
||||||
|
}.get(stop_reason or "", stop_reason or "stop")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _usage(usage: dict[str, Any] | None) -> dict[str, int]:
|
||||||
|
if not usage:
|
||||||
|
return {}
|
||||||
|
prompt = int(usage.get("inputTokens") or 0)
|
||||||
|
completion = int(usage.get("outputTokens") or 0)
|
||||||
|
total = int(usage.get("totalTokens") or prompt + completion)
|
||||||
|
result = {
|
||||||
|
"prompt_tokens": prompt,
|
||||||
|
"completion_tokens": completion,
|
||||||
|
"total_tokens": total,
|
||||||
|
}
|
||||||
|
cache_read = int(usage.get("cacheReadInputTokens") or 0)
|
||||||
|
cache_write = int(usage.get("cacheWriteInputTokens") or 0)
|
||||||
|
if cache_read:
|
||||||
|
result["cached_tokens"] = cache_read
|
||||||
|
result["cache_read_input_tokens"] = cache_read
|
||||||
|
if cache_write:
|
||||||
|
result["cache_creation_input_tokens"] = cache_write
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_reasoning(block: dict[str, Any]) -> tuple[str | None, dict[str, Any] | None]:
|
||||||
|
reasoning = block.get("reasoningContent")
|
||||||
|
if not isinstance(reasoning, dict):
|
||||||
|
return None, None
|
||||||
|
text_obj = reasoning.get("reasoningText")
|
||||||
|
if isinstance(text_obj, dict):
|
||||||
|
text = text_obj.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text, {
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": text,
|
||||||
|
"signature": text_obj.get("signature", ""),
|
||||||
|
}
|
||||||
|
redacted = reasoning.get("redactedContent")
|
||||||
|
if redacted is not None:
|
||||||
|
if isinstance(redacted, (bytes, bytearray)):
|
||||||
|
encoded = base64.b64encode(bytes(redacted)).decode("ascii")
|
||||||
|
return None, {"type": "redacted_thinking", "redactedContentBase64": encoded}
|
||||||
|
return None, {"type": "redacted_thinking", "redactedContent": redacted}
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_response(cls, response: dict[str, Any]) -> LLMResponse:
|
||||||
|
content_parts: list[str] = []
|
||||||
|
reasoning_parts: list[str] = []
|
||||||
|
tool_calls: list[ToolCallRequest] = []
|
||||||
|
thinking_blocks: list[dict[str, Any]] = []
|
||||||
|
message = (response.get("output") or {}).get("message") or {}
|
||||||
|
|
||||||
|
for block in message.get("content") or []:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
continue
|
||||||
|
if isinstance(block.get("text"), str):
|
||||||
|
content_parts.append(block["text"])
|
||||||
|
tool_use = block.get("toolUse")
|
||||||
|
if isinstance(tool_use, dict):
|
||||||
|
arguments = tool_use.get("input") if isinstance(tool_use.get("input"), dict) else {}
|
||||||
|
tool_calls.append(ToolCallRequest(
|
||||||
|
id=str(tool_use.get("toolUseId") or ""),
|
||||||
|
name=str(tool_use.get("name") or ""),
|
||||||
|
arguments=arguments,
|
||||||
|
))
|
||||||
|
reasoning_text, thinking = cls._parse_reasoning(block)
|
||||||
|
if reasoning_text:
|
||||||
|
reasoning_parts.append(reasoning_text)
|
||||||
|
if thinking:
|
||||||
|
thinking_blocks.append(thinking)
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content="".join(content_parts) or None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason=cls._finish_reason(response.get("stopReason")),
|
||||||
|
usage=cls._usage(response.get("usage")),
|
||||||
|
reasoning_content="".join(reasoning_parts) or None,
|
||||||
|
thinking_blocks=thinking_blocks or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_stream_event(
|
||||||
|
cls,
|
||||||
|
event: dict[str, Any],
|
||||||
|
*,
|
||||||
|
content_parts: list[str],
|
||||||
|
reasoning_parts: list[str],
|
||||||
|
thinking_blocks: list[dict[str, Any]],
|
||||||
|
tool_buffers: dict[int, dict[str, Any]],
|
||||||
|
state: dict[str, Any],
|
||||||
|
) -> str | None:
|
||||||
|
if "contentBlockStart" in event:
|
||||||
|
data = event["contentBlockStart"]
|
||||||
|
idx = int(data.get("contentBlockIndex") or 0)
|
||||||
|
start = data.get("start") or {}
|
||||||
|
tool_use = start.get("toolUse")
|
||||||
|
if isinstance(tool_use, dict):
|
||||||
|
tool_buffers[idx] = {
|
||||||
|
"id": str(tool_use.get("toolUseId") or ""),
|
||||||
|
"name": str(tool_use.get("name") or ""),
|
||||||
|
"input": "",
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "contentBlockDelta" in event:
|
||||||
|
data = event["contentBlockDelta"]
|
||||||
|
idx = int(data.get("contentBlockIndex") or 0)
|
||||||
|
delta = data.get("delta") or {}
|
||||||
|
text = delta.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
content_parts.append(text)
|
||||||
|
return text
|
||||||
|
tool_delta = delta.get("toolUse")
|
||||||
|
if isinstance(tool_delta, dict):
|
||||||
|
buf = tool_buffers.setdefault(idx, {"id": "", "name": "", "input": ""})
|
||||||
|
if isinstance(tool_delta.get("input"), str):
|
||||||
|
buf["input"] += tool_delta["input"]
|
||||||
|
reasoning = delta.get("reasoningContent")
|
||||||
|
if isinstance(reasoning, dict):
|
||||||
|
buf = state.setdefault("reasoning_buffers", {}).setdefault(
|
||||||
|
idx, {"text": "", "signature": "", "redactedContent": None}
|
||||||
|
)
|
||||||
|
if isinstance(reasoning.get("text"), str):
|
||||||
|
buf["text"] += reasoning["text"]
|
||||||
|
reasoning_parts.append(reasoning["text"])
|
||||||
|
if isinstance(reasoning.get("signature"), str):
|
||||||
|
buf["signature"] = reasoning["signature"]
|
||||||
|
if reasoning.get("redactedContent") is not None:
|
||||||
|
buf["redactedContent"] = reasoning["redactedContent"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "contentBlockStop" in event:
|
||||||
|
idx = int((event["contentBlockStop"] or {}).get("contentBlockIndex") or 0)
|
||||||
|
reasoning_buf = state.setdefault("reasoning_buffers", {}).pop(idx, None)
|
||||||
|
if reasoning_buf:
|
||||||
|
if reasoning_buf.get("text"):
|
||||||
|
thinking_blocks.append({
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": reasoning_buf["text"],
|
||||||
|
"signature": reasoning_buf.get("signature", ""),
|
||||||
|
})
|
||||||
|
elif reasoning_buf.get("redactedContent") is not None:
|
||||||
|
redacted = reasoning_buf["redactedContent"]
|
||||||
|
if isinstance(redacted, (bytes, bytearray)):
|
||||||
|
redacted_block = {
|
||||||
|
"type": "redacted_thinking",
|
||||||
|
"redactedContentBase64": base64.b64encode(bytes(redacted)).decode("ascii"),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
redacted_block = {
|
||||||
|
"type": "redacted_thinking",
|
||||||
|
"redactedContent": redacted,
|
||||||
|
}
|
||||||
|
thinking_blocks.append({
|
||||||
|
**redacted_block,
|
||||||
|
})
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "messageStop" in event:
|
||||||
|
state["stop_reason"] = (event["messageStop"] or {}).get("stopReason")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "metadata" in event:
|
||||||
|
metadata = event["metadata"] or {}
|
||||||
|
if isinstance(metadata.get("usage"), dict):
|
||||||
|
state["usage"] = metadata["usage"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _stream_result(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
content_parts: list[str],
|
||||||
|
reasoning_parts: list[str],
|
||||||
|
thinking_blocks: list[dict[str, Any]],
|
||||||
|
tool_buffers: dict[int, dict[str, Any]],
|
||||||
|
state: dict[str, Any],
|
||||||
|
) -> LLMResponse:
|
||||||
|
tool_calls: list[ToolCallRequest] = []
|
||||||
|
for buf in tool_buffers.values():
|
||||||
|
args: Any = {}
|
||||||
|
if buf.get("input"):
|
||||||
|
try:
|
||||||
|
args = json_repair.loads(buf["input"])
|
||||||
|
except Exception:
|
||||||
|
args = {}
|
||||||
|
tool_calls.append(ToolCallRequest(
|
||||||
|
id=buf.get("id") or "",
|
||||||
|
name=buf.get("name") or "",
|
||||||
|
arguments=args if isinstance(args, dict) else {},
|
||||||
|
))
|
||||||
|
return LLMResponse(
|
||||||
|
content="".join(content_parts) or None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
finish_reason=cls._finish_reason(state.get("stop_reason")),
|
||||||
|
usage=cls._usage(state.get("usage")),
|
||||||
|
reasoning_content="".join(reasoning_parts) or None,
|
||||||
|
thinking_blocks=thinking_blocks or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _handle_error(cls, e: Exception) -> LLMResponse:
|
||||||
|
response = getattr(e, "response", None)
|
||||||
|
metadata = response.get("ResponseMetadata", {}) if isinstance(response, dict) else {}
|
||||||
|
headers = metadata.get("HTTPHeaders") if isinstance(metadata, dict) else None
|
||||||
|
error_obj = response.get("Error", {}) if isinstance(response, dict) else {}
|
||||||
|
message = error_obj.get("Message") if isinstance(error_obj, dict) else None
|
||||||
|
code = error_obj.get("Code") if isinstance(error_obj, dict) else None
|
||||||
|
status_code = metadata.get("HTTPStatusCode") if isinstance(metadata, dict) else None
|
||||||
|
body = message or str(e)
|
||||||
|
retry_after = cls._extract_retry_after_from_headers(headers)
|
||||||
|
if retry_after is None:
|
||||||
|
retry_after = cls._extract_retry_after(body)
|
||||||
|
|
||||||
|
error_name = e.__class__.__name__.lower()
|
||||||
|
error_kind = None
|
||||||
|
if "timeout" in error_name:
|
||||||
|
error_kind = "timeout"
|
||||||
|
elif "connection" in error_name or "endpoint" in error_name:
|
||||||
|
error_kind = "connection"
|
||||||
|
|
||||||
|
code_text = str(code or "").lower()
|
||||||
|
should_retry = None
|
||||||
|
if status_code is not None:
|
||||||
|
should_retry = int(status_code) == 429 or int(status_code) >= 500
|
||||||
|
if any(token in code_text for token in ("throttl", "timeout", "unavailable", "modelnotready")):
|
||||||
|
should_retry = True
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Error: {str(body).strip()[:500]}",
|
||||||
|
finish_reason="error",
|
||||||
|
retry_after=retry_after,
|
||||||
|
error_status_code=int(status_code) if status_code is not None else None,
|
||||||
|
error_kind=error_kind,
|
||||||
|
error_type=code_text or None,
|
||||||
|
error_code=code_text or None,
|
||||||
|
error_retry_after_s=retry_after,
|
||||||
|
error_should_retry=should_retry,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
try:
|
||||||
|
kwargs = self._build_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice
|
||||||
|
)
|
||||||
|
response = await asyncio.to_thread(self._client.converse, **kwargs)
|
||||||
|
return self._parse_response(response)
|
||||||
|
except Exception as e:
|
||||||
|
return self._handle_error(e)
|
||||||
|
|
||||||
|
async def chat_stream(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
max_tokens: int = 4096,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
reasoning_effort: str | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||||
|
content_parts: list[str] = []
|
||||||
|
reasoning_parts: list[str] = []
|
||||||
|
thinking_blocks: list[dict[str, Any]] = []
|
||||||
|
tool_buffers: dict[int, dict[str, Any]] = {}
|
||||||
|
state: dict[str, Any] = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs = self._build_kwargs(
|
||||||
|
messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice
|
||||||
|
)
|
||||||
|
response = await asyncio.to_thread(self._client.converse_stream, **kwargs)
|
||||||
|
stream = iter(response.get("stream") or [])
|
||||||
|
while True:
|
||||||
|
event = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(_next_or_none, stream),
|
||||||
|
timeout=idle_timeout_s,
|
||||||
|
)
|
||||||
|
if event is None:
|
||||||
|
break
|
||||||
|
delta = self._parse_stream_event(
|
||||||
|
event,
|
||||||
|
content_parts=content_parts,
|
||||||
|
reasoning_parts=reasoning_parts,
|
||||||
|
thinking_blocks=thinking_blocks,
|
||||||
|
tool_buffers=tool_buffers,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
if delta and on_content_delta:
|
||||||
|
await on_content_delta(delta)
|
||||||
|
return self._stream_result(
|
||||||
|
content_parts=content_parts,
|
||||||
|
reasoning_parts=reasoning_parts,
|
||||||
|
thinking_blocks=thinking_blocks,
|
||||||
|
tool_buffers=tool_buffers,
|
||||||
|
state=state,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return LLMResponse(
|
||||||
|
content=(
|
||||||
|
f"Error calling LLM: stream stalled for more than "
|
||||||
|
f"{idle_timeout_s} seconds"
|
||||||
|
),
|
||||||
|
finish_reason="error",
|
||||||
|
error_kind="timeout",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return self._handle_error(e)
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return self.default_model
|
||||||
@ -60,6 +60,17 @@ def make_provider(config: Config) -> LLMProvider:
|
|||||||
default_model=model,
|
default_model=model,
|
||||||
extra_headers=p.extra_headers if p else None,
|
extra_headers=p.extra_headers if p else None,
|
||||||
)
|
)
|
||||||
|
elif backend == "bedrock":
|
||||||
|
from nanobot.providers.bedrock_provider import BedrockProvider
|
||||||
|
|
||||||
|
provider = BedrockProvider(
|
||||||
|
api_key=p.api_key if p else None,
|
||||||
|
api_base=p.api_base if p else None,
|
||||||
|
default_model=model,
|
||||||
|
region=getattr(p, "region", None) if p else None,
|
||||||
|
profile=getattr(p, "profile", None) if p else None,
|
||||||
|
extra_body=p.extra_body if p else None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
|
||||||
@ -85,12 +96,17 @@ def provider_signature(config: Config) -> tuple[object, ...]:
|
|||||||
"""Return the config fields that affect the primary LLM provider."""
|
"""Return the config fields that affect the primary LLM provider."""
|
||||||
model = config.agents.defaults.model
|
model = config.agents.defaults.model
|
||||||
defaults = config.agents.defaults
|
defaults = config.agents.defaults
|
||||||
|
p = config.get_provider(model)
|
||||||
return (
|
return (
|
||||||
model,
|
model,
|
||||||
defaults.provider,
|
defaults.provider,
|
||||||
config.get_provider_name(model),
|
config.get_provider_name(model),
|
||||||
config.get_api_key(model),
|
config.get_api_key(model),
|
||||||
config.get_api_base(model),
|
config.get_api_base(model),
|
||||||
|
p.extra_headers if p else None,
|
||||||
|
p.extra_body if p else None,
|
||||||
|
getattr(p, "region", None) if p else None,
|
||||||
|
getattr(p, "profile", None) if p else None,
|
||||||
defaults.max_tokens,
|
defaults.max_tokens,
|
||||||
defaults.temperature,
|
defaults.temperature,
|
||||||
defaults.reasoning_effort,
|
defaults.reasoning_effort,
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import time
|
import time
|
||||||
import webbrowser
|
import webbrowser
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from oauth_cli_kit.models import OAuthToken
|
from oauth_cli_kit.models import OAuthToken
|
||||||
@ -86,10 +87,8 @@ def login_github_copilot(
|
|||||||
printer(f"Open: {verify_url}")
|
printer(f"Open: {verify_url}")
|
||||||
printer(f"Code: {user_code}")
|
printer(f"Code: {user_code}")
|
||||||
if verify_complete:
|
if verify_complete:
|
||||||
try:
|
with suppress(Exception):
|
||||||
webbrowser.open(verify_complete)
|
webbrowser.open(verify_complete)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
deadline = time.time() + expires_in
|
deadline = time.time() + expires_in
|
||||||
current_interval = interval
|
current_interval = interval
|
||||||
|
|||||||
@ -34,7 +34,7 @@ class ProviderSpec:
|
|||||||
display_name: str = "" # shown in `nanobot status`
|
display_name: str = "" # shown in `nanobot status`
|
||||||
|
|
||||||
# which provider implementation to use
|
# which provider implementation to use
|
||||||
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" | "github_copilot"
|
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" | "github_copilot" | "bedrock"
|
||||||
backend: str = "openai_compat"
|
backend: str = "openai_compat"
|
||||||
|
|
||||||
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
||||||
@ -105,6 +105,29 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
backend="azure_openai",
|
backend="azure_openai",
|
||||||
is_direct=True,
|
is_direct=True,
|
||||||
),
|
),
|
||||||
|
# === AWS Bedrock (native Converse API via bedrock-runtime) =============
|
||||||
|
ProviderSpec(
|
||||||
|
name="bedrock",
|
||||||
|
keywords=(
|
||||||
|
"bedrock",
|
||||||
|
"anthropic.claude",
|
||||||
|
"amazon.nova",
|
||||||
|
"meta.",
|
||||||
|
"mistral.",
|
||||||
|
"cohere.",
|
||||||
|
"qwen.",
|
||||||
|
"deepseek.",
|
||||||
|
"openai.gpt-oss",
|
||||||
|
"ai21.",
|
||||||
|
"moonshot.",
|
||||||
|
"writer.",
|
||||||
|
"zai.",
|
||||||
|
),
|
||||||
|
env_key="AWS_BEARER_TOKEN_BEDROCK",
|
||||||
|
display_name="AWS Bedrock",
|
||||||
|
backend="bedrock",
|
||||||
|
is_direct=True,
|
||||||
|
),
|
||||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||||
# Gateways can route any model, so they win in fallback.
|
# Gateways can route any model, so they win in fallback.
|
||||||
# OpenRouter: global gateway, keys start with "sk-or-"
|
# OpenRouter: global gateway, keys start with "sk-or-"
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import ipaddress
|
import ipaddress
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
|
from contextlib import suppress
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
_BLOCKED_NETWORKS = [
|
_BLOCKED_NETWORKS = [
|
||||||
@ -30,10 +31,8 @@ def configure_ssrf_whitelist(cidrs: list[str]) -> None:
|
|||||||
global _allowed_networks
|
global _allowed_networks
|
||||||
nets = []
|
nets = []
|
||||||
for cidr in cidrs:
|
for cidr in cidrs:
|
||||||
try:
|
with suppress(ValueError):
|
||||||
nets.append(ipaddress.ip_network(cidr, strict=False))
|
nets.append(ipaddress.ip_network(cidr, strict=False))
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
_allowed_networks = nets
|
_allowed_networks = nets
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -118,7 +119,7 @@ class Session:
|
|||||||
if include_timestamps:
|
if include_timestamps:
|
||||||
content = self._annotate_message_time(message, content)
|
content = self._annotate_message_time(message, content)
|
||||||
entry: dict[str, Any] = {"role": message["role"], "content": content}
|
entry: dict[str, Any] = {"role": message["role"], "content": content}
|
||||||
for key in ("tool_calls", "tool_call_id", "name", "reasoning_content"):
|
for key in ("tool_calls", "tool_call_id", "name", "reasoning_content", "thinking_blocks"):
|
||||||
if key in message:
|
if key in message:
|
||||||
entry[key] = message[key]
|
entry[key] = message[key]
|
||||||
out.append(entry)
|
out.append(entry)
|
||||||
@ -362,15 +363,11 @@ class SessionManager:
|
|||||||
if data.get("_type") == "metadata":
|
if data.get("_type") == "metadata":
|
||||||
metadata = data.get("metadata", {})
|
metadata = data.get("metadata", {})
|
||||||
if data.get("created_at"):
|
if data.get("created_at"):
|
||||||
try:
|
with suppress(ValueError, TypeError):
|
||||||
created_at = datetime.fromisoformat(data["created_at"])
|
created_at = datetime.fromisoformat(data["created_at"])
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
if data.get("updated_at"):
|
if data.get("updated_at"):
|
||||||
try:
|
with suppress(ValueError, TypeError):
|
||||||
updated_at = datetime.fromisoformat(data["updated_at"])
|
updated_at = datetime.fromisoformat(data["updated_at"])
|
||||||
except (ValueError, TypeError):
|
|
||||||
pass
|
|
||||||
last_consolidated = data.get("last_consolidated", 0)
|
last_consolidated = data.get("last_consolidated", 0)
|
||||||
else:
|
else:
|
||||||
messages.append(data)
|
messages.append(data)
|
||||||
@ -440,14 +437,12 @@ class SessionManager:
|
|||||||
# On Windows, opening a directory with O_RDONLY raises
|
# On Windows, opening a directory with O_RDONLY raises
|
||||||
# PermissionError — skip the dir sync there (NTFS
|
# PermissionError — skip the dir sync there (NTFS
|
||||||
# journals metadata synchronously).
|
# journals metadata synchronously).
|
||||||
try:
|
with suppress(PermissionError):
|
||||||
fd = os.open(str(path.parent), os.O_RDONLY)
|
fd = os.open(str(path.parent), os.O_RDONLY)
|
||||||
try:
|
try:
|
||||||
os.fsync(fd)
|
os.fsync(fd)
|
||||||
finally:
|
finally:
|
||||||
os.close(fd)
|
os.close(fd)
|
||||||
except PermissionError:
|
|
||||||
pass # Windows — directory fsync not supported
|
|
||||||
except BaseException:
|
except BaseException:
|
||||||
tmp_path.unlink(missing_ok=True)
|
tmp_path.unlink(missing_ok=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@ -12,25 +12,23 @@ Example:
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
import zipfile
|
import zipfile
|
||||||
|
from contextlib import suppress
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from quick_validate import validate_skill
|
from quick_validate import validate_skill
|
||||||
|
|
||||||
|
|
||||||
def _is_within(path: Path, root: Path) -> bool:
|
def _is_within(path: Path, root: Path) -> bool:
|
||||||
try:
|
with suppress(ValueError):
|
||||||
path.relative_to(root)
|
path.relative_to(root)
|
||||||
return True
|
return True
|
||||||
except ValueError:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_partial_archive(skill_filename: Path) -> None:
|
def _cleanup_partial_archive(skill_filename: Path) -> None:
|
||||||
try:
|
|
||||||
if skill_filename.exists():
|
if skill_filename.exists():
|
||||||
|
with suppress(OSError):
|
||||||
skill_filename.unlink()
|
skill_filename.unlink()
|
||||||
except OSError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def package_skill(skill_path, output_dir=None):
|
def package_skill(skill_path, output_dir=None):
|
||||||
|
|||||||
86
nanobot/skills/update-setup/SKILL.md
Normal file
86
nanobot/skills/update-setup/SKILL.md
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
---
|
||||||
|
name: update-setup
|
||||||
|
description: One-time setup wizard for the nanobot upgrade skill. Triggers: setup update, configure update, 切设置更新, 初始化更新.
|
||||||
|
---
|
||||||
|
|
||||||
|
# Update Setup
|
||||||
|
|
||||||
|
Generate a personalized upgrade skill for this workspace.
|
||||||
|
|
||||||
|
## Step 1: Check Existing
|
||||||
|
|
||||||
|
Use `read_file` to check if `skills/update/SKILL.md` already exists in the workspace.
|
||||||
|
|
||||||
|
If it exists, use `ask_user` to ask: "An upgrade skill already exists. Reconfigure?" with options ["yes", "no"]. If no, stop here.
|
||||||
|
|
||||||
|
## Step 2: Current Version
|
||||||
|
|
||||||
|
Use `exec` to run `nanobot --version`. Tell the user the current version.
|
||||||
|
|
||||||
|
## Step 3: Ask Questions
|
||||||
|
|
||||||
|
Use `ask_user` for the questions below, one question per call.
|
||||||
|
|
||||||
|
**Question 1 — Install method:**
|
||||||
|
|
||||||
|
```
|
||||||
|
question: "How did you install nanobot?"
|
||||||
|
options: ["uv", "pipx", "pip", "source (git clone)"]
|
||||||
|
```
|
||||||
|
|
||||||
|
If the user selected `source (git clone)`, ask for the local checkout path:
|
||||||
|
`question: "Where is your nanobot source checkout? Enter an absolute path or a path relative to this workspace:"`.
|
||||||
|
|
||||||
|
**Question 2 — Optional dependencies:**
|
||||||
|
|
||||||
|
```
|
||||||
|
question: "Which optional dependencies do you need? List names separated by spaces, or reply 'none'. Available: api, wecom, weixin, msteams, matrix, discord, langsmith, pdf"
|
||||||
|
```
|
||||||
|
|
||||||
|
Parse the reply. If the user says "none" or similar, set extras to empty. Otherwise collect the valid names.
|
||||||
|
|
||||||
|
**Question 3 — Proxy:**
|
||||||
|
|
||||||
|
```
|
||||||
|
question: "Do you need an HTTP proxy to reach PyPI or GitHub?"
|
||||||
|
options: ["no", "yes"]
|
||||||
|
```
|
||||||
|
|
||||||
|
If yes, ask one more time for the proxy URL: `question: "Enter proxy URL (e.g. http://127.0.0.1:7890):"`.
|
||||||
|
|
||||||
|
## Step 4: Generate Skill
|
||||||
|
|
||||||
|
Build the extras string. If the user selected dependencies, format as `[dep1,dep2,...]`. Otherwise omit the brackets entirely.
|
||||||
|
|
||||||
|
Determine the upgrade command from the install method:
|
||||||
|
|
||||||
|
| Method | Command |
|
||||||
|
|--------|---------|
|
||||||
|
| uv | `uv tool install "nanobot-ai[EXTRAS]" --force` |
|
||||||
|
| pipx | `pipx install --force "nanobot-ai[EXTRAS]"` |
|
||||||
|
| pip | `python -m pip install --upgrade "nanobot-ai[EXTRAS]"` |
|
||||||
|
| source | `cd <SOURCE_CHECKOUT> && git pull && python -m pip install -e ".[EXTRAS]"` |
|
||||||
|
|
||||||
|
For source installs, include extras in the editable install command when selected. Quote the source checkout path if it contains spaces.
|
||||||
|
|
||||||
|
Build the skill content. If proxy is configured, add `export http_proxy=URL` and `export https_proxy=URL` lines before the upgrade command.
|
||||||
|
|
||||||
|
Use `write_file` to write `skills/update/SKILL.md` with this content:
|
||||||
|
|
||||||
|
```
|
||||||
|
---
|
||||||
|
name: update
|
||||||
|
description: "Upgrade nanobot to the latest version. Triggers: upgrade nanobot, update nanobot, 升级nanobot, 更新nanobot."
|
||||||
|
---
|
||||||
|
|
||||||
|
# Update Nanobot
|
||||||
|
|
||||||
|
1. (If proxy configured) Set proxy: `export http_proxy=URL && export https_proxy=URL`
|
||||||
|
2. Use `exec` to run the upgrade command: <UPGRADE_COMMAND>
|
||||||
|
3. Use `exec` to verify: `nanobot --version`
|
||||||
|
4. Tell the user the new version. Say: "Run `/restart` to restart nanobot and apply the update. If `/restart` is unavailable in this channel, restart the nanobot process manually."
|
||||||
|
```
|
||||||
|
|
||||||
|
## Step 5: Confirm
|
||||||
|
|
||||||
|
Tell the user: "Upgrade skill created. Say 'upgrade nanobot' when you want to update."
|
||||||
@ -6,6 +6,7 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from contextlib import suppress
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -416,14 +417,10 @@ def estimate_prompt_tokens_chain(
|
|||||||
"""Estimate prompt tokens via provider counter first, then tiktoken fallback."""
|
"""Estimate prompt tokens via provider counter first, then tiktoken fallback."""
|
||||||
provider_counter = getattr(provider, "estimate_prompt_tokens", None)
|
provider_counter = getattr(provider, "estimate_prompt_tokens", None)
|
||||||
if callable(provider_counter):
|
if callable(provider_counter):
|
||||||
try:
|
with suppress(Exception):
|
||||||
tokens, source = provider_counter(messages, tools, model)
|
tokens, source = provider_counter(messages, tools, model)
|
||||||
if isinstance(tokens, (int, float)) and tokens > 0:
|
if isinstance(tokens, (int, float)) and tokens > 0:
|
||||||
return int(tokens), str(source or "provider_counter")
|
return int(tokens), str(source or "provider_counter")
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
estimated = estimate_prompt_tokens(messages, tools)
|
|
||||||
if estimated > 0:
|
if estimated > 0:
|
||||||
return int(estimated), "tiktoken"
|
return int(estimated), "tiktoken"
|
||||||
return 0, "none"
|
return 0, "none"
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -26,11 +27,9 @@ def format_restart_completed_message(started_at_raw: str) -> str:
|
|||||||
"""Build restart completion text and include elapsed time when available."""
|
"""Build restart completion text and include elapsed time when available."""
|
||||||
elapsed_suffix = ""
|
elapsed_suffix = ""
|
||||||
if started_at_raw:
|
if started_at_raw:
|
||||||
try:
|
with suppress(ValueError):
|
||||||
elapsed_s = max(0.0, time.time() - float(started_at_raw))
|
elapsed_s = max(0.0, time.time() - float(started_at_raw))
|
||||||
elapsed_suffix = f" in {elapsed_s:.1f}s"
|
elapsed_suffix = f" in {elapsed_s:.1f}s"
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
return f"Restart completed{elapsed_suffix}."
|
return f"Restart completed{elapsed_suffix}."
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -61,6 +61,7 @@ dependencies = [
|
|||||||
"openpyxl>=3.1.0,<4.0.0",
|
"openpyxl>=3.1.0,<4.0.0",
|
||||||
"python-pptx>=1.0.0,<2.0.0",
|
"python-pptx>=1.0.0,<2.0.0",
|
||||||
"filelock>=3.25.2",
|
"filelock>=3.25.2",
|
||||||
|
"boto3>=1.43.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@ -87,6 +87,42 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
|||||||
assert "Return exactly: OK" in user_content
|
assert "Return exactly: OK" in user_content
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_context_includes_sender_id_when_provided(tmp_path) -> None:
|
||||||
|
"""Sender ID should be included in runtime context when provided."""
|
||||||
|
workspace = _make_workspace(tmp_path)
|
||||||
|
builder = ContextBuilder(workspace)
|
||||||
|
|
||||||
|
messages = builder.build_messages(
|
||||||
|
history=[],
|
||||||
|
current_message="Return exactly: OK",
|
||||||
|
channel="cli",
|
||||||
|
chat_id="direct",
|
||||||
|
sender_id="user-12345",
|
||||||
|
)
|
||||||
|
|
||||||
|
user_content = messages[-1]["content"]
|
||||||
|
assert isinstance(user_content, str)
|
||||||
|
assert "Sender ID: user-12345" in user_content
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_context_excludes_sender_id_when_not_provided(tmp_path) -> None:
|
||||||
|
"""Sender ID should not be present in runtime context when not provided."""
|
||||||
|
workspace = _make_workspace(tmp_path)
|
||||||
|
builder = ContextBuilder(workspace)
|
||||||
|
|
||||||
|
messages = builder.build_messages(
|
||||||
|
history=[],
|
||||||
|
current_message="Return exactly: OK",
|
||||||
|
channel="cli",
|
||||||
|
chat_id="direct",
|
||||||
|
sender_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_content = messages[-1]["content"]
|
||||||
|
assert isinstance(user_content, str)
|
||||||
|
assert "Sender ID:" not in user_content
|
||||||
|
|
||||||
|
|
||||||
def test_unprocessed_history_injected_into_system_prompt(tmp_path) -> None:
|
def test_unprocessed_history_injected_into_system_prompt(tmp_path) -> None:
|
||||||
"""Entries in history.jsonl not yet consumed by Dream appear with timestamps."""
|
"""Entries in history.jsonl not yet consumed by Dream appear with timestamps."""
|
||||||
workspace = _make_workspace(tmp_path)
|
workspace = _make_workspace(tmp_path)
|
||||||
|
|||||||
@ -727,6 +727,7 @@ def test_set_tool_context_passes_thread_session_key_to_spawn(tmp_path: Path) ->
|
|||||||
loop._set_tool_context(
|
loop._set_tool_context(
|
||||||
"slack",
|
"slack",
|
||||||
"C123",
|
"C123",
|
||||||
|
message_id="msg-123",
|
||||||
metadata={"slack": {"thread_ts": "1700.42", "channel_type": "channel"}},
|
metadata={"slack": {"thread_ts": "1700.42", "channel_type": "channel"}},
|
||||||
session_key="slack:C123:1700.42",
|
session_key="slack:C123:1700.42",
|
||||||
)
|
)
|
||||||
@ -734,6 +735,7 @@ def test_set_tool_context_passes_thread_session_key_to_spawn(tmp_path: Path) ->
|
|||||||
spawn_tool = loop.tools.get("spawn")
|
spawn_tool = loop.tools.get("spawn")
|
||||||
assert spawn_tool is not None
|
assert spawn_tool is not None
|
||||||
assert spawn_tool._session_key.get() == "slack:C123:1700.42"
|
assert spawn_tool._session_key.get() == "slack:C123:1700.42"
|
||||||
|
assert spawn_tool._origin_message_id.get() == "msg-123"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -766,14 +768,17 @@ async def test_system_subagent_followup_uses_thread_session_and_slack_metadata(t
|
|||||||
chat_id="slack:C123",
|
chat_id="slack:C123",
|
||||||
content="subagent result",
|
content="subagent result",
|
||||||
session_key_override="slack:C123:1700.42",
|
session_key_override="slack:C123:1700.42",
|
||||||
metadata={"subagent_task_id": "sub-1"},
|
metadata={"subagent_task_id": "sub-1", "origin_message_id": "msg-123"},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert outbound is not None
|
assert outbound is not None
|
||||||
assert outbound.channel == "slack"
|
assert outbound.channel == "slack"
|
||||||
assert outbound.chat_id == "C123"
|
assert outbound.chat_id == "C123"
|
||||||
assert outbound.metadata == {"slack": {"thread_ts": "1700.42"}}
|
assert outbound.metadata == {
|
||||||
|
"slack": {"thread_ts": "1700.42"},
|
||||||
|
"origin_message_id": "msg-123",
|
||||||
|
}
|
||||||
assert "thread question" in seen["initial_messages"][1]["content"]
|
assert "thread question" in seen["initial_messages"][1]["content"]
|
||||||
|
|
||||||
loop.sessions.invalidate("slack:C123:1700.42")
|
loop.sessions.invalidate("slack:C123:1700.42")
|
||||||
|
|||||||
@ -180,6 +180,7 @@ def test_get_history_preserves_reasoning_content():
|
|||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "done",
|
"content": "done",
|
||||||
"reasoning_content": "hidden chain of thought",
|
"reasoning_content": "hidden chain of thought",
|
||||||
|
"thinking_blocks": [{"type": "thinking", "thinking": "hidden chain of thought", "signature": "sig"}],
|
||||||
})
|
})
|
||||||
|
|
||||||
history = session.get_history(max_messages=500)
|
history = session.get_history(max_messages=500)
|
||||||
@ -190,6 +191,11 @@ def test_get_history_preserves_reasoning_content():
|
|||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "done",
|
"content": "done",
|
||||||
"reasoning_content": "hidden chain of thought",
|
"reasoning_content": "hidden chain of thought",
|
||||||
|
"thinking_blocks": [{
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": "hidden chain of thought",
|
||||||
|
"signature": "sig",
|
||||||
|
}],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,8 @@ from nanobot.bus.queue import MessageBus
|
|||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.channels.manager import ChannelManager
|
from nanobot.channels.manager import ChannelManager
|
||||||
from nanobot.config.schema import ChannelsConfig
|
from nanobot.config.schema import ChannelsConfig
|
||||||
|
from nanobot.providers.transcription import GroqTranscriptionProvider as _GroqProvider
|
||||||
|
from nanobot.providers.transcription import OpenAITranscriptionProvider as _OpenAIProvider
|
||||||
from nanobot.utils.restart import RestartNotice
|
from nanobot.utils.restart import RestartNotice
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -338,9 +340,6 @@ async def test_base_channel_passes_language_to_groq_transcription_provider():
|
|||||||
# Transcription provider HTTP tests
|
# Transcription provider HTTP tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
from nanobot.providers.transcription import GroqTranscriptionProvider as _GroqProvider
|
|
||||||
from nanobot.providers.transcription import OpenAITranscriptionProvider as _OpenAIProvider
|
|
||||||
|
|
||||||
|
|
||||||
class _StubResponse:
|
class _StubResponse:
|
||||||
def raise_for_status(self):
|
def raise_for_status(self):
|
||||||
@ -791,6 +790,50 @@ async def test_send_with_retry_skips_send_when_streamed():
|
|||||||
assert send_delta_called is False
|
assert send_delta_called is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_outbound_duplicate_suppression_is_scoped_to_origin_message() -> None:
|
||||||
|
fake_config = SimpleNamespace(
|
||||||
|
channels=ChannelsConfig(send_max_retries=3),
|
||||||
|
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||||
|
)
|
||||||
|
|
||||||
|
mgr = ChannelManager.__new__(ChannelManager)
|
||||||
|
mgr.config = fake_config
|
||||||
|
mgr.bus = MessageBus()
|
||||||
|
mgr.channels = {}
|
||||||
|
mgr._dispatch_task = None
|
||||||
|
mgr._origin_reply_fingerprints = {}
|
||||||
|
|
||||||
|
first = OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="chat123",
|
||||||
|
content="Done",
|
||||||
|
metadata={"message_id": "msg-1"},
|
||||||
|
)
|
||||||
|
duplicate = OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="chat123",
|
||||||
|
content=" Done ",
|
||||||
|
metadata={"origin_message_id": "msg-1"},
|
||||||
|
)
|
||||||
|
separate_turn = OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="chat123",
|
||||||
|
content="Done",
|
||||||
|
metadata={"message_id": "msg-2"},
|
||||||
|
)
|
||||||
|
new_origin_content = OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="chat123",
|
||||||
|
content="Done with extra details",
|
||||||
|
metadata={"origin_message_id": "msg-1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mgr._should_suppress_outbound(first) is False
|
||||||
|
assert mgr._should_suppress_outbound(duplicate) is True
|
||||||
|
assert mgr._should_suppress_outbound(separate_turn) is False
|
||||||
|
assert mgr._should_suppress_outbound(new_origin_content) is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_with_retry_propagates_cancelled_error():
|
async def test_send_with_retry_propagates_cancelled_error():
|
||||||
"""_send_with_retry should re-raise CancelledError for graceful shutdown."""
|
"""_send_with_retry should re-raise CancelledError for graceful shutdown."""
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import asyncio
|
|||||||
import zipfile
|
import zipfile
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
@ -17,19 +16,27 @@ except ImportError:
|
|||||||
if not DINGTALK_AVAILABLE:
|
if not DINGTALK_AVAILABLE:
|
||||||
pytest.skip("DingTalk dependencies not installed (dingtalk-stream)", allow_module_level=True)
|
pytest.skip("DingTalk dependencies not installed (dingtalk-stream)", allow_module_level=True)
|
||||||
|
|
||||||
from nanobot.bus.queue import MessageBus
|
|
||||||
import nanobot.channels.dingtalk as dingtalk_module
|
import nanobot.channels.dingtalk as dingtalk_module
|
||||||
from nanobot.channels.dingtalk import DingTalkChannel, NanobotDingTalkHandler
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.dingtalk import DingTalkConfig
|
from nanobot.channels.dingtalk import DingTalkChannel, DingTalkConfig, NanobotDingTalkHandler
|
||||||
|
|
||||||
|
|
||||||
class _FakeResponse:
|
class _FakeResponse:
|
||||||
def __init__(self, status_code: int = 200, json_body: dict | None = None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
status_code: int = 200,
|
||||||
|
json_body: dict | None = None,
|
||||||
|
*,
|
||||||
|
content: bytes = b"",
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
url: str = "https://example.com/file",
|
||||||
|
) -> None:
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self._json_body = json_body or {}
|
self._json_body = json_body or {}
|
||||||
self.text = "{}"
|
self.text = content.decode("utf-8", errors="replace") if content else "{}"
|
||||||
self.content = b""
|
self.content = content
|
||||||
self.headers = {"content-type": "application/json"}
|
self.headers = headers or {"content-type": "application/json"}
|
||||||
|
self.url = httpx.URL(url)
|
||||||
|
|
||||||
def json(self) -> dict:
|
def json(self) -> dict:
|
||||||
return self._json_body
|
return self._json_body
|
||||||
@ -46,11 +53,13 @@ class _FakeHttp:
|
|||||||
return _FakeResponse()
|
return _FakeResponse()
|
||||||
|
|
||||||
async def post(self, url: str, json=None, headers=None, **kwargs):
|
async def post(self, url: str, json=None, headers=None, **kwargs):
|
||||||
self.calls.append({"method": "POST", "url": url, "json": json, "headers": headers})
|
self.calls.append(
|
||||||
|
{"method": "POST", "url": url, "json": json, "headers": headers, "kwargs": kwargs}
|
||||||
|
)
|
||||||
return self._next_response()
|
return self._next_response()
|
||||||
|
|
||||||
async def get(self, url: str, **kwargs):
|
async def get(self, url: str, **kwargs):
|
||||||
self.calls.append({"method": "GET", "url": url})
|
self.calls.append({"method": "GET", "url": url, "kwargs": kwargs})
|
||||||
return self._next_response()
|
return self._next_response()
|
||||||
|
|
||||||
|
|
||||||
@ -242,6 +251,245 @@ async def test_download_dingtalk_file(tmp_path, monkeypatch) -> None:
|
|||||||
assert channel._http.calls[1]["method"] == "GET"
|
assert channel._http.calls[1]["method"] == "GET"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_media_bytes_rejects_private_http_target_before_fetch() -> None:
|
||||||
|
"""Remote media fetches must not reach loopback/private addresses."""
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._http = _FakeHttp(
|
||||||
|
responses=[
|
||||||
|
_FakeResponse(
|
||||||
|
200,
|
||||||
|
content=b"internal secret",
|
||||||
|
headers={"content-type": "text/plain"},
|
||||||
|
url="http://127.0.0.1/admin.txt",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
data, filename, content_type = await channel._read_media_bytes("http://127.0.0.1/admin.txt")
|
||||||
|
|
||||||
|
assert (data, filename, content_type) == (None, None, None)
|
||||||
|
assert channel._http.calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_media_bytes_rejects_private_redirect_result() -> None:
|
||||||
|
"""A public-looking media URL must not be accepted after redirecting private."""
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._http = _FakeHttp(
|
||||||
|
responses=[
|
||||||
|
_FakeResponse(
|
||||||
|
200,
|
||||||
|
content=b"metadata bytes",
|
||||||
|
headers={"content-type": "text/plain"},
|
||||||
|
url="http://127.0.0.1/metadata",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
data, filename, content_type = await channel._read_media_bytes("https://example.com/safe.txt")
|
||||||
|
|
||||||
|
assert (data, filename, content_type) == (None, None, None)
|
||||||
|
assert len(channel._http.calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_media_bytes_rejects_oversized_remote_response(monkeypatch) -> None:
|
||||||
|
"""DingTalk media downloads should enforce a byte cap before upload."""
|
||||||
|
monkeypatch.setattr(dingtalk_module, "DINGTALK_MAX_REMOTE_MEDIA_BYTES", 8, raising=False)
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._http = _FakeHttp(
|
||||||
|
responses=[
|
||||||
|
_FakeResponse(
|
||||||
|
200,
|
||||||
|
content=b"123456789",
|
||||||
|
headers={"content-type": "text/plain"},
|
||||||
|
url="https://example.com/large.txt",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
data, filename, content_type = await channel._read_media_bytes("https://example.com/large.txt")
|
||||||
|
|
||||||
|
assert (data, filename, content_type) == (None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_media_bytes_does_not_follow_remote_redirects_by_default() -> None:
|
||||||
|
"""Redirects are refused by default instead of followed into internal networks."""
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._http = _FakeHttp(
|
||||||
|
responses=[
|
||||||
|
_FakeResponse(
|
||||||
|
302,
|
||||||
|
headers={"location": "http://127.0.0.1/metadata"},
|
||||||
|
url="https://example.com/redirect.txt",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
data, filename, content_type = await channel._read_media_bytes("https://example.com/redirect.txt")
|
||||||
|
|
||||||
|
assert (data, filename, content_type) == (None, None, None)
|
||||||
|
assert channel._http.calls[0]["kwargs"]["follow_redirects"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_media_bytes_follows_safe_redirect_when_explicitly_enabled() -> None:
|
||||||
|
"""Operators can opt in to public redirects without enabling private redirects."""
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(
|
||||||
|
client_id="app",
|
||||||
|
client_secret="secret",
|
||||||
|
allow_from=["*"],
|
||||||
|
allow_remote_media_redirects=True,
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._http = _FakeHttp(
|
||||||
|
responses=[
|
||||||
|
_FakeResponse(
|
||||||
|
302,
|
||||||
|
headers={"location": "https://example.com/final.txt"},
|
||||||
|
url="https://example.com/redirect.txt",
|
||||||
|
),
|
||||||
|
_FakeResponse(
|
||||||
|
200,
|
||||||
|
content=b"redirected media",
|
||||||
|
headers={"content-type": "text/plain"},
|
||||||
|
url="https://example.com/final.txt",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
data, filename, content_type = await channel._read_media_bytes("https://example.com/redirect.txt")
|
||||||
|
|
||||||
|
assert (data, filename, content_type) == (b"redirected media", "redirect.txt", "text/plain")
|
||||||
|
assert [call["url"] for call in channel._http.calls] == [
|
||||||
|
"https://example.com/redirect.txt",
|
||||||
|
"https://example.com/final.txt",
|
||||||
|
]
|
||||||
|
assert all(call["kwargs"]["follow_redirects"] is False for call in channel._http.calls)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_media_bytes_blocks_cross_host_redirect_without_allowlist() -> None:
|
||||||
|
"""Redirect opt-in should not allow arbitrary cross-host redirects by default."""
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(
|
||||||
|
client_id="app",
|
||||||
|
client_secret="secret",
|
||||||
|
allow_from=["*"],
|
||||||
|
allow_remote_media_redirects=True,
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._http = _FakeHttp(
|
||||||
|
responses=[
|
||||||
|
_FakeResponse(
|
||||||
|
302,
|
||||||
|
headers={"location": "https://example.org/final.txt"},
|
||||||
|
url="https://example.com/redirect.txt",
|
||||||
|
),
|
||||||
|
_FakeResponse(
|
||||||
|
200,
|
||||||
|
content=b"cross-host media",
|
||||||
|
headers={"content-type": "text/plain"},
|
||||||
|
url="https://example.org/final.txt",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
data, filename, content_type = await channel._read_media_bytes("https://example.com/redirect.txt")
|
||||||
|
|
||||||
|
assert (data, filename, content_type) == (None, None, None)
|
||||||
|
assert [call["url"] for call in channel._http.calls] == ["https://example.com/redirect.txt"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_media_bytes_allows_cross_host_redirect_when_allowlisted() -> None:
|
||||||
|
"""Operators can explicitly allow a known CDN/download host for redirects."""
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(
|
||||||
|
client_id="app",
|
||||||
|
client_secret="secret",
|
||||||
|
allow_from=["*"],
|
||||||
|
allow_remote_media_redirects=True,
|
||||||
|
remote_media_redirect_allowed_hosts=["example.org"],
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._http = _FakeHttp(
|
||||||
|
responses=[
|
||||||
|
_FakeResponse(
|
||||||
|
302,
|
||||||
|
headers={"location": "https://example.org/final.txt"},
|
||||||
|
url="https://example.com/redirect.txt",
|
||||||
|
),
|
||||||
|
_FakeResponse(
|
||||||
|
200,
|
||||||
|
content=b"cross-host media",
|
||||||
|
headers={"content-type": "text/plain"},
|
||||||
|
url="https://example.org/final.txt",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
data, filename, content_type = await channel._read_media_bytes("https://example.com/redirect.txt")
|
||||||
|
|
||||||
|
assert (data, filename, content_type) == (b"cross-host media", "redirect.txt", "text/plain")
|
||||||
|
assert [call["url"] for call in channel._http.calls] == [
|
||||||
|
"https://example.com/redirect.txt",
|
||||||
|
"https://example.org/final.txt",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_media_bytes_blocks_private_redirect_even_when_redirects_enabled() -> None:
|
||||||
|
"""Redirect opt-in must still validate each hop before fetching it."""
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
DingTalkConfig(
|
||||||
|
client_id="app",
|
||||||
|
client_secret="secret",
|
||||||
|
allow_from=["*"],
|
||||||
|
allow_remote_media_redirects=True,
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._http = _FakeHttp(
|
||||||
|
responses=[
|
||||||
|
_FakeResponse(
|
||||||
|
302,
|
||||||
|
headers={"location": "http://127.0.0.1/metadata"},
|
||||||
|
url="https://example.com/redirect.txt",
|
||||||
|
),
|
||||||
|
_FakeResponse(
|
||||||
|
200,
|
||||||
|
content=b"internal secret",
|
||||||
|
headers={"content-type": "text/plain"},
|
||||||
|
url="http://127.0.0.1/metadata",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
data, filename, content_type = await channel._read_media_bytes("https://example.com/redirect.txt")
|
||||||
|
|
||||||
|
assert (data, filename, content_type) == (None, None, None)
|
||||||
|
assert [call["url"] for call in channel._http.calls] == ["https://example.com/redirect.txt"]
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_upload_payload_zips_html_attachment() -> None:
|
def test_normalize_upload_payload_zips_html_attachment() -> None:
|
||||||
channel = DingTalkChannel(
|
channel = DingTalkChannel(
|
||||||
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
DingTalkConfig(client_id="app", client_secret="secret", allow_from=["*"]),
|
||||||
|
|||||||
@ -380,6 +380,62 @@ async def test_on_message_skips_typing_for_self_message() -> None:
|
|||||||
assert client.typing_calls == []
|
assert client.typing_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_skips_pre_startup_event() -> None:
|
||||||
|
channel = MatrixChannel(_make_config(), MessageBus())
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
channel._started_at_ms = 1_000_000
|
||||||
|
|
||||||
|
handled: list[str] = []
|
||||||
|
|
||||||
|
async def _fake_handle_message(**kwargs) -> None:
|
||||||
|
handled.append(kwargs["sender_id"])
|
||||||
|
|
||||||
|
channel._handle_message = _fake_handle_message # type: ignore[method-assign]
|
||||||
|
|
||||||
|
room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room")
|
||||||
|
old_event = SimpleNamespace(
|
||||||
|
sender="@alice:matrix.org", body="old", source={}, server_timestamp=999_999
|
||||||
|
)
|
||||||
|
fresh_event = SimpleNamespace(
|
||||||
|
sender="@alice:matrix.org", body="fresh", source={}, server_timestamp=1_000_001
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._on_message(room, old_event)
|
||||||
|
await channel._on_message(room, fresh_event)
|
||||||
|
|
||||||
|
assert handled == ["@alice:matrix.org"]
|
||||||
|
assert client.typing_calls == [
|
||||||
|
("!room:matrix.org", True, TYPING_NOTICE_TIMEOUT_MS),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_media_message_skips_pre_startup_event() -> None:
|
||||||
|
channel = MatrixChannel(_make_config(), MessageBus())
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
channel._started_at_ms = 1_000_000
|
||||||
|
|
||||||
|
handled: list[str] = []
|
||||||
|
|
||||||
|
async def _fake_handle_message(**kwargs) -> None:
|
||||||
|
handled.append(kwargs["sender_id"])
|
||||||
|
|
||||||
|
channel._handle_message = _fake_handle_message # type: ignore[method-assign]
|
||||||
|
|
||||||
|
room = SimpleNamespace(room_id="!room:matrix.org", display_name="Test room")
|
||||||
|
old_event = SimpleNamespace(
|
||||||
|
sender="@alice:matrix.org", body="old", source={}, server_timestamp=999_999
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._on_media_message(room, old_event)
|
||||||
|
|
||||||
|
assert handled == []
|
||||||
|
assert client.typing_calls == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_message_skips_typing_for_denied_sender() -> None:
|
async def test_on_message_skips_typing_for_denied_sender() -> None:
|
||||||
channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
|
channel = MatrixChannel(_make_config(allow_from=["@bob:matrix.org"]), MessageBus())
|
||||||
@ -1190,6 +1246,44 @@ async def test_send_progress_keeps_typing_keepalive_running() -> None:
|
|||||||
await channel.stop()
|
await channel.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_empty_content_does_not_call_room_send() -> None:
|
||||||
|
"""Progress messages with empty content must not produce an empty body: '' event."""
|
||||||
|
channel = MatrixChannel(_make_config(), MessageBus())
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="matrix",
|
||||||
|
chat_id="!room:matrix.org",
|
||||||
|
content="",
|
||||||
|
metadata={"_progress": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert client.room_send_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_whitespace_only_content_does_not_call_room_send() -> None:
|
||||||
|
"""Progress messages with whitespace-only content must not produce an empty message."""
|
||||||
|
channel = MatrixChannel(_make_config(), MessageBus())
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel="matrix",
|
||||||
|
chat_id="!room:matrix.org",
|
||||||
|
content=" \n\n ",
|
||||||
|
metadata={"_progress": True},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert client.room_send_calls == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_clears_typing_when_send_fails() -> None:
|
async def test_send_clears_typing_when_send_fails() -> None:
|
||||||
channel = MatrixChannel(_make_config(), MessageBus())
|
channel = MatrixChannel(_make_config(), MessageBus())
|
||||||
|
|||||||
254
tests/providers/test_bedrock_provider.py
Normal file
254
tests/providers/test_bedrock_provider.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
"""Tests for the native AWS Bedrock Converse provider."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import Config, ProvidersConfig
|
||||||
|
from nanobot.providers.bedrock_provider import BedrockProvider
|
||||||
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
|
|
||||||
|
class FakeClient:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
response: dict[str, Any] | None = None,
|
||||||
|
stream_events: list[dict[str, Any]] | None = None,
|
||||||
|
error: Exception | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.response = response
|
||||||
|
self.stream_events = stream_events or []
|
||||||
|
self.error = error
|
||||||
|
self.calls: list[dict[str, Any]] = []
|
||||||
|
self.stream_calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def converse(self, **kwargs):
|
||||||
|
self.calls.append(kwargs)
|
||||||
|
if self.error:
|
||||||
|
raise self.error
|
||||||
|
return self.response or {}
|
||||||
|
|
||||||
|
def converse_stream(self, **kwargs):
|
||||||
|
self.stream_calls.append(kwargs)
|
||||||
|
if self.error:
|
||||||
|
raise self.error
|
||||||
|
return {"stream": iter(self.stream_events)}
|
||||||
|
|
||||||
|
|
||||||
|
class FakeBedrockError(Exception):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("too many requests")
|
||||||
|
self.response = {
|
||||||
|
"ResponseMetadata": {
|
||||||
|
"HTTPStatusCode": 429,
|
||||||
|
"HTTPHeaders": {"retry-after": "3"},
|
||||||
|
},
|
||||||
|
"Error": {
|
||||||
|
"Code": "ThrottlingException",
|
||||||
|
"Message": "Rate exceeded",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_provider_is_registered_and_matches_without_api_key() -> None:
|
||||||
|
spec = find_by_name("bedrock")
|
||||||
|
assert spec is not None
|
||||||
|
assert spec.backend == "bedrock"
|
||||||
|
assert spec.is_direct is True
|
||||||
|
assert hasattr(ProvidersConfig(), "bedrock")
|
||||||
|
|
||||||
|
cfg = Config.model_validate({
|
||||||
|
"agents": {"defaults": {"model": "bedrock/global.anthropic.claude-opus-4-7"}},
|
||||||
|
"providers": {"bedrock": {"region": "us-east-1"}},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert cfg.get_provider_name() == "bedrock"
|
||||||
|
assert cfg.get_provider().region == "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_opus_47_uses_adaptive_thinking_and_omits_temperature() -> None:
|
||||||
|
provider = BedrockProvider(region="us-east-1", client=FakeClient())
|
||||||
|
|
||||||
|
kwargs = provider._build_kwargs(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=None,
|
||||||
|
model="bedrock/global.anthropic.claude-opus-4-7",
|
||||||
|
max_tokens=2048,
|
||||||
|
temperature=0.1,
|
||||||
|
reasoning_effort="medium",
|
||||||
|
tool_choice=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert kwargs["modelId"] == "global.anthropic.claude-opus-4-7"
|
||||||
|
assert kwargs["inferenceConfig"] == {"maxTokens": 2048}
|
||||||
|
assert kwargs["additionalModelRequestFields"]["thinking"] == {
|
||||||
|
"type": "adaptive",
|
||||||
|
"effort": "medium",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_generic_bedrock_model_keeps_temperature_and_skips_anthropic_thinking() -> None:
|
||||||
|
provider = BedrockProvider(region="us-east-1", client=FakeClient())
|
||||||
|
|
||||||
|
kwargs = provider._build_kwargs(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=None,
|
||||||
|
model="bedrock/amazon.nova-lite-v1:0",
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.3,
|
||||||
|
reasoning_effort="medium",
|
||||||
|
tool_choice=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert kwargs["modelId"] == "amazon.nova-lite-v1:0"
|
||||||
|
assert kwargs["inferenceConfig"] == {"maxTokens": 1024, "temperature": 0.3}
|
||||||
|
assert "additionalModelRequestFields" not in kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_kwargs_converts_messages_tools_and_tool_results() -> None:
|
||||||
|
provider = BedrockProvider(region="us-east-1", client=FakeClient())
|
||||||
|
tools = [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "read_file",
|
||||||
|
"description": "Read a file",
|
||||||
|
"parameters": {"type": "object", "properties": {"path": {"type": "string"}}},
|
||||||
|
},
|
||||||
|
}]
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "read x"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "toolu_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": '{"path": "x"}'},
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "toolu_1", "name": "read_file", "content": "ok"},
|
||||||
|
{"role": "user", "content": "continue"},
|
||||||
|
]
|
||||||
|
|
||||||
|
kwargs = provider._build_kwargs(
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
model="bedrock/anthropic.claude-opus-4-7",
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.7,
|
||||||
|
reasoning_effort=None,
|
||||||
|
tool_choice="required",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert kwargs["system"] == [{"text": "You are helpful."}]
|
||||||
|
assert kwargs["messages"][1]["content"] == [{
|
||||||
|
"toolUse": {
|
||||||
|
"toolUseId": "toolu_1",
|
||||||
|
"name": "read_file",
|
||||||
|
"input": {"path": "x"},
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
assert kwargs["messages"][2]["role"] == "user"
|
||||||
|
assert kwargs["messages"][2]["content"][0]["toolResult"]["toolUseId"] == "toolu_1"
|
||||||
|
assert kwargs["messages"][2]["content"][1] == {"text": "continue"}
|
||||||
|
tool_spec = kwargs["toolConfig"]["tools"][0]["toolSpec"]
|
||||||
|
assert tool_spec["name"] == "read_file"
|
||||||
|
assert kwargs["toolConfig"]["toolChoice"] == {"any": {}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_response_maps_text_tools_reasoning_usage_and_stop_reason() -> None:
|
||||||
|
response = {
|
||||||
|
"output": {
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"reasoningContent": {"reasoningText": {"text": "think", "signature": "sig"}}},
|
||||||
|
{"text": "hello"},
|
||||||
|
{"toolUse": {"toolUseId": "t1", "name": "search", "input": {"q": "x"}}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stopReason": "tool_use",
|
||||||
|
"usage": {
|
||||||
|
"inputTokens": 10,
|
||||||
|
"outputTokens": 5,
|
||||||
|
"totalTokens": 15,
|
||||||
|
"cacheReadInputTokens": 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = BedrockProvider._parse_response(response)
|
||||||
|
|
||||||
|
assert result.content == "hello"
|
||||||
|
assert result.finish_reason == "tool_calls"
|
||||||
|
assert result.usage["prompt_tokens"] == 10
|
||||||
|
assert result.usage["cached_tokens"] == 2
|
||||||
|
assert result.reasoning_content == "think"
|
||||||
|
assert result.thinking_blocks == [{"type": "thinking", "thinking": "think", "signature": "sig"}]
|
||||||
|
assert result.tool_calls[0].id == "t1"
|
||||||
|
assert result.tool_calls[0].arguments == {"q": "x"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_stream_aggregates_text_tool_use_and_usage() -> None:
|
||||||
|
client = FakeClient(stream_events=[
|
||||||
|
{"contentBlockDelta": {"contentBlockIndex": 0, "delta": {"text": "he"}}},
|
||||||
|
{"contentBlockDelta": {"contentBlockIndex": 0, "delta": {"text": "llo"}}},
|
||||||
|
{
|
||||||
|
"contentBlockStart": {
|
||||||
|
"contentBlockIndex": 1,
|
||||||
|
"start": {"toolUse": {"toolUseId": "t1", "name": "search"}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"contentBlockDelta": {
|
||||||
|
"contentBlockIndex": 1,
|
||||||
|
"delta": {"toolUse": {"input": '{"q":'}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"contentBlockDelta": {
|
||||||
|
"contentBlockIndex": 1,
|
||||||
|
"delta": {"toolUse": {"input": '"x"}'}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{"contentBlockStop": {"contentBlockIndex": 1}},
|
||||||
|
{"messageStop": {"stopReason": "tool_use"}},
|
||||||
|
{"metadata": {"usage": {"inputTokens": 3, "outputTokens": 4, "totalTokens": 7}}},
|
||||||
|
])
|
||||||
|
provider = BedrockProvider(region="us-east-1", client=client)
|
||||||
|
deltas: list[str] = []
|
||||||
|
|
||||||
|
result = await provider.chat_stream(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
model="bedrock/anthropic.claude-opus-4-7",
|
||||||
|
on_content_delta=lambda text: _append_delta(deltas, text),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert deltas == ["he", "llo"]
|
||||||
|
assert result.content == "hello"
|
||||||
|
assert result.finish_reason == "tool_calls"
|
||||||
|
assert result.usage == {"prompt_tokens": 3, "completion_tokens": 4, "total_tokens": 7}
|
||||||
|
assert result.tool_calls[0].name == "search"
|
||||||
|
assert result.tool_calls[0].arguments == {"q": "x"}
|
||||||
|
|
||||||
|
|
||||||
|
async def _append_delta(deltas: list[str], text: str) -> None:
|
||||||
|
deltas.append(text)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_error_maps_retry_metadata() -> None:
|
||||||
|
provider = BedrockProvider(region="us-east-1", client=FakeClient(error=FakeBedrockError()))
|
||||||
|
|
||||||
|
result = await provider.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
assert result.error_status_code == 429
|
||||||
|
assert result.error_should_retry is True
|
||||||
|
assert result.error_code == "throttlingexception"
|
||||||
|
assert result.retry_after == 3
|
||||||
@ -13,6 +13,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
|||||||
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
|
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
|
||||||
monkeypatch.delitem(sys.modules, "nanobot.providers.github_copilot_provider", raising=False)
|
monkeypatch.delitem(sys.modules, "nanobot.providers.github_copilot_provider", raising=False)
|
||||||
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
|
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
|
||||||
|
monkeypatch.delitem(sys.modules, "nanobot.providers.bedrock_provider", raising=False)
|
||||||
|
|
||||||
providers = importlib.import_module("nanobot.providers")
|
providers = importlib.import_module("nanobot.providers")
|
||||||
|
|
||||||
@ -21,6 +22,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
|||||||
assert "nanobot.providers.openai_codex_provider" not in sys.modules
|
assert "nanobot.providers.openai_codex_provider" not in sys.modules
|
||||||
assert "nanobot.providers.github_copilot_provider" not in sys.modules
|
assert "nanobot.providers.github_copilot_provider" not in sys.modules
|
||||||
assert "nanobot.providers.azure_openai_provider" not in sys.modules
|
assert "nanobot.providers.azure_openai_provider" not in sys.modules
|
||||||
|
assert "nanobot.providers.bedrock_provider" not in sys.modules
|
||||||
assert providers.__all__ == [
|
assert providers.__all__ == [
|
||||||
"LLMProvider",
|
"LLMProvider",
|
||||||
"LLMResponse",
|
"LLMResponse",
|
||||||
@ -29,6 +31,7 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
|||||||
"OpenAICodexProvider",
|
"OpenAICodexProvider",
|
||||||
"GitHubCopilotProvider",
|
"GitHubCopilotProvider",
|
||||||
"AzureOpenAIProvider",
|
"AzureOpenAIProvider",
|
||||||
|
"BedrockProvider",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,8 +10,8 @@ import pytest
|
|||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from nanobot.api.server import (
|
from nanobot.api.server import (
|
||||||
_sse_chunk,
|
|
||||||
_SSE_DONE,
|
_SSE_DONE,
|
||||||
|
_sse_chunk,
|
||||||
create_app,
|
create_app,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -111,13 +111,13 @@ async def test_stream_true_returns_sse(aiohttp_client) -> None:
|
|||||||
assert resp.content_type == "text/event-stream"
|
assert resp.content_type == "text/event-stream"
|
||||||
|
|
||||||
body = await resp.text()
|
body = await resp.text()
|
||||||
lines = [l for l in body.split("\n") if l.startswith("data: ")]
|
lines = [line for line in body.split("\n") if line.startswith("data: ")]
|
||||||
|
|
||||||
# Should have: 2 token chunks + 1 finish chunk + [DONE]
|
# Should have: 2 token chunks + 1 finish chunk + [DONE]
|
||||||
data_lines = [l[len("data: "):] for l in lines]
|
data_lines = [line[len("data: "):] for line in lines]
|
||||||
assert data_lines[-1] == "[DONE]"
|
assert data_lines[-1] == "[DONE]"
|
||||||
|
|
||||||
chunks = [json.loads(l) for l in data_lines[:-1]]
|
chunks = [json.loads(line) for line in data_lines[:-1]]
|
||||||
assert chunks[0]["choices"][0]["delta"]["content"] == "Hello"
|
assert chunks[0]["choices"][0]["delta"]["content"] == "Hello"
|
||||||
assert chunks[1]["choices"][0]["delta"]["content"] == " world"
|
assert chunks[1]["choices"][0]["delta"]["content"] == " world"
|
||||||
# Last chunk before [DONE] should have finish_reason=stop
|
# Last chunk before [DONE] should have finish_reason=stop
|
||||||
@ -181,8 +181,12 @@ async def test_stream_sse_chunk_ids_are_consistent(aiohttp_client) -> None:
|
|||||||
json={"messages": [{"role": "user", "content": "go"}], "stream": True},
|
json={"messages": [{"role": "user", "content": "go"}], "stream": True},
|
||||||
)
|
)
|
||||||
body = await resp.text()
|
body = await resp.text()
|
||||||
data_lines = [l[len("data: "):] for l in body.split("\n") if l.startswith("data: ") and l != "data: [DONE]"]
|
data_lines = [
|
||||||
chunks = [json.loads(l) for l in data_lines]
|
line[len("data: "):]
|
||||||
|
for line in body.split("\n")
|
||||||
|
if line.startswith("data: ") and line != "data: [DONE]"
|
||||||
|
]
|
||||||
|
chunks = [json.loads(line) for line in data_lines]
|
||||||
|
|
||||||
chunk_ids = {c["id"] for c in chunks}
|
chunk_ids = {c["id"] for c in chunks}
|
||||||
assert len(chunk_ids) == 1, f"Expected single chunk id, got {chunk_ids}"
|
assert len(chunk_ids) == 1, f"Expected single chunk id, got {chunk_ids}"
|
||||||
@ -218,6 +222,85 @@ async def test_stream_passes_on_stream_callbacks(aiohttp_client) -> None:
|
|||||||
assert captured_kwargs.get("on_stream_end") is not None
|
assert captured_kwargs.get("on_stream_end") is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_segment_end_does_not_close_sse(aiohttp_client) -> None:
|
||||||
|
"""Intermediate stream-end callbacks should not terminate the HTTP stream."""
|
||||||
|
agent = MagicMock()
|
||||||
|
|
||||||
|
async def fake_process_direct(*, on_stream=None, on_stream_end=None, **kwargs):
|
||||||
|
assert on_stream is not None
|
||||||
|
assert on_stream_end is not None
|
||||||
|
await on_stream("planning")
|
||||||
|
await on_stream_end(resuming=True)
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
await on_stream(" final")
|
||||||
|
await on_stream_end(resuming=False)
|
||||||
|
return "planning final"
|
||||||
|
|
||||||
|
agent.process_direct = fake_process_direct
|
||||||
|
agent._connect_mcp = AsyncMock()
|
||||||
|
agent.close_mcp = AsyncMock()
|
||||||
|
|
||||||
|
app = create_app(agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "use a tool"}], "stream": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
body = await resp.text()
|
||||||
|
data_lines = [
|
||||||
|
line[len("data: "):] for line in body.split("\n") if line.startswith("data: ")
|
||||||
|
]
|
||||||
|
assert data_lines[-1] == "[DONE]"
|
||||||
|
|
||||||
|
chunks = [json.loads(line) for line in data_lines[:-1]]
|
||||||
|
deltas = [c["choices"][0]["delta"].get("content", "") for c in chunks]
|
||||||
|
assert "planning" in deltas
|
||||||
|
assert " final" in deltas
|
||||||
|
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_uses_final_response_when_no_deltas(aiohttp_client) -> None:
|
||||||
|
"""stream=true should not return an empty stream when the agent returns content."""
|
||||||
|
agent = MagicMock()
|
||||||
|
|
||||||
|
async def fake_process_direct(*, on_stream=None, on_stream_end=None, **kwargs):
|
||||||
|
assert on_stream is not None
|
||||||
|
assert on_stream_end is not None
|
||||||
|
await on_stream_end(resuming=False)
|
||||||
|
return "plain final"
|
||||||
|
|
||||||
|
agent.process_direct = fake_process_direct
|
||||||
|
agent._connect_mcp = AsyncMock()
|
||||||
|
agent.close_mcp = AsyncMock()
|
||||||
|
|
||||||
|
app = create_app(agent, model_name="m")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={"messages": [{"role": "user", "content": "hi"}], "stream": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == 200
|
||||||
|
body = await resp.text()
|
||||||
|
data_lines = [
|
||||||
|
line[len("data: "):] for line in body.split("\n") if line.startswith("data: ")
|
||||||
|
]
|
||||||
|
chunks = [json.loads(line) for line in data_lines[:-1]]
|
||||||
|
deltas = [c["choices"][0]["delta"].get("content", "") for c in chunks]
|
||||||
|
|
||||||
|
assert "plain final" in deltas
|
||||||
|
assert data_lines[-1] == "[DONE]"
|
||||||
|
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_with_session_id(aiohttp_client) -> None:
|
async def test_stream_with_session_id(aiohttp_client) -> None:
|
||||||
|
|||||||
@ -27,12 +27,16 @@ class TestEditReadTracking:
|
|||||||
"""edit_file should warn when file hasn't been read first."""
|
"""edit_file should warn when file hasn't been read first."""
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def read_tool(self, tmp_path):
|
def file_states(self):
|
||||||
return ReadFileTool(workspace=tmp_path)
|
return file_state.FileStates()
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def edit_tool(self, tmp_path):
|
def read_tool(self, tmp_path, file_states):
|
||||||
return EditFileTool(workspace=tmp_path)
|
return ReadFileTool(workspace=tmp_path, file_states=file_states)
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def edit_tool(self, tmp_path, file_states):
|
||||||
|
return EditFileTool(workspace=tmp_path, file_states=file_states)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_edit_warns_if_file_not_read_first(self, edit_tool, tmp_path):
|
async def test_edit_warns_if_file_not_read_first(self, edit_tool, tmp_path):
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -166,35 +166,3 @@ class TestMessageToolTurnTracking:
|
|||||||
tool._sent_in_turn = True
|
tool._sent_in_turn = True
|
||||||
tool.start_turn()
|
tool.start_turn()
|
||||||
assert not tool._sent_in_turn
|
assert not tool._sent_in_turn
|
||||||
|
|
||||||
|
|
||||||
class TestSystemReplySuppression:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_subagent_system_reply_suppressed_when_duplicate(self, tmp_path: Path) -> None:
|
|
||||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
|
||||||
patch("nanobot.agent.loop.SessionManager") as MockSessionManager, \
|
|
||||||
patch("nanobot.agent.loop.SubagentManager"):
|
|
||||||
session = MagicMock()
|
|
||||||
session.get_history.return_value = []
|
|
||||||
MockSessionManager.return_value.get_or_create.return_value = session
|
|
||||||
|
|
||||||
bus = MessageBus()
|
|
||||||
provider = MagicMock()
|
|
||||||
provider.get_default_model.return_value = "test-model"
|
|
||||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model", memory_window=10)
|
|
||||||
|
|
||||||
loop._remember_visible_reply("feishu:chat123", "Done")
|
|
||||||
loop._run_agent_loop = AsyncMock(return_value=("Done", [], []))
|
|
||||||
loop._save_turn = MagicMock()
|
|
||||||
loop.sessions.save = MagicMock()
|
|
||||||
|
|
||||||
msg = InboundMessage(
|
|
||||||
channel="system",
|
|
||||||
sender_id="subagent",
|
|
||||||
chat_id="feishu:chat123",
|
|
||||||
content="background result",
|
|
||||||
metadata={"source": "subagent"},
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await loop._process_message(msg)
|
|
||||||
assert result is None
|
|
||||||
|
|||||||
@ -97,6 +97,67 @@ class TestReadDedup:
|
|||||||
assert isinstance(second, list)
|
assert isinstance(second, list)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Cross-session isolation (issue #3571)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Each session must keep its own read cache. When session A reads a file,
|
||||||
|
# session B reading the same file must still receive the full content, not
|
||||||
|
# the "[File unchanged since last read]" dedup stub. The stub is only valid
|
||||||
|
# within the session that first cached the read.
|
||||||
|
|
||||||
|
class TestReadDedupSessionIsolation:
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_separate_sessions_do_not_share_dedup_state(self, tmp_path):
|
||||||
|
f = tmp_path / "shared.txt"
|
||||||
|
f.write_text("\n".join(f"line {i}" for i in range(10)), encoding="utf-8")
|
||||||
|
|
||||||
|
session_a_tool = ReadFileTool(workspace=tmp_path)
|
||||||
|
session_b_tool = ReadFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
first = await session_a_tool.execute(path=str(f))
|
||||||
|
assert "line 0" in first
|
||||||
|
|
||||||
|
# Session B has never read this file before — it must see the full
|
||||||
|
# content, not the dedup stub from session A.
|
||||||
|
second = await session_b_tool.execute(path=str(f))
|
||||||
|
assert "unchanged" not in second.lower(), (
|
||||||
|
"Session B should not inherit session A's read-dedup state. "
|
||||||
|
f"Got: {second!r}"
|
||||||
|
)
|
||||||
|
assert "line 0" in second
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_shared_loop_tool_uses_bound_session_state(self, tmp_path):
|
||||||
|
f = tmp_path / "shared.txt"
|
||||||
|
f.write_text("\n".join(f"line {i}" for i in range(10)), encoding="utf-8")
|
||||||
|
|
||||||
|
# AgentLoop registers one shared ReadFileTool instance. The session
|
||||||
|
# boundary is the task-local FileStates binding, not the tool object.
|
||||||
|
shared_tool = ReadFileTool(workspace=tmp_path)
|
||||||
|
session_a = file_state.FileStates()
|
||||||
|
session_b = file_state.FileStates()
|
||||||
|
|
||||||
|
token = file_state.bind_file_states(session_a)
|
||||||
|
try:
|
||||||
|
first = await shared_tool.execute(path=str(f))
|
||||||
|
repeat = await shared_tool.execute(path=str(f))
|
||||||
|
finally:
|
||||||
|
file_state.reset_file_states(token)
|
||||||
|
|
||||||
|
assert "line 0" in first
|
||||||
|
assert "unchanged" in repeat.lower()
|
||||||
|
|
||||||
|
token = file_state.bind_file_states(session_b)
|
||||||
|
try:
|
||||||
|
second_session_read = await shared_tool.execute(path=str(f))
|
||||||
|
finally:
|
||||||
|
file_state.reset_file_states(token)
|
||||||
|
|
||||||
|
assert "unchanged" not in second_session_read.lower()
|
||||||
|
assert "line 0" in second_session_read
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# PDF support
|
# PDF support
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user