mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-24 10:32:45 +00:00
Merge origin/main into fix-ollama-image-generation
This commit is contained in:
commit
f5534bcaa0
14
README.md
14
README.md
@ -1,6 +1,18 @@
|
|||||||

|

|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
<p>
|
||||||
|
<a href="https://nanobot.wiki/docs/latest/getting-started/nanobot-overview">English</a> |
|
||||||
|
<a href="https://nanobot.wiki/cn/docs/latest/getting-started/nanobot-overview">简体中文</a> |
|
||||||
|
<a href="https://nanobot.wiki/zh-Hant/docs/latest/getting-started/nanobot-overview">繁體中文</a> |
|
||||||
|
<a href="https://nanobot.wiki/es/docs/latest/getting-started/nanobot-overview">Español</a> |
|
||||||
|
<a href="https://nanobot.wiki/fr/docs/latest/getting-started/nanobot-overview">Français</a> |
|
||||||
|
<a href="https://nanobot.wiki/id/docs/latest/getting-started/nanobot-overview">Bahasa Indonesia</a> |
|
||||||
|
<a href="https://nanobot.wiki/ja/docs/latest/getting-started/nanobot-overview">日本語</a> |
|
||||||
|
<a href="https://nanobot.wiki/ko/docs/latest/getting-started/nanobot-overview">한국어</a> |
|
||||||
|
<a href="https://nanobot.wiki/ru/docs/latest/getting-started/nanobot-overview">Русский</a> |
|
||||||
|
<a href="https://nanobot.wiki/vi/docs/latest/getting-started/nanobot-overview">Tiếng Việt</a>
|
||||||
|
</p>
|
||||||
<p>
|
<p>
|
||||||
<a href="https://pypi.org/project/nanobot-ai/"><img src="https://img.shields.io/pypi/v/nanobot-ai" alt="PyPI"></a>
|
<a href="https://pypi.org/project/nanobot-ai/"><img src="https://img.shields.io/pypi/v/nanobot-ai" alt="PyPI"></a>
|
||||||
<a href="https://pepy.tech/project/nanobot-ai"><img src="https://static.pepy.tech/badge/nanobot-ai" alt="Downloads"></a>
|
<a href="https://pepy.tech/project/nanobot-ai"><img src="https://static.pepy.tech/badge/nanobot-ai" alt="Downloads"></a>
|
||||||
@ -61,7 +73,7 @@
|
|||||||
- **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks.
|
- **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks.
|
||||||
- **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened.
|
- **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened.
|
||||||
- **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media.
|
- **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media.
|
||||||
- **2026-04-10** 📓 Notebook editing tool, multiple MCP servers, Feishu streaming & done-emoji.
|
- **2026-04-10** 📓 Multiple MCP servers, Feishu streaming & done-emoji.
|
||||||
- **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config.
|
- **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config.
|
||||||
- **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback.
|
- **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback.
|
||||||
- **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools.
|
- **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools.
|
||||||
|
|||||||
@ -17,6 +17,7 @@ Connect nanobot to your favorite chat platform. Want to build your own? See the
|
|||||||
| **Wecom** | Bot ID + Bot Secret |
|
| **Wecom** | Bot ID + Bot Secret |
|
||||||
| **Microsoft Teams** | App ID + App Password + public HTTPS endpoint |
|
| **Microsoft Teams** | App ID + App Password + public HTTPS endpoint |
|
||||||
| **Mochat** | Claw token (auto-setup available) |
|
| **Mochat** | Claw token (auto-setup available) |
|
||||||
|
| **Signal** | signal-cli daemon + phone number |
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Telegram</b> (Recommended)</summary>
|
<summary><b>Telegram</b> (Recommended)</summary>
|
||||||
@ -669,3 +670,69 @@ nanobot gateway
|
|||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Signal</b></summary>
|
||||||
|
|
||||||
|
Uses **signal-cli** daemon in HTTP mode — receive messages via SSE, send via JSON-RPC.
|
||||||
|
|
||||||
|
**1. Install signal-cli**
|
||||||
|
|
||||||
|
Install [signal-cli](https://github.com/AsamK/signal-cli) and register a phone number:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
signal-cli -u +1234567890 register
|
||||||
|
signal-cli -u +1234567890 verify <CODE>
|
||||||
|
```
|
||||||
|
|
||||||
|
Start the daemon:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
signal-cli -a +1234567890 daemon --http localhost:8080
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Configure**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"channels": {
|
||||||
|
"signal": {
|
||||||
|
"enabled": true,
|
||||||
|
"phoneNumber": "+1234567890",
|
||||||
|
"daemonHost": "localhost",
|
||||||
|
"daemonPort": 8080,
|
||||||
|
"dm": {
|
||||||
|
"enabled": true,
|
||||||
|
"policy": "open"
|
||||||
|
},
|
||||||
|
"group": {
|
||||||
|
"enabled": true,
|
||||||
|
"policy": "open",
|
||||||
|
"requireMention": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> - `phoneNumber`: Your registered Signal phone number.
|
||||||
|
> - `daemonHost` / `daemonPort`: Where signal-cli daemon is listening (default `localhost:8080`).
|
||||||
|
> - `dm.policy`: `"open"` (anyone can DM) or `"allowlist"` (only listed numbers/UUIDs). When `"allowlist"`, unlisted DM senders receive a pairing code.
|
||||||
|
> - `dm.allowFrom`: List of allowed phone numbers or UUIDs (used when policy is `"allowlist"`).
|
||||||
|
> - `group.policy`: `"open"` (all groups) or `"allowlist"` (only listed group IDs).
|
||||||
|
> - `group.requireMention`: When `true` (default), the bot only responds in groups when @mentioned.
|
||||||
|
> - `group.allowFrom`: List of allowed group IDs (used when group policy is `"allowlist"`).
|
||||||
|
> - `attachmentsDir`: Override the directory where signal-cli stores inbound attachments. Defaults to `~/.local/share/signal-cli/attachments` (the Linux default). Set this if signal-cli runs with a custom `XDG_DATA_HOME` or on macOS/Windows.
|
||||||
|
> - `groupMessageBufferSize`: Number of recent group messages kept for context (default `20`, must be > 0).
|
||||||
|
|
||||||
|
**3. Run**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nanobot gateway
|
||||||
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> The channel automatically reconnects to the signal-cli daemon with exponential backoff if the connection drops.
|
||||||
|
> Markdown in bot replies is automatically converted to Signal text styles (bold, italic, code, etc.).
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|||||||
@ -148,6 +148,7 @@ ANTHROPIC_API_KEY="$(bw get password api/anthropic)" nanobot agent
|
|||||||
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
||||||
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
||||||
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
|
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
|
||||||
|
| `novita` | LLM (Novita AI OpenAI-compatible gateway) | [novita.ai](https://novita.ai) |
|
||||||
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
||||||
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
||||||
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from nanobot.utils.prompt_templates import render_template
|
|||||||
class ContextBuilder:
|
class ContextBuilder:
|
||||||
"""Builds the context (system prompt + messages) for the agent."""
|
"""Builds the context (system prompt + messages) for the agent."""
|
||||||
|
|
||||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
|
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md"]
|
||||||
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||||
_MAX_RECENT_HISTORY = 50
|
_MAX_RECENT_HISTORY = 50
|
||||||
_MAX_HISTORY_CHARS = 32_000 # hard cap on recent history section size
|
_MAX_HISTORY_CHARS = 32_000 # hard cap on recent history section size
|
||||||
@ -47,6 +47,8 @@ class ContextBuilder:
|
|||||||
if bootstrap:
|
if bootstrap:
|
||||||
parts.append(bootstrap)
|
parts.append(bootstrap)
|
||||||
|
|
||||||
|
parts.append(render_template("agent/tool_contract.md"))
|
||||||
|
|
||||||
memory = self.memory.get_memory_context()
|
memory = self.memory.get_memory_context()
|
||||||
if memory and not self._is_template_content(self.memory.read_memory(), "memory/MEMORY.md"):
|
if memory and not self._is_template_content(self.memory.read_memory(), "memory/MEMORY.md"):
|
||||||
parts.append(f"# Memory\n\n{memory}")
|
parts.append(f"# Memory\n\n{memory}")
|
||||||
@ -210,4 +212,3 @@ class ContextBuilder:
|
|||||||
if not images:
|
if not images:
|
||||||
return text
|
return text
|
||||||
return images + [{"type": "text", "text": text}]
|
return images + [{"type": "text", "text": text}]
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,8 @@ from nanobot.utils.file_edit_events import (
|
|||||||
build_file_edit_end_event,
|
build_file_edit_end_event,
|
||||||
build_file_edit_error_event,
|
build_file_edit_error_event,
|
||||||
build_file_edit_start_event,
|
build_file_edit_start_event,
|
||||||
prepare_file_edit_tracker,
|
prepare_file_edit_tracker as _prepare_file_edit_tracker,
|
||||||
|
prepare_file_edit_trackers,
|
||||||
StreamingFileEditTracker,
|
StreamingFileEditTracker,
|
||||||
)
|
)
|
||||||
from nanobot.utils.helpers import (
|
from nanobot.utils.helpers import (
|
||||||
@ -58,11 +59,14 @@ _SNIP_SAFETY_BUFFER = 1024
|
|||||||
_MICROCOMPACT_KEEP_RECENT = 10
|
_MICROCOMPACT_KEEP_RECENT = 10
|
||||||
_MICROCOMPACT_MIN_CHARS = 500
|
_MICROCOMPACT_MIN_CHARS = 500
|
||||||
_COMPACTABLE_TOOLS = frozenset({
|
_COMPACTABLE_TOOLS = frozenset({
|
||||||
"read_file", "exec", "grep",
|
"read_file", "exec", "grep", "find_files",
|
||||||
"web_search", "web_fetch", "list_dir",
|
"web_search", "web_fetch", "list_dir", "list_exec_sessions",
|
||||||
})
|
})
|
||||||
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
|
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
|
||||||
|
|
||||||
|
# Backward-compatible module attribute for tests/extensions that monkeypatch
|
||||||
|
# the former single-file tracker hook. Runtime uses prepare_file_edit_trackers.
|
||||||
|
prepare_file_edit_tracker = _prepare_file_edit_tracker
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@ -857,8 +861,8 @@ class AgentRunner:
|
|||||||
and on_progress_accepts_file_edit_events(spec.progress_callback)
|
and on_progress_accepts_file_edit_events(spec.progress_callback)
|
||||||
)
|
)
|
||||||
progress_callback = spec.progress_callback if emit_file_edit_events else None
|
progress_callback = spec.progress_callback if emit_file_edit_events else None
|
||||||
file_edit_tracker = (
|
file_edit_trackers = (
|
||||||
prepare_file_edit_tracker(
|
prepare_file_edit_trackers(
|
||||||
call_id=tool_call.id,
|
call_id=tool_call.id,
|
||||||
tool_name=tool_call.name,
|
tool_name=tool_call.name,
|
||||||
tool=tool,
|
tool=tool,
|
||||||
@ -868,13 +872,13 @@ class AgentRunner:
|
|||||||
if progress_callback is not None
|
if progress_callback is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
if file_edit_tracker is not None and progress_callback is not None:
|
if file_edit_trackers and progress_callback is not None:
|
||||||
await invoke_file_edit_progress(
|
await invoke_file_edit_progress(
|
||||||
progress_callback,
|
progress_callback,
|
||||||
[build_file_edit_start_event(
|
[build_file_edit_start_event(
|
||||||
file_edit_tracker,
|
file_edit_tracker,
|
||||||
params if isinstance(params, dict) else None,
|
params if isinstance(params, dict) else None,
|
||||||
)],
|
) for file_edit_tracker in file_edit_trackers],
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if tool is not None:
|
if tool is not None:
|
||||||
@ -884,10 +888,13 @@ class AgentRunner:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
except BaseException as exc:
|
except BaseException as exc:
|
||||||
if file_edit_tracker is not None and progress_callback is not None:
|
if file_edit_trackers and progress_callback is not None:
|
||||||
await invoke_file_edit_progress(
|
await invoke_file_edit_progress(
|
||||||
progress_callback,
|
progress_callback,
|
||||||
[build_file_edit_error_event(file_edit_tracker, str(exc))],
|
[
|
||||||
|
build_file_edit_error_event(file_edit_tracker, str(exc))
|
||||||
|
for file_edit_tracker in file_edit_trackers
|
||||||
|
],
|
||||||
)
|
)
|
||||||
event = {
|
event = {
|
||||||
"name": tool_call.name,
|
"name": tool_call.name,
|
||||||
@ -910,10 +917,13 @@ class AgentRunner:
|
|||||||
return payload, event, None
|
return payload, event, None
|
||||||
|
|
||||||
if isinstance(result, str) and result.startswith("Error"):
|
if isinstance(result, str) and result.startswith("Error"):
|
||||||
if file_edit_tracker is not None and progress_callback is not None:
|
if file_edit_trackers and progress_callback is not None:
|
||||||
await invoke_file_edit_progress(
|
await invoke_file_edit_progress(
|
||||||
progress_callback,
|
progress_callback,
|
||||||
[build_file_edit_error_event(file_edit_tracker, result)],
|
[
|
||||||
|
build_file_edit_error_event(file_edit_tracker, result)
|
||||||
|
for file_edit_tracker in file_edit_trackers
|
||||||
|
],
|
||||||
)
|
)
|
||||||
event = {
|
event = {
|
||||||
"name": tool_call.name,
|
"name": tool_call.name,
|
||||||
@ -933,13 +943,13 @@ class AgentRunner:
|
|||||||
return result + hint, event, RuntimeError(result)
|
return result + hint, event, RuntimeError(result)
|
||||||
return result + hint, event, None
|
return result + hint, event, None
|
||||||
|
|
||||||
if file_edit_tracker is not None and progress_callback is not None:
|
if file_edit_trackers and progress_callback is not None:
|
||||||
await invoke_file_edit_progress(
|
await invoke_file_edit_progress(
|
||||||
progress_callback,
|
progress_callback,
|
||||||
[build_file_edit_end_event(
|
[build_file_edit_end_event(
|
||||||
file_edit_tracker,
|
file_edit_tracker,
|
||||||
params if isinstance(params, dict) else None,
|
params if isinstance(params, dict) else None,
|
||||||
)],
|
) for file_edit_tracker in file_edit_trackers],
|
||||||
)
|
)
|
||||||
|
|
||||||
detail = "" if result is None else str(result)
|
detail = "" if result is None else str(result)
|
||||||
|
|||||||
352
nanobot/agent/tools/apply_patch.py
Normal file
352
nanobot/agent/tools/apply_patch.py
Normal file
@ -0,0 +1,352 @@
|
|||||||
|
"""Apply file edits by providing structured edit instructions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import difflib
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from nanobot.agent.tools.base import tool_parameters
|
||||||
|
from nanobot.agent.tools.filesystem import _FsTool
|
||||||
|
from nanobot.agent.tools.schema import (
|
||||||
|
ArraySchema,
|
||||||
|
BooleanSchema,
|
||||||
|
ObjectSchema,
|
||||||
|
StringSchema,
|
||||||
|
tool_parameters_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _PatchSummary:
|
||||||
|
action: str
|
||||||
|
path: str
|
||||||
|
added: int = 0
|
||||||
|
deleted: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class _PatchError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_ABSOLUTE_WINDOWS_RE = re.compile(r"^[A-Za-z]:[\\/]")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_relative_path(path: str) -> str:
|
||||||
|
normalized = path.strip()
|
||||||
|
if not normalized:
|
||||||
|
raise _PatchError("patch path cannot be empty")
|
||||||
|
if "\0" in normalized:
|
||||||
|
raise _PatchError(f"patch path contains a null byte: {path!r}")
|
||||||
|
if normalized.startswith(("~", "/", "\\")) or _ABSOLUTE_WINDOWS_RE.match(normalized):
|
||||||
|
raise _PatchError(f"patch path must be relative: {path}")
|
||||||
|
if any(part == ".." for part in re.split(r"[\\/]+", normalized)):
|
||||||
|
raise _PatchError(f"patch path must not contain '..': {path}")
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _lines_to_text(lines: list[str]) -> str:
|
||||||
|
if not lines:
|
||||||
|
return ""
|
||||||
|
return "\n".join(lines) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _text_line_count(text: str) -> int:
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
return len(text.splitlines())
|
||||||
|
|
||||||
|
|
||||||
|
def _line_diff_stats(before: str, after: str) -> tuple[int, int]:
|
||||||
|
before_lines = before.replace("\r\n", "\n").splitlines()
|
||||||
|
after_lines = after.replace("\r\n", "\n").splitlines()
|
||||||
|
added = 0
|
||||||
|
deleted = 0
|
||||||
|
matcher = difflib.SequenceMatcher(a=before_lines, b=after_lines, autojunk=False)
|
||||||
|
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||||
|
if tag == "equal":
|
||||||
|
continue
|
||||||
|
if tag in ("replace", "delete"):
|
||||||
|
deleted += i2 - i1
|
||||||
|
if tag in ("replace", "insert"):
|
||||||
|
added += j2 - j1
|
||||||
|
return added, deleted
|
||||||
|
|
||||||
|
|
||||||
|
def _format_summary(summary: _PatchSummary) -> str:
|
||||||
|
stats = ""
|
||||||
|
if summary.added or summary.deleted:
|
||||||
|
stats = f" (+{summary.added}/-{summary.deleted})"
|
||||||
|
return f"- {summary.action} {summary.path}{stats}"
|
||||||
|
|
||||||
|
|
||||||
|
@tool_parameters(
|
||||||
|
tool_parameters_schema(
|
||||||
|
edits=ArraySchema(
|
||||||
|
items=ObjectSchema(
|
||||||
|
path=StringSchema("Relative path to the file to edit."),
|
||||||
|
action=StringSchema(
|
||||||
|
"Operation type: replace (find and replace text), add (append new content or create file), delete (remove text).",
|
||||||
|
enum=["replace", "add", "delete"],
|
||||||
|
),
|
||||||
|
old_text=StringSchema(
|
||||||
|
"Exact text to search for in the file. Required for replace and delete.",
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
new_text=StringSchema(
|
||||||
|
"Text to replace with or append. Required for replace and add.",
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
required=["path", "action"],
|
||||||
|
),
|
||||||
|
description="List of edits to apply. Each edit specifies a file and the change to make.",
|
||||||
|
min_items=1,
|
||||||
|
max_items=20,
|
||||||
|
),
|
||||||
|
dry_run=BooleanSchema(
|
||||||
|
description="Validate and summarize the patch without writing files.",
|
||||||
|
default=False,
|
||||||
|
),
|
||||||
|
required=["edits"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
class ApplyPatchTool(_FsTool):
|
||||||
|
"""Apply file edits by providing structured edit instructions."""
|
||||||
|
_scopes = {"core", "subagent"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "apply_patch"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Default tool for code edits. Supports multi-file changes in a single call. "
|
||||||
|
"Provide a list of structured edits, each specifying a file path, action (replace/add/delete), and the text to change. "
|
||||||
|
"Paths must be relative. Set dry_run=true to validate and preview without writing files. "
|
||||||
|
"Use edit_file only for small exact replacements on a single file."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
edits: list[dict] | None = None,
|
||||||
|
dry_run: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
if not edits:
|
||||||
|
raise _PatchError("must provide edits")
|
||||||
|
|
||||||
|
writes: dict[Path, str] = {}
|
||||||
|
deletes: set[Path] = set()
|
||||||
|
summaries: list[_PatchSummary] = []
|
||||||
|
|
||||||
|
for edit in edits:
|
||||||
|
if not isinstance(edit, dict):
|
||||||
|
raise _PatchError("each edit must be an object")
|
||||||
|
raw_path = edit.get("path")
|
||||||
|
if not isinstance(raw_path, str):
|
||||||
|
raise _PatchError("path required for edit")
|
||||||
|
path = _validate_relative_path(raw_path)
|
||||||
|
action = edit.get("action")
|
||||||
|
if not isinstance(action, str):
|
||||||
|
raise _PatchError(f"action required for edit: {path}")
|
||||||
|
source = self._resolve(path)
|
||||||
|
|
||||||
|
if action == "add":
|
||||||
|
new_text = edit.get("new_text")
|
||||||
|
if new_text is None:
|
||||||
|
raise _PatchError(f"new_text required for add: {path}")
|
||||||
|
|
||||||
|
pending = writes.get(source)
|
||||||
|
if pending is not None:
|
||||||
|
content = pending
|
||||||
|
exists = True
|
||||||
|
elif source.exists():
|
||||||
|
raw = source.read_bytes()
|
||||||
|
try:
|
||||||
|
content = raw.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
raise _PatchError(f"file is not UTF-8 text: {path}")
|
||||||
|
exists = True
|
||||||
|
else:
|
||||||
|
content = ""
|
||||||
|
exists = False
|
||||||
|
|
||||||
|
if exists:
|
||||||
|
uses_crlf = "\r\n" in content
|
||||||
|
new_norm = content.replace("\r\n", "\n") + new_text.replace("\r\n", "\n")
|
||||||
|
if new_norm and not new_norm.endswith("\n"):
|
||||||
|
new_norm += "\n"
|
||||||
|
if uses_crlf:
|
||||||
|
new_norm = new_norm.replace("\n", "\r\n")
|
||||||
|
writes[source] = new_norm
|
||||||
|
deletes.discard(source)
|
||||||
|
added, deleted = _line_diff_stats(content, new_norm)
|
||||||
|
action_name = "update"
|
||||||
|
else:
|
||||||
|
new_norm = new_text.replace("\r\n", "\n")
|
||||||
|
if new_norm and not new_norm.endswith("\n"):
|
||||||
|
new_norm += "\n"
|
||||||
|
writes[source] = new_norm
|
||||||
|
deletes.discard(source)
|
||||||
|
added = _text_line_count(new_norm)
|
||||||
|
deleted = 0
|
||||||
|
action_name = "add"
|
||||||
|
|
||||||
|
summaries.append(
|
||||||
|
_PatchSummary(
|
||||||
|
action=action_name, path=path, added=added, deleted=deleted
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action == "replace":
|
||||||
|
old_text = edit.get("old_text") or ""
|
||||||
|
if not old_text:
|
||||||
|
raise _PatchError(f"old_text required for replace: {path}")
|
||||||
|
new_text = edit.get("new_text")
|
||||||
|
if new_text is None:
|
||||||
|
raise _PatchError(f"new_text required for replace: {path}")
|
||||||
|
|
||||||
|
pending = writes.get(source)
|
||||||
|
if pending is not None:
|
||||||
|
content = pending
|
||||||
|
elif source.exists():
|
||||||
|
raw = source.read_bytes()
|
||||||
|
try:
|
||||||
|
content = raw.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
raise _PatchError(f"file is not UTF-8 text: {path}")
|
||||||
|
else:
|
||||||
|
raise _PatchError(f"file to update does not exist: {path}")
|
||||||
|
|
||||||
|
if pending is None and not source.is_file():
|
||||||
|
raise _PatchError(f"path to update is not a file: {path}")
|
||||||
|
|
||||||
|
uses_crlf = "\r\n" in content
|
||||||
|
norm_content = content.replace("\r\n", "\n")
|
||||||
|
norm_old = old_text.replace("\r\n", "\n")
|
||||||
|
|
||||||
|
pos = norm_content.find(norm_old)
|
||||||
|
if pos < 0:
|
||||||
|
raise _PatchError(f"old_text not found in {path}")
|
||||||
|
if norm_content.find(norm_old, pos + 1) >= 0:
|
||||||
|
raise _PatchError(f"old_text appears multiple times in {path}")
|
||||||
|
|
||||||
|
new_norm = (
|
||||||
|
norm_content[:pos]
|
||||||
|
+ new_text.replace("\r\n", "\n")
|
||||||
|
+ norm_content[pos + len(norm_old) :]
|
||||||
|
)
|
||||||
|
if new_norm and not new_norm.endswith("\n"):
|
||||||
|
new_norm += "\n"
|
||||||
|
if uses_crlf:
|
||||||
|
new_norm = new_norm.replace("\n", "\r\n")
|
||||||
|
|
||||||
|
writes[source] = new_norm
|
||||||
|
deletes.discard(source)
|
||||||
|
added, deleted = _line_diff_stats(content, new_norm)
|
||||||
|
summaries.append(
|
||||||
|
_PatchSummary(
|
||||||
|
action="update", path=path, added=added, deleted=deleted
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif action == "delete":
|
||||||
|
old_text = edit.get("old_text") or ""
|
||||||
|
if not old_text:
|
||||||
|
raise _PatchError(f"old_text required for delete: {path}")
|
||||||
|
|
||||||
|
pending = writes.get(source)
|
||||||
|
if pending is not None:
|
||||||
|
content = pending
|
||||||
|
elif source.exists():
|
||||||
|
raw = source.read_bytes()
|
||||||
|
try:
|
||||||
|
content = raw.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
raise _PatchError(f"file is not UTF-8 text: {path}")
|
||||||
|
else:
|
||||||
|
raise _PatchError(f"file to update does not exist: {path}")
|
||||||
|
|
||||||
|
if pending is None and not source.is_file():
|
||||||
|
raise _PatchError(f"path to update is not a file: {path}")
|
||||||
|
|
||||||
|
uses_crlf = "\r\n" in content
|
||||||
|
norm_content = content.replace("\r\n", "\n")
|
||||||
|
norm_old = old_text.replace("\r\n", "\n")
|
||||||
|
|
||||||
|
pos = norm_content.find(norm_old)
|
||||||
|
if pos < 0:
|
||||||
|
raise _PatchError(f"old_text not found in {path}")
|
||||||
|
if norm_content.find(norm_old, pos + 1) >= 0:
|
||||||
|
raise _PatchError(f"old_text appears multiple times in {path}")
|
||||||
|
|
||||||
|
if norm_old == norm_content:
|
||||||
|
deletes.add(source)
|
||||||
|
writes.pop(source, None)
|
||||||
|
added, deleted = 0, _text_line_count(content)
|
||||||
|
summaries.append(
|
||||||
|
_PatchSummary(
|
||||||
|
action="delete", path=path, added=added, deleted=deleted
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_norm = (
|
||||||
|
norm_content[:pos] + norm_content[pos + len(norm_old) :]
|
||||||
|
)
|
||||||
|
if new_norm and not new_norm.endswith("\n"):
|
||||||
|
new_norm += "\n"
|
||||||
|
if uses_crlf:
|
||||||
|
new_norm = new_norm.replace("\n", "\r\n")
|
||||||
|
writes[source] = new_norm
|
||||||
|
deletes.discard(source)
|
||||||
|
added, deleted = _line_diff_stats(content, new_norm)
|
||||||
|
summaries.append(
|
||||||
|
_PatchSummary(
|
||||||
|
action="update", path=path, added=added, deleted=deleted
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise _PatchError(f"unknown action: {action}")
|
||||||
|
|
||||||
|
if dry_run:
|
||||||
|
return "Patch dry-run succeeded:\n" + "\n".join(
|
||||||
|
_format_summary(summary) for summary in summaries
|
||||||
|
)
|
||||||
|
|
||||||
|
backups: dict[Path, bytes | None] = {}
|
||||||
|
for path in set(writes) | deletes:
|
||||||
|
backups[path] = path.read_bytes() if path.exists() else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
for path in deletes:
|
||||||
|
if path.exists():
|
||||||
|
path.unlink()
|
||||||
|
for path, content in writes.items():
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_text(content, encoding="utf-8", newline="")
|
||||||
|
except Exception:
|
||||||
|
for path, data in backups.items():
|
||||||
|
if data is None:
|
||||||
|
if path.exists():
|
||||||
|
path.unlink()
|
||||||
|
else:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_bytes(data)
|
||||||
|
raise
|
||||||
|
|
||||||
|
for path in set(writes) | deletes:
|
||||||
|
self._file_states.record_write(path)
|
||||||
|
return "Patch applied:\n" + "\n".join(
|
||||||
|
_format_summary(summary) for summary in summaries
|
||||||
|
)
|
||||||
|
except PermissionError as exc:
|
||||||
|
return f"Error: {exc}"
|
||||||
|
except _PatchError as exc:
|
||||||
|
return f"Error applying patch: {exc}"
|
||||||
|
except Exception as exc:
|
||||||
|
return f"Error applying patch: {exc}"
|
||||||
591
nanobot/agent/tools/exec_session.py
Normal file
591
nanobot/agent/tools/exec_session.py
Normal file
@ -0,0 +1,591 @@
|
|||||||
|
"""Session support for long-running exec workflows."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from contextlib import suppress
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
|
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_YIELD_MS = 1000
|
||||||
|
MAX_YIELD_MS = 30_000
|
||||||
|
DEFAULT_WAIT_FOR_MS = 10_000
|
||||||
|
MAX_WAIT_FOR_MS = 120_000
|
||||||
|
DEFAULT_MAX_OUTPUT_CHARS = 10_000
|
||||||
|
MAX_OUTPUT_CHARS = 50_000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _SessionPoll:
|
||||||
|
output: str
|
||||||
|
done: bool
|
||||||
|
exit_code: int | None
|
||||||
|
elapsed_s: float = 0.0
|
||||||
|
timed_out: bool = False
|
||||||
|
terminated: bool = False
|
||||||
|
stdin_closed: bool = False
|
||||||
|
truncated_chars: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ExecSessionInfo:
|
||||||
|
session_id: str
|
||||||
|
command: str
|
||||||
|
cwd: str
|
||||||
|
elapsed_s: float
|
||||||
|
idle_s: float
|
||||||
|
remaining_s: float
|
||||||
|
returncode: int | None
|
||||||
|
|
||||||
|
|
||||||
|
class _ExecSession:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session_id: str,
|
||||||
|
process: asyncio.subprocess.Process,
|
||||||
|
command: str,
|
||||||
|
cwd: str,
|
||||||
|
timeout: int,
|
||||||
|
) -> None:
|
||||||
|
self.session_id = session_id
|
||||||
|
self.process = process
|
||||||
|
self.command = command
|
||||||
|
self.cwd = cwd
|
||||||
|
self.started_at = time.monotonic()
|
||||||
|
self.deadline = time.monotonic() + timeout
|
||||||
|
self.last_access = time.monotonic()
|
||||||
|
self._chunks: list[str] = []
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
self._timed_out = False
|
||||||
|
self._stdout_task = asyncio.create_task(self._read_stream(process.stdout, ""))
|
||||||
|
self._stderr_task = asyncio.create_task(self._read_stream(process.stderr, "STDERR:\n"))
|
||||||
|
|
||||||
|
async def _read_stream(
|
||||||
|
self,
|
||||||
|
stream: asyncio.StreamReader | None,
|
||||||
|
prefix: str,
|
||||||
|
) -> None:
|
||||||
|
if stream is None:
|
||||||
|
return
|
||||||
|
first = True
|
||||||
|
while True:
|
||||||
|
chunk = await stream.read(4096)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
text = chunk.decode("utf-8", errors="replace")
|
||||||
|
if prefix and first:
|
||||||
|
text = prefix + text
|
||||||
|
first = False
|
||||||
|
async with self._lock:
|
||||||
|
self._chunks.append(text)
|
||||||
|
|
||||||
|
async def write(self, chars: str) -> str | None:
|
||||||
|
if self.process.returncode is not None:
|
||||||
|
return "session has already exited"
|
||||||
|
if self.process.stdin is None:
|
||||||
|
return "session stdin is not available"
|
||||||
|
try:
|
||||||
|
self.process.stdin.write(chars.encode("utf-8"))
|
||||||
|
await self.process.stdin.drain()
|
||||||
|
except (BrokenPipeError, ConnectionResetError):
|
||||||
|
return "session stdin is closed"
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def close_stdin(self) -> str | None:
|
||||||
|
if self.process.returncode is not None:
|
||||||
|
return "session has already exited"
|
||||||
|
if self.process.stdin is None:
|
||||||
|
return "session stdin is not available"
|
||||||
|
self.process.stdin.close()
|
||||||
|
with suppress(BrokenPipeError, ConnectionResetError):
|
||||||
|
await self.process.stdin.wait_closed()
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def poll(
|
||||||
|
self,
|
||||||
|
yield_time_ms: int,
|
||||||
|
max_output_chars: int,
|
||||||
|
*,
|
||||||
|
terminated: bool = False,
|
||||||
|
stdin_closed: bool = False,
|
||||||
|
) -> _SessionPoll:
|
||||||
|
self.last_access = time.monotonic()
|
||||||
|
if yield_time_ms > 0 and self.process.returncode is None:
|
||||||
|
await asyncio.sleep(min(yield_time_ms, MAX_YIELD_MS) / 1000)
|
||||||
|
|
||||||
|
if self.process.returncode is None and time.monotonic() >= self.deadline:
|
||||||
|
self._timed_out = True
|
||||||
|
await self.kill()
|
||||||
|
|
||||||
|
if self.process.returncode is not None:
|
||||||
|
with suppress(asyncio.TimeoutError):
|
||||||
|
await asyncio.wait_for(
|
||||||
|
asyncio.gather(self._stdout_task, self._stderr_task),
|
||||||
|
timeout=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
output = "".join(self._chunks)
|
||||||
|
self._chunks.clear()
|
||||||
|
|
||||||
|
output, truncated = _truncate_output(output, max_output_chars)
|
||||||
|
return _SessionPoll(
|
||||||
|
output=output,
|
||||||
|
done=self.process.returncode is not None,
|
||||||
|
exit_code=self.process.returncode,
|
||||||
|
elapsed_s=max(0.0, time.monotonic() - self.started_at),
|
||||||
|
timed_out=self._timed_out,
|
||||||
|
terminated=terminated,
|
||||||
|
stdin_closed=stdin_closed,
|
||||||
|
truncated_chars=truncated,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def kill(self) -> None:
|
||||||
|
if self.process.returncode is not None:
|
||||||
|
return
|
||||||
|
self.process.kill()
|
||||||
|
with suppress(asyncio.TimeoutError):
|
||||||
|
await asyncio.wait_for(self.process.wait(), timeout=5.0)
|
||||||
|
|
||||||
|
|
||||||
|
class ExecSessionManager:
|
||||||
|
def __init__(self, *, max_sessions: int = 8, idle_timeout: int = 1800) -> None:
|
||||||
|
self.max_sessions = max_sessions
|
||||||
|
self.idle_timeout = idle_timeout
|
||||||
|
self._sessions: dict[str, _ExecSession] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def start(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
command: str,
|
||||||
|
cwd: str,
|
||||||
|
env: dict[str, str],
|
||||||
|
timeout: int,
|
||||||
|
shell_program: str | None,
|
||||||
|
login: bool,
|
||||||
|
yield_time_ms: int,
|
||||||
|
max_output_chars: int,
|
||||||
|
) -> tuple[str, _SessionPoll]:
|
||||||
|
async with self._lock:
|
||||||
|
await self._cleanup_locked()
|
||||||
|
if len(self._sessions) >= self.max_sessions:
|
||||||
|
raise RuntimeError(f"maximum exec sessions reached ({self.max_sessions})")
|
||||||
|
process = await self._spawn(command, cwd, env, shell_program, login)
|
||||||
|
session_id = uuid.uuid4().hex[:12]
|
||||||
|
session = _ExecSession(
|
||||||
|
session_id=session_id,
|
||||||
|
process=process,
|
||||||
|
command=command,
|
||||||
|
cwd=cwd,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
self._sessions[session_id] = session
|
||||||
|
|
||||||
|
poll = await session.poll(yield_time_ms, max_output_chars)
|
||||||
|
if poll.done:
|
||||||
|
async with self._lock:
|
||||||
|
self._sessions.pop(session_id, None)
|
||||||
|
return session_id, poll
|
||||||
|
|
||||||
|
async def write(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session_id: str,
|
||||||
|
chars: str | None,
|
||||||
|
close_stdin: bool,
|
||||||
|
terminate: bool,
|
||||||
|
yield_time_ms: int,
|
||||||
|
max_output_chars: int,
|
||||||
|
) -> _SessionPoll:
|
||||||
|
async with self._lock:
|
||||||
|
await self._cleanup_locked()
|
||||||
|
session = self._sessions.get(session_id)
|
||||||
|
if session is None:
|
||||||
|
raise KeyError(session_id)
|
||||||
|
|
||||||
|
if chars:
|
||||||
|
error = await session.write(chars)
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
stdin_closed = False
|
||||||
|
if close_stdin:
|
||||||
|
error = await session.close_stdin()
|
||||||
|
if error:
|
||||||
|
raise RuntimeError(error)
|
||||||
|
stdin_closed = True
|
||||||
|
if terminate:
|
||||||
|
await session.kill()
|
||||||
|
poll = await session.poll(
|
||||||
|
yield_time_ms,
|
||||||
|
max_output_chars,
|
||||||
|
terminated=terminate,
|
||||||
|
stdin_closed=stdin_closed,
|
||||||
|
)
|
||||||
|
if poll.done:
|
||||||
|
async with self._lock:
|
||||||
|
self._sessions.pop(session_id, None)
|
||||||
|
return poll
|
||||||
|
|
||||||
|
async def list(self) -> list[ExecSessionInfo]:
|
||||||
|
async with self._lock:
|
||||||
|
await self._cleanup_locked()
|
||||||
|
now = time.monotonic()
|
||||||
|
return [
|
||||||
|
ExecSessionInfo(
|
||||||
|
session_id=session_id,
|
||||||
|
command=session.command,
|
||||||
|
cwd=session.cwd,
|
||||||
|
elapsed_s=max(0.0, now - session.started_at),
|
||||||
|
idle_s=max(0.0, now - session.last_access),
|
||||||
|
remaining_s=max(0.0, session.deadline - now),
|
||||||
|
returncode=session.process.returncode,
|
||||||
|
)
|
||||||
|
for session_id, session in sorted(self._sessions.items())
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _cleanup_locked(self) -> None:
|
||||||
|
now = time.monotonic()
|
||||||
|
stale = [
|
||||||
|
session_id
|
||||||
|
for session_id, session in self._sessions.items()
|
||||||
|
if now - session.last_access > self.idle_timeout
|
||||||
|
]
|
||||||
|
for session_id in stale:
|
||||||
|
session = self._sessions.pop(session_id)
|
||||||
|
await session.kill()
|
||||||
|
|
||||||
|
async def _spawn(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
cwd: str,
|
||||||
|
env: dict[str, str],
|
||||||
|
shell_program: str | None,
|
||||||
|
login: bool,
|
||||||
|
) -> asyncio.subprocess.Process:
|
||||||
|
from nanobot.agent.tools import shell
|
||||||
|
|
||||||
|
if shell._IS_WINDOWS:
|
||||||
|
return await asyncio.create_subprocess_shell(
|
||||||
|
command,
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=cwd,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
shell_program = shell_program or shutil.which("bash") or "/bin/bash"
|
||||||
|
args = [shell_program]
|
||||||
|
if login and shell_program.rsplit("/", 1)[-1] in {"bash", "zsh"}:
|
||||||
|
args.append("-l")
|
||||||
|
args.extend(["-c", command])
|
||||||
|
return await asyncio.create_subprocess_exec(
|
||||||
|
*args,
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
cwd=cwd,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_EXEC_SESSION_MANAGER = ExecSessionManager()
|
||||||
|
|
||||||
|
|
||||||
|
def clamp_session_int(value: int | None, default: int, minimum: int, maximum: int) -> int:
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
return min(max(value, minimum), maximum)
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_output(output: str, max_output_chars: int) -> tuple[str, int]:
|
||||||
|
if len(output) <= max_output_chars:
|
||||||
|
return output, 0
|
||||||
|
half = max_output_chars // 2
|
||||||
|
omitted = len(output) - max_output_chars
|
||||||
|
return (
|
||||||
|
output[:half]
|
||||||
|
+ f"\n\n... ({omitted:,} chars truncated) ...\n\n"
|
||||||
|
+ output[-half:],
|
||||||
|
omitted,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_session_poll(session_id: str, poll: _SessionPoll) -> str:
|
||||||
|
parts = [poll.output] if poll.output else []
|
||||||
|
if poll.truncated_chars:
|
||||||
|
parts.append(f"(output truncated by {poll.truncated_chars:,} chars)")
|
||||||
|
if poll.timed_out:
|
||||||
|
parts.append("Error: Command timed out; session was terminated.")
|
||||||
|
if poll.terminated and not poll.timed_out:
|
||||||
|
parts.append("Session terminated.")
|
||||||
|
if poll.stdin_closed:
|
||||||
|
parts.append("Stdin closed.")
|
||||||
|
if poll.done:
|
||||||
|
parts.append(f"Exit code: {poll.exit_code}")
|
||||||
|
else:
|
||||||
|
parts.append(f"Process running. session_id: {session_id}")
|
||||||
|
parts.append(f"Elapsed: {poll.elapsed_s:.1f}s")
|
||||||
|
return "\n".join(parts) if parts else "(no output yet)"
|
||||||
|
|
||||||
|
|
||||||
|
@tool_parameters(
|
||||||
|
tool_parameters_schema(
|
||||||
|
session_id=StringSchema("Session id returned by exec when yield_time_ms is used."),
|
||||||
|
chars=StringSchema(
|
||||||
|
"Bytes/text to write to stdin. Omit or pass an empty string to only poll recent output.",
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
close_stdin=BooleanSchema(
|
||||||
|
description="Close stdin after writing chars. Useful for commands waiting for EOF.",
|
||||||
|
default=False,
|
||||||
|
),
|
||||||
|
terminate=BooleanSchema(
|
||||||
|
description="Terminate the running exec session.",
|
||||||
|
default=False,
|
||||||
|
),
|
||||||
|
yield_time_ms=IntegerSchema(
|
||||||
|
DEFAULT_YIELD_MS,
|
||||||
|
description="Milliseconds to wait before returning recent output (default 1000, max 30000).",
|
||||||
|
minimum=0,
|
||||||
|
maximum=MAX_YIELD_MS,
|
||||||
|
),
|
||||||
|
wait_for=StringSchema(
|
||||||
|
"Optional text to wait for in output before returning. "
|
||||||
|
"Useful for interactive commands and dev servers.",
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
wait_timeout_ms=IntegerSchema(
|
||||||
|
DEFAULT_WAIT_FOR_MS,
|
||||||
|
description="Maximum milliseconds to wait for wait_for text (default 10000, max 120000).",
|
||||||
|
minimum=0,
|
||||||
|
maximum=MAX_WAIT_FOR_MS,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
max_output_chars=IntegerSchema(
|
||||||
|
DEFAULT_MAX_OUTPUT_CHARS,
|
||||||
|
description="Maximum output characters to return from this poll (default 10000, max 50000).",
|
||||||
|
minimum=1000,
|
||||||
|
maximum=MAX_OUTPUT_CHARS,
|
||||||
|
),
|
||||||
|
max_output_tokens=IntegerSchema(
|
||||||
|
DEFAULT_MAX_OUTPUT_CHARS,
|
||||||
|
description="Compatibility alias for max_output_chars. The current runtime uses a character budget.",
|
||||||
|
minimum=1000,
|
||||||
|
maximum=MAX_OUTPUT_CHARS,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
required=["session_id"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
class WriteStdinTool(Tool):
|
||||||
|
"""Write to or poll a running exec session."""
|
||||||
|
|
||||||
|
_scopes = {"core", "subagent"}
|
||||||
|
config_key = "exec"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_cls(cls):
|
||||||
|
from nanobot.agent.tools.shell import ExecToolConfig
|
||||||
|
|
||||||
|
return ExecToolConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enabled(cls, ctx: Any) -> bool:
|
||||||
|
return ctx.config.exec.enable
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
manager: ExecSessionManager | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, ctx: Any) -> Tool:
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exclusive(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "write_stdin"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Interact with a running exec session created by exec with "
|
||||||
|
"yield_time_ms. Use chars='' to poll without writing, chars to send "
|
||||||
|
"stdin, close_stdin=true to send EOF, or terminate=true to stop the "
|
||||||
|
"process. Use wait_for with wait_timeout_ms for dev servers, test "
|
||||||
|
"watchers, and prompts where you need to wait for expected output. "
|
||||||
|
"Do not use this to start new commands; start them with exec."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
chars: str | None = None,
|
||||||
|
close_stdin: bool = False,
|
||||||
|
terminate: bool = False,
|
||||||
|
yield_time_ms: int | None = None,
|
||||||
|
wait_for: str | None = None,
|
||||||
|
wait_timeout_ms: int | None = None,
|
||||||
|
max_output_chars: int | None = None,
|
||||||
|
max_output_tokens: int | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
if max_output_chars is None:
|
||||||
|
max_output_chars = max_output_tokens
|
||||||
|
output_limit = clamp_session_int(
|
||||||
|
max_output_chars,
|
||||||
|
DEFAULT_MAX_OUTPUT_CHARS,
|
||||||
|
1000,
|
||||||
|
MAX_OUTPUT_CHARS,
|
||||||
|
)
|
||||||
|
if wait_for:
|
||||||
|
return await self._wait_for_output(
|
||||||
|
session_id=session_id,
|
||||||
|
chars=chars,
|
||||||
|
close_stdin=close_stdin,
|
||||||
|
terminate=terminate,
|
||||||
|
wait_for=wait_for,
|
||||||
|
wait_timeout_ms=clamp_session_int(
|
||||||
|
wait_timeout_ms,
|
||||||
|
DEFAULT_WAIT_FOR_MS,
|
||||||
|
0,
|
||||||
|
MAX_WAIT_FOR_MS,
|
||||||
|
),
|
||||||
|
max_output_chars=output_limit,
|
||||||
|
)
|
||||||
|
poll = await self._manager.write(
|
||||||
|
session_id=session_id,
|
||||||
|
chars=chars,
|
||||||
|
close_stdin=close_stdin,
|
||||||
|
terminate=terminate,
|
||||||
|
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
|
||||||
|
max_output_chars=output_limit,
|
||||||
|
)
|
||||||
|
return format_session_poll(session_id, poll)
|
||||||
|
except KeyError:
|
||||||
|
return f"Error: exec session not found: {session_id}"
|
||||||
|
except Exception as exc:
|
||||||
|
return f"Error writing to exec session: {exc}"
|
||||||
|
|
||||||
|
async def _wait_for_output(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
session_id: str,
|
||||||
|
chars: str | None,
|
||||||
|
close_stdin: bool,
|
||||||
|
terminate: bool,
|
||||||
|
wait_for: str,
|
||||||
|
wait_timeout_ms: int,
|
||||||
|
max_output_chars: int,
|
||||||
|
) -> str:
|
||||||
|
deadline = time.monotonic() + (wait_timeout_ms / 1000)
|
||||||
|
aggregate: list[str] = []
|
||||||
|
first = True
|
||||||
|
poll: _SessionPoll | None = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
remaining_ms = max(0, int((deadline - time.monotonic()) * 1000))
|
||||||
|
step_ms = min(500, remaining_ms)
|
||||||
|
poll = await self._manager.write(
|
||||||
|
session_id=session_id,
|
||||||
|
chars=chars if first else None,
|
||||||
|
close_stdin=close_stdin if first else False,
|
||||||
|
terminate=terminate if first else False,
|
||||||
|
yield_time_ms=step_ms,
|
||||||
|
max_output_chars=max_output_chars,
|
||||||
|
)
|
||||||
|
first = False
|
||||||
|
if poll.output:
|
||||||
|
aggregate.append(poll.output)
|
||||||
|
joined = "".join(aggregate)
|
||||||
|
if wait_for in joined:
|
||||||
|
poll.output = joined
|
||||||
|
return format_session_poll(session_id, poll)
|
||||||
|
if poll.done or remaining_ms <= 0:
|
||||||
|
poll.output = "".join(aggregate)
|
||||||
|
result = format_session_poll(session_id, poll)
|
||||||
|
if wait_for not in poll.output:
|
||||||
|
result += f"\nWait target not observed: {wait_for!r}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@tool_parameters(tool_parameters_schema())
|
||||||
|
class ListExecSessionsTool(Tool):
|
||||||
|
"""List active exec sessions."""
|
||||||
|
|
||||||
|
_scopes = {"core", "subagent"}
|
||||||
|
config_key = "exec"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_cls(cls):
|
||||||
|
from nanobot.agent.tools.shell import ExecToolConfig
|
||||||
|
|
||||||
|
return ExecToolConfig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def enabled(cls, ctx: Any) -> bool:
|
||||||
|
return ctx.config.exec.enable
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
manager: ExecSessionManager | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, ctx: Any) -> Tool:
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "list_exec_sessions"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"List active long-running exec sessions, including session_id, cwd, "
|
||||||
|
"elapsed time, idle time, remaining timeout, and command preview. "
|
||||||
|
"Use this to recover a session_id after context shifts before "
|
||||||
|
"polling, writing stdin, or terminating with write_stdin."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def execute(self, **kwargs: Any) -> str:
|
||||||
|
try:
|
||||||
|
sessions = await self._manager.list()
|
||||||
|
if not sessions:
|
||||||
|
return "No active exec sessions."
|
||||||
|
lines = []
|
||||||
|
for info in sessions:
|
||||||
|
command = " ".join(info.command.split())
|
||||||
|
if len(command) > 120:
|
||||||
|
command = command[:119] + "..."
|
||||||
|
status = "exited" if info.returncode is not None else "running"
|
||||||
|
lines.append(
|
||||||
|
f"{info.session_id} | {status} | elapsed={info.elapsed_s:.1f}s "
|
||||||
|
f"| idle={info.idle_s:.1f}s | remaining={info.remaining_s:.1f}s "
|
||||||
|
f"| cwd={info.cwd} | {command}"
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
except Exception as exc:
|
||||||
|
return f"Error listing exec sessions: {exc}"
|
||||||
@ -132,6 +132,10 @@ def _parse_page_range(pages: str, total: int) -> tuple[int, int]:
|
|||||||
minimum=1,
|
minimum=1,
|
||||||
),
|
),
|
||||||
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
|
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
|
||||||
|
force=BooleanSchema(
|
||||||
|
description="Bypass same-file read deduplication and return content again.",
|
||||||
|
default=False,
|
||||||
|
),
|
||||||
required=["path"],
|
required=["path"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -154,7 +158,11 @@ class ReadFileTool(_FsTool):
|
|||||||
"Text output format: LINE_NUM|CONTENT. "
|
"Text output format: LINE_NUM|CONTENT. "
|
||||||
"Images return visual content for analysis. "
|
"Images return visual content for analysis. "
|
||||||
"Supports PDF, DOCX, XLSX, PPTX documents. "
|
"Supports PDF, DOCX, XLSX, PPTX documents. "
|
||||||
|
"Use find_files/list_dir first when the path is uncertain. "
|
||||||
|
"Read the relevant range before editing so replacements or patches "
|
||||||
|
"are based on current content. "
|
||||||
"Use offset and limit for large text files. "
|
"Use offset and limit for large text files. "
|
||||||
|
"Use force=true to re-read content even if unchanged. "
|
||||||
"Reads exceeding ~128K chars are truncated."
|
"Reads exceeding ~128K chars are truncated."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -162,7 +170,15 @@ class ReadFileTool(_FsTool):
|
|||||||
def read_only(self) -> bool:
|
def read_only(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any:
|
async def execute(
|
||||||
|
self,
|
||||||
|
path: str | None = None,
|
||||||
|
offset: int = 1,
|
||||||
|
limit: int | None = None,
|
||||||
|
pages: str | None = None,
|
||||||
|
force: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
try:
|
try:
|
||||||
if not path:
|
if not path:
|
||||||
return "Error reading file: Unknown path"
|
return "Error reading file: Unknown path"
|
||||||
@ -202,7 +218,13 @@ class ReadFileTool(_FsTool):
|
|||||||
current_mtime = os.path.getmtime(fp)
|
current_mtime = os.path.getmtime(fp)
|
||||||
except OSError:
|
except OSError:
|
||||||
current_mtime = 0.0
|
current_mtime = 0.0
|
||||||
if entry and entry.can_dedup and entry.offset == offset and entry.limit == limit:
|
if (
|
||||||
|
not force
|
||||||
|
and entry
|
||||||
|
and entry.can_dedup
|
||||||
|
and entry.offset == offset
|
||||||
|
and entry.limit == limit
|
||||||
|
):
|
||||||
if current_mtime != entry.mtime:
|
if current_mtime != entry.mtime:
|
||||||
# File was modified externally - force full read and mark as not dedupable
|
# File was modified externally - force full read and mark as not dedupable
|
||||||
entry.can_dedup = False
|
entry.can_dedup = False
|
||||||
@ -365,9 +387,10 @@ class WriteFileTool(_FsTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Write content to a file. Overwrites if the file already exists; "
|
"Create a new file or intentionally replace an entire file with "
|
||||||
"creates parent directories as needed. "
|
"the provided content. Overwrites existing files and creates parent "
|
||||||
"For partial edits, prefer edit_file instead."
|
"directories as needed. For code changes or partial edits, prefer "
|
||||||
|
"apply_patch; use edit_file only for small exact replacements."
|
||||||
)
|
)
|
||||||
|
|
||||||
async def execute(self, path: str | None = None, content: str | None = None, **kwargs: Any) -> str:
|
async def execute(self, path: str | None = None, content: str | None = None, **kwargs: Any) -> str:
|
||||||
@ -657,6 +680,24 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
|||||||
old_text=StringSchema("The text to find and replace"),
|
old_text=StringSchema("The text to find and replace"),
|
||||||
new_text=StringSchema("The text to replace with"),
|
new_text=StringSchema("The text to replace with"),
|
||||||
replace_all=BooleanSchema(description="Replace all occurrences (default false)"),
|
replace_all=BooleanSchema(description="Replace all occurrences (default false)"),
|
||||||
|
occurrence=IntegerSchema(
|
||||||
|
1,
|
||||||
|
description="Optional 1-based occurrence to replace when old_text appears multiple times.",
|
||||||
|
minimum=1,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
line_hint=IntegerSchema(
|
||||||
|
1,
|
||||||
|
description="Optional 1-based line hint used to choose the nearest match.",
|
||||||
|
minimum=1,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
expected_replacements=IntegerSchema(
|
||||||
|
1,
|
||||||
|
description="Optional guard for the number of replacements that must be made.",
|
||||||
|
minimum=1,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
required=["path", "old_text", "new_text"],
|
required=["path", "old_text", "new_text"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -674,10 +715,13 @@ class EditFileTool(_FsTool):
|
|||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Edit a file by replacing old_text with new_text. "
|
"Perform a small, exact replacement in one file by replacing "
|
||||||
"Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. "
|
"old_text with new_text. Use this for narrow text substitutions "
|
||||||
"If old_text matches multiple times, you must provide more context "
|
"with old_text copied from read_file. For multi-file, structural, "
|
||||||
"or set replace_all=true. Shows a diff of the closest match on failure."
|
"or generated code edits, prefer apply_patch. If old_text matches "
|
||||||
|
"multiple times, provide more context or set occurrence, line_hint, "
|
||||||
|
"replace_all, and expected_replacements. Shows closest-match "
|
||||||
|
"diagnostics on failure."
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -688,7 +732,8 @@ class EditFileTool(_FsTool):
|
|||||||
async def execute(
|
async def execute(
|
||||||
self, path: str | None = None, old_text: str | None = None,
|
self, path: str | None = None, old_text: str | None = None,
|
||||||
new_text: str | None = None,
|
new_text: str | None = None,
|
||||||
replace_all: bool = False, **kwargs: Any,
|
replace_all: bool = False, occurrence: int | None = None,
|
||||||
|
line_hint: int | None = None, expected_replacements: int | None = None, **kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
try:
|
try:
|
||||||
if not path:
|
if not path:
|
||||||
@ -697,10 +742,12 @@ class EditFileTool(_FsTool):
|
|||||||
raise ValueError("Unknown old_text")
|
raise ValueError("Unknown old_text")
|
||||||
if new_text is None:
|
if new_text is None:
|
||||||
raise ValueError("Unknown new_text")
|
raise ValueError("Unknown new_text")
|
||||||
|
if occurrence is not None and occurrence < 1:
|
||||||
# .ipynb detection
|
return "Error: occurrence must be >= 1."
|
||||||
if path.endswith(".ipynb"):
|
if line_hint is not None and line_hint < 1:
|
||||||
return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file."
|
return "Error: line_hint must be >= 1."
|
||||||
|
if expected_replacements is not None and expected_replacements < 1:
|
||||||
|
return "Error: expected_replacements must be >= 1."
|
||||||
|
|
||||||
fp = self._resolve(path)
|
fp = self._resolve(path)
|
||||||
|
|
||||||
@ -743,15 +790,42 @@ class EditFileTool(_FsTool):
|
|||||||
if not matches:
|
if not matches:
|
||||||
return self._not_found_msg(old_text, content, path)
|
return self._not_found_msg(old_text, content, path)
|
||||||
count = len(matches)
|
count = len(matches)
|
||||||
|
if replace_all and occurrence is not None:
|
||||||
|
return "Error: occurrence cannot be used with replace_all=true."
|
||||||
|
if replace_all and line_hint is not None:
|
||||||
|
return "Error: line_hint cannot be used with replace_all=true."
|
||||||
|
if occurrence is not None and line_hint is not None:
|
||||||
|
return "Error: line_hint cannot be used with occurrence."
|
||||||
if count > 1 and not replace_all:
|
if count > 1 and not replace_all:
|
||||||
line_numbers = [match.line for match in matches]
|
if occurrence is not None:
|
||||||
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
|
if occurrence > count:
|
||||||
if len(line_numbers) > 3:
|
return (
|
||||||
preview += ", ..."
|
f"Error: occurrence {occurrence} is out of range; "
|
||||||
location_hint = f" at {preview}" if preview else ""
|
f"old_text appears {count} times."
|
||||||
|
)
|
||||||
|
elif line_hint is not None:
|
||||||
|
nearest = min(matches, key=lambda match: abs(match.line - line_hint))
|
||||||
|
distance = abs(nearest.line - line_hint)
|
||||||
|
if sum(1 for match in matches if abs(match.line - line_hint) == distance) > 1:
|
||||||
|
return (
|
||||||
|
f"Error: line_hint {line_hint} is ambiguous; "
|
||||||
|
f"old_text appears {count} times."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
line_numbers = [match.line for match in matches]
|
||||||
|
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
|
||||||
|
if len(line_numbers) > 3:
|
||||||
|
preview += ", ..."
|
||||||
|
location_hint = f" at {preview}" if preview else ""
|
||||||
|
return (
|
||||||
|
f"Warning: old_text appears {count} times{location_hint}. "
|
||||||
|
"Provide more context, set occurrence to choose one match, "
|
||||||
|
"or set replace_all=true."
|
||||||
|
)
|
||||||
|
elif occurrence is not None and occurrence > count:
|
||||||
return (
|
return (
|
||||||
f"Warning: old_text appears {count} times{location_hint}. "
|
f"Error: occurrence {occurrence} is out of range; "
|
||||||
"Provide more context to make it unique, or set replace_all=true."
|
f"old_text appears {count} time."
|
||||||
)
|
)
|
||||||
|
|
||||||
norm_new = new_text.replace("\r\n", "\n")
|
norm_new = new_text.replace("\r\n", "\n")
|
||||||
@ -760,7 +834,17 @@ class EditFileTool(_FsTool):
|
|||||||
if fp.suffix.lower() not in self._MARKDOWN_EXTS:
|
if fp.suffix.lower() not in self._MARKDOWN_EXTS:
|
||||||
norm_new = self._strip_trailing_ws(norm_new)
|
norm_new = self._strip_trailing_ws(norm_new)
|
||||||
|
|
||||||
selected = matches if replace_all else matches[:1]
|
if replace_all:
|
||||||
|
selected = matches
|
||||||
|
elif line_hint is not None:
|
||||||
|
selected = [min(matches, key=lambda match: abs(match.line - line_hint))]
|
||||||
|
else:
|
||||||
|
selected = [matches[occurrence - 1 if occurrence else 0]]
|
||||||
|
if expected_replacements is not None and len(selected) != expected_replacements:
|
||||||
|
return (
|
||||||
|
f"Error: expected {expected_replacements} replacements but "
|
||||||
|
f"would make {len(selected)}."
|
||||||
|
)
|
||||||
new_content = content
|
new_content = content
|
||||||
for match in reversed(selected):
|
for match in reversed(selected):
|
||||||
replacement = _preserve_quote_style(norm_old, match.text, norm_new)
|
replacement = _preserve_quote_style(norm_old, match.text, norm_new)
|
||||||
|
|||||||
@ -21,7 +21,6 @@ from nanobot.providers.image_generation import (
|
|||||||
ImageGenerationProvider,
|
ImageGenerationProvider,
|
||||||
get_image_gen_provider,
|
get_image_gen_provider,
|
||||||
)
|
)
|
||||||
from nanobot.providers.registry import find_by_name
|
|
||||||
from nanobot.utils.artifacts import (
|
from nanobot.utils.artifacts import (
|
||||||
ArtifactError,
|
ArtifactError,
|
||||||
generated_image_tool_result,
|
generated_image_tool_result,
|
||||||
@ -118,10 +117,6 @@ class ImageGenerationTool(Tool):
|
|||||||
def _provider_config(self) -> ProviderConfig | None:
|
def _provider_config(self) -> ProviderConfig | None:
|
||||||
return self.provider_configs.get(self.config.provider)
|
return self.provider_configs.get(self.config.provider)
|
||||||
|
|
||||||
def _provider_allows_missing_api_key(self) -> bool:
|
|
||||||
spec = find_by_name(self.config.provider)
|
|
||||||
return bool(spec and (spec.is_local or spec.is_direct or spec.is_oauth))
|
|
||||||
|
|
||||||
def _provider_client(self) -> ImageGenerationProvider | None:
|
def _provider_client(self) -> ImageGenerationProvider | None:
|
||||||
provider = self._provider_config()
|
provider = self._provider_config()
|
||||||
cls = get_image_gen_provider(self.config.provider)
|
cls = get_image_gen_provider(self.config.provider)
|
||||||
@ -135,12 +130,6 @@ class ImageGenerationTool(Tool):
|
|||||||
}
|
}
|
||||||
return cls(**kwargs)
|
return cls(**kwargs)
|
||||||
|
|
||||||
def _missing_api_key_error(self) -> str:
|
|
||||||
cls = get_image_gen_provider(self.config.provider)
|
|
||||||
if cls and cls.missing_key_message:
|
|
||||||
return f"Error: {cls.missing_key_message}"
|
|
||||||
return f"Error: {self.config.provider} API key is not configured."
|
|
||||||
|
|
||||||
def _resolve_reference_image(self, value: str) -> str:
|
def _resolve_reference_image(self, value: str) -> str:
|
||||||
raw_path = Path(value).expanduser()
|
raw_path = Path(value).expanduser()
|
||||||
path = raw_path if raw_path.is_absolute() else self.workspace / raw_path
|
path = raw_path if raw_path.is_absolute() else self.workspace / raw_path
|
||||||
@ -178,9 +167,6 @@ class ImageGenerationTool(Tool):
|
|||||||
client = self._provider_client()
|
client = self._provider_client()
|
||||||
if client is None:
|
if client is None:
|
||||||
return f"Error: unsupported image generation provider '{self.config.provider}'"
|
return f"Error: unsupported image generation provider '{self.config.provider}'"
|
||||||
provider = self._provider_config()
|
|
||||||
if not self._provider_allows_missing_api_key() and (not provider or not provider.api_key):
|
|
||||||
return self._missing_api_key_error()
|
|
||||||
|
|
||||||
requested = count or 1
|
requested = count or 1
|
||||||
if requested > self.config.max_images_per_turn:
|
if requested > self.config.max_images_per_turn:
|
||||||
|
|||||||
@ -1,162 +0,0 @@
|
|||||||
"""NotebookEditTool — edit Jupyter .ipynb notebooks."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from nanobot.agent.tools.base import tool_parameters
|
|
||||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
|
||||||
from nanobot.agent.tools.filesystem import _FsTool
|
|
||||||
|
|
||||||
|
|
||||||
def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict:
|
|
||||||
cell: dict[str, Any] = {
|
|
||||||
"cell_type": cell_type,
|
|
||||||
"source": source,
|
|
||||||
"metadata": {},
|
|
||||||
}
|
|
||||||
if cell_type == "code":
|
|
||||||
cell["outputs"] = []
|
|
||||||
cell["execution_count"] = None
|
|
||||||
if generate_id:
|
|
||||||
cell["id"] = uuid.uuid4().hex[:8]
|
|
||||||
return cell
|
|
||||||
|
|
||||||
|
|
||||||
def _make_empty_notebook() -> dict:
|
|
||||||
return {
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5,
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
|
||||||
"language_info": {"name": "python"},
|
|
||||||
},
|
|
||||||
"cells": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@tool_parameters(
|
|
||||||
tool_parameters_schema(
|
|
||||||
path=StringSchema("Path to the .ipynb notebook file"),
|
|
||||||
cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0),
|
|
||||||
new_source=StringSchema("New source content for the cell"),
|
|
||||||
cell_type=StringSchema(
|
|
||||||
"Cell type: 'code' or 'markdown' (default: code)",
|
|
||||||
enum=["code", "markdown"],
|
|
||||||
),
|
|
||||||
edit_mode=StringSchema(
|
|
||||||
"Mode: 'replace' (default), 'insert' (after target), or 'delete'",
|
|
||||||
enum=["replace", "insert", "delete"],
|
|
||||||
),
|
|
||||||
required=["path", "cell_index"],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
class NotebookEditTool(_FsTool):
|
|
||||||
"""Edit Jupyter notebook cells: replace, insert, or delete."""
|
|
||||||
_scopes = {"core"}
|
|
||||||
|
|
||||||
_VALID_CELL_TYPES = frozenset({"code", "markdown"})
|
|
||||||
_VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"})
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return "notebook_edit"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def description(self) -> str:
|
|
||||||
return (
|
|
||||||
"Edit a Jupyter notebook (.ipynb) cell. "
|
|
||||||
"Modes: replace (default) replaces cell content, "
|
|
||||||
"insert adds a new cell after the target index, "
|
|
||||||
"delete removes the cell at the index. "
|
|
||||||
"cell_index is 0-based."
|
|
||||||
)
|
|
||||||
|
|
||||||
async def execute(
|
|
||||||
self,
|
|
||||||
path: str | None = None,
|
|
||||||
cell_index: int = 0,
|
|
||||||
new_source: str = "",
|
|
||||||
cell_type: str = "code",
|
|
||||||
edit_mode: str = "replace",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> str:
|
|
||||||
try:
|
|
||||||
if not path:
|
|
||||||
return "Error: path is required"
|
|
||||||
|
|
||||||
if not path.endswith(".ipynb"):
|
|
||||||
return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files."
|
|
||||||
|
|
||||||
if edit_mode not in self._VALID_EDIT_MODES:
|
|
||||||
return (
|
|
||||||
f"Error: Invalid edit_mode '{edit_mode}'. "
|
|
||||||
"Use one of: replace, insert, delete."
|
|
||||||
)
|
|
||||||
|
|
||||||
if cell_type not in self._VALID_CELL_TYPES:
|
|
||||||
return (
|
|
||||||
f"Error: Invalid cell_type '{cell_type}'. "
|
|
||||||
"Use one of: code, markdown."
|
|
||||||
)
|
|
||||||
|
|
||||||
fp = self._resolve(path)
|
|
||||||
|
|
||||||
# Create new notebook if file doesn't exist and mode is insert
|
|
||||||
if not fp.exists():
|
|
||||||
if edit_mode != "insert":
|
|
||||||
return f"Error: File not found: {path}"
|
|
||||||
nb = _make_empty_notebook()
|
|
||||||
cell = _new_cell(new_source, cell_type, generate_id=True)
|
|
||||||
nb["cells"].append(cell)
|
|
||||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
|
||||||
return f"Successfully created {fp} with 1 cell"
|
|
||||||
|
|
||||||
try:
|
|
||||||
nb = json.loads(fp.read_text(encoding="utf-8"))
|
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
||||||
return f"Error: Failed to parse notebook: {e}"
|
|
||||||
|
|
||||||
cells = nb.get("cells", [])
|
|
||||||
nbformat_minor = nb.get("nbformat_minor", 0)
|
|
||||||
generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5
|
|
||||||
|
|
||||||
if edit_mode == "delete":
|
|
||||||
if cell_index < 0 or cell_index >= len(cells):
|
|
||||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
|
||||||
cells.pop(cell_index)
|
|
||||||
nb["cells"] = cells
|
|
||||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
|
||||||
return f"Successfully deleted cell {cell_index} from {fp}"
|
|
||||||
|
|
||||||
if edit_mode == "insert":
|
|
||||||
insert_at = min(cell_index + 1, len(cells))
|
|
||||||
cell = _new_cell(new_source, cell_type, generate_id=generate_id)
|
|
||||||
cells.insert(insert_at, cell)
|
|
||||||
nb["cells"] = cells
|
|
||||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
|
||||||
return f"Successfully inserted cell at index {insert_at} in {fp}"
|
|
||||||
|
|
||||||
# Default: replace
|
|
||||||
if cell_index < 0 or cell_index >= len(cells):
|
|
||||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
|
||||||
cells[cell_index]["source"] = new_source
|
|
||||||
if cell_type and cells[cell_index].get("cell_type") != cell_type:
|
|
||||||
cells[cell_index]["cell_type"] = cell_type
|
|
||||||
if cell_type == "code":
|
|
||||||
cells[cell_index].setdefault("outputs", [])
|
|
||||||
cells[cell_index].setdefault("execution_count", None)
|
|
||||||
elif "outputs" in cells[cell_index]:
|
|
||||||
del cells[cell_index]["outputs"]
|
|
||||||
cells[cell_index].pop("execution_count", None)
|
|
||||||
nb["cells"] = cells
|
|
||||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
|
||||||
return f"Successfully edited cell {cell_index} in {fp}"
|
|
||||||
|
|
||||||
except PermissionError as e:
|
|
||||||
return f"Error: {e}"
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error editing notebook: {e}"
|
|
||||||
@ -1,4 +1,4 @@
|
|||||||
"""Search tools: grep."""
|
"""Search tools: file discovery and grep."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -12,6 +12,7 @@ from typing import Any, Iterable, TypeVar
|
|||||||
from nanobot.agent.tools.filesystem import ListDirTool, _FsTool
|
from nanobot.agent.tools.filesystem import ListDirTool, _FsTool
|
||||||
|
|
||||||
_DEFAULT_HEAD_LIMIT = 250
|
_DEFAULT_HEAD_LIMIT = 250
|
||||||
|
_DEFAULT_FILE_HEAD_LIMIT = 200
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
_TYPE_GLOB_MAP = {
|
_TYPE_GLOB_MAP = {
|
||||||
"py": ("*.py", "*.pyi"),
|
"py": ("*.py", "*.pyi"),
|
||||||
@ -88,6 +89,14 @@ def _matches_type(name: str, file_type: str | None) -> bool:
|
|||||||
return any(fnmatch.fnmatch(name.lower(), pattern.lower()) for pattern in patterns)
|
return any(fnmatch.fnmatch(name.lower(), pattern.lower()) for pattern in patterns)
|
||||||
|
|
||||||
|
|
||||||
|
def _matches_query(rel_path: str, query: str | None) -> bool:
|
||||||
|
if not query:
|
||||||
|
return True
|
||||||
|
haystack = rel_path.lower()
|
||||||
|
terms = [part for part in query.lower().split() if part]
|
||||||
|
return all(term in haystack for term in terms)
|
||||||
|
|
||||||
|
|
||||||
class _SearchTool(_FsTool):
|
class _SearchTool(_FsTool):
|
||||||
_IGNORE_DIRS = set(ListDirTool._IGNORE_DIRS)
|
_IGNORE_DIRS = set(ListDirTool._IGNORE_DIRS)
|
||||||
|
|
||||||
@ -109,6 +118,163 @@ class _SearchTool(_FsTool):
|
|||||||
yield current / filename
|
yield current / filename
|
||||||
|
|
||||||
|
|
||||||
|
class FindFilesTool(_SearchTool):
|
||||||
|
"""Find files by path fragment, glob, or type."""
|
||||||
|
_scopes = {"core", "subagent"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "find_files"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Find files by path fragment, glob, or file type. "
|
||||||
|
"Use this before read_file when you need to locate files, and "
|
||||||
|
"prefer it over shell find/ls for ordinary workspace discovery. "
|
||||||
|
"Returns workspace-relative paths and skips common dependency/build "
|
||||||
|
"directories."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Directory or file to search in (default '.')",
|
||||||
|
},
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": (
|
||||||
|
"Optional case-insensitive path fragment search. "
|
||||||
|
"Whitespace-separated terms must all be present."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
"glob": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional file filter, e.g. '*.py' or 'tests/**/test_*.py'",
|
||||||
|
},
|
||||||
|
"type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional file type shorthand, e.g. 'py', 'ts', 'md', 'json'",
|
||||||
|
},
|
||||||
|
"include_dirs": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Include matching directories as well as files (default false)",
|
||||||
|
},
|
||||||
|
"sort": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["path", "modified"],
|
||||||
|
"description": "Sort by path or most recently modified first (default path)",
|
||||||
|
},
|
||||||
|
"head_limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of paths to return (default 200, 0 for all, max 1000)",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 1000,
|
||||||
|
},
|
||||||
|
"offset": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Skip the first N results before applying head_limit",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 100000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _iter_paths(self, root: Path, *, include_dirs: bool) -> Iterable[Path]:
|
||||||
|
if root.is_file():
|
||||||
|
yield root
|
||||||
|
return
|
||||||
|
if include_dirs:
|
||||||
|
yield root
|
||||||
|
for dirpath, dirnames, filenames in os.walk(root):
|
||||||
|
dirnames[:] = sorted(d for d in dirnames if d not in self._IGNORE_DIRS)
|
||||||
|
current = Path(dirpath)
|
||||||
|
if include_dirs and current != root:
|
||||||
|
yield current
|
||||||
|
for filename in sorted(filenames):
|
||||||
|
yield current / filename
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
path: str = ".",
|
||||||
|
query: str | None = None,
|
||||||
|
glob: str | None = None,
|
||||||
|
type: str | None = None,
|
||||||
|
include_dirs: bool = False,
|
||||||
|
sort: str = "path",
|
||||||
|
head_limit: int | None = None,
|
||||||
|
offset: int = 0,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
target = self._resolve(path or ".")
|
||||||
|
if not target.exists():
|
||||||
|
return f"Error: Path not found: {path}"
|
||||||
|
if not (target.is_dir() or target.is_file()):
|
||||||
|
return f"Error: Unsupported path: {path}"
|
||||||
|
|
||||||
|
if sort not in {"path", "modified"}:
|
||||||
|
return "Error: sort must be 'path' or 'modified'"
|
||||||
|
|
||||||
|
limit = (
|
||||||
|
_DEFAULT_FILE_HEAD_LIMIT
|
||||||
|
if head_limit is None
|
||||||
|
else None if head_limit == 0 else head_limit
|
||||||
|
)
|
||||||
|
root = target if target.is_dir() else target.parent
|
||||||
|
matches: list[tuple[str, float]] = []
|
||||||
|
|
||||||
|
for candidate in self._iter_paths(target, include_dirs=include_dirs):
|
||||||
|
if candidate.is_dir() and not include_dirs:
|
||||||
|
continue
|
||||||
|
rel_path = candidate.relative_to(root).as_posix()
|
||||||
|
display_path = self._display_path(candidate, root)
|
||||||
|
name = candidate.name
|
||||||
|
|
||||||
|
if glob and not _match_glob(rel_path, name, glob):
|
||||||
|
continue
|
||||||
|
if candidate.is_file() and not _matches_type(name, type):
|
||||||
|
continue
|
||||||
|
if candidate.is_dir() and type:
|
||||||
|
continue
|
||||||
|
if not _matches_query(display_path, query):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
mtime = candidate.stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
mtime = 0.0
|
||||||
|
suffix = "/" if candidate.is_dir() else ""
|
||||||
|
matches.append((display_path + suffix, mtime))
|
||||||
|
|
||||||
|
if sort == "modified":
|
||||||
|
matches.sort(key=lambda item: (-item[1], item[0]))
|
||||||
|
else:
|
||||||
|
matches.sort(key=lambda item: item[0])
|
||||||
|
|
||||||
|
paths = [item[0] for item in matches]
|
||||||
|
paged, truncated = _paginate(paths, limit, offset)
|
||||||
|
if not paged:
|
||||||
|
return "No files found"
|
||||||
|
|
||||||
|
result = "\n".join(paged)
|
||||||
|
note = _pagination_note(limit, offset, truncated)
|
||||||
|
if note:
|
||||||
|
result += "\n\n" + note
|
||||||
|
return result
|
||||||
|
except PermissionError as e:
|
||||||
|
return f"Error: {e}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error finding files: {e}"
|
||||||
|
|
||||||
|
|
||||||
class GrepTool(_SearchTool):
|
class GrepTool(_SearchTool):
|
||||||
"""Search file contents using a regex-like pattern."""
|
"""Search file contents using a regex-like pattern."""
|
||||||
_scopes = {"core", "subagent"}
|
_scopes = {"core", "subagent"}
|
||||||
@ -125,7 +291,8 @@ class GrepTool(_SearchTool):
|
|||||||
return (
|
return (
|
||||||
"Search file contents with a regex pattern. "
|
"Search file contents with a regex pattern. "
|
||||||
"Default output_mode is files_with_matches (file paths only); "
|
"Default output_mode is files_with_matches (file paths only); "
|
||||||
"use content mode for matching lines with context. "
|
"use content mode for matching lines with context. Prefer this "
|
||||||
|
"over shell grep for ordinary workspace searches. "
|
||||||
"Skips binary and files >2 MB. Supports glob/type filtering."
|
"Skips binary and files >2 MB. Supports glob/type filtering."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -15,8 +16,17 @@ from loguru import logger
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
|
from nanobot.agent.tools.exec_session import (
|
||||||
|
DEFAULT_MAX_OUTPUT_CHARS,
|
||||||
|
DEFAULT_YIELD_MS,
|
||||||
|
DEFAULT_EXEC_SESSION_MANAGER,
|
||||||
|
MAX_OUTPUT_CHARS,
|
||||||
|
MAX_YIELD_MS,
|
||||||
|
clamp_session_int,
|
||||||
|
format_session_poll,
|
||||||
|
)
|
||||||
from nanobot.agent.tools.sandbox import wrap_command
|
from nanobot.agent.tools.sandbox import wrap_command
|
||||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
|
|
||||||
@ -44,10 +54,22 @@ class ExecToolConfig(Base):
|
|||||||
deny_patterns: list[str] = Field(default_factory=list)
|
deny_patterns: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _PreparedCommand:
|
||||||
|
command: str
|
||||||
|
cwd: str
|
||||||
|
env: dict[str, str]
|
||||||
|
timeout: int
|
||||||
|
shell_program: str | None
|
||||||
|
login: bool
|
||||||
|
|
||||||
|
|
||||||
@tool_parameters(
|
@tool_parameters(
|
||||||
tool_parameters_schema(
|
tool_parameters_schema(
|
||||||
command=StringSchema("The shell command to execute"),
|
command=StringSchema("The shell command to execute"),
|
||||||
|
cmd=StringSchema("Compatibility alias for command"),
|
||||||
working_dir=StringSchema("Optional working directory for the command"),
|
working_dir=StringSchema("Optional working directory for the command"),
|
||||||
|
workdir=StringSchema("Compatibility alias for working_dir"),
|
||||||
timeout=IntegerSchema(
|
timeout=IntegerSchema(
|
||||||
60,
|
60,
|
||||||
description=(
|
description=(
|
||||||
@ -57,7 +79,44 @@ class ExecToolConfig(Base):
|
|||||||
minimum=1,
|
minimum=1,
|
||||||
maximum=600,
|
maximum=600,
|
||||||
),
|
),
|
||||||
required=["command"],
|
shell=StringSchema(
|
||||||
|
"Optional shell binary to launch. On Unix, supports sh, bash, or zsh.",
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
login=BooleanSchema(
|
||||||
|
description="Whether to run bash/zsh with login shell semantics (default true).",
|
||||||
|
default=True,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
yield_time_ms=IntegerSchema(
|
||||||
|
description=(
|
||||||
|
"Optional milliseconds to wait before returning output. "
|
||||||
|
"When set, a still-running command returns a session_id that "
|
||||||
|
"can be polled or written to with write_stdin. Omit this field "
|
||||||
|
"to keep one-shot exec behavior."
|
||||||
|
),
|
||||||
|
minimum=0,
|
||||||
|
maximum=MAX_YIELD_MS,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
max_output_chars=IntegerSchema(
|
||||||
|
description=(
|
||||||
|
"Maximum output characters to return when yield_time_ms is used "
|
||||||
|
"(default 10000, max 50000)."
|
||||||
|
),
|
||||||
|
minimum=1000,
|
||||||
|
maximum=MAX_OUTPUT_CHARS,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
max_output_tokens=IntegerSchema(
|
||||||
|
description=(
|
||||||
|
"Compatibility alias for max_output_chars. The current runtime "
|
||||||
|
"uses a character budget."
|
||||||
|
),
|
||||||
|
minimum=1000,
|
||||||
|
maximum=MAX_OUTPUT_CHARS,
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
class ExecTool(Tool):
|
class ExecTool(Tool):
|
||||||
@ -98,6 +157,7 @@ class ExecTool(Tool):
|
|||||||
sandbox: str = "",
|
sandbox: str = "",
|
||||||
path_append: str = "",
|
path_append: str = "",
|
||||||
allowed_env_keys: list[str] | None = None,
|
allowed_env_keys: list[str] | None = None,
|
||||||
|
session_manager: Any | None = None,
|
||||||
):
|
):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.working_dir = working_dir
|
self.working_dir = working_dir
|
||||||
@ -125,6 +185,7 @@ class ExecTool(Tool):
|
|||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
self.path_append = path_append
|
self.path_append = path_append
|
||||||
self.allowed_env_keys = allowed_env_keys or []
|
self.allowed_env_keys = allowed_env_keys or []
|
||||||
|
self._session_manager = session_manager or DEFAULT_EXEC_SESSION_MANAGER
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
@ -150,10 +211,15 @@ class ExecTool(Tool):
|
|||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return (
|
return (
|
||||||
"Execute a shell command and return its output. "
|
"Execute a shell command and return its output. "
|
||||||
"Prefer read_file/write_file/edit_file over cat/echo/sed, "
|
"Use this for tests, builds, package commands, git commands, and "
|
||||||
"and grep/glob over shell find/grep. "
|
"other process execution. Prefer read_file/find_files/grep for "
|
||||||
|
"inspection and apply_patch/write_file/edit_file for file changes "
|
||||||
|
"instead of cat, shell find/grep, echo, or sed. "
|
||||||
"Use -y or --yes flags to avoid interactive prompts. "
|
"Use -y or --yes flags to avoid interactive prompts. "
|
||||||
"Output is truncated at 10 000 chars; timeout defaults to 60s."
|
"For long-running or interactive commands, pass yield_time_ms; "
|
||||||
|
"if the command keeps running, exec returns a session_id that can "
|
||||||
|
"be polled or written to with write_stdin. Output is truncated at "
|
||||||
|
"10 000 chars; timeout defaults to 60s."
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -161,9 +227,111 @@ class ExecTool(Tool):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self, command: str, working_dir: str | None = None,
|
self, command: str | None = None, cmd: str | None = None,
|
||||||
timeout: int | None = None, **kwargs: Any,
|
working_dir: str | None = None, workdir: str | None = None,
|
||||||
|
timeout: int | None = None, shell: str | None = None,
|
||||||
|
login: bool | None = None, yield_time_ms: int | None = None,
|
||||||
|
max_output_chars: int | None = None,
|
||||||
|
max_output_tokens: int | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
command = command or cmd
|
||||||
|
working_dir = working_dir or workdir
|
||||||
|
if not command:
|
||||||
|
return "Error: Missing command. Provide command or cmd."
|
||||||
|
if max_output_chars is None:
|
||||||
|
max_output_chars = max_output_tokens
|
||||||
|
|
||||||
|
prepared = self._prepare_command(command, working_dir, timeout, shell, login)
|
||||||
|
if isinstance(prepared, str):
|
||||||
|
return prepared
|
||||||
|
|
||||||
|
if yield_time_ms is not None:
|
||||||
|
return await self._execute_session(prepared, yield_time_ms, max_output_chars)
|
||||||
|
|
||||||
|
try:
|
||||||
|
process = await self._spawn(
|
||||||
|
prepared.command,
|
||||||
|
prepared.cwd,
|
||||||
|
prepared.env,
|
||||||
|
prepared.shell_program,
|
||||||
|
prepared.login,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
stdout, stderr = await asyncio.wait_for(
|
||||||
|
process.communicate(),
|
||||||
|
timeout=prepared.timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
await self._kill_process(process)
|
||||||
|
return f"Error: Command timed out after {prepared.timeout} seconds"
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
await self._kill_process(process)
|
||||||
|
raise
|
||||||
|
|
||||||
|
output_parts = []
|
||||||
|
|
||||||
|
if stdout:
|
||||||
|
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||||
|
|
||||||
|
if stderr:
|
||||||
|
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||||
|
if stderr_text.strip():
|
||||||
|
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||||
|
|
||||||
|
output_parts.append(f"\nExit code: {process.returncode}")
|
||||||
|
|
||||||
|
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||||
|
|
||||||
|
max_len = clamp_session_int(max_output_chars, self._MAX_OUTPUT, 1000, MAX_OUTPUT_CHARS)
|
||||||
|
if len(result) > max_len:
|
||||||
|
half = max_len // 2
|
||||||
|
result = (
|
||||||
|
result[:half]
|
||||||
|
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
|
||||||
|
+ result[-half:]
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error executing command: {str(e)}"
|
||||||
|
|
||||||
|
async def _execute_session(
|
||||||
|
self,
|
||||||
|
prepared: _PreparedCommand,
|
||||||
|
yield_time_ms: int | None,
|
||||||
|
max_output_chars: int | None,
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
session_id, poll = await self._session_manager.start(
|
||||||
|
command=prepared.command,
|
||||||
|
cwd=prepared.cwd,
|
||||||
|
env=prepared.env,
|
||||||
|
timeout=prepared.timeout,
|
||||||
|
shell_program=prepared.shell_program,
|
||||||
|
login=prepared.login,
|
||||||
|
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
|
||||||
|
max_output_chars=clamp_session_int(
|
||||||
|
max_output_chars,
|
||||||
|
DEFAULT_MAX_OUTPUT_CHARS,
|
||||||
|
1000,
|
||||||
|
MAX_OUTPUT_CHARS,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return format_session_poll(session_id, poll)
|
||||||
|
except Exception as exc:
|
||||||
|
return f"Error executing command: {exc}"
|
||||||
|
|
||||||
|
def _prepare_command(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
working_dir: str | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
shell: str | None = None,
|
||||||
|
login: bool | None = None,
|
||||||
|
) -> _PreparedCommand | str:
|
||||||
cwd = working_dir or self.working_dir or os.getcwd()
|
cwd = working_dir or self.working_dir or os.getcwd()
|
||||||
|
|
||||||
# Prevent an LLM-supplied working_dir from escaping the configured
|
# Prevent an LLM-supplied working_dir from escaping the configured
|
||||||
@ -211,52 +379,24 @@ class ExecTool(Tool):
|
|||||||
env["NANOBOT_PATH_APPEND"] = self.path_append
|
env["NANOBOT_PATH_APPEND"] = self.path_append
|
||||||
command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}'
|
command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}'
|
||||||
|
|
||||||
try:
|
shell_program, shell_error = self._resolve_shell(shell)
|
||||||
process = await self._spawn(command, cwd, env)
|
if shell_error:
|
||||||
|
return shell_error
|
||||||
|
|
||||||
try:
|
return _PreparedCommand(
|
||||||
stdout, stderr = await asyncio.wait_for(
|
command=command,
|
||||||
process.communicate(),
|
cwd=cwd,
|
||||||
timeout=effective_timeout,
|
env=env,
|
||||||
)
|
timeout=effective_timeout,
|
||||||
except asyncio.TimeoutError:
|
shell_program=shell_program,
|
||||||
await self._kill_process(process)
|
login=True if login is None else login,
|
||||||
return f"Error: Command timed out after {effective_timeout} seconds"
|
)
|
||||||
except asyncio.CancelledError:
|
|
||||||
await self._kill_process(process)
|
|
||||||
raise
|
|
||||||
|
|
||||||
output_parts = []
|
|
||||||
|
|
||||||
if stdout:
|
|
||||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
|
||||||
|
|
||||||
if stderr:
|
|
||||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
|
||||||
if stderr_text.strip():
|
|
||||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
|
||||||
|
|
||||||
output_parts.append(f"\nExit code: {process.returncode}")
|
|
||||||
|
|
||||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
|
||||||
|
|
||||||
max_len = self._MAX_OUTPUT
|
|
||||||
if len(result) > max_len:
|
|
||||||
half = max_len // 2
|
|
||||||
result = (
|
|
||||||
result[:half]
|
|
||||||
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
|
|
||||||
+ result[-half:]
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return f"Error executing command: {str(e)}"
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _spawn(
|
async def _spawn(
|
||||||
command: str, cwd: str, env: dict[str, str],
|
command: str, cwd: str, env: dict[str, str],
|
||||||
|
shell_program: str | None = None,
|
||||||
|
login: bool = True,
|
||||||
) -> asyncio.subprocess.Process:
|
) -> asyncio.subprocess.Process:
|
||||||
"""Launch *command* in a platform-appropriate shell."""
|
"""Launch *command* in a platform-appropriate shell."""
|
||||||
if _IS_WINDOWS:
|
if _IS_WINDOWS:
|
||||||
@ -272,9 +412,14 @@ class ExecTool(Tool):
|
|||||||
cwd=cwd,
|
cwd=cwd,
|
||||||
env=env,
|
env=env,
|
||||||
)
|
)
|
||||||
bash = shutil.which("bash") or "/bin/bash"
|
shell_program = shell_program or shutil.which("bash") or "/bin/bash"
|
||||||
|
args = [shell_program]
|
||||||
|
shell_name = Path(shell_program).name.lower()
|
||||||
|
if login and shell_name in {"bash", "bash.exe", "zsh", "zsh.exe"}:
|
||||||
|
args.append("-l")
|
||||||
|
args.extend(["-c", command])
|
||||||
return await asyncio.create_subprocess_exec(
|
return await asyncio.create_subprocess_exec(
|
||||||
bash, "-l", "-c", command,
|
*args,
|
||||||
stdin=asyncio.subprocess.DEVNULL,
|
stdin=asyncio.subprocess.DEVNULL,
|
||||||
stdout=asyncio.subprocess.PIPE,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
stderr=asyncio.subprocess.PIPE,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
@ -282,6 +427,31 @@ class ExecTool(Tool):
|
|||||||
env=env,
|
env=env,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_shell(shell: str | None) -> tuple[str | None, str | None]:
|
||||||
|
if not shell:
|
||||||
|
return None, None
|
||||||
|
if _IS_WINDOWS:
|
||||||
|
return None, "Error: shell parameter is not supported on Windows"
|
||||||
|
if "\0" in shell or "\n" in shell or "\r" in shell:
|
||||||
|
return None, "Error: shell contains invalid characters"
|
||||||
|
allowed = {"sh", "bash", "zsh"}
|
||||||
|
path = Path(shell).expanduser()
|
||||||
|
if path.is_absolute():
|
||||||
|
if path.name not in allowed:
|
||||||
|
return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
|
||||||
|
if not path.is_file() or not os.access(path, os.X_OK):
|
||||||
|
return None, f"Error: shell is not executable: {shell}"
|
||||||
|
return str(path), None
|
||||||
|
if "/" in shell or "\\" in shell:
|
||||||
|
return None, "Error: shell must be a shell name or absolute path"
|
||||||
|
if shell not in allowed:
|
||||||
|
return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
|
||||||
|
resolved = shutil.which(shell)
|
||||||
|
if not resolved:
|
||||||
|
return None, f"Error: shell not found: {shell}"
|
||||||
|
return resolved, None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _kill_process(process: asyncio.subprocess.Process) -> None:
|
async def _kill_process(process: asyncio.subprocess.Process) -> None:
|
||||||
"""Kill a subprocess and reap it to prevent zombies."""
|
"""Kill a subprocess and reap it to prevent zombies."""
|
||||||
@ -418,7 +588,7 @@ class ExecTool(Tool):
|
|||||||
# Windows: match drive-root paths like `C:\` as well as `C:\path\to\file`, and UNC paths like `\\server\share`
|
# Windows: match drive-root paths like `C:\` as well as `C:\path\to\file`, and UNC paths like `\\server\share`
|
||||||
# NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted.
|
# NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted.
|
||||||
win_paths = re.findall(
|
win_paths = re.findall(
|
||||||
r"(?:[A-Za-z]:[^\s\"'|><;]*|\\\\[^\s\"'|><;]+(?:\\[^\s\"'|><;]+)*)",
|
r"(?<![A-Za-z])(?:[A-Za-z]:[^\s\"'|><;]*|\\\\[^\s\"'|><;]+(?:\\[^\s\"'|><;]+)*)",
|
||||||
command
|
command
|
||||||
)
|
)
|
||||||
posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
|
posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
|
||||||
|
|||||||
1402
nanobot/channels/signal.py
Normal file
1402
nanobot/channels/signal.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -79,6 +79,12 @@ BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION}
|
|||||||
ERRCODE_SESSION_EXPIRED = -14
|
ERRCODE_SESSION_EXPIRED = -14
|
||||||
SESSION_PAUSE_DURATION_S = 60 * 60
|
SESSION_PAUSE_DURATION_S = 60 * 60
|
||||||
|
|
||||||
|
# iLink context_token is observed to expire server-side after ~90-160s of
|
||||||
|
# agent inactivity (openclaw/openclaw#61174). Proactively refresh before
|
||||||
|
# sending if the cached token is older than this threshold.
|
||||||
|
CONTEXT_TOKEN_MAX_AGE_S = 60
|
||||||
|
|
||||||
|
|
||||||
# Retry constants (matching the reference plugin's monitor.ts)
|
# Retry constants (matching the reference plugin's monitor.ts)
|
||||||
MAX_CONSECUTIVE_FAILURES = 3
|
MAX_CONSECUTIVE_FAILURES = 3
|
||||||
BACKOFF_DELAY_S = 30
|
BACKOFF_DELAY_S = 30
|
||||||
@ -159,6 +165,8 @@ class WeixinChannel(BaseChannel):
|
|||||||
self._session_pause_until: float = 0.0
|
self._session_pause_until: float = 0.0
|
||||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||||
self._typing_tickets: dict[str, dict[str, Any]] = {}
|
self._typing_tickets: dict[str, dict[str, Any]] = {}
|
||||||
|
self._context_token_at: dict[str, float] = {}
|
||||||
|
self._pending_tool_hints: dict[str, list[str]] = {}
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# State persistence
|
# State persistence
|
||||||
@ -486,6 +494,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
except Exception:
|
except Exception:
|
||||||
if not self._running:
|
if not self._running:
|
||||||
break
|
break
|
||||||
|
self.logger.exception("WeChat poll loop error")
|
||||||
consecutive_failures += 1
|
consecutive_failures += 1
|
||||||
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
|
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
|
||||||
consecutive_failures = 0
|
consecutive_failures = 0
|
||||||
@ -495,6 +504,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._pending_tool_hints.clear()
|
||||||
if self._poll_task and not self._poll_task.done():
|
if self._poll_task and not self._poll_task.done():
|
||||||
self._poll_task.cancel()
|
self._poll_task.cancel()
|
||||||
for chat_id in list(self._typing_tasks):
|
for chat_id in list(self._typing_tasks):
|
||||||
@ -545,6 +555,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
# Check for API-level errors (monitor.ts checks both ret and errcode)
|
# Check for API-level errors (monitor.ts checks both ret and errcode)
|
||||||
ret = data.get("ret", 0)
|
ret = data.get("ret", 0)
|
||||||
errcode = data.get("errcode", 0)
|
errcode = data.get("errcode", 0)
|
||||||
|
|
||||||
is_error = (ret is not None and ret != 0) or (errcode is not None and errcode != 0)
|
is_error = (ret is not None and ret != 0) or (errcode is not None and errcode != 0)
|
||||||
|
|
||||||
if is_error:
|
if is_error:
|
||||||
@ -575,8 +586,10 @@ class WeixinChannel(BaseChannel):
|
|||||||
# Process messages (WeixinMessage[] from types.ts)
|
# Process messages (WeixinMessage[] from types.ts)
|
||||||
msgs: list[dict] = data.get("msgs", []) or []
|
msgs: list[dict] = data.get("msgs", []) or []
|
||||||
for msg in msgs:
|
for msg in msgs:
|
||||||
with suppress(Exception):
|
try:
|
||||||
await self._process_message(msg)
|
await self._process_message(msg)
|
||||||
|
except Exception:
|
||||||
|
self.logger.exception("Failed to process WeChat message")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Inbound message processing (matches inbound.ts + process-message.ts)
|
# Inbound message processing (matches inbound.ts + process-message.ts)
|
||||||
@ -610,6 +623,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
ctx_token = msg.get("context_token", "")
|
ctx_token = msg.get("context_token", "")
|
||||||
if ctx_token:
|
if ctx_token:
|
||||||
self._context_tokens[from_user_id] = ctx_token
|
self._context_tokens[from_user_id] = ctx_token
|
||||||
|
self._context_token_at[from_user_id] = time.time()
|
||||||
self._save_state()
|
self._save_state()
|
||||||
|
|
||||||
# Parse item_list (WeixinMessage.item_list — types.ts:161)
|
# Parse item_list (WeixinMessage.item_list — types.ts:161)
|
||||||
@ -915,6 +929,99 @@ class WeixinChannel(BaseChannel):
|
|||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
async def _refresh_context_token_if_stale(
|
||||||
|
self, chat_id: str, context_token: str
|
||||||
|
) -> str:
|
||||||
|
"""Return a fresh context_token if the cached one is too old.
|
||||||
|
|
||||||
|
iLink context_token expires server-side after a short idle period
|
||||||
|
(empirically ~90s). Proactively refreshing before sending prevents
|
||||||
|
silent message loss on long agent turns or cron pushes.
|
||||||
|
"""
|
||||||
|
if not context_token:
|
||||||
|
return context_token
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
cached_at = self._context_token_at.get(chat_id, 0)
|
||||||
|
age = now - cached_at
|
||||||
|
|
||||||
|
if age < CONTEXT_TOKEN_MAX_AGE_S:
|
||||||
|
return context_token
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
"WeChat context_token for {} is {:.0f}s old; refreshing via getconfig",
|
||||||
|
chat_id,
|
||||||
|
age,
|
||||||
|
)
|
||||||
|
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"ilink_user_id": chat_id,
|
||||||
|
"context_token": context_token,
|
||||||
|
"base_info": BASE_INFO,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
data = await self._api_post("ilink/bot/getconfig", body)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning("WeChat getconfig failed for {}: {}", chat_id, e)
|
||||||
|
return context_token
|
||||||
|
|
||||||
|
if data.get("ret", 0) != 0:
|
||||||
|
self.logger.warning(
|
||||||
|
"WeChat getconfig returned ret={} for {}: {}",
|
||||||
|
data.get("ret"),
|
||||||
|
chat_id,
|
||||||
|
data.get("errmsg", ""),
|
||||||
|
)
|
||||||
|
return context_token
|
||||||
|
|
||||||
|
new_token = str(data.get("context_token", "") or "")
|
||||||
|
if new_token and new_token != context_token:
|
||||||
|
self.logger.info(
|
||||||
|
"WeChat context_token refreshed for {} (age {:.0f}s -> fresh)",
|
||||||
|
chat_id,
|
||||||
|
age,
|
||||||
|
)
|
||||||
|
self._context_tokens[chat_id] = new_token
|
||||||
|
self._context_token_at[chat_id] = now
|
||||||
|
self._save_state()
|
||||||
|
return new_token
|
||||||
|
|
||||||
|
return context_token
|
||||||
|
|
||||||
|
async def _flush_tool_hints(self, chat_id: str) -> None:
|
||||||
|
"""Send any buffered tool hints for *chat_id* as a single message.
|
||||||
|
|
||||||
|
Tool hints are coalesced to reduce message count and avoid hitting the
|
||||||
|
WeChat iLink rate limit (~7 msgs / 5 min). Failures are logged but
|
||||||
|
not raised so that the main message send is never blocked.
|
||||||
|
"""
|
||||||
|
hints = self._pending_tool_hints.pop(chat_id, None)
|
||||||
|
if not hints:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
"Flushing {} buffered tool hint(s) for {}",
|
||||||
|
len(hints),
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx_token = self._context_tokens.get(chat_id, "")
|
||||||
|
ctx_token = await self._refresh_context_token_if_stale(chat_id, ctx_token)
|
||||||
|
if not ctx_token:
|
||||||
|
self.logger.warning(
|
||||||
|
"Dropped {} buffered tool hint(s) for {}: no context_token",
|
||||||
|
len(hints),
|
||||||
|
chat_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._send_text(chat_id, "\n\n".join(hints), ctx_token)
|
||||||
|
except Exception:
|
||||||
|
self.logger.exception(
|
||||||
|
"Failed to flush buffered tool hints for {}", chat_id
|
||||||
|
)
|
||||||
|
|
||||||
async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None:
|
async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None:
|
||||||
"""Best-effort sendtyping wrapper."""
|
"""Best-effort sendtyping wrapper."""
|
||||||
if not typing_ticket:
|
if not typing_ticket:
|
||||||
@ -944,11 +1051,47 @@ class WeixinChannel(BaseChannel):
|
|||||||
self._assert_session_active()
|
self._assert_session_active()
|
||||||
|
|
||||||
is_progress = bool((msg.metadata or {}).get("_progress", False))
|
is_progress = bool((msg.metadata or {}).get("_progress", False))
|
||||||
|
|
||||||
|
# Buffer tool hints to coalesce consecutive ones and avoid burning
|
||||||
|
# WeChat iLink rate-limit quota (~7 msgs / 5 min).
|
||||||
|
if is_progress and (msg.metadata or {}).get("_tool_hint"):
|
||||||
|
if not self.send_tool_hints:
|
||||||
|
return
|
||||||
|
self._pending_tool_hints.setdefault(msg.chat_id, []).append(msg.content)
|
||||||
|
self.logger.debug(
|
||||||
|
"Buffered tool hint for {} (count={})",
|
||||||
|
msg.chat_id,
|
||||||
|
len(self._pending_tool_hints[msg.chat_id]),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Reasoning deltas are invisible in WeChat (there is no reasoning
|
||||||
|
# UI). Skip them entirely — do not send and do not flush buffer.
|
||||||
|
if is_progress and (msg.metadata or {}).get("_reasoning_delta"):
|
||||||
|
self.logger.debug(
|
||||||
|
"Dropped invisible reasoning delta for {}", msg.chat_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
content = msg.content.strip()
|
||||||
|
|
||||||
|
# Empty progress messages (e.g. after_iteration tool_events) must
|
||||||
|
# NOT act as separators — they have no visible content.
|
||||||
|
if is_progress and not content and not (msg.media or []):
|
||||||
|
self.logger.debug(
|
||||||
|
"Skipped empty progress message for {} (no visible content)",
|
||||||
|
msg.chat_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Flush buffered hints before sending any visible message.
|
||||||
|
await self._flush_tool_hints(msg.chat_id)
|
||||||
|
|
||||||
if not is_progress:
|
if not is_progress:
|
||||||
await self._stop_typing(msg.chat_id, clear_remote=True)
|
await self._stop_typing(msg.chat_id, clear_remote=True)
|
||||||
|
|
||||||
content = msg.content.strip()
|
|
||||||
ctx_token = self._context_tokens.get(msg.chat_id, "")
|
ctx_token = self._context_tokens.get(msg.chat_id, "")
|
||||||
|
ctx_token = await self._refresh_context_token_if_stale(msg.chat_id, ctx_token)
|
||||||
if not ctx_token:
|
if not ctx_token:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"WeChat context_token missing for chat_id={msg.chat_id}, cannot send"
|
f"WeChat context_token missing for chat_id={msg.chat_id}, cannot send"
|
||||||
@ -1037,6 +1180,18 @@ class WeixinChannel(BaseChannel):
|
|||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
|
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
|
||||||
|
|
||||||
|
async def send_delta(
|
||||||
|
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||||
|
) -> None:
|
||||||
|
"""Weixin iLink does not support native streaming deltas.
|
||||||
|
|
||||||
|
We only hook ``_stream_end`` so buffered tool hints are flushed even
|
||||||
|
when the final answer carries the ``_streamed`` flag and bypasses
|
||||||
|
:meth:`send`.
|
||||||
|
"""
|
||||||
|
if metadata and metadata.get("_stream_end"):
|
||||||
|
await self._flush_tool_hints(chat_id)
|
||||||
|
|
||||||
async def _start_typing(self, chat_id: str, context_token: str = "") -> None:
|
async def _start_typing(self, chat_id: str, context_token: str = "") -> None:
|
||||||
"""Start typing indicator immediately when a message is received."""
|
"""Start typing indicator immediately when a message is received."""
|
||||||
if not self._client or not self._token or not chat_id:
|
if not self._client or not self._token or not chat_id:
|
||||||
@ -1120,10 +1275,11 @@ class WeixinChannel(BaseChannel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
data = await self._api_post("ilink/bot/sendmessage", body)
|
data = await self._api_post("ilink/bot/sendmessage", body)
|
||||||
|
ret = data.get("ret", 0)
|
||||||
errcode = data.get("errcode", 0)
|
errcode = data.get("errcode", 0)
|
||||||
if errcode and errcode != 0:
|
if (ret is not None and ret != 0) or (errcode is not None and errcode != 0):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"WeChat send text error (code {errcode}): {data.get('errmsg', '')}"
|
f"WeChat send text error (ret={ret}, errcode={errcode}): {data.get('errmsg', '')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _send_media_file(
|
async def _send_media_file(
|
||||||
@ -1270,10 +1426,11 @@ class WeixinChannel(BaseChannel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
data = await self._api_post("ilink/bot/sendmessage", body)
|
data = await self._api_post("ilink/bot/sendmessage", body)
|
||||||
|
ret = data.get("ret", 0)
|
||||||
errcode = data.get("errcode", 0)
|
errcode = data.get("errcode", 0)
|
||||||
if errcode and errcode != 0:
|
if (ret is not None and ret != 0) or (errcode is not None and errcode != 0):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"WeChat send media error (code {errcode}): {data.get('errmsg', '')}"
|
f"WeChat send media error (ret={ret}, errcode={errcode}): {data.get('errmsg', '')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -211,6 +211,7 @@ class ProvidersConfig(Base):
|
|||||||
ant_ling: ProviderConfig = Field(default_factory=ProviderConfig) # Ant Ling
|
ant_ling: ProviderConfig = Field(default_factory=ProviderConfig) # Ant Ling
|
||||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||||
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
||||||
|
novita: ProviderConfig = Field(default_factory=ProviderConfig) # Novita AI
|
||||||
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
||||||
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
|
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
|
||||||
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
|
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
import re
|
import re
|
||||||
@ -898,6 +899,426 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OpenAI image generation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_OPENAI_DALLE2_SUPPORTED_SIZES = {"256x256", "512x512", "1024x1024"}
|
||||||
|
_OPENAI_DALLE3_SUPPORTED_SIZES = {"1024x1024", "1792x1024", "1024x1792"}
|
||||||
|
_OPENAI_GPT_IMAGE_SUPPORTED_SIZES = {
|
||||||
|
"1024x1024",
|
||||||
|
"1536x1024",
|
||||||
|
"1024x1536",
|
||||||
|
"auto",
|
||||||
|
}
|
||||||
|
_OPENAI_DALLE2_ASPECT_RATIO_SIZES = {
|
||||||
|
"1:1": "1024x1024",
|
||||||
|
"16:9": "1024x1024",
|
||||||
|
"9:16": "1024x1024",
|
||||||
|
"3:4": "1024x1024",
|
||||||
|
"4:3": "1024x1024",
|
||||||
|
}
|
||||||
|
_OPENAI_DALLE3_ASPECT_RATIO_SIZES = {
|
||||||
|
"1:1": "1024x1024",
|
||||||
|
"16:9": "1792x1024",
|
||||||
|
"9:16": "1024x1792",
|
||||||
|
"3:4": "1024x1792",
|
||||||
|
"4:3": "1792x1024",
|
||||||
|
}
|
||||||
|
_OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES = {
|
||||||
|
"1:1": "1024x1024",
|
||||||
|
"16:9": "1536x1024",
|
||||||
|
"9:16": "1024x1536",
|
||||||
|
"3:4": "1024x1536",
|
||||||
|
"4:3": "1536x1024",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageGenerationClient(ImageGenerationProvider):
|
||||||
|
"""OpenAI Images API using an API key (``providers.openai.apiKey``)."""
|
||||||
|
|
||||||
|
provider_name = "openai"
|
||||||
|
missing_key_message = (
|
||||||
|
"OpenAI API key is not configured. Set providers.openai.apiKey."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _default_base_url(self) -> str:
|
||||||
|
return "https://api.openai.com/v1"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _strip_model_prefix(model: str) -> str:
|
||||||
|
"""Remove ``openai/`` prefix if present (OpenRouter convention)."""
|
||||||
|
if model.startswith("openai/") or model.startswith("openai_codex/"):
|
||||||
|
return model.split("/", 1)[1]
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
reference_images: list[str] | None = None,
|
||||||
|
aspect_ratio: str | None = None,
|
||||||
|
image_size: str | None = None,
|
||||||
|
) -> GeneratedImageResponse:
|
||||||
|
if not self.api_key:
|
||||||
|
raise ImageGenerationError(self.missing_key_message)
|
||||||
|
|
||||||
|
if reference_images:
|
||||||
|
logger.warning(
|
||||||
|
"DALL-E models do not support reference images; "
|
||||||
|
"ignoring {} reference image(s) for {}",
|
||||||
|
len(reference_images),
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
**self.extra_headers,
|
||||||
|
}
|
||||||
|
|
||||||
|
clean_model = self._strip_model_prefix(model)
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"model": clean_model,
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
if not _openai_is_gpt_image_model(clean_model):
|
||||||
|
body["response_format"] = "b64_json"
|
||||||
|
body["n"] = 1
|
||||||
|
|
||||||
|
size = _openai_size(clean_model, aspect_ratio, image_size)
|
||||||
|
if size:
|
||||||
|
body["size"] = size
|
||||||
|
|
||||||
|
body.update(self.extra_body)
|
||||||
|
|
||||||
|
logger.info("OpenAI Images API request: POST {}/images/generations body={}", self.api_base, body)
|
||||||
|
|
||||||
|
response = await self._http_post(
|
||||||
|
f"{self.api_base}/images/generations",
|
||||||
|
headers=headers,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
detail = response.text[:1000]
|
||||||
|
logger.error("OpenAI Images API error ({}): {}", response.status_code, detail)
|
||||||
|
raise ImageGenerationError(
|
||||||
|
f"OpenAI image generation failed (HTTP {response.status_code}): {detail}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
payload = response.json()
|
||||||
|
logger.info("OpenAI Images API response ({}): {}", response.status_code,
|
||||||
|
{k: v for k, v in payload.items() if k != "data"})
|
||||||
|
|
||||||
|
client = self._client
|
||||||
|
owns_client = client is None
|
||||||
|
if owns_client:
|
||||||
|
client = httpx.AsyncClient(timeout=self.timeout)
|
||||||
|
try:
|
||||||
|
images = await _openai_images_from_payload(client, payload)
|
||||||
|
finally:
|
||||||
|
if owns_client:
|
||||||
|
await client.aclose()
|
||||||
|
|
||||||
|
self._require_images(images, payload)
|
||||||
|
|
||||||
|
return GeneratedImageResponse(images=images, content="", raw=payload)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OpenAI Codex image generation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class CodexImageGenerationClient(ImageGenerationProvider):
|
||||||
|
"""OpenAI image generation via Codex subscription OAuth.
|
||||||
|
|
||||||
|
Uses the Codex Responses API with the ``image_generation`` tool
|
||||||
|
(the same mechanism ChatGPT uses internally). No API key required —
|
||||||
|
the Codex OAuth token from ``oauth_cli_kit`` is used instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider_name = "openai_codex"
|
||||||
|
missing_key_message = (
|
||||||
|
"Codex OAuth token is unavailable. "
|
||||||
|
"Log in with Codex subscription first."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _default_base_url(self) -> str:
|
||||||
|
return "https://chatgpt.com/backend-api"
|
||||||
|
|
||||||
|
def _codex_model(self, model: str) -> str:
|
||||||
|
"""Strip the ``openai-codex/`` prefix if present."""
|
||||||
|
if model.startswith(("openai-codex/", "openai_codex/")):
|
||||||
|
return model.split("/", 1)[1]
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
prompt: str,
|
||||||
|
model: str,
|
||||||
|
reference_images: list[str] | None = None,
|
||||||
|
aspect_ratio: str | None = None,
|
||||||
|
image_size: str | None = None,
|
||||||
|
) -> GeneratedImageResponse:
|
||||||
|
try:
|
||||||
|
from oauth_cli_kit import get_token as get_codex_token
|
||||||
|
except ImportError:
|
||||||
|
raise ImageGenerationError(self.missing_key_message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
token = await asyncio.to_thread(get_codex_token)
|
||||||
|
except Exception as exc:
|
||||||
|
raise ImageGenerationError(self.missing_key_message) from exc
|
||||||
|
if not token or not token.access:
|
||||||
|
raise ImageGenerationError(self.missing_key_message)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Using Codex OAuth token for image generation (account: {})",
|
||||||
|
token.account_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if reference_images:
|
||||||
|
logger.warning(
|
||||||
|
"Codex image generation does not support reference images; "
|
||||||
|
"ignoring {} reference image(s)",
|
||||||
|
len(reference_images),
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token.access}",
|
||||||
|
"chatgpt-account-id": token.account_id,
|
||||||
|
"OpenAI-Beta": "responses=experimental",
|
||||||
|
"originator": "nanobot",
|
||||||
|
"User-Agent": "nanobot (python)",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
**self.extra_headers,
|
||||||
|
}
|
||||||
|
|
||||||
|
body: dict[str, Any] = {
|
||||||
|
"model": self._codex_model(model),
|
||||||
|
"instructions": "Generate an image based on the user's request.",
|
||||||
|
"input": [{"role": "user", "content": prompt}],
|
||||||
|
"tools": [{"type": "image_generation"}],
|
||||||
|
"tool_choice": "auto",
|
||||||
|
"stream": True,
|
||||||
|
"store": False,
|
||||||
|
}
|
||||||
|
body.update(self.extra_body)
|
||||||
|
|
||||||
|
logger.info("Codex Responses API request: POST {}/codex/responses body={}",
|
||||||
|
self.api_base, {k: v for k, v in body.items() if k != "input"})
|
||||||
|
|
||||||
|
response = await self._http_post(
|
||||||
|
f"{self.api_base}/codex/responses",
|
||||||
|
headers=headers,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
detail = response.text[:1000]
|
||||||
|
logger.error("Codex Responses API error ({}): {}", response.status_code, detail)
|
||||||
|
raise ImageGenerationError(
|
||||||
|
f"Codex image generation failed (HTTP {response.status_code}): {detail}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
images, content_text = await _parse_codex_sse_images(response)
|
||||||
|
|
||||||
|
raw = {"status": "completed"}
|
||||||
|
self._require_images(images, raw)
|
||||||
|
|
||||||
|
return GeneratedImageResponse(images=images, content=content_text, raw=raw)
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_size(
|
||||||
|
model: str,
|
||||||
|
aspect_ratio: str | None,
|
||||||
|
image_size: str | None,
|
||||||
|
) -> str:
|
||||||
|
"""Resolve aspect ratio or image_size to an OpenAI Images API size string."""
|
||||||
|
sizes, supported_sizes = _openai_size_options(model)
|
||||||
|
explicit_size = _normalize_openai_image_size(image_size)
|
||||||
|
if explicit_size and _openai_explicit_size_supported(
|
||||||
|
explicit_size,
|
||||||
|
supported_sizes=supported_sizes,
|
||||||
|
):
|
||||||
|
return explicit_size
|
||||||
|
if explicit_size:
|
||||||
|
logger.warning(
|
||||||
|
"OpenAI image size '{}' is not supported by {}; using aspect ratio/default size",
|
||||||
|
explicit_size,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
if aspect_ratio and aspect_ratio in sizes:
|
||||||
|
return sizes[aspect_ratio]
|
||||||
|
return "1024x1024"
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_is_gpt_image_model(model: str) -> bool:
|
||||||
|
normalized = model.lower()
|
||||||
|
return normalized.startswith(("gpt-image", "chatgpt-image"))
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_size_options(model: str) -> tuple[dict[str, str], set[str] | None]:
|
||||||
|
normalized = model.lower()
|
||||||
|
if normalized.startswith("dall-e-2"):
|
||||||
|
return _OPENAI_DALLE2_ASPECT_RATIO_SIZES, _OPENAI_DALLE2_SUPPORTED_SIZES
|
||||||
|
if normalized.startswith("dall-e-3"):
|
||||||
|
return _OPENAI_DALLE3_ASPECT_RATIO_SIZES, _OPENAI_DALLE3_SUPPORTED_SIZES
|
||||||
|
if normalized.startswith("gpt-image-2"):
|
||||||
|
return _OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, None
|
||||||
|
return _OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, _OPENAI_GPT_IMAGE_SUPPORTED_SIZES
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_openai_image_size(image_size: str | None) -> str | None:
|
||||||
|
if not image_size:
|
||||||
|
return None
|
||||||
|
normalized = image_size.strip().lower()
|
||||||
|
return normalized or None
|
||||||
|
|
||||||
|
|
||||||
|
def _openai_explicit_size_supported(
|
||||||
|
size: str,
|
||||||
|
*,
|
||||||
|
supported_sizes: set[str] | None,
|
||||||
|
) -> bool:
|
||||||
|
if supported_sizes is not None:
|
||||||
|
return size in supported_sizes
|
||||||
|
width, sep, height = size.partition("x")
|
||||||
|
return bool(sep and width.isdecimal() and height.isdecimal())
|
||||||
|
|
||||||
|
|
||||||
|
async def _openai_images_from_payload(
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Extract images from OpenAI Images API response.
|
||||||
|
|
||||||
|
Handles both ``b64_json`` (preferred) and ``url`` (downloaded) formats.
|
||||||
|
"""
|
||||||
|
images: list[str] = []
|
||||||
|
for item in payload.get("data") or []:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
b64 = item.get("b64_json")
|
||||||
|
if isinstance(b64, str) and b64:
|
||||||
|
images.append(_b64_image_data_url(b64))
|
||||||
|
continue
|
||||||
|
url = item.get("url")
|
||||||
|
if isinstance(url, str) and url:
|
||||||
|
images.append(await _download_image_data_url(client, url))
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def _codex_responses_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
||||||
|
"""Extract images from Codex Responses API ``image_generation_call`` output."""
|
||||||
|
images: list[str] = []
|
||||||
|
for item in payload.get("output") or []:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
if item.get("type") != "image_generation_call":
|
||||||
|
continue
|
||||||
|
result = item.get("result")
|
||||||
|
if isinstance(result, str):
|
||||||
|
images.append(result if result.startswith("data:image/") else _b64_image_data_url(result))
|
||||||
|
continue
|
||||||
|
if isinstance(result, dict):
|
||||||
|
image_url = result.get("image_url") or result.get("image") or ""
|
||||||
|
if isinstance(image_url, str):
|
||||||
|
images.append(image_url if image_url.startswith("data:image/") else _b64_image_data_url(image_url))
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
async def _parse_codex_sse_images(
|
||||||
|
response: httpx.Response,
|
||||||
|
) -> tuple[list[str], str]:
|
||||||
|
"""Parse a Codex Responses API SSE stream for image generation output.
|
||||||
|
|
||||||
|
Returns ``(images, content_text)``.
|
||||||
|
"""
|
||||||
|
import json as _json
|
||||||
|
|
||||||
|
images: list[str] = []
|
||||||
|
text_parts: list[str] = []
|
||||||
|
|
||||||
|
buffer: list[str] = []
|
||||||
|
async for line_bytes in response.aiter_lines():
|
||||||
|
line = line_bytes.strip()
|
||||||
|
if line == "":
|
||||||
|
if buffer:
|
||||||
|
data_lines = []
|
||||||
|
for bl in buffer:
|
||||||
|
if bl.startswith("data:"):
|
||||||
|
data_lines.append(bl[5:].strip())
|
||||||
|
buffer.clear()
|
||||||
|
if data_lines:
|
||||||
|
raw = "".join(data_lines)
|
||||||
|
if raw == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
event = _json.loads(raw)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
ev_type = event.get("type", "")
|
||||||
|
if ev_type in ("error", "response.failed"):
|
||||||
|
logger.error("Codex SSE failure: {}", raw[:2000])
|
||||||
|
_collect_images_from_sse_event(event, images)
|
||||||
|
_collect_text_from_sse_event(event, text_parts)
|
||||||
|
continue
|
||||||
|
buffer.append(line)
|
||||||
|
|
||||||
|
# flush remaining
|
||||||
|
if buffer:
|
||||||
|
data_lines = [bl[5:].strip() for bl in buffer if bl.startswith("data:")]
|
||||||
|
raw = "".join(data_lines)
|
||||||
|
if raw and raw != "[DONE]":
|
||||||
|
try:
|
||||||
|
event = _json.loads(raw)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
_collect_images_from_sse_event(event, images)
|
||||||
|
_collect_text_from_sse_event(event, text_parts)
|
||||||
|
|
||||||
|
return images, "".join(text_parts).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_images_from_sse_event(event: dict[str, Any], images: list[str]) -> None:
|
||||||
|
if event.get("type") != "response.output_item.done":
|
||||||
|
return
|
||||||
|
item = event.get("item") or {}
|
||||||
|
if item.get("type") != "image_generation_call":
|
||||||
|
return
|
||||||
|
result = item.get("result")
|
||||||
|
if isinstance(result, str):
|
||||||
|
if result.startswith("data:image/"):
|
||||||
|
images.append(result)
|
||||||
|
else:
|
||||||
|
images.append(_b64_image_data_url(result))
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
image_url = result.get("image_url") or result.get("image") or ""
|
||||||
|
if isinstance(image_url, str):
|
||||||
|
if image_url.startswith("data:image/"):
|
||||||
|
images.append(image_url)
|
||||||
|
else:
|
||||||
|
images.append(_b64_image_data_url(image_url))
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_text_from_sse_event(event: dict[str, Any], text_parts: list[str]) -> None:
|
||||||
|
if event.get("type") == "response.output_text.delta":
|
||||||
|
delta = event.get("delta")
|
||||||
|
if isinstance(delta, str) and delta:
|
||||||
|
text_parts.append(delta)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# StepFun (阶跃星辰) image generation
|
# StepFun (阶跃星辰) image generation
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -1025,9 +1446,11 @@ def _stepfun_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
|||||||
# Provider registration
|
# Provider registration
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
register_image_gen_provider(OpenRouterImageGenerationClient)
|
|
||||||
register_image_gen_provider(AIHubMixImageGenerationClient)
|
register_image_gen_provider(AIHubMixImageGenerationClient)
|
||||||
|
register_image_gen_provider(CodexImageGenerationClient)
|
||||||
register_image_gen_provider(GeminiImageGenerationClient)
|
register_image_gen_provider(GeminiImageGenerationClient)
|
||||||
register_image_gen_provider(OllamaImageGenerationClient)
|
register_image_gen_provider(OllamaImageGenerationClient)
|
||||||
register_image_gen_provider(MiniMaxImageGenerationClient)
|
register_image_gen_provider(MiniMaxImageGenerationClient)
|
||||||
|
register_image_gen_provider(OpenAIImageGenerationClient)
|
||||||
|
register_image_gen_provider(OpenRouterImageGenerationClient)
|
||||||
register_image_gen_provider(StepFunImageGenerationClient)
|
register_image_gen_provider(StepFunImageGenerationClient)
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import secrets
|
|||||||
import string
|
import string
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections import deque
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@ -74,41 +75,43 @@ _THINKING_STYLE_MAP: dict[str, Any] = {
|
|||||||
"enable_thinking": lambda on: {"enable_thinking": on},
|
"enable_thinking": lambda on: {"enable_thinking": on},
|
||||||
"reasoning_split": lambda on: {"reasoning_split": on},
|
"reasoning_split": lambda on: {"reasoning_split": on},
|
||||||
}
|
}
|
||||||
|
_GATEWAY_REASONING_STYLE_MAP: dict[str, Any] = {
|
||||||
|
"reasoning_effort": lambda effort: {"reasoning": {"effort": effort}},
|
||||||
|
}
|
||||||
|
_MODEL_THINKING_STYLES: dict[str, str] = {
|
||||||
|
**dict.fromkeys(_KIMI_THINKING_MODELS, "thinking_type"),
|
||||||
|
**dict.fromkeys(_MIMO_THINKING_MODELS, "thinking_type"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _is_kimi_thinking_model(model_name: str) -> bool:
|
def _model_slug(model_name: str) -> str:
|
||||||
"""Return True if model_name refers to a Kimi thinking-capable model.
|
return model_name.lower().rsplit("/", 1)[-1]
|
||||||
|
|
||||||
Supports two forms:
|
|
||||||
- Exact match: e.g. kimi-k2.5 / kimi-k2.6 in _KIMI_THINKING_MODELS
|
|
||||||
- Slug match: moonshotai/kimi-k2.5 -> the part after the last "/"
|
|
||||||
is checked against _KIMI_THINKING_MODELS
|
|
||||||
|
|
||||||
This covers both the native Moonshot provider (bare slug) and
|
|
||||||
OpenRouter-style names (``"publisher/slug"``).
|
|
||||||
"""
|
|
||||||
name = model_name.lower()
|
|
||||||
if name in _KIMI_THINKING_MODELS:
|
|
||||||
return True
|
|
||||||
if "/" in name and name.rsplit("/", 1)[1] in _KIMI_THINKING_MODELS:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_mimo_thinking_model(model_name: str) -> bool:
|
def _model_thinking_style(model_name: str) -> str:
|
||||||
"""Return True if model_name refers to a MiMo thinking-capable model.
|
return _MODEL_THINKING_STYLES.get(_model_slug(model_name), "")
|
||||||
|
|
||||||
Mirrors _is_kimi_thinking_model: gateway providers (e.g. OpenRouter
|
|
||||||
routing ``xiaomi/mimo-v2.5-pro``) have no ``thinking_style`` on their
|
def _thinking_styles_for(spec: ProviderSpec | None, model_name: str) -> list[str]:
|
||||||
spec, so the spec-driven branch in _build_kwargs misses them. The
|
styles: list[str] = []
|
||||||
model-name path catches those cases.
|
if spec and spec.thinking_style:
|
||||||
"""
|
styles.append(spec.thinking_style)
|
||||||
name = model_name.lower()
|
model_style = _model_thinking_style(model_name)
|
||||||
if name in _MIMO_THINKING_MODELS:
|
if model_style and model_style not in styles:
|
||||||
return True
|
styles.append(model_style)
|
||||||
if "/" in name and name.rsplit("/", 1)[1] in _MIMO_THINKING_MODELS:
|
return styles
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
def _thinking_extra_body(style: str, thinking_enabled: bool) -> dict[str, Any] | None:
|
||||||
|
builder = _THINKING_STYLE_MAP.get(style)
|
||||||
|
return builder(thinking_enabled) if builder else None
|
||||||
|
|
||||||
|
|
||||||
|
def _gateway_reasoning_extra_body(style: str, effort: str | None) -> dict[str, Any] | None:
|
||||||
|
if not effort:
|
||||||
|
return None
|
||||||
|
builder = _GATEWAY_REASONING_STYLE_MAP.get(style)
|
||||||
|
return builder(effort) if builder else None
|
||||||
|
|
||||||
|
|
||||||
def _openai_compat_timeout_s() -> float:
|
def _openai_compat_timeout_s() -> float:
|
||||||
@ -461,6 +464,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
"""Strip non-standard keys, normalize tool_call IDs."""
|
"""Strip non-standard keys, normalize tool_call IDs."""
|
||||||
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
||||||
id_map: dict[str, str] = {}
|
id_map: dict[str, str] = {}
|
||||||
|
pending_tool_ids: dict[str, deque[str]] = {}
|
||||||
force_string_content = bool(self._spec and self._spec.name == "deepseek")
|
force_string_content = bool(self._spec and self._spec.name == "deepseek")
|
||||||
|
|
||||||
def map_id(value: Any) -> Any:
|
def map_id(value: Any) -> Any:
|
||||||
@ -468,15 +472,49 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
return value
|
return value
|
||||||
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
||||||
|
|
||||||
|
def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str:
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
base = map_id(value)
|
||||||
|
else:
|
||||||
|
base = _short_tool_id()
|
||||||
|
if not isinstance(base, str) or not base:
|
||||||
|
base = _short_tool_id()
|
||||||
|
if base not in used_ids:
|
||||||
|
return base
|
||||||
|
seed = value if isinstance(value, str) and value else base
|
||||||
|
salt = 1
|
||||||
|
while True:
|
||||||
|
candidate = self._normalize_tool_call_id(f"{seed}:{idx}:{salt}")
|
||||||
|
if isinstance(candidate, str) and candidate not in used_ids:
|
||||||
|
return candidate
|
||||||
|
salt += 1
|
||||||
|
|
||||||
|
def map_tool_result_id(value: Any) -> Any:
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return value
|
||||||
|
queue = pending_tool_ids.get(value)
|
||||||
|
if queue:
|
||||||
|
mapped = queue.popleft()
|
||||||
|
if not queue:
|
||||||
|
pending_tool_ids.pop(value, None)
|
||||||
|
return mapped
|
||||||
|
return map_id(value)
|
||||||
|
|
||||||
for clean in sanitized:
|
for clean in sanitized:
|
||||||
if isinstance(clean.get("tool_calls"), list):
|
if isinstance(clean.get("tool_calls"), list):
|
||||||
normalized = []
|
normalized = []
|
||||||
for tc in clean["tool_calls"]:
|
used_ids: set[str] = set()
|
||||||
|
for idx, tc in enumerate(clean["tool_calls"]):
|
||||||
if not isinstance(tc, dict):
|
if not isinstance(tc, dict):
|
||||||
normalized.append(tc)
|
normalized.append(tc)
|
||||||
continue
|
continue
|
||||||
tc_clean = dict(tc)
|
tc_clean = dict(tc)
|
||||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
raw_id = tc_clean.get("id")
|
||||||
|
mapped_id = unique_tool_id(raw_id, used_ids, idx)
|
||||||
|
tc_clean["id"] = mapped_id
|
||||||
|
used_ids.add(mapped_id)
|
||||||
|
if isinstance(raw_id, str) and raw_id:
|
||||||
|
pending_tool_ids.setdefault(raw_id, deque()).append(mapped_id)
|
||||||
function = tc_clean.get("function")
|
function = tc_clean.get("function")
|
||||||
if isinstance(function, dict):
|
if isinstance(function, dict):
|
||||||
function_clean = dict(function)
|
function_clean = dict(function)
|
||||||
@ -494,7 +532,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
# that mix non-empty content with tool_calls.
|
# that mix non-empty content with tool_calls.
|
||||||
clean["content"] = None
|
clean["content"] = None
|
||||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
clean["tool_call_id"] = map_tool_result_id(clean["tool_call_id"])
|
||||||
if (
|
if (
|
||||||
force_string_content
|
force_string_content
|
||||||
and not (clean.get("role") == "assistant" and clean.get("tool_calls"))
|
and not (clean.get("role") == "assistant" and clean.get("tool_calls"))
|
||||||
@ -581,39 +619,27 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
if wire_effort and semantic_effort != "none":
|
if wire_effort and semantic_effort != "none":
|
||||||
kwargs["reasoning_effort"] = wire_effort
|
kwargs["reasoning_effort"] = wire_effort
|
||||||
|
|
||||||
# Provider-specific thinking parameters.
|
# Only send thinking controls when reasoning_effort is explicit so
|
||||||
# Only sent when reasoning_effort is explicitly configured so that
|
# omitting the config preserves each provider's default.
|
||||||
# the provider default is preserved otherwise.
|
if reasoning_effort is not None:
|
||||||
# The mapping is driven by ProviderSpec.thinking_style so that adding
|
|
||||||
# a new provider never requires touching this function.
|
|
||||||
if spec and spec.thinking_style and reasoning_effort is not None:
|
|
||||||
thinking_enabled = semantic_effort not in ("none", "minimal")
|
thinking_enabled = semantic_effort not in ("none", "minimal")
|
||||||
extra = _THINKING_STYLE_MAP.get(spec.thinking_style, lambda _: None)(thinking_enabled)
|
for thinking_style in _thinking_styles_for(spec, model_name):
|
||||||
if extra:
|
extra = _thinking_extra_body(thinking_style, thinking_enabled)
|
||||||
kwargs.setdefault("extra_body", {}).update(extra)
|
if extra:
|
||||||
|
kwargs.setdefault("extra_body", {}).update(extra)
|
||||||
|
gateway_style = getattr(spec, "gateway_reasoning_style", "") if spec else ""
|
||||||
|
if gateway_style and _model_thinking_style(model_name):
|
||||||
|
extra = _gateway_reasoning_extra_body(gateway_style, semantic_effort)
|
||||||
|
if extra:
|
||||||
|
kwargs.setdefault("extra_body", {}).update(extra)
|
||||||
|
|
||||||
# Model-level thinking injection for Kimi thinking-capable models.
|
# Moonshot rejects requests that carry both 'reasoning_effort'
|
||||||
# Strip any provider prefix (e.g. "moonshotai/") before the set lookup
|
# and the native 'thinking' param. We already expressed the
|
||||||
# so that OpenRouter-style names like "moonshotai/kimi-k2.5" are handled
|
# user's intent via the provider-native shape, so drop the
|
||||||
# identically to bare names like "kimi-k2.5".
|
# redundant wire-level kwarg. Only kimi models need this —
|
||||||
if reasoning_effort is not None and _is_kimi_thinking_model(model_name):
|
# Xiaomi's API accepts both params.
|
||||||
thinking_enabled = semantic_effort not in ("none", "minimal")
|
if _model_slug(model_name) in _KIMI_THINKING_MODELS:
|
||||||
kwargs.setdefault("extra_body", {}).update(
|
kwargs.pop("reasoning_effort", None)
|
||||||
{"thinking": {"type": "enabled" if thinking_enabled else "disabled"}}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Model-level thinking injection for MiMo thinking-capable models.
|
|
||||||
# Same shape as Kimi: gateway providers (OpenRouter, etc.) lack the
|
|
||||||
# xiaomi_mimo spec's thinking_style, so the spec-driven branch above
|
|
||||||
# misses them — match by model name to catch "xiaomi/mimo-v2.5-pro"
|
|
||||||
# and friends. (Direct xiaomi_mimo requests are also covered here;
|
|
||||||
# both branches write the same payload, so the dict update is a
|
|
||||||
# safe no-op for already-handled cases.)
|
|
||||||
if reasoning_effort is not None and _is_mimo_thinking_model(model_name):
|
|
||||||
thinking_enabled = semantic_effort not in ("none", "minimal")
|
|
||||||
kwargs.setdefault("extra_body", {}).update(
|
|
||||||
{"thinking": {"type": "enabled" if thinking_enabled else "disabled"}}
|
|
||||||
)
|
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
kwargs["tools"] = tools
|
kwargs["tools"] = tools
|
||||||
@ -628,8 +654,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
and semantic_effort not in ("none", "minimal")
|
and semantic_effort not in ("none", "minimal")
|
||||||
and (
|
and (
|
||||||
(spec and spec.thinking_style)
|
(spec and spec.thinking_style)
|
||||||
or _is_kimi_thinking_model(model_name)
|
or _model_thinking_style(model_name)
|
||||||
or _is_mimo_thinking_model(model_name)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
implicit_deepseek_thinking = (
|
implicit_deepseek_thinking = (
|
||||||
@ -1097,6 +1122,15 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
if delta:
|
if delta:
|
||||||
_accum_legacy_function_call(getattr(delta, "function_call", None))
|
_accum_legacy_function_call(getattr(delta, "function_call", None))
|
||||||
|
|
||||||
|
# Some providers (e.g. Zhipu/GLM) reuse the same tool_call id for
|
||||||
|
# parallel tool calls in streaming mode. Deduplicate before building
|
||||||
|
# the response so downstream tool messages don't collide.
|
||||||
|
_seen_tc_ids: set[str] = set()
|
||||||
|
for b in tc_bufs.values():
|
||||||
|
if not b["id"] or b["id"] in _seen_tc_ids:
|
||||||
|
b["id"] = _short_tool_id()
|
||||||
|
_seen_tc_ids.add(b["id"])
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content="".join(content_parts) or None,
|
content="".join(content_parts) or None,
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
|
|||||||
@ -71,6 +71,11 @@ class ProviderSpec:
|
|||||||
# "reasoning_split" — {"reasoning_split": true/false} (MiniMax)
|
# "reasoning_split" — {"reasoning_split": true/false} (MiniMax)
|
||||||
thinking_style: str = ""
|
thinking_style: str = ""
|
||||||
|
|
||||||
|
# Gateway-native reasoning control to pair with model-level thinking styles.
|
||||||
|
# "reasoning_effort" — {"reasoning": {"effort": <none|minimal|...>}}
|
||||||
|
# (OpenRouter)
|
||||||
|
gateway_reasoning_style: str = ""
|
||||||
|
|
||||||
# When True, treat the "reasoning" response field as formal content
|
# When True, treat the "reasoning" response field as formal content
|
||||||
# when "content" is empty. Only set this for providers (e.g. StepFun)
|
# when "content" is empty. Only set this for providers (e.g. StepFun)
|
||||||
# whose API returns the actual answer in "reasoning" instead of "content".
|
# whose API returns the actual answer in "reasoning" instead of "content".
|
||||||
@ -142,6 +147,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
detect_by_base_keyword="openrouter",
|
detect_by_base_keyword="openrouter",
|
||||||
default_api_base="https://openrouter.ai/api/v1",
|
default_api_base="https://openrouter.ai/api/v1",
|
||||||
supports_prompt_caching=True,
|
supports_prompt_caching=True,
|
||||||
|
gateway_reasoning_style="reasoning_effort",
|
||||||
),
|
),
|
||||||
# Hugging Face Inference Providers: OpenAI-compatible router for chat models.
|
# Hugging Face Inference Providers: OpenAI-compatible router for chat models.
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
@ -193,6 +199,18 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
default_api_base="https://api.siliconflow.cn/v1",
|
default_api_base="https://api.siliconflow.cn/v1",
|
||||||
),
|
),
|
||||||
|
|
||||||
|
# Novita AI: OpenAI-compatible gateway for hosted model APIs.
|
||||||
|
ProviderSpec(
|
||||||
|
name="novita",
|
||||||
|
keywords=("novita",),
|
||||||
|
env_key="NOVITA_API_KEY",
|
||||||
|
display_name="Novita AI",
|
||||||
|
backend="openai_compat",
|
||||||
|
is_gateway=True,
|
||||||
|
detect_by_base_keyword="novita",
|
||||||
|
default_api_base="https://api.novita.ai/openai",
|
||||||
|
),
|
||||||
|
|
||||||
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
|
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
name="volcengine",
|
name="volcengine",
|
||||||
|
|||||||
@ -1,5 +1,9 @@
|
|||||||
# Agent Instructions
|
# Agent Instructions
|
||||||
|
|
||||||
|
## Workspace Guidance
|
||||||
|
|
||||||
|
Use this file for project-specific preferences, recurring workflow conventions, and instructions you want the agent to remember for this workspace. Keep durable facts about the user in `USER.md`, personality/style guidance in `SOUL.md`, and long-term memory in `memory/MEMORY.md`.
|
||||||
|
|
||||||
## Scheduled Reminders
|
## Scheduled Reminders
|
||||||
|
|
||||||
Before scheduling reminders, check available skills and follow skill guidance first.
|
Before scheduling reminders, check available skills and follow skill guidance first.
|
||||||
@ -10,10 +14,10 @@ Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegr
|
|||||||
|
|
||||||
## Heartbeat Tasks
|
## Heartbeat Tasks
|
||||||
|
|
||||||
`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks:
|
`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks.
|
||||||
|
|
||||||
- **Add**: `edit_file` to append new tasks
|
- Use `apply_patch` for normal task-list updates, especially when adding, removing, or changing multiple lines.
|
||||||
- **Remove**: `edit_file` to delete completed tasks
|
- Use `edit_file` only for small exact replacements copied from the current `HEARTBEAT.md`.
|
||||||
- **Rewrite**: `write_file` to replace all tasks
|
- Use `write_file` for first creation or intentional full-file rewrites.
|
||||||
|
|
||||||
When the user asks for a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time cron reminder.
|
When the user asks for a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time cron reminder.
|
||||||
|
|||||||
@ -1,28 +0,0 @@
|
|||||||
# Tool Usage Notes
|
|
||||||
|
|
||||||
Tool signatures are provided automatically via function calling.
|
|
||||||
This file documents non-obvious constraints and usage patterns.
|
|
||||||
|
|
||||||
## exec — Safety Limits
|
|
||||||
|
|
||||||
- Commands have a configurable timeout (default 60s)
|
|
||||||
- Dangerous commands are blocked (rm -rf, format, dd, shutdown, etc.)
|
|
||||||
- Output is truncated at 10,000 characters
|
|
||||||
- `restrictToWorkspace` config can limit file access to the workspace
|
|
||||||
|
|
||||||
## grep — Content Search
|
|
||||||
|
|
||||||
- Use `grep` to search file contents inside the workspace
|
|
||||||
- Default behavior returns only matching file paths (`output_mode="files_with_matches"`)
|
|
||||||
- Supports optional `glob` filtering (e.g. `glob="*.py"`) plus `context_before` / `context_after`
|
|
||||||
- Supports `type="py"`, `type="ts"`, `type="md"` and similar shorthand filters
|
|
||||||
- Use `fixed_strings=true` for literal keywords containing regex characters
|
|
||||||
- Use `output_mode="files_with_matches"` to get only matching file paths
|
|
||||||
- Use `output_mode="count"` to size a search before reading full matches
|
|
||||||
- Use `head_limit` and `offset` to page across results
|
|
||||||
- Prefer this over `exec` for code and history searches
|
|
||||||
- Binary or oversized files may be skipped to keep results readable
|
|
||||||
|
|
||||||
## cron — Scheduled Reminders
|
|
||||||
|
|
||||||
- Please refer to cron skill for usage.
|
|
||||||
60
nanobot/templates/agent/tool_contract.md
Normal file
60
nanobot/templates/agent/tool_contract.md
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# Tool Usage Notes
|
||||||
|
|
||||||
|
Tool signatures are provided automatically via function calling. This section
|
||||||
|
documents the general tool contract and non-obvious usage patterns.
|
||||||
|
|
||||||
|
## General Tool Contract
|
||||||
|
|
||||||
|
- Use the narrowest structured tool that directly matches the task.
|
||||||
|
- Use read-only discovery before writes when state is uncertain.
|
||||||
|
- Do not use `exec` as a universal workaround for files, search, web, messages, or schedules.
|
||||||
|
- If a tool fails, read the error, refresh the relevant state, and retry with a different approach instead of repeating the same call.
|
||||||
|
- After meaningful changes, verify with the smallest reliable check: re-read changed state, run targeted tests, or inspect command output.
|
||||||
|
- Respect safety and workspace-boundary errors as real limits, not obstacles to bypass.
|
||||||
|
|
||||||
|
## Discovery and Reading
|
||||||
|
|
||||||
|
- Use `find_files` or `list_dir` to locate workspace paths before `read_file` when a path is uncertain.
|
||||||
|
- Use `grep` for content search inside the workspace; prefer it over shell grep for ordinary searches.
|
||||||
|
- `grep` defaults to `output_mode="files_with_matches"`; use `output_mode="content"` for matching lines with context.
|
||||||
|
- Use `fixed_strings=true` for literal keywords containing regex characters.
|
||||||
|
- Use `output_mode="count"` to size a broad search before reading full matches.
|
||||||
|
- Use `head_limit` and `offset` to page across large result sets.
|
||||||
|
- Binary or oversized files may be skipped to keep results readable.
|
||||||
|
|
||||||
|
## File and Coding Workflows
|
||||||
|
|
||||||
|
- For code or config changes, the default loop is: locate (`find_files`/`grep`), inspect (`read_file`), edit (`apply_patch`), then verify (`exec` or re-read).
|
||||||
|
- Use `apply_patch` as the default code editing tool, especially for multi-file changes, structural edits, generated code, moves, adds, or deletes.
|
||||||
|
- Use `apply_patch dry_run=true` when the patch is uncertain and you want validation plus a change summary before writing.
|
||||||
|
- Use `edit_file` only for small exact replacements in one file, with `old_text` copied from `read_file`; add `occurrence`, `line_hint`, or `expected_replacements` when ambiguity matters.
|
||||||
|
- Use `write_file` for new files or intentional full-file rewrites, not routine partial edits.
|
||||||
|
- If `apply_patch` or `edit_file` fails, re-read with `force=true`, narrow the context, and try a smaller patch rather than switching to shell `sed` or `echo`.
|
||||||
|
|
||||||
|
## Process Execution
|
||||||
|
|
||||||
|
- Use `exec` for tests, builds, package commands, git commands, and other process execution.
|
||||||
|
- Prefer dedicated file/search tools over `cat`, shell `find`, shell `grep`, `sed`, or `echo` for ordinary workspace inspection and edits.
|
||||||
|
- Use non-interactive flags such as `-y` or `--yes` when available.
|
||||||
|
- Commands have a configurable timeout (default 60s), dangerous commands are blocked, and output is truncated.
|
||||||
|
- For long-running or interactive commands, pass `yield_time_ms`; if the process keeps running, continue with `write_stdin`.
|
||||||
|
- Use `write_stdin` to poll, provide stdin, close stdin, wait for expected output with `wait_for`, or terminate an existing exec session.
|
||||||
|
- Use `list_exec_sessions` to recover active session IDs after context shifts.
|
||||||
|
|
||||||
|
## Web and External Information
|
||||||
|
|
||||||
|
- Use web tools when the user asks for current information, a specific URL, or information likely to have changed.
|
||||||
|
- Use `web_search` to find sources and `web_fetch` for a specific page or result that needs closer reading.
|
||||||
|
- Do not invent freshness-sensitive facts when tools can verify them.
|
||||||
|
|
||||||
|
## Messaging and Media
|
||||||
|
|
||||||
|
- Use `message` to send content or local media to the user/channel.
|
||||||
|
- `read_file` only reads content for your analysis; it does not deliver a file to the user.
|
||||||
|
- When sending an existing local file, attach it through the message/media mechanism instead of pasting file contents unless the user asked for text.
|
||||||
|
|
||||||
|
## Scheduling and Background Work
|
||||||
|
|
||||||
|
- Use `cron` for scheduled reminders or recurring jobs; do not run `nanobot cron` through `exec`.
|
||||||
|
- For heartbeat tasks, update `HEARTBEAT.md` according to the agent instructions.
|
||||||
|
- Do not write reminders only to memory files when the user expects an actual notification.
|
||||||
@ -3,15 +3,13 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import difflib
|
import difflib
|
||||||
import json
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Awaitable, Callable
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
|
TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "apply_patch"})
|
||||||
TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"})
|
|
||||||
_MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024
|
_MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024
|
||||||
_LIVE_EMIT_INTERVAL_S = 0.18
|
_LIVE_EMIT_INTERVAL_S = 0.18
|
||||||
_LIVE_EMIT_LINE_STEP = 24
|
_LIVE_EMIT_LINE_STEP = 24
|
||||||
@ -154,19 +152,108 @@ def prepare_file_edit_tracker(
|
|||||||
workspace: Path | None,
|
workspace: Path | None,
|
||||||
params: dict[str, Any] | None,
|
params: dict[str, Any] | None,
|
||||||
) -> FileEditTracker | None:
|
) -> FileEditTracker | None:
|
||||||
|
trackers = prepare_file_edit_trackers(
|
||||||
|
call_id=call_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool=tool,
|
||||||
|
workspace=workspace,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
return trackers[0] if trackers else None
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_file_edit_trackers(
|
||||||
|
*,
|
||||||
|
call_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
tool: Any,
|
||||||
|
workspace: Path | None,
|
||||||
|
params: dict[str, Any] | None,
|
||||||
|
) -> list[FileEditTracker]:
|
||||||
if not is_file_edit_tool(tool_name):
|
if not is_file_edit_tool(tool_name):
|
||||||
return None
|
return []
|
||||||
|
paths = resolve_file_edit_paths(tool_name, tool, workspace, params)
|
||||||
|
trackers: list[FileEditTracker] = []
|
||||||
|
seen: set[Path] = set()
|
||||||
|
for path in paths:
|
||||||
|
try:
|
||||||
|
resolved = path.resolve()
|
||||||
|
except Exception:
|
||||||
|
resolved = path
|
||||||
|
if resolved in seen:
|
||||||
|
continue
|
||||||
|
seen.add(resolved)
|
||||||
|
before = read_file_snapshot(path)
|
||||||
|
trackers.append(FileEditTracker(
|
||||||
|
call_id=str(call_id or ""),
|
||||||
|
tool=tool_name,
|
||||||
|
path=path,
|
||||||
|
display_path=display_file_edit_path(path, workspace),
|
||||||
|
before=before,
|
||||||
|
))
|
||||||
|
return trackers
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_file_edit_paths(
|
||||||
|
tool_name: str,
|
||||||
|
tool: Any,
|
||||||
|
workspace: Path | None,
|
||||||
|
params: dict[str, Any] | None,
|
||||||
|
) -> list[Path]:
|
||||||
|
if tool_name == "apply_patch":
|
||||||
|
return _resolve_apply_patch_paths(tool, workspace, params)
|
||||||
path = resolve_file_edit_path(tool, workspace, params)
|
path = resolve_file_edit_path(tool, workspace, params)
|
||||||
if path is None:
|
if path is None:
|
||||||
return None
|
return []
|
||||||
before = read_file_snapshot(path)
|
return [path]
|
||||||
return FileEditTracker(
|
|
||||||
call_id=str(call_id or ""),
|
|
||||||
tool=tool_name,
|
def _resolve_apply_patch_paths(
|
||||||
path=path,
|
tool: Any,
|
||||||
display_path=display_file_edit_path(path, workspace),
|
workspace: Path | None,
|
||||||
before=before,
|
params: dict[str, Any] | None,
|
||||||
)
|
) -> list[Path]:
|
||||||
|
if not isinstance(params, dict):
|
||||||
|
return []
|
||||||
|
edits = params.get("edits")
|
||||||
|
if not isinstance(edits, list) or not edits:
|
||||||
|
return []
|
||||||
|
if params.get("dry_run") is True:
|
||||||
|
return []
|
||||||
|
|
||||||
|
resolved: list[Path] = []
|
||||||
|
seen: set[Path] = set()
|
||||||
|
for edit in edits:
|
||||||
|
if not isinstance(edit, dict):
|
||||||
|
continue
|
||||||
|
raw_path = edit.get("path")
|
||||||
|
if not isinstance(raw_path, str) or not raw_path.strip():
|
||||||
|
continue
|
||||||
|
path = _resolve_raw_file_edit_path(tool, workspace, raw_path)
|
||||||
|
if path is not None and path not in seen:
|
||||||
|
seen.add(path)
|
||||||
|
resolved.append(path)
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_raw_file_edit_path(
|
||||||
|
tool: Any,
|
||||||
|
workspace: Path | None,
|
||||||
|
raw_path: str,
|
||||||
|
) -> Path | None:
|
||||||
|
resolver = getattr(tool, "_resolve", None)
|
||||||
|
if callable(resolver):
|
||||||
|
try:
|
||||||
|
resolved = resolver(raw_path)
|
||||||
|
if isinstance(resolved, Path):
|
||||||
|
return resolved
|
||||||
|
if resolved:
|
||||||
|
return Path(resolved)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if workspace is None:
|
||||||
|
return Path(raw_path).expanduser().resolve()
|
||||||
|
return (workspace / raw_path).expanduser().resolve()
|
||||||
|
|
||||||
|
|
||||||
def build_file_edit_start_event(
|
def build_file_edit_start_event(
|
||||||
@ -304,6 +391,9 @@ class StreamingFileEditTracker:
|
|||||||
self._states[key] = state
|
self._states[key] = state
|
||||||
|
|
||||||
state.apply_delta(payload)
|
state.apply_delta(payload)
|
||||||
|
if state.name == "apply_patch":
|
||||||
|
await self._update_apply_patch(state)
|
||||||
|
return
|
||||||
if state.name not in {"write_file", "edit_file"}:
|
if state.name not in {"write_file", "edit_file"}:
|
||||||
return
|
return
|
||||||
if state.path is None:
|
if state.path is None:
|
||||||
@ -343,10 +433,80 @@ class StreamingFileEditTracker:
|
|||||||
deleted=deleted,
|
deleted=deleted,
|
||||||
)])
|
)])
|
||||||
|
|
||||||
|
async def _update_apply_patch(self, state: _StreamingFileEditState) -> None:
|
||||||
|
if _json_bool_true(state.arguments, "dry_run"):
|
||||||
|
return
|
||||||
|
tool = self._tools.get("apply_patch") if hasattr(self._tools, "get") else None
|
||||||
|
events: list[dict[str, Any]] = []
|
||||||
|
now = time.monotonic()
|
||||||
|
|
||||||
|
path_matches = list(re.finditer(r'"path"\s*:\s*"([^"]+)"', state.arguments))
|
||||||
|
if not path_matches:
|
||||||
|
return
|
||||||
|
|
||||||
|
for i, m in enumerate(path_matches):
|
||||||
|
raw_path = m.group(1)
|
||||||
|
path = _resolve_raw_file_edit_path(tool, self._workspace, raw_path)
|
||||||
|
if path is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
segment_start = m.start()
|
||||||
|
segment_end = path_matches[i + 1].start() if i + 1 < len(path_matches) else len(state.arguments)
|
||||||
|
segment = state.arguments[segment_start:segment_end]
|
||||||
|
|
||||||
|
action_match = re.search(r'"action"\s*:\s*"(replace|add|delete)"', segment)
|
||||||
|
action = action_match.group(1) if action_match else "replace"
|
||||||
|
|
||||||
|
old_text = _extract_json_string_prefix(segment, "old_text") or ""
|
||||||
|
new_text = _extract_json_string_prefix(segment, "new_text") or ""
|
||||||
|
|
||||||
|
added = _text_line_count(new_text) if action in ("replace", "add") else 0
|
||||||
|
deleted = _text_line_count(old_text) if action in ("replace", "delete") else 0
|
||||||
|
delete_file = action == "delete"
|
||||||
|
|
||||||
|
file_state = state.patch_files.get(raw_path)
|
||||||
|
if file_state is None:
|
||||||
|
tracker = FileEditTracker(
|
||||||
|
call_id=state.call_id or state.key,
|
||||||
|
tool="apply_patch",
|
||||||
|
path=path,
|
||||||
|
display_path=display_file_edit_path(path, self._workspace),
|
||||||
|
before=read_file_snapshot(path),
|
||||||
|
)
|
||||||
|
file_state = _StreamingPatchFileState(tracker=tracker)
|
||||||
|
state.patch_files[raw_path] = file_state
|
||||||
|
if delete_file and added == 0 and deleted == 0 and file_state.tracker.before.countable:
|
||||||
|
deleted = _text_line_count(file_state.tracker.before.text or "")
|
||||||
|
if not file_state.should_emit(added, deleted, now):
|
||||||
|
continue
|
||||||
|
file_state.mark_emitted(added, deleted, now)
|
||||||
|
events.append(build_file_edit_live_event(
|
||||||
|
file_state.tracker,
|
||||||
|
added=added,
|
||||||
|
deleted=deleted,
|
||||||
|
))
|
||||||
|
if events:
|
||||||
|
await self._emit(events)
|
||||||
|
|
||||||
async def flush(self) -> None:
|
async def flush(self) -> None:
|
||||||
events: list[dict[str, Any]] = []
|
events: list[dict[str, Any]] = []
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
for state in self._states.values():
|
for state in self._states.values():
|
||||||
|
for file_state in state.patch_files.values():
|
||||||
|
added, deleted = file_state.last_added, file_state.last_deleted
|
||||||
|
if not file_state.emitted_once:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
file_state.last_emitted_added == added
|
||||||
|
and file_state.last_emitted_deleted == deleted
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
file_state.mark_emitted(added, deleted, now)
|
||||||
|
events.append(build_file_edit_live_event(
|
||||||
|
file_state.tracker,
|
||||||
|
added=added,
|
||||||
|
deleted=deleted,
|
||||||
|
))
|
||||||
if state.tracker is None:
|
if state.tracker is None:
|
||||||
continue
|
continue
|
||||||
added, deleted = state.live_diff_counts()
|
added, deleted = state.live_diff_counts()
|
||||||
@ -367,12 +527,14 @@ class StreamingFileEditTracker:
|
|||||||
|
|
||||||
def apply_final_call_ids(self, final_tool_calls: list[Any]) -> None:
|
def apply_final_call_ids(self, final_tool_calls: list[Any]) -> None:
|
||||||
"""Keep final start/end events keyed to any earlier streamed placeholder."""
|
"""Keep final start/end events keyed to any earlier streamed placeholder."""
|
||||||
|
used_canonicals: set[str] = set()
|
||||||
for tool_call in final_tool_calls:
|
for tool_call in final_tool_calls:
|
||||||
canonical = self.canonical_call_id_for(tool_call)
|
canonical = self.canonical_call_id_for(tool_call)
|
||||||
if canonical:
|
if canonical and canonical not in used_canonicals:
|
||||||
try:
|
try:
|
||||||
tool_call.id = canonical
|
tool_call.id = canonical
|
||||||
except Exception:
|
used_canonicals.add(canonical)
|
||||||
|
except (AttributeError, TypeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def canonical_call_id_for(self, tool_call: Any) -> str | None:
|
def canonical_call_id_for(self, tool_call: Any) -> str | None:
|
||||||
@ -389,6 +551,10 @@ class StreamingFileEditTracker:
|
|||||||
"""Mark streamed edits as failed when no final tool call will run."""
|
"""Mark streamed edits as failed when no final tool call will run."""
|
||||||
events: list[dict[str, Any]] = []
|
events: list[dict[str, Any]] = []
|
||||||
for state in self._states.values():
|
for state in self._states.values():
|
||||||
|
for file_state in state.patch_files.values():
|
||||||
|
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
|
||||||
|
continue
|
||||||
|
events.append(build_file_edit_error_event(file_state.tracker, error))
|
||||||
if state.tracker is None:
|
if state.tracker is None:
|
||||||
continue
|
continue
|
||||||
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
|
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
|
||||||
@ -492,6 +658,39 @@ class _StreamingJsonStringField:
|
|||||||
self.last_char_cr = False
|
self.last_char_cr = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class _StreamingPatchFileState:
|
||||||
|
tracker: FileEditTracker
|
||||||
|
emitted_once: bool = False
|
||||||
|
last_emitted_added: int = -1
|
||||||
|
last_emitted_deleted: int = -1
|
||||||
|
last_emit_at: float = 0.0
|
||||||
|
last_added: int = 0
|
||||||
|
last_deleted: int = 0
|
||||||
|
|
||||||
|
def should_emit(self, added: int, deleted: int, now: float) -> bool:
|
||||||
|
self.last_added = added
|
||||||
|
self.last_deleted = deleted
|
||||||
|
if not self.emitted_once:
|
||||||
|
return True
|
||||||
|
if added == self.last_emitted_added and deleted == self.last_emitted_deleted:
|
||||||
|
return False
|
||||||
|
if max(
|
||||||
|
abs(added - self.last_emitted_added),
|
||||||
|
abs(deleted - self.last_emitted_deleted),
|
||||||
|
) >= _LIVE_EMIT_LINE_STEP:
|
||||||
|
return True
|
||||||
|
return now - self.last_emit_at >= _LIVE_EMIT_INTERVAL_S
|
||||||
|
|
||||||
|
def mark_emitted(self, added: int, deleted: int, now: float) -> None:
|
||||||
|
self.emitted_once = True
|
||||||
|
self.last_added = added
|
||||||
|
self.last_deleted = deleted
|
||||||
|
self.last_emitted_added = added
|
||||||
|
self.last_emitted_deleted = deleted
|
||||||
|
self.last_emit_at = now
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class _StreamingFileEditState:
|
class _StreamingFileEditState:
|
||||||
key: str
|
key: str
|
||||||
@ -509,6 +708,7 @@ class _StreamingFileEditState:
|
|||||||
new_text: _StreamingJsonStringField = field(
|
new_text: _StreamingJsonStringField = field(
|
||||||
default_factory=lambda: _StreamingJsonStringField("new_text")
|
default_factory=lambda: _StreamingJsonStringField("new_text")
|
||||||
)
|
)
|
||||||
|
patch_files: dict[str, _StreamingPatchFileState] = field(default_factory=dict)
|
||||||
emitted_once: bool = False
|
emitted_once: bool = False
|
||||||
last_emitted_added: int = -1
|
last_emitted_added: int = -1
|
||||||
last_emitted_deleted: int = -1
|
last_emitted_deleted: int = -1
|
||||||
@ -531,6 +731,7 @@ class _StreamingFileEditState:
|
|||||||
self.content.reset()
|
self.content.reset()
|
||||||
self.old_text.reset()
|
self.old_text.reset()
|
||||||
self.new_text.reset()
|
self.new_text.reset()
|
||||||
|
self.patch_files.clear()
|
||||||
return
|
return
|
||||||
delta = payload.get("arguments_delta")
|
delta = payload.get("arguments_delta")
|
||||||
if isinstance(delta, str) and delta:
|
if isinstance(delta, str) and delta:
|
||||||
@ -590,6 +791,14 @@ class _StreamingFileEditState:
|
|||||||
name = getattr(tool_call, "name", None)
|
name = getattr(tool_call, "name", None)
|
||||||
if name != self.name:
|
if name != self.name:
|
||||||
return False
|
return False
|
||||||
|
if self.name == "apply_patch":
|
||||||
|
arguments = getattr(tool_call, "arguments", None)
|
||||||
|
if not isinstance(arguments, dict):
|
||||||
|
return False
|
||||||
|
edits = arguments.get("edits")
|
||||||
|
if not isinstance(edits, list):
|
||||||
|
return False
|
||||||
|
return '"edits"' in self.arguments
|
||||||
arguments = getattr(tool_call, "arguments", None)
|
arguments = getattr(tool_call, "arguments", None)
|
||||||
if not isinstance(arguments, dict):
|
if not isinstance(arguments, dict):
|
||||||
return False
|
return False
|
||||||
@ -612,6 +821,51 @@ def _stream_key(payload: dict[str, Any]) -> str:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _json_bool_true(source: str, key: str) -> bool:
|
||||||
|
return re.search(rf'"{re.escape(key)}"\s*:\s*true\b', source) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_json_string_prefix(source: str, key: str) -> str | None:
|
||||||
|
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
|
||||||
|
if match is None:
|
||||||
|
return None
|
||||||
|
out: list[str] = []
|
||||||
|
i = match.end()
|
||||||
|
escape = False
|
||||||
|
while i < len(source):
|
||||||
|
ch = source[i]
|
||||||
|
if escape:
|
||||||
|
escape = False
|
||||||
|
if ch == "n":
|
||||||
|
out.append("\n")
|
||||||
|
elif ch == "r":
|
||||||
|
out.append("\r")
|
||||||
|
elif ch == "t":
|
||||||
|
out.append("\t")
|
||||||
|
elif ch == "u":
|
||||||
|
digits = source[i + 1:i + 5]
|
||||||
|
if len(digits) < 4:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
out.append(chr(int(digits, 16)))
|
||||||
|
except ValueError:
|
||||||
|
break
|
||||||
|
i += 4
|
||||||
|
else:
|
||||||
|
out.append(ch)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
if ch == "\\":
|
||||||
|
escape = True
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
if ch == '"':
|
||||||
|
return "".join(out)
|
||||||
|
out.append(ch)
|
||||||
|
i += 1
|
||||||
|
return "".join(out)
|
||||||
|
|
||||||
|
|
||||||
def _extract_complete_json_string(source: str, key: str) -> str | None:
|
def _extract_complete_json_string(source: str, key: str) -> str | None:
|
||||||
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
|
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
|
||||||
if match is None:
|
if match is None:
|
||||||
@ -704,77 +958,4 @@ def _predict_after_text(
|
|||||||
return before_text.replace(old_text, new_text)
|
return before_text.replace(old_text, new_text)
|
||||||
return before_text.replace(old_text, new_text, 1)
|
return before_text.replace(old_text, new_text, 1)
|
||||||
return None
|
return None
|
||||||
if tool_name == "notebook_edit":
|
|
||||||
return _predict_notebook_after_text(params, before_text)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> str | None:
|
|
||||||
try:
|
|
||||||
nb = json.loads(before_text) if before_text.strip() else _empty_notebook()
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
cells = nb.get("cells")
|
|
||||||
if not isinstance(cells, list):
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
cell_index = int(params.get("cell_index", 0))
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
return None
|
|
||||||
new_source = params.get("new_source")
|
|
||||||
source = new_source if isinstance(new_source, str) else ""
|
|
||||||
cell_type = (
|
|
||||||
params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code"
|
|
||||||
)
|
|
||||||
mode = (
|
|
||||||
params.get("edit_mode")
|
|
||||||
if params.get("edit_mode") in ("replace", "insert", "delete")
|
|
||||||
else "replace"
|
|
||||||
)
|
|
||||||
if mode == "delete":
|
|
||||||
if 0 <= cell_index < len(cells):
|
|
||||||
cells.pop(cell_index)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
elif mode == "insert":
|
|
||||||
insert_at = min(max(cell_index + 1, 0), len(cells))
|
|
||||||
cells.insert(insert_at, _new_notebook_cell(source, str(cell_type)))
|
|
||||||
else:
|
|
||||||
if not (0 <= cell_index < len(cells)):
|
|
||||||
return None
|
|
||||||
cell = cells[cell_index]
|
|
||||||
if not isinstance(cell, dict):
|
|
||||||
return None
|
|
||||||
cell["source"] = source
|
|
||||||
cell["cell_type"] = cell_type
|
|
||||||
if cell_type == "code":
|
|
||||||
cell.setdefault("outputs", [])
|
|
||||||
cell.setdefault("execution_count", None)
|
|
||||||
else:
|
|
||||||
cell.pop("outputs", None)
|
|
||||||
cell.pop("execution_count", None)
|
|
||||||
nb["cells"] = cells
|
|
||||||
try:
|
|
||||||
return json.dumps(nb, indent=1, ensure_ascii=False)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _empty_notebook() -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5,
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
|
||||||
"language_info": {"name": "python"},
|
|
||||||
},
|
|
||||||
"cells": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _new_notebook_cell(source: str, cell_type: str) -> dict[str, Any]:
|
|
||||||
cell: dict[str, Any] = {"cell_type": cell_type, "source": source, "metadata": {}}
|
|
||||||
if cell_type == "code":
|
|
||||||
cell["outputs"] = []
|
|
||||||
cell["execution_count"] = None
|
|
||||||
return cell
|
|
||||||
|
|||||||
@ -576,7 +576,7 @@ def build_status_content(
|
|||||||
|
|
||||||
|
|
||||||
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||||
"""Sync bundled templates to workspace. Only creates missing files."""
|
"""Sync bundled templates to workspace. Creates missing files without overwriting user files."""
|
||||||
from importlib.resources import files as pkg_files
|
from importlib.resources import files as pkg_files
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -589,10 +589,11 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]
|
|||||||
added: list[str] = []
|
added: list[str] = []
|
||||||
|
|
||||||
def _write(src, dest: Path):
|
def _write(src, dest: Path):
|
||||||
|
content = src.read_text(encoding="utf-8") if src else ""
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
return
|
return
|
||||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8")
|
dest.write_text(content, encoding="utf-8")
|
||||||
added.append(str(dest.relative_to(workspace)))
|
added.append(str(dest.relative_to(workspace)))
|
||||||
|
|
||||||
for item in tpl.iterdir():
|
for item in tpl.iterdir():
|
||||||
|
|||||||
@ -11,8 +11,10 @@ _TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = {
|
|||||||
"read_file": (["path", "file_path"], "read {}", True, False),
|
"read_file": (["path", "file_path"], "read {}", True, False),
|
||||||
"write_file": (["path", "file_path"], "write {}", True, False),
|
"write_file": (["path", "file_path"], "write {}", True, False),
|
||||||
"edit": (["file_path", "path"], "edit {}", True, False),
|
"edit": (["file_path", "path"], "edit {}", True, False),
|
||||||
|
"find_files": (["query", "glob", "path"], "find {}", False, False),
|
||||||
"grep": (["pattern"], 'grep "{}"', False, False),
|
"grep": (["pattern"], 'grep "{}"', False, False),
|
||||||
"exec": (["command"], "$ {}", False, True),
|
"exec": (["command"], "$ {}", False, True),
|
||||||
|
"list_exec_sessions": ([], "exec sessions", False, False),
|
||||||
"web_search": (["query"], 'search "{}"', False, False),
|
"web_search": (["query"], 'search "{}"', False, False),
|
||||||
"web_fetch": (["url"], "fetch {}", True, False),
|
"web_fetch": (["url"], "fetch {}", True, False),
|
||||||
"list_dir": (["path"], "ls {}", True, False),
|
"list_dir": (["path"], "ls {}", True, False),
|
||||||
@ -81,6 +83,8 @@ def _extract_arg(tc, key_args: list[str]) -> str | None:
|
|||||||
|
|
||||||
def _fmt_known(tc, fmt: tuple, max_length: int = 40) -> str:
|
def _fmt_known(tc, fmt: tuple, max_length: int = 40) -> str:
|
||||||
"""Format a registered tool using its template."""
|
"""Format a registered tool using its template."""
|
||||||
|
if not fmt[0] and "{}" not in fmt[1]:
|
||||||
|
return fmt[1]
|
||||||
val = _extract_arg(tc, fmt[0])
|
val = _extract_arg(tc, fmt[0])
|
||||||
if val is None:
|
if val is None:
|
||||||
return tc.name
|
return tc.name
|
||||||
|
|||||||
@ -73,12 +73,16 @@ def _mask_secret_hint(secret: str | None) -> str | None:
|
|||||||
def _provider_requires_api_key(spec: Any) -> bool:
|
def _provider_requires_api_key(spec: Any) -> bool:
|
||||||
if spec.backend == "azure_openai":
|
if spec.backend == "azure_openai":
|
||||||
return True
|
return True
|
||||||
|
if spec.is_oauth:
|
||||||
|
return False
|
||||||
if spec.is_local or spec.is_direct:
|
if spec.is_local or spec.is_direct:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _provider_configured_for_settings(spec: Any, provider_config: Any) -> bool:
|
def _provider_configured_for_settings(spec: Any, provider_config: Any) -> bool:
|
||||||
|
if spec.is_oauth:
|
||||||
|
return True
|
||||||
if _provider_requires_api_key(spec):
|
if _provider_requires_api_key(spec):
|
||||||
return bool(provider_config.api_key)
|
return bool(provider_config.api_key)
|
||||||
return bool(
|
return bool(
|
||||||
|
|||||||
@ -139,6 +139,13 @@ class TestLoadBootstrapFiles:
|
|||||||
for name in ContextBuilder.BOOTSTRAP_FILES:
|
for name in ContextBuilder.BOOTSTRAP_FILES:
|
||||||
assert f"## {name}" in result
|
assert f"## {name}" in result
|
||||||
|
|
||||||
|
def test_legacy_tools_md_is_not_bootstrapped(self, tmp_path):
|
||||||
|
(tmp_path / "TOOLS.md").write_text("workspace tool notes", encoding="utf-8")
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
result = builder._load_bootstrap_files()
|
||||||
|
assert "TOOLS.md" not in result
|
||||||
|
assert "workspace tool notes" not in result
|
||||||
|
|
||||||
def test_utf8_content(self, tmp_path):
|
def test_utf8_content(self, tmp_path):
|
||||||
(tmp_path / "AGENTS.md").write_text("用中文回复", encoding="utf-8")
|
(tmp_path / "AGENTS.md").write_text("用中文回复", encoding="utf-8")
|
||||||
builder = _builder(tmp_path)
|
builder = _builder(tmp_path)
|
||||||
@ -171,6 +178,37 @@ class TestIsTemplateContent:
|
|||||||
assert ContextBuilder._is_template_content("totally different", "memory/MEMORY.md") is False
|
assert ContextBuilder._is_template_content("totally different", "memory/MEMORY.md") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bundled bootstrap templates
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBundledToolContract:
|
||||||
|
def test_tool_contract_balances_general_and_coding_workflows(self):
|
||||||
|
from importlib.resources import files as pkg_files
|
||||||
|
|
||||||
|
tpl = pkg_files("nanobot") / "templates" / "agent" / "tool_contract.md"
|
||||||
|
content = tpl.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
assert "## General Tool Contract" in content
|
||||||
|
assert "Use the narrowest structured tool" in content
|
||||||
|
assert "Do not use `exec` as a universal workaround" in content
|
||||||
|
assert "## File and Coding Workflows" in content
|
||||||
|
assert "apply_patch" in content
|
||||||
|
assert "## Web and External Information" in content
|
||||||
|
assert "## Messaging and Media" in content
|
||||||
|
assert "## Scheduling and Background Work" in content
|
||||||
|
assert "pure coding" not in content.lower()
|
||||||
|
|
||||||
|
def test_tool_contract_is_injected_without_workspace_file(self, tmp_path):
|
||||||
|
builder = _builder(tmp_path)
|
||||||
|
prompt = builder.build_system_prompt()
|
||||||
|
|
||||||
|
assert "# Tool Usage Notes" in prompt
|
||||||
|
assert "## General Tool Contract" in prompt
|
||||||
|
assert "Do not use `exec` as a universal workaround" in prompt
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _build_user_content
|
# _build_user_content
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@ -346,6 +346,26 @@ class TestSyncWorkspaceTemplates:
|
|||||||
content = (workspace / "AGENTS.md").read_text()
|
content = (workspace / "AGENTS.md").read_text()
|
||||||
assert content == "existing content"
|
assert content == "existing content"
|
||||||
|
|
||||||
|
def test_does_not_create_tools_md(self, tmp_path):
|
||||||
|
"""Tool contract is injected internally, not copied into user workspaces."""
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
|
||||||
|
added = sync_workspace_templates(workspace, silent=True)
|
||||||
|
|
||||||
|
assert "TOOLS.md" not in added
|
||||||
|
assert not (workspace / "TOOLS.md").exists()
|
||||||
|
|
||||||
|
def test_preserves_existing_tools_md_without_overwriting(self, tmp_path):
|
||||||
|
"""Legacy user workspaces may have TOOLS.md; sync should leave it untouched."""
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
workspace.mkdir(parents=True)
|
||||||
|
tools_path = workspace / "TOOLS.md"
|
||||||
|
tools_path.write_text("custom tool notes", encoding="utf-8")
|
||||||
|
|
||||||
|
sync_workspace_templates(workspace, silent=True)
|
||||||
|
|
||||||
|
assert tools_path.read_text(encoding="utf-8") == "custom tool notes"
|
||||||
|
|
||||||
def test_creates_memory_directory(self, tmp_path):
|
def test_creates_memory_directory(self, tmp_path):
|
||||||
"""Should create memory directory structure."""
|
"""Should create memory directory structure."""
|
||||||
workspace = tmp_path / "workspace"
|
workspace = tmp_path / "workspace"
|
||||||
|
|||||||
1514
tests/channels/test_signal_channel.py
Normal file
1514
tests/channels/test_signal_channel.py
Normal file
File diff suppressed because it is too large
Load Diff
525
tests/channels/test_signal_markdown.py
Normal file
525
tests/channels/test_signal_markdown.py
Normal file
@ -0,0 +1,525 @@
|
|||||||
|
"""Unit tests for the Signal markdown → plain text + textStyle converter."""
|
||||||
|
|
||||||
|
from nanobot.channels.signal import _markdown_to_signal, _partition_styles
|
||||||
|
from nanobot.utils.helpers import split_message
|
||||||
|
|
||||||
|
|
||||||
|
def _utf16_len(s: str) -> int:
|
||||||
|
return len(s.encode("utf-16-le")) // 2
|
||||||
|
|
||||||
|
|
||||||
|
def styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]:
|
||||||
|
"""Return a dict mapping each styled substring to its style list."""
|
||||||
|
result: dict[str, list[str]] = {}
|
||||||
|
for entry in text_styles:
|
||||||
|
start_s, length_s, style = entry.split(":", 2)
|
||||||
|
start, length = int(start_s), int(length_s)
|
||||||
|
span = plain[start : start + length]
|
||||||
|
result.setdefault(span, []).append(style)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def utf16_styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]:
|
||||||
|
"""Like styles_for, but slices `plain` using UTF-16 offsets (Signal's units)."""
|
||||||
|
encoded = plain.encode("utf-16-le")
|
||||||
|
result: dict[str, list[str]] = {}
|
||||||
|
for entry in text_styles:
|
||||||
|
start_s, length_s, style = entry.split(":", 2)
|
||||||
|
start, length = int(start_s), int(length_s)
|
||||||
|
span = encoded[start * 2 : (start + length) * 2].decode("utf-16-le")
|
||||||
|
result.setdefault(span, []).append(style)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Basic cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty():
|
||||||
|
plain, styles = _markdown_to_signal("")
|
||||||
|
assert plain == ""
|
||||||
|
assert styles == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_plain_text():
|
||||||
|
plain, styles = _markdown_to_signal("hello world")
|
||||||
|
assert plain == "hello world"
|
||||||
|
assert styles == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_bold_stars():
|
||||||
|
plain, styles = _markdown_to_signal("say **hello** now")
|
||||||
|
assert plain == "say hello now"
|
||||||
|
assert styles_for(plain, styles) == {"hello": ["BOLD"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_bold_underscores():
|
||||||
|
plain, styles = _markdown_to_signal("say __hello__ now")
|
||||||
|
assert plain == "say hello now"
|
||||||
|
assert styles_for(plain, styles) == {"hello": ["BOLD"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_italic_star():
|
||||||
|
plain, styles = _markdown_to_signal("say *hello* now")
|
||||||
|
assert plain == "say hello now"
|
||||||
|
assert styles_for(plain, styles) == {"hello": ["ITALIC"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_italic_underscore():
|
||||||
|
plain, styles = _markdown_to_signal("say _hello_ now")
|
||||||
|
assert plain == "say hello now"
|
||||||
|
assert styles_for(plain, styles) == {"hello": ["ITALIC"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_strikethrough():
|
||||||
|
plain, styles = _markdown_to_signal("say ~~hello~~ now")
|
||||||
|
assert plain == "say hello now"
|
||||||
|
assert styles_for(plain, styles) == {"hello": ["STRIKETHROUGH"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Code
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_inline_code():
|
||||||
|
plain, styles = _markdown_to_signal("run `ls -la` here")
|
||||||
|
assert plain == "run ls -la here"
|
||||||
|
assert styles_for(plain, styles) == {"ls -la": ["MONOSPACE"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_code_block():
|
||||||
|
plain, styles = _markdown_to_signal("```\nprint('hi')\n```")
|
||||||
|
assert "print('hi')" in plain
|
||||||
|
assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or "MONOSPACE" in str(
|
||||||
|
styles_for(plain, styles)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_code_block_with_lang():
|
||||||
|
plain, styles = _markdown_to_signal("```python\ncode\n```")
|
||||||
|
assert "code" in plain
|
||||||
|
assert any("MONOSPACE" in s for s in styles)
|
||||||
|
|
||||||
|
|
||||||
|
def test_code_block_not_processed_further():
|
||||||
|
"""Markdown inside a code block must not be styled."""
|
||||||
|
plain, styles = _markdown_to_signal("```\n**not bold**\n```")
|
||||||
|
assert "**not bold**" in plain
|
||||||
|
# Only MONOSPACE should be applied, no BOLD
|
||||||
|
for entry in styles:
|
||||||
|
assert "BOLD" not in entry
|
||||||
|
|
||||||
|
|
||||||
|
def test_inline_code_not_processed_further():
|
||||||
|
"""Markdown inside inline code must not be styled."""
|
||||||
|
plain, styles = _markdown_to_signal("use `**raw**` please")
|
||||||
|
assert "**raw**" in plain
|
||||||
|
for entry in styles:
|
||||||
|
assert "BOLD" not in entry
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Headers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_header_becomes_bold():
|
||||||
|
plain, styles = _markdown_to_signal("# My Title")
|
||||||
|
assert plain == "My Title"
|
||||||
|
assert styles_for(plain, styles) == {"My Title": ["BOLD"]}
|
||||||
|
|
||||||
|
|
||||||
|
def test_h2_becomes_bold():
|
||||||
|
plain, styles = _markdown_to_signal("## Sub-section")
|
||||||
|
assert plain == "Sub-section"
|
||||||
|
assert styles_for(plain, styles) == {"Sub-section": ["BOLD"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Blockquotes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_blockquote_strips_marker():
|
||||||
|
plain, styles = _markdown_to_signal("> some quote")
|
||||||
|
assert plain == "some quote"
|
||||||
|
assert styles == []
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Lists
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_bullet_dash():
|
||||||
|
plain, styles = _markdown_to_signal("- item one")
|
||||||
|
assert plain == "• item one"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bullet_star():
|
||||||
|
plain, styles = _markdown_to_signal("* item two")
|
||||||
|
assert plain == "• item two"
|
||||||
|
|
||||||
|
|
||||||
|
def test_numbered_list():
|
||||||
|
plain, styles = _markdown_to_signal("1. first\n2. second")
|
||||||
|
assert "1. first" in plain
|
||||||
|
assert "2. second" in plain
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Links
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_link_text_differs_from_url():
|
||||||
|
plain, styles = _markdown_to_signal("[Click here](https://example.com)")
|
||||||
|
assert plain == "Click here (https://example.com)"
|
||||||
|
assert styles == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_link_text_equals_url():
|
||||||
|
plain, styles = _markdown_to_signal("[https://example.com](https://example.com)")
|
||||||
|
assert plain == "https://example.com"
|
||||||
|
assert styles == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_link_text_equals_url_without_scheme():
|
||||||
|
plain, styles = _markdown_to_signal("[example.com](https://example.com)")
|
||||||
|
assert plain == "https://example.com"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Mixed / nesting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_bold_and_italic_adjacent():
|
||||||
|
plain, styles = _markdown_to_signal("**bold** and *italic*")
|
||||||
|
assert plain == "bold and italic"
|
||||||
|
sd = styles_for(plain, styles)
|
||||||
|
assert sd.get("bold") == ["BOLD"]
|
||||||
|
assert sd.get("italic") == ["ITALIC"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_header_with_inline_code():
|
||||||
|
"""Header becomes BOLD; code inside becomes MONOSPACE (not double-BOLD)."""
|
||||||
|
plain, styles = _markdown_to_signal("# Use `grep`")
|
||||||
|
assert plain == "Use grep"
|
||||||
|
sd = styles_for(plain, styles)
|
||||||
|
assert "BOLD" in sd.get("Use ", []) or "BOLD" in str(styles)
|
||||||
|
assert "MONOSPACE" in sd.get("grep", [])
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiline_mixed():
|
||||||
|
md = "**Title**\n\nSome *italic* text.\n\n- bullet\n- another"
|
||||||
|
plain, styles = _markdown_to_signal(md)
|
||||||
|
assert "Title" in plain
|
||||||
|
assert "italic" in plain
|
||||||
|
assert "• bullet" in plain
|
||||||
|
sd = styles_for(plain, styles)
|
||||||
|
assert "BOLD" in sd.get("Title", [])
|
||||||
|
assert "ITALIC" in sd.get("italic", [])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Table rendering
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_table_rendered_as_monospace():
|
||||||
|
md = "| A | B |\n| - | - |\n| 1 | 2 |"
|
||||||
|
plain, styles = _markdown_to_signal(md)
|
||||||
|
assert "A" in plain and "B" in plain
|
||||||
|
assert any("MONOSPACE" in s for s in styles)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Style range format
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_style_range_format():
|
||||||
|
"""Each style entry must be 'start:length:STYLE'."""
|
||||||
|
_, styles = _markdown_to_signal("**bold** text")
|
||||||
|
for entry in styles:
|
||||||
|
parts = entry.split(":")
|
||||||
|
assert len(parts) == 3
|
||||||
|
assert parts[0].isdigit()
|
||||||
|
assert parts[1].isdigit()
|
||||||
|
assert parts[2] in {"BOLD", "ITALIC", "STRIKETHROUGH", "MONOSPACE", "SPOILER"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_style_ranges_are_within_bounds():
|
||||||
|
text = "hello **world** end"
|
||||||
|
plain, styles = _markdown_to_signal(text)
|
||||||
|
for entry in styles:
|
||||||
|
start_s, length_s, _ = entry.split(":", 2)
|
||||||
|
start, length = int(start_s), int(length_s)
|
||||||
|
assert start >= 0
|
||||||
|
assert start + length <= len(plain)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Non-BMP / UTF-16 offsets
|
||||||
|
#
|
||||||
|
# Signal's BodyRange (and signal-cli's textStyle) interprets start/length in
|
||||||
|
# UTF-16 code units. Python's len() counts code points, so characters outside
|
||||||
|
# the BMP (emojis, supplementary CJK) shift offsets by +1 per occurrence.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def assert_within_utf16_bounds(plain: str, styles: list[str]) -> None:
|
||||||
|
limit = _utf16_len(plain)
|
||||||
|
for entry in styles:
|
||||||
|
start_s, length_s, _ = entry.split(":", 2)
|
||||||
|
start, length = int(start_s), int(length_s)
|
||||||
|
assert start >= 0
|
||||||
|
assert start + length <= limit, f"range {entry} exceeds utf-16 length {limit} of {plain!r}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bold_with_emoji_inside():
|
||||||
|
plain, styles = _markdown_to_signal("**hi 🎉 bye**")
|
||||||
|
assert plain == "hi 🎉 bye"
|
||||||
|
assert utf16_styles_for(plain, styles) == {"hi 🎉 bye": ["BOLD"]}
|
||||||
|
assert_within_utf16_bounds(plain, styles)
|
||||||
|
|
||||||
|
|
||||||
|
def test_italic_with_trailing_emoji():
|
||||||
|
plain, styles = _markdown_to_signal("*bye 🎉*")
|
||||||
|
assert plain == "bye 🎉"
|
||||||
|
assert utf16_styles_for(plain, styles) == {"bye 🎉": ["ITALIC"]}
|
||||||
|
assert_within_utf16_bounds(plain, styles)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bold_after_emoji_prefix():
|
||||||
|
plain, styles = _markdown_to_signal("🎉 **bold**")
|
||||||
|
assert plain == "🎉 bold"
|
||||||
|
assert utf16_styles_for(plain, styles) == {"bold": ["BOLD"]}
|
||||||
|
assert_within_utf16_bounds(plain, styles)
|
||||||
|
|
||||||
|
|
||||||
|
def test_bold_after_and_inside_emoji():
|
||||||
|
plain, styles = _markdown_to_signal("🎉 **a 🎊 b**")
|
||||||
|
assert plain == "🎉 a 🎊 b"
|
||||||
|
assert utf16_styles_for(plain, styles) == {"a 🎊 b": ["BOLD"]}
|
||||||
|
assert_within_utf16_bounds(plain, styles)
|
||||||
|
|
||||||
|
|
||||||
|
def test_supplementary_cjk_in_bold():
|
||||||
|
"""Non-BMP CJK (U+20BB7) proves the bug is UTF-16, not emoji-specific."""
|
||||||
|
plain, styles = _markdown_to_signal("**𠮷野家**")
|
||||||
|
assert plain == "𠮷野家"
|
||||||
|
assert utf16_styles_for(plain, styles) == {"𠮷野家": ["BOLD"]}
|
||||||
|
assert_within_utf16_bounds(plain, styles)
|
||||||
|
|
||||||
|
|
||||||
|
def test_zwj_emoji_in_bold():
|
||||||
|
"""ZWJ family sequence = multiple surrogate pairs + BMP ZWJs."""
|
||||||
|
plain, styles = _markdown_to_signal("**hi 👨👩👧 bye**")
|
||||||
|
assert plain == "hi 👨👩👧 bye"
|
||||||
|
assert utf16_styles_for(plain, styles) == {"hi 👨👩👧 bye": ["BOLD"]}
|
||||||
|
assert_within_utf16_bounds(plain, styles)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ascii_offsets_unchanged():
|
||||||
|
"""ASCII-only path must produce the same offsets as before the UTF-16 fix."""
|
||||||
|
plain, styles = _markdown_to_signal("**bold** plain *it*")
|
||||||
|
assert plain == "bold plain it"
|
||||||
|
assert sorted(styles) == sorted(["0:4:BOLD", "11:2:ITALIC"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_reported_daily_brief_pattern():
|
||||||
|
"""Regression for the reported bug: a single non-BMP emoji shifts every
|
||||||
|
subsequent styled span left by 1 UTF-16 unit, lopping off the last letter.
|
||||||
|
"""
|
||||||
|
md = (
|
||||||
|
"**Weather**\n"
|
||||||
|
"- Conditions: 🌩️ Thunderstorms\n\n"
|
||||||
|
"**News**\n"
|
||||||
|
"*World*\n"
|
||||||
|
"*Local*\n\n"
|
||||||
|
"**Quote of the Day**"
|
||||||
|
)
|
||||||
|
plain, styles = _markdown_to_signal(md)
|
||||||
|
sd = utf16_styles_for(plain, styles)
|
||||||
|
assert sd.get("Weather") == ["BOLD"]
|
||||||
|
assert sd.get("News") == ["BOLD"]
|
||||||
|
assert sd.get("World") == ["ITALIC"]
|
||||||
|
assert sd.get("Local") == ["ITALIC"]
|
||||||
|
assert sd.get("Quote of the Day") == ["BOLD"]
|
||||||
|
assert_within_utf16_bounds(plain, styles)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Chunk redistribution
|
||||||
|
#
|
||||||
|
# split_message can break a long Signal payload into multiple chunks. The
|
||||||
|
# style ranges from _markdown_to_signal are anchored to the full text, so
|
||||||
|
# they must be redistributed per-chunk with rebased offsets — otherwise
|
||||||
|
# styles for chunks 1..N are silently lost.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_chunk_styles(text: str, max_len: int) -> tuple[list[str], list[list[str]]]:
|
||||||
|
"""Helper: full markdown → signal pipeline, including chunking."""
|
||||||
|
plain, styles = _markdown_to_signal(text)
|
||||||
|
chunks = split_message(plain, max_len) if plain else [""]
|
||||||
|
return chunks, _partition_styles(plain, chunks, styles)
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_styles_single_chunk_passthrough():
|
||||||
|
plain, styles = _markdown_to_signal("**bold** plain *it*")
|
||||||
|
parts = _partition_styles(plain, [plain], styles)
|
||||||
|
assert parts == [styles]
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_styles_no_styles():
|
||||||
|
plain = "hello world"
|
||||||
|
assert _partition_styles(plain, [plain], []) == [[]]
|
||||||
|
assert _partition_styles(plain, ["hello", "world"], []) == [[], []]
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_styles_drops_styles_outside_chunks():
|
||||||
|
"""Whitespace trimmed by split_message must not carry a style range."""
|
||||||
|
plain = "a b"
|
||||||
|
# Fake a style spanning the trimmed whitespace only.
|
||||||
|
chunks = ["a", "b"]
|
||||||
|
parts = _partition_styles(plain, chunks, ["1:3:BOLD"])
|
||||||
|
assert parts == [[], []]
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_styles_long_message_preserves_chunk_one_styles():
|
||||||
|
"""A bold span deep in the message must follow the message into chunk 1."""
|
||||||
|
# Two ~30-char paragraphs separated by a blank line, then **tail**.
|
||||||
|
line_a = "alpha " * 5 # 30 chars, ends with space
|
||||||
|
line_b = "beta " * 5
|
||||||
|
md = f"{line_a.strip()}\n\n{line_b.strip()}\n\n**tail**"
|
||||||
|
plain, styles = _markdown_to_signal(md)
|
||||||
|
# Force a split between the paragraphs.
|
||||||
|
max_len = len(line_a.strip()) + 2 # fits paragraph A + the "\n\n"
|
||||||
|
chunks = split_message(plain, max_len)
|
||||||
|
assert len(chunks) >= 2, "test setup must produce a split"
|
||||||
|
parts = _partition_styles(plain, chunks, styles)
|
||||||
|
# The bold "tail" should land in the last chunk, with chunk-relative offset.
|
||||||
|
final_chunk = chunks[-1]
|
||||||
|
final_styles = parts[-1]
|
||||||
|
assert any("BOLD" in s for s in final_styles)
|
||||||
|
for entry in final_styles:
|
||||||
|
s, ln, _ = entry.split(":", 2)
|
||||||
|
start, length = int(s), int(ln)
|
||||||
|
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
|
||||||
|
"utf-16-le"
|
||||||
|
)
|
||||||
|
assert slice_ == "tail"
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_styles_chunk_zero_styles_unchanged():
|
||||||
|
"""Styles entirely in chunk 0 keep their original offsets."""
|
||||||
|
md = "**head** middle and **tail**"
|
||||||
|
plain, styles = _markdown_to_signal(md)
|
||||||
|
# Split so chunk 0 contains "head" and part of the rest, chunk 1 contains "tail".
|
||||||
|
chunks = split_message(plain, 12)
|
||||||
|
assert len(chunks) >= 2
|
||||||
|
parts = _partition_styles(plain, chunks, styles)
|
||||||
|
# "head" lives in chunk 0; assert its offset is unchanged (chunk 0 starts at 0).
|
||||||
|
head_entries = [s for s in parts[0] if "BOLD" in s]
|
||||||
|
assert any(s.startswith("0:4:") for s in head_entries)
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_styles_with_non_bmp_chunk_offset():
|
||||||
|
"""Chunk-start offsets must be expressed in UTF-16 code units."""
|
||||||
|
# Emoji in chunk 0, bold in chunk 1.
|
||||||
|
md = "🎉 alpha beta gamma\n\n**tail**"
|
||||||
|
plain, styles = _markdown_to_signal(md)
|
||||||
|
chunks = split_message(plain, 18)
|
||||||
|
assert len(chunks) >= 2
|
||||||
|
parts = _partition_styles(plain, chunks, styles)
|
||||||
|
final_styles = parts[-1]
|
||||||
|
assert any("BOLD" in s for s in final_styles)
|
||||||
|
final_chunk = chunks[-1]
|
||||||
|
for entry in final_styles:
|
||||||
|
s, ln, _ = entry.split(":", 2)
|
||||||
|
start, length = int(s), int(ln)
|
||||||
|
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
|
||||||
|
"utf-16-le"
|
||||||
|
)
|
||||||
|
assert slice_ == "tail"
|
||||||
|
|
||||||
|
|
||||||
|
def test_partition_styles_range_spanning_chunks_is_split():
|
||||||
|
"""A style range that straddles a chunk boundary gets sliced into both chunks."""
|
||||||
|
# Construct manually: plain = "abc def", style covers "abc def" (whole thing).
|
||||||
|
plain = "abc def"
|
||||||
|
chunks = split_message(plain, 4) # "abc" / "def"
|
||||||
|
assert chunks == ["abc", "def"]
|
||||||
|
parts = _partition_styles(plain, chunks, ["0:7:BOLD"])
|
||||||
|
# Chunk 0 holds 0:3:BOLD, chunk 1 holds 0:3:BOLD (length=3 each, "def" only
|
||||||
|
# since the space was trimmed by lstrip).
|
||||||
|
assert parts[0] == ["0:3:BOLD"]
|
||||||
|
assert parts[1] == ["0:3:BOLD"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Adjacency, nesting, and malformed input
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_bold_italic_combo_outer_bold_inner_italic():
|
||||||
|
"""`**_combo_**` carries both BOLD and ITALIC over the same span."""
|
||||||
|
plain, styles = _markdown_to_signal("**_combo_**")
|
||||||
|
assert plain == "combo"
|
||||||
|
sd = styles_for(plain, styles)
|
||||||
|
assert set(sd.get("combo", [])) == {"BOLD", "ITALIC"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_bold_and_italic_adjacent_no_separator():
|
||||||
|
"""`**bold***italic*` produces BOLD on `bold` and ITALIC on `italic`."""
|
||||||
|
plain, styles = _markdown_to_signal("**bold***italic*")
|
||||||
|
assert plain == "bolditalic"
|
||||||
|
sd = styles_for(plain, styles)
|
||||||
|
assert sd.get("bold") == ["BOLD"]
|
||||||
|
assert sd.get("italic") == ["ITALIC"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_unclosed_bold_falls_through_as_plain():
|
||||||
|
"""An unmatched `**` opener round-trips as literal text with no style."""
|
||||||
|
plain, styles = _markdown_to_signal("**bold")
|
||||||
|
assert plain == "**bold"
|
||||||
|
assert styles == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_unclosed_inline_code_falls_through_as_plain():
|
||||||
|
"""An unmatched backtick round-trips as literal text with no style."""
|
||||||
|
plain, styles = _markdown_to_signal("use `grep")
|
||||||
|
assert plain == "use `grep"
|
||||||
|
assert styles == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_inline_code_inside_blockquote():
|
||||||
|
"""Blockquote prefix is stripped; inline code becomes MONOSPACE."""
|
||||||
|
plain, styles = _markdown_to_signal("> use `grep`")
|
||||||
|
assert plain == "use grep"
|
||||||
|
sd = styles_for(plain, styles)
|
||||||
|
assert sd.get("grep") == ["MONOSPACE"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_header_with_inner_bold_produces_contiguous_bold_ranges():
|
||||||
|
"""`# **wrap** me` — header forces BOLD over the whole line; the inner `**`
|
||||||
|
splits the run, yielding two contiguous BOLD ranges that together cover
|
||||||
|
"wrap me". This is intentional — Signal renders adjacent same-style ranges
|
||||||
|
as a single visual span.
|
||||||
|
"""
|
||||||
|
plain, styles = _markdown_to_signal("# **wrap** me")
|
||||||
|
assert plain == "wrap me"
|
||||||
|
# Both ranges are BOLD; collectively they cover the whole "wrap me".
|
||||||
|
bold_ranges = [s for s in styles if s.endswith(":BOLD")]
|
||||||
|
assert len(bold_ranges) == 2
|
||||||
|
covered = set()
|
||||||
|
for entry in bold_ranges:
|
||||||
|
start, length, _ = entry.split(":", 2)
|
||||||
|
for i in range(int(start), int(start) + int(length)):
|
||||||
|
covered.add(i)
|
||||||
|
assert covered == set(range(len(plain)))
|
||||||
@ -1055,6 +1055,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist(
|
|||||||
}
|
}
|
||||||
assert image_providers["openrouter"]["label"] == "OpenRouter"
|
assert image_providers["openrouter"]["label"] == "OpenRouter"
|
||||||
assert image_providers["openrouter"]["configured"] is False
|
assert image_providers["openrouter"]["configured"] is False
|
||||||
|
assert image_providers["openai_codex"]["configured"] is True
|
||||||
assert image_providers["gemini"]["label"] == "Gemini"
|
assert image_providers["gemini"]["label"] == "Gemini"
|
||||||
assert body["runtime"]["config_path"] == str(config_path)
|
assert body["runtime"]["config_path"] == str(config_path)
|
||||||
workspace_path = body["runtime"]["workspace_path"].replace("\\", "/")
|
workspace_path = body["runtime"]["workspace_path"].replace("\\", "/")
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
@ -374,6 +375,7 @@ async def test_send_uses_typing_start_and_cancel_when_ticket_available() -> None
|
|||||||
channel._client = object()
|
channel._client = object()
|
||||||
channel._token = "token"
|
channel._token = "token"
|
||||||
channel._context_tokens["wx-user"] = "ctx-typing"
|
channel._context_tokens["wx-user"] = "ctx-typing"
|
||||||
|
channel._context_token_at["wx-user"] = time.time()
|
||||||
channel._send_text = AsyncMock()
|
channel._send_text = AsyncMock()
|
||||||
channel._api_post = AsyncMock(
|
channel._api_post = AsyncMock(
|
||||||
side_effect=[
|
side_effect=[
|
||||||
@ -402,6 +404,7 @@ async def test_send_still_sends_text_when_typing_ticket_missing() -> None:
|
|||||||
channel._client = object()
|
channel._client = object()
|
||||||
channel._token = "token"
|
channel._token = "token"
|
||||||
channel._context_tokens["wx-user"] = "ctx-no-ticket"
|
channel._context_tokens["wx-user"] = "ctx-no-ticket"
|
||||||
|
channel._context_token_at["wx-user"] = time.time()
|
||||||
channel._send_text = AsyncMock()
|
channel._send_text = AsyncMock()
|
||||||
channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "no config"})
|
channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "no config"})
|
||||||
|
|
||||||
@ -1254,3 +1257,526 @@ async def test_send_text_succeeds_on_zero_errcode() -> None:
|
|||||||
await channel._send_text("wx-user", "hello", "ctx-ok")
|
await channel._send_text("wx-user", "hello", "ctx-ok")
|
||||||
|
|
||||||
channel._api_post.assert_awaited_once()
|
channel._api_post.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_text_raises_on_nonzero_ret_even_when_errcode_zero() -> None:
|
||||||
|
"""_send_text must raise when the API returns ret != 0, even if errcode is 0.
|
||||||
|
|
||||||
|
The iLink API signals failure through either field. Checking only errcode
|
||||||
|
caused silent message drops (responses generated but never delivered).
|
||||||
|
"""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel._api_post = AsyncMock(
|
||||||
|
return_value={"ret": -100, "errcode": 0, "errmsg": "internal error"}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="WeChat send text error.*ret=-100.*errcode=0"):
|
||||||
|
await channel._send_text("wx-user", "hello", "ctx-ok")
|
||||||
|
|
||||||
|
channel._api_post.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests for _poll_once not silently dropping messages on processing errors
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_once_logs_exception_on_process_message_failure(monkeypatch) -> None:
|
||||||
|
"""When _process_message raises, _poll_once must log the error and continue
|
||||||
|
processing remaining messages instead of silently swallowing the exception."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = SimpleNamespace(timeout=None)
|
||||||
|
channel._token = "token"
|
||||||
|
channel._get_updates_buf = "old-buf"
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
logged_messages: list[str] = []
|
||||||
|
|
||||||
|
async def _failing_process(msg: dict) -> None:
|
||||||
|
calls.append(msg.get("message_id"))
|
||||||
|
if msg.get("message_id") == "msg-1":
|
||||||
|
raise RuntimeError("processing failed")
|
||||||
|
|
||||||
|
channel._process_message = _failing_process # type: ignore[method-assign]
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
channel.logger,
|
||||||
|
"exception",
|
||||||
|
lambda message, *args, **kwargs: logged_messages.append(str(message)),
|
||||||
|
)
|
||||||
|
|
||||||
|
channel._api_post = AsyncMock( # type: ignore[method-assign]
|
||||||
|
return_value={
|
||||||
|
"ret": 0,
|
||||||
|
"errcode": 0,
|
||||||
|
"get_updates_buf": "new-buf",
|
||||||
|
"msgs": [
|
||||||
|
{"message_id": "msg-1", "message_type": 1},
|
||||||
|
{"message_id": "msg-2", "message_type": 1},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._poll_once()
|
||||||
|
|
||||||
|
# Both messages should have been attempted
|
||||||
|
assert calls == ["msg-1", "msg-2"]
|
||||||
|
# Buffer should still advance (already updated before processing)
|
||||||
|
assert channel._get_updates_buf == "new-buf"
|
||||||
|
# Error should be logged
|
||||||
|
assert any("Failed to process WeChat message" in m for m in logged_messages)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_poll_loop_logs_exception_and_continues_on_poll_failure(monkeypatch) -> None:
|
||||||
|
"""When _poll_once raises a non-timeout exception, the start() loop must log
|
||||||
|
the error and continue polling instead of exiting silently."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel.config.token = "token" # skip QR login in start()
|
||||||
|
channel._running = True
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
logged_messages: list[str] = []
|
||||||
|
|
||||||
|
async def _failing_poll() -> None:
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
raise RuntimeError("poll exploded")
|
||||||
|
channel._running = False # Stop after second call
|
||||||
|
|
||||||
|
channel._poll_once = _failing_poll # type: ignore[method-assign]
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
channel.logger,
|
||||||
|
"exception",
|
||||||
|
lambda message, *args, **kwargs: logged_messages.append(str(message)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use a tiny retry delay so the test finishes quickly
|
||||||
|
original_retry = weixin_mod.RETRY_DELAY_S
|
||||||
|
weixin_mod.RETRY_DELAY_S = 0.01
|
||||||
|
try:
|
||||||
|
await channel.start()
|
||||||
|
finally:
|
||||||
|
weixin_mod.RETRY_DELAY_S = original_retry
|
||||||
|
|
||||||
|
assert call_count == 2
|
||||||
|
assert any("WeChat poll loop error" in m for m in logged_messages)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tool-hint buffering
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_buffer_single_tool_hint_not_sent_immediately() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel.send_tool_hints = True
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._context_token_at["wx-user"] = time.time()
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "Using tool",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True, "_tool_hint": True},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
channel._send_text.assert_not_awaited()
|
||||||
|
assert channel._pending_tool_hints["wx-user"] == ["Using tool"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_buffer_multiple_tool_hints_flushed_on_final_answer() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel.send_tool_hints = True
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._context_token_at["wx-user"] = time.time()
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
for hint in ["tool1", "tool2"]:
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": hint,
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True, "_tool_hint": True},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "Done",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel._send_text.await_count == 2
|
||||||
|
channel._send_text.assert_any_await("wx-user", "tool1\n\ntool2", "ctx-1")
|
||||||
|
channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
|
||||||
|
assert "wx-user" not in channel._pending_tool_hints
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_thought_progress_flushes_tool_hints() -> None:
|
||||||
|
"""Thoughts are visible progress messages and must act as separators,
|
||||||
|
flushing buffered tool hints before they are sent."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel.send_tool_hints = True
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._context_token_at["wx-user"] = time.time()
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
# Buffer a tool hint
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "search 'foo'",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True, "_tool_hint": True},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send a thought — progress but not a tool_hint.
|
||||||
|
# It must act as a separator and flush the buffered hint.
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "Let me think...",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
# The buffered hint was flushed before the thought was sent.
|
||||||
|
channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1")
|
||||||
|
channel._send_text.assert_any_await("wx-user", "Let me think...", "ctx-1")
|
||||||
|
assert "wx-user" not in channel._pending_tool_hints
|
||||||
|
|
||||||
|
# Final answer arrives with nothing left to flush.
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "Done",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel._send_text.await_count == 3
|
||||||
|
channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reasoning_delta_does_not_flush_tool_hints() -> None:
|
||||||
|
"""Reasoning deltas are invisible in WeChat and must NOT flush buffered
|
||||||
|
tool hints — otherwise hints separated only by hidden reasoning would
|
||||||
|
fail to coalesce."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel.send_tool_hints = True
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._context_token_at["wx-user"] = time.time()
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
# Buffer a tool hint
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "search 'foo'",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True, "_tool_hint": True},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send a reasoning delta — invisible in WeChat, must NOT flush
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "Thinking step 1...",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True, "_reasoning_delta": True},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reasoning is invisible; hint stays buffered, _send_text not called
|
||||||
|
channel._send_text.assert_not_awaited()
|
||||||
|
assert channel._pending_tool_hints["wx-user"] == ["search 'foo'"]
|
||||||
|
|
||||||
|
# Final answer flushes the buffered hint
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "Done",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1")
|
||||||
|
channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
|
||||||
|
assert "wx-user" not in channel._pending_tool_hints
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_progress_message_does_not_flush_tool_hints() -> None:
|
||||||
|
"""Empty progress messages (e.g. after_iteration tool_events) have no
|
||||||
|
visible content and must NOT act as separators."""
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel.send_tool_hints = True
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._context_token_at["wx-user"] = time.time()
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
# Buffer a tool hint
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "search 'foo'",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True, "_tool_hint": True},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send an empty progress message (no content, no media)
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True, "_tool_events": [{"phase": "end"}]},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Nothing should have been sent yet
|
||||||
|
channel._send_text.assert_not_awaited()
|
||||||
|
assert channel._pending_tool_hints["wx-user"] == ["search 'foo'"]
|
||||||
|
|
||||||
|
# Final answer flushes the buffered hint
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "Done",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1")
|
||||||
|
channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
|
||||||
|
assert "wx-user" not in channel._pending_tool_hints
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_buffer_flush_refreshes_context_token() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel.send_tool_hints = True
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-old"
|
||||||
|
channel._context_token_at["wx-user"] = time.time()
|
||||||
|
channel._refresh_context_token_if_stale = AsyncMock(return_value="ctx-refreshed")
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "hint",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True, "_tool_hint": True},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "Done",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel._refresh_context_token_if_stale.await_count == 2
|
||||||
|
channel._refresh_context_token_if_stale.assert_any_await("wx-user", "ctx-old")
|
||||||
|
channel._send_text.assert_any_await("wx-user", "hint", "ctx-refreshed")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_buffer_flush_failure_does_not_block_final_answer() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel.send_tool_hints = True
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._context_token_at["wx-user"] = time.time()
|
||||||
|
channel._send_text = AsyncMock(side_effect=[RuntimeError("boom"), None])
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "hint",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True, "_tool_hint": True},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "Done",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel._send_text.await_count == 2
|
||||||
|
channel._send_text.assert_any_await("wx-user", "hint", "ctx-1")
|
||||||
|
channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_buffer_flushed_on_stream_end() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel.send_tool_hints = True
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-1"
|
||||||
|
channel._context_token_at["wx-user"] = time.time()
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "hint",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True, "_tool_hint": True},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel.send_delta("wx-user", "", {"_stream_end": True})
|
||||||
|
|
||||||
|
channel._send_text.assert_awaited_once_with("wx-user", "hint", "ctx-1")
|
||||||
|
assert "wx-user" not in channel._pending_tool_hints
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_clears_buffer() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._pending_tool_hints["wx-user"] = ["hint1", "hint2"]
|
||||||
|
await channel.stop()
|
||||||
|
assert "wx-user" not in channel._pending_tool_hints
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_tool_hints_false_drops_tool_hints() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel.send_tool_hints = False
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
|
||||||
|
await channel.send(
|
||||||
|
type(
|
||||||
|
"Msg",
|
||||||
|
(),
|
||||||
|
{
|
||||||
|
"chat_id": "wx-user",
|
||||||
|
"content": "hint",
|
||||||
|
"media": [],
|
||||||
|
"metadata": {"_progress": True, "_tool_hint": True},
|
||||||
|
},
|
||||||
|
)()
|
||||||
|
)
|
||||||
|
|
||||||
|
channel._send_text.assert_not_awaited()
|
||||||
|
assert "wx-user" not in channel._pending_tool_hints
|
||||||
|
|||||||
@ -192,3 +192,20 @@ def test_match_provider_uses_preset_provider_when_forced() -> None:
|
|||||||
})
|
})
|
||||||
name = config.get_provider_name()
|
name = config.get_provider_name()
|
||||||
assert name == "anthropic"
|
assert name == "anthropic"
|
||||||
|
|
||||||
|
|
||||||
|
def test_match_provider_routes_forced_novita_model_api_models() -> None:
|
||||||
|
config = Config.model_validate({
|
||||||
|
"providers": {
|
||||||
|
"novita": {"apiKey": "sk-test"},
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
"provider": "novita",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert config.get_provider_name() == "novita"
|
||||||
|
assert config.get_api_base() == "https://api.novita.ai/openai"
|
||||||
|
|||||||
@ -56,6 +56,35 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None:
|
|||||||
assert result.content == "hello world"
|
assert result.content == "hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_provider_parse_chunks_deduplicates_parallel_tool_call_ids() -> None:
|
||||||
|
chunks = [{
|
||||||
|
"choices": [{
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"id": "call_dup",
|
||||||
|
"function": {"name": "read_file", "arguments": '{"path":"a.txt"}'},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"index": 1,
|
||||||
|
"id": "call_dup",
|
||||||
|
"function": {"name": "read_file", "arguments": '{"path":"b.txt"}'},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}],
|
||||||
|
}]
|
||||||
|
|
||||||
|
result = OpenAICompatProvider._parse_chunks(chunks)
|
||||||
|
ids = [tool_call.id for tool_call in result.tool_calls or []]
|
||||||
|
|
||||||
|
assert ids[0] == "call_dup"
|
||||||
|
assert len(ids) == 2
|
||||||
|
assert len(set(ids)) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_local_provider_502_error_includes_reachability_hint() -> None:
|
def test_local_provider_502_error_includes_reachability_hint() -> None:
|
||||||
spec = find_by_name("ollama")
|
spec = find_by_name("ollama")
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
|||||||
@ -9,11 +9,13 @@ import pytest
|
|||||||
|
|
||||||
from nanobot.providers.image_generation import (
|
from nanobot.providers.image_generation import (
|
||||||
AIHubMixImageGenerationClient,
|
AIHubMixImageGenerationClient,
|
||||||
|
CodexImageGenerationClient,
|
||||||
GeminiImageGenerationClient,
|
GeminiImageGenerationClient,
|
||||||
GeneratedImageResponse,
|
GeneratedImageResponse,
|
||||||
ImageGenerationError,
|
ImageGenerationError,
|
||||||
MiniMaxImageGenerationClient,
|
MiniMaxImageGenerationClient,
|
||||||
OllamaImageGenerationClient,
|
OllamaImageGenerationClient,
|
||||||
|
OpenAIImageGenerationClient,
|
||||||
OpenRouterImageGenerationClient,
|
OpenRouterImageGenerationClient,
|
||||||
StepFunImageGenerationClient,
|
StepFunImageGenerationClient,
|
||||||
)
|
)
|
||||||
@ -37,12 +39,14 @@ class FakeResponse:
|
|||||||
payload: dict[str, Any],
|
payload: dict[str, Any],
|
||||||
status_code: int = 200,
|
status_code: int = 200,
|
||||||
content: bytes = b"",
|
content: bytes = b"",
|
||||||
|
sse_lines: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._payload = payload
|
self._payload = payload
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.text = str(payload)
|
self.text = str(payload)
|
||||||
self.content = content
|
self.content = content
|
||||||
self.request = httpx.Request("POST", "https://openrouter.ai/api/v1/chat/completions")
|
self.request = httpx.Request("POST", "https://openrouter.ai/api/v1/chat/completions")
|
||||||
|
self._sse_lines = sse_lines
|
||||||
|
|
||||||
def json(self) -> dict[str, Any]:
|
def json(self) -> dict[str, Any]:
|
||||||
return self._payload
|
return self._payload
|
||||||
@ -52,6 +56,15 @@ class FakeResponse:
|
|||||||
response = httpx.Response(self.status_code, request=self.request, text=self.text)
|
response = httpx.Response(self.status_code, request=self.request, text=self.text)
|
||||||
raise httpx.HTTPStatusError("failed", request=self.request, response=response)
|
raise httpx.HTTPStatusError("failed", request=self.request, response=response)
|
||||||
|
|
||||||
|
async def aiter_lines(self):
|
||||||
|
if self._sse_lines is not None:
|
||||||
|
for line in self._sse_lines:
|
||||||
|
yield line
|
||||||
|
return
|
||||||
|
# Fallback: treat response text as SSE lines
|
||||||
|
for line in self.text.split("\n"):
|
||||||
|
yield line
|
||||||
|
|
||||||
|
|
||||||
class FakeClient:
|
class FakeClient:
|
||||||
def __init__(self, response: FakeResponse) -> None:
|
def __init__(self, response: FakeResponse) -> None:
|
||||||
@ -564,3 +577,437 @@ async def test_stepfun_no_images_raises() -> None:
|
|||||||
|
|
||||||
with pytest.raises(ImageGenerationError, match="returned no images"):
|
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||||
await client.generate(prompt="draw", model="step-image-edit-2")
|
await client.generate(prompt="draw", model="step-image-edit-2")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OpenAI
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_payload_and_response() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
api_base="https://api.openai.com/v1",
|
||||||
|
extra_headers={"X-Test": "1"},
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.generate(
|
||||||
|
prompt="a cat on the moon",
|
||||||
|
model="dall-e-3",
|
||||||
|
aspect_ratio="16:9",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.images == [PNG_DATA_URL]
|
||||||
|
call = fake.calls[0]
|
||||||
|
assert call["url"] == "https://api.openai.com/v1/images/generations"
|
||||||
|
assert call["headers"]["Authorization"] == "Bearer sk-openai-test"
|
||||||
|
assert call["headers"]["X-Test"] == "1"
|
||||||
|
body = call["json"]
|
||||||
|
assert body["model"] == "dall-e-3"
|
||||||
|
assert body["prompt"] == "a cat on the moon"
|
||||||
|
assert body["response_format"] == "b64_json"
|
||||||
|
assert body["n"] == 1
|
||||||
|
assert body["size"] == "1792x1024"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_b64_json_response_uses_detected_mime() -> None:
|
||||||
|
raw_b64 = base64.b64encode(JPEG_BYTES).decode("ascii")
|
||||||
|
fake = FakeClient(FakeResponse({"data": [{"b64_json": raw_b64}]}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.generate(prompt="draw", model="dall-e-3")
|
||||||
|
|
||||||
|
assert response.images == [f"data:image/jpeg;base64,{raw_b64}"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_url_download_fallback() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({"data": [{"url": "https://cdn.example/image.png"}]}))
|
||||||
|
fake.get_response = FakeResponse({}, content=PNG_BYTES)
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.generate(prompt="draw", model="dall-e-3")
|
||||||
|
|
||||||
|
assert response.images[0].startswith("data:image/png;base64,")
|
||||||
|
assert fake.get_calls[0]["url"] == "https://cdn.example/image.png"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_multiple_images() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({
|
||||||
|
"data": [
|
||||||
|
{"b64_json": RAW_B64},
|
||||||
|
{"b64_json": RAW_B64},
|
||||||
|
]
|
||||||
|
}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.generate(prompt="draw", model="dall-e-3")
|
||||||
|
|
||||||
|
assert len(response.images) == 2
|
||||||
|
assert response.images == [PNG_DATA_URL, PNG_DATA_URL]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_aspect_ratio_to_size() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="1:1")
|
||||||
|
assert fake.calls[0]["json"]["size"] == "1024x1024"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_dalle3_uses_supported_orientation_sizes() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="3:4")
|
||||||
|
await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="4:3")
|
||||||
|
|
||||||
|
assert fake.calls[0]["json"]["size"] == "1024x1792"
|
||||||
|
assert fake.calls[1]["json"]["size"] == "1792x1024"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_dalle2_uses_square_size_for_non_square_ratios() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.generate(prompt="draw", model="dall-e-2", aspect_ratio="16:9")
|
||||||
|
|
||||||
|
assert fake.calls[0]["json"]["size"] == "1024x1024"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_gpt_image_uses_supported_landscape_size() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="16:9")
|
||||||
|
|
||||||
|
assert fake.calls[0]["json"]["size"] == "1536x1024"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_gpt_image_uses_supported_orientation_sizes() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="3:4")
|
||||||
|
await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="4:3")
|
||||||
|
|
||||||
|
assert fake.calls[0]["json"]["size"] == "1024x1536"
|
||||||
|
assert fake.calls[1]["json"]["size"] == "1536x1024"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_default_size_when_no_aspect_ratio() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.generate(prompt="draw", model="dall-e-3")
|
||||||
|
|
||||||
|
body = fake.calls[0]["json"]
|
||||||
|
assert body["size"] == "1024x1024"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_ignores_explicit_size_unsupported_by_model_family() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.generate(
|
||||||
|
prompt="draw",
|
||||||
|
model="dall-e-3",
|
||||||
|
aspect_ratio="16:9",
|
||||||
|
image_size="1536x1024",
|
||||||
|
)
|
||||||
|
|
||||||
|
body = fake.calls[0]["json"]
|
||||||
|
assert body["size"] == "1792x1024"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_uses_explicit_image_size() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.generate(
|
||||||
|
prompt="draw",
|
||||||
|
model="dall-e-3",
|
||||||
|
aspect_ratio="16:9",
|
||||||
|
image_size="1024x1024",
|
||||||
|
)
|
||||||
|
|
||||||
|
body = fake.calls[0]["json"]
|
||||||
|
assert body["size"] == "1024x1024"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_requires_api_key() -> None:
|
||||||
|
client = OpenAIImageGenerationClient(api_key=None)
|
||||||
|
|
||||||
|
with pytest.raises(ImageGenerationError, match="API key"):
|
||||||
|
await client.generate(prompt="draw", model="dall-e-3")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# OpenAI Codex (Responses API)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_codex_payload_and_response(monkeypatch) -> None:
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeToken:
|
||||||
|
account_id: str = "acct-123"
|
||||||
|
access: str = "oauth-token"
|
||||||
|
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||||
|
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||||
|
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||||
|
|
||||||
|
sse_lines = [
|
||||||
|
'data: {"type":"response.output_item.added","item":{"id":"ig_1","type":"image_generation_call","status":"in_progress"}}',
|
||||||
|
"",
|
||||||
|
f'data: {{"type":"response.output_item.done","item":{{"id":"ig_1","type":"image_generation_call","result":"{PNG_DATA_URL}","status":"completed"}}}}',
|
||||||
|
"",
|
||||||
|
'data: [DONE]',
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
fake = FakeClient(FakeResponse({}, sse_lines=sse_lines))
|
||||||
|
client = CodexImageGenerationClient(
|
||||||
|
api_key=None,
|
||||||
|
api_base="https://chatgpt.com/backend-api",
|
||||||
|
extra_headers={"X-Test": "1"},
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.generate(
|
||||||
|
prompt="draw a cat",
|
||||||
|
model="gpt-5.4",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.images == [PNG_DATA_URL]
|
||||||
|
assert response.content == ""
|
||||||
|
call = fake.calls[0]
|
||||||
|
assert call["url"] == "https://chatgpt.com/backend-api/codex/responses"
|
||||||
|
assert call["headers"]["Authorization"] == "Bearer oauth-token"
|
||||||
|
assert call["headers"]["chatgpt-account-id"] == "acct-123"
|
||||||
|
assert call["headers"]["OpenAI-Beta"] == "responses=experimental"
|
||||||
|
assert call["headers"]["X-Test"] == "1"
|
||||||
|
body = call["json"]
|
||||||
|
assert body["model"] == "gpt-5.4"
|
||||||
|
assert body["instructions"] == "Generate an image based on the user's request."
|
||||||
|
assert body["input"] == [{"role": "user", "content": "draw a cat"}]
|
||||||
|
assert body["tools"] == [{"type": "image_generation"}]
|
||||||
|
assert body["tool_choice"] == "auto"
|
||||||
|
assert body["store"] is False
|
||||||
|
assert body["stream"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_codex_strips_model_prefix(monkeypatch) -> None:
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeToken:
|
||||||
|
account_id: str = "acct-123"
|
||||||
|
access: str = "oauth-token"
|
||||||
|
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||||
|
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||||
|
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||||
|
|
||||||
|
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||||
|
f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":"{PNG_DATA_URL}"}}}}',
|
||||||
|
"",
|
||||||
|
'data: [DONE]',
|
||||||
|
"",
|
||||||
|
]))
|
||||||
|
client = CodexImageGenerationClient(
|
||||||
|
api_key=None, client=fake # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.generate(prompt="draw", model="openai-codex/gpt-5.4")
|
||||||
|
|
||||||
|
assert fake.calls[0]["json"]["model"] == "gpt-5.4"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_codex_requires_oauth(monkeypatch) -> None:
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
raise RuntimeError("no token")
|
||||||
|
|
||||||
|
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||||
|
|
||||||
|
client = CodexImageGenerationClient(api_key=None)
|
||||||
|
|
||||||
|
with pytest.raises(ImageGenerationError, match="OAuth token"):
|
||||||
|
await client.generate(prompt="draw", model="gpt-5.4")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_codex_no_images_raises(monkeypatch) -> None:
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeToken:
|
||||||
|
account_id: str = "acct-123"
|
||||||
|
access: str = "oauth-token"
|
||||||
|
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||||
|
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||||
|
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||||
|
|
||||||
|
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||||
|
'data: {"type":"response.completed","response":{"status":"completed"}}',
|
||||||
|
"",
|
||||||
|
'data: [DONE]',
|
||||||
|
"",
|
||||||
|
]))
|
||||||
|
client = CodexImageGenerationClient(
|
||||||
|
api_key=None, client=fake # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||||
|
await client.generate(prompt="draw", model="gpt-5.4")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_codex_extracts_text_content(monkeypatch) -> None:
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeToken:
|
||||||
|
account_id: str = "acct-123"
|
||||||
|
access: str = "oauth-token"
|
||||||
|
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||||
|
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||||
|
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||||
|
|
||||||
|
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||||
|
'data: {"type":"response.output_text.delta","delta":"Here "}',
|
||||||
|
"",
|
||||||
|
'data: {"type":"response.output_text.delta","delta":"is your cat image."}',
|
||||||
|
"",
|
||||||
|
f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":"{PNG_DATA_URL}"}}}}',
|
||||||
|
"",
|
||||||
|
'data: [DONE]',
|
||||||
|
"",
|
||||||
|
]))
|
||||||
|
client = CodexImageGenerationClient(
|
||||||
|
api_key=None, client=fake # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.generate(prompt="draw a cat", model="gpt-5.4")
|
||||||
|
|
||||||
|
assert response.images == [PNG_DATA_URL]
|
||||||
|
assert response.content == "Here is your cat image."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_codex_json_result_format(monkeypatch) -> None:
|
||||||
|
"""image_generation_call result can be a dict with image_url key."""
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FakeToken:
|
||||||
|
account_id: str = "acct-123"
|
||||||
|
access: str = "oauth-token"
|
||||||
|
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||||
|
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||||
|
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||||
|
|
||||||
|
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||||
|
f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":{{"image_url":"{PNG_DATA_URL}"}}}}}}',
|
||||||
|
"",
|
||||||
|
'data: [DONE]',
|
||||||
|
"",
|
||||||
|
]))
|
||||||
|
client = CodexImageGenerationClient(
|
||||||
|
api_key=None, client=fake # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await client.generate(prompt="draw", model="gpt-5.4")
|
||||||
|
|
||||||
|
assert response.images == [PNG_DATA_URL]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_no_images_raises() -> None:
|
||||||
|
fake = FakeClient(FakeResponse({"data": []}))
|
||||||
|
client = OpenAIImageGenerationClient(
|
||||||
|
api_key="sk-openai-test",
|
||||||
|
client=fake, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||||
|
await client.generate(prompt="draw", model="dall-e-3")
|
||||||
|
|||||||
@ -441,6 +441,15 @@ def test_openrouter_spec_is_gateway() -> None:
|
|||||||
assert spec.default_api_base == "https://openrouter.ai/api/v1"
|
assert spec.default_api_base == "https://openrouter.ai/api/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_novita_spec_uses_openai_compatible_gateway() -> None:
|
||||||
|
spec = find_by_name("novita")
|
||||||
|
assert spec is not None
|
||||||
|
assert spec.is_gateway is True
|
||||||
|
assert spec.backend == "openai_compat"
|
||||||
|
assert spec.env_key == "NOVITA_API_KEY"
|
||||||
|
assert spec.default_api_base == "https://api.novita.ai/openai"
|
||||||
|
|
||||||
|
|
||||||
def test_gemma_routes_to_gemini_provider() -> None:
|
def test_gemma_routes_to_gemini_provider() -> None:
|
||||||
"""gemma models (e.g. gemma-3-27b-it) must auto-route to Gemini when GEMINI_API_KEY is set.
|
"""gemma models (e.g. gemma-3-27b-it) must auto-route to Gemini when GEMINI_API_KEY is set.
|
||||||
Users running gemma via the Gemini API endpoint expect automatic provider detection."""
|
Users running gemma via the Gemini API endpoint expect automatic provider detection."""
|
||||||
@ -1007,6 +1016,41 @@ def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -
|
|||||||
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_deduplicates_duplicate_tool_call_ids_in_history() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
sanitized = provider._sanitize_messages([
|
||||||
|
{"role": "user", "content": "check both files"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "ab1b45c2a",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": '{"path":"a.txt"}'},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "ab1b45c2a",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": '{"path":"b.txt"}'},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "a"},
|
||||||
|
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "b"},
|
||||||
|
{"role": "user", "content": "continue"},
|
||||||
|
])
|
||||||
|
|
||||||
|
tool_call_ids = [tc["id"] for tc in sanitized[1]["tool_calls"]]
|
||||||
|
tool_result_ids = [sanitized[2]["tool_call_id"], sanitized[3]["tool_call_id"]]
|
||||||
|
|
||||||
|
assert tool_call_ids[0] == "ab1b45c2a"
|
||||||
|
assert len(tool_call_ids) == len(set(tool_call_ids)) == 2
|
||||||
|
assert tool_result_ids == tool_call_ids
|
||||||
|
|
||||||
|
|
||||||
def test_openai_compat_stringifies_dict_tool_arguments() -> None:
|
def test_openai_compat_stringifies_dict_tool_arguments() -> None:
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
provider = OpenAICompatProvider()
|
provider = OpenAICompatProvider()
|
||||||
@ -1376,12 +1420,15 @@ def test_kimi_k25_thinking_enabled() -> None:
|
|||||||
"""kimi-k2.5 with reasoning_effort set should opt in to thinking."""
|
"""kimi-k2.5 with reasoning_effort set should opt in to thinking."""
|
||||||
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="medium")
|
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="medium")
|
||||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||||
|
# Moonshot rejects both 'reasoning_effort' and 'thinking' (#3939)
|
||||||
|
assert "reasoning_effort" not in kw
|
||||||
|
|
||||||
|
|
||||||
def test_kimi_k25_thinking_disabled_for_minimal() -> None:
|
def test_kimi_k25_thinking_disabled_for_minimal() -> None:
|
||||||
"""reasoning_effort='minimal' maps to thinking disabled for kimi-k2.5."""
|
"""reasoning_effort='minimal' maps to thinking disabled for kimi-k2.5."""
|
||||||
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="minimal")
|
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="minimal")
|
||||||
assert kw.get("extra_body") == {"thinking": {"type": "disabled"}}
|
assert kw.get("extra_body") == {"thinking": {"type": "disabled"}}
|
||||||
|
assert "reasoning_effort" not in kw
|
||||||
|
|
||||||
|
|
||||||
def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None:
|
def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None:
|
||||||
@ -1391,21 +1438,36 @@ def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_kimi_k25_thinking_enabled_with_openrouter_prefix() -> None:
|
def test_kimi_k25_thinking_enabled_with_openrouter_prefix() -> None:
|
||||||
"""OpenRouter-style model names like moonshotai/kimi-k2.5 must trigger thinking."""
|
"""OpenRouter-style model names like moonshotai/kimi-k2.5 must trigger thinking.
|
||||||
|
|
||||||
|
OR drops upstream-provider `thinking` fields, so the same intent also has
|
||||||
|
to go through OR's `reasoning.effort` shape (#3851 follow-up).
|
||||||
|
"""
|
||||||
kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.5", reasoning_effort="medium")
|
kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.5", reasoning_effort="medium")
|
||||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
assert kw.get("extra_body") == {
|
||||||
|
"thinking": {"type": "enabled"},
|
||||||
|
"reasoning": {"effort": "medium"},
|
||||||
|
}
|
||||||
|
# Even via OR, reasoning_effort wire kwarg is dropped for kimi models
|
||||||
|
assert "reasoning_effort" not in kw
|
||||||
|
|
||||||
|
|
||||||
def test_kimi_k26_thinking_enabled() -> None:
|
def test_kimi_k26_thinking_enabled() -> None:
|
||||||
"""kimi-k2.6 with reasoning_effort set should opt in to thinking."""
|
"""kimi-k2.6 with reasoning_effort set should opt in to thinking."""
|
||||||
kw = _build_kwargs_for("moonshot", "kimi-k2.6", reasoning_effort="medium")
|
kw = _build_kwargs_for("moonshot", "kimi-k2.6", reasoning_effort="medium")
|
||||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||||
|
assert "reasoning_effort" not in kw
|
||||||
|
|
||||||
|
|
||||||
def test_kimi_k26_thinking_enabled_with_openrouter_prefix() -> None:
|
def test_kimi_k26_thinking_enabled_with_openrouter_prefix() -> None:
|
||||||
"""OpenRouter-style names like moonshotai/kimi-k2.6 must trigger thinking."""
|
"""OpenRouter-style names like moonshotai/kimi-k2.6 must trigger thinking
|
||||||
|
via both upstream `thinking` and OR's `reasoning.effort`."""
|
||||||
kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.6", reasoning_effort="medium")
|
kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.6", reasoning_effort="medium")
|
||||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
assert kw.get("extra_body") == {
|
||||||
|
"thinking": {"type": "enabled"},
|
||||||
|
"reasoning": {"effort": "medium"},
|
||||||
|
}
|
||||||
|
assert "reasoning_effort" not in kw
|
||||||
|
|
||||||
|
|
||||||
def test_moonshot_kimi_k26_temperature_override() -> None:
|
def test_moonshot_kimi_k26_temperature_override() -> None:
|
||||||
@ -1424,6 +1486,7 @@ def test_kimi_k26_code_preview_thinking_enabled() -> None:
|
|||||||
"""k2.6-code-preview also supports thinking; should behave like k2.5."""
|
"""k2.6-code-preview also supports thinking; should behave like k2.5."""
|
||||||
kw = _build_kwargs_for("moonshot", "k2.6-code-preview", reasoning_effort="high")
|
kw = _build_kwargs_for("moonshot", "k2.6-code-preview", reasoning_effort="high")
|
||||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||||
|
assert "reasoning_effort" not in kw
|
||||||
|
|
||||||
|
|
||||||
def test_kimi_k2_series_no_thinking_injection() -> None:
|
def test_kimi_k2_series_no_thinking_injection() -> None:
|
||||||
@ -1453,6 +1516,7 @@ def test_kimi_k25_thinking_disabled_for_none_string() -> None:
|
|||||||
"""reasoning_effort='none' maps to thinking disabled for kimi-k2.5."""
|
"""reasoning_effort='none' maps to thinking disabled for kimi-k2.5."""
|
||||||
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="none")
|
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="none")
|
||||||
assert kw.get("extra_body") == {"thinking": {"type": "disabled"}}
|
assert kw.get("extra_body") == {"thinking": {"type": "disabled"}}
|
||||||
|
assert "reasoning_effort" not in kw
|
||||||
|
|
||||||
|
|
||||||
def test_dashscope_thinking_disabled_for_none_string() -> None:
|
def test_dashscope_thinking_disabled_for_none_string() -> None:
|
||||||
|
|||||||
97
tests/providers/test_novita_provider.py
Normal file
97
tests/providers/test_novita_provider.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
"""Tests for the Novita AI provider registration."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from nanobot.config.schema import Config, ProvidersConfig
|
||||||
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
from nanobot.providers.registry import PROVIDERS, find_by_name
|
||||||
|
|
||||||
|
|
||||||
|
def test_novita_config_field_exists() -> None:
|
||||||
|
config = ProvidersConfig()
|
||||||
|
|
||||||
|
assert hasattr(config, "novita")
|
||||||
|
|
||||||
|
|
||||||
|
def test_novita_provider_in_registry() -> None:
|
||||||
|
specs = {spec.name: spec for spec in PROVIDERS}
|
||||||
|
|
||||||
|
assert "novita" in specs
|
||||||
|
novita = specs["novita"]
|
||||||
|
assert novita.backend == "openai_compat"
|
||||||
|
assert novita.env_key == "NOVITA_API_KEY"
|
||||||
|
assert novita.display_name == "Novita AI"
|
||||||
|
assert novita.is_gateway is True
|
||||||
|
assert novita.detect_by_base_keyword == "novita"
|
||||||
|
assert novita.default_api_base == "https://api.novita.ai/openai"
|
||||||
|
assert novita.strip_model_prefix is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_by_name_novita() -> None:
|
||||||
|
spec = find_by_name("novita")
|
||||||
|
|
||||||
|
assert spec is not None
|
||||||
|
assert spec.name == "novita"
|
||||||
|
|
||||||
|
|
||||||
|
def test_novita_forced_provider_uses_default_api_base() -> None:
|
||||||
|
config = Config.model_validate({
|
||||||
|
"providers": {
|
||||||
|
"novita": {
|
||||||
|
"apiKey": "novita-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
"provider": "novita",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert config.get_provider_name("deepseek-v4-pro") == "novita"
|
||||||
|
assert config.get_api_key("deepseek-v4-pro") == "novita-key"
|
||||||
|
assert config.get_api_base("deepseek-v4-pro") == "https://api.novita.ai/openai"
|
||||||
|
|
||||||
|
|
||||||
|
def test_novita_gateway_routes_unprefixed_models_when_configured() -> None:
|
||||||
|
config = Config.model_validate({
|
||||||
|
"providers": {
|
||||||
|
"novita": {
|
||||||
|
"apiKey": "novita-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"agents": {
|
||||||
|
"defaults": {
|
||||||
|
"model": "deepseek-v4-pro",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
assert config.get_provider_name("deepseek-v4-pro") == "novita"
|
||||||
|
assert config.get_api_key("deepseek-v4-pro") == "novita-key"
|
||||||
|
assert config.get_api_base("deepseek-v4-pro") == "https://api.novita.ai/openai"
|
||||||
|
|
||||||
|
|
||||||
|
def test_novita_preserves_model_api_id() -> None:
|
||||||
|
spec = find_by_name("novita")
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="novita-key",
|
||||||
|
default_model="deepseek-v4-pro",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = provider._build_kwargs(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=None,
|
||||||
|
model="deepseek-v4-pro",
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.7,
|
||||||
|
reasoning_effort=None,
|
||||||
|
tool_choice=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert kwargs["model"] == "deepseek-v4-pro"
|
||||||
|
assert kwargs["max_tokens"] == 1024
|
||||||
|
assert "max_completion_tokens" not in kwargs
|
||||||
@ -32,7 +32,7 @@ def _mimo_spec():
|
|||||||
|
|
||||||
|
|
||||||
def _openrouter_spec():
|
def _openrouter_spec():
|
||||||
"""Return the registered OpenRouter ProviderSpec (no thinking_style)."""
|
"""Return the registered OpenRouter ProviderSpec."""
|
||||||
specs = {s.name: s for s in PROVIDERS}
|
specs = {s.name: s for s in PROVIDERS}
|
||||||
return specs["openrouter"]
|
return specs["openrouter"]
|
||||||
|
|
||||||
@ -77,6 +77,13 @@ def test_xiaomi_mimo_uses_thinking_type_style():
|
|||||||
assert spec.default_api_base == "https://api.xiaomimimo.com/v1"
|
assert spec.default_api_base == "https://api.xiaomimimo.com/v1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openrouter_declares_gateway_reasoning_style():
|
||||||
|
"""OpenRouter uses its own reasoning.effort field for routed thinking models."""
|
||||||
|
spec = _openrouter_spec()
|
||||||
|
assert spec.thinking_style == ""
|
||||||
|
assert spec.gateway_reasoning_style == "reasoning_effort"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _build_kwargs wire-format
|
# _build_kwargs wire-format
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -142,9 +149,11 @@ def test_mimo_reasoning_effort_unset_preserves_provider_default():
|
|||||||
|
|
||||||
|
|
||||||
def test_mimo_via_openrouter_reasoning_effort_none_disables_thinking():
|
def test_mimo_via_openrouter_reasoning_effort_none_disables_thinking():
|
||||||
"""OpenRouter routes MiMo as "xiaomi/mimo-v2.5-pro"; the openrouter spec
|
"""OpenRouter routes MiMo as "xiaomi/mimo-v2.5-pro" and does NOT forward
|
||||||
has no thinking_style, so the disable signal must come from the
|
extra_body.thinking to upstream, so a disable signal must also reach OR
|
||||||
model-name path (#3845)."""
|
in its own `reasoning.effort` shape. Verifies both the upstream-MiMo
|
||||||
|
payload (#3845) and the OR-native payload (#3851 follow-up) are sent.
|
||||||
|
"""
|
||||||
provider = _openrouter_provider("xiaomi/mimo-v2.5-pro")
|
provider = _openrouter_provider("xiaomi/mimo-v2.5-pro")
|
||||||
kwargs = provider._build_kwargs(
|
kwargs = provider._build_kwargs(
|
||||||
messages=_simple_messages(),
|
messages=_simple_messages(),
|
||||||
@ -152,11 +161,15 @@ def test_mimo_via_openrouter_reasoning_effort_none_disables_thinking():
|
|||||||
temperature=0.7, reasoning_effort="none", tool_choice=None,
|
temperature=0.7, reasoning_effort="none", tool_choice=None,
|
||||||
)
|
)
|
||||||
assert "reasoning_effort" not in kwargs
|
assert "reasoning_effort" not in kwargs
|
||||||
assert kwargs["extra_body"] == {"thinking": {"type": "disabled"}}
|
assert kwargs["extra_body"] == {
|
||||||
|
"thinking": {"type": "disabled"},
|
||||||
|
"reasoning": {"effort": "none"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_mimo_via_openrouter_reasoning_effort_medium_enables_thinking():
|
def test_mimo_via_openrouter_reasoning_effort_medium_enables_thinking():
|
||||||
"""Same as the direct path: any non-none/minimal effort enables thinking."""
|
"""Non-none/minimal effort enables thinking and the OR `reasoning.effort`
|
||||||
|
field mirrors the requested effort level."""
|
||||||
provider = _openrouter_provider("xiaomi/mimo-v2.5-pro")
|
provider = _openrouter_provider("xiaomi/mimo-v2.5-pro")
|
||||||
kwargs = provider._build_kwargs(
|
kwargs = provider._build_kwargs(
|
||||||
messages=_simple_messages(),
|
messages=_simple_messages(),
|
||||||
@ -164,7 +177,10 @@ def test_mimo_via_openrouter_reasoning_effort_medium_enables_thinking():
|
|||||||
temperature=0.7, reasoning_effort="medium", tool_choice=None,
|
temperature=0.7, reasoning_effort="medium", tool_choice=None,
|
||||||
)
|
)
|
||||||
assert kwargs.get("reasoning_effort") == "medium"
|
assert kwargs.get("reasoning_effort") == "medium"
|
||||||
assert kwargs["extra_body"] == {"thinking": {"type": "enabled"}}
|
assert kwargs["extra_body"] == {
|
||||||
|
"thinking": {"type": "enabled"},
|
||||||
|
"reasoning": {"effort": "medium"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_mimo_via_openrouter_bare_slug_also_matches():
|
def test_mimo_via_openrouter_bare_slug_also_matches():
|
||||||
@ -176,12 +192,16 @@ def test_mimo_via_openrouter_bare_slug_also_matches():
|
|||||||
tools=None, model=None, max_tokens=100,
|
tools=None, model=None, max_tokens=100,
|
||||||
temperature=0.7, reasoning_effort="none", tool_choice=None,
|
temperature=0.7, reasoning_effort="none", tool_choice=None,
|
||||||
)
|
)
|
||||||
assert kwargs["extra_body"] == {"thinking": {"type": "disabled"}}
|
assert kwargs["extra_body"] == {
|
||||||
|
"thinking": {"type": "disabled"},
|
||||||
|
"reasoning": {"effort": "none"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_mimo_flash_via_openrouter_does_not_inject_thinking():
|
def test_mimo_flash_via_openrouter_does_not_inject_thinking():
|
||||||
"""mimo-v2-flash has no thinking mode per Xiaomi docs; the allowlist
|
"""mimo-v2-flash has no thinking mode per Xiaomi docs; the allowlist
|
||||||
excludes it, so no thinking field should be injected on the gateway path."""
|
excludes it, so neither the upstream `thinking` field nor OR's
|
||||||
|
`reasoning.effort` should be injected on the gateway path."""
|
||||||
provider = _openrouter_provider("xiaomi/mimo-v2-flash")
|
provider = _openrouter_provider("xiaomi/mimo-v2-flash")
|
||||||
kwargs = provider._build_kwargs(
|
kwargs = provider._build_kwargs(
|
||||||
messages=_simple_messages(),
|
messages=_simple_messages(),
|
||||||
@ -200,3 +220,18 @@ def test_non_mimo_model_via_openrouter_unaffected():
|
|||||||
temperature=0.7, reasoning_effort="none", tool_choice=None,
|
temperature=0.7, reasoning_effort="none", tool_choice=None,
|
||||||
)
|
)
|
||||||
assert "extra_body" not in kwargs
|
assert "extra_body" not in kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_kimi_via_openrouter_also_injects_reasoning_effort():
|
||||||
|
"""Kimi has the same gateway problem as MiMo: OR drops the upstream
|
||||||
|
`thinking` field. The same OR-reasoning injection should fire."""
|
||||||
|
provider = _openrouter_provider("moonshotai/kimi-k2.5")
|
||||||
|
kwargs = provider._build_kwargs(
|
||||||
|
messages=_simple_messages(),
|
||||||
|
tools=None, model=None, max_tokens=100,
|
||||||
|
temperature=0.7, reasoning_effort="none", tool_choice=None,
|
||||||
|
)
|
||||||
|
assert kwargs["extra_body"] == {
|
||||||
|
"thinking": {"type": "disabled"},
|
||||||
|
"reasoning": {"effort": "none"},
|
||||||
|
}
|
||||||
|
|||||||
330
tests/tools/test_apply_patch_tool.py
Normal file
330
tests/tools/test_apply_patch_tool.py
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from nanobot.agent.tools.apply_patch import ApplyPatchTool
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_replace(tmp_path):
|
||||||
|
target = tmp_path / "calc.py"
|
||||||
|
target.write_text("def add(a, b):\n return a + b\n")
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "calc.py",
|
||||||
|
"action": "replace",
|
||||||
|
"old_text": " return a + b",
|
||||||
|
"new_text": " return a - b",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "update calc.py" in result
|
||||||
|
assert target.read_text() == "def add(a, b):\n return a - b\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_add_new_file(tmp_path):
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "config.py",
|
||||||
|
"action": "add",
|
||||||
|
"new_text": "DEBUG = True",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "add config.py" in result
|
||||||
|
assert (tmp_path / "config.py").read_text() == "DEBUG = True\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_preserves_new_file_trailing_blank_lines(tmp_path):
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "notes.txt",
|
||||||
|
"action": "add",
|
||||||
|
"new_text": "one\n\n",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "add notes.txt" in result
|
||||||
|
assert (tmp_path / "notes.txt").read_text() == "one\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_add_to_existing_file(tmp_path):
|
||||||
|
target = tmp_path / "log.py"
|
||||||
|
target.write_text("import logging\n\nlogger = logging.getLogger(__name__)\n")
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "log.py",
|
||||||
|
"action": "add",
|
||||||
|
"new_text": "def debug(msg):\n logger.debug(msg)",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "update log.py" in result
|
||||||
|
assert (
|
||||||
|
target.read_text()
|
||||||
|
== "import logging\n\nlogger = logging.getLogger(__name__)\ndef debug(msg):\n logger.debug(msg)\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_delete(tmp_path):
|
||||||
|
target = tmp_path / "utils.py"
|
||||||
|
target.write_text("def unused():\n pass\ndef used():\n return 1\n")
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "utils.py",
|
||||||
|
"action": "delete",
|
||||||
|
"old_text": "def unused():\n pass\n",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "update utils.py" in result
|
||||||
|
assert target.read_text() == "def used():\n return 1\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_delete_entire_file(tmp_path):
|
||||||
|
target = tmp_path / "obsolete.txt"
|
||||||
|
target.write_text("remove me\n")
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "obsolete.txt",
|
||||||
|
"action": "delete",
|
||||||
|
"old_text": "remove me\n",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "delete obsolete.txt" in result
|
||||||
|
assert not target.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_delete_substring_with_surrounding_whitespace(tmp_path):
|
||||||
|
target = tmp_path / "keep_whitespace.txt"
|
||||||
|
target.write_text(" token \n")
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "keep_whitespace.txt",
|
||||||
|
"action": "delete",
|
||||||
|
"old_text": "token",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "update keep_whitespace.txt" in result
|
||||||
|
assert target.exists()
|
||||||
|
assert target.read_text() == " \n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_batch_multiple_files(tmp_path):
|
||||||
|
a = tmp_path / "a.py"
|
||||||
|
a.write_text("X = 1\n")
|
||||||
|
b = tmp_path / "b.py"
|
||||||
|
b.write_text("from a import X\nprint(X)\n")
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "a.py",
|
||||||
|
"action": "replace",
|
||||||
|
"old_text": "X = 1",
|
||||||
|
"new_text": "Y = 1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "b.py",
|
||||||
|
"action": "replace",
|
||||||
|
"old_text": "from a import X",
|
||||||
|
"new_text": "from a import Y",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "update a.py" in result
|
||||||
|
assert "update b.py" in result
|
||||||
|
assert a.read_text() == "Y = 1\n"
|
||||||
|
assert b.read_text() == "from a import Y\nprint(X)\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_rejects_ambiguous_old_text(tmp_path):
|
||||||
|
target = tmp_path / "repeated.txt"
|
||||||
|
target.write_text("target\nmiddle\ntarget\n")
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "repeated.txt",
|
||||||
|
"action": "replace",
|
||||||
|
"old_text": "target",
|
||||||
|
"new_text": "changed",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "old_text appears multiple times" in result
|
||||||
|
assert target.read_text() == "target\nmiddle\ntarget\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_dry_run_validates_without_writing(tmp_path):
|
||||||
|
target = tmp_path / "dry.txt"
|
||||||
|
target.write_text("before\n")
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "dry.txt",
|
||||||
|
"action": "replace",
|
||||||
|
"old_text": "before",
|
||||||
|
"new_text": "after",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "added.txt",
|
||||||
|
"action": "add",
|
||||||
|
"new_text": "new",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
dry_run=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Patch dry-run succeeded" in result
|
||||||
|
assert target.read_text() == "before\n"
|
||||||
|
assert not (tmp_path / "added.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_rejects_absolute_and_parent_paths(tmp_path):
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
absolute = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "/tmp/owned.txt",
|
||||||
|
"action": "add",
|
||||||
|
"new_text": "nope",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
parent = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "../owned.txt",
|
||||||
|
"action": "add",
|
||||||
|
"new_text": "nope",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
windows_absolute = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": r"C:\owned.txt",
|
||||||
|
"action": "add",
|
||||||
|
"new_text": "nope",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
windows_parent = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": r"..\owned.txt",
|
||||||
|
"action": "add",
|
||||||
|
"new_text": "nope",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "must be relative" in absolute
|
||||||
|
assert "must not contain '..'" in parent
|
||||||
|
assert "must be relative" in windows_absolute
|
||||||
|
assert "must not contain '..'" in windows_parent
|
||||||
|
assert not (tmp_path.parent / "owned.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_reports_invalid_edit_shapes(tmp_path):
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
missing_path = asyncio.run(tool.execute(edits=[{"action": "add", "new_text": "x"}]))
|
||||||
|
missing_action = asyncio.run(tool.execute(edits=[{"path": "x.txt", "new_text": "x"}]))
|
||||||
|
non_object = asyncio.run(tool.execute(edits=["not an object"])) # type: ignore[list-item]
|
||||||
|
|
||||||
|
assert "path required for edit" in missing_path
|
||||||
|
assert "action required for edit: x.txt" in missing_action
|
||||||
|
assert "each edit must be an object" in non_object
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_edits_rolls_back_when_late_operation_fails(tmp_path):
|
||||||
|
first = tmp_path / "first.txt"
|
||||||
|
first.write_text("before\n")
|
||||||
|
tool = ApplyPatchTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(
|
||||||
|
tool.execute(
|
||||||
|
edits=[
|
||||||
|
{
|
||||||
|
"path": "first.txt",
|
||||||
|
"action": "replace",
|
||||||
|
"old_text": "before",
|
||||||
|
"new_text": "after",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "missing.txt",
|
||||||
|
"action": "delete",
|
||||||
|
"old_text": "remove me",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "file to update does not exist: missing.txt" in result
|
||||||
|
assert first.read_text() == "before\n"
|
||||||
@ -1,5 +1,5 @@
|
|||||||
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
|
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
|
||||||
.ipynb detection, and create-file semantics."""
|
notebook JSON editing, and create-file semantics."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -108,22 +108,27 @@ class TestEditCreateFile:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# .ipynb detection
|
# .ipynb editing
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
class TestEditIpynbDetection:
|
class TestEditIpynbFiles:
|
||||||
"""edit_file should refuse .ipynb and suggest notebook_edit."""
|
"""edit_file edits notebooks as normal JSON files."""
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def tool(self, tmp_path):
|
def tool(self, tmp_path):
|
||||||
return EditFileTool(workspace=tmp_path)
|
return EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path):
|
async def test_ipynb_can_be_edited_as_json(self, tool, tmp_path):
|
||||||
f = tmp_path / "analysis.ipynb"
|
f = tmp_path / "analysis.ipynb"
|
||||||
f.write_text('{"cells": []}', encoding="utf-8")
|
f.write_text('{"cells": []}', encoding="utf-8")
|
||||||
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
result = await tool.execute(
|
||||||
assert "notebook" in result.lower()
|
path=str(f),
|
||||||
|
old_text='"cells": []',
|
||||||
|
new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
|
||||||
|
)
|
||||||
|
assert "Successfully edited" in result
|
||||||
|
assert '"source": "hi"' in f.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@ -162,7 +162,7 @@ class TestPathAppendPlatform:
|
|||||||
captured_cmd = None
|
captured_cmd = None
|
||||||
captured_env = {}
|
captured_env = {}
|
||||||
|
|
||||||
async def capture_spawn(cmd, cwd, env):
|
async def capture_spawn(cmd, cwd, env, shell_program=None, login=True):
|
||||||
nonlocal captured_cmd
|
nonlocal captured_cmd
|
||||||
captured_cmd = cmd
|
captured_cmd = cmd
|
||||||
captured_env.update(env)
|
captured_env.update(env)
|
||||||
@ -190,7 +190,7 @@ class TestPathAppendPlatform:
|
|||||||
|
|
||||||
captured_env = {}
|
captured_env = {}
|
||||||
|
|
||||||
async def capture_spawn(cmd, cwd, env):
|
async def capture_spawn(cmd, cwd, env, shell_program=None, login=True):
|
||||||
captured_env.update(env)
|
captured_env.update(env)
|
||||||
return mock_proc
|
return mock_proc
|
||||||
|
|
||||||
|
|||||||
361
tests/tools/test_exec_session_tools.py
Normal file
361
tests/tools/test_exec_session_tools.py
Normal file
@ -0,0 +1,361 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
|
from nanobot.agent.tools.exec_session import ExecSessionManager, ListExecSessionsTool, WriteStdinTool
|
||||||
|
|
||||||
|
|
||||||
|
def _python_command(code: str) -> str:
|
||||||
|
if sys.platform == "win32":
|
||||||
|
return f"{subprocess.list2cmdline([sys.executable])} -u -c {subprocess.list2cmdline([code])}"
|
||||||
|
return f"{shlex.quote(sys.executable)} -u -c {shlex.quote(code)}"
|
||||||
|
|
||||||
|
|
||||||
|
def _session_id(output: str) -> str:
|
||||||
|
match = re.search(r"session_id:\s*([0-9a-f]+)", output)
|
||||||
|
assert match, output
|
||||||
|
return match.group(1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_keeps_one_shot_behavior_without_yield_time_ms(tmp_path):
|
||||||
|
async def run() -> str:
|
||||||
|
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||||
|
return await tool.execute(command="echo hello")
|
||||||
|
|
||||||
|
result = asyncio.run(run())
|
||||||
|
|
||||||
|
assert "hello" in result
|
||||||
|
assert "Exit code: 0" in result
|
||||||
|
assert "session_id:" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_accepts_command_aliases(tmp_path):
|
||||||
|
async def run() -> str:
|
||||||
|
tool = ExecTool(working_dir="/")
|
||||||
|
return await tool.execute(
|
||||||
|
cmd=_python_command("import os; print(os.getcwd())"),
|
||||||
|
workdir=str(tmp_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(run())
|
||||||
|
|
||||||
|
assert str(tmp_path) in result
|
||||||
|
assert "Exit code: 0" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_returns_completed_session_output_when_yield_time_ms_is_used(tmp_path):
|
||||||
|
async def run() -> str:
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||||
|
stdin_tool = WriteStdinTool(manager=manager)
|
||||||
|
|
||||||
|
result = await tool.execute(command="echo hello", yield_time_ms=1000)
|
||||||
|
if "session_id:" in result:
|
||||||
|
sid = _session_id(result)
|
||||||
|
result += "\n" + await stdin_tool.execute(
|
||||||
|
session_id=sid,
|
||||||
|
chars="",
|
||||||
|
yield_time_ms=1000,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
result = asyncio.run(run())
|
||||||
|
|
||||||
|
assert "hello" in result
|
||||||
|
assert "Exit code: 0" in result
|
||||||
|
assert "session_id:" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_session_accepts_max_output_tokens_alias(tmp_path):
|
||||||
|
async def run() -> str:
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||||
|
command = _python_command("print('A' * 2000)")
|
||||||
|
return await tool.execute(
|
||||||
|
command=command,
|
||||||
|
yield_time_ms=1000,
|
||||||
|
max_output_tokens=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(run())
|
||||||
|
|
||||||
|
assert "chars truncated" in result
|
||||||
|
assert "Exit code: 0" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_one_shot_accepts_max_output_tokens_alias(tmp_path):
|
||||||
|
async def run() -> str:
|
||||||
|
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||||
|
command = _python_command("print('A' * 2000)")
|
||||||
|
return await tool.execute(command=command, max_output_tokens=1000)
|
||||||
|
|
||||||
|
result = asyncio.run(run())
|
||||||
|
|
||||||
|
assert "chars truncated" in result
|
||||||
|
assert "Exit code: 0" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_accepts_supported_shell_parameter(tmp_path):
|
||||||
|
async def run() -> str:
|
||||||
|
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||||
|
return await tool.execute(command="echo shell-ok", shell="sh", login=False)
|
||||||
|
|
||||||
|
if sys.platform == "win32":
|
||||||
|
return
|
||||||
|
result = asyncio.run(run())
|
||||||
|
|
||||||
|
assert "shell-ok" in result
|
||||||
|
assert "Exit code: 0" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_rejects_unsupported_shell(tmp_path):
|
||||||
|
async def run() -> str:
|
||||||
|
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||||
|
return await tool.execute(command="echo no", shell="python")
|
||||||
|
|
||||||
|
if sys.platform == "win32":
|
||||||
|
return
|
||||||
|
result = asyncio.run(run())
|
||||||
|
|
||||||
|
assert "unsupported shell" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_can_continue_with_stdin(tmp_path):
|
||||||
|
async def run() -> tuple[str, str]:
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||||
|
stdin_tool = WriteStdinTool(manager=manager)
|
||||||
|
command = _python_command(
|
||||||
|
"import sys; print('ready', flush=True); "
|
||||||
|
"line=sys.stdin.readline(); print('got:' + line.strip(), flush=True)"
|
||||||
|
)
|
||||||
|
|
||||||
|
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||||
|
sid = _session_id(initial)
|
||||||
|
result = await stdin_tool.execute(session_id=sid, chars="ping\n", yield_time_ms=1000)
|
||||||
|
return initial, result
|
||||||
|
|
||||||
|
initial, result = asyncio.run(run())
|
||||||
|
assert "ready" in initial
|
||||||
|
assert "Process running" in initial
|
||||||
|
assert "Elapsed:" in initial
|
||||||
|
assert "got:ping" in result
|
||||||
|
assert "Exit code: 0" in result
|
||||||
|
assert "Elapsed:" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_stdin_can_close_stdin(tmp_path):
|
||||||
|
async def run() -> tuple[str, str]:
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||||
|
stdin_tool = WriteStdinTool(manager=manager)
|
||||||
|
command = _python_command(
|
||||||
|
"import sys; print('ready', flush=True); "
|
||||||
|
"data=sys.stdin.read(); print('got:' + data, flush=True)"
|
||||||
|
)
|
||||||
|
|
||||||
|
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||||
|
sid = _session_id(initial)
|
||||||
|
result = await stdin_tool.execute(
|
||||||
|
session_id=sid,
|
||||||
|
chars="payload",
|
||||||
|
close_stdin=True,
|
||||||
|
yield_time_ms=1000,
|
||||||
|
)
|
||||||
|
return initial, result
|
||||||
|
|
||||||
|
initial, result = asyncio.run(run())
|
||||||
|
assert "ready" in initial
|
||||||
|
assert "got:payload" in result
|
||||||
|
assert "Stdin closed." in result
|
||||||
|
assert "Exit code: 0" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_stdin_can_terminate_session(tmp_path):
|
||||||
|
async def run() -> tuple[str, str]:
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=30, session_manager=manager)
|
||||||
|
stdin_tool = WriteStdinTool(manager=manager)
|
||||||
|
command = _python_command(
|
||||||
|
"import time; print('ready', flush=True); time.sleep(30)"
|
||||||
|
)
|
||||||
|
|
||||||
|
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||||
|
sid = _session_id(initial)
|
||||||
|
result = await stdin_tool.execute(
|
||||||
|
session_id=sid,
|
||||||
|
terminate=True,
|
||||||
|
yield_time_ms=0,
|
||||||
|
)
|
||||||
|
return initial, result
|
||||||
|
|
||||||
|
initial, result = asyncio.run(run())
|
||||||
|
assert "ready" in initial
|
||||||
|
assert "Session terminated." in result
|
||||||
|
assert "Exit code:" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_stdin_accepts_max_output_tokens_alias(tmp_path):
|
||||||
|
async def run() -> tuple[str, str, str]:
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||||
|
stdin_tool = WriteStdinTool(manager=manager)
|
||||||
|
command = _python_command(
|
||||||
|
"import time; print('A' * 2000, flush=True); time.sleep(5)"
|
||||||
|
)
|
||||||
|
|
||||||
|
initial = await exec_tool.execute(command=command, yield_time_ms=0)
|
||||||
|
sid = _session_id(initial)
|
||||||
|
poll = await stdin_tool.execute(
|
||||||
|
session_id=sid,
|
||||||
|
yield_time_ms=500,
|
||||||
|
max_output_tokens=1000,
|
||||||
|
)
|
||||||
|
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||||
|
return initial, poll, cleanup
|
||||||
|
|
||||||
|
initial, poll, cleanup = asyncio.run(run())
|
||||||
|
assert "Process running" in initial
|
||||||
|
assert "chars truncated" in poll
|
||||||
|
assert "Session terminated." in cleanup
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_stdin_preserves_completed_session_output_until_polled(tmp_path):
|
||||||
|
async def run() -> tuple[str, str]:
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||||
|
stdin_tool = WriteStdinTool(manager=manager)
|
||||||
|
command = _python_command(
|
||||||
|
"import time; print('ready', flush=True); "
|
||||||
|
"time.sleep(1.0); print('done', flush=True)"
|
||||||
|
)
|
||||||
|
|
||||||
|
initial = await exec_tool.execute(command=command, yield_time_ms=300)
|
||||||
|
sid = _session_id(initial)
|
||||||
|
await asyncio.sleep(1.2)
|
||||||
|
final = await stdin_tool.execute(session_id=sid, chars="", yield_time_ms=0)
|
||||||
|
return initial, final
|
||||||
|
|
||||||
|
initial, final = asyncio.run(run())
|
||||||
|
|
||||||
|
assert "ready" in initial
|
||||||
|
assert "done" in final
|
||||||
|
assert "Exit code: 0" in final
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_stdin_can_wait_for_expected_output(tmp_path):
|
||||||
|
async def run() -> tuple[str, str, str]:
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||||
|
stdin_tool = WriteStdinTool(manager=manager)
|
||||||
|
command = _python_command(
|
||||||
|
"import time; print('booting', flush=True); "
|
||||||
|
"time.sleep(0.4); print('ready', flush=True); time.sleep(5)"
|
||||||
|
)
|
||||||
|
|
||||||
|
initial = await exec_tool.execute(command=command, yield_time_ms=100)
|
||||||
|
sid = _session_id(initial)
|
||||||
|
waited = await stdin_tool.execute(
|
||||||
|
session_id=sid,
|
||||||
|
wait_for="ready",
|
||||||
|
wait_timeout_ms=3000,
|
||||||
|
yield_time_ms=0,
|
||||||
|
)
|
||||||
|
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||||
|
return initial, waited, cleanup
|
||||||
|
|
||||||
|
initial, waited, cleanup = asyncio.run(run())
|
||||||
|
|
||||||
|
assert "Process running" in initial
|
||||||
|
assert "booting" in initial + waited
|
||||||
|
assert "ready" in waited
|
||||||
|
assert "Wait target not observed" not in waited
|
||||||
|
assert "Session terminated." in cleanup
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_stdin_wait_for_reports_timeout_without_killing_session(tmp_path):
|
||||||
|
async def run() -> tuple[str, str, str]:
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||||
|
stdin_tool = WriteStdinTool(manager=manager)
|
||||||
|
command = _python_command(
|
||||||
|
"import time; print('booting', flush=True); time.sleep(5)"
|
||||||
|
)
|
||||||
|
|
||||||
|
initial = await exec_tool.execute(command=command, yield_time_ms=100)
|
||||||
|
sid = _session_id(initial)
|
||||||
|
waited = await stdin_tool.execute(
|
||||||
|
session_id=sid,
|
||||||
|
wait_for="never-ready",
|
||||||
|
wait_timeout_ms=200,
|
||||||
|
yield_time_ms=0,
|
||||||
|
)
|
||||||
|
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||||
|
return initial, waited, cleanup
|
||||||
|
|
||||||
|
initial, waited, cleanup = asyncio.run(run())
|
||||||
|
|
||||||
|
assert "Process running" in initial
|
||||||
|
assert "booting" in initial + waited
|
||||||
|
assert "Process running" in waited
|
||||||
|
assert "Wait target not observed: 'never-ready'" in waited
|
||||||
|
assert "Session terminated." in cleanup
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_session_mode_reuses_exec_safety_guard(tmp_path):
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
tool = ExecTool(
|
||||||
|
working_dir=str(tmp_path),
|
||||||
|
deny_patterns=[r"echo\s+blocked"],
|
||||||
|
session_manager=manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(command="echo blocked", yield_time_ms=0))
|
||||||
|
|
||||||
|
assert "blocked by deny pattern" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_write_stdin_reports_missing_session(tmp_path):
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
tool = WriteStdinTool(manager=manager)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(session_id="missing", chars=""))
|
||||||
|
|
||||||
|
assert "exec session not found" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_exec_sessions_reports_running_commands(tmp_path):
|
||||||
|
async def run() -> tuple[str, str, str]:
|
||||||
|
manager = ExecSessionManager()
|
||||||
|
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||||
|
list_tool = ListExecSessionsTool(manager=manager)
|
||||||
|
stdin_tool = WriteStdinTool(manager=manager)
|
||||||
|
command = _python_command(
|
||||||
|
"import time; print('ready', flush=True); time.sleep(5)"
|
||||||
|
)
|
||||||
|
|
||||||
|
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||||
|
sid = _session_id(initial)
|
||||||
|
listing = await list_tool.execute()
|
||||||
|
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||||
|
return sid, listing, cleanup
|
||||||
|
|
||||||
|
sid, listing, cleanup = asyncio.run(run())
|
||||||
|
|
||||||
|
assert sid in listing
|
||||||
|
assert "running" in listing
|
||||||
|
assert "elapsed=" in listing
|
||||||
|
assert "remaining=" in listing
|
||||||
|
assert str(tmp_path) in listing
|
||||||
|
assert "Session terminated." in cleanup
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_exec_sessions_reports_empty_state():
|
||||||
|
result = asyncio.run(ListExecSessionsTool(manager=ExecSessionManager()).execute())
|
||||||
|
|
||||||
|
assert result == "No active exec sessions."
|
||||||
216
tests/tools/test_file_edit_coding_enhancements.py
Normal file
216
tests/tools/test_file_edit_coding_enhancements.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool
|
||||||
|
|
||||||
|
|
||||||
|
def test_read_file_force_bypasses_dedup(tmp_path):
|
||||||
|
target = tmp_path / "data.txt"
|
||||||
|
target.write_text("alpha\n")
|
||||||
|
tool = ReadFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
first = asyncio.run(tool.execute(path=str(target)))
|
||||||
|
second = asyncio.run(tool.execute(path=str(target)))
|
||||||
|
forced = asyncio.run(tool.execute(path=str(target), force=True))
|
||||||
|
|
||||||
|
assert "alpha" in first
|
||||||
|
assert "unchanged" in second.lower()
|
||||||
|
assert "alpha" in forced
|
||||||
|
assert "unchanged" not in forced.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_can_select_occurrence(tmp_path):
|
||||||
|
target = tmp_path / "duplicate.txt"
|
||||||
|
target.write_text("one\nsame\ntwo\nsame\n")
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text="same",
|
||||||
|
new_text="changed",
|
||||||
|
occurrence=2,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "Successfully edited" in result
|
||||||
|
assert target.read_text() == "one\nsame\ntwo\nchanged\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_expected_replacements_guards_replace_all(tmp_path):
|
||||||
|
target = tmp_path / "duplicate.txt"
|
||||||
|
target.write_text("same\nsame\n")
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text="same",
|
||||||
|
new_text="changed",
|
||||||
|
replace_all=True,
|
||||||
|
expected_replacements=1,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "expected 1 replacements but would make 2" in result
|
||||||
|
assert target.read_text() == "same\nsame\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_expected_replacements_allows_replace_all_when_count_matches(tmp_path):
|
||||||
|
target = tmp_path / "duplicate.txt"
|
||||||
|
target.write_text("same\nsame\n")
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text="same",
|
||||||
|
new_text="changed",
|
||||||
|
replace_all=True,
|
||||||
|
expected_replacements=2,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "Successfully edited" in result
|
||||||
|
assert target.read_text() == "changed\nchanged\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_can_select_nearest_line_hint(tmp_path):
|
||||||
|
target = tmp_path / "duplicate.txt"
|
||||||
|
target.write_text("one\nsame\ntwo\nsame\n")
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text="same",
|
||||||
|
new_text="changed",
|
||||||
|
line_hint=4,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "Successfully edited" in result
|
||||||
|
assert target.read_text() == "one\nsame\ntwo\nchanged\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_can_edit_ipynb_as_json(tmp_path):
|
||||||
|
target = tmp_path / "analysis.ipynb"
|
||||||
|
target.write_text('{"cells": []}')
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text='"cells": []',
|
||||||
|
new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "Successfully edited" in result
|
||||||
|
assert '"source": "hi"' in target.read_text()
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_multiple_match_hint_mentions_occurrence(tmp_path):
|
||||||
|
target = tmp_path / "duplicate.txt"
|
||||||
|
target.write_text("same\nsame\n")
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text="same",
|
||||||
|
new_text="changed",
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "old_text appears 2 times" in result
|
||||||
|
assert "occurrence" in result
|
||||||
|
assert target.read_text() == "same\nsame\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_rejects_ambiguous_line_hint(tmp_path):
|
||||||
|
target = tmp_path / "duplicate.txt"
|
||||||
|
target.write_text("same\nmiddle\nsame\n")
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text="same",
|
||||||
|
new_text="changed",
|
||||||
|
line_hint=2,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "line_hint 2 is ambiguous" in result
|
||||||
|
assert target.read_text() == "same\nmiddle\nsame\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_rejects_occurrence_with_replace_all(tmp_path):
|
||||||
|
target = tmp_path / "duplicate.txt"
|
||||||
|
target.write_text("same\nsame\n")
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text="same",
|
||||||
|
new_text="changed",
|
||||||
|
occurrence=1,
|
||||||
|
replace_all=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "occurrence cannot be used with replace_all" in result
|
||||||
|
assert target.read_text() == "same\nsame\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_rejects_line_hint_with_replace_all(tmp_path):
|
||||||
|
target = tmp_path / "duplicate.txt"
|
||||||
|
target.write_text("same\nsame\n")
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text="same",
|
||||||
|
new_text="changed",
|
||||||
|
line_hint=1,
|
||||||
|
replace_all=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "line_hint cannot be used with replace_all" in result
|
||||||
|
assert target.read_text() == "same\nsame\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_rejects_line_hint_with_occurrence(tmp_path):
|
||||||
|
target = tmp_path / "duplicate.txt"
|
||||||
|
target.write_text("same\nsame\n")
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text="same",
|
||||||
|
new_text="changed",
|
||||||
|
occurrence=1,
|
||||||
|
line_hint=1,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "line_hint cannot be used with occurrence" in result
|
||||||
|
assert target.read_text() == "same\nsame\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_rejects_zero_occurrence(tmp_path):
|
||||||
|
target = tmp_path / "duplicate.txt"
|
||||||
|
target.write_text("same\n")
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text="same",
|
||||||
|
new_text="changed",
|
||||||
|
occurrence=0,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "occurrence must be >= 1" in result
|
||||||
|
assert target.read_text() == "same\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_edit_file_rejects_zero_line_hint(tmp_path):
|
||||||
|
target = tmp_path / "duplicate.txt"
|
||||||
|
target.write_text("same\n")
|
||||||
|
tool = EditFileTool(workspace=tmp_path)
|
||||||
|
|
||||||
|
result = asyncio.run(tool.execute(
|
||||||
|
path=str(target),
|
||||||
|
old_text="same",
|
||||||
|
new_text="changed",
|
||||||
|
line_hint=0,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert "line_hint must be >= 1" in result
|
||||||
|
assert target.read_text() == "same\n"
|
||||||
@ -1,147 +0,0 @@
|
|||||||
"""Tests for NotebookEditTool — Jupyter .ipynb editing."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
|
||||||
|
|
||||||
|
|
||||||
def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict:
|
|
||||||
"""Build a minimal valid .ipynb structure."""
|
|
||||||
return {
|
|
||||||
"nbformat": nbformat,
|
|
||||||
"nbformat_minor": nbformat_minor,
|
|
||||||
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
|
|
||||||
"cells": cells or [],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _code_cell(source: str, cell_id: str | None = None) -> dict:
|
|
||||||
cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None}
|
|
||||||
if cell_id:
|
|
||||||
cell["id"] = cell_id
|
|
||||||
return cell
|
|
||||||
|
|
||||||
|
|
||||||
def _md_cell(source: str, cell_id: str | None = None) -> dict:
|
|
||||||
cell = {"cell_type": "markdown", "source": source, "metadata": {}}
|
|
||||||
if cell_id:
|
|
||||||
cell["id"] = cell_id
|
|
||||||
return cell
|
|
||||||
|
|
||||||
|
|
||||||
def _write_nb(tmp_path, name: str, nb: dict) -> str:
|
|
||||||
p = tmp_path / name
|
|
||||||
p.write_text(json.dumps(nb), encoding="utf-8")
|
|
||||||
return str(p)
|
|
||||||
|
|
||||||
|
|
||||||
class TestNotebookEdit:
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def tool(self, tmp_path):
|
|
||||||
return NotebookEditTool(workspace=tmp_path)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_replace_cell_content(self, tool, tmp_path):
|
|
||||||
nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")])
|
|
||||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
|
||||||
result = await tool.execute(path=path, cell_index=0, new_source="print('world')")
|
|
||||||
assert "Successfully" in result
|
|
||||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
|
||||||
assert saved["cells"][0]["source"] == "print('world')"
|
|
||||||
assert saved["cells"][1]["source"] == "x = 1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_insert_cell_after_target(self, tool, tmp_path):
|
|
||||||
nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")])
|
|
||||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
|
||||||
result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert")
|
|
||||||
assert "Successfully" in result
|
|
||||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
|
||||||
assert len(saved["cells"]) == 3
|
|
||||||
assert saved["cells"][0]["source"] == "cell 0"
|
|
||||||
assert saved["cells"][1]["source"] == "inserted"
|
|
||||||
assert saved["cells"][2]["source"] == "cell 1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_cell(self, tool, tmp_path):
|
|
||||||
nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")])
|
|
||||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
|
||||||
result = await tool.execute(path=path, cell_index=1, edit_mode="delete")
|
|
||||||
assert "Successfully" in result
|
|
||||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
|
||||||
assert len(saved["cells"]) == 2
|
|
||||||
assert saved["cells"][0]["source"] == "A"
|
|
||||||
assert saved["cells"][1]["source"] == "C"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_new_notebook_from_scratch(self, tool, tmp_path):
|
|
||||||
path = str(tmp_path / "new.ipynb")
|
|
||||||
result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown")
|
|
||||||
assert "Successfully" in result or "created" in result.lower()
|
|
||||||
saved = json.loads((tmp_path / "new.ipynb").read_text())
|
|
||||||
assert saved["nbformat"] == 4
|
|
||||||
assert len(saved["cells"]) == 1
|
|
||||||
assert saved["cells"][0]["cell_type"] == "markdown"
|
|
||||||
assert saved["cells"][0]["source"] == "# Hello"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_invalid_cell_index_error(self, tool, tmp_path):
|
|
||||||
nb = _make_notebook([_code_cell("only cell")])
|
|
||||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
|
||||||
result = await tool.execute(path=path, cell_index=5, new_source="x")
|
|
||||||
assert "Error" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_non_ipynb_rejected(self, tool, tmp_path):
|
|
||||||
f = tmp_path / "script.py"
|
|
||||||
f.write_text("pass")
|
|
||||||
result = await tool.execute(path=str(f), cell_index=0, new_source="x")
|
|
||||||
assert "Error" in result
|
|
||||||
assert ".ipynb" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_preserves_metadata_and_outputs(self, tool, tmp_path):
|
|
||||||
cell = _code_cell("old")
|
|
||||||
cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}]
|
|
||||||
cell["execution_count"] = 42
|
|
||||||
nb = _make_notebook([cell])
|
|
||||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
|
||||||
await tool.execute(path=path, cell_index=0, new_source="new")
|
|
||||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
|
||||||
assert saved["metadata"]["kernelspec"]["language"] == "python"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_nbformat_45_generates_cell_id(self, tool, tmp_path):
|
|
||||||
nb = _make_notebook([], nbformat_minor=5)
|
|
||||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
|
||||||
await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert")
|
|
||||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
|
||||||
assert "id" in saved["cells"][0]
|
|
||||||
assert len(saved["cells"][0]["id"]) > 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_insert_with_cell_type_markdown(self, tool, tmp_path):
|
|
||||||
nb = _make_notebook([_code_cell("code")])
|
|
||||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
|
||||||
await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown")
|
|
||||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
|
||||||
assert saved["cells"][1]["cell_type"] == "markdown"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_invalid_edit_mode_rejected(self, tool, tmp_path):
|
|
||||||
nb = _make_notebook([_code_cell("code")])
|
|
||||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
|
||||||
result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae")
|
|
||||||
assert "Error" in result
|
|
||||||
assert "edit_mode" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_invalid_cell_type_rejected(self, tool, tmp_path):
|
|
||||||
nb = _make_notebook([_code_cell("code")])
|
|
||||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
|
||||||
result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw")
|
|
||||||
assert "Error" in result
|
|
||||||
assert "cell_type" in result
|
|
||||||
@ -12,7 +12,7 @@ import pytest
|
|||||||
|
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.agent.subagent import SubagentManager, SubagentStatus
|
from nanobot.agent.subagent import SubagentManager, SubagentStatus
|
||||||
from nanobot.agent.tools.search import GrepTool
|
from nanobot.agent.tools.search import FindFilesTool, GrepTool
|
||||||
from nanobot.agent.tools.web import WebSearchTool
|
from nanobot.agent.tools.web import WebSearchTool
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.config.schema import WebSearchConfig
|
from nanobot.config.schema import WebSearchConfig
|
||||||
@ -33,6 +33,68 @@ async def test_web_search_tool_refreshes_dynamic_config_loader(monkeypatch) -> N
|
|||||||
assert await tool.execute("nanobot") == "duckduckgo:nanobot:3"
|
assert await tool.execute("nanobot") == "duckduckgo:nanobot:3"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_files_filters_by_query_glob_and_type(tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "src").mkdir()
|
||||||
|
(tmp_path / "src" / "settings_view.tsx").write_text("export {}\n", encoding="utf-8")
|
||||||
|
(tmp_path / "src" / "settings_api.py").write_text("pass\n", encoding="utf-8")
|
||||||
|
(tmp_path / "README.md").write_text("settings\n", encoding="utf-8")
|
||||||
|
|
||||||
|
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||||
|
result = await tool.execute(
|
||||||
|
path=".",
|
||||||
|
query="settings",
|
||||||
|
glob="src/**",
|
||||||
|
type="ts",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.splitlines() == ["src/settings_view.tsx"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_files_can_include_directories(tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "src" / "settings").mkdir(parents=True)
|
||||||
|
(tmp_path / "src" / "settings" / "index.ts").write_text("export {}\n", encoding="utf-8")
|
||||||
|
|
||||||
|
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||||
|
result = await tool.execute(path="src", query="settings", include_dirs=True)
|
||||||
|
|
||||||
|
assert "src/settings/" in result.splitlines()
|
||||||
|
assert "src/settings/index.ts" in result.splitlines()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_files_supports_modified_sort_and_pagination(tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "src").mkdir()
|
||||||
|
for idx, name in enumerate(("a.py", "b.py", "c.py"), start=1):
|
||||||
|
file_path = tmp_path / "src" / name
|
||||||
|
file_path.write_text("pass\n", encoding="utf-8")
|
||||||
|
os.utime(file_path, (idx, idx))
|
||||||
|
|
||||||
|
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||||
|
result = await tool.execute(
|
||||||
|
path="src",
|
||||||
|
type="py",
|
||||||
|
sort="modified",
|
||||||
|
head_limit=1,
|
||||||
|
offset=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.splitlines()[0] == "src/b.py"
|
||||||
|
assert "pagination: limit=1, offset=1" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_find_files_rejects_paths_outside_workspace(tmp_path: Path) -> None:
|
||||||
|
outside = tmp_path.parent / "outside-find-files.txt"
|
||||||
|
outside.write_text("secret\n", encoding="utf-8")
|
||||||
|
|
||||||
|
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||||
|
result = await tool.execute(path=str(outside))
|
||||||
|
|
||||||
|
assert result.startswith("Error:")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_grep_respects_glob_filter_and_context(tmp_path: Path) -> None:
|
async def test_grep_respects_glob_filter_and_context(tmp_path: Path) -> None:
|
||||||
(tmp_path / "src").mkdir()
|
(tmp_path / "src").mkdir()
|
||||||
@ -249,6 +311,7 @@ def test_agent_loop_registers_grep(tmp_path: Path) -> None:
|
|||||||
|
|
||||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||||
|
|
||||||
|
assert "find_files" in loop.tools.tool_names
|
||||||
assert "grep" in loop.tools.tool_names
|
assert "grep" in loop.tools.tool_names
|
||||||
|
|
||||||
|
|
||||||
@ -280,6 +343,7 @@ async def test_subagent_registers_grep(tmp_path: Path) -> None:
|
|||||||
status = SubagentStatus(task_id="sub-1", label="label", task_description="search task", started_at=time.monotonic())
|
status = SubagentStatus(task_id="sub-1", label="label", task_description="search task", started_at=time.monotonic())
|
||||||
await mgr._run_subagent("sub-1", "search task", "label", {"channel": "cli", "chat_id": "direct"}, status)
|
await mgr._run_subagent("sub-1", "search task", "label", {"channel": "cli", "chat_id": "direct"}, status)
|
||||||
|
|
||||||
|
assert "find_files" in captured["tool_names"]
|
||||||
assert "grep" in captured["tool_names"]
|
assert "grep" in captured["tool_names"]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
46
tests/tools/test_tool_descriptions.py
Normal file
46
tests/tools/test_tool_descriptions.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from nanobot.agent.tools.apply_patch import ApplyPatchTool
|
||||||
|
from nanobot.agent.tools.exec_session import ListExecSessionsTool, WriteStdinTool
|
||||||
|
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
|
||||||
|
from nanobot.agent.tools.search import FindFilesTool, GrepTool
|
||||||
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
|
|
||||||
|
|
||||||
|
def test_coding_tool_descriptions_steer_editing_priority() -> None:
|
||||||
|
apply_patch = ApplyPatchTool().description.lower()
|
||||||
|
edit_file = EditFileTool().description.lower()
|
||||||
|
write_file = WriteFileTool().description.lower()
|
||||||
|
|
||||||
|
assert "default tool for code edits" in apply_patch
|
||||||
|
assert "multi-file" in apply_patch
|
||||||
|
assert "dry_run=true" in apply_patch
|
||||||
|
assert "edit_file only for small exact replacements" in apply_patch
|
||||||
|
|
||||||
|
assert "small, exact replacement" in edit_file
|
||||||
|
assert "copied from read_file" in edit_file
|
||||||
|
assert "prefer apply_patch" in edit_file
|
||||||
|
|
||||||
|
assert "replace an entire file" in write_file
|
||||||
|
assert "prefer apply_patch" in write_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_coding_tool_descriptions_steer_discovery_and_shell_usage() -> None:
|
||||||
|
read_file = ReadFileTool().description.lower()
|
||||||
|
find_files = FindFilesTool().description.lower()
|
||||||
|
grep = GrepTool().description.lower()
|
||||||
|
exec_tool = ExecTool().description.lower()
|
||||||
|
write_stdin = WriteStdinTool().description.lower()
|
||||||
|
list_sessions = ListExecSessionsTool().description.lower()
|
||||||
|
|
||||||
|
assert "find_files/list_dir first" in read_file
|
||||||
|
assert "before editing" in read_file
|
||||||
|
assert "prefer it over shell find/ls" in find_files
|
||||||
|
assert "prefer this over shell grep" in grep
|
||||||
|
|
||||||
|
assert "tests, builds" in exec_tool
|
||||||
|
assert "prefer read_file/find_files/grep" in exec_tool
|
||||||
|
assert "apply_patch/write_file/edit_file" in exec_tool
|
||||||
|
assert "yield_time_ms" in exec_tool
|
||||||
|
|
||||||
|
assert "do not use this to start new commands" in write_stdin
|
||||||
|
assert "wait_for" in write_stdin
|
||||||
|
assert "recover a session_id" in list_sessions
|
||||||
@ -89,9 +89,11 @@ def test_discover_finds_concrete_tools():
|
|||||||
loader = ToolLoader()
|
loader = ToolLoader()
|
||||||
discovered = loader.discover()
|
discovered = loader.discover()
|
||||||
class_names = {cls.__name__ for cls in discovered}
|
class_names = {cls.__name__ for cls in discovered}
|
||||||
|
assert "ApplyPatchTool" in class_names
|
||||||
assert "ExecTool" in class_names
|
assert "ExecTool" in class_names
|
||||||
assert "MessageTool" in class_names
|
assert "MessageTool" in class_names
|
||||||
assert "SpawnTool" in class_names
|
assert "SpawnTool" in class_names
|
||||||
|
assert "WriteStdinTool" in class_names
|
||||||
|
|
||||||
|
|
||||||
def test_discover_excludes_abstract_and_mcp():
|
def test_discover_excludes_abstract_and_mcp():
|
||||||
@ -406,7 +408,8 @@ def test_loader_registers_same_tools_as_old_hardcoded():
|
|||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
"read_file", "write_file", "edit_file", "list_dir",
|
"read_file", "write_file", "edit_file", "list_dir",
|
||||||
"grep", "notebook_edit", "exec", "web_search", "web_fetch",
|
"find_files", "grep", "exec", "write_stdin", "list_exec_sessions",
|
||||||
|
"web_search", "web_fetch",
|
||||||
"message", "spawn", "cron",
|
"message", "spawn", "cron",
|
||||||
}
|
}
|
||||||
actual = set(registered)
|
actual = set(registered)
|
||||||
|
|||||||
@ -3,6 +3,8 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.tools import (
|
from nanobot.agent.tools import (
|
||||||
ArraySchema,
|
ArraySchema,
|
||||||
IntegerSchema,
|
IntegerSchema,
|
||||||
@ -15,6 +17,7 @@ from nanobot.agent.tools import (
|
|||||||
from nanobot.agent.tools.base import Tool
|
from nanobot.agent.tools.base import Tool
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.agent.tools.shell import ExecTool
|
from nanobot.agent.tools.shell import ExecTool
|
||||||
|
from nanobot.security.network import configure_ssrf_whitelist
|
||||||
|
|
||||||
|
|
||||||
class SampleTool(Tool):
|
class SampleTool(Tool):
|
||||||
@ -218,6 +221,39 @@ def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
|
|||||||
assert "/bin/python" not in paths
|
assert "/bin/python" not in paths
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_extract_absolute_paths_ignores_urls() -> None:
|
||||||
|
cmd = 'curl -s -o /dev/null -w "%{http_code}" https://www.google.com'
|
||||||
|
paths = ExecTool._extract_absolute_paths(cmd)
|
||||||
|
assert paths == ["/dev/null"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"command",
|
||||||
|
[
|
||||||
|
'curl -s -o /dev/null -w "%{http_code}" https://www.google.com',
|
||||||
|
'wget -q -O - http://example.com 2>&1 | head -c 100',
|
||||||
|
'python3 -c "import urllib.request; print(urllib.request.urlopen(\'http://example.com\').read()[:100])"',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_exec_guard_allows_public_urls(tmp_path, command: str) -> None:
|
||||||
|
tool = ExecTool(restrict_to_workspace=True)
|
||||||
|
error = tool._guard_command(command, str(tmp_path))
|
||||||
|
assert error is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_exec_guard_allows_whitelisted_internal_urls(tmp_path) -> None:
|
||||||
|
configure_ssrf_whitelist(["10.10.10.0/24"])
|
||||||
|
try:
|
||||||
|
tool = ExecTool(restrict_to_workspace=True)
|
||||||
|
error = tool._guard_command(
|
||||||
|
'curl -s -H "Authorization: Bearer ..." http://10.10.10.3:8123/api/',
|
||||||
|
str(tmp_path),
|
||||||
|
)
|
||||||
|
assert error is None
|
||||||
|
finally:
|
||||||
|
configure_ssrf_whitelist([])
|
||||||
|
|
||||||
|
|
||||||
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
||||||
cmd = "cat /tmp/data.txt > /tmp/out.txt"
|
cmd = "cat /tmp/data.txt > /tmp/out.txt"
|
||||||
paths = ExecTool._extract_absolute_paths(cmd)
|
paths = ExecTool._extract_absolute_paths(cmd)
|
||||||
|
|||||||
@ -5,12 +5,13 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
from nanobot.utils.file_edit_events import (
|
from nanobot.utils.file_edit_events import (
|
||||||
|
StreamingFileEditTracker,
|
||||||
build_file_edit_end_event,
|
build_file_edit_end_event,
|
||||||
build_file_edit_start_event,
|
build_file_edit_start_event,
|
||||||
line_diff_stats,
|
line_diff_stats,
|
||||||
prepare_file_edit_tracker,
|
prepare_file_edit_tracker,
|
||||||
|
prepare_file_edit_trackers,
|
||||||
read_file_snapshot,
|
read_file_snapshot,
|
||||||
StreamingFileEditTracker,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -81,6 +82,63 @@ def test_binary_file_is_reported_but_not_counted(tmp_path: Path) -> None:
|
|||||||
assert (event["added"], event["deleted"]) == (0, 0)
|
assert (event["added"], event["deleted"]) == (0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "src").mkdir()
|
||||||
|
existing = tmp_path / "src" / "existing.py"
|
||||||
|
existing.write_text("old\nkeep\n", encoding="utf-8")
|
||||||
|
delete_me = tmp_path / "src" / "delete_me.py"
|
||||||
|
delete_me.write_text("gone\n", encoding="utf-8")
|
||||||
|
|
||||||
|
edits = [
|
||||||
|
{"path": "src/new.py", "action": "add", "new_text": "fresh"},
|
||||||
|
{"path": "src/existing.py", "action": "replace", "old_text": "old", "new_text": "new"},
|
||||||
|
{"path": "src/delete_me.py", "action": "delete", "old_text": "gone\n"},
|
||||||
|
]
|
||||||
|
|
||||||
|
trackers = prepare_file_edit_trackers(
|
||||||
|
call_id="call-patch",
|
||||||
|
tool_name="apply_patch",
|
||||||
|
tool=None,
|
||||||
|
workspace=tmp_path,
|
||||||
|
params={"edits": edits},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert [tracker.display_path for tracker in trackers] == [
|
||||||
|
"src/new.py",
|
||||||
|
"src/existing.py",
|
||||||
|
"src/delete_me.py",
|
||||||
|
]
|
||||||
|
|
||||||
|
(tmp_path / "src" / "new.py").write_text("fresh\n", encoding="utf-8")
|
||||||
|
existing.write_text("new\nkeep\n", encoding="utf-8")
|
||||||
|
delete_me.unlink()
|
||||||
|
|
||||||
|
events = [build_file_edit_end_event(tracker, {"edits": edits}) for tracker in trackers]
|
||||||
|
by_path = {event["path"]: event for event in events}
|
||||||
|
assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0)
|
||||||
|
assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1)
|
||||||
|
assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_patch_dry_run_does_not_prepare_file_edit_trackers(tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "file.txt").write_text("old\n", encoding="utf-8")
|
||||||
|
|
||||||
|
trackers = prepare_file_edit_trackers(
|
||||||
|
call_id="call-patch",
|
||||||
|
tool_name="apply_patch",
|
||||||
|
tool=None,
|
||||||
|
workspace=tmp_path,
|
||||||
|
params={
|
||||||
|
"dry_run": True,
|
||||||
|
"edits": [
|
||||||
|
{"path": "file.txt", "action": "replace", "old_text": "old", "new_text": "new"}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert trackers == []
|
||||||
|
|
||||||
|
|
||||||
def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None:
|
def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None:
|
||||||
target = tmp_path / "large.txt"
|
target = tmp_path / "large.txt"
|
||||||
params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)}
|
params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)}
|
||||||
@ -140,6 +198,58 @@ def test_streaming_write_file_tracker_emits_live_line_counts(tmp_path: Path) ->
|
|||||||
assert events[-1]["deleted"] == 0
|
assert events[-1]["deleted"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_apply_patch_tracker_emits_live_counts_per_file(tmp_path: Path) -> None:
|
||||||
|
(tmp_path / "src").mkdir()
|
||||||
|
(tmp_path / "src" / "existing.py").write_text("old\nkeep\n", encoding="utf-8")
|
||||||
|
events: list[dict] = []
|
||||||
|
|
||||||
|
async def emit(batch: list[dict]) -> None:
|
||||||
|
events.extend(batch)
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||||
|
await tracker.update({
|
||||||
|
"index": 0,
|
||||||
|
"call_id": "call-patch",
|
||||||
|
"name": "apply_patch",
|
||||||
|
"arguments_delta": (
|
||||||
|
'{"edits":[{"path":"src/existing.py","action":"replace","old_text":"old","new_text":"new"}'
|
||||||
|
',{"path":"src/new.py","action":"add","new_text":"fresh"}]}'
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
by_path = {event["path"]: event for event in events}
|
||||||
|
assert by_path["src/existing.py"]["tool"] == "apply_patch"
|
||||||
|
assert by_path["src/existing.py"]["status"] == "editing"
|
||||||
|
assert by_path["src/existing.py"]["approximate"] is True
|
||||||
|
assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1)
|
||||||
|
assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_apply_patch_tracker_skips_dry_run(tmp_path: Path) -> None:
|
||||||
|
events: list[dict] = []
|
||||||
|
|
||||||
|
async def emit(batch: list[dict]) -> None:
|
||||||
|
events.extend(batch)
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||||
|
await tracker.update({
|
||||||
|
"index": 0,
|
||||||
|
"call_id": "call-patch",
|
||||||
|
"name": "apply_patch",
|
||||||
|
"arguments_delta": (
|
||||||
|
'{"dry_run":true,"edits":[{"path":"dry.md","action":"add","new_text":"preview"}]}'
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
assert events == []
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None:
|
def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None:
|
||||||
events: list[dict] = []
|
events: list[dict] = []
|
||||||
|
|
||||||
@ -308,6 +418,43 @@ def test_streaming_tracker_applies_canonical_call_id_to_final_tool(tmp_path: Pat
|
|||||||
asyncio.run(run())
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_tracker_does_not_restore_duplicate_canonical_ids(tmp_path: Path) -> None:
|
||||||
|
events: list[dict] = []
|
||||||
|
|
||||||
|
async def emit(batch: list[dict]) -> None:
|
||||||
|
events.extend(batch)
|
||||||
|
|
||||||
|
async def run() -> None:
|
||||||
|
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||||
|
await tracker.update({
|
||||||
|
"index": 0,
|
||||||
|
"call_id": "call_dup",
|
||||||
|
"name": "write_file",
|
||||||
|
"arguments_delta": '{"path":"a.md","content":"one\\n"}',
|
||||||
|
})
|
||||||
|
await tracker.update({
|
||||||
|
"index": 1,
|
||||||
|
"call_id": "call_dup",
|
||||||
|
"name": "write_file",
|
||||||
|
"arguments_delta": '{"path":"b.md","content":"two\\n"}',
|
||||||
|
})
|
||||||
|
final_a = SimpleNamespace(
|
||||||
|
id="call_dup",
|
||||||
|
name="write_file",
|
||||||
|
arguments={"path": "a.md", "content": "one\n"},
|
||||||
|
)
|
||||||
|
final_b = SimpleNamespace(
|
||||||
|
id="call_unique",
|
||||||
|
name="write_file",
|
||||||
|
arguments={"path": "b.md", "content": "two\n"},
|
||||||
|
)
|
||||||
|
tracker.apply_final_call_ids([final_a, final_b])
|
||||||
|
assert final_a.id == "call_dup"
|
||||||
|
assert final_b.id == "call_unique"
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
def test_streaming_edit_file_tracker_flushes_small_pending_count(tmp_path: Path) -> None:
|
def test_streaming_edit_file_tracker_flushes_small_pending_count(tmp_path: Path) -> None:
|
||||||
target = tmp_path / "small.py"
|
target = tmp_path / "small.py"
|
||||||
target.write_text("old\n", encoding="utf-8")
|
target.write_text("old\n", encoding="utf-8")
|
||||||
|
|||||||
@ -43,6 +43,7 @@ const SIDEBAR_STORAGE_KEY = "nanobot-webui.sidebar";
|
|||||||
const COMPLETED_RUNS_STORAGE_KEY = "nanobot-webui.sidebar.completed-runs.v1";
|
const COMPLETED_RUNS_STORAGE_KEY = "nanobot-webui.sidebar.completed-runs.v1";
|
||||||
const RESTART_STARTED_KEY = "nanobot-webui.restartStartedAt";
|
const RESTART_STARTED_KEY = "nanobot-webui.restartStartedAt";
|
||||||
const SIDEBAR_WIDTH = 272;
|
const SIDEBAR_WIDTH = 272;
|
||||||
|
const SIDEBAR_RAIL_WIDTH = 56;
|
||||||
const TOKEN_REFRESH_MARGIN_MS = 30_000;
|
const TOKEN_REFRESH_MARGIN_MS = 30_000;
|
||||||
const TOKEN_REFRESH_MIN_DELAY_MS = 5_000;
|
const TOKEN_REFRESH_MIN_DELAY_MS = 5_000;
|
||||||
type ShellView = "chat" | "settings";
|
type ShellView = "chat" | "settings";
|
||||||
@ -411,6 +412,10 @@ function Shell({
|
|||||||
setDesktopSidebarOpen(false);
|
setDesktopSidebarOpen(false);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const openDesktopSidebar = useCallback(() => {
|
||||||
|
setDesktopSidebarOpen(true);
|
||||||
|
}, []);
|
||||||
|
|
||||||
const closeMobileSidebar = useCallback(() => {
|
const closeMobileSidebar = useCallback(() => {
|
||||||
setMobileSidebarOpen(false);
|
setMobileSidebarOpen(false);
|
||||||
}, []);
|
}, []);
|
||||||
@ -560,6 +565,21 @@ function Shell({
|
|||||||
setSessionSearchOpen(true);
|
setSessionSearchOpen(true);
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handleKeyDown = (event: globalThis.KeyboardEvent) => {
|
||||||
|
if (event.defaultPrevented) return;
|
||||||
|
const plainCommandK =
|
||||||
|
(event.metaKey || event.ctrlKey) && !event.altKey && !event.shiftKey;
|
||||||
|
if (!plainCommandK) return;
|
||||||
|
if (event.key.toLowerCase() !== "k") return;
|
||||||
|
event.preventDefault();
|
||||||
|
onOpenSessionSearch();
|
||||||
|
};
|
||||||
|
|
||||||
|
window.addEventListener("keydown", handleKeyDown);
|
||||||
|
return () => window.removeEventListener("keydown", handleKeyDown);
|
||||||
|
}, [onOpenSessionSearch]);
|
||||||
|
|
||||||
const onSelectSearchResult = useCallback(
|
const onSelectSearchResult = useCallback(
|
||||||
(key: string) => {
|
(key: string) => {
|
||||||
setSessionSearchOpen(false);
|
setSessionSearchOpen(false);
|
||||||
@ -732,17 +752,19 @@ function Shell({
|
|||||||
"relative z-20 hidden shrink-0 overflow-hidden lg:block",
|
"relative z-20 hidden shrink-0 overflow-hidden lg:block",
|
||||||
"transition-[width] duration-300 ease-out",
|
"transition-[width] duration-300 ease-out",
|
||||||
)}
|
)}
|
||||||
style={{ width: desktopSidebarOpen ? SIDEBAR_WIDTH : 0 }}
|
style={{
|
||||||
|
width: desktopSidebarOpen ? SIDEBAR_WIDTH : SIDEBAR_RAIL_WIDTH,
|
||||||
|
}}
|
||||||
>
|
>
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className="absolute inset-y-0 left-0 h-full w-full overflow-hidden bg-sidebar shadow-inner-right"
|
||||||
"absolute inset-y-0 left-0 h-full overflow-hidden bg-sidebar shadow-inner-right",
|
|
||||||
"transition-transform duration-300 ease-out",
|
|
||||||
desktopSidebarOpen ? "translate-x-0" : "-translate-x-full",
|
|
||||||
)}
|
|
||||||
style={{ width: SIDEBAR_WIDTH }}
|
|
||||||
>
|
>
|
||||||
<Sidebar {...sidebarProps} onCollapse={closeDesktopSidebar} />
|
<Sidebar
|
||||||
|
{...sidebarProps}
|
||||||
|
collapsed={!desktopSidebarOpen}
|
||||||
|
onCollapse={closeDesktopSidebar}
|
||||||
|
onExpand={openDesktopSidebar}
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
</aside>
|
</aside>
|
||||||
) : null}
|
) : null}
|
||||||
@ -769,17 +791,15 @@ function Shell({
|
|||||||
</Sheet>
|
</Sheet>
|
||||||
) : null}
|
) : null}
|
||||||
|
|
||||||
{showMainSidebar ? (
|
<SessionSearchDialog
|
||||||
<SessionSearchDialog
|
open={sessionSearchOpen}
|
||||||
open={sessionSearchOpen}
|
onOpenChange={setSessionSearchOpen}
|
||||||
onOpenChange={setSessionSearchOpen}
|
sessions={sessions}
|
||||||
sessions={sessions}
|
activeKey={activeKey}
|
||||||
activeKey={activeKey}
|
loading={loading}
|
||||||
loading={loading}
|
titleOverrides={sidebarState.title_overrides}
|
||||||
titleOverrides={sidebarState.title_overrides}
|
onSelect={onSelectSearchResult}
|
||||||
onSelect={onSelectSearchResult}
|
/>
|
||||||
/>
|
|
||||||
) : null}
|
|
||||||
|
|
||||||
<main className="relative flex h-full min-w-0 flex-1 flex-col">
|
<main className="relative flex h-full min-w-0 flex-1 flex-col">
|
||||||
<div
|
<div
|
||||||
@ -797,7 +817,7 @@ function Shell({
|
|||||||
onTurnEnd={onTurnEnd}
|
onTurnEnd={onTurnEnd}
|
||||||
theme={theme}
|
theme={theme}
|
||||||
onToggleTheme={toggle}
|
onToggleTheme={toggle}
|
||||||
hideSidebarToggleOnDesktop={desktopSidebarOpen}
|
hideSidebarToggleOnDesktop
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
{view === "settings" && (
|
{view === "settings" && (
|
||||||
|
|||||||
@ -1,3 +1,9 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useEffect,
|
||||||
|
useMemo,
|
||||||
|
useState,
|
||||||
|
} from "react";
|
||||||
import {
|
import {
|
||||||
Archive,
|
Archive,
|
||||||
ArchiveRestore,
|
ArchiveRestore,
|
||||||
@ -19,6 +25,9 @@ import { deriveTitle, relativeTime } from "@/lib/format";
|
|||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import type { ChatSummary, SidebarDensity, SidebarSortMode } from "@/lib/types";
|
import type { ChatSummary, SidebarDensity, SidebarSortMode } from "@/lib/types";
|
||||||
|
|
||||||
|
const INITIAL_VISIBLE_SESSIONS = 160;
|
||||||
|
const VISIBLE_SESSIONS_INCREMENT = 160;
|
||||||
|
|
||||||
interface ChatListProps {
|
interface ChatListProps {
|
||||||
sessions: ChatSummary[];
|
sessions: ChatSummary[];
|
||||||
activeKey: string | null;
|
activeKey: string | null;
|
||||||
@ -42,7 +51,7 @@ interface ChatListProps {
|
|||||||
emptyLabel?: string;
|
emptyLabel?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatList({
|
export const ChatList = memo(function ChatList({
|
||||||
sessions,
|
sessions,
|
||||||
activeKey,
|
activeKey,
|
||||||
onSelect,
|
onSelect,
|
||||||
@ -65,6 +74,52 @@ export function ChatList({
|
|||||||
emptyLabel,
|
emptyLabel,
|
||||||
}: ChatListProps) {
|
}: ChatListProps) {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const [visibleLimit, setVisibleLimit] = useState(INITIAL_VISIBLE_SESSIONS);
|
||||||
|
const labels = useMemo(() => ({
|
||||||
|
pinned: t("chat.groups.pinned"),
|
||||||
|
all: t("chat.groups.all"),
|
||||||
|
today: t("chat.groups.today"),
|
||||||
|
yesterday: t("chat.groups.yesterday"),
|
||||||
|
earlier: t("chat.groups.earlier"),
|
||||||
|
archived: t("chat.groups.archived"),
|
||||||
|
fallbackTitle: t("chat.newChat"),
|
||||||
|
}), [t]);
|
||||||
|
const groups = useMemo(
|
||||||
|
() => groupSessions(sessions, labels, {
|
||||||
|
pinnedKeys,
|
||||||
|
archivedKeys,
|
||||||
|
titleOverrides,
|
||||||
|
showArchived,
|
||||||
|
sort,
|
||||||
|
}),
|
||||||
|
[
|
||||||
|
archivedKeys,
|
||||||
|
labels,
|
||||||
|
pinnedKeys,
|
||||||
|
sessions,
|
||||||
|
showArchived,
|
||||||
|
sort,
|
||||||
|
titleOverrides,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
const limitedGroups = useMemo(
|
||||||
|
() => limitGroups(groups, visibleLimit, activeKey),
|
||||||
|
[activeKey, groups, visibleLimit],
|
||||||
|
);
|
||||||
|
const totalSessionCount = useMemo(
|
||||||
|
() => groups.reduce((total, group) => total + group.sessions.length, 0),
|
||||||
|
[groups],
|
||||||
|
);
|
||||||
|
const visibleSessionCount = useMemo(
|
||||||
|
() => limitedGroups.reduce((total, group) => total + group.sessions.length, 0),
|
||||||
|
[limitedGroups],
|
||||||
|
);
|
||||||
|
const hiddenSessionCount = Math.max(0, totalSessionCount - visibleSessionCount);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setVisibleLimit(INITIAL_VISIBLE_SESSIONS);
|
||||||
|
}, [showArchived, sort]);
|
||||||
|
|
||||||
if (loading && sessions.length === 0) {
|
if (loading && sessions.length === 0) {
|
||||||
return (
|
return (
|
||||||
<div className="px-3 py-6 text-[12px] text-muted-foreground">
|
<div className="px-3 py-6 text-[12px] text-muted-foreground">
|
||||||
@ -81,21 +136,6 @@ export function ChatList({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const groups = groupSessions(sessions, {
|
|
||||||
pinned: t("chat.groups.pinned"),
|
|
||||||
all: t("chat.groups.all"),
|
|
||||||
today: t("chat.groups.today"),
|
|
||||||
yesterday: t("chat.groups.yesterday"),
|
|
||||||
earlier: t("chat.groups.earlier"),
|
|
||||||
archived: t("chat.groups.archived"),
|
|
||||||
fallbackTitle: t("chat.newChat"),
|
|
||||||
}, {
|
|
||||||
pinnedKeys,
|
|
||||||
archivedKeys,
|
|
||||||
titleOverrides,
|
|
||||||
showArchived,
|
|
||||||
sort,
|
|
||||||
});
|
|
||||||
const pinned = new Set(pinnedKeys);
|
const pinned = new Set(pinnedKeys);
|
||||||
const archived = new Set(archivedKeys);
|
const archived = new Set(archivedKeys);
|
||||||
const running = new Set(runningChatIds);
|
const running = new Set(runningChatIds);
|
||||||
@ -105,7 +145,7 @@ export function ChatList({
|
|||||||
return (
|
return (
|
||||||
<div className="h-full min-h-0 min-w-0 overflow-x-hidden overflow-y-auto overscroll-contain">
|
<div className="h-full min-h-0 min-w-0 overflow-x-hidden overflow-y-auto overscroll-contain">
|
||||||
<div className="min-w-0 space-y-3 px-2 py-1.5">
|
<div className="min-w-0 space-y-3 px-2 py-1.5">
|
||||||
{groups.map((group) => (
|
{limitedGroups.map((group) => (
|
||||||
<section key={group.label} aria-label={group.label}>
|
<section key={group.label} aria-label={group.label}>
|
||||||
<div className="px-2 pb-1 text-[12px] font-medium text-muted-foreground/65">
|
<div className="px-2 pb-1 text-[12px] font-medium text-muted-foreground/65">
|
||||||
{group.label}
|
{group.label}
|
||||||
@ -228,10 +268,25 @@ export function ChatList({
|
|||||||
</ul>
|
</ul>
|
||||||
</section>
|
</section>
|
||||||
))}
|
))}
|
||||||
|
{hiddenSessionCount > 0 ? (
|
||||||
|
<div className="px-2 pb-2 pt-1">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() =>
|
||||||
|
setVisibleLimit((limit) =>
|
||||||
|
Math.min(totalSessionCount, limit + VISIBLE_SESSIONS_INCREMENT),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
className="h-8 w-full rounded-full text-[12px] font-medium text-muted-foreground transition-colors hover:bg-sidebar-accent/65 hover:text-sidebar-foreground"
|
||||||
|
>
|
||||||
|
{t("chat.showMore", { count: hiddenSessionCount })}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
) : null}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
|
|
||||||
function SessionActivityIndicator({
|
function SessionActivityIndicator({
|
||||||
state,
|
state,
|
||||||
@ -366,6 +421,45 @@ function groupSessions(
|
|||||||
return groups;
|
return groups;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function limitGroups(
|
||||||
|
groups: Array<{ label: string; sessions: ChatSummary[] }>,
|
||||||
|
limit: number,
|
||||||
|
activeKey: string | null,
|
||||||
|
): Array<{ label: string; sessions: ChatSummary[] }> {
|
||||||
|
let remaining = Math.max(0, limit);
|
||||||
|
let activeVisible = !activeKey;
|
||||||
|
const out: Array<{ label: string; sessions: ChatSummary[] }> = [];
|
||||||
|
|
||||||
|
for (const group of groups) {
|
||||||
|
const visible = remaining > 0
|
||||||
|
? group.sessions.slice(0, remaining)
|
||||||
|
: [];
|
||||||
|
remaining -= visible.length;
|
||||||
|
if (activeKey && visible.some((session) => session.key === activeKey)) {
|
||||||
|
activeVisible = true;
|
||||||
|
}
|
||||||
|
if (visible.length > 0) {
|
||||||
|
out.push({ label: group.label, sessions: visible });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (activeVisible || !activeKey) return out;
|
||||||
|
|
||||||
|
for (const group of groups) {
|
||||||
|
const active = group.sessions.find((session) => session.key === activeKey);
|
||||||
|
if (!active) continue;
|
||||||
|
const existing = out.find((item) => item.label === group.label);
|
||||||
|
if (existing) {
|
||||||
|
existing.sessions = [...existing.sessions, active];
|
||||||
|
} else {
|
||||||
|
out.push({ label: group.label, sessions: [active] });
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
function sortSessions(
|
function sortSessions(
|
||||||
sessions: ChatSummary[],
|
sessions: ChatSummary[],
|
||||||
sort: SidebarSortMode,
|
sort: SidebarSortMode,
|
||||||
|
|||||||
@ -37,13 +37,16 @@ export function SessionSearchDialog({
|
|||||||
const [highlightedIndex, setHighlightedIndex] = useState(0);
|
const [highlightedIndex, setHighlightedIndex] = useState(0);
|
||||||
|
|
||||||
const normalizedQuery = query.trim().toLowerCase();
|
const normalizedQuery = query.trim().toLowerCase();
|
||||||
const results = useMemo(() => {
|
const sessionResults = useMemo(() => {
|
||||||
|
if (!open) return [];
|
||||||
if (!normalizedQuery) return sessions;
|
if (!normalizedQuery) return sessions;
|
||||||
const terms = normalizedQuery.split(/\s+/).filter(Boolean);
|
const terms = normalizedQuery.split(/\s+/).filter(Boolean);
|
||||||
return sessions.filter((session) =>
|
return sessions.filter((session) =>
|
||||||
sessionMatchesTerms(session, terms, titleOverrides[session.key]),
|
sessionMatchesTerms(session, terms, titleOverrides[session.key]),
|
||||||
);
|
);
|
||||||
}, [normalizedQuery, sessions, titleOverrides]);
|
}, [normalizedQuery, open, sessions, titleOverrides]);
|
||||||
|
const itemCount = sessionResults.length;
|
||||||
|
const shortcutLabel = useMemo(getSearchShortcutLabel, []);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!open) return;
|
if (!open) return;
|
||||||
@ -58,9 +61,9 @@ export function SessionSearchDialog({
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setHighlightedIndex((index) =>
|
setHighlightedIndex((index) =>
|
||||||
results.length === 0 ? 0 : Math.min(index, results.length - 1),
|
itemCount === 0 ? 0 : Math.min(index, itemCount - 1),
|
||||||
);
|
);
|
||||||
}, [results.length]);
|
}, [itemCount]);
|
||||||
|
|
||||||
const handleSelect = (key: string) => {
|
const handleSelect = (key: string) => {
|
||||||
onOpenChange(false);
|
onOpenChange(false);
|
||||||
@ -71,17 +74,19 @@ export function SessionSearchDialog({
|
|||||||
if (event.key === "ArrowDown") {
|
if (event.key === "ArrowDown") {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
setHighlightedIndex((index) =>
|
setHighlightedIndex((index) =>
|
||||||
results.length === 0 ? 0 : Math.min(index + 1, results.length - 1),
|
itemCount === 0 ? 0 : (index + 1) % itemCount,
|
||||||
);
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (event.key === "ArrowUp") {
|
if (event.key === "ArrowUp") {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
setHighlightedIndex((index) => Math.max(index - 1, 0));
|
setHighlightedIndex((index) =>
|
||||||
|
itemCount === 0 ? 0 : (index - 1 + itemCount) % itemCount,
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (event.key === "Enter") {
|
if (event.key === "Enter") {
|
||||||
const highlighted = results[highlightedIndex];
|
const highlighted = sessionResults[highlightedIndex];
|
||||||
if (!highlighted) return;
|
if (!highlighted) return;
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
handleSelect(highlighted.key);
|
handleSelect(highlighted.key);
|
||||||
@ -125,70 +130,75 @@ export function SessionSearchDialog({
|
|||||||
aria-label={t("sidebar.searchAria")}
|
aria-label={t("sidebar.searchAria")}
|
||||||
className="h-full min-w-0 flex-1 bg-transparent text-[15px] font-medium text-foreground outline-none placeholder:text-muted-foreground/75"
|
className="h-full min-w-0 flex-1 bg-transparent text-[15px] font-medium text-foreground outline-none placeholder:text-muted-foreground/75"
|
||||||
/>
|
/>
|
||||||
|
<kbd className="hidden h-6 shrink-0 items-center rounded-md border border-border/70 bg-muted/60 px-2 text-[11px] font-medium text-muted-foreground sm:inline-flex">
|
||||||
|
{shortcutLabel}
|
||||||
|
</kbd>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="min-h-0 overflow-y-auto overscroll-contain p-2">
|
<div className="min-h-0 overflow-y-auto overscroll-contain p-2">
|
||||||
<div className="px-2 pb-1.5 pt-1 text-[12px] font-medium text-muted-foreground/70">
|
<section>
|
||||||
{sectionLabel}
|
<div className="px-2 pb-1.5 pt-1 text-[12px] font-medium text-muted-foreground/70">
|
||||||
</div>
|
{sectionLabel}
|
||||||
|
</div>
|
||||||
|
|
||||||
{loading && sessions.length === 0 ? (
|
{loading && sessions.length === 0 ? (
|
||||||
<div className="px-3 py-7 text-[13px] text-muted-foreground">
|
<div className="px-3 py-7 text-[13px] text-muted-foreground">
|
||||||
{t("chat.loading")}
|
{t("chat.loading")}
|
||||||
</div>
|
</div>
|
||||||
) : results.length === 0 ? (
|
) : sessionResults.length === 0 ? (
|
||||||
<div className="px-3 py-7 text-[13px] text-muted-foreground">
|
<div className="px-3 py-7 text-[13px] text-muted-foreground">
|
||||||
{emptyLabel}
|
{emptyLabel}
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<ul className="space-y-1">
|
<ul className="space-y-1">
|
||||||
{results.map((session, index) => {
|
{sessionResults.map((session, index) => {
|
||||||
const title = titleOverrides[session.key]?.trim() ||
|
const title = titleOverrides[session.key]?.trim() ||
|
||||||
session.title?.trim() ||
|
session.title?.trim() ||
|
||||||
deriveTitle(session.preview, t("chat.newChat"));
|
deriveTitle(session.preview, t("chat.newChat"));
|
||||||
const preview = session.preview.trim();
|
const preview = session.preview.trim();
|
||||||
const showPreview =
|
const showPreview =
|
||||||
preview.length > 0 &&
|
preview.length > 0 &&
|
||||||
preview.toLowerCase() !== title.trim().toLowerCase();
|
preview.toLowerCase() !== title.trim().toLowerCase();
|
||||||
const highlighted = index === highlightedIndex;
|
const highlighted = index === highlightedIndex;
|
||||||
const active = session.key === activeKey;
|
const active = session.key === activeKey;
|
||||||
return (
|
return (
|
||||||
<li key={session.key}>
|
<li key={session.key}>
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
onClick={() => handleSelect(session.key)}
|
onClick={() => handleSelect(session.key)}
|
||||||
onMouseEnter={() => setHighlightedIndex(index)}
|
onMouseEnter={() => setHighlightedIndex(index)}
|
||||||
aria-current={active ? "page" : undefined}
|
aria-current={active ? "page" : undefined}
|
||||||
className={cn(
|
className={cn(
|
||||||
"flex min-h-12 w-full min-w-0 rounded-xl px-3 py-2.5 text-left transition-colors",
|
"flex min-h-12 w-full min-w-0 rounded-xl px-3 py-2.5 text-left transition-colors",
|
||||||
highlighted
|
highlighted
|
||||||
? "bg-accent text-accent-foreground"
|
? "bg-accent text-accent-foreground"
|
||||||
: "text-popover-foreground hover:bg-accent/75 hover:text-accent-foreground",
|
: "text-popover-foreground hover:bg-accent/75 hover:text-accent-foreground",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<span className="min-w-0 flex-1">
|
<span className="min-w-0 flex-1">
|
||||||
<span className="block truncate text-[14px] font-medium leading-5">
|
<span className="block truncate text-[14px] font-medium leading-5">
|
||||||
{title}
|
{title}
|
||||||
</span>
|
|
||||||
{showPreview ? (
|
|
||||||
<span
|
|
||||||
className={cn(
|
|
||||||
"block truncate text-[12px] leading-4",
|
|
||||||
highlighted
|
|
||||||
? "text-accent-foreground/70"
|
|
||||||
: "text-muted-foreground",
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
{preview}
|
|
||||||
</span>
|
</span>
|
||||||
) : null}
|
{showPreview ? (
|
||||||
</span>
|
<span
|
||||||
</button>
|
className={cn(
|
||||||
</li>
|
"block truncate text-[12px] leading-4",
|
||||||
);
|
highlighted
|
||||||
})}
|
? "text-accent-foreground/70"
|
||||||
</ul>
|
: "text-muted-foreground",
|
||||||
)}
|
)}
|
||||||
|
>
|
||||||
|
{preview}
|
||||||
|
</span>
|
||||||
|
) : null}
|
||||||
|
</span>
|
||||||
|
</button>
|
||||||
|
</li>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</ul>
|
||||||
|
)}
|
||||||
|
</section>
|
||||||
</div>
|
</div>
|
||||||
</DialogContent>
|
</DialogContent>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
@ -211,3 +221,13 @@ function sessionMatchesTerms(
|
|||||||
|
|
||||||
return terms.every((term) => haystack.includes(term));
|
return terms.every((term) => haystack.includes(term));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getSearchShortcutLabel() {
|
||||||
|
if (typeof navigator === "undefined") return "Ctrl K";
|
||||||
|
const platform = navigator.platform.toLowerCase();
|
||||||
|
const apple =
|
||||||
|
platform.includes("mac") ||
|
||||||
|
platform.includes("iphone") ||
|
||||||
|
platform.includes("ipad");
|
||||||
|
return apple ? "⌘K" : "Ctrl K";
|
||||||
|
}
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { useState } from "react";
|
import { useState, type ReactNode } from "react";
|
||||||
import {
|
import {
|
||||||
Archive,
|
Archive,
|
||||||
ListFilter,
|
ListFilter,
|
||||||
@ -28,6 +28,7 @@ import type {
|
|||||||
SidebarSortMode,
|
SidebarSortMode,
|
||||||
SidebarViewState,
|
SidebarViewState,
|
||||||
} from "@/lib/types";
|
} from "@/lib/types";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
interface SidebarProps {
|
interface SidebarProps {
|
||||||
sessions: ChatSummary[];
|
sessions: ChatSummary[];
|
||||||
@ -44,7 +45,9 @@ interface SidebarProps {
|
|||||||
onToggleArchived: () => void;
|
onToggleArchived: () => void;
|
||||||
onUpdateView: (view: Partial<SidebarViewState>) => void;
|
onUpdateView: (view: Partial<SidebarViewState>) => void;
|
||||||
onCollapse: () => void;
|
onCollapse: () => void;
|
||||||
|
onExpand?: () => void;
|
||||||
containActionMenus?: boolean;
|
containActionMenus?: boolean;
|
||||||
|
collapsed?: boolean;
|
||||||
pinnedKeys?: string[];
|
pinnedKeys?: string[];
|
||||||
archivedKeys?: string[];
|
archivedKeys?: string[];
|
||||||
titleOverrides?: Record<string, string>;
|
titleOverrides?: Record<string, string>;
|
||||||
@ -59,6 +62,8 @@ export function Sidebar(props: SidebarProps) {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [menuPortalContainer, setMenuPortalContainer] =
|
const [menuPortalContainer, setMenuPortalContainer] =
|
||||||
useState<HTMLElement | null>(null);
|
useState<HTMLElement | null>(null);
|
||||||
|
const collapsed = Boolean(props.collapsed);
|
||||||
|
const toggleLabel = t("thread.header.toggleSidebar");
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<nav
|
<nav
|
||||||
@ -66,108 +71,189 @@ export function Sidebar(props: SidebarProps) {
|
|||||||
aria-label={t("sidebar.navigation")}
|
aria-label={t("sidebar.navigation")}
|
||||||
className="flex h-full w-full min-w-0 flex-col border-r border-sidebar-border/60 bg-sidebar text-sidebar-foreground"
|
className="flex h-full w-full min-w-0 flex-col border-r border-sidebar-border/60 bg-sidebar text-sidebar-foreground"
|
||||||
>
|
>
|
||||||
<div className="flex items-center justify-between px-3 pb-2.5 pt-3">
|
<div
|
||||||
<picture className="block min-w-0">
|
className={cn(
|
||||||
<source srcSet="/brand/nanobot_logo.webp" type="image/webp" />
|
"flex items-center px-3 pb-2.5 pt-3",
|
||||||
|
collapsed ? "w-14 justify-start" : "justify-between",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
aria-label={collapsed ? toggleLabel : undefined}
|
||||||
|
aria-hidden={collapsed ? undefined : true}
|
||||||
|
title={collapsed ? toggleLabel : undefined}
|
||||||
|
onClick={collapsed ? props.onExpand : undefined}
|
||||||
|
tabIndex={collapsed ? 0 : -1}
|
||||||
|
className={cn(
|
||||||
|
"flex h-9 w-9 shrink-0 items-center justify-center overflow-hidden rounded-xl transition-colors",
|
||||||
|
collapsed
|
||||||
|
? "-ml-0.5 hover:bg-sidebar-accent/75"
|
||||||
|
: "pointer-events-none -ml-0.5",
|
||||||
|
)}
|
||||||
|
>
|
||||||
<img
|
<img
|
||||||
src="/brand/nanobot_logo.png"
|
src="/brand/nanobot_icon.png"
|
||||||
alt="nanobot"
|
alt=""
|
||||||
className="h-6 w-auto select-none object-contain opacity-95"
|
className="h-8 w-8 select-none object-contain"
|
||||||
draggable={false}
|
draggable={false}
|
||||||
/>
|
/>
|
||||||
</picture>
|
</button>
|
||||||
<Button
|
{!collapsed && (
|
||||||
variant="ghost"
|
<Button
|
||||||
size="icon"
|
variant="ghost"
|
||||||
aria-label={t("sidebar.collapse")}
|
size="icon"
|
||||||
onClick={props.onCollapse}
|
aria-label={t("sidebar.collapse")}
|
||||||
className="h-7 w-7 rounded-lg text-muted-foreground/85 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
onClick={props.onCollapse}
|
||||||
>
|
className="h-7 w-7 rounded-lg text-muted-foreground/85 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
||||||
<Menu className="h-3.5 w-3.5" />
|
>
|
||||||
</Button>
|
<Menu className="h-3.5 w-3.5" />
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="space-y-1.5 px-2 pb-2">
|
<div
|
||||||
<Button
|
className={cn(
|
||||||
|
"space-y-1.5 px-2 pb-2",
|
||||||
|
collapsed && "flex w-14 flex-col items-center px-0",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<SidebarActionButton
|
||||||
|
collapsed={collapsed}
|
||||||
|
label={t("sidebar.newChat")}
|
||||||
onClick={props.onNewChat}
|
onClick={props.onNewChat}
|
||||||
className="h-8 w-full justify-start gap-2 rounded-full px-3 text-[12.5px] font-medium text-sidebar-foreground/92 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
icon={<SquarePen className="h-4 w-4" />}
|
||||||
variant="ghost"
|
/>
|
||||||
>
|
<SidebarActionButton
|
||||||
<SquarePen className="h-3.5 w-3.5" />
|
collapsed={collapsed}
|
||||||
{t("sidebar.newChat")}
|
label={t("sidebar.searchAria")}
|
||||||
</Button>
|
|
||||||
<Button
|
|
||||||
type="button"
|
|
||||||
onClick={props.onOpenSearch}
|
onClick={props.onOpenSearch}
|
||||||
className="h-8 w-full justify-start gap-2 rounded-full px-3 text-[12.5px] font-medium text-sidebar-foreground/85 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
icon={<Search className="h-4 w-4" />}
|
||||||
variant="ghost"
|
/>
|
||||||
>
|
|
||||||
<Search className="h-3.5 w-3.5" aria-hidden />
|
|
||||||
{t("sidebar.searchAria")}
|
|
||||||
</Button>
|
|
||||||
<SidebarViewMenu
|
<SidebarViewMenu
|
||||||
|
compact={collapsed}
|
||||||
view={props.viewState}
|
view={props.viewState}
|
||||||
onUpdateView={props.onUpdateView}
|
onUpdateView={props.onUpdateView}
|
||||||
/>
|
/>
|
||||||
{props.archivedCount ? (
|
{props.archivedCount ? (
|
||||||
<Button
|
<SidebarActionButton
|
||||||
type="button"
|
collapsed={collapsed}
|
||||||
|
label={props.showArchived ? t("chat.hideArchived") : t("chat.showArchived")}
|
||||||
onClick={props.onToggleArchived}
|
onClick={props.onToggleArchived}
|
||||||
className="h-8 w-full justify-start gap-2 rounded-full px-3 text-[12.5px] font-medium text-sidebar-foreground/75 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
icon={<Archive className="h-4 w-4" />}
|
||||||
variant="ghost"
|
/>
|
||||||
>
|
|
||||||
<Archive className="h-3.5 w-3.5" aria-hidden />
|
|
||||||
{props.showArchived ? t("chat.hideArchived") : t("chat.showArchived")}
|
|
||||||
</Button>
|
|
||||||
) : null}
|
) : null}
|
||||||
</div>
|
</div>
|
||||||
<div className="flex min-h-0 min-w-0 flex-1 flex-col overflow-hidden">
|
<div
|
||||||
<ChatList
|
className={cn(
|
||||||
sessions={props.sessions}
|
"flex min-h-0 min-w-0 flex-1 flex-col overflow-hidden transition-opacity duration-200",
|
||||||
activeKey={props.activeKey}
|
collapsed && "pointer-events-none opacity-0",
|
||||||
loading={props.loading}
|
)}
|
||||||
emptyLabel={t("chat.noSessions")}
|
>
|
||||||
onSelect={props.onSelect}
|
{!collapsed && (
|
||||||
onRequestDelete={props.onRequestDelete}
|
<ChatList
|
||||||
onTogglePin={props.onTogglePin}
|
sessions={props.sessions}
|
||||||
onRequestRename={props.onRequestRename}
|
activeKey={props.activeKey}
|
||||||
onToggleArchive={props.onToggleArchive}
|
loading={props.loading}
|
||||||
pinnedKeys={props.pinnedKeys}
|
emptyLabel={t("chat.noSessions")}
|
||||||
archivedKeys={props.archivedKeys}
|
onSelect={props.onSelect}
|
||||||
titleOverrides={props.titleOverrides}
|
onRequestDelete={props.onRequestDelete}
|
||||||
runningChatIds={props.runningChatIds}
|
onTogglePin={props.onTogglePin}
|
||||||
completedChatIds={props.completedChatIds}
|
onRequestRename={props.onRequestRename}
|
||||||
density={props.viewState?.density}
|
onToggleArchive={props.onToggleArchive}
|
||||||
showPreviews={props.viewState?.show_previews}
|
pinnedKeys={props.pinnedKeys}
|
||||||
showTimestamps={props.viewState?.show_timestamps}
|
archivedKeys={props.archivedKeys}
|
||||||
sort={props.viewState?.sort}
|
titleOverrides={props.titleOverrides}
|
||||||
showArchived={props.showArchived}
|
runningChatIds={props.runningChatIds}
|
||||||
actionMenuPortalContainer={
|
completedChatIds={props.completedChatIds}
|
||||||
props.containActionMenus ? menuPortalContainer : undefined
|
density={props.viewState?.density}
|
||||||
}
|
showPreviews={props.viewState?.show_previews}
|
||||||
/>
|
showTimestamps={props.viewState?.show_timestamps}
|
||||||
|
sort={props.viewState?.sort}
|
||||||
|
showArchived={props.showArchived}
|
||||||
|
actionMenuPortalContainer={
|
||||||
|
props.containActionMenus ? menuPortalContainer : undefined
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
<Separator className="bg-sidebar-border/50" />
|
<Separator className="bg-sidebar-border/50" />
|
||||||
<div className="flex items-center gap-1 px-2.5 py-2.5 text-xs">
|
<div
|
||||||
<Button
|
className={cn(
|
||||||
type="button"
|
"flex items-center gap-1 px-2.5 py-2.5 text-xs",
|
||||||
variant="ghost"
|
collapsed && "w-14 flex-col px-0",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<SidebarActionButton
|
||||||
|
collapsed={collapsed}
|
||||||
|
label={t("sidebar.settings")}
|
||||||
onClick={props.onOpenSettings}
|
onClick={props.onOpenSettings}
|
||||||
className="h-8 min-w-0 flex-1 justify-start gap-2 rounded-full px-2.5 text-[12.5px] font-medium text-sidebar-foreground/85 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
className={collapsed ? undefined : "flex-1"}
|
||||||
>
|
icon={<Settings className="h-4 w-4" />}
|
||||||
<Settings className="h-3.5 w-3.5" aria-hidden />
|
/>
|
||||||
{t("sidebar.settings")}
|
|
||||||
</Button>
|
|
||||||
<ConnectionBadge />
|
<ConnectionBadge />
|
||||||
</div>
|
</div>
|
||||||
</nav>
|
</nav>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function SidebarActionButton({
|
||||||
|
collapsed,
|
||||||
|
label,
|
||||||
|
icon,
|
||||||
|
onClick,
|
||||||
|
className,
|
||||||
|
}: {
|
||||||
|
collapsed: boolean;
|
||||||
|
label: string;
|
||||||
|
icon: ReactNode;
|
||||||
|
onClick: () => void;
|
||||||
|
className?: string;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="ghost"
|
||||||
|
aria-label={label}
|
||||||
|
title={collapsed ? label : undefined}
|
||||||
|
onClick={onClick}
|
||||||
|
className={cn(
|
||||||
|
"group h-8 min-w-0 gap-2 overflow-hidden rounded-full font-medium text-sidebar-foreground/85 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground",
|
||||||
|
"transition-[width,padding,border-radius,color,background-color] duration-300 ease-out",
|
||||||
|
collapsed
|
||||||
|
? "w-9 justify-center gap-0 rounded-xl px-0"
|
||||||
|
: "w-full justify-start gap-2 px-3 text-[12.5px]",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<span
|
||||||
|
className={cn(
|
||||||
|
"flex shrink-0 items-center justify-center transition-transform duration-300 ease-out",
|
||||||
|
collapsed ? "translate-x-0" : "translate-x-0",
|
||||||
|
)}
|
||||||
|
aria-hidden
|
||||||
|
>
|
||||||
|
{icon}
|
||||||
|
</span>
|
||||||
|
<span
|
||||||
|
className={cn(
|
||||||
|
"min-w-0 overflow-hidden truncate whitespace-nowrap transition-[max-width,opacity,transform] duration-200 ease-out",
|
||||||
|
collapsed
|
||||||
|
? "max-w-0 -translate-x-1 opacity-0"
|
||||||
|
: "max-w-[12rem] translate-x-0 opacity-100",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{label}
|
||||||
|
</span>
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
function SidebarViewMenu({
|
function SidebarViewMenu({
|
||||||
|
compact = false,
|
||||||
view,
|
view,
|
||||||
onUpdateView,
|
onUpdateView,
|
||||||
}: {
|
}: {
|
||||||
|
compact?: boolean;
|
||||||
view?: SidebarViewState;
|
view?: SidebarViewState;
|
||||||
onUpdateView: (view: Partial<SidebarViewState>) => void;
|
onUpdateView: (view: Partial<SidebarViewState>) => void;
|
||||||
}) {
|
}) {
|
||||||
@ -182,11 +268,28 @@ function SidebarViewMenu({
|
|||||||
<DropdownMenuTrigger asChild>
|
<DropdownMenuTrigger asChild>
|
||||||
<Button
|
<Button
|
||||||
type="button"
|
type="button"
|
||||||
className="h-8 w-full justify-start gap-2 rounded-full px-3 text-[12.5px] font-medium text-sidebar-foreground/75 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
aria-label={t("sidebar.viewOptions")}
|
||||||
|
title={compact ? t("sidebar.viewOptions") : undefined}
|
||||||
|
className={cn(
|
||||||
|
"h-8 min-w-0 overflow-hidden font-medium text-sidebar-foreground/75 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground",
|
||||||
|
"transition-[width,padding,border-radius,color,background-color] duration-300 ease-out",
|
||||||
|
compact
|
||||||
|
? "w-9 justify-center gap-0 rounded-xl px-0"
|
||||||
|
: "w-full justify-start gap-2 rounded-full px-3 text-[12.5px]",
|
||||||
|
)}
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
>
|
>
|
||||||
<ListFilter className="h-3.5 w-3.5" aria-hidden />
|
<ListFilter className="h-4 w-4 shrink-0" aria-hidden />
|
||||||
{t("sidebar.viewOptions")}
|
<span
|
||||||
|
className={cn(
|
||||||
|
"min-w-0 overflow-hidden truncate whitespace-nowrap transition-[max-width,opacity,transform] duration-200 ease-out",
|
||||||
|
compact
|
||||||
|
? "max-w-0 -translate-x-1 opacity-0"
|
||||||
|
: "max-w-[12rem] translate-x-0 opacity-100",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{t("sidebar.viewOptions")}
|
||||||
|
</span>
|
||||||
</Button>
|
</Button>
|
||||||
</DropdownMenuTrigger>
|
</DropdownMenuTrigger>
|
||||||
<DropdownMenuContent align="start" className="w-52">
|
<DropdownMenuContent align="start" className="w-52">
|
||||||
|
|||||||
@ -27,6 +27,7 @@ interface ActivityCounts {
|
|||||||
fileCount: number;
|
fileCount: number;
|
||||||
added: number;
|
added: number;
|
||||||
deleted: number;
|
deleted: number;
|
||||||
|
hasDiffStats: boolean;
|
||||||
hasEditingFiles: boolean;
|
hasEditingFiles: boolean;
|
||||||
hasFailedFiles: boolean;
|
hasFailedFiles: boolean;
|
||||||
primaryFilePath?: string;
|
primaryFilePath?: string;
|
||||||
@ -61,6 +62,7 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act
|
|||||||
}
|
}
|
||||||
let added = 0;
|
let added = 0;
|
||||||
let deleted = 0;
|
let deleted = 0;
|
||||||
|
let hasDiffStats = false;
|
||||||
let hasEditingFiles = false;
|
let hasEditingFiles = false;
|
||||||
let failedFileCount = 0;
|
let failedFileCount = 0;
|
||||||
let primaryFilePath: string | undefined;
|
let primaryFilePath: string | undefined;
|
||||||
@ -77,6 +79,10 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act
|
|||||||
if (edit.status === "error" || edit.binary) {
|
if (edit.status === "error" || edit.binary) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (!hasVisibleDiffStats(edit)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
hasDiffStats = true;
|
||||||
added += edit.added;
|
added += edit.added;
|
||||||
deleted += edit.deleted;
|
deleted += edit.deleted;
|
||||||
}
|
}
|
||||||
@ -86,6 +92,7 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act
|
|||||||
fileCount: fileEdits.length,
|
fileCount: fileEdits.length,
|
||||||
added,
|
added,
|
||||||
deleted,
|
deleted,
|
||||||
|
hasDiffStats,
|
||||||
hasEditingFiles,
|
hasEditingFiles,
|
||||||
hasFailedFiles: fileEdits.length > 0 && failedFileCount === fileEdits.length,
|
hasFailedFiles: fileEdits.length > 0 && failedFileCount === fileEdits.length,
|
||||||
primaryFilePath,
|
primaryFilePath,
|
||||||
@ -120,6 +127,7 @@ export function AgentActivityCluster({
|
|||||||
fileCount,
|
fileCount,
|
||||||
added,
|
added,
|
||||||
deleted,
|
deleted,
|
||||||
|
hasDiffStats,
|
||||||
hasEditingFiles,
|
hasEditingFiles,
|
||||||
hasFailedFiles,
|
hasFailedFiles,
|
||||||
primaryFilePath,
|
primaryFilePath,
|
||||||
@ -140,6 +148,7 @@ export function AgentActivityCluster({
|
|||||||
const headerBusy = fileCount > 0 ? hasEditingFiles : isTurnStreaming;
|
const headerBusy = fileCount > 0 ? hasEditingFiles : isTurnStreaming;
|
||||||
const singleFilePath = fileCount === 1 ? primaryFilePath : undefined;
|
const singleFilePath = fileCount === 1 ? primaryFilePath : undefined;
|
||||||
const singleFileTooltipPath = fileCount === 1 ? primaryFileTooltipPath : undefined;
|
const singleFileTooltipPath = fileCount === 1 ? primaryFileTooltipPath : undefined;
|
||||||
|
const hasVisibleActivity = reasoningSteps > 0 || toolCalls > 0 || fileCount > 0;
|
||||||
|
|
||||||
const fileActivitySummary = fileCount > 0
|
const fileActivitySummary = fileCount > 0
|
||||||
? hasPendingFileEdit && !singleFilePath
|
? hasPendingFileEdit && !singleFilePath
|
||||||
@ -243,6 +252,8 @@ export function AgentActivityCluster({
|
|||||||
autoFollowActivityRef.current = distance < ACTIVITY_SCROLL_NEAR_BOTTOM_PX;
|
autoFollowActivityRef.current = distance < ACTIVITY_SCROLL_NEAR_BOTTOM_PX;
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
if (!hasVisibleActivity) return null;
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className={cn("w-full", hasBodyBelow && "mb-2")}>
|
<div className={cn("w-full", hasBodyBelow && "mb-2")}>
|
||||||
<button
|
<button
|
||||||
@ -282,7 +293,7 @@ export function AgentActivityCluster({
|
|||||||
{summary}
|
{summary}
|
||||||
</StreamingLabelSheen>
|
</StreamingLabelSheen>
|
||||||
)}
|
)}
|
||||||
{fileCount > 0 && (
|
{fileCount > 0 && hasDiffStats && (
|
||||||
<span className="inline-flex min-w-0 items-center gap-1 text-muted-foreground/85">
|
<span className="inline-flex min-w-0 items-center gap-1 text-muted-foreground/85">
|
||||||
<DiffPair added={added} deleted={deleted} />
|
<DiffPair added={added} deleted={deleted} />
|
||||||
</span>
|
</span>
|
||||||
@ -435,6 +446,17 @@ function summarizeFileEdits(edits: UIFileEdit[], active: boolean): FileEditSumma
|
|||||||
summary.absolute_path = edit.absolute_path;
|
summary.absolute_path = edit.absolute_path;
|
||||||
}
|
}
|
||||||
summary.pending = summary.pending || !!edit.pending || !edit.path;
|
summary.pending = summary.pending || !!edit.pending || !edit.path;
|
||||||
|
if (!edit.path && edit.pending) {
|
||||||
|
if (active && edit.status === "editing") {
|
||||||
|
summary.hasActiveEditing = true;
|
||||||
|
summary.approximate = summary.approximate || !!edit.approximate;
|
||||||
|
if (!edit.binary) {
|
||||||
|
summary.added += edit.added;
|
||||||
|
summary.deleted += edit.deleted;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (active && edit.status === "editing") {
|
if (active && edit.status === "editing") {
|
||||||
summary.hasActiveEditing = true;
|
summary.hasActiveEditing = true;
|
||||||
summary.binary = summary.binary || !!edit.binary;
|
summary.binary = summary.binary || !!edit.binary;
|
||||||
@ -461,8 +483,16 @@ function summarizeFileEdits(edits: UIFileEdit[], active: boolean): FileEditSumma
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return order.map((key) => {
|
return order.flatMap((key) => {
|
||||||
const summary = byPath.get(key)!;
|
const summary = byPath.get(key)!;
|
||||||
|
if (
|
||||||
|
!summary.path
|
||||||
|
&& !summary.hasActiveEditing
|
||||||
|
&& !summary.hasSuccessfulChange
|
||||||
|
&& !summary.hasFailed
|
||||||
|
) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
const status: UIFileEdit["status"] = summary.hasActiveEditing
|
const status: UIFileEdit["status"] = summary.hasActiveEditing
|
||||||
? "editing"
|
? "editing"
|
||||||
: summary.hasSuccessfulChange
|
: summary.hasSuccessfulChange
|
||||||
@ -470,7 +500,7 @@ function summarizeFileEdits(edits: UIFileEdit[], active: boolean): FileEditSumma
|
|||||||
: summary.hasFailed
|
: summary.hasFailed
|
||||||
? "error"
|
? "error"
|
||||||
: "done";
|
: "done";
|
||||||
return {
|
return [{
|
||||||
key: summary.key,
|
key: summary.key,
|
||||||
path: summary.path,
|
path: summary.path,
|
||||||
absolute_path: summary.absolute_path,
|
absolute_path: summary.absolute_path,
|
||||||
@ -481,10 +511,14 @@ function summarizeFileEdits(edits: UIFileEdit[], active: boolean): FileEditSumma
|
|||||||
status,
|
status,
|
||||||
pending: summary.pending && !summary.path,
|
pending: summary.pending && !summary.path,
|
||||||
error: summary.error,
|
error: summary.error,
|
||||||
};
|
}];
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function hasVisibleDiffStats(edit: Pick<FileEditSummary, "added" | "deleted">): boolean {
|
||||||
|
return edit.added > 0 || edit.deleted > 0;
|
||||||
|
}
|
||||||
|
|
||||||
function FileEditGroup({ edits }: { edits: FileEditSummary[] }) {
|
function FileEditGroup({ edits }: { edits: FileEditSummary[] }) {
|
||||||
if (edits.length === 0) return null;
|
if (edits.length === 0) return null;
|
||||||
return (
|
return (
|
||||||
@ -500,7 +534,7 @@ function FileEditRow({ edit }: { edit: FileEditSummary }) {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const editing = edit.status === "editing";
|
const editing = edit.status === "editing";
|
||||||
const failed = edit.status === "error";
|
const failed = edit.status === "error";
|
||||||
const hasCountedDiff = !failed && !edit.binary;
|
const hasCountedDiff = !failed && !edit.binary && hasVisibleDiffStats(edit);
|
||||||
return (
|
return (
|
||||||
<li className="grid grid-cols-[minmax(0,1fr)_auto] items-center gap-3 rounded-md px-2 py-1.5 text-xs">
|
<li className="grid grid-cols-[minmax(0,1fr)_auto] items-center gap-3 rounded-md px-2 py-1.5 text-xs">
|
||||||
<div className="flex min-w-0 items-center gap-2">
|
<div className="flex min-w-0 items-center gap-2">
|
||||||
|
|||||||
@ -32,12 +32,17 @@ export function ThreadHeader({
|
|||||||
onClick={onToggleSidebar}
|
onClick={onToggleSidebar}
|
||||||
className={cn(
|
className={cn(
|
||||||
"h-7 w-7 rounded-md text-muted-foreground hover:bg-accent/35 hover:text-foreground",
|
"h-7 w-7 rounded-md text-muted-foreground hover:bg-accent/35 hover:text-foreground",
|
||||||
hideSidebarToggleOnDesktop && "lg:pointer-events-none lg:opacity-0",
|
hideSidebarToggleOnDesktop && "lg:hidden",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<Menu className="h-3.5 w-3.5" />
|
<Menu className="h-3.5 w-3.5" />
|
||||||
</Button>
|
</Button>
|
||||||
<ThemeButton theme={theme} onToggleTheme={onToggleTheme} label={t("thread.header.toggleTheme")} />
|
<ThemeButton
|
||||||
|
theme={theme}
|
||||||
|
onToggleTheme={onToggleTheme}
|
||||||
|
label={t("thread.header.toggleTheme")}
|
||||||
|
className="ml-auto"
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -52,7 +57,7 @@ export function ThreadHeader({
|
|||||||
onClick={onToggleSidebar}
|
onClick={onToggleSidebar}
|
||||||
className={cn(
|
className={cn(
|
||||||
"h-7 w-7 rounded-md text-muted-foreground hover:bg-accent/35 hover:text-foreground",
|
"h-7 w-7 rounded-md text-muted-foreground hover:bg-accent/35 hover:text-foreground",
|
||||||
hideSidebarToggleOnDesktop && "lg:pointer-events-none lg:opacity-0",
|
hideSidebarToggleOnDesktop && "lg:hidden",
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<Menu className="h-3.5 w-3.5" />
|
<Menu className="h-3.5 w-3.5" />
|
||||||
@ -62,7 +67,12 @@ export function ThreadHeader({
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<ThemeButton theme={theme} onToggleTheme={onToggleTheme} label={t("thread.header.toggleTheme")} />
|
<ThemeButton
|
||||||
|
theme={theme}
|
||||||
|
onToggleTheme={onToggleTheme}
|
||||||
|
label={t("thread.header.toggleTheme")}
|
||||||
|
className="ml-auto shrink-0"
|
||||||
|
/>
|
||||||
|
|
||||||
<div aria-hidden className="pointer-events-none absolute inset-x-0 top-full h-4" />
|
<div aria-hidden className="pointer-events-none absolute inset-x-0 top-full h-4" />
|
||||||
</div>
|
</div>
|
||||||
@ -73,10 +83,12 @@ function ThemeButton({
|
|||||||
theme,
|
theme,
|
||||||
onToggleTheme,
|
onToggleTheme,
|
||||||
label,
|
label,
|
||||||
|
className,
|
||||||
}: {
|
}: {
|
||||||
theme: "light" | "dark";
|
theme: "light" | "dark";
|
||||||
onToggleTheme: () => void;
|
onToggleTheme: () => void;
|
||||||
label: string;
|
label: string;
|
||||||
|
className?: string;
|
||||||
}) {
|
}) {
|
||||||
return (
|
return (
|
||||||
<Button
|
<Button
|
||||||
@ -84,7 +96,10 @@ function ThemeButton({
|
|||||||
size="icon"
|
size="icon"
|
||||||
aria-label={label}
|
aria-label={label}
|
||||||
onClick={onToggleTheme}
|
onClick={onToggleTheme}
|
||||||
className="h-8 w-8 rounded-full text-muted-foreground/85 hover:bg-accent/40 hover:text-foreground"
|
className={cn(
|
||||||
|
"h-8 w-8 rounded-full text-muted-foreground/85 hover:bg-accent/40 hover:text-foreground",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
>
|
>
|
||||||
{theme === "dark" ? (
|
{theme === "dark" ? (
|
||||||
<Sun className="h-4 w-4" />
|
<Sun className="h-4 w-4" />
|
||||||
|
|||||||
@ -27,13 +27,25 @@ export function useSessions(): {
|
|||||||
const [loading, setLoading] = useState(true);
|
const [loading, setLoading] = useState(true);
|
||||||
const [error, setError] = useState<string | null>(null);
|
const [error, setError] = useState<string | null>(null);
|
||||||
const tokenRef = useRef(token);
|
const tokenRef = useRef(token);
|
||||||
|
const optimisticKeysRef = useRef<Set<string>>(new Set());
|
||||||
tokenRef.current = token;
|
tokenRef.current = token;
|
||||||
|
|
||||||
const refresh = useCallback(async () => {
|
const refresh = useCallback(async () => {
|
||||||
try {
|
try {
|
||||||
setLoading(true);
|
setLoading(true);
|
||||||
const rows = await listSessions(tokenRef.current);
|
const rows = await listSessions(tokenRef.current);
|
||||||
setSessions(rows);
|
const serverKeys = new Set(rows.map((row) => row.key));
|
||||||
|
setSessions((prev) => [
|
||||||
|
...rows,
|
||||||
|
...prev.filter(
|
||||||
|
(session) =>
|
||||||
|
optimisticKeysRef.current.has(session.key) &&
|
||||||
|
!serverKeys.has(session.key),
|
||||||
|
),
|
||||||
|
]);
|
||||||
|
for (const key of Array.from(optimisticKeysRef.current)) {
|
||||||
|
if (serverKeys.has(key)) optimisticKeysRef.current.delete(key);
|
||||||
|
}
|
||||||
setError(null);
|
setError(null);
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
const msg =
|
const msg =
|
||||||
@ -57,6 +69,7 @@ export function useSessions(): {
|
|||||||
const createChat = useCallback(async (): Promise<string> => {
|
const createChat = useCallback(async (): Promise<string> => {
|
||||||
const chatId = await client.newChat();
|
const chatId = await client.newChat();
|
||||||
const key = `websocket:${chatId}`;
|
const key = `websocket:${chatId}`;
|
||||||
|
optimisticKeysRef.current.add(key);
|
||||||
// Optimistic insert; a subsequent refresh will replace it with the
|
// Optimistic insert; a subsequent refresh will replace it with the
|
||||||
// authoritative row once the server persists the session.
|
// authoritative row once the server persists the session.
|
||||||
setSessions((prev) => [
|
setSessions((prev) => [
|
||||||
@ -77,6 +90,7 @@ export function useSessions(): {
|
|||||||
const deleteChat = useCallback(
|
const deleteChat = useCallback(
|
||||||
async (key: string) => {
|
async (key: string) => {
|
||||||
await apiDeleteSession(tokenRef.current, key);
|
await apiDeleteSession(tokenRef.current, key);
|
||||||
|
optimisticKeysRef.current.delete(key);
|
||||||
setSessions((prev) => prev.filter((s) => s.key !== key));
|
setSessions((prev) => prev.filter((s) => s.key !== key));
|
||||||
},
|
},
|
||||||
[],
|
[],
|
||||||
|
|||||||
@ -271,6 +271,7 @@
|
|||||||
"fallbackTitle": "Chat {{id}}",
|
"fallbackTitle": "Chat {{id}}",
|
||||||
"loading": "Loading…",
|
"loading": "Loading…",
|
||||||
"noSessions": "No sessions yet.",
|
"noSessions": "No sessions yet.",
|
||||||
|
"showMore": "Show {{count}} more",
|
||||||
"actions": "Chat actions for {{title}}",
|
"actions": "Chat actions for {{title}}",
|
||||||
"activity": {
|
"activity": {
|
||||||
"running": "Agent running",
|
"running": "Agent running",
|
||||||
|
|||||||
@ -224,6 +224,7 @@
|
|||||||
"fallbackTitle": "Chat {{id}}",
|
"fallbackTitle": "Chat {{id}}",
|
||||||
"loading": "Cargando…",
|
"loading": "Cargando…",
|
||||||
"noSessions": "Todavía no hay sesiones.",
|
"noSessions": "Todavía no hay sesiones.",
|
||||||
|
"showMore": "Mostrar {{count}} más",
|
||||||
"actions": "Acciones del chat {{title}}",
|
"actions": "Acciones del chat {{title}}",
|
||||||
"activity": {
|
"activity": {
|
||||||
"running": "Agent running",
|
"running": "Agent running",
|
||||||
|
|||||||
@ -224,6 +224,7 @@
|
|||||||
"fallbackTitle": "Discussion {{id}}",
|
"fallbackTitle": "Discussion {{id}}",
|
||||||
"loading": "Chargement…",
|
"loading": "Chargement…",
|
||||||
"noSessions": "Aucune session pour le moment.",
|
"noSessions": "Aucune session pour le moment.",
|
||||||
|
"showMore": "Afficher {{count}} de plus",
|
||||||
"actions": "Actions de la discussion {{title}}",
|
"actions": "Actions de la discussion {{title}}",
|
||||||
"activity": {
|
"activity": {
|
||||||
"running": "Agent running",
|
"running": "Agent running",
|
||||||
|
|||||||
@ -224,6 +224,7 @@
|
|||||||
"fallbackTitle": "Obrolan {{id}}",
|
"fallbackTitle": "Obrolan {{id}}",
|
||||||
"loading": "Memuat…",
|
"loading": "Memuat…",
|
||||||
"noSessions": "Belum ada sesi.",
|
"noSessions": "Belum ada sesi.",
|
||||||
|
"showMore": "Tampilkan {{count}} lagi",
|
||||||
"actions": "Aksi obrolan untuk {{title}}",
|
"actions": "Aksi obrolan untuk {{title}}",
|
||||||
"activity": {
|
"activity": {
|
||||||
"running": "Agent running",
|
"running": "Agent running",
|
||||||
|
|||||||
@ -224,6 +224,7 @@
|
|||||||
"fallbackTitle": "チャット {{id}}",
|
"fallbackTitle": "チャット {{id}}",
|
||||||
"loading": "読み込み中…",
|
"loading": "読み込み中…",
|
||||||
"noSessions": "まだセッションがありません。",
|
"noSessions": "まだセッションがありません。",
|
||||||
|
"showMore": "さらに {{count}} 件表示",
|
||||||
"actions": "「{{title}}」のチャット操作",
|
"actions": "「{{title}}」のチャット操作",
|
||||||
"activity": {
|
"activity": {
|
||||||
"running": "Agent running",
|
"running": "Agent running",
|
||||||
|
|||||||
@ -224,6 +224,7 @@
|
|||||||
"fallbackTitle": "채팅 {{id}}",
|
"fallbackTitle": "채팅 {{id}}",
|
||||||
"loading": "불러오는 중…",
|
"loading": "불러오는 중…",
|
||||||
"noSessions": "아직 세션이 없습니다.",
|
"noSessions": "아직 세션이 없습니다.",
|
||||||
|
"showMore": "{{count}}개 더 보기",
|
||||||
"actions": "{{title}} 채팅 작업",
|
"actions": "{{title}} 채팅 작업",
|
||||||
"activity": {
|
"activity": {
|
||||||
"running": "Agent running",
|
"running": "Agent running",
|
||||||
|
|||||||
@ -224,6 +224,7 @@
|
|||||||
"fallbackTitle": "Trò chuyện {{id}}",
|
"fallbackTitle": "Trò chuyện {{id}}",
|
||||||
"loading": "Đang tải…",
|
"loading": "Đang tải…",
|
||||||
"noSessions": "Chưa có phiên nào.",
|
"noSessions": "Chưa có phiên nào.",
|
||||||
|
"showMore": "Hiển thị thêm {{count}}",
|
||||||
"actions": "Tác vụ cho cuộc trò chuyện {{title}}",
|
"actions": "Tác vụ cho cuộc trò chuyện {{title}}",
|
||||||
"activity": {
|
"activity": {
|
||||||
"running": "Agent running",
|
"running": "Agent running",
|
||||||
|
|||||||
@ -259,6 +259,7 @@
|
|||||||
"fallbackTitle": "对话 {{id}}",
|
"fallbackTitle": "对话 {{id}}",
|
||||||
"loading": "加载中…",
|
"loading": "加载中…",
|
||||||
"noSessions": "还没有会话。",
|
"noSessions": "还没有会话。",
|
||||||
|
"showMore": "再显示 {{count}} 个",
|
||||||
"actions": "“{{title}}” 的会话操作",
|
"actions": "“{{title}}” 的会话操作",
|
||||||
"activity": {
|
"activity": {
|
||||||
"running": "Agent 正在运行",
|
"running": "Agent 正在运行",
|
||||||
|
|||||||
@ -224,6 +224,7 @@
|
|||||||
"fallbackTitle": "對話 {{id}}",
|
"fallbackTitle": "對話 {{id}}",
|
||||||
"loading": "載入中…",
|
"loading": "載入中…",
|
||||||
"noSessions": "目前還沒有會話。",
|
"noSessions": "目前還沒有會話。",
|
||||||
|
"showMore": "再顯示 {{count}} 個",
|
||||||
"actions": "「{{title}}」的會話操作",
|
"actions": "「{{title}}」的會話操作",
|
||||||
"activity": {
|
"activity": {
|
||||||
"running": "Agent 正在執行",
|
"running": "Agent 正在執行",
|
||||||
|
|||||||
@ -271,6 +271,69 @@ describe("AgentActivityCluster", () => {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("does not render zero diff counters for completed edits", () => {
|
||||||
|
render(
|
||||||
|
<AgentActivityCluster
|
||||||
|
messages={activityMessages("", {
|
||||||
|
id: "t2",
|
||||||
|
role: "tool",
|
||||||
|
kind: "trace",
|
||||||
|
content: "edit_file()",
|
||||||
|
traces: ["edit_file()"],
|
||||||
|
fileEdits: [{
|
||||||
|
call_id: "call-edit",
|
||||||
|
tool: "edit_file",
|
||||||
|
path: "src/app.tsx",
|
||||||
|
phase: "end",
|
||||||
|
added: 0,
|
||||||
|
deleted: 0,
|
||||||
|
approximate: false,
|
||||||
|
status: "done",
|
||||||
|
}],
|
||||||
|
createdAt: 3,
|
||||||
|
})}
|
||||||
|
isTurnStreaming={false}
|
||||||
|
hasBodyBelow={false}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(screen.getByRole("button", { name: /edited app\.tsx/i })).toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("+0")).not.toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("-0")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("drops stale pathless pending edits after the turn completes", () => {
|
||||||
|
render(
|
||||||
|
<AgentActivityCluster
|
||||||
|
messages={[{
|
||||||
|
id: "t1",
|
||||||
|
role: "tool",
|
||||||
|
kind: "trace",
|
||||||
|
content: "",
|
||||||
|
traces: [],
|
||||||
|
fileEdits: [{
|
||||||
|
call_id: "call-edit",
|
||||||
|
tool: "edit_file",
|
||||||
|
path: "",
|
||||||
|
phase: "start",
|
||||||
|
added: 98,
|
||||||
|
deleted: 0,
|
||||||
|
approximate: true,
|
||||||
|
status: "editing",
|
||||||
|
pending: true,
|
||||||
|
}],
|
||||||
|
createdAt: 1,
|
||||||
|
}]}
|
||||||
|
isTurnStreaming={false}
|
||||||
|
hasBodyBelow={false}
|
||||||
|
/>,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(screen.queryByRole("button", { name: /preparing edit/i })).not.toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("+98")).not.toBeInTheDocument();
|
||||||
|
expect(screen.queryByText("0 tool calls")).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
it("renders pending file edit placeholders before the path is known", () => {
|
it("renders pending file edit placeholders before the path is known", () => {
|
||||||
render(
|
render(
|
||||||
<AgentActivityCluster
|
<AgentActivityCluster
|
||||||
|
|||||||
@ -992,6 +992,73 @@ describe("App layout", () => {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("opens search from the keyboard shortcut", async () => {
|
||||||
|
mockSessions = [
|
||||||
|
{
|
||||||
|
key: "websocket:chat-a",
|
||||||
|
channel: "websocket",
|
||||||
|
chatId: "chat-a",
|
||||||
|
createdAt: "2026-04-16T10:00:00Z",
|
||||||
|
updatedAt: "2026-04-16T10:00:00Z",
|
||||||
|
preview: "Existing chat",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
render(<App />);
|
||||||
|
|
||||||
|
await waitFor(() => expect(connectSpy).toHaveBeenCalled());
|
||||||
|
fireEvent.keyDown(window, { key: "k", metaKey: true });
|
||||||
|
|
||||||
|
const dialog = await screen.findByRole("dialog", { name: "Search" });
|
||||||
|
expect(within(dialog).queryByText("Global actions")).not.toBeInTheDocument();
|
||||||
|
expect(within(dialog).getByText("Existing chat")).toBeInTheDocument();
|
||||||
|
|
||||||
|
const textbox = within(dialog).getByRole("textbox", { name: "Search" });
|
||||||
|
fireEvent.change(textbox, { target: { value: "missing" } });
|
||||||
|
expect(within(dialog).queryByText("Existing chat")).not.toBeInTheDocument();
|
||||||
|
|
||||||
|
fireEvent.change(textbox, { target: { value: "existing" } });
|
||||||
|
expect(within(dialog).getByText("Existing chat")).toBeInTheDocument();
|
||||||
|
|
||||||
|
fireEvent.keyDown(textbox, { key: "Enter" });
|
||||||
|
await waitFor(() =>
|
||||||
|
expect(screen.queryByRole("dialog", { name: "Search" })).not.toBeInTheDocument(),
|
||||||
|
);
|
||||||
|
expect(createChatSpy).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("keeps large sidebars light while search still covers every chat", async () => {
|
||||||
|
mockSessions = Array.from({ length: 170 }, (_, index) => {
|
||||||
|
const chatId = `chat-${index}`;
|
||||||
|
return {
|
||||||
|
key: `websocket:${chatId}`,
|
||||||
|
channel: "websocket" as const,
|
||||||
|
chatId,
|
||||||
|
createdAt: new Date(Date.UTC(2026, 3, 16, 12, 0 - index)).toISOString(),
|
||||||
|
updatedAt: new Date(Date.UTC(2026, 3, 16, 12, 0 - index)).toISOString(),
|
||||||
|
title: index === 169 ? "Hidden target" : `Bulk chat ${index}`,
|
||||||
|
preview: "",
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
render(<App />);
|
||||||
|
|
||||||
|
await waitFor(() => expect(connectSpy).toHaveBeenCalled());
|
||||||
|
const sidebar = screen.getByRole("navigation", { name: "Sidebar navigation" });
|
||||||
|
await waitFor(() =>
|
||||||
|
expect(within(sidebar).getByRole("button", { name: "Bulk chat 0" })).toBeInTheDocument(),
|
||||||
|
);
|
||||||
|
expect(within(sidebar).queryByText("Hidden target")).not.toBeInTheDocument();
|
||||||
|
expect(within(sidebar).getByRole("button", { name: "Show 10 more" })).toBeInTheDocument();
|
||||||
|
|
||||||
|
fireEvent.click(within(sidebar).getByRole("button", { name: "Search" }));
|
||||||
|
const dialog = await screen.findByRole("dialog", { name: "Search" });
|
||||||
|
fireEvent.change(within(dialog).getByRole("textbox", { name: "Search" }), {
|
||||||
|
target: { value: "hidden" },
|
||||||
|
});
|
||||||
|
expect(within(dialog).getByText("Hidden target")).toBeInTheDocument();
|
||||||
|
});
|
||||||
|
|
||||||
it("opens a blank start page without creating an empty chat", async () => {
|
it("opens a blank start page without creating an empty chat", async () => {
|
||||||
mockSessions = [
|
mockSessions = [
|
||||||
{
|
{
|
||||||
@ -1025,10 +1092,16 @@ describe("App layout", () => {
|
|||||||
|
|
||||||
fireEvent.click(screen.getByRole("button", { name: "Collapse sidebar" }));
|
fireEvent.click(screen.getByRole("button", { name: "Collapse sidebar" }));
|
||||||
const desktopAside = container.querySelector("aside.lg\\:block") as HTMLElement;
|
const desktopAside = container.querySelector("aside.lg\\:block") as HTMLElement;
|
||||||
await waitFor(() => expect(desktopAside.style.width).toBe("0px"));
|
await waitFor(() => expect(desktopAside.style.width).toBe("56px"));
|
||||||
|
|
||||||
expect(screen.queryByRole("button", { name: "Start a new chat" })).not.toBeInTheDocument();
|
expect(screen.queryByRole("button", { name: "Start a new chat" })).not.toBeInTheDocument();
|
||||||
fireEvent.click(screen.getByRole("button", { name: "Toggle sidebar" }));
|
const rail = screen.getByRole("navigation", { name: "Sidebar navigation" });
|
||||||
|
expect(within(rail).getByRole("button", { name: "New chat" })).toBeInTheDocument();
|
||||||
|
expect(within(rail).getByRole("button", { name: "Search" })).toBeInTheDocument();
|
||||||
|
expect(within(rail).getByRole("button", { name: "View" })).toBeInTheDocument();
|
||||||
|
expect(within(rail).queryByText("Existing chat")).not.toBeInTheDocument();
|
||||||
|
|
||||||
|
fireEvent.click(within(rail).getByRole("button", { name: "Toggle sidebar" }));
|
||||||
await waitFor(() => expect(desktopAside.style.width).toBe("272px"));
|
await waitFor(() => expect(desktopAside.style.width).toBe("272px"));
|
||||||
|
|
||||||
const sidebar = screen.getByRole("navigation", { name: "Sidebar navigation" });
|
const sidebar = screen.getByRole("navigation", { name: "Sidebar navigation" });
|
||||||
|
|||||||
@ -78,6 +78,7 @@ describe("webui i18n", () => {
|
|||||||
const common = resource.common;
|
const common = resource.common;
|
||||||
expect(common.app.system.restarting).toBeTruthy();
|
expect(common.app.system.restarting).toBeTruthy();
|
||||||
expect(common.sidebar.settings).toBeTruthy();
|
expect(common.sidebar.settings).toBeTruthy();
|
||||||
|
expect(common.chat.showMore).toBeTruthy();
|
||||||
expect(common.settings.sidebar.title).toBeTruthy();
|
expect(common.settings.sidebar.title).toBeTruthy();
|
||||||
expect(common.settings.backToChat).toBeTruthy();
|
expect(common.settings.backToChat).toBeTruthy();
|
||||||
for (const key of SETTINGS_NAV_KEYS) {
|
for (const key of SETTINGS_NAV_KEYS) {
|
||||||
|
|||||||
@ -157,6 +157,53 @@ describe("useSessions", () => {
|
|||||||
expect(api.listSessions).toHaveBeenCalledTimes(2);
|
expect(api.listSessions).toHaveBeenCalledTimes(2);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("keeps a newly created chat visible until the server session list catches up", async () => {
|
||||||
|
vi.mocked(api.listSessions)
|
||||||
|
.mockResolvedValueOnce([])
|
||||||
|
.mockResolvedValueOnce([])
|
||||||
|
.mockResolvedValueOnce([
|
||||||
|
{
|
||||||
|
key: "websocket:chat-new",
|
||||||
|
channel: "websocket",
|
||||||
|
chatId: "chat-new",
|
||||||
|
createdAt: "2026-05-20T10:00:00Z",
|
||||||
|
updatedAt: "2026-05-20T10:01:00Z",
|
||||||
|
title: "Generated title",
|
||||||
|
preview: "First message",
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
const client = fakeClient();
|
||||||
|
client.newChat.mockResolvedValue("chat-new");
|
||||||
|
|
||||||
|
const { result } = renderHook(() => useSessions(), {
|
||||||
|
wrapper: wrap(client),
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => expect(result.current.loading).toBe(false));
|
||||||
|
expect(result.current.sessions).toEqual([]);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.createChat();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.refresh();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]);
|
||||||
|
expect(result.current.sessions[0]?.preview).toBe("");
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await result.current.refresh();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]);
|
||||||
|
expect(result.current.sessions[0]?.preview).toBe("First message");
|
||||||
|
expect(result.current.sessions[0]?.title).toBe("Generated title");
|
||||||
|
});
|
||||||
|
|
||||||
it("passes through WebUI transcript user media as images and media", async () => {
|
it("passes through WebUI transcript user media as images and media", async () => {
|
||||||
vi.mocked(api.fetchWebuiThread).mockResolvedValue({
|
vi.mocked(api.fetchWebuiThread).mockResolvedValue({
|
||||||
schemaVersion: 3,
|
schemaVersion: 3,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user