mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-23 18:12:32 +00:00
Merge remote-tracking branch 'origin/main' into codex/review-pr-3929
This commit is contained in:
commit
143224e25a
14
README.md
14
README.md
@ -1,6 +1,18 @@
|
||||

|
||||
|
||||
<div align="center">
|
||||
<p>
|
||||
<a href="https://nanobot.wiki/docs/latest/getting-started/nanobot-overview">English</a> |
|
||||
<a href="https://nanobot.wiki/cn/docs/latest/getting-started/nanobot-overview">简体中文</a> |
|
||||
<a href="https://nanobot.wiki/zh-Hant/docs/latest/getting-started/nanobot-overview">繁體中文</a> |
|
||||
<a href="https://nanobot.wiki/es/docs/latest/getting-started/nanobot-overview">Español</a> |
|
||||
<a href="https://nanobot.wiki/fr/docs/latest/getting-started/nanobot-overview">Français</a> |
|
||||
<a href="https://nanobot.wiki/id/docs/latest/getting-started/nanobot-overview">Bahasa Indonesia</a> |
|
||||
<a href="https://nanobot.wiki/ja/docs/latest/getting-started/nanobot-overview">日本語</a> |
|
||||
<a href="https://nanobot.wiki/ko/docs/latest/getting-started/nanobot-overview">한국어</a> |
|
||||
<a href="https://nanobot.wiki/ru/docs/latest/getting-started/nanobot-overview">Русский</a> |
|
||||
<a href="https://nanobot.wiki/vi/docs/latest/getting-started/nanobot-overview">Tiếng Việt</a>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://pypi.org/project/nanobot-ai/"><img src="https://img.shields.io/pypi/v/nanobot-ai" alt="PyPI"></a>
|
||||
<a href="https://pepy.tech/project/nanobot-ai"><img src="https://static.pepy.tech/badge/nanobot-ai" alt="Downloads"></a>
|
||||
@ -61,7 +73,7 @@
|
||||
- **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks.
|
||||
- **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened.
|
||||
- **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media.
|
||||
- **2026-04-10** 📓 Notebook editing tool, multiple MCP servers, Feishu streaming & done-emoji.
|
||||
- **2026-04-10** 📓 Multiple MCP servers, Feishu streaming & done-emoji.
|
||||
- **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config.
|
||||
- **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback.
|
||||
- **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools.
|
||||
|
||||
@ -17,6 +17,7 @@ Connect nanobot to your favorite chat platform. Want to build your own? See the
|
||||
| **Wecom** | Bot ID + Bot Secret |
|
||||
| **Microsoft Teams** | App ID + App Password + public HTTPS endpoint |
|
||||
| **Mochat** | Claw token (auto-setup available) |
|
||||
| **Signal** | signal-cli daemon + phone number |
|
||||
|
||||
<details>
|
||||
<summary><b>Telegram</b> (Recommended)</summary>
|
||||
@ -669,3 +670,69 @@ nanobot gateway
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Signal</b></summary>
|
||||
|
||||
Uses **signal-cli** daemon in HTTP mode — receive messages via SSE, send via JSON-RPC.
|
||||
|
||||
**1. Install signal-cli**
|
||||
|
||||
Install [signal-cli](https://github.com/AsamK/signal-cli) and register a phone number:
|
||||
|
||||
```bash
|
||||
signal-cli -u +1234567890 register
|
||||
signal-cli -u +1234567890 verify <CODE>
|
||||
```
|
||||
|
||||
Start the daemon:
|
||||
|
||||
```bash
|
||||
signal-cli -a +1234567890 daemon --http localhost:8080
|
||||
```
|
||||
|
||||
**2. Configure**
|
||||
|
||||
```json
|
||||
{
|
||||
"channels": {
|
||||
"signal": {
|
||||
"enabled": true,
|
||||
"phoneNumber": "+1234567890",
|
||||
"daemonHost": "localhost",
|
||||
"daemonPort": 8080,
|
||||
"dm": {
|
||||
"enabled": true,
|
||||
"policy": "open"
|
||||
},
|
||||
"group": {
|
||||
"enabled": true,
|
||||
"policy": "open",
|
||||
"requireMention": true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> - `phoneNumber`: Your registered Signal phone number.
|
||||
> - `daemonHost` / `daemonPort`: Where signal-cli daemon is listening (default `localhost:8080`).
|
||||
> - `dm.policy`: `"open"` (anyone can DM) or `"allowlist"` (only listed numbers/UUIDs). When `"allowlist"`, unlisted DM senders receive a pairing code.
|
||||
> - `dm.allowFrom`: List of allowed phone numbers or UUIDs (used when policy is `"allowlist"`).
|
||||
> - `group.policy`: `"open"` (all groups) or `"allowlist"` (only listed group IDs).
|
||||
> - `group.requireMention`: When `true` (default), the bot only responds in groups when @mentioned.
|
||||
> - `group.allowFrom`: List of allowed group IDs (used when group policy is `"allowlist"`).
|
||||
> - `attachmentsDir`: Override the directory where signal-cli stores inbound attachments. Defaults to `~/.local/share/signal-cli/attachments` (the Linux default). Set this if signal-cli runs with a custom `XDG_DATA_HOME` or on macOS/Windows.
|
||||
> - `groupMessageBufferSize`: Number of recent group messages kept for context (default `20`, must be > 0).
|
||||
|
||||
**3. Run**
|
||||
|
||||
```bash
|
||||
nanobot gateway
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> The channel automatically reconnects to the signal-cli daemon with exponential backoff if the connection drops.
|
||||
> Markdown in bot replies is automatically converted to Signal text styles (bold, italic, code, etc.).
|
||||
|
||||
</details>
|
||||
|
||||
@ -148,6 +148,7 @@ ANTHROPIC_API_KEY="$(bw get password api/anthropic)" nanobot agent
|
||||
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
|
||||
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
|
||||
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
|
||||
| `novita` | LLM (Novita AI OpenAI-compatible gateway) | [novita.ai](https://novita.ai) |
|
||||
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
|
||||
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
|
||||
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
|
||||
|
||||
@ -23,7 +23,7 @@ The feature is disabled by default. Enable it in `~/.nanobot/config.json`, confi
|
||||
}
|
||||
```
|
||||
|
||||
See [Provider Notes](#provider-notes) for AIHubMix, MiniMax, and Gemini configuration examples.
|
||||
See [Provider Notes](#provider-notes) for AIHubMix, MiniMax, Gemini, Ollama, and StepFun configuration examples.
|
||||
|
||||
> [!TIP]
|
||||
> Prefer environment variables for API keys. nanobot resolves `${VAR_NAME}` values from the environment at startup.
|
||||
@ -46,7 +46,7 @@ The WebUI hides provider storage details from the user. The agent sees the saved
|
||||
| Option | Type | Default | Description |
|
||||
|--------|------|---------|-------------|
|
||||
| `tools.imageGeneration.enabled` | boolean | `false` | Register the `generate_image` tool |
|
||||
| `tools.imageGeneration.provider` | string | `"openrouter"` | Image provider name. Supported values: `openrouter`, `aihubmix`, `minimax`, `gemini`, `stepfun` |
|
||||
| `tools.imageGeneration.provider` | string | `"openrouter"` | Image provider name. Supported values: `openrouter`, `aihubmix`, `minimax`, `gemini`, `ollama`, `stepfun` |
|
||||
| `tools.imageGeneration.model` | string | `"openai/gpt-5.4-image-2"` | Provider model name |
|
||||
| `tools.imageGeneration.defaultAspectRatio` | string | `"1:1"` | Default ratio when the prompt/tool call does not specify one |
|
||||
| `tools.imageGeneration.defaultImageSize` | string | `"1K"` | Default size hint, for example `1K`, `2K`, `4K`, or `1024x1024` |
|
||||
@ -168,6 +168,31 @@ For reference-image edits, use a Gemini Flash image model:
|
||||
|
||||
Imagen 4 supports the aspect ratios `1:1`, `9:16`, `16:9`, `3:4`, and `4:3`. Unsupported ratios are ignored and the model uses its default. The `defaultImageSize` setting has no effect on Gemini models; sizing is controlled by `defaultAspectRatio` only. Reference images passed with an Imagen model are ignored (with a warning logged).
|
||||
|
||||
### Ollama
|
||||
|
||||
Ollama's experimental native image generation API works with local servers and hosted ollama.com models. Local access at `http://localhost:11434/api` does not require an API key; set `providers.ollama.apiKey` only when targeting `https://ollama.com/api`.
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
"ollama": {
|
||||
"apiBase": "http://localhost:11434/api"
|
||||
}
|
||||
},
|
||||
"tools": {
|
||||
"imageGeneration": {
|
||||
"enabled": true,
|
||||
"provider": "ollama",
|
||||
"model": "x/z-image-turbo",
|
||||
"defaultAspectRatio": "16:9",
|
||||
"defaultImageSize": "2K"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Ollama maps `defaultAspectRatio` and `defaultImageSize` to native `width` and `height` values. Reference images are not supported by this integration.
|
||||
|
||||
### StepFun
|
||||
|
||||
StepFun (阶跃星辰) `step-image-edit-2` supports text-to-image generation. The `step-1x-medium` variant additionally supports **style-reference** image edits, where a reference image guides the visual style of the output.
|
||||
@ -274,7 +299,7 @@ Use the reference image. Keep the same robot and composition, change the palette
|
||||
|---------|-------|
|
||||
| `generate_image` is not available | Set `tools.imageGeneration.enabled` to `true` and restart the gateway |
|
||||
| Missing API key error | Configure `providers.<provider>.apiKey`; if using `${VAR_NAME}`, confirm the environment variable is visible to the gateway process |
|
||||
| `unsupported image generation provider` | Use `openrouter`, `aihubmix`, `minimax`, `gemini`, or `stepfun` |
|
||||
| `unsupported image generation provider` | Use `openrouter`, `aihubmix`, `minimax`, `gemini`, `ollama`, or `stepfun` |
|
||||
| AIHubMix says `Incorrect model ID` | Use `model: "gpt-image-2-free"`; nanobot expands it to the required `openai/gpt-image-2-free` model path internally |
|
||||
| Generation times out | Try a smaller/default image size, set AIHubMix `extraBody.quality` to `"low"`, or retry later |
|
||||
| Reference image rejected | Reference image paths must be inside the workspace or nanobot media directory and must be valid image files |
|
||||
|
||||
@ -22,7 +22,7 @@ from nanobot.utils.prompt_templates import render_template
|
||||
class ContextBuilder:
|
||||
"""Builds the context (system prompt + messages) for the agent."""
|
||||
|
||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
|
||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md"]
|
||||
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||
_MAX_RECENT_HISTORY = 50
|
||||
_MAX_HISTORY_CHARS = 32_000 # hard cap on recent history section size
|
||||
@ -47,6 +47,8 @@ class ContextBuilder:
|
||||
if bootstrap:
|
||||
parts.append(bootstrap)
|
||||
|
||||
parts.append(render_template("agent/tool_contract.md"))
|
||||
|
||||
memory = self.memory.get_memory_context()
|
||||
if memory and not self._is_template_content(self.memory.read_memory(), "memory/MEMORY.md"):
|
||||
parts.append(f"# Memory\n\n{memory}")
|
||||
@ -210,4 +212,3 @@ class ContextBuilder:
|
||||
if not images:
|
||||
return text
|
||||
return images + [{"type": "text", "text": text}]
|
||||
|
||||
|
||||
@ -19,7 +19,8 @@ from nanobot.utils.file_edit_events import (
|
||||
build_file_edit_end_event,
|
||||
build_file_edit_error_event,
|
||||
build_file_edit_start_event,
|
||||
prepare_file_edit_tracker,
|
||||
prepare_file_edit_tracker as _prepare_file_edit_tracker,
|
||||
prepare_file_edit_trackers,
|
||||
StreamingFileEditTracker,
|
||||
)
|
||||
from nanobot.utils.helpers import (
|
||||
@ -58,11 +59,14 @@ _SNIP_SAFETY_BUFFER = 1024
|
||||
_MICROCOMPACT_KEEP_RECENT = 10
|
||||
_MICROCOMPACT_MIN_CHARS = 500
|
||||
_COMPACTABLE_TOOLS = frozenset({
|
||||
"read_file", "exec", "grep",
|
||||
"web_search", "web_fetch", "list_dir",
|
||||
"read_file", "exec", "grep", "find_files",
|
||||
"web_search", "web_fetch", "list_dir", "list_exec_sessions",
|
||||
})
|
||||
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
|
||||
|
||||
# Backward-compatible module attribute for tests/extensions that monkeypatch
|
||||
# the former single-file tracker hook. Runtime uses prepare_file_edit_trackers.
|
||||
prepare_file_edit_tracker = _prepare_file_edit_tracker
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -857,8 +861,8 @@ class AgentRunner:
|
||||
and on_progress_accepts_file_edit_events(spec.progress_callback)
|
||||
)
|
||||
progress_callback = spec.progress_callback if emit_file_edit_events else None
|
||||
file_edit_tracker = (
|
||||
prepare_file_edit_tracker(
|
||||
file_edit_trackers = (
|
||||
prepare_file_edit_trackers(
|
||||
call_id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
tool=tool,
|
||||
@ -868,13 +872,13 @@ class AgentRunner:
|
||||
if progress_callback is not None
|
||||
else None
|
||||
)
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
if file_edit_trackers and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
progress_callback,
|
||||
[build_file_edit_start_event(
|
||||
file_edit_tracker,
|
||||
params if isinstance(params, dict) else None,
|
||||
)],
|
||||
) for file_edit_tracker in file_edit_trackers],
|
||||
)
|
||||
try:
|
||||
if tool is not None:
|
||||
@ -884,10 +888,13 @@ class AgentRunner:
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except BaseException as exc:
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
if file_edit_trackers and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
progress_callback,
|
||||
[build_file_edit_error_event(file_edit_tracker, str(exc))],
|
||||
[
|
||||
build_file_edit_error_event(file_edit_tracker, str(exc))
|
||||
for file_edit_tracker in file_edit_trackers
|
||||
],
|
||||
)
|
||||
event = {
|
||||
"name": tool_call.name,
|
||||
@ -910,10 +917,13 @@ class AgentRunner:
|
||||
return payload, event, None
|
||||
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
if file_edit_trackers and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
progress_callback,
|
||||
[build_file_edit_error_event(file_edit_tracker, result)],
|
||||
[
|
||||
build_file_edit_error_event(file_edit_tracker, result)
|
||||
for file_edit_tracker in file_edit_trackers
|
||||
],
|
||||
)
|
||||
event = {
|
||||
"name": tool_call.name,
|
||||
@ -933,13 +943,13 @@ class AgentRunner:
|
||||
return result + hint, event, RuntimeError(result)
|
||||
return result + hint, event, None
|
||||
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
if file_edit_trackers and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
progress_callback,
|
||||
[build_file_edit_end_event(
|
||||
file_edit_tracker,
|
||||
params if isinstance(params, dict) else None,
|
||||
)],
|
||||
) for file_edit_tracker in file_edit_trackers],
|
||||
)
|
||||
|
||||
detail = "" if result is None else str(result)
|
||||
|
||||
352
nanobot/agent/tools/apply_patch.py
Normal file
352
nanobot/agent/tools/apply_patch.py
Normal file
@ -0,0 +1,352 @@
|
||||
"""Apply file edits by providing structured edit instructions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import tool_parameters
|
||||
from nanobot.agent.tools.filesystem import _FsTool
|
||||
from nanobot.agent.tools.schema import (
|
||||
ArraySchema,
|
||||
BooleanSchema,
|
||||
ObjectSchema,
|
||||
StringSchema,
|
||||
tool_parameters_schema,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _PatchSummary:
|
||||
action: str
|
||||
path: str
|
||||
added: int = 0
|
||||
deleted: int = 0
|
||||
|
||||
|
||||
class _PatchError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
_ABSOLUTE_WINDOWS_RE = re.compile(r"^[A-Za-z]:[\\/]")
|
||||
|
||||
|
||||
def _validate_relative_path(path: str) -> str:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
raise _PatchError("patch path cannot be empty")
|
||||
if "\0" in normalized:
|
||||
raise _PatchError(f"patch path contains a null byte: {path!r}")
|
||||
if normalized.startswith(("~", "/", "\\")) or _ABSOLUTE_WINDOWS_RE.match(normalized):
|
||||
raise _PatchError(f"patch path must be relative: {path}")
|
||||
if any(part == ".." for part in re.split(r"[\\/]+", normalized)):
|
||||
raise _PatchError(f"patch path must not contain '..': {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _lines_to_text(lines: list[str]) -> str:
|
||||
if not lines:
|
||||
return ""
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _text_line_count(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
return len(text.splitlines())
|
||||
|
||||
|
||||
def _line_diff_stats(before: str, after: str) -> tuple[int, int]:
|
||||
before_lines = before.replace("\r\n", "\n").splitlines()
|
||||
after_lines = after.replace("\r\n", "\n").splitlines()
|
||||
added = 0
|
||||
deleted = 0
|
||||
matcher = difflib.SequenceMatcher(a=before_lines, b=after_lines, autojunk=False)
|
||||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||
if tag == "equal":
|
||||
continue
|
||||
if tag in ("replace", "delete"):
|
||||
deleted += i2 - i1
|
||||
if tag in ("replace", "insert"):
|
||||
added += j2 - j1
|
||||
return added, deleted
|
||||
|
||||
|
||||
def _format_summary(summary: _PatchSummary) -> str:
|
||||
stats = ""
|
||||
if summary.added or summary.deleted:
|
||||
stats = f" (+{summary.added}/-{summary.deleted})"
|
||||
return f"- {summary.action} {summary.path}{stats}"
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
edits=ArraySchema(
|
||||
items=ObjectSchema(
|
||||
path=StringSchema("Relative path to the file to edit."),
|
||||
action=StringSchema(
|
||||
"Operation type: replace (find and replace text), add (append new content or create file), delete (remove text).",
|
||||
enum=["replace", "add", "delete"],
|
||||
),
|
||||
old_text=StringSchema(
|
||||
"Exact text to search for in the file. Required for replace and delete.",
|
||||
nullable=True,
|
||||
),
|
||||
new_text=StringSchema(
|
||||
"Text to replace with or append. Required for replace and add.",
|
||||
nullable=True,
|
||||
),
|
||||
required=["path", "action"],
|
||||
),
|
||||
description="List of edits to apply. Each edit specifies a file and the change to make.",
|
||||
min_items=1,
|
||||
max_items=20,
|
||||
),
|
||||
dry_run=BooleanSchema(
|
||||
description="Validate and summarize the patch without writing files.",
|
||||
default=False,
|
||||
),
|
||||
required=["edits"],
|
||||
)
|
||||
)
|
||||
class ApplyPatchTool(_FsTool):
|
||||
"""Apply file edits by providing structured edit instructions."""
|
||||
_scopes = {"core", "subagent"}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "apply_patch"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Default tool for code edits. Supports multi-file changes in a single call. "
|
||||
"Provide a list of structured edits, each specifying a file path, action (replace/add/delete), and the text to change. "
|
||||
"Paths must be relative. Set dry_run=true to validate and preview without writing files. "
|
||||
"Use edit_file only for small exact replacements on a single file."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
edits: list[dict] | None = None,
|
||||
dry_run: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if not edits:
|
||||
raise _PatchError("must provide edits")
|
||||
|
||||
writes: dict[Path, str] = {}
|
||||
deletes: set[Path] = set()
|
||||
summaries: list[_PatchSummary] = []
|
||||
|
||||
for edit in edits:
|
||||
if not isinstance(edit, dict):
|
||||
raise _PatchError("each edit must be an object")
|
||||
raw_path = edit.get("path")
|
||||
if not isinstance(raw_path, str):
|
||||
raise _PatchError("path required for edit")
|
||||
path = _validate_relative_path(raw_path)
|
||||
action = edit.get("action")
|
||||
if not isinstance(action, str):
|
||||
raise _PatchError(f"action required for edit: {path}")
|
||||
source = self._resolve(path)
|
||||
|
||||
if action == "add":
|
||||
new_text = edit.get("new_text")
|
||||
if new_text is None:
|
||||
raise _PatchError(f"new_text required for add: {path}")
|
||||
|
||||
pending = writes.get(source)
|
||||
if pending is not None:
|
||||
content = pending
|
||||
exists = True
|
||||
elif source.exists():
|
||||
raw = source.read_bytes()
|
||||
try:
|
||||
content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
raise _PatchError(f"file is not UTF-8 text: {path}")
|
||||
exists = True
|
||||
else:
|
||||
content = ""
|
||||
exists = False
|
||||
|
||||
if exists:
|
||||
uses_crlf = "\r\n" in content
|
||||
new_norm = content.replace("\r\n", "\n") + new_text.replace("\r\n", "\n")
|
||||
if new_norm and not new_norm.endswith("\n"):
|
||||
new_norm += "\n"
|
||||
if uses_crlf:
|
||||
new_norm = new_norm.replace("\n", "\r\n")
|
||||
writes[source] = new_norm
|
||||
deletes.discard(source)
|
||||
added, deleted = _line_diff_stats(content, new_norm)
|
||||
action_name = "update"
|
||||
else:
|
||||
new_norm = new_text.replace("\r\n", "\n")
|
||||
if new_norm and not new_norm.endswith("\n"):
|
||||
new_norm += "\n"
|
||||
writes[source] = new_norm
|
||||
deletes.discard(source)
|
||||
added = _text_line_count(new_norm)
|
||||
deleted = 0
|
||||
action_name = "add"
|
||||
|
||||
summaries.append(
|
||||
_PatchSummary(
|
||||
action=action_name, path=path, added=added, deleted=deleted
|
||||
)
|
||||
)
|
||||
|
||||
elif action == "replace":
|
||||
old_text = edit.get("old_text") or ""
|
||||
if not old_text:
|
||||
raise _PatchError(f"old_text required for replace: {path}")
|
||||
new_text = edit.get("new_text")
|
||||
if new_text is None:
|
||||
raise _PatchError(f"new_text required for replace: {path}")
|
||||
|
||||
pending = writes.get(source)
|
||||
if pending is not None:
|
||||
content = pending
|
||||
elif source.exists():
|
||||
raw = source.read_bytes()
|
||||
try:
|
||||
content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
raise _PatchError(f"file is not UTF-8 text: {path}")
|
||||
else:
|
||||
raise _PatchError(f"file to update does not exist: {path}")
|
||||
|
||||
if pending is None and not source.is_file():
|
||||
raise _PatchError(f"path to update is not a file: {path}")
|
||||
|
||||
uses_crlf = "\r\n" in content
|
||||
norm_content = content.replace("\r\n", "\n")
|
||||
norm_old = old_text.replace("\r\n", "\n")
|
||||
|
||||
pos = norm_content.find(norm_old)
|
||||
if pos < 0:
|
||||
raise _PatchError(f"old_text not found in {path}")
|
||||
if norm_content.find(norm_old, pos + 1) >= 0:
|
||||
raise _PatchError(f"old_text appears multiple times in {path}")
|
||||
|
||||
new_norm = (
|
||||
norm_content[:pos]
|
||||
+ new_text.replace("\r\n", "\n")
|
||||
+ norm_content[pos + len(norm_old) :]
|
||||
)
|
||||
if new_norm and not new_norm.endswith("\n"):
|
||||
new_norm += "\n"
|
||||
if uses_crlf:
|
||||
new_norm = new_norm.replace("\n", "\r\n")
|
||||
|
||||
writes[source] = new_norm
|
||||
deletes.discard(source)
|
||||
added, deleted = _line_diff_stats(content, new_norm)
|
||||
summaries.append(
|
||||
_PatchSummary(
|
||||
action="update", path=path, added=added, deleted=deleted
|
||||
)
|
||||
)
|
||||
|
||||
elif action == "delete":
|
||||
old_text = edit.get("old_text") or ""
|
||||
if not old_text:
|
||||
raise _PatchError(f"old_text required for delete: {path}")
|
||||
|
||||
pending = writes.get(source)
|
||||
if pending is not None:
|
||||
content = pending
|
||||
elif source.exists():
|
||||
raw = source.read_bytes()
|
||||
try:
|
||||
content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
raise _PatchError(f"file is not UTF-8 text: {path}")
|
||||
else:
|
||||
raise _PatchError(f"file to update does not exist: {path}")
|
||||
|
||||
if pending is None and not source.is_file():
|
||||
raise _PatchError(f"path to update is not a file: {path}")
|
||||
|
||||
uses_crlf = "\r\n" in content
|
||||
norm_content = content.replace("\r\n", "\n")
|
||||
norm_old = old_text.replace("\r\n", "\n")
|
||||
|
||||
pos = norm_content.find(norm_old)
|
||||
if pos < 0:
|
||||
raise _PatchError(f"old_text not found in {path}")
|
||||
if norm_content.find(norm_old, pos + 1) >= 0:
|
||||
raise _PatchError(f"old_text appears multiple times in {path}")
|
||||
|
||||
if norm_old == norm_content:
|
||||
deletes.add(source)
|
||||
writes.pop(source, None)
|
||||
added, deleted = 0, _text_line_count(content)
|
||||
summaries.append(
|
||||
_PatchSummary(
|
||||
action="delete", path=path, added=added, deleted=deleted
|
||||
)
|
||||
)
|
||||
else:
|
||||
new_norm = (
|
||||
norm_content[:pos] + norm_content[pos + len(norm_old) :]
|
||||
)
|
||||
if new_norm and not new_norm.endswith("\n"):
|
||||
new_norm += "\n"
|
||||
if uses_crlf:
|
||||
new_norm = new_norm.replace("\n", "\r\n")
|
||||
writes[source] = new_norm
|
||||
deletes.discard(source)
|
||||
added, deleted = _line_diff_stats(content, new_norm)
|
||||
summaries.append(
|
||||
_PatchSummary(
|
||||
action="update", path=path, added=added, deleted=deleted
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise _PatchError(f"unknown action: {action}")
|
||||
|
||||
if dry_run:
|
||||
return "Patch dry-run succeeded:\n" + "\n".join(
|
||||
_format_summary(summary) for summary in summaries
|
||||
)
|
||||
|
||||
backups: dict[Path, bytes | None] = {}
|
||||
for path in set(writes) | deletes:
|
||||
backups[path] = path.read_bytes() if path.exists() else None
|
||||
|
||||
try:
|
||||
for path in deletes:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
for path, content in writes.items():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content, encoding="utf-8", newline="")
|
||||
except Exception:
|
||||
for path, data in backups.items():
|
||||
if data is None:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
else:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(data)
|
||||
raise
|
||||
|
||||
for path in set(writes) | deletes:
|
||||
self._file_states.record_write(path)
|
||||
return "Patch applied:\n" + "\n".join(
|
||||
_format_summary(summary) for summary in summaries
|
||||
)
|
||||
except PermissionError as exc:
|
||||
return f"Error: {exc}"
|
||||
except _PatchError as exc:
|
||||
return f"Error applying patch: {exc}"
|
||||
except Exception as exc:
|
||||
return f"Error applying patch: {exc}"
|
||||
591
nanobot/agent/tools/exec_session.py
Normal file
591
nanobot/agent/tools/exec_session.py
Normal file
@ -0,0 +1,591 @@
|
||||
"""Session support for long-running exec workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||
|
||||
|
||||
DEFAULT_YIELD_MS = 1000
|
||||
MAX_YIELD_MS = 30_000
|
||||
DEFAULT_WAIT_FOR_MS = 10_000
|
||||
MAX_WAIT_FOR_MS = 120_000
|
||||
DEFAULT_MAX_OUTPUT_CHARS = 10_000
|
||||
MAX_OUTPUT_CHARS = 50_000
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _SessionPoll:
|
||||
output: str
|
||||
done: bool
|
||||
exit_code: int | None
|
||||
elapsed_s: float = 0.0
|
||||
timed_out: bool = False
|
||||
terminated: bool = False
|
||||
stdin_closed: bool = False
|
||||
truncated_chars: int = 0
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ExecSessionInfo:
|
||||
session_id: str
|
||||
command: str
|
||||
cwd: str
|
||||
elapsed_s: float
|
||||
idle_s: float
|
||||
remaining_s: float
|
||||
returncode: int | None
|
||||
|
||||
|
||||
class _ExecSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
process: asyncio.subprocess.Process,
|
||||
command: str,
|
||||
cwd: str,
|
||||
timeout: int,
|
||||
) -> None:
|
||||
self.session_id = session_id
|
||||
self.process = process
|
||||
self.command = command
|
||||
self.cwd = cwd
|
||||
self.started_at = time.monotonic()
|
||||
self.deadline = time.monotonic() + timeout
|
||||
self.last_access = time.monotonic()
|
||||
self._chunks: list[str] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._timed_out = False
|
||||
self._stdout_task = asyncio.create_task(self._read_stream(process.stdout, ""))
|
||||
self._stderr_task = asyncio.create_task(self._read_stream(process.stderr, "STDERR:\n"))
|
||||
|
||||
async def _read_stream(
|
||||
self,
|
||||
stream: asyncio.StreamReader | None,
|
||||
prefix: str,
|
||||
) -> None:
|
||||
if stream is None:
|
||||
return
|
||||
first = True
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
text = chunk.decode("utf-8", errors="replace")
|
||||
if prefix and first:
|
||||
text = prefix + text
|
||||
first = False
|
||||
async with self._lock:
|
||||
self._chunks.append(text)
|
||||
|
||||
async def write(self, chars: str) -> str | None:
|
||||
if self.process.returncode is not None:
|
||||
return "session has already exited"
|
||||
if self.process.stdin is None:
|
||||
return "session stdin is not available"
|
||||
try:
|
||||
self.process.stdin.write(chars.encode("utf-8"))
|
||||
await self.process.stdin.drain()
|
||||
except (BrokenPipeError, ConnectionResetError):
|
||||
return "session stdin is closed"
|
||||
return None
|
||||
|
||||
async def close_stdin(self) -> str | None:
|
||||
if self.process.returncode is not None:
|
||||
return "session has already exited"
|
||||
if self.process.stdin is None:
|
||||
return "session stdin is not available"
|
||||
self.process.stdin.close()
|
||||
with suppress(BrokenPipeError, ConnectionResetError):
|
||||
await self.process.stdin.wait_closed()
|
||||
return None
|
||||
|
||||
async def poll(
|
||||
self,
|
||||
yield_time_ms: int,
|
||||
max_output_chars: int,
|
||||
*,
|
||||
terminated: bool = False,
|
||||
stdin_closed: bool = False,
|
||||
) -> _SessionPoll:
|
||||
self.last_access = time.monotonic()
|
||||
if yield_time_ms > 0 and self.process.returncode is None:
|
||||
await asyncio.sleep(min(yield_time_ms, MAX_YIELD_MS) / 1000)
|
||||
|
||||
if self.process.returncode is None and time.monotonic() >= self.deadline:
|
||||
self._timed_out = True
|
||||
await self.kill()
|
||||
|
||||
if self.process.returncode is not None:
|
||||
with suppress(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(self._stdout_task, self._stderr_task),
|
||||
timeout=2.0,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
output = "".join(self._chunks)
|
||||
self._chunks.clear()
|
||||
|
||||
output, truncated = _truncate_output(output, max_output_chars)
|
||||
return _SessionPoll(
|
||||
output=output,
|
||||
done=self.process.returncode is not None,
|
||||
exit_code=self.process.returncode,
|
||||
elapsed_s=max(0.0, time.monotonic() - self.started_at),
|
||||
timed_out=self._timed_out,
|
||||
terminated=terminated,
|
||||
stdin_closed=stdin_closed,
|
||||
truncated_chars=truncated,
|
||||
)
|
||||
|
||||
async def kill(self) -> None:
|
||||
if self.process.returncode is not None:
|
||||
return
|
||||
self.process.kill()
|
||||
with suppress(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(self.process.wait(), timeout=5.0)
|
||||
|
||||
|
||||
class ExecSessionManager:
|
||||
def __init__(self, *, max_sessions: int = 8, idle_timeout: int = 1800) -> None:
|
||||
self.max_sessions = max_sessions
|
||||
self.idle_timeout = idle_timeout
|
||||
self._sessions: dict[str, _ExecSession] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def start(
|
||||
self,
|
||||
*,
|
||||
command: str,
|
||||
cwd: str,
|
||||
env: dict[str, str],
|
||||
timeout: int,
|
||||
shell_program: str | None,
|
||||
login: bool,
|
||||
yield_time_ms: int,
|
||||
max_output_chars: int,
|
||||
) -> tuple[str, _SessionPoll]:
|
||||
async with self._lock:
|
||||
await self._cleanup_locked()
|
||||
if len(self._sessions) >= self.max_sessions:
|
||||
raise RuntimeError(f"maximum exec sessions reached ({self.max_sessions})")
|
||||
process = await self._spawn(command, cwd, env, shell_program, login)
|
||||
session_id = uuid.uuid4().hex[:12]
|
||||
session = _ExecSession(
|
||||
session_id=session_id,
|
||||
process=process,
|
||||
command=command,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._sessions[session_id] = session
|
||||
|
||||
poll = await session.poll(yield_time_ms, max_output_chars)
|
||||
if poll.done:
|
||||
async with self._lock:
|
||||
self._sessions.pop(session_id, None)
|
||||
return session_id, poll
|
||||
|
||||
async def write(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
chars: str | None,
|
||||
close_stdin: bool,
|
||||
terminate: bool,
|
||||
yield_time_ms: int,
|
||||
max_output_chars: int,
|
||||
) -> _SessionPoll:
|
||||
async with self._lock:
|
||||
await self._cleanup_locked()
|
||||
session = self._sessions.get(session_id)
|
||||
if session is None:
|
||||
raise KeyError(session_id)
|
||||
|
||||
if chars:
|
||||
error = await session.write(chars)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
stdin_closed = False
|
||||
if close_stdin:
|
||||
error = await session.close_stdin()
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
stdin_closed = True
|
||||
if terminate:
|
||||
await session.kill()
|
||||
poll = await session.poll(
|
||||
yield_time_ms,
|
||||
max_output_chars,
|
||||
terminated=terminate,
|
||||
stdin_closed=stdin_closed,
|
||||
)
|
||||
if poll.done:
|
||||
async with self._lock:
|
||||
self._sessions.pop(session_id, None)
|
||||
return poll
|
||||
|
||||
async def list(self) -> list[ExecSessionInfo]:
|
||||
async with self._lock:
|
||||
await self._cleanup_locked()
|
||||
now = time.monotonic()
|
||||
return [
|
||||
ExecSessionInfo(
|
||||
session_id=session_id,
|
||||
command=session.command,
|
||||
cwd=session.cwd,
|
||||
elapsed_s=max(0.0, now - session.started_at),
|
||||
idle_s=max(0.0, now - session.last_access),
|
||||
remaining_s=max(0.0, session.deadline - now),
|
||||
returncode=session.process.returncode,
|
||||
)
|
||||
for session_id, session in sorted(self._sessions.items())
|
||||
]
|
||||
|
||||
async def _cleanup_locked(self) -> None:
|
||||
now = time.monotonic()
|
||||
stale = [
|
||||
session_id
|
||||
for session_id, session in self._sessions.items()
|
||||
if now - session.last_access > self.idle_timeout
|
||||
]
|
||||
for session_id in stale:
|
||||
session = self._sessions.pop(session_id)
|
||||
await session.kill()
|
||||
|
||||
async def _spawn(
|
||||
self,
|
||||
command: str,
|
||||
cwd: str,
|
||||
env: dict[str, str],
|
||||
shell_program: str | None,
|
||||
login: bool,
|
||||
) -> asyncio.subprocess.Process:
|
||||
from nanobot.agent.tools import shell
|
||||
|
||||
if shell._IS_WINDOWS:
|
||||
return await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
shell_program = shell_program or shutil.which("bash") or "/bin/bash"
|
||||
args = [shell_program]
|
||||
if login and shell_program.rsplit("/", 1)[-1] in {"bash", "zsh"}:
|
||||
args.append("-l")
|
||||
args.extend(["-c", command])
|
||||
return await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_EXEC_SESSION_MANAGER = ExecSessionManager()
|
||||
|
||||
|
||||
def clamp_session_int(value: int | None, default: int, minimum: int, maximum: int) -> int:
|
||||
if value is None:
|
||||
return default
|
||||
return min(max(value, minimum), maximum)
|
||||
|
||||
|
||||
def _truncate_output(output: str, max_output_chars: int) -> tuple[str, int]:
|
||||
if len(output) <= max_output_chars:
|
||||
return output, 0
|
||||
half = max_output_chars // 2
|
||||
omitted = len(output) - max_output_chars
|
||||
return (
|
||||
output[:half]
|
||||
+ f"\n\n... ({omitted:,} chars truncated) ...\n\n"
|
||||
+ output[-half:],
|
||||
omitted,
|
||||
)
|
||||
|
||||
|
||||
def format_session_poll(session_id: str, poll: _SessionPoll) -> str:
|
||||
parts = [poll.output] if poll.output else []
|
||||
if poll.truncated_chars:
|
||||
parts.append(f"(output truncated by {poll.truncated_chars:,} chars)")
|
||||
if poll.timed_out:
|
||||
parts.append("Error: Command timed out; session was terminated.")
|
||||
if poll.terminated and not poll.timed_out:
|
||||
parts.append("Session terminated.")
|
||||
if poll.stdin_closed:
|
||||
parts.append("Stdin closed.")
|
||||
if poll.done:
|
||||
parts.append(f"Exit code: {poll.exit_code}")
|
||||
else:
|
||||
parts.append(f"Process running. session_id: {session_id}")
|
||||
parts.append(f"Elapsed: {poll.elapsed_s:.1f}s")
|
||||
return "\n".join(parts) if parts else "(no output yet)"
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
session_id=StringSchema("Session id returned by exec when yield_time_ms is used."),
|
||||
chars=StringSchema(
|
||||
"Bytes/text to write to stdin. Omit or pass an empty string to only poll recent output.",
|
||||
nullable=True,
|
||||
),
|
||||
close_stdin=BooleanSchema(
|
||||
description="Close stdin after writing chars. Useful for commands waiting for EOF.",
|
||||
default=False,
|
||||
),
|
||||
terminate=BooleanSchema(
|
||||
description="Terminate the running exec session.",
|
||||
default=False,
|
||||
),
|
||||
yield_time_ms=IntegerSchema(
|
||||
DEFAULT_YIELD_MS,
|
||||
description="Milliseconds to wait before returning recent output (default 1000, max 30000).",
|
||||
minimum=0,
|
||||
maximum=MAX_YIELD_MS,
|
||||
),
|
||||
wait_for=StringSchema(
|
||||
"Optional text to wait for in output before returning. "
|
||||
"Useful for interactive commands and dev servers.",
|
||||
nullable=True,
|
||||
),
|
||||
wait_timeout_ms=IntegerSchema(
|
||||
DEFAULT_WAIT_FOR_MS,
|
||||
description="Maximum milliseconds to wait for wait_for text (default 10000, max 120000).",
|
||||
minimum=0,
|
||||
maximum=MAX_WAIT_FOR_MS,
|
||||
nullable=True,
|
||||
),
|
||||
max_output_chars=IntegerSchema(
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
description="Maximum output characters to return from this poll (default 10000, max 50000).",
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
),
|
||||
max_output_tokens=IntegerSchema(
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
description="Compatibility alias for max_output_chars. The current runtime uses a character budget.",
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
nullable=True,
|
||||
),
|
||||
required=["session_id"],
|
||||
)
|
||||
)
|
||||
class WriteStdinTool(Tool):
|
||||
"""Write to or poll a running exec session."""
|
||||
|
||||
_scopes = {"core", "subagent"}
|
||||
config_key = "exec"
|
||||
|
||||
@classmethod
|
||||
def config_cls(cls):
|
||||
from nanobot.agent.tools.shell import ExecToolConfig
|
||||
|
||||
return ExecToolConfig
|
||||
|
||||
@classmethod
|
||||
def enabled(cls, ctx: Any) -> bool:
|
||||
return ctx.config.exec.enable
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
manager: ExecSessionManager | None = None,
|
||||
) -> None:
|
||||
self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER
|
||||
|
||||
@classmethod
|
||||
def create(cls, ctx: Any) -> Tool:
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_stdin"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Interact with a running exec session created by exec with "
|
||||
"yield_time_ms. Use chars='' to poll without writing, chars to send "
|
||||
"stdin, close_stdin=true to send EOF, or terminate=true to stop the "
|
||||
"process. Use wait_for with wait_timeout_ms for dev servers, test "
|
||||
"watchers, and prompts where you need to wait for expected output. "
|
||||
"Do not use this to start new commands; start them with exec."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
session_id: str,
|
||||
chars: str | None = None,
|
||||
close_stdin: bool = False,
|
||||
terminate: bool = False,
|
||||
yield_time_ms: int | None = None,
|
||||
wait_for: str | None = None,
|
||||
wait_timeout_ms: int | None = None,
|
||||
max_output_chars: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if max_output_chars is None:
|
||||
max_output_chars = max_output_tokens
|
||||
output_limit = clamp_session_int(
|
||||
max_output_chars,
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
1000,
|
||||
MAX_OUTPUT_CHARS,
|
||||
)
|
||||
if wait_for:
|
||||
return await self._wait_for_output(
|
||||
session_id=session_id,
|
||||
chars=chars,
|
||||
close_stdin=close_stdin,
|
||||
terminate=terminate,
|
||||
wait_for=wait_for,
|
||||
wait_timeout_ms=clamp_session_int(
|
||||
wait_timeout_ms,
|
||||
DEFAULT_WAIT_FOR_MS,
|
||||
0,
|
||||
MAX_WAIT_FOR_MS,
|
||||
),
|
||||
max_output_chars=output_limit,
|
||||
)
|
||||
poll = await self._manager.write(
|
||||
session_id=session_id,
|
||||
chars=chars,
|
||||
close_stdin=close_stdin,
|
||||
terminate=terminate,
|
||||
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
|
||||
max_output_chars=output_limit,
|
||||
)
|
||||
return format_session_poll(session_id, poll)
|
||||
except KeyError:
|
||||
return f"Error: exec session not found: {session_id}"
|
||||
except Exception as exc:
|
||||
return f"Error writing to exec session: {exc}"
|
||||
|
||||
async def _wait_for_output(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
chars: str | None,
|
||||
close_stdin: bool,
|
||||
terminate: bool,
|
||||
wait_for: str,
|
||||
wait_timeout_ms: int,
|
||||
max_output_chars: int,
|
||||
) -> str:
|
||||
deadline = time.monotonic() + (wait_timeout_ms / 1000)
|
||||
aggregate: list[str] = []
|
||||
first = True
|
||||
poll: _SessionPoll | None = None
|
||||
|
||||
while True:
|
||||
remaining_ms = max(0, int((deadline - time.monotonic()) * 1000))
|
||||
step_ms = min(500, remaining_ms)
|
||||
poll = await self._manager.write(
|
||||
session_id=session_id,
|
||||
chars=chars if first else None,
|
||||
close_stdin=close_stdin if first else False,
|
||||
terminate=terminate if first else False,
|
||||
yield_time_ms=step_ms,
|
||||
max_output_chars=max_output_chars,
|
||||
)
|
||||
first = False
|
||||
if poll.output:
|
||||
aggregate.append(poll.output)
|
||||
joined = "".join(aggregate)
|
||||
if wait_for in joined:
|
||||
poll.output = joined
|
||||
return format_session_poll(session_id, poll)
|
||||
if poll.done or remaining_ms <= 0:
|
||||
poll.output = "".join(aggregate)
|
||||
result = format_session_poll(session_id, poll)
|
||||
if wait_for not in poll.output:
|
||||
result += f"\nWait target not observed: {wait_for!r}"
|
||||
return result
|
||||
|
||||
|
||||
@tool_parameters(tool_parameters_schema())
|
||||
class ListExecSessionsTool(Tool):
|
||||
"""List active exec sessions."""
|
||||
|
||||
_scopes = {"core", "subagent"}
|
||||
config_key = "exec"
|
||||
|
||||
@classmethod
|
||||
def config_cls(cls):
|
||||
from nanobot.agent.tools.shell import ExecToolConfig
|
||||
|
||||
return ExecToolConfig
|
||||
|
||||
@classmethod
|
||||
def enabled(cls, ctx: Any) -> bool:
|
||||
return ctx.config.exec.enable
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
manager: ExecSessionManager | None = None,
|
||||
) -> None:
|
||||
self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER
|
||||
|
||||
@classmethod
|
||||
def create(cls, ctx: Any) -> Tool:
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_exec_sessions"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List active long-running exec sessions, including session_id, cwd, "
|
||||
"elapsed time, idle time, remaining timeout, and command preview. "
|
||||
"Use this to recover a session_id after context shifts before "
|
||||
"polling, writing stdin, or terminating with write_stdin."
|
||||
)
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
try:
|
||||
sessions = await self._manager.list()
|
||||
if not sessions:
|
||||
return "No active exec sessions."
|
||||
lines = []
|
||||
for info in sessions:
|
||||
command = " ".join(info.command.split())
|
||||
if len(command) > 120:
|
||||
command = command[:119] + "..."
|
||||
status = "exited" if info.returncode is not None else "running"
|
||||
lines.append(
|
||||
f"{info.session_id} | {status} | elapsed={info.elapsed_s:.1f}s "
|
||||
f"| idle={info.idle_s:.1f}s | remaining={info.remaining_s:.1f}s "
|
||||
f"| cwd={info.cwd} | {command}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
except Exception as exc:
|
||||
return f"Error listing exec sessions: {exc}"
|
||||
@ -132,6 +132,10 @@ def _parse_page_range(pages: str, total: int) -> tuple[int, int]:
|
||||
minimum=1,
|
||||
),
|
||||
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
|
||||
force=BooleanSchema(
|
||||
description="Bypass same-file read deduplication and return content again.",
|
||||
default=False,
|
||||
),
|
||||
required=["path"],
|
||||
)
|
||||
)
|
||||
@ -154,7 +158,11 @@ class ReadFileTool(_FsTool):
|
||||
"Text output format: LINE_NUM|CONTENT. "
|
||||
"Images return visual content for analysis. "
|
||||
"Supports PDF, DOCX, XLSX, PPTX documents. "
|
||||
"Use find_files/list_dir first when the path is uncertain. "
|
||||
"Read the relevant range before editing so replacements or patches "
|
||||
"are based on current content. "
|
||||
"Use offset and limit for large text files. "
|
||||
"Use force=true to re-read content even if unchanged. "
|
||||
"Reads exceeding ~128K chars are truncated."
|
||||
)
|
||||
|
||||
@ -162,7 +170,15 @@ class ReadFileTool(_FsTool):
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any:
|
||||
async def execute(
|
||||
self,
|
||||
path: str | None = None,
|
||||
offset: int = 1,
|
||||
limit: int | None = None,
|
||||
pages: str | None = None,
|
||||
force: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
if not path:
|
||||
return "Error reading file: Unknown path"
|
||||
@ -202,7 +218,13 @@ class ReadFileTool(_FsTool):
|
||||
current_mtime = os.path.getmtime(fp)
|
||||
except OSError:
|
||||
current_mtime = 0.0
|
||||
if entry and entry.can_dedup and entry.offset == offset and entry.limit == limit:
|
||||
if (
|
||||
not force
|
||||
and entry
|
||||
and entry.can_dedup
|
||||
and entry.offset == offset
|
||||
and entry.limit == limit
|
||||
):
|
||||
if current_mtime != entry.mtime:
|
||||
# File was modified externally - force full read and mark as not dedupable
|
||||
entry.can_dedup = False
|
||||
@ -365,9 +387,10 @@ class WriteFileTool(_FsTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write content to a file. Overwrites if the file already exists; "
|
||||
"creates parent directories as needed. "
|
||||
"For partial edits, prefer edit_file instead."
|
||||
"Create a new file or intentionally replace an entire file with "
|
||||
"the provided content. Overwrites existing files and creates parent "
|
||||
"directories as needed. For code changes or partial edits, prefer "
|
||||
"apply_patch; use edit_file only for small exact replacements."
|
||||
)
|
||||
|
||||
async def execute(self, path: str | None = None, content: str | None = None, **kwargs: Any) -> str:
|
||||
@ -657,6 +680,24 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||
old_text=StringSchema("The text to find and replace"),
|
||||
new_text=StringSchema("The text to replace with"),
|
||||
replace_all=BooleanSchema(description="Replace all occurrences (default false)"),
|
||||
occurrence=IntegerSchema(
|
||||
1,
|
||||
description="Optional 1-based occurrence to replace when old_text appears multiple times.",
|
||||
minimum=1,
|
||||
nullable=True,
|
||||
),
|
||||
line_hint=IntegerSchema(
|
||||
1,
|
||||
description="Optional 1-based line hint used to choose the nearest match.",
|
||||
minimum=1,
|
||||
nullable=True,
|
||||
),
|
||||
expected_replacements=IntegerSchema(
|
||||
1,
|
||||
description="Optional guard for the number of replacements that must be made.",
|
||||
minimum=1,
|
||||
nullable=True,
|
||||
),
|
||||
required=["path", "old_text", "new_text"],
|
||||
)
|
||||
)
|
||||
@ -674,10 +715,13 @@ class EditFileTool(_FsTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a file by replacing old_text with new_text. "
|
||||
"Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. "
|
||||
"If old_text matches multiple times, you must provide more context "
|
||||
"or set replace_all=true. Shows a diff of the closest match on failure."
|
||||
"Perform a small, exact replacement in one file by replacing "
|
||||
"old_text with new_text. Use this for narrow text substitutions "
|
||||
"with old_text copied from read_file. For multi-file, structural, "
|
||||
"or generated code edits, prefer apply_patch. If old_text matches "
|
||||
"multiple times, provide more context or set occurrence, line_hint, "
|
||||
"replace_all, and expected_replacements. Shows closest-match "
|
||||
"diagnostics on failure."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -688,7 +732,8 @@ class EditFileTool(_FsTool):
|
||||
async def execute(
|
||||
self, path: str | None = None, old_text: str | None = None,
|
||||
new_text: str | None = None,
|
||||
replace_all: bool = False, **kwargs: Any,
|
||||
replace_all: bool = False, occurrence: int | None = None,
|
||||
line_hint: int | None = None, expected_replacements: int | None = None, **kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if not path:
|
||||
@ -697,10 +742,12 @@ class EditFileTool(_FsTool):
|
||||
raise ValueError("Unknown old_text")
|
||||
if new_text is None:
|
||||
raise ValueError("Unknown new_text")
|
||||
|
||||
# .ipynb detection
|
||||
if path.endswith(".ipynb"):
|
||||
return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file."
|
||||
if occurrence is not None and occurrence < 1:
|
||||
return "Error: occurrence must be >= 1."
|
||||
if line_hint is not None and line_hint < 1:
|
||||
return "Error: line_hint must be >= 1."
|
||||
if expected_replacements is not None and expected_replacements < 1:
|
||||
return "Error: expected_replacements must be >= 1."
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
@ -743,15 +790,42 @@ class EditFileTool(_FsTool):
|
||||
if not matches:
|
||||
return self._not_found_msg(old_text, content, path)
|
||||
count = len(matches)
|
||||
if replace_all and occurrence is not None:
|
||||
return "Error: occurrence cannot be used with replace_all=true."
|
||||
if replace_all and line_hint is not None:
|
||||
return "Error: line_hint cannot be used with replace_all=true."
|
||||
if occurrence is not None and line_hint is not None:
|
||||
return "Error: line_hint cannot be used with occurrence."
|
||||
if count > 1 and not replace_all:
|
||||
line_numbers = [match.line for match in matches]
|
||||
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
|
||||
if len(line_numbers) > 3:
|
||||
preview += ", ..."
|
||||
location_hint = f" at {preview}" if preview else ""
|
||||
if occurrence is not None:
|
||||
if occurrence > count:
|
||||
return (
|
||||
f"Error: occurrence {occurrence} is out of range; "
|
||||
f"old_text appears {count} times."
|
||||
)
|
||||
elif line_hint is not None:
|
||||
nearest = min(matches, key=lambda match: abs(match.line - line_hint))
|
||||
distance = abs(nearest.line - line_hint)
|
||||
if sum(1 for match in matches if abs(match.line - line_hint) == distance) > 1:
|
||||
return (
|
||||
f"Error: line_hint {line_hint} is ambiguous; "
|
||||
f"old_text appears {count} times."
|
||||
)
|
||||
else:
|
||||
line_numbers = [match.line for match in matches]
|
||||
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
|
||||
if len(line_numbers) > 3:
|
||||
preview += ", ..."
|
||||
location_hint = f" at {preview}" if preview else ""
|
||||
return (
|
||||
f"Warning: old_text appears {count} times{location_hint}. "
|
||||
"Provide more context, set occurrence to choose one match, "
|
||||
"or set replace_all=true."
|
||||
)
|
||||
elif occurrence is not None and occurrence > count:
|
||||
return (
|
||||
f"Warning: old_text appears {count} times{location_hint}. "
|
||||
"Provide more context to make it unique, or set replace_all=true."
|
||||
f"Error: occurrence {occurrence} is out of range; "
|
||||
f"old_text appears {count} time."
|
||||
)
|
||||
|
||||
norm_new = new_text.replace("\r\n", "\n")
|
||||
@ -760,7 +834,17 @@ class EditFileTool(_FsTool):
|
||||
if fp.suffix.lower() not in self._MARKDOWN_EXTS:
|
||||
norm_new = self._strip_trailing_ws(norm_new)
|
||||
|
||||
selected = matches if replace_all else matches[:1]
|
||||
if replace_all:
|
||||
selected = matches
|
||||
elif line_hint is not None:
|
||||
selected = [min(matches, key=lambda match: abs(match.line - line_hint))]
|
||||
else:
|
||||
selected = [matches[occurrence - 1 if occurrence else 0]]
|
||||
if expected_replacements is not None and len(selected) != expected_replacements:
|
||||
return (
|
||||
f"Error: expected {expected_replacements} replacements but "
|
||||
f"would make {len(selected)}."
|
||||
)
|
||||
new_content = content
|
||||
for match in reversed(selected):
|
||||
replacement = _preserve_quote_style(norm_old, match.text, norm_new)
|
||||
|
||||
@ -130,12 +130,6 @@ class ImageGenerationTool(Tool):
|
||||
}
|
||||
return cls(**kwargs)
|
||||
|
||||
def _missing_api_key_error(self) -> str:
|
||||
cls = get_image_gen_provider(self.config.provider)
|
||||
if cls and cls.missing_key_message:
|
||||
return f"Error: {cls.missing_key_message}"
|
||||
return f"Error: {self.config.provider} API key is not configured."
|
||||
|
||||
def _resolve_reference_image(self, value: str) -> str:
|
||||
raw_path = Path(value).expanduser()
|
||||
path = raw_path if raw_path.is_absolute() else self.workspace / raw_path
|
||||
@ -173,9 +167,6 @@ class ImageGenerationTool(Tool):
|
||||
client = self._provider_client()
|
||||
if client is None:
|
||||
return f"Error: unsupported image generation provider '{self.config.provider}'"
|
||||
provider = self._provider_config()
|
||||
if not provider or not provider.api_key:
|
||||
return self._missing_api_key_error()
|
||||
|
||||
requested = count or 1
|
||||
if requested > self.config.max_images_per_turn:
|
||||
|
||||
@ -1,162 +0,0 @@
|
||||
"""NotebookEditTool — edit Jupyter .ipynb notebooks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import tool_parameters
|
||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools.filesystem import _FsTool
|
||||
|
||||
|
||||
def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict:
|
||||
cell: dict[str, Any] = {
|
||||
"cell_type": cell_type,
|
||||
"source": source,
|
||||
"metadata": {},
|
||||
}
|
||||
if cell_type == "code":
|
||||
cell["outputs"] = []
|
||||
cell["execution_count"] = None
|
||||
if generate_id:
|
||||
cell["id"] = uuid.uuid4().hex[:8]
|
||||
return cell
|
||||
|
||||
|
||||
def _make_empty_notebook() -> dict:
|
||||
return {
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5,
|
||||
"metadata": {
|
||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
||||
"language_info": {"name": "python"},
|
||||
},
|
||||
"cells": [],
|
||||
}
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
path=StringSchema("Path to the .ipynb notebook file"),
|
||||
cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0),
|
||||
new_source=StringSchema("New source content for the cell"),
|
||||
cell_type=StringSchema(
|
||||
"Cell type: 'code' or 'markdown' (default: code)",
|
||||
enum=["code", "markdown"],
|
||||
),
|
||||
edit_mode=StringSchema(
|
||||
"Mode: 'replace' (default), 'insert' (after target), or 'delete'",
|
||||
enum=["replace", "insert", "delete"],
|
||||
),
|
||||
required=["path", "cell_index"],
|
||||
)
|
||||
)
|
||||
class NotebookEditTool(_FsTool):
|
||||
"""Edit Jupyter notebook cells: replace, insert, or delete."""
|
||||
_scopes = {"core"}
|
||||
|
||||
_VALID_CELL_TYPES = frozenset({"code", "markdown"})
|
||||
_VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"})
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "notebook_edit"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a Jupyter notebook (.ipynb) cell. "
|
||||
"Modes: replace (default) replaces cell content, "
|
||||
"insert adds a new cell after the target index, "
|
||||
"delete removes the cell at the index. "
|
||||
"cell_index is 0-based."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
path: str | None = None,
|
||||
cell_index: int = 0,
|
||||
new_source: str = "",
|
||||
cell_type: str = "code",
|
||||
edit_mode: str = "replace",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if not path:
|
||||
return "Error: path is required"
|
||||
|
||||
if not path.endswith(".ipynb"):
|
||||
return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files."
|
||||
|
||||
if edit_mode not in self._VALID_EDIT_MODES:
|
||||
return (
|
||||
f"Error: Invalid edit_mode '{edit_mode}'. "
|
||||
"Use one of: replace, insert, delete."
|
||||
)
|
||||
|
||||
if cell_type not in self._VALID_CELL_TYPES:
|
||||
return (
|
||||
f"Error: Invalid cell_type '{cell_type}'. "
|
||||
"Use one of: code, markdown."
|
||||
)
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
# Create new notebook if file doesn't exist and mode is insert
|
||||
if not fp.exists():
|
||||
if edit_mode != "insert":
|
||||
return f"Error: File not found: {path}"
|
||||
nb = _make_empty_notebook()
|
||||
cell = _new_cell(new_source, cell_type, generate_id=True)
|
||||
nb["cells"].append(cell)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully created {fp} with 1 cell"
|
||||
|
||||
try:
|
||||
nb = json.loads(fp.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
return f"Error: Failed to parse notebook: {e}"
|
||||
|
||||
cells = nb.get("cells", [])
|
||||
nbformat_minor = nb.get("nbformat_minor", 0)
|
||||
generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5
|
||||
|
||||
if edit_mode == "delete":
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells.pop(cell_index)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully deleted cell {cell_index} from {fp}"
|
||||
|
||||
if edit_mode == "insert":
|
||||
insert_at = min(cell_index + 1, len(cells))
|
||||
cell = _new_cell(new_source, cell_type, generate_id=generate_id)
|
||||
cells.insert(insert_at, cell)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully inserted cell at index {insert_at} in {fp}"
|
||||
|
||||
# Default: replace
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells[cell_index]["source"] = new_source
|
||||
if cell_type and cells[cell_index].get("cell_type") != cell_type:
|
||||
cells[cell_index]["cell_type"] = cell_type
|
||||
if cell_type == "code":
|
||||
cells[cell_index].setdefault("outputs", [])
|
||||
cells[cell_index].setdefault("execution_count", None)
|
||||
elif "outputs" in cells[cell_index]:
|
||||
del cells[cell_index]["outputs"]
|
||||
cells[cell_index].pop("execution_count", None)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully edited cell {cell_index} in {fp}"
|
||||
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing notebook: {e}"
|
||||
@ -1,4 +1,4 @@
|
||||
"""Search tools: grep."""
|
||||
"""Search tools: file discovery and grep."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -12,6 +12,7 @@ from typing import Any, Iterable, TypeVar
|
||||
from nanobot.agent.tools.filesystem import ListDirTool, _FsTool
|
||||
|
||||
_DEFAULT_HEAD_LIMIT = 250
|
||||
_DEFAULT_FILE_HEAD_LIMIT = 200
|
||||
T = TypeVar("T")
|
||||
_TYPE_GLOB_MAP = {
|
||||
"py": ("*.py", "*.pyi"),
|
||||
@ -88,6 +89,14 @@ def _matches_type(name: str, file_type: str | None) -> bool:
|
||||
return any(fnmatch.fnmatch(name.lower(), pattern.lower()) for pattern in patterns)
|
||||
|
||||
|
||||
def _matches_query(rel_path: str, query: str | None) -> bool:
|
||||
if not query:
|
||||
return True
|
||||
haystack = rel_path.lower()
|
||||
terms = [part for part in query.lower().split() if part]
|
||||
return all(term in haystack for term in terms)
|
||||
|
||||
|
||||
class _SearchTool(_FsTool):
|
||||
_IGNORE_DIRS = set(ListDirTool._IGNORE_DIRS)
|
||||
|
||||
@ -109,6 +118,163 @@ class _SearchTool(_FsTool):
|
||||
yield current / filename
|
||||
|
||||
|
||||
class FindFilesTool(_SearchTool):
|
||||
"""Find files by path fragment, glob, or type."""
|
||||
_scopes = {"core", "subagent"}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_files"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Find files by path fragment, glob, or file type. "
|
||||
"Use this before read_file when you need to locate files, and "
|
||||
"prefer it over shell find/ls for ordinary workspace discovery. "
|
||||
"Returns workspace-relative paths and skips common dependency/build "
|
||||
"directories."
|
||||
)
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search in (default '.')",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional case-insensitive path fragment search. "
|
||||
"Whitespace-separated terms must all be present."
|
||||
),
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "Optional file filter, e.g. '*.py' or 'tests/**/test_*.py'",
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "Optional file type shorthand, e.g. 'py', 'ts', 'md', 'json'",
|
||||
},
|
||||
"include_dirs": {
|
||||
"type": "boolean",
|
||||
"description": "Include matching directories as well as files (default false)",
|
||||
},
|
||||
"sort": {
|
||||
"type": "string",
|
||||
"enum": ["path", "modified"],
|
||||
"description": "Sort by path or most recently modified first (default path)",
|
||||
},
|
||||
"head_limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of paths to return (default 200, 0 for all, max 1000)",
|
||||
"minimum": 0,
|
||||
"maximum": 1000,
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Skip the first N results before applying head_limit",
|
||||
"minimum": 0,
|
||||
"maximum": 100000,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _iter_paths(self, root: Path, *, include_dirs: bool) -> Iterable[Path]:
|
||||
if root.is_file():
|
||||
yield root
|
||||
return
|
||||
if include_dirs:
|
||||
yield root
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
dirnames[:] = sorted(d for d in dirnames if d not in self._IGNORE_DIRS)
|
||||
current = Path(dirpath)
|
||||
if include_dirs and current != root:
|
||||
yield current
|
||||
for filename in sorted(filenames):
|
||||
yield current / filename
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
path: str = ".",
|
||||
query: str | None = None,
|
||||
glob: str | None = None,
|
||||
type: str | None = None,
|
||||
include_dirs: bool = False,
|
||||
sort: str = "path",
|
||||
head_limit: int | None = None,
|
||||
offset: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
target = self._resolve(path or ".")
|
||||
if not target.exists():
|
||||
return f"Error: Path not found: {path}"
|
||||
if not (target.is_dir() or target.is_file()):
|
||||
return f"Error: Unsupported path: {path}"
|
||||
|
||||
if sort not in {"path", "modified"}:
|
||||
return "Error: sort must be 'path' or 'modified'"
|
||||
|
||||
limit = (
|
||||
_DEFAULT_FILE_HEAD_LIMIT
|
||||
if head_limit is None
|
||||
else None if head_limit == 0 else head_limit
|
||||
)
|
||||
root = target if target.is_dir() else target.parent
|
||||
matches: list[tuple[str, float]] = []
|
||||
|
||||
for candidate in self._iter_paths(target, include_dirs=include_dirs):
|
||||
if candidate.is_dir() and not include_dirs:
|
||||
continue
|
||||
rel_path = candidate.relative_to(root).as_posix()
|
||||
display_path = self._display_path(candidate, root)
|
||||
name = candidate.name
|
||||
|
||||
if glob and not _match_glob(rel_path, name, glob):
|
||||
continue
|
||||
if candidate.is_file() and not _matches_type(name, type):
|
||||
continue
|
||||
if candidate.is_dir() and type:
|
||||
continue
|
||||
if not _matches_query(display_path, query):
|
||||
continue
|
||||
try:
|
||||
mtime = candidate.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = 0.0
|
||||
suffix = "/" if candidate.is_dir() else ""
|
||||
matches.append((display_path + suffix, mtime))
|
||||
|
||||
if sort == "modified":
|
||||
matches.sort(key=lambda item: (-item[1], item[0]))
|
||||
else:
|
||||
matches.sort(key=lambda item: item[0])
|
||||
|
||||
paths = [item[0] for item in matches]
|
||||
paged, truncated = _paginate(paths, limit, offset)
|
||||
if not paged:
|
||||
return "No files found"
|
||||
|
||||
result = "\n".join(paged)
|
||||
note = _pagination_note(limit, offset, truncated)
|
||||
if note:
|
||||
result += "\n\n" + note
|
||||
return result
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error finding files: {e}"
|
||||
|
||||
|
||||
class GrepTool(_SearchTool):
|
||||
"""Search file contents using a regex-like pattern."""
|
||||
_scopes = {"core", "subagent"}
|
||||
@ -125,7 +291,8 @@ class GrepTool(_SearchTool):
|
||||
return (
|
||||
"Search file contents with a regex pattern. "
|
||||
"Default output_mode is files_with_matches (file paths only); "
|
||||
"use content mode for matching lines with context. "
|
||||
"use content mode for matching lines with context. Prefer this "
|
||||
"over shell grep for ordinary workspace searches. "
|
||||
"Skips binary and files >2 MB. Supports glob/type filtering."
|
||||
)
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import re
|
||||
import shutil
|
||||
import sys
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -15,8 +16,17 @@ from loguru import logger
|
||||
from pydantic import Field
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.exec_session import (
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
DEFAULT_YIELD_MS,
|
||||
DEFAULT_EXEC_SESSION_MANAGER,
|
||||
MAX_OUTPUT_CHARS,
|
||||
MAX_YIELD_MS,
|
||||
clamp_session_int,
|
||||
format_session_poll,
|
||||
)
|
||||
from nanobot.agent.tools.sandbox import wrap_command
|
||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
@ -44,10 +54,22 @@ class ExecToolConfig(Base):
|
||||
deny_patterns: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _PreparedCommand:
|
||||
command: str
|
||||
cwd: str
|
||||
env: dict[str, str]
|
||||
timeout: int
|
||||
shell_program: str | None
|
||||
login: bool
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
command=StringSchema("The shell command to execute"),
|
||||
cmd=StringSchema("Compatibility alias for command"),
|
||||
working_dir=StringSchema("Optional working directory for the command"),
|
||||
workdir=StringSchema("Compatibility alias for working_dir"),
|
||||
timeout=IntegerSchema(
|
||||
60,
|
||||
description=(
|
||||
@ -57,7 +79,44 @@ class ExecToolConfig(Base):
|
||||
minimum=1,
|
||||
maximum=600,
|
||||
),
|
||||
required=["command"],
|
||||
shell=StringSchema(
|
||||
"Optional shell binary to launch. On Unix, supports sh, bash, or zsh.",
|
||||
nullable=True,
|
||||
),
|
||||
login=BooleanSchema(
|
||||
description="Whether to run bash/zsh with login shell semantics (default true).",
|
||||
default=True,
|
||||
nullable=True,
|
||||
),
|
||||
yield_time_ms=IntegerSchema(
|
||||
description=(
|
||||
"Optional milliseconds to wait before returning output. "
|
||||
"When set, a still-running command returns a session_id that "
|
||||
"can be polled or written to with write_stdin. Omit this field "
|
||||
"to keep one-shot exec behavior."
|
||||
),
|
||||
minimum=0,
|
||||
maximum=MAX_YIELD_MS,
|
||||
nullable=True,
|
||||
),
|
||||
max_output_chars=IntegerSchema(
|
||||
description=(
|
||||
"Maximum output characters to return when yield_time_ms is used "
|
||||
"(default 10000, max 50000)."
|
||||
),
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
nullable=True,
|
||||
),
|
||||
max_output_tokens=IntegerSchema(
|
||||
description=(
|
||||
"Compatibility alias for max_output_chars. The current runtime "
|
||||
"uses a character budget."
|
||||
),
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
class ExecTool(Tool):
|
||||
@ -98,6 +157,7 @@ class ExecTool(Tool):
|
||||
sandbox: str = "",
|
||||
path_append: str = "",
|
||||
allowed_env_keys: list[str] | None = None,
|
||||
session_manager: Any | None = None,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.working_dir = working_dir
|
||||
@ -125,6 +185,7 @@ class ExecTool(Tool):
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.path_append = path_append
|
||||
self.allowed_env_keys = allowed_env_keys or []
|
||||
self._session_manager = session_manager or DEFAULT_EXEC_SESSION_MANAGER
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -150,10 +211,15 @@ class ExecTool(Tool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a shell command and return its output. "
|
||||
"Prefer read_file/write_file/edit_file over cat/echo/sed, "
|
||||
"and grep/glob over shell find/grep. "
|
||||
"Use this for tests, builds, package commands, git commands, and "
|
||||
"other process execution. Prefer read_file/find_files/grep for "
|
||||
"inspection and apply_patch/write_file/edit_file for file changes "
|
||||
"instead of cat, shell find/grep, echo, or sed. "
|
||||
"Use -y or --yes flags to avoid interactive prompts. "
|
||||
"Output is truncated at 10 000 chars; timeout defaults to 60s."
|
||||
"For long-running or interactive commands, pass yield_time_ms; "
|
||||
"if the command keeps running, exec returns a session_id that can "
|
||||
"be polled or written to with write_stdin. Output is truncated at "
|
||||
"10 000 chars; timeout defaults to 60s."
|
||||
)
|
||||
|
||||
@property
|
||||
@ -161,9 +227,111 @@ class ExecTool(Tool):
|
||||
return True
|
||||
|
||||
async def execute(
|
||||
self, command: str, working_dir: str | None = None,
|
||||
timeout: int | None = None, **kwargs: Any,
|
||||
self, command: str | None = None, cmd: str | None = None,
|
||||
working_dir: str | None = None, workdir: str | None = None,
|
||||
timeout: int | None = None, shell: str | None = None,
|
||||
login: bool | None = None, yield_time_ms: int | None = None,
|
||||
max_output_chars: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
command = command or cmd
|
||||
working_dir = working_dir or workdir
|
||||
if not command:
|
||||
return "Error: Missing command. Provide command or cmd."
|
||||
if max_output_chars is None:
|
||||
max_output_chars = max_output_tokens
|
||||
|
||||
prepared = self._prepare_command(command, working_dir, timeout, shell, login)
|
||||
if isinstance(prepared, str):
|
||||
return prepared
|
||||
|
||||
if yield_time_ms is not None:
|
||||
return await self._execute_session(prepared, yield_time_ms, max_output_chars)
|
||||
|
||||
try:
|
||||
process = await self._spawn(
|
||||
prepared.command,
|
||||
prepared.cwd,
|
||||
prepared.env,
|
||||
prepared.shell_program,
|
||||
prepared.login,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=prepared.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await self._kill_process(process)
|
||||
return f"Error: Command timed out after {prepared.timeout} seconds"
|
||||
except asyncio.CancelledError:
|
||||
await self._kill_process(process)
|
||||
raise
|
||||
|
||||
output_parts = []
|
||||
|
||||
if stdout:
|
||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||
|
||||
if stderr:
|
||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||
if stderr_text.strip():
|
||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||
|
||||
max_len = clamp_session_int(max_output_chars, self._MAX_OUTPUT, 1000, MAX_OUTPUT_CHARS)
|
||||
if len(result) > max_len:
|
||||
half = max_len // 2
|
||||
result = (
|
||||
result[:half]
|
||||
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
|
||||
+ result[-half:]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
async def _execute_session(
|
||||
self,
|
||||
prepared: _PreparedCommand,
|
||||
yield_time_ms: int | None,
|
||||
max_output_chars: int | None,
|
||||
) -> str:
|
||||
try:
|
||||
session_id, poll = await self._session_manager.start(
|
||||
command=prepared.command,
|
||||
cwd=prepared.cwd,
|
||||
env=prepared.env,
|
||||
timeout=prepared.timeout,
|
||||
shell_program=prepared.shell_program,
|
||||
login=prepared.login,
|
||||
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
|
||||
max_output_chars=clamp_session_int(
|
||||
max_output_chars,
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
1000,
|
||||
MAX_OUTPUT_CHARS,
|
||||
),
|
||||
)
|
||||
return format_session_poll(session_id, poll)
|
||||
except Exception as exc:
|
||||
return f"Error executing command: {exc}"
|
||||
|
||||
def _prepare_command(
|
||||
self,
|
||||
command: str,
|
||||
working_dir: str | None = None,
|
||||
timeout: int | None = None,
|
||||
shell: str | None = None,
|
||||
login: bool | None = None,
|
||||
) -> _PreparedCommand | str:
|
||||
cwd = working_dir or self.working_dir or os.getcwd()
|
||||
|
||||
# Prevent an LLM-supplied working_dir from escaping the configured
|
||||
@ -211,52 +379,24 @@ class ExecTool(Tool):
|
||||
env["NANOBOT_PATH_APPEND"] = self.path_append
|
||||
command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}'
|
||||
|
||||
try:
|
||||
process = await self._spawn(command, cwd, env)
|
||||
shell_program, shell_error = self._resolve_shell(shell)
|
||||
if shell_error:
|
||||
return shell_error
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await self._kill_process(process)
|
||||
return f"Error: Command timed out after {effective_timeout} seconds"
|
||||
except asyncio.CancelledError:
|
||||
await self._kill_process(process)
|
||||
raise
|
||||
|
||||
output_parts = []
|
||||
|
||||
if stdout:
|
||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||
|
||||
if stderr:
|
||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||
if stderr_text.strip():
|
||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||
|
||||
max_len = self._MAX_OUTPUT
|
||||
if len(result) > max_len:
|
||||
half = max_len // 2
|
||||
result = (
|
||||
result[:half]
|
||||
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
|
||||
+ result[-half:]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
return _PreparedCommand(
|
||||
command=command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
timeout=effective_timeout,
|
||||
shell_program=shell_program,
|
||||
login=True if login is None else login,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _spawn(
|
||||
command: str, cwd: str, env: dict[str, str],
|
||||
shell_program: str | None = None,
|
||||
login: bool = True,
|
||||
) -> asyncio.subprocess.Process:
|
||||
"""Launch *command* in a platform-appropriate shell."""
|
||||
if _IS_WINDOWS:
|
||||
@ -272,9 +412,14 @@ class ExecTool(Tool):
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
bash = shutil.which("bash") or "/bin/bash"
|
||||
shell_program = shell_program or shutil.which("bash") or "/bin/bash"
|
||||
args = [shell_program]
|
||||
shell_name = Path(shell_program).name.lower()
|
||||
if login and shell_name in {"bash", "bash.exe", "zsh", "zsh.exe"}:
|
||||
args.append("-l")
|
||||
args.extend(["-c", command])
|
||||
return await asyncio.create_subprocess_exec(
|
||||
bash, "-l", "-c", command,
|
||||
*args,
|
||||
stdin=asyncio.subprocess.DEVNULL,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
@ -282,6 +427,31 @@ class ExecTool(Tool):
|
||||
env=env,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_shell(shell: str | None) -> tuple[str | None, str | None]:
|
||||
if not shell:
|
||||
return None, None
|
||||
if _IS_WINDOWS:
|
||||
return None, "Error: shell parameter is not supported on Windows"
|
||||
if "\0" in shell or "\n" in shell or "\r" in shell:
|
||||
return None, "Error: shell contains invalid characters"
|
||||
allowed = {"sh", "bash", "zsh"}
|
||||
path = Path(shell).expanduser()
|
||||
if path.is_absolute():
|
||||
if path.name not in allowed:
|
||||
return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
|
||||
if not path.is_file() or not os.access(path, os.X_OK):
|
||||
return None, f"Error: shell is not executable: {shell}"
|
||||
return str(path), None
|
||||
if "/" in shell or "\\" in shell:
|
||||
return None, "Error: shell must be a shell name or absolute path"
|
||||
if shell not in allowed:
|
||||
return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
|
||||
resolved = shutil.which(shell)
|
||||
if not resolved:
|
||||
return None, f"Error: shell not found: {shell}"
|
||||
return resolved, None
|
||||
|
||||
@staticmethod
|
||||
async def _kill_process(process: asyncio.subprocess.Process) -> None:
|
||||
"""Kill a subprocess and reap it to prevent zombies."""
|
||||
@ -418,7 +588,7 @@ class ExecTool(Tool):
|
||||
# Windows: match drive-root paths like `C:\` as well as `C:\path\to\file`, and UNC paths like `\\server\share`
|
||||
# NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted.
|
||||
win_paths = re.findall(
|
||||
r"(?:[A-Za-z]:[^\s\"'|><;]*|\\\\[^\s\"'|><;]+(?:\\[^\s\"'|><;]+)*)",
|
||||
r"(?<![A-Za-z])(?:[A-Za-z]:[^\s\"'|><;]*|\\\\[^\s\"'|><;]+(?:\\[^\s\"'|><;]+)*)",
|
||||
command
|
||||
)
|
||||
posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
|
||||
|
||||
1402
nanobot/channels/signal.py
Normal file
1402
nanobot/channels/signal.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -79,6 +79,12 @@ BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION}
|
||||
ERRCODE_SESSION_EXPIRED = -14
|
||||
SESSION_PAUSE_DURATION_S = 60 * 60
|
||||
|
||||
# iLink context_token is observed to expire server-side after ~90-160s of
|
||||
# agent inactivity (openclaw/openclaw#61174). Proactively refresh before
|
||||
# sending if the cached token is older than this threshold.
|
||||
CONTEXT_TOKEN_MAX_AGE_S = 60
|
||||
|
||||
|
||||
# Retry constants (matching the reference plugin's monitor.ts)
|
||||
MAX_CONSECUTIVE_FAILURES = 3
|
||||
BACKOFF_DELAY_S = 30
|
||||
@ -159,6 +165,8 @@ class WeixinChannel(BaseChannel):
|
||||
self._session_pause_until: float = 0.0
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
self._typing_tickets: dict[str, dict[str, Any]] = {}
|
||||
self._context_token_at: dict[str, float] = {}
|
||||
self._pending_tool_hints: dict[str, list[str]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State persistence
|
||||
@ -486,6 +494,7 @@ class WeixinChannel(BaseChannel):
|
||||
except Exception:
|
||||
if not self._running:
|
||||
break
|
||||
self.logger.exception("WeChat poll loop error")
|
||||
consecutive_failures += 1
|
||||
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
|
||||
consecutive_failures = 0
|
||||
@ -495,6 +504,7 @@ class WeixinChannel(BaseChannel):
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
self._pending_tool_hints.clear()
|
||||
if self._poll_task and not self._poll_task.done():
|
||||
self._poll_task.cancel()
|
||||
for chat_id in list(self._typing_tasks):
|
||||
@ -545,6 +555,7 @@ class WeixinChannel(BaseChannel):
|
||||
# Check for API-level errors (monitor.ts checks both ret and errcode)
|
||||
ret = data.get("ret", 0)
|
||||
errcode = data.get("errcode", 0)
|
||||
|
||||
is_error = (ret is not None and ret != 0) or (errcode is not None and errcode != 0)
|
||||
|
||||
if is_error:
|
||||
@ -575,8 +586,10 @@ class WeixinChannel(BaseChannel):
|
||||
# Process messages (WeixinMessage[] from types.ts)
|
||||
msgs: list[dict] = data.get("msgs", []) or []
|
||||
for msg in msgs:
|
||||
with suppress(Exception):
|
||||
try:
|
||||
await self._process_message(msg)
|
||||
except Exception:
|
||||
self.logger.exception("Failed to process WeChat message")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Inbound message processing (matches inbound.ts + process-message.ts)
|
||||
@ -610,6 +623,7 @@ class WeixinChannel(BaseChannel):
|
||||
ctx_token = msg.get("context_token", "")
|
||||
if ctx_token:
|
||||
self._context_tokens[from_user_id] = ctx_token
|
||||
self._context_token_at[from_user_id] = time.time()
|
||||
self._save_state()
|
||||
|
||||
# Parse item_list (WeixinMessage.item_list — types.ts:161)
|
||||
@ -915,6 +929,99 @@ class WeixinChannel(BaseChannel):
|
||||
}
|
||||
return ""
|
||||
|
||||
async def _refresh_context_token_if_stale(
|
||||
self, chat_id: str, context_token: str
|
||||
) -> str:
|
||||
"""Return a fresh context_token if the cached one is too old.
|
||||
|
||||
iLink context_token expires server-side after a short idle period
|
||||
(empirically ~90s). Proactively refreshing before sending prevents
|
||||
silent message loss on long agent turns or cron pushes.
|
||||
"""
|
||||
if not context_token:
|
||||
return context_token
|
||||
|
||||
now = time.time()
|
||||
cached_at = self._context_token_at.get(chat_id, 0)
|
||||
age = now - cached_at
|
||||
|
||||
if age < CONTEXT_TOKEN_MAX_AGE_S:
|
||||
return context_token
|
||||
|
||||
self.logger.debug(
|
||||
"WeChat context_token for {} is {:.0f}s old; refreshing via getconfig",
|
||||
chat_id,
|
||||
age,
|
||||
)
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"ilink_user_id": chat_id,
|
||||
"context_token": context_token,
|
||||
"base_info": BASE_INFO,
|
||||
}
|
||||
try:
|
||||
data = await self._api_post("ilink/bot/getconfig", body)
|
||||
except Exception as e:
|
||||
self.logger.warning("WeChat getconfig failed for {}: {}", chat_id, e)
|
||||
return context_token
|
||||
|
||||
if data.get("ret", 0) != 0:
|
||||
self.logger.warning(
|
||||
"WeChat getconfig returned ret={} for {}: {}",
|
||||
data.get("ret"),
|
||||
chat_id,
|
||||
data.get("errmsg", ""),
|
||||
)
|
||||
return context_token
|
||||
|
||||
new_token = str(data.get("context_token", "") or "")
|
||||
if new_token and new_token != context_token:
|
||||
self.logger.info(
|
||||
"WeChat context_token refreshed for {} (age {:.0f}s -> fresh)",
|
||||
chat_id,
|
||||
age,
|
||||
)
|
||||
self._context_tokens[chat_id] = new_token
|
||||
self._context_token_at[chat_id] = now
|
||||
self._save_state()
|
||||
return new_token
|
||||
|
||||
return context_token
|
||||
|
||||
async def _flush_tool_hints(self, chat_id: str) -> None:
|
||||
"""Send any buffered tool hints for *chat_id* as a single message.
|
||||
|
||||
Tool hints are coalesced to reduce message count and avoid hitting the
|
||||
WeChat iLink rate limit (~7 msgs / 5 min). Failures are logged but
|
||||
not raised so that the main message send is never blocked.
|
||||
"""
|
||||
hints = self._pending_tool_hints.pop(chat_id, None)
|
||||
if not hints:
|
||||
return
|
||||
|
||||
self.logger.info(
|
||||
"Flushing {} buffered tool hint(s) for {}",
|
||||
len(hints),
|
||||
chat_id,
|
||||
)
|
||||
|
||||
ctx_token = self._context_tokens.get(chat_id, "")
|
||||
ctx_token = await self._refresh_context_token_if_stale(chat_id, ctx_token)
|
||||
if not ctx_token:
|
||||
self.logger.warning(
|
||||
"Dropped {} buffered tool hint(s) for {}: no context_token",
|
||||
len(hints),
|
||||
chat_id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await self._send_text(chat_id, "\n\n".join(hints), ctx_token)
|
||||
except Exception:
|
||||
self.logger.exception(
|
||||
"Failed to flush buffered tool hints for {}", chat_id
|
||||
)
|
||||
|
||||
async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None:
|
||||
"""Best-effort sendtyping wrapper."""
|
||||
if not typing_ticket:
|
||||
@ -944,11 +1051,47 @@ class WeixinChannel(BaseChannel):
|
||||
self._assert_session_active()
|
||||
|
||||
is_progress = bool((msg.metadata or {}).get("_progress", False))
|
||||
|
||||
# Buffer tool hints to coalesce consecutive ones and avoid burning
|
||||
# WeChat iLink rate-limit quota (~7 msgs / 5 min).
|
||||
if is_progress and (msg.metadata or {}).get("_tool_hint"):
|
||||
if not self.send_tool_hints:
|
||||
return
|
||||
self._pending_tool_hints.setdefault(msg.chat_id, []).append(msg.content)
|
||||
self.logger.debug(
|
||||
"Buffered tool hint for {} (count={})",
|
||||
msg.chat_id,
|
||||
len(self._pending_tool_hints[msg.chat_id]),
|
||||
)
|
||||
return
|
||||
|
||||
# Reasoning deltas are invisible in WeChat (there is no reasoning
|
||||
# UI). Skip them entirely — do not send and do not flush buffer.
|
||||
if is_progress and (msg.metadata or {}).get("_reasoning_delta"):
|
||||
self.logger.debug(
|
||||
"Dropped invisible reasoning delta for {}", msg.chat_id
|
||||
)
|
||||
return
|
||||
|
||||
content = msg.content.strip()
|
||||
|
||||
# Empty progress messages (e.g. after_iteration tool_events) must
|
||||
# NOT act as separators — they have no visible content.
|
||||
if is_progress and not content and not (msg.media or []):
|
||||
self.logger.debug(
|
||||
"Skipped empty progress message for {} (no visible content)",
|
||||
msg.chat_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Flush buffered hints before sending any visible message.
|
||||
await self._flush_tool_hints(msg.chat_id)
|
||||
|
||||
if not is_progress:
|
||||
await self._stop_typing(msg.chat_id, clear_remote=True)
|
||||
|
||||
content = msg.content.strip()
|
||||
ctx_token = self._context_tokens.get(msg.chat_id, "")
|
||||
ctx_token = await self._refresh_context_token_if_stale(msg.chat_id, ctx_token)
|
||||
if not ctx_token:
|
||||
raise RuntimeError(
|
||||
f"WeChat context_token missing for chat_id={msg.chat_id}, cannot send"
|
||||
@ -1037,6 +1180,18 @@ class WeixinChannel(BaseChannel):
|
||||
with suppress(Exception):
|
||||
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
|
||||
|
||||
async def send_delta(
|
||||
self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
"""Weixin iLink does not support native streaming deltas.
|
||||
|
||||
We only hook ``_stream_end`` so buffered tool hints are flushed even
|
||||
when the final answer carries the ``_streamed`` flag and bypasses
|
||||
:meth:`send`.
|
||||
"""
|
||||
if metadata and metadata.get("_stream_end"):
|
||||
await self._flush_tool_hints(chat_id)
|
||||
|
||||
async def _start_typing(self, chat_id: str, context_token: str = "") -> None:
|
||||
"""Start typing indicator immediately when a message is received."""
|
||||
if not self._client or not self._token or not chat_id:
|
||||
@ -1120,10 +1275,11 @@ class WeixinChannel(BaseChannel):
|
||||
}
|
||||
|
||||
data = await self._api_post("ilink/bot/sendmessage", body)
|
||||
ret = data.get("ret", 0)
|
||||
errcode = data.get("errcode", 0)
|
||||
if errcode and errcode != 0:
|
||||
if (ret is not None and ret != 0) or (errcode is not None and errcode != 0):
|
||||
raise RuntimeError(
|
||||
f"WeChat send text error (code {errcode}): {data.get('errmsg', '')}"
|
||||
f"WeChat send text error (ret={ret}, errcode={errcode}): {data.get('errmsg', '')}"
|
||||
)
|
||||
|
||||
async def _send_media_file(
|
||||
@ -1270,10 +1426,11 @@ class WeixinChannel(BaseChannel):
|
||||
}
|
||||
|
||||
data = await self._api_post("ilink/bot/sendmessage", body)
|
||||
ret = data.get("ret", 0)
|
||||
errcode = data.get("errcode", 0)
|
||||
if errcode and errcode != 0:
|
||||
if (ret is not None and ret != 0) or (errcode is not None and errcode != 0):
|
||||
raise RuntimeError(
|
||||
f"WeChat send media error (code {errcode}): {data.get('errmsg', '')}"
|
||||
f"WeChat send media error (ret={ret}, errcode={errcode}): {data.get('errmsg', '')}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -211,6 +211,7 @@ class ProvidersConfig(Base):
|
||||
ant_ling: ProviderConfig = Field(default_factory=ProviderConfig) # Ant Ling
|
||||
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
|
||||
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
|
||||
novita: ProviderConfig = Field(default_factory=ProviderConfig) # Novita AI
|
||||
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
|
||||
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
|
||||
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
|
||||
|
||||
@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import binascii
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@ -31,6 +33,14 @@ _AIHUBMIX_ASPECT_RATIO_SIZES = {
|
||||
}
|
||||
_GEMINI_DEFAULT_TIMEOUT_S = 120.0
|
||||
_GEMINI_IMAGEN_ASPECT_RATIOS = {"1:1", "9:16", "16:9", "3:4", "4:3"}
|
||||
_OLLAMA_DEFAULT_SIDE = 1024
|
||||
_OLLAMA_SIZE_PRESETS = {
|
||||
"1K": 1024,
|
||||
"2K": 2048,
|
||||
"4K": 4096,
|
||||
}
|
||||
_OLLAMA_EXPLICIT_SIZE_RE = re.compile(r"^\s*(\d+)\s*[xX]\s*(\d+)\s*$")
|
||||
_OLLAMA_ASPECT_RATIO_RE = re.compile(r"^\s*(\d+)\s*:\s*(\d+)\s*$")
|
||||
|
||||
|
||||
class ImageGenerationError(RuntimeError):
|
||||
@ -438,6 +448,139 @@ def _http_error_detail(response: httpx.Response) -> str:
|
||||
return response.text[:500] or "<empty response body>"
|
||||
|
||||
|
||||
def _round_to_multiple(value: float, multiple: int = 8) -> int:
|
||||
rounded = int(round(value / multiple) * multiple)
|
||||
return max(multiple, rounded)
|
||||
|
||||
|
||||
def _ollama_dimensions(aspect_ratio: str | None, image_size: str | None) -> tuple[int, int]:
|
||||
if image_size:
|
||||
size = image_size.strip()
|
||||
explicit = _OLLAMA_EXPLICIT_SIZE_RE.fullmatch(size)
|
||||
if explicit:
|
||||
return int(explicit.group(1)), int(explicit.group(2))
|
||||
long_side = _OLLAMA_SIZE_PRESETS.get(size.upper(), _OLLAMA_DEFAULT_SIDE)
|
||||
else:
|
||||
long_side = _OLLAMA_DEFAULT_SIDE
|
||||
|
||||
if not aspect_ratio:
|
||||
return long_side, long_side
|
||||
|
||||
ratio = _OLLAMA_ASPECT_RATIO_RE.fullmatch(aspect_ratio.strip())
|
||||
if ratio is None:
|
||||
return long_side, long_side
|
||||
|
||||
width_ratio = int(ratio.group(1))
|
||||
height_ratio = int(ratio.group(2))
|
||||
if width_ratio <= 0 or height_ratio <= 0:
|
||||
return long_side, long_side
|
||||
|
||||
if width_ratio >= height_ratio:
|
||||
width = long_side
|
||||
height = _round_to_multiple(long_side * height_ratio / width_ratio)
|
||||
else:
|
||||
height = long_side
|
||||
width = _round_to_multiple(long_side * width_ratio / height_ratio)
|
||||
return max(8, width), max(8, height)
|
||||
|
||||
|
||||
def _ollama_image_data_url(value: str) -> str:
|
||||
if value.startswith("data:image/"):
|
||||
return value
|
||||
return _b64_image_data_url(value)
|
||||
|
||||
|
||||
def _ollama_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
||||
images: list[str] = []
|
||||
|
||||
def collect(value: Any) -> None:
|
||||
if isinstance(value, str) and value:
|
||||
images.append(_ollama_image_data_url(value))
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
collect(item)
|
||||
|
||||
collect(payload.get("image"))
|
||||
collect(payload.get("images"))
|
||||
return images
|
||||
|
||||
|
||||
class OllamaImageGenerationClient(ImageGenerationProvider):
|
||||
"""Async client for Ollama native image generation models."""
|
||||
|
||||
provider_name = "ollama"
|
||||
default_timeout = 300.0
|
||||
|
||||
def _default_base_url(self) -> str:
|
||||
return "http://localhost:11434/api"
|
||||
|
||||
def _resolve_base_url(self, api_base: str | None) -> str:
|
||||
if api_base:
|
||||
base = api_base.rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
return f"{base[:-3]}/api"
|
||||
return base
|
||||
return self._default_base_url()
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
reference_images: list[str] | None = None,
|
||||
aspect_ratio: str | None = None,
|
||||
image_size: str | None = None,
|
||||
) -> GeneratedImageResponse:
|
||||
if reference_images:
|
||||
raise ImageGenerationError(
|
||||
"Ollama image generation does not support reference images"
|
||||
)
|
||||
|
||||
width, height = _ollama_dimensions(aspect_ratio, image_size)
|
||||
body: dict[str, Any] = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"steps": 0,
|
||||
}
|
||||
body.update(self.extra_body)
|
||||
body["stream"] = False
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
**self.extra_headers,
|
||||
}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
|
||||
url = f"{self.api_base}/generate"
|
||||
response = await self._http_post(url, headers=headers, body=body)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
detail = _http_error_detail(response)
|
||||
logger.error(
|
||||
"Ollama image generation failed (HTTP {}): {}",
|
||||
response.status_code,
|
||||
detail,
|
||||
)
|
||||
raise ImageGenerationError(
|
||||
f"Ollama image generation failed (HTTP {response.status_code}): {detail}"
|
||||
) from exc
|
||||
|
||||
data = response.json()
|
||||
images = _ollama_images_from_payload(data)
|
||||
|
||||
self._require_images(images, data)
|
||||
|
||||
response_text = data.get("response")
|
||||
content = response_text if isinstance(response_text, str) else ""
|
||||
|
||||
return GeneratedImageResponse(images=images, content=content, raw=data)
|
||||
|
||||
|
||||
class GeminiImageGenerationClient(ImageGenerationProvider):
|
||||
"""Async client for Gemini/Imagen image generation via the Generative Language API."""
|
||||
|
||||
@ -759,6 +902,426 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
||||
return images
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI image generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_OPENAI_DALLE2_SUPPORTED_SIZES = {"256x256", "512x512", "1024x1024"}
|
||||
_OPENAI_DALLE3_SUPPORTED_SIZES = {"1024x1024", "1792x1024", "1024x1792"}
|
||||
_OPENAI_GPT_IMAGE_SUPPORTED_SIZES = {
|
||||
"1024x1024",
|
||||
"1536x1024",
|
||||
"1024x1536",
|
||||
"auto",
|
||||
}
|
||||
_OPENAI_DALLE2_ASPECT_RATIO_SIZES = {
|
||||
"1:1": "1024x1024",
|
||||
"16:9": "1024x1024",
|
||||
"9:16": "1024x1024",
|
||||
"3:4": "1024x1024",
|
||||
"4:3": "1024x1024",
|
||||
}
|
||||
_OPENAI_DALLE3_ASPECT_RATIO_SIZES = {
|
||||
"1:1": "1024x1024",
|
||||
"16:9": "1792x1024",
|
||||
"9:16": "1024x1792",
|
||||
"3:4": "1024x1792",
|
||||
"4:3": "1792x1024",
|
||||
}
|
||||
_OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES = {
|
||||
"1:1": "1024x1024",
|
||||
"16:9": "1536x1024",
|
||||
"9:16": "1024x1536",
|
||||
"3:4": "1024x1536",
|
||||
"4:3": "1536x1024",
|
||||
}
|
||||
|
||||
|
||||
class OpenAIImageGenerationClient(ImageGenerationProvider):
|
||||
"""OpenAI Images API using an API key (``providers.openai.apiKey``)."""
|
||||
|
||||
provider_name = "openai"
|
||||
missing_key_message = (
|
||||
"OpenAI API key is not configured. Set providers.openai.apiKey."
|
||||
)
|
||||
|
||||
def _default_base_url(self) -> str:
|
||||
return "https://api.openai.com/v1"
|
||||
|
||||
@staticmethod
|
||||
def _strip_model_prefix(model: str) -> str:
|
||||
"""Remove ``openai/`` prefix if present (OpenRouter convention)."""
|
||||
if model.startswith("openai/") or model.startswith("openai_codex/"):
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
reference_images: list[str] | None = None,
|
||||
aspect_ratio: str | None = None,
|
||||
image_size: str | None = None,
|
||||
) -> GeneratedImageResponse:
|
||||
if not self.api_key:
|
||||
raise ImageGenerationError(self.missing_key_message)
|
||||
|
||||
if reference_images:
|
||||
logger.warning(
|
||||
"DALL-E models do not support reference images; "
|
||||
"ignoring {} reference image(s) for {}",
|
||||
len(reference_images),
|
||||
model,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
**self.extra_headers,
|
||||
}
|
||||
|
||||
clean_model = self._strip_model_prefix(model)
|
||||
body: dict[str, Any] = {
|
||||
"model": clean_model,
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
if not _openai_is_gpt_image_model(clean_model):
|
||||
body["response_format"] = "b64_json"
|
||||
body["n"] = 1
|
||||
|
||||
size = _openai_size(clean_model, aspect_ratio, image_size)
|
||||
if size:
|
||||
body["size"] = size
|
||||
|
||||
body.update(self.extra_body)
|
||||
|
||||
logger.info("OpenAI Images API request: POST {}/images/generations body={}", self.api_base, body)
|
||||
|
||||
response = await self._http_post(
|
||||
f"{self.api_base}/images/generations",
|
||||
headers=headers,
|
||||
body=body,
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
detail = response.text[:1000]
|
||||
logger.error("OpenAI Images API error ({}): {}", response.status_code, detail)
|
||||
raise ImageGenerationError(
|
||||
f"OpenAI image generation failed (HTTP {response.status_code}): {detail}"
|
||||
) from exc
|
||||
|
||||
payload = response.json()
|
||||
logger.info("OpenAI Images API response ({}): {}", response.status_code,
|
||||
{k: v for k, v in payload.items() if k != "data"})
|
||||
|
||||
client = self._client
|
||||
owns_client = client is None
|
||||
if owns_client:
|
||||
client = httpx.AsyncClient(timeout=self.timeout)
|
||||
try:
|
||||
images = await _openai_images_from_payload(client, payload)
|
||||
finally:
|
||||
if owns_client:
|
||||
await client.aclose()
|
||||
|
||||
self._require_images(images, payload)
|
||||
|
||||
return GeneratedImageResponse(images=images, content="", raw=payload)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI Codex image generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CodexImageGenerationClient(ImageGenerationProvider):
|
||||
"""OpenAI image generation via Codex subscription OAuth.
|
||||
|
||||
Uses the Codex Responses API with the ``image_generation`` tool
|
||||
(the same mechanism ChatGPT uses internally). No API key required —
|
||||
the Codex OAuth token from ``oauth_cli_kit`` is used instead.
|
||||
"""
|
||||
|
||||
provider_name = "openai_codex"
|
||||
missing_key_message = (
|
||||
"Codex OAuth token is unavailable. "
|
||||
"Log in with Codex subscription first."
|
||||
)
|
||||
|
||||
def _default_base_url(self) -> str:
|
||||
return "https://chatgpt.com/backend-api"
|
||||
|
||||
def _codex_model(self, model: str) -> str:
|
||||
"""Strip the ``openai-codex/`` prefix if present."""
|
||||
if model.startswith(("openai-codex/", "openai_codex/")):
|
||||
return model.split("/", 1)[1]
|
||||
return model
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str,
|
||||
reference_images: list[str] | None = None,
|
||||
aspect_ratio: str | None = None,
|
||||
image_size: str | None = None,
|
||||
) -> GeneratedImageResponse:
|
||||
try:
|
||||
from oauth_cli_kit import get_token as get_codex_token
|
||||
except ImportError:
|
||||
raise ImageGenerationError(self.missing_key_message)
|
||||
|
||||
try:
|
||||
token = await asyncio.to_thread(get_codex_token)
|
||||
except Exception as exc:
|
||||
raise ImageGenerationError(self.missing_key_message) from exc
|
||||
if not token or not token.access:
|
||||
raise ImageGenerationError(self.missing_key_message)
|
||||
|
||||
logger.info(
|
||||
"Using Codex OAuth token for image generation (account: {})",
|
||||
token.account_id,
|
||||
)
|
||||
|
||||
if reference_images:
|
||||
logger.warning(
|
||||
"Codex image generation does not support reference images; "
|
||||
"ignoring {} reference image(s)",
|
||||
len(reference_images),
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token.access}",
|
||||
"chatgpt-account-id": token.account_id,
|
||||
"OpenAI-Beta": "responses=experimental",
|
||||
"originator": "nanobot",
|
||||
"User-Agent": "nanobot (python)",
|
||||
"Content-Type": "application/json",
|
||||
**self.extra_headers,
|
||||
}
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"model": self._codex_model(model),
|
||||
"instructions": "Generate an image based on the user's request.",
|
||||
"input": [{"role": "user", "content": prompt}],
|
||||
"tools": [{"type": "image_generation"}],
|
||||
"tool_choice": "auto",
|
||||
"stream": True,
|
||||
"store": False,
|
||||
}
|
||||
body.update(self.extra_body)
|
||||
|
||||
logger.info("Codex Responses API request: POST {}/codex/responses body={}",
|
||||
self.api_base, {k: v for k, v in body.items() if k != "input"})
|
||||
|
||||
response = await self._http_post(
|
||||
f"{self.api_base}/codex/responses",
|
||||
headers=headers,
|
||||
body=body,
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
detail = response.text[:1000]
|
||||
logger.error("Codex Responses API error ({}): {}", response.status_code, detail)
|
||||
raise ImageGenerationError(
|
||||
f"Codex image generation failed (HTTP {response.status_code}): {detail}"
|
||||
) from exc
|
||||
|
||||
images, content_text = await _parse_codex_sse_images(response)
|
||||
|
||||
raw = {"status": "completed"}
|
||||
self._require_images(images, raw)
|
||||
|
||||
return GeneratedImageResponse(images=images, content=content_text, raw=raw)
|
||||
|
||||
|
||||
def _openai_size(
|
||||
model: str,
|
||||
aspect_ratio: str | None,
|
||||
image_size: str | None,
|
||||
) -> str:
|
||||
"""Resolve aspect ratio or image_size to an OpenAI Images API size string."""
|
||||
sizes, supported_sizes = _openai_size_options(model)
|
||||
explicit_size = _normalize_openai_image_size(image_size)
|
||||
if explicit_size and _openai_explicit_size_supported(
|
||||
explicit_size,
|
||||
supported_sizes=supported_sizes,
|
||||
):
|
||||
return explicit_size
|
||||
if explicit_size:
|
||||
logger.warning(
|
||||
"OpenAI image size '{}' is not supported by {}; using aspect ratio/default size",
|
||||
explicit_size,
|
||||
model,
|
||||
)
|
||||
if aspect_ratio and aspect_ratio in sizes:
|
||||
return sizes[aspect_ratio]
|
||||
return "1024x1024"
|
||||
|
||||
|
||||
def _openai_is_gpt_image_model(model: str) -> bool:
|
||||
normalized = model.lower()
|
||||
return normalized.startswith(("gpt-image", "chatgpt-image"))
|
||||
|
||||
|
||||
def _openai_size_options(model: str) -> tuple[dict[str, str], set[str] | None]:
|
||||
normalized = model.lower()
|
||||
if normalized.startswith("dall-e-2"):
|
||||
return _OPENAI_DALLE2_ASPECT_RATIO_SIZES, _OPENAI_DALLE2_SUPPORTED_SIZES
|
||||
if normalized.startswith("dall-e-3"):
|
||||
return _OPENAI_DALLE3_ASPECT_RATIO_SIZES, _OPENAI_DALLE3_SUPPORTED_SIZES
|
||||
if normalized.startswith("gpt-image-2"):
|
||||
return _OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, None
|
||||
return _OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, _OPENAI_GPT_IMAGE_SUPPORTED_SIZES
|
||||
|
||||
|
||||
def _normalize_openai_image_size(image_size: str | None) -> str | None:
|
||||
if not image_size:
|
||||
return None
|
||||
normalized = image_size.strip().lower()
|
||||
return normalized or None
|
||||
|
||||
|
||||
def _openai_explicit_size_supported(
|
||||
size: str,
|
||||
*,
|
||||
supported_sizes: set[str] | None,
|
||||
) -> bool:
|
||||
if supported_sizes is not None:
|
||||
return size in supported_sizes
|
||||
width, sep, height = size.partition("x")
|
||||
return bool(sep and width.isdecimal() and height.isdecimal())
|
||||
|
||||
|
||||
async def _openai_images_from_payload(
|
||||
client: httpx.AsyncClient,
|
||||
payload: dict[str, Any],
|
||||
) -> list[str]:
|
||||
"""Extract images from OpenAI Images API response.
|
||||
|
||||
Handles both ``b64_json`` (preferred) and ``url`` (downloaded) formats.
|
||||
"""
|
||||
images: list[str] = []
|
||||
for item in payload.get("data") or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
b64 = item.get("b64_json")
|
||||
if isinstance(b64, str) and b64:
|
||||
images.append(_b64_image_data_url(b64))
|
||||
continue
|
||||
url = item.get("url")
|
||||
if isinstance(url, str) and url:
|
||||
images.append(await _download_image_data_url(client, url))
|
||||
return images
|
||||
|
||||
|
||||
def _codex_responses_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
||||
"""Extract images from Codex Responses API ``image_generation_call`` output."""
|
||||
images: list[str] = []
|
||||
for item in payload.get("output") or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("type") != "image_generation_call":
|
||||
continue
|
||||
result = item.get("result")
|
||||
if isinstance(result, str):
|
||||
images.append(result if result.startswith("data:image/") else _b64_image_data_url(result))
|
||||
continue
|
||||
if isinstance(result, dict):
|
||||
image_url = result.get("image_url") or result.get("image") or ""
|
||||
if isinstance(image_url, str):
|
||||
images.append(image_url if image_url.startswith("data:image/") else _b64_image_data_url(image_url))
|
||||
return images
|
||||
|
||||
|
||||
async def _parse_codex_sse_images(
|
||||
response: httpx.Response,
|
||||
) -> tuple[list[str], str]:
|
||||
"""Parse a Codex Responses API SSE stream for image generation output.
|
||||
|
||||
Returns ``(images, content_text)``.
|
||||
"""
|
||||
import json as _json
|
||||
|
||||
images: list[str] = []
|
||||
text_parts: list[str] = []
|
||||
|
||||
buffer: list[str] = []
|
||||
async for line_bytes in response.aiter_lines():
|
||||
line = line_bytes.strip()
|
||||
if line == "":
|
||||
if buffer:
|
||||
data_lines = []
|
||||
for bl in buffer:
|
||||
if bl.startswith("data:"):
|
||||
data_lines.append(bl[5:].strip())
|
||||
buffer.clear()
|
||||
if data_lines:
|
||||
raw = "".join(data_lines)
|
||||
if raw == "[DONE]":
|
||||
break
|
||||
try:
|
||||
event = _json.loads(raw)
|
||||
except Exception:
|
||||
continue
|
||||
ev_type = event.get("type", "")
|
||||
if ev_type in ("error", "response.failed"):
|
||||
logger.error("Codex SSE failure: {}", raw[:2000])
|
||||
_collect_images_from_sse_event(event, images)
|
||||
_collect_text_from_sse_event(event, text_parts)
|
||||
continue
|
||||
buffer.append(line)
|
||||
|
||||
# flush remaining
|
||||
if buffer:
|
||||
data_lines = [bl[5:].strip() for bl in buffer if bl.startswith("data:")]
|
||||
raw = "".join(data_lines)
|
||||
if raw and raw != "[DONE]":
|
||||
try:
|
||||
event = _json.loads(raw)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
_collect_images_from_sse_event(event, images)
|
||||
_collect_text_from_sse_event(event, text_parts)
|
||||
|
||||
return images, "".join(text_parts).strip()
|
||||
|
||||
|
||||
def _collect_images_from_sse_event(event: dict[str, Any], images: list[str]) -> None:
|
||||
if event.get("type") != "response.output_item.done":
|
||||
return
|
||||
item = event.get("item") or {}
|
||||
if item.get("type") != "image_generation_call":
|
||||
return
|
||||
result = item.get("result")
|
||||
if isinstance(result, str):
|
||||
if result.startswith("data:image/"):
|
||||
images.append(result)
|
||||
else:
|
||||
images.append(_b64_image_data_url(result))
|
||||
elif isinstance(result, dict):
|
||||
image_url = result.get("image_url") or result.get("image") or ""
|
||||
if isinstance(image_url, str):
|
||||
if image_url.startswith("data:image/"):
|
||||
images.append(image_url)
|
||||
else:
|
||||
images.append(_b64_image_data_url(image_url))
|
||||
|
||||
|
||||
def _collect_text_from_sse_event(event: dict[str, Any], text_parts: list[str]) -> None:
|
||||
if event.get("type") == "response.output_text.delta":
|
||||
delta = event.get("delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
text_parts.append(delta)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StepFun (阶跃星辰) image generation
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -886,8 +1449,11 @@ def _stepfun_images_from_payload(payload: dict[str, Any]) -> list[str]:
|
||||
# Provider registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
register_image_gen_provider(OpenRouterImageGenerationClient)
|
||||
register_image_gen_provider(AIHubMixImageGenerationClient)
|
||||
register_image_gen_provider(CodexImageGenerationClient)
|
||||
register_image_gen_provider(GeminiImageGenerationClient)
|
||||
register_image_gen_provider(OllamaImageGenerationClient)
|
||||
register_image_gen_provider(MiniMaxImageGenerationClient)
|
||||
register_image_gen_provider(OpenAIImageGenerationClient)
|
||||
register_image_gen_provider(OpenRouterImageGenerationClient)
|
||||
register_image_gen_provider(StepFunImageGenerationClient)
|
||||
|
||||
@ -11,6 +11,7 @@ import secrets
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Awaitable, Callable
|
||||
from ipaddress import ip_address
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -74,41 +75,43 @@ _THINKING_STYLE_MAP: dict[str, Any] = {
|
||||
"enable_thinking": lambda on: {"enable_thinking": on},
|
||||
"reasoning_split": lambda on: {"reasoning_split": on},
|
||||
}
|
||||
_GATEWAY_REASONING_STYLE_MAP: dict[str, Any] = {
|
||||
"reasoning_effort": lambda effort: {"reasoning": {"effort": effort}},
|
||||
}
|
||||
_MODEL_THINKING_STYLES: dict[str, str] = {
|
||||
**dict.fromkeys(_KIMI_THINKING_MODELS, "thinking_type"),
|
||||
**dict.fromkeys(_MIMO_THINKING_MODELS, "thinking_type"),
|
||||
}
|
||||
|
||||
|
||||
def _is_kimi_thinking_model(model_name: str) -> bool:
|
||||
"""Return True if model_name refers to a Kimi thinking-capable model.
|
||||
|
||||
Supports two forms:
|
||||
- Exact match: e.g. kimi-k2.5 / kimi-k2.6 in _KIMI_THINKING_MODELS
|
||||
- Slug match: moonshotai/kimi-k2.5 -> the part after the last "/"
|
||||
is checked against _KIMI_THINKING_MODELS
|
||||
|
||||
This covers both the native Moonshot provider (bare slug) and
|
||||
OpenRouter-style names (``"publisher/slug"``).
|
||||
"""
|
||||
name = model_name.lower()
|
||||
if name in _KIMI_THINKING_MODELS:
|
||||
return True
|
||||
if "/" in name and name.rsplit("/", 1)[1] in _KIMI_THINKING_MODELS:
|
||||
return True
|
||||
return False
|
||||
def _model_slug(model_name: str) -> str:
|
||||
return model_name.lower().rsplit("/", 1)[-1]
|
||||
|
||||
|
||||
def _is_mimo_thinking_model(model_name: str) -> bool:
|
||||
"""Return True if model_name refers to a MiMo thinking-capable model.
|
||||
def _model_thinking_style(model_name: str) -> str:
|
||||
return _MODEL_THINKING_STYLES.get(_model_slug(model_name), "")
|
||||
|
||||
Mirrors _is_kimi_thinking_model: gateway providers (e.g. OpenRouter
|
||||
routing ``xiaomi/mimo-v2.5-pro``) have no ``thinking_style`` on their
|
||||
spec, so the spec-driven branch in _build_kwargs misses them. The
|
||||
model-name path catches those cases.
|
||||
"""
|
||||
name = model_name.lower()
|
||||
if name in _MIMO_THINKING_MODELS:
|
||||
return True
|
||||
if "/" in name and name.rsplit("/", 1)[1] in _MIMO_THINKING_MODELS:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _thinking_styles_for(spec: ProviderSpec | None, model_name: str) -> list[str]:
|
||||
styles: list[str] = []
|
||||
if spec and spec.thinking_style:
|
||||
styles.append(spec.thinking_style)
|
||||
model_style = _model_thinking_style(model_name)
|
||||
if model_style and model_style not in styles:
|
||||
styles.append(model_style)
|
||||
return styles
|
||||
|
||||
|
||||
def _thinking_extra_body(style: str, thinking_enabled: bool) -> dict[str, Any] | None:
|
||||
builder = _THINKING_STYLE_MAP.get(style)
|
||||
return builder(thinking_enabled) if builder else None
|
||||
|
||||
|
||||
def _gateway_reasoning_extra_body(style: str, effort: str | None) -> dict[str, Any] | None:
|
||||
if not effort:
|
||||
return None
|
||||
builder = _GATEWAY_REASONING_STYLE_MAP.get(style)
|
||||
return builder(effort) if builder else None
|
||||
|
||||
|
||||
def _openai_compat_timeout_s() -> float:
|
||||
@ -461,6 +464,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
"""Strip non-standard keys, normalize tool_call IDs."""
|
||||
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
||||
id_map: dict[str, str] = {}
|
||||
pending_tool_ids: dict[str, deque[str]] = {}
|
||||
force_string_content = bool(self._spec and self._spec.name == "deepseek")
|
||||
|
||||
def map_id(value: Any) -> Any:
|
||||
@ -468,15 +472,49 @@ class OpenAICompatProvider(LLMProvider):
|
||||
return value
|
||||
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
||||
|
||||
def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str:
|
||||
if isinstance(value, str) and value:
|
||||
base = map_id(value)
|
||||
else:
|
||||
base = _short_tool_id()
|
||||
if not isinstance(base, str) or not base:
|
||||
base = _short_tool_id()
|
||||
if base not in used_ids:
|
||||
return base
|
||||
seed = value if isinstance(value, str) and value else base
|
||||
salt = 1
|
||||
while True:
|
||||
candidate = self._normalize_tool_call_id(f"{seed}:{idx}:{salt}")
|
||||
if isinstance(candidate, str) and candidate not in used_ids:
|
||||
return candidate
|
||||
salt += 1
|
||||
|
||||
def map_tool_result_id(value: Any) -> Any:
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
queue = pending_tool_ids.get(value)
|
||||
if queue:
|
||||
mapped = queue.popleft()
|
||||
if not queue:
|
||||
pending_tool_ids.pop(value, None)
|
||||
return mapped
|
||||
return map_id(value)
|
||||
|
||||
for clean in sanitized:
|
||||
if isinstance(clean.get("tool_calls"), list):
|
||||
normalized = []
|
||||
for tc in clean["tool_calls"]:
|
||||
used_ids: set[str] = set()
|
||||
for idx, tc in enumerate(clean["tool_calls"]):
|
||||
if not isinstance(tc, dict):
|
||||
normalized.append(tc)
|
||||
continue
|
||||
tc_clean = dict(tc)
|
||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||
raw_id = tc_clean.get("id")
|
||||
mapped_id = unique_tool_id(raw_id, used_ids, idx)
|
||||
tc_clean["id"] = mapped_id
|
||||
used_ids.add(mapped_id)
|
||||
if isinstance(raw_id, str) and raw_id:
|
||||
pending_tool_ids.setdefault(raw_id, deque()).append(mapped_id)
|
||||
function = tc_clean.get("function")
|
||||
if isinstance(function, dict):
|
||||
function_clean = dict(function)
|
||||
@ -494,7 +532,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
# that mix non-empty content with tool_calls.
|
||||
clean["content"] = None
|
||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
clean["tool_call_id"] = map_tool_result_id(clean["tool_call_id"])
|
||||
if (
|
||||
force_string_content
|
||||
and not (clean.get("role") == "assistant" and clean.get("tool_calls"))
|
||||
@ -581,39 +619,27 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if wire_effort and semantic_effort != "none":
|
||||
kwargs["reasoning_effort"] = wire_effort
|
||||
|
||||
# Provider-specific thinking parameters.
|
||||
# Only sent when reasoning_effort is explicitly configured so that
|
||||
# the provider default is preserved otherwise.
|
||||
# The mapping is driven by ProviderSpec.thinking_style so that adding
|
||||
# a new provider never requires touching this function.
|
||||
if spec and spec.thinking_style and reasoning_effort is not None:
|
||||
# Only send thinking controls when reasoning_effort is explicit so
|
||||
# omitting the config preserves each provider's default.
|
||||
if reasoning_effort is not None:
|
||||
thinking_enabled = semantic_effort not in ("none", "minimal")
|
||||
extra = _THINKING_STYLE_MAP.get(spec.thinking_style, lambda _: None)(thinking_enabled)
|
||||
if extra:
|
||||
kwargs.setdefault("extra_body", {}).update(extra)
|
||||
for thinking_style in _thinking_styles_for(spec, model_name):
|
||||
extra = _thinking_extra_body(thinking_style, thinking_enabled)
|
||||
if extra:
|
||||
kwargs.setdefault("extra_body", {}).update(extra)
|
||||
gateway_style = getattr(spec, "gateway_reasoning_style", "") if spec else ""
|
||||
if gateway_style and _model_thinking_style(model_name):
|
||||
extra = _gateway_reasoning_extra_body(gateway_style, semantic_effort)
|
||||
if extra:
|
||||
kwargs.setdefault("extra_body", {}).update(extra)
|
||||
|
||||
# Model-level thinking injection for Kimi thinking-capable models.
|
||||
# Strip any provider prefix (e.g. "moonshotai/") before the set lookup
|
||||
# so that OpenRouter-style names like "moonshotai/kimi-k2.5" are handled
|
||||
# identically to bare names like "kimi-k2.5".
|
||||
if reasoning_effort is not None and _is_kimi_thinking_model(model_name):
|
||||
thinking_enabled = semantic_effort not in ("none", "minimal")
|
||||
kwargs.setdefault("extra_body", {}).update(
|
||||
{"thinking": {"type": "enabled" if thinking_enabled else "disabled"}}
|
||||
)
|
||||
|
||||
# Model-level thinking injection for MiMo thinking-capable models.
|
||||
# Same shape as Kimi: gateway providers (OpenRouter, etc.) lack the
|
||||
# xiaomi_mimo spec's thinking_style, so the spec-driven branch above
|
||||
# misses them — match by model name to catch "xiaomi/mimo-v2.5-pro"
|
||||
# and friends. (Direct xiaomi_mimo requests are also covered here;
|
||||
# both branches write the same payload, so the dict update is a
|
||||
# safe no-op for already-handled cases.)
|
||||
if reasoning_effort is not None and _is_mimo_thinking_model(model_name):
|
||||
thinking_enabled = semantic_effort not in ("none", "minimal")
|
||||
kwargs.setdefault("extra_body", {}).update(
|
||||
{"thinking": {"type": "enabled" if thinking_enabled else "disabled"}}
|
||||
)
|
||||
# Moonshot rejects requests that carry both 'reasoning_effort'
|
||||
# and the native 'thinking' param. We already expressed the
|
||||
# user's intent via the provider-native shape, so drop the
|
||||
# redundant wire-level kwarg. Only kimi models need this —
|
||||
# Xiaomi's API accepts both params.
|
||||
if _model_slug(model_name) in _KIMI_THINKING_MODELS:
|
||||
kwargs.pop("reasoning_effort", None)
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
@ -628,8 +654,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
and semantic_effort not in ("none", "minimal")
|
||||
and (
|
||||
(spec and spec.thinking_style)
|
||||
or _is_kimi_thinking_model(model_name)
|
||||
or _is_mimo_thinking_model(model_name)
|
||||
or _model_thinking_style(model_name)
|
||||
)
|
||||
)
|
||||
implicit_deepseek_thinking = (
|
||||
@ -1097,6 +1122,15 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if delta:
|
||||
_accum_legacy_function_call(getattr(delta, "function_call", None))
|
||||
|
||||
# Some providers (e.g. Zhipu/GLM) reuse the same tool_call id for
|
||||
# parallel tool calls in streaming mode. Deduplicate before building
|
||||
# the response so downstream tool messages don't collide.
|
||||
_seen_tc_ids: set[str] = set()
|
||||
for b in tc_bufs.values():
|
||||
if not b["id"] or b["id"] in _seen_tc_ids:
|
||||
b["id"] = _short_tool_id()
|
||||
_seen_tc_ids.add(b["id"])
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=[
|
||||
|
||||
@ -15,6 +15,7 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str
|
||||
"""
|
||||
system_prompt = ""
|
||||
input_items: list[dict[str, Any]] = []
|
||||
used_item_ids: set[str] = set()
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
@ -30,17 +31,19 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str
|
||||
|
||||
if role == "assistant":
|
||||
if isinstance(content, str) and content:
|
||||
message_id = _unique_item_id(f"msg_{idx}", used_item_ids)
|
||||
input_items.append({
|
||||
"type": "message", "role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
"status": "completed", "id": f"msg_{idx}",
|
||||
"status": "completed", "id": message_id,
|
||||
})
|
||||
for tool_call in msg.get("tool_calls", []) or []:
|
||||
fn = tool_call.get("function") or {}
|
||||
call_id, item_id = split_tool_call_id(tool_call.get("id"))
|
||||
response_item_id = _unique_item_id(item_id or f"fc_{idx}", used_item_ids)
|
||||
input_items.append({
|
||||
"type": "function_call",
|
||||
"id": item_id or f"fc_{idx}",
|
||||
"id": response_item_id,
|
||||
"call_id": call_id or f"call_{idx}",
|
||||
"name": fn.get("name"),
|
||||
"arguments": fn.get("arguments") or "{}",
|
||||
@ -97,6 +100,20 @@ def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
return converted
|
||||
|
||||
|
||||
def _unique_item_id(item_id: str, used: set[str]) -> str:
|
||||
"""Return a Responses input item id that is unique within one request."""
|
||||
if item_id not in used:
|
||||
used.add(item_id)
|
||||
return item_id
|
||||
|
||||
suffix = 2
|
||||
while f"{item_id}_{suffix}" in used:
|
||||
suffix += 1
|
||||
unique = f"{item_id}_{suffix}"
|
||||
used.add(unique)
|
||||
return unique
|
||||
|
||||
|
||||
def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
|
||||
"""Split a compound ``call_id|item_id`` string.
|
||||
|
||||
|
||||
@ -71,6 +71,11 @@ class ProviderSpec:
|
||||
# "reasoning_split" — {"reasoning_split": true/false} (MiniMax)
|
||||
thinking_style: str = ""
|
||||
|
||||
# Gateway-native reasoning control to pair with model-level thinking styles.
|
||||
# "reasoning_effort" — {"reasoning": {"effort": <none|minimal|...>}}
|
||||
# (OpenRouter)
|
||||
gateway_reasoning_style: str = ""
|
||||
|
||||
# When True, treat the "reasoning" response field as formal content
|
||||
# when "content" is empty. Only set this for providers (e.g. StepFun)
|
||||
# whose API returns the actual answer in "reasoning" instead of "content".
|
||||
@ -142,6 +147,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
detect_by_base_keyword="openrouter",
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
supports_prompt_caching=True,
|
||||
gateway_reasoning_style="reasoning_effort",
|
||||
),
|
||||
# Hugging Face Inference Providers: OpenAI-compatible router for chat models.
|
||||
ProviderSpec(
|
||||
@ -193,6 +199,18 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
default_api_base="https://api.siliconflow.cn/v1",
|
||||
),
|
||||
|
||||
# Novita AI: OpenAI-compatible gateway for hosted model APIs.
|
||||
ProviderSpec(
|
||||
name="novita",
|
||||
keywords=("novita",),
|
||||
env_key="NOVITA_API_KEY",
|
||||
display_name="Novita AI",
|
||||
backend="openai_compat",
|
||||
is_gateway=True,
|
||||
detect_by_base_keyword="novita",
|
||||
default_api_base="https://api.novita.ai/openai",
|
||||
),
|
||||
|
||||
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
|
||||
ProviderSpec(
|
||||
name="volcengine",
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
# Agent Instructions
|
||||
|
||||
## Workspace Guidance
|
||||
|
||||
Use this file for project-specific preferences, recurring workflow conventions, and instructions you want the agent to remember for this workspace. Keep durable facts about the user in `USER.md`, personality/style guidance in `SOUL.md`, and long-term memory in `memory/MEMORY.md`.
|
||||
|
||||
## Scheduled Reminders
|
||||
|
||||
Before scheduling reminders, check available skills and follow skill guidance first.
|
||||
@ -10,10 +14,10 @@ Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegr
|
||||
|
||||
## Heartbeat Tasks
|
||||
|
||||
`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks:
|
||||
`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks.
|
||||
|
||||
- **Add**: `edit_file` to append new tasks
|
||||
- **Remove**: `edit_file` to delete completed tasks
|
||||
- **Rewrite**: `write_file` to replace all tasks
|
||||
- Use `apply_patch` for normal task-list updates, especially when adding, removing, or changing multiple lines.
|
||||
- Use `edit_file` only for small exact replacements copied from the current `HEARTBEAT.md`.
|
||||
- Use `write_file` for first creation or intentional full-file rewrites.
|
||||
|
||||
When the user asks for a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time cron reminder.
|
||||
|
||||
@ -1,28 +0,0 @@
|
||||
# Tool Usage Notes
|
||||
|
||||
Tool signatures are provided automatically via function calling.
|
||||
This file documents non-obvious constraints and usage patterns.
|
||||
|
||||
## exec — Safety Limits
|
||||
|
||||
- Commands have a configurable timeout (default 60s)
|
||||
- Dangerous commands are blocked (rm -rf, format, dd, shutdown, etc.)
|
||||
- Output is truncated at 10,000 characters
|
||||
- `restrictToWorkspace` config can limit file access to the workspace
|
||||
|
||||
## grep — Content Search
|
||||
|
||||
- Use `grep` to search file contents inside the workspace
|
||||
- Default behavior returns only matching file paths (`output_mode="files_with_matches"`)
|
||||
- Supports optional `glob` filtering (e.g. `glob="*.py"`) plus `context_before` / `context_after`
|
||||
- Supports `type="py"`, `type="ts"`, `type="md"` and similar shorthand filters
|
||||
- Use `fixed_strings=true` for literal keywords containing regex characters
|
||||
- Use `output_mode="files_with_matches"` to get only matching file paths
|
||||
- Use `output_mode="count"` to size a search before reading full matches
|
||||
- Use `head_limit` and `offset` to page across results
|
||||
- Prefer this over `exec` for code and history searches
|
||||
- Binary or oversized files may be skipped to keep results readable
|
||||
|
||||
## cron — Scheduled Reminders
|
||||
|
||||
- Please refer to cron skill for usage.
|
||||
60
nanobot/templates/agent/tool_contract.md
Normal file
60
nanobot/templates/agent/tool_contract.md
Normal file
@ -0,0 +1,60 @@
|
||||
# Tool Usage Notes
|
||||
|
||||
Tool signatures are provided automatically via function calling. This section
|
||||
documents the general tool contract and non-obvious usage patterns.
|
||||
|
||||
## General Tool Contract
|
||||
|
||||
- Use the narrowest structured tool that directly matches the task.
|
||||
- Use read-only discovery before writes when state is uncertain.
|
||||
- Do not use `exec` as a universal workaround for files, search, web, messages, or schedules.
|
||||
- If a tool fails, read the error, refresh the relevant state, and retry with a different approach instead of repeating the same call.
|
||||
- After meaningful changes, verify with the smallest reliable check: re-read changed state, run targeted tests, or inspect command output.
|
||||
- Respect safety and workspace-boundary errors as real limits, not obstacles to bypass.
|
||||
|
||||
## Discovery and Reading
|
||||
|
||||
- Use `find_files` or `list_dir` to locate workspace paths before `read_file` when a path is uncertain.
|
||||
- Use `grep` for content search inside the workspace; prefer it over shell grep for ordinary searches.
|
||||
- `grep` defaults to `output_mode="files_with_matches"`; use `output_mode="content"` for matching lines with context.
|
||||
- Use `fixed_strings=true` for literal keywords containing regex characters.
|
||||
- Use `output_mode="count"` to size a broad search before reading full matches.
|
||||
- Use `head_limit` and `offset` to page across large result sets.
|
||||
- Binary or oversized files may be skipped to keep results readable.
|
||||
|
||||
## File and Coding Workflows
|
||||
|
||||
- For code or config changes, the default loop is: locate (`find_files`/`grep`), inspect (`read_file`), edit (`apply_patch`), then verify (`exec` or re-read).
|
||||
- Use `apply_patch` as the default code editing tool, especially for multi-file changes, structural edits, generated code, moves, adds, or deletes.
|
||||
- Use `apply_patch dry_run=true` when the patch is uncertain and you want validation plus a change summary before writing.
|
||||
- Use `edit_file` only for small exact replacements in one file, with `old_text` copied from `read_file`; add `occurrence`, `line_hint`, or `expected_replacements` when ambiguity matters.
|
||||
- Use `write_file` for new files or intentional full-file rewrites, not routine partial edits.
|
||||
- If `apply_patch` or `edit_file` fails, re-read with `force=true`, narrow the context, and try a smaller patch rather than switching to shell `sed` or `echo`.
|
||||
|
||||
## Process Execution
|
||||
|
||||
- Use `exec` for tests, builds, package commands, git commands, and other process execution.
|
||||
- Prefer dedicated file/search tools over `cat`, shell `find`, shell `grep`, `sed`, or `echo` for ordinary workspace inspection and edits.
|
||||
- Use non-interactive flags such as `-y` or `--yes` when available.
|
||||
- Commands have a configurable timeout (default 60s), dangerous commands are blocked, and output is truncated.
|
||||
- For long-running or interactive commands, pass `yield_time_ms`; if the process keeps running, continue with `write_stdin`.
|
||||
- Use `write_stdin` to poll, provide stdin, close stdin, wait for expected output with `wait_for`, or terminate an existing exec session.
|
||||
- Use `list_exec_sessions` to recover active session IDs after context shifts.
|
||||
|
||||
## Web and External Information
|
||||
|
||||
- Use web tools when the user asks for current information, a specific URL, or information likely to have changed.
|
||||
- Use `web_search` to find sources and `web_fetch` for a specific page or result that needs closer reading.
|
||||
- Do not invent freshness-sensitive facts when tools can verify them.
|
||||
|
||||
## Messaging and Media
|
||||
|
||||
- Use `message` to send content or local media to the user/channel.
|
||||
- `read_file` only reads content for your analysis; it does not deliver a file to the user.
|
||||
- When sending an existing local file, attach it through the message/media mechanism instead of pasting file contents unless the user asked for text.
|
||||
|
||||
## Scheduling and Background Work
|
||||
|
||||
- Use `cron` for scheduled reminders or recurring jobs; do not run `nanobot cron` through `exec`.
|
||||
- For heartbeat tasks, update `HEARTBEAT.md` according to the agent instructions.
|
||||
- Do not write reminders only to memory files when the user expects an actual notification.
|
||||
@ -3,15 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
|
||||
TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"})
|
||||
TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "apply_patch"})
|
||||
_MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024
|
||||
_LIVE_EMIT_INTERVAL_S = 0.18
|
||||
_LIVE_EMIT_LINE_STEP = 24
|
||||
@ -154,19 +152,108 @@ def prepare_file_edit_tracker(
|
||||
workspace: Path | None,
|
||||
params: dict[str, Any] | None,
|
||||
) -> FileEditTracker | None:
|
||||
trackers = prepare_file_edit_trackers(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
tool=tool,
|
||||
workspace=workspace,
|
||||
params=params,
|
||||
)
|
||||
return trackers[0] if trackers else None
|
||||
|
||||
|
||||
def prepare_file_edit_trackers(
|
||||
*,
|
||||
call_id: str,
|
||||
tool_name: str,
|
||||
tool: Any,
|
||||
workspace: Path | None,
|
||||
params: dict[str, Any] | None,
|
||||
) -> list[FileEditTracker]:
|
||||
if not is_file_edit_tool(tool_name):
|
||||
return None
|
||||
return []
|
||||
paths = resolve_file_edit_paths(tool_name, tool, workspace, params)
|
||||
trackers: list[FileEditTracker] = []
|
||||
seen: set[Path] = set()
|
||||
for path in paths:
|
||||
try:
|
||||
resolved = path.resolve()
|
||||
except Exception:
|
||||
resolved = path
|
||||
if resolved in seen:
|
||||
continue
|
||||
seen.add(resolved)
|
||||
before = read_file_snapshot(path)
|
||||
trackers.append(FileEditTracker(
|
||||
call_id=str(call_id or ""),
|
||||
tool=tool_name,
|
||||
path=path,
|
||||
display_path=display_file_edit_path(path, workspace),
|
||||
before=before,
|
||||
))
|
||||
return trackers
|
||||
|
||||
|
||||
def resolve_file_edit_paths(
|
||||
tool_name: str,
|
||||
tool: Any,
|
||||
workspace: Path | None,
|
||||
params: dict[str, Any] | None,
|
||||
) -> list[Path]:
|
||||
if tool_name == "apply_patch":
|
||||
return _resolve_apply_patch_paths(tool, workspace, params)
|
||||
path = resolve_file_edit_path(tool, workspace, params)
|
||||
if path is None:
|
||||
return None
|
||||
before = read_file_snapshot(path)
|
||||
return FileEditTracker(
|
||||
call_id=str(call_id or ""),
|
||||
tool=tool_name,
|
||||
path=path,
|
||||
display_path=display_file_edit_path(path, workspace),
|
||||
before=before,
|
||||
)
|
||||
return []
|
||||
return [path]
|
||||
|
||||
|
||||
def _resolve_apply_patch_paths(
|
||||
tool: Any,
|
||||
workspace: Path | None,
|
||||
params: dict[str, Any] | None,
|
||||
) -> list[Path]:
|
||||
if not isinstance(params, dict):
|
||||
return []
|
||||
edits = params.get("edits")
|
||||
if not isinstance(edits, list) or not edits:
|
||||
return []
|
||||
if params.get("dry_run") is True:
|
||||
return []
|
||||
|
||||
resolved: list[Path] = []
|
||||
seen: set[Path] = set()
|
||||
for edit in edits:
|
||||
if not isinstance(edit, dict):
|
||||
continue
|
||||
raw_path = edit.get("path")
|
||||
if not isinstance(raw_path, str) or not raw_path.strip():
|
||||
continue
|
||||
path = _resolve_raw_file_edit_path(tool, workspace, raw_path)
|
||||
if path is not None and path not in seen:
|
||||
seen.add(path)
|
||||
resolved.append(path)
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_raw_file_edit_path(
|
||||
tool: Any,
|
||||
workspace: Path | None,
|
||||
raw_path: str,
|
||||
) -> Path | None:
|
||||
resolver = getattr(tool, "_resolve", None)
|
||||
if callable(resolver):
|
||||
try:
|
||||
resolved = resolver(raw_path)
|
||||
if isinstance(resolved, Path):
|
||||
return resolved
|
||||
if resolved:
|
||||
return Path(resolved)
|
||||
except Exception:
|
||||
return None
|
||||
if workspace is None:
|
||||
return Path(raw_path).expanduser().resolve()
|
||||
return (workspace / raw_path).expanduser().resolve()
|
||||
|
||||
|
||||
def build_file_edit_start_event(
|
||||
@ -304,6 +391,9 @@ class StreamingFileEditTracker:
|
||||
self._states[key] = state
|
||||
|
||||
state.apply_delta(payload)
|
||||
if state.name == "apply_patch":
|
||||
await self._update_apply_patch(state)
|
||||
return
|
||||
if state.name not in {"write_file", "edit_file"}:
|
||||
return
|
||||
if state.path is None:
|
||||
@ -343,10 +433,80 @@ class StreamingFileEditTracker:
|
||||
deleted=deleted,
|
||||
)])
|
||||
|
||||
async def _update_apply_patch(self, state: _StreamingFileEditState) -> None:
|
||||
if _json_bool_true(state.arguments, "dry_run"):
|
||||
return
|
||||
tool = self._tools.get("apply_patch") if hasattr(self._tools, "get") else None
|
||||
events: list[dict[str, Any]] = []
|
||||
now = time.monotonic()
|
||||
|
||||
path_matches = list(re.finditer(r'"path"\s*:\s*"([^"]+)"', state.arguments))
|
||||
if not path_matches:
|
||||
return
|
||||
|
||||
for i, m in enumerate(path_matches):
|
||||
raw_path = m.group(1)
|
||||
path = _resolve_raw_file_edit_path(tool, self._workspace, raw_path)
|
||||
if path is None:
|
||||
continue
|
||||
|
||||
segment_start = m.start()
|
||||
segment_end = path_matches[i + 1].start() if i + 1 < len(path_matches) else len(state.arguments)
|
||||
segment = state.arguments[segment_start:segment_end]
|
||||
|
||||
action_match = re.search(r'"action"\s*:\s*"(replace|add|delete)"', segment)
|
||||
action = action_match.group(1) if action_match else "replace"
|
||||
|
||||
old_text = _extract_json_string_prefix(segment, "old_text") or ""
|
||||
new_text = _extract_json_string_prefix(segment, "new_text") or ""
|
||||
|
||||
added = _text_line_count(new_text) if action in ("replace", "add") else 0
|
||||
deleted = _text_line_count(old_text) if action in ("replace", "delete") else 0
|
||||
delete_file = action == "delete"
|
||||
|
||||
file_state = state.patch_files.get(raw_path)
|
||||
if file_state is None:
|
||||
tracker = FileEditTracker(
|
||||
call_id=state.call_id or state.key,
|
||||
tool="apply_patch",
|
||||
path=path,
|
||||
display_path=display_file_edit_path(path, self._workspace),
|
||||
before=read_file_snapshot(path),
|
||||
)
|
||||
file_state = _StreamingPatchFileState(tracker=tracker)
|
||||
state.patch_files[raw_path] = file_state
|
||||
if delete_file and added == 0 and deleted == 0 and file_state.tracker.before.countable:
|
||||
deleted = _text_line_count(file_state.tracker.before.text or "")
|
||||
if not file_state.should_emit(added, deleted, now):
|
||||
continue
|
||||
file_state.mark_emitted(added, deleted, now)
|
||||
events.append(build_file_edit_live_event(
|
||||
file_state.tracker,
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
))
|
||||
if events:
|
||||
await self._emit(events)
|
||||
|
||||
async def flush(self) -> None:
|
||||
events: list[dict[str, Any]] = []
|
||||
now = time.monotonic()
|
||||
for state in self._states.values():
|
||||
for file_state in state.patch_files.values():
|
||||
added, deleted = file_state.last_added, file_state.last_deleted
|
||||
if not file_state.emitted_once:
|
||||
continue
|
||||
if (
|
||||
file_state.last_emitted_added == added
|
||||
and file_state.last_emitted_deleted == deleted
|
||||
):
|
||||
continue
|
||||
file_state.mark_emitted(added, deleted, now)
|
||||
events.append(build_file_edit_live_event(
|
||||
file_state.tracker,
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
))
|
||||
if state.tracker is None:
|
||||
continue
|
||||
added, deleted = state.live_diff_counts()
|
||||
@ -367,12 +527,14 @@ class StreamingFileEditTracker:
|
||||
|
||||
def apply_final_call_ids(self, final_tool_calls: list[Any]) -> None:
|
||||
"""Keep final start/end events keyed to any earlier streamed placeholder."""
|
||||
used_canonicals: set[str] = set()
|
||||
for tool_call in final_tool_calls:
|
||||
canonical = self.canonical_call_id_for(tool_call)
|
||||
if canonical:
|
||||
if canonical and canonical not in used_canonicals:
|
||||
try:
|
||||
tool_call.id = canonical
|
||||
except Exception:
|
||||
used_canonicals.add(canonical)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
def canonical_call_id_for(self, tool_call: Any) -> str | None:
|
||||
@ -389,6 +551,10 @@ class StreamingFileEditTracker:
|
||||
"""Mark streamed edits as failed when no final tool call will run."""
|
||||
events: list[dict[str, Any]] = []
|
||||
for state in self._states.values():
|
||||
for file_state in state.patch_files.values():
|
||||
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
|
||||
continue
|
||||
events.append(build_file_edit_error_event(file_state.tracker, error))
|
||||
if state.tracker is None:
|
||||
continue
|
||||
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
|
||||
@ -492,6 +658,39 @@ class _StreamingJsonStringField:
|
||||
self.last_char_cr = False
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _StreamingPatchFileState:
|
||||
tracker: FileEditTracker
|
||||
emitted_once: bool = False
|
||||
last_emitted_added: int = -1
|
||||
last_emitted_deleted: int = -1
|
||||
last_emit_at: float = 0.0
|
||||
last_added: int = 0
|
||||
last_deleted: int = 0
|
||||
|
||||
def should_emit(self, added: int, deleted: int, now: float) -> bool:
|
||||
self.last_added = added
|
||||
self.last_deleted = deleted
|
||||
if not self.emitted_once:
|
||||
return True
|
||||
if added == self.last_emitted_added and deleted == self.last_emitted_deleted:
|
||||
return False
|
||||
if max(
|
||||
abs(added - self.last_emitted_added),
|
||||
abs(deleted - self.last_emitted_deleted),
|
||||
) >= _LIVE_EMIT_LINE_STEP:
|
||||
return True
|
||||
return now - self.last_emit_at >= _LIVE_EMIT_INTERVAL_S
|
||||
|
||||
def mark_emitted(self, added: int, deleted: int, now: float) -> None:
|
||||
self.emitted_once = True
|
||||
self.last_added = added
|
||||
self.last_deleted = deleted
|
||||
self.last_emitted_added = added
|
||||
self.last_emitted_deleted = deleted
|
||||
self.last_emit_at = now
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _StreamingFileEditState:
|
||||
key: str
|
||||
@ -509,6 +708,7 @@ class _StreamingFileEditState:
|
||||
new_text: _StreamingJsonStringField = field(
|
||||
default_factory=lambda: _StreamingJsonStringField("new_text")
|
||||
)
|
||||
patch_files: dict[str, _StreamingPatchFileState] = field(default_factory=dict)
|
||||
emitted_once: bool = False
|
||||
last_emitted_added: int = -1
|
||||
last_emitted_deleted: int = -1
|
||||
@ -531,6 +731,7 @@ class _StreamingFileEditState:
|
||||
self.content.reset()
|
||||
self.old_text.reset()
|
||||
self.new_text.reset()
|
||||
self.patch_files.clear()
|
||||
return
|
||||
delta = payload.get("arguments_delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
@ -590,6 +791,14 @@ class _StreamingFileEditState:
|
||||
name = getattr(tool_call, "name", None)
|
||||
if name != self.name:
|
||||
return False
|
||||
if self.name == "apply_patch":
|
||||
arguments = getattr(tool_call, "arguments", None)
|
||||
if not isinstance(arguments, dict):
|
||||
return False
|
||||
edits = arguments.get("edits")
|
||||
if not isinstance(edits, list):
|
||||
return False
|
||||
return '"edits"' in self.arguments
|
||||
arguments = getattr(tool_call, "arguments", None)
|
||||
if not isinstance(arguments, dict):
|
||||
return False
|
||||
@ -612,6 +821,51 @@ def _stream_key(payload: dict[str, Any]) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _json_bool_true(source: str, key: str) -> bool:
|
||||
return re.search(rf'"{re.escape(key)}"\s*:\s*true\b', source) is not None
|
||||
|
||||
|
||||
def _extract_json_string_prefix(source: str, key: str) -> str | None:
|
||||
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
|
||||
if match is None:
|
||||
return None
|
||||
out: list[str] = []
|
||||
i = match.end()
|
||||
escape = False
|
||||
while i < len(source):
|
||||
ch = source[i]
|
||||
if escape:
|
||||
escape = False
|
||||
if ch == "n":
|
||||
out.append("\n")
|
||||
elif ch == "r":
|
||||
out.append("\r")
|
||||
elif ch == "t":
|
||||
out.append("\t")
|
||||
elif ch == "u":
|
||||
digits = source[i + 1:i + 5]
|
||||
if len(digits) < 4:
|
||||
break
|
||||
try:
|
||||
out.append(chr(int(digits, 16)))
|
||||
except ValueError:
|
||||
break
|
||||
i += 4
|
||||
else:
|
||||
out.append(ch)
|
||||
i += 1
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
i += 1
|
||||
continue
|
||||
if ch == '"':
|
||||
return "".join(out)
|
||||
out.append(ch)
|
||||
i += 1
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _extract_complete_json_string(source: str, key: str) -> str | None:
|
||||
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
|
||||
if match is None:
|
||||
@ -704,77 +958,4 @@ def _predict_after_text(
|
||||
return before_text.replace(old_text, new_text)
|
||||
return before_text.replace(old_text, new_text, 1)
|
||||
return None
|
||||
if tool_name == "notebook_edit":
|
||||
return _predict_notebook_after_text(params, before_text)
|
||||
return None
|
||||
|
||||
|
||||
def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> str | None:
|
||||
try:
|
||||
nb = json.loads(before_text) if before_text.strip() else _empty_notebook()
|
||||
except Exception:
|
||||
return None
|
||||
cells = nb.get("cells")
|
||||
if not isinstance(cells, list):
|
||||
return None
|
||||
try:
|
||||
cell_index = int(params.get("cell_index", 0))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
new_source = params.get("new_source")
|
||||
source = new_source if isinstance(new_source, str) else ""
|
||||
cell_type = (
|
||||
params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code"
|
||||
)
|
||||
mode = (
|
||||
params.get("edit_mode")
|
||||
if params.get("edit_mode") in ("replace", "insert", "delete")
|
||||
else "replace"
|
||||
)
|
||||
if mode == "delete":
|
||||
if 0 <= cell_index < len(cells):
|
||||
cells.pop(cell_index)
|
||||
else:
|
||||
return None
|
||||
elif mode == "insert":
|
||||
insert_at = min(max(cell_index + 1, 0), len(cells))
|
||||
cells.insert(insert_at, _new_notebook_cell(source, str(cell_type)))
|
||||
else:
|
||||
if not (0 <= cell_index < len(cells)):
|
||||
return None
|
||||
cell = cells[cell_index]
|
||||
if not isinstance(cell, dict):
|
||||
return None
|
||||
cell["source"] = source
|
||||
cell["cell_type"] = cell_type
|
||||
if cell_type == "code":
|
||||
cell.setdefault("outputs", [])
|
||||
cell.setdefault("execution_count", None)
|
||||
else:
|
||||
cell.pop("outputs", None)
|
||||
cell.pop("execution_count", None)
|
||||
nb["cells"] = cells
|
||||
try:
|
||||
return json.dumps(nb, indent=1, ensure_ascii=False)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _empty_notebook() -> dict[str, Any]:
|
||||
return {
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5,
|
||||
"metadata": {
|
||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
||||
"language_info": {"name": "python"},
|
||||
},
|
||||
"cells": [],
|
||||
}
|
||||
|
||||
|
||||
def _new_notebook_cell(source: str, cell_type: str) -> dict[str, Any]:
|
||||
cell: dict[str, Any] = {"cell_type": cell_type, "source": source, "metadata": {}}
|
||||
if cell_type == "code":
|
||||
cell["outputs"] = []
|
||||
cell["execution_count"] = None
|
||||
return cell
|
||||
|
||||
@ -576,7 +576,7 @@ def build_status_content(
|
||||
|
||||
|
||||
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||
"""Sync bundled templates to workspace. Only creates missing files."""
|
||||
"""Sync bundled templates to workspace. Creates missing files without overwriting user files."""
|
||||
from importlib.resources import files as pkg_files
|
||||
|
||||
try:
|
||||
@ -589,10 +589,11 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]
|
||||
added: list[str] = []
|
||||
|
||||
def _write(src, dest: Path):
|
||||
content = src.read_text(encoding="utf-8") if src else ""
|
||||
if dest.exists():
|
||||
return
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8")
|
||||
dest.write_text(content, encoding="utf-8")
|
||||
added.append(str(dest.relative_to(workspace)))
|
||||
|
||||
for item in tpl.iterdir():
|
||||
|
||||
@ -11,8 +11,10 @@ _TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = {
|
||||
"read_file": (["path", "file_path"], "read {}", True, False),
|
||||
"write_file": (["path", "file_path"], "write {}", True, False),
|
||||
"edit": (["file_path", "path"], "edit {}", True, False),
|
||||
"find_files": (["query", "glob", "path"], "find {}", False, False),
|
||||
"grep": (["pattern"], 'grep "{}"', False, False),
|
||||
"exec": (["command"], "$ {}", False, True),
|
||||
"list_exec_sessions": ([], "exec sessions", False, False),
|
||||
"web_search": (["query"], 'search "{}"', False, False),
|
||||
"web_fetch": (["url"], "fetch {}", True, False),
|
||||
"list_dir": (["path"], "ls {}", True, False),
|
||||
@ -81,6 +83,8 @@ def _extract_arg(tc, key_args: list[str]) -> str | None:
|
||||
|
||||
def _fmt_known(tc, fmt: tuple, max_length: int = 40) -> str:
|
||||
"""Format a registered tool using its template."""
|
||||
if not fmt[0] and "{}" not in fmt[1]:
|
||||
return fmt[1]
|
||||
val = _extract_arg(tc, fmt[0])
|
||||
if val is None:
|
||||
return tc.name
|
||||
|
||||
@ -73,12 +73,16 @@ def _mask_secret_hint(secret: str | None) -> str | None:
|
||||
def _provider_requires_api_key(spec: Any) -> bool:
|
||||
if spec.backend == "azure_openai":
|
||||
return True
|
||||
if spec.is_oauth:
|
||||
return False
|
||||
if spec.is_local or spec.is_direct:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _provider_configured_for_settings(spec: Any, provider_config: Any) -> bool:
|
||||
if spec.is_oauth:
|
||||
return True
|
||||
if _provider_requires_api_key(spec):
|
||||
return bool(provider_config.api_key)
|
||||
return bool(
|
||||
|
||||
@ -139,6 +139,13 @@ class TestLoadBootstrapFiles:
|
||||
for name in ContextBuilder.BOOTSTRAP_FILES:
|
||||
assert f"## {name}" in result
|
||||
|
||||
def test_legacy_tools_md_is_not_bootstrapped(self, tmp_path):
|
||||
(tmp_path / "TOOLS.md").write_text("workspace tool notes", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._load_bootstrap_files()
|
||||
assert "TOOLS.md" not in result
|
||||
assert "workspace tool notes" not in result
|
||||
|
||||
def test_utf8_content(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("用中文回复", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
@ -171,6 +178,37 @@ class TestIsTemplateContent:
|
||||
assert ContextBuilder._is_template_content("totally different", "memory/MEMORY.md") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bundled bootstrap templates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBundledToolContract:
|
||||
def test_tool_contract_balances_general_and_coding_workflows(self):
|
||||
from importlib.resources import files as pkg_files
|
||||
|
||||
tpl = pkg_files("nanobot") / "templates" / "agent" / "tool_contract.md"
|
||||
content = tpl.read_text(encoding="utf-8")
|
||||
|
||||
assert "## General Tool Contract" in content
|
||||
assert "Use the narrowest structured tool" in content
|
||||
assert "Do not use `exec` as a universal workaround" in content
|
||||
assert "## File and Coding Workflows" in content
|
||||
assert "apply_patch" in content
|
||||
assert "## Web and External Information" in content
|
||||
assert "## Messaging and Media" in content
|
||||
assert "## Scheduling and Background Work" in content
|
||||
assert "pure coding" not in content.lower()
|
||||
|
||||
def test_tool_contract_is_injected_without_workspace_file(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
assert "# Tool Usage Notes" in prompt
|
||||
assert "## General Tool Contract" in prompt
|
||||
assert "Do not use `exec` as a universal workaround" in prompt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_user_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -346,6 +346,26 @@ class TestSyncWorkspaceTemplates:
|
||||
content = (workspace / "AGENTS.md").read_text()
|
||||
assert content == "existing content"
|
||||
|
||||
def test_does_not_create_tools_md(self, tmp_path):
|
||||
"""Tool contract is injected internally, not copied into user workspaces."""
|
||||
workspace = tmp_path / "workspace"
|
||||
|
||||
added = sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
assert "TOOLS.md" not in added
|
||||
assert not (workspace / "TOOLS.md").exists()
|
||||
|
||||
def test_preserves_existing_tools_md_without_overwriting(self, tmp_path):
|
||||
"""Legacy user workspaces may have TOOLS.md; sync should leave it untouched."""
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir(parents=True)
|
||||
tools_path = workspace / "TOOLS.md"
|
||||
tools_path.write_text("custom tool notes", encoding="utf-8")
|
||||
|
||||
sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
assert tools_path.read_text(encoding="utf-8") == "custom tool notes"
|
||||
|
||||
def test_creates_memory_directory(self, tmp_path):
|
||||
"""Should create memory directory structure."""
|
||||
workspace = tmp_path / "workspace"
|
||||
|
||||
1514
tests/channels/test_signal_channel.py
Normal file
1514
tests/channels/test_signal_channel.py
Normal file
File diff suppressed because it is too large
Load Diff
525
tests/channels/test_signal_markdown.py
Normal file
525
tests/channels/test_signal_markdown.py
Normal file
@ -0,0 +1,525 @@
|
||||
"""Unit tests for the Signal markdown → plain text + textStyle converter."""
|
||||
|
||||
from nanobot.channels.signal import _markdown_to_signal, _partition_styles
|
||||
from nanobot.utils.helpers import split_message
|
||||
|
||||
|
||||
def _utf16_len(s: str) -> int:
|
||||
return len(s.encode("utf-16-le")) // 2
|
||||
|
||||
|
||||
def styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]:
|
||||
"""Return a dict mapping each styled substring to its style list."""
|
||||
result: dict[str, list[str]] = {}
|
||||
for entry in text_styles:
|
||||
start_s, length_s, style = entry.split(":", 2)
|
||||
start, length = int(start_s), int(length_s)
|
||||
span = plain[start : start + length]
|
||||
result.setdefault(span, []).append(style)
|
||||
return result
|
||||
|
||||
|
||||
def utf16_styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]:
|
||||
"""Like styles_for, but slices `plain` using UTF-16 offsets (Signal's units)."""
|
||||
encoded = plain.encode("utf-16-le")
|
||||
result: dict[str, list[str]] = {}
|
||||
for entry in text_styles:
|
||||
start_s, length_s, style = entry.split(":", 2)
|
||||
start, length = int(start_s), int(length_s)
|
||||
span = encoded[start * 2 : (start + length) * 2].decode("utf-16-le")
|
||||
result.setdefault(span, []).append(style)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_empty():
|
||||
plain, styles = _markdown_to_signal("")
|
||||
assert plain == ""
|
||||
assert styles == []
|
||||
|
||||
|
||||
def test_plain_text():
|
||||
plain, styles = _markdown_to_signal("hello world")
|
||||
assert plain == "hello world"
|
||||
assert styles == []
|
||||
|
||||
|
||||
def test_bold_stars():
|
||||
plain, styles = _markdown_to_signal("say **hello** now")
|
||||
assert plain == "say hello now"
|
||||
assert styles_for(plain, styles) == {"hello": ["BOLD"]}
|
||||
|
||||
|
||||
def test_bold_underscores():
|
||||
plain, styles = _markdown_to_signal("say __hello__ now")
|
||||
assert plain == "say hello now"
|
||||
assert styles_for(plain, styles) == {"hello": ["BOLD"]}
|
||||
|
||||
|
||||
def test_italic_star():
|
||||
plain, styles = _markdown_to_signal("say *hello* now")
|
||||
assert plain == "say hello now"
|
||||
assert styles_for(plain, styles) == {"hello": ["ITALIC"]}
|
||||
|
||||
|
||||
def test_italic_underscore():
|
||||
plain, styles = _markdown_to_signal("say _hello_ now")
|
||||
assert plain == "say hello now"
|
||||
assert styles_for(plain, styles) == {"hello": ["ITALIC"]}
|
||||
|
||||
|
||||
def test_strikethrough():
|
||||
plain, styles = _markdown_to_signal("say ~~hello~~ now")
|
||||
assert plain == "say hello now"
|
||||
assert styles_for(plain, styles) == {"hello": ["STRIKETHROUGH"]}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Code
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_inline_code():
|
||||
plain, styles = _markdown_to_signal("run `ls -la` here")
|
||||
assert plain == "run ls -la here"
|
||||
assert styles_for(plain, styles) == {"ls -la": ["MONOSPACE"]}
|
||||
|
||||
|
||||
def test_code_block():
|
||||
plain, styles = _markdown_to_signal("```\nprint('hi')\n```")
|
||||
assert "print('hi')" in plain
|
||||
assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or "MONOSPACE" in str(
|
||||
styles_for(plain, styles)
|
||||
)
|
||||
|
||||
|
||||
def test_code_block_with_lang():
|
||||
plain, styles = _markdown_to_signal("```python\ncode\n```")
|
||||
assert "code" in plain
|
||||
assert any("MONOSPACE" in s for s in styles)
|
||||
|
||||
|
||||
def test_code_block_not_processed_further():
|
||||
"""Markdown inside a code block must not be styled."""
|
||||
plain, styles = _markdown_to_signal("```\n**not bold**\n```")
|
||||
assert "**not bold**" in plain
|
||||
# Only MONOSPACE should be applied, no BOLD
|
||||
for entry in styles:
|
||||
assert "BOLD" not in entry
|
||||
|
||||
|
||||
def test_inline_code_not_processed_further():
|
||||
"""Markdown inside inline code must not be styled."""
|
||||
plain, styles = _markdown_to_signal("use `**raw**` please")
|
||||
assert "**raw**" in plain
|
||||
for entry in styles:
|
||||
assert "BOLD" not in entry
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Headers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_header_becomes_bold():
|
||||
plain, styles = _markdown_to_signal("# My Title")
|
||||
assert plain == "My Title"
|
||||
assert styles_for(plain, styles) == {"My Title": ["BOLD"]}
|
||||
|
||||
|
||||
def test_h2_becomes_bold():
|
||||
plain, styles = _markdown_to_signal("## Sub-section")
|
||||
assert plain == "Sub-section"
|
||||
assert styles_for(plain, styles) == {"Sub-section": ["BOLD"]}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Blockquotes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_blockquote_strips_marker():
|
||||
plain, styles = _markdown_to_signal("> some quote")
|
||||
assert plain == "some quote"
|
||||
assert styles == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bullet_dash():
|
||||
plain, styles = _markdown_to_signal("- item one")
|
||||
assert plain == "• item one"
|
||||
|
||||
|
||||
def test_bullet_star():
|
||||
plain, styles = _markdown_to_signal("* item two")
|
||||
assert plain == "• item two"
|
||||
|
||||
|
||||
def test_numbered_list():
|
||||
plain, styles = _markdown_to_signal("1. first\n2. second")
|
||||
assert "1. first" in plain
|
||||
assert "2. second" in plain
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Links
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_link_text_differs_from_url():
|
||||
plain, styles = _markdown_to_signal("[Click here](https://example.com)")
|
||||
assert plain == "Click here (https://example.com)"
|
||||
assert styles == []
|
||||
|
||||
|
||||
def test_link_text_equals_url():
|
||||
plain, styles = _markdown_to_signal("[https://example.com](https://example.com)")
|
||||
assert plain == "https://example.com"
|
||||
assert styles == []
|
||||
|
||||
|
||||
def test_link_text_equals_url_without_scheme():
|
||||
plain, styles = _markdown_to_signal("[example.com](https://example.com)")
|
||||
assert plain == "https://example.com"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mixed / nesting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bold_and_italic_adjacent():
|
||||
plain, styles = _markdown_to_signal("**bold** and *italic*")
|
||||
assert plain == "bold and italic"
|
||||
sd = styles_for(plain, styles)
|
||||
assert sd.get("bold") == ["BOLD"]
|
||||
assert sd.get("italic") == ["ITALIC"]
|
||||
|
||||
|
||||
def test_header_with_inline_code():
|
||||
"""Header becomes BOLD; code inside becomes MONOSPACE (not double-BOLD)."""
|
||||
plain, styles = _markdown_to_signal("# Use `grep`")
|
||||
assert plain == "Use grep"
|
||||
sd = styles_for(plain, styles)
|
||||
assert "BOLD" in sd.get("Use ", []) or "BOLD" in str(styles)
|
||||
assert "MONOSPACE" in sd.get("grep", [])
|
||||
|
||||
|
||||
def test_multiline_mixed():
|
||||
md = "**Title**\n\nSome *italic* text.\n\n- bullet\n- another"
|
||||
plain, styles = _markdown_to_signal(md)
|
||||
assert "Title" in plain
|
||||
assert "italic" in plain
|
||||
assert "• bullet" in plain
|
||||
sd = styles_for(plain, styles)
|
||||
assert "BOLD" in sd.get("Title", [])
|
||||
assert "ITALIC" in sd.get("italic", [])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Table rendering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_table_rendered_as_monospace():
|
||||
md = "| A | B |\n| - | - |\n| 1 | 2 |"
|
||||
plain, styles = _markdown_to_signal(md)
|
||||
assert "A" in plain and "B" in plain
|
||||
assert any("MONOSPACE" in s for s in styles)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Style range format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_style_range_format():
|
||||
"""Each style entry must be 'start:length:STYLE'."""
|
||||
_, styles = _markdown_to_signal("**bold** text")
|
||||
for entry in styles:
|
||||
parts = entry.split(":")
|
||||
assert len(parts) == 3
|
||||
assert parts[0].isdigit()
|
||||
assert parts[1].isdigit()
|
||||
assert parts[2] in {"BOLD", "ITALIC", "STRIKETHROUGH", "MONOSPACE", "SPOILER"}
|
||||
|
||||
|
||||
def test_style_ranges_are_within_bounds():
|
||||
text = "hello **world** end"
|
||||
plain, styles = _markdown_to_signal(text)
|
||||
for entry in styles:
|
||||
start_s, length_s, _ = entry.split(":", 2)
|
||||
start, length = int(start_s), int(length_s)
|
||||
assert start >= 0
|
||||
assert start + length <= len(plain)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-BMP / UTF-16 offsets
|
||||
#
|
||||
# Signal's BodyRange (and signal-cli's textStyle) interprets start/length in
|
||||
# UTF-16 code units. Python's len() counts code points, so characters outside
|
||||
# the BMP (emojis, supplementary CJK) shift offsets by +1 per occurrence.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def assert_within_utf16_bounds(plain: str, styles: list[str]) -> None:
|
||||
limit = _utf16_len(plain)
|
||||
for entry in styles:
|
||||
start_s, length_s, _ = entry.split(":", 2)
|
||||
start, length = int(start_s), int(length_s)
|
||||
assert start >= 0
|
||||
assert start + length <= limit, f"range {entry} exceeds utf-16 length {limit} of {plain!r}"
|
||||
|
||||
|
||||
def test_bold_with_emoji_inside():
|
||||
plain, styles = _markdown_to_signal("**hi 🎉 bye**")
|
||||
assert plain == "hi 🎉 bye"
|
||||
assert utf16_styles_for(plain, styles) == {"hi 🎉 bye": ["BOLD"]}
|
||||
assert_within_utf16_bounds(plain, styles)
|
||||
|
||||
|
||||
def test_italic_with_trailing_emoji():
|
||||
plain, styles = _markdown_to_signal("*bye 🎉*")
|
||||
assert plain == "bye 🎉"
|
||||
assert utf16_styles_for(plain, styles) == {"bye 🎉": ["ITALIC"]}
|
||||
assert_within_utf16_bounds(plain, styles)
|
||||
|
||||
|
||||
def test_bold_after_emoji_prefix():
|
||||
plain, styles = _markdown_to_signal("🎉 **bold**")
|
||||
assert plain == "🎉 bold"
|
||||
assert utf16_styles_for(plain, styles) == {"bold": ["BOLD"]}
|
||||
assert_within_utf16_bounds(plain, styles)
|
||||
|
||||
|
||||
def test_bold_after_and_inside_emoji():
|
||||
plain, styles = _markdown_to_signal("🎉 **a 🎊 b**")
|
||||
assert plain == "🎉 a 🎊 b"
|
||||
assert utf16_styles_for(plain, styles) == {"a 🎊 b": ["BOLD"]}
|
||||
assert_within_utf16_bounds(plain, styles)
|
||||
|
||||
|
||||
def test_supplementary_cjk_in_bold():
|
||||
"""Non-BMP CJK (U+20BB7) proves the bug is UTF-16, not emoji-specific."""
|
||||
plain, styles = _markdown_to_signal("**𠮷野家**")
|
||||
assert plain == "𠮷野家"
|
||||
assert utf16_styles_for(plain, styles) == {"𠮷野家": ["BOLD"]}
|
||||
assert_within_utf16_bounds(plain, styles)
|
||||
|
||||
|
||||
def test_zwj_emoji_in_bold():
|
||||
"""ZWJ family sequence = multiple surrogate pairs + BMP ZWJs."""
|
||||
plain, styles = _markdown_to_signal("**hi 👨👩👧 bye**")
|
||||
assert plain == "hi 👨👩👧 bye"
|
||||
assert utf16_styles_for(plain, styles) == {"hi 👨👩👧 bye": ["BOLD"]}
|
||||
assert_within_utf16_bounds(plain, styles)
|
||||
|
||||
|
||||
def test_ascii_offsets_unchanged():
|
||||
"""ASCII-only path must produce the same offsets as before the UTF-16 fix."""
|
||||
plain, styles = _markdown_to_signal("**bold** plain *it*")
|
||||
assert plain == "bold plain it"
|
||||
assert sorted(styles) == sorted(["0:4:BOLD", "11:2:ITALIC"])
|
||||
|
||||
|
||||
def test_reported_daily_brief_pattern():
|
||||
"""Regression for the reported bug: a single non-BMP emoji shifts every
|
||||
subsequent styled span left by 1 UTF-16 unit, lopping off the last letter.
|
||||
"""
|
||||
md = (
|
||||
"**Weather**\n"
|
||||
"- Conditions: 🌩️ Thunderstorms\n\n"
|
||||
"**News**\n"
|
||||
"*World*\n"
|
||||
"*Local*\n\n"
|
||||
"**Quote of the Day**"
|
||||
)
|
||||
plain, styles = _markdown_to_signal(md)
|
||||
sd = utf16_styles_for(plain, styles)
|
||||
assert sd.get("Weather") == ["BOLD"]
|
||||
assert sd.get("News") == ["BOLD"]
|
||||
assert sd.get("World") == ["ITALIC"]
|
||||
assert sd.get("Local") == ["ITALIC"]
|
||||
assert sd.get("Quote of the Day") == ["BOLD"]
|
||||
assert_within_utf16_bounds(plain, styles)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chunk redistribution
|
||||
#
|
||||
# split_message can break a long Signal payload into multiple chunks. The
|
||||
# style ranges from _markdown_to_signal are anchored to the full text, so
|
||||
# they must be redistributed per-chunk with rebased offsets — otherwise
|
||||
# styles for chunks 1..N are silently lost.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _resolve_chunk_styles(text: str, max_len: int) -> tuple[list[str], list[list[str]]]:
|
||||
"""Helper: full markdown → signal pipeline, including chunking."""
|
||||
plain, styles = _markdown_to_signal(text)
|
||||
chunks = split_message(plain, max_len) if plain else [""]
|
||||
return chunks, _partition_styles(plain, chunks, styles)
|
||||
|
||||
|
||||
def test_partition_styles_single_chunk_passthrough():
|
||||
plain, styles = _markdown_to_signal("**bold** plain *it*")
|
||||
parts = _partition_styles(plain, [plain], styles)
|
||||
assert parts == [styles]
|
||||
|
||||
|
||||
def test_partition_styles_no_styles():
|
||||
plain = "hello world"
|
||||
assert _partition_styles(plain, [plain], []) == [[]]
|
||||
assert _partition_styles(plain, ["hello", "world"], []) == [[], []]
|
||||
|
||||
|
||||
def test_partition_styles_drops_styles_outside_chunks():
|
||||
"""Whitespace trimmed by split_message must not carry a style range."""
|
||||
plain = "a b"
|
||||
# Fake a style spanning the trimmed whitespace only.
|
||||
chunks = ["a", "b"]
|
||||
parts = _partition_styles(plain, chunks, ["1:3:BOLD"])
|
||||
assert parts == [[], []]
|
||||
|
||||
|
||||
def test_partition_styles_long_message_preserves_chunk_one_styles():
|
||||
"""A bold span deep in the message must follow the message into chunk 1."""
|
||||
# Two ~30-char paragraphs separated by a blank line, then **tail**.
|
||||
line_a = "alpha " * 5 # 30 chars, ends with space
|
||||
line_b = "beta " * 5
|
||||
md = f"{line_a.strip()}\n\n{line_b.strip()}\n\n**tail**"
|
||||
plain, styles = _markdown_to_signal(md)
|
||||
# Force a split between the paragraphs.
|
||||
max_len = len(line_a.strip()) + 2 # fits paragraph A + the "\n\n"
|
||||
chunks = split_message(plain, max_len)
|
||||
assert len(chunks) >= 2, "test setup must produce a split"
|
||||
parts = _partition_styles(plain, chunks, styles)
|
||||
# The bold "tail" should land in the last chunk, with chunk-relative offset.
|
||||
final_chunk = chunks[-1]
|
||||
final_styles = parts[-1]
|
||||
assert any("BOLD" in s for s in final_styles)
|
||||
for entry in final_styles:
|
||||
s, ln, _ = entry.split(":", 2)
|
||||
start, length = int(s), int(ln)
|
||||
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
|
||||
"utf-16-le"
|
||||
)
|
||||
assert slice_ == "tail"
|
||||
|
||||
|
||||
def test_partition_styles_chunk_zero_styles_unchanged():
|
||||
"""Styles entirely in chunk 0 keep their original offsets."""
|
||||
md = "**head** middle and **tail**"
|
||||
plain, styles = _markdown_to_signal(md)
|
||||
# Split so chunk 0 contains "head" and part of the rest, chunk 1 contains "tail".
|
||||
chunks = split_message(plain, 12)
|
||||
assert len(chunks) >= 2
|
||||
parts = _partition_styles(plain, chunks, styles)
|
||||
# "head" lives in chunk 0; assert its offset is unchanged (chunk 0 starts at 0).
|
||||
head_entries = [s for s in parts[0] if "BOLD" in s]
|
||||
assert any(s.startswith("0:4:") for s in head_entries)
|
||||
|
||||
|
||||
def test_partition_styles_with_non_bmp_chunk_offset():
|
||||
"""Chunk-start offsets must be expressed in UTF-16 code units."""
|
||||
# Emoji in chunk 0, bold in chunk 1.
|
||||
md = "🎉 alpha beta gamma\n\n**tail**"
|
||||
plain, styles = _markdown_to_signal(md)
|
||||
chunks = split_message(plain, 18)
|
||||
assert len(chunks) >= 2
|
||||
parts = _partition_styles(plain, chunks, styles)
|
||||
final_styles = parts[-1]
|
||||
assert any("BOLD" in s for s in final_styles)
|
||||
final_chunk = chunks[-1]
|
||||
for entry in final_styles:
|
||||
s, ln, _ = entry.split(":", 2)
|
||||
start, length = int(s), int(ln)
|
||||
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
|
||||
"utf-16-le"
|
||||
)
|
||||
assert slice_ == "tail"
|
||||
|
||||
|
||||
def test_partition_styles_range_spanning_chunks_is_split():
|
||||
"""A style range that straddles a chunk boundary gets sliced into both chunks."""
|
||||
# Construct manually: plain = "abc def", style covers "abc def" (whole thing).
|
||||
plain = "abc def"
|
||||
chunks = split_message(plain, 4) # "abc" / "def"
|
||||
assert chunks == ["abc", "def"]
|
||||
parts = _partition_styles(plain, chunks, ["0:7:BOLD"])
|
||||
# Chunk 0 holds 0:3:BOLD, chunk 1 holds 0:3:BOLD (length=3 each, "def" only
|
||||
# since the space was trimmed by lstrip).
|
||||
assert parts[0] == ["0:3:BOLD"]
|
||||
assert parts[1] == ["0:3:BOLD"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Adjacency, nesting, and malformed input
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bold_italic_combo_outer_bold_inner_italic():
|
||||
"""`**_combo_**` carries both BOLD and ITALIC over the same span."""
|
||||
plain, styles = _markdown_to_signal("**_combo_**")
|
||||
assert plain == "combo"
|
||||
sd = styles_for(plain, styles)
|
||||
assert set(sd.get("combo", [])) == {"BOLD", "ITALIC"}
|
||||
|
||||
|
||||
def test_bold_and_italic_adjacent_no_separator():
|
||||
"""`**bold***italic*` produces BOLD on `bold` and ITALIC on `italic`."""
|
||||
plain, styles = _markdown_to_signal("**bold***italic*")
|
||||
assert plain == "bolditalic"
|
||||
sd = styles_for(plain, styles)
|
||||
assert sd.get("bold") == ["BOLD"]
|
||||
assert sd.get("italic") == ["ITALIC"]
|
||||
|
||||
|
||||
def test_unclosed_bold_falls_through_as_plain():
|
||||
"""An unmatched `**` opener round-trips as literal text with no style."""
|
||||
plain, styles = _markdown_to_signal("**bold")
|
||||
assert plain == "**bold"
|
||||
assert styles == []
|
||||
|
||||
|
||||
def test_unclosed_inline_code_falls_through_as_plain():
|
||||
"""An unmatched backtick round-trips as literal text with no style."""
|
||||
plain, styles = _markdown_to_signal("use `grep")
|
||||
assert plain == "use `grep"
|
||||
assert styles == []
|
||||
|
||||
|
||||
def test_inline_code_inside_blockquote():
|
||||
"""Blockquote prefix is stripped; inline code becomes MONOSPACE."""
|
||||
plain, styles = _markdown_to_signal("> use `grep`")
|
||||
assert plain == "use grep"
|
||||
sd = styles_for(plain, styles)
|
||||
assert sd.get("grep") == ["MONOSPACE"]
|
||||
|
||||
|
||||
def test_header_with_inner_bold_produces_contiguous_bold_ranges():
|
||||
"""`# **wrap** me` — header forces BOLD over the whole line; the inner `**`
|
||||
splits the run, yielding two contiguous BOLD ranges that together cover
|
||||
"wrap me". This is intentional — Signal renders adjacent same-style ranges
|
||||
as a single visual span.
|
||||
"""
|
||||
plain, styles = _markdown_to_signal("# **wrap** me")
|
||||
assert plain == "wrap me"
|
||||
# Both ranges are BOLD; collectively they cover the whole "wrap me".
|
||||
bold_ranges = [s for s in styles if s.endswith(":BOLD")]
|
||||
assert len(bold_ranges) == 2
|
||||
covered = set()
|
||||
for entry in bold_ranges:
|
||||
start, length, _ = entry.split(":", 2)
|
||||
for i in range(int(start), int(start) + int(length)):
|
||||
covered.add(i)
|
||||
assert covered == set(range(len(plain)))
|
||||
@ -1055,6 +1055,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist(
|
||||
}
|
||||
assert image_providers["openrouter"]["label"] == "OpenRouter"
|
||||
assert image_providers["openrouter"]["configured"] is False
|
||||
assert image_providers["openai_codex"]["configured"] is True
|
||||
assert image_providers["gemini"]["label"] == "Gemini"
|
||||
assert body["runtime"]["config_path"] == str(config_path)
|
||||
workspace_path = body["runtime"]["workspace_path"].replace("\\", "/")
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
@ -374,6 +375,7 @@ async def test_send_uses_typing_start_and_cancel_when_ticket_available() -> None
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-typing"
|
||||
channel._context_token_at["wx-user"] = time.time()
|
||||
channel._send_text = AsyncMock()
|
||||
channel._api_post = AsyncMock(
|
||||
side_effect=[
|
||||
@ -402,6 +404,7 @@ async def test_send_still_sends_text_when_typing_ticket_missing() -> None:
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-no-ticket"
|
||||
channel._context_token_at["wx-user"] = time.time()
|
||||
channel._send_text = AsyncMock()
|
||||
channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "no config"})
|
||||
|
||||
@ -1254,3 +1257,526 @@ async def test_send_text_succeeds_on_zero_errcode() -> None:
|
||||
await channel._send_text("wx-user", "hello", "ctx-ok")
|
||||
|
||||
channel._api_post.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_raises_on_nonzero_ret_even_when_errcode_zero() -> None:
|
||||
"""_send_text must raise when the API returns ret != 0, even if errcode is 0.
|
||||
|
||||
The iLink API signals failure through either field. Checking only errcode
|
||||
caused silent message drops (responses generated but never delivered).
|
||||
"""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._api_post = AsyncMock(
|
||||
return_value={"ret": -100, "errcode": 0, "errmsg": "internal error"}
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="WeChat send text error.*ret=-100.*errcode=0"):
|
||||
await channel._send_text("wx-user", "hello", "ctx-ok")
|
||||
|
||||
channel._api_post.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _poll_once not silently dropping messages on processing errors
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_once_logs_exception_on_process_message_failure(monkeypatch) -> None:
|
||||
"""When _process_message raises, _poll_once must log the error and continue
|
||||
processing remaining messages instead of silently swallowing the exception."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = SimpleNamespace(timeout=None)
|
||||
channel._token = "token"
|
||||
channel._get_updates_buf = "old-buf"
|
||||
|
||||
calls = []
|
||||
logged_messages: list[str] = []
|
||||
|
||||
async def _failing_process(msg: dict) -> None:
|
||||
calls.append(msg.get("message_id"))
|
||||
if msg.get("message_id") == "msg-1":
|
||||
raise RuntimeError("processing failed")
|
||||
|
||||
channel._process_message = _failing_process # type: ignore[method-assign]
|
||||
|
||||
monkeypatch.setattr(
|
||||
channel.logger,
|
||||
"exception",
|
||||
lambda message, *args, **kwargs: logged_messages.append(str(message)),
|
||||
)
|
||||
|
||||
channel._api_post = AsyncMock( # type: ignore[method-assign]
|
||||
return_value={
|
||||
"ret": 0,
|
||||
"errcode": 0,
|
||||
"get_updates_buf": "new-buf",
|
||||
"msgs": [
|
||||
{"message_id": "msg-1", "message_type": 1},
|
||||
{"message_id": "msg-2", "message_type": 1},
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await channel._poll_once()
|
||||
|
||||
# Both messages should have been attempted
|
||||
assert calls == ["msg-1", "msg-2"]
|
||||
# Buffer should still advance (already updated before processing)
|
||||
assert channel._get_updates_buf == "new-buf"
|
||||
# Error should be logged
|
||||
assert any("Failed to process WeChat message" in m for m in logged_messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_loop_logs_exception_and_continues_on_poll_failure(monkeypatch) -> None:
|
||||
"""When _poll_once raises a non-timeout exception, the start() loop must log
|
||||
the error and continue polling instead of exiting silently."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel.config.token = "token" # skip QR login in start()
|
||||
channel._running = True
|
||||
|
||||
call_count = 0
|
||||
logged_messages: list[str] = []
|
||||
|
||||
async def _failing_poll() -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("poll exploded")
|
||||
channel._running = False # Stop after second call
|
||||
|
||||
channel._poll_once = _failing_poll # type: ignore[method-assign]
|
||||
|
||||
monkeypatch.setattr(
|
||||
channel.logger,
|
||||
"exception",
|
||||
lambda message, *args, **kwargs: logged_messages.append(str(message)),
|
||||
)
|
||||
|
||||
# Use a tiny retry delay so the test finishes quickly
|
||||
original_retry = weixin_mod.RETRY_DELAY_S
|
||||
weixin_mod.RETRY_DELAY_S = 0.01
|
||||
try:
|
||||
await channel.start()
|
||||
finally:
|
||||
weixin_mod.RETRY_DELAY_S = original_retry
|
||||
|
||||
assert call_count == 2
|
||||
assert any("WeChat poll loop error" in m for m in logged_messages)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool-hint buffering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_single_tool_hint_not_sent_immediately() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel.send_tool_hints = True
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._context_token_at["wx-user"] = time.time()
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "Using tool",
|
||||
"media": [],
|
||||
"metadata": {"_progress": True, "_tool_hint": True},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
channel._send_text.assert_not_awaited()
|
||||
assert channel._pending_tool_hints["wx-user"] == ["Using tool"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_multiple_tool_hints_flushed_on_final_answer() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel.send_tool_hints = True
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._context_token_at["wx-user"] = time.time()
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
for hint in ["tool1", "tool2"]:
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": hint,
|
||||
"media": [],
|
||||
"metadata": {"_progress": True, "_tool_hint": True},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "Done",
|
||||
"media": [],
|
||||
"metadata": {},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
assert channel._send_text.await_count == 2
|
||||
channel._send_text.assert_any_await("wx-user", "tool1\n\ntool2", "ctx-1")
|
||||
channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
|
||||
assert "wx-user" not in channel._pending_tool_hints
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thought_progress_flushes_tool_hints() -> None:
|
||||
"""Thoughts are visible progress messages and must act as separators,
|
||||
flushing buffered tool hints before they are sent."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel.send_tool_hints = True
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._context_token_at["wx-user"] = time.time()
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
# Buffer a tool hint
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "search 'foo'",
|
||||
"media": [],
|
||||
"metadata": {"_progress": True, "_tool_hint": True},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
# Send a thought — progress but not a tool_hint.
|
||||
# It must act as a separator and flush the buffered hint.
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "Let me think...",
|
||||
"media": [],
|
||||
"metadata": {"_progress": True},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
# The buffered hint was flushed before the thought was sent.
|
||||
channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1")
|
||||
channel._send_text.assert_any_await("wx-user", "Let me think...", "ctx-1")
|
||||
assert "wx-user" not in channel._pending_tool_hints
|
||||
|
||||
# Final answer arrives with nothing left to flush.
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "Done",
|
||||
"media": [],
|
||||
"metadata": {},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
assert channel._send_text.await_count == 3
|
||||
channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_delta_does_not_flush_tool_hints() -> None:
|
||||
"""Reasoning deltas are invisible in WeChat and must NOT flush buffered
|
||||
tool hints — otherwise hints separated only by hidden reasoning would
|
||||
fail to coalesce."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel.send_tool_hints = True
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._context_token_at["wx-user"] = time.time()
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
# Buffer a tool hint
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "search 'foo'",
|
||||
"media": [],
|
||||
"metadata": {"_progress": True, "_tool_hint": True},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
# Send a reasoning delta — invisible in WeChat, must NOT flush
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "Thinking step 1...",
|
||||
"media": [],
|
||||
"metadata": {"_progress": True, "_reasoning_delta": True},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
# Reasoning is invisible; hint stays buffered, _send_text not called
|
||||
channel._send_text.assert_not_awaited()
|
||||
assert channel._pending_tool_hints["wx-user"] == ["search 'foo'"]
|
||||
|
||||
# Final answer flushes the buffered hint
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "Done",
|
||||
"media": [],
|
||||
"metadata": {},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1")
|
||||
channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
|
||||
assert "wx-user" not in channel._pending_tool_hints
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_progress_message_does_not_flush_tool_hints() -> None:
|
||||
"""Empty progress messages (e.g. after_iteration tool_events) have no
|
||||
visible content and must NOT act as separators."""
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel.send_tool_hints = True
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._context_token_at["wx-user"] = time.time()
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
# Buffer a tool hint
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "search 'foo'",
|
||||
"media": [],
|
||||
"metadata": {"_progress": True, "_tool_hint": True},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
# Send an empty progress message (no content, no media)
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "",
|
||||
"media": [],
|
||||
"metadata": {"_progress": True, "_tool_events": [{"phase": "end"}]},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
# Nothing should have been sent yet
|
||||
channel._send_text.assert_not_awaited()
|
||||
assert channel._pending_tool_hints["wx-user"] == ["search 'foo'"]
|
||||
|
||||
# Final answer flushes the buffered hint
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "Done",
|
||||
"media": [],
|
||||
"metadata": {},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1")
|
||||
channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
|
||||
assert "wx-user" not in channel._pending_tool_hints
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_flush_refreshes_context_token() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel.send_tool_hints = True
|
||||
channel._context_tokens["wx-user"] = "ctx-old"
|
||||
channel._context_token_at["wx-user"] = time.time()
|
||||
channel._refresh_context_token_if_stale = AsyncMock(return_value="ctx-refreshed")
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "hint",
|
||||
"media": [],
|
||||
"metadata": {"_progress": True, "_tool_hint": True},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "Done",
|
||||
"media": [],
|
||||
"metadata": {},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
assert channel._refresh_context_token_if_stale.await_count == 2
|
||||
channel._refresh_context_token_if_stale.assert_any_await("wx-user", "ctx-old")
|
||||
channel._send_text.assert_any_await("wx-user", "hint", "ctx-refreshed")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_flush_failure_does_not_block_final_answer() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel.send_tool_hints = True
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._context_token_at["wx-user"] = time.time()
|
||||
channel._send_text = AsyncMock(side_effect=[RuntimeError("boom"), None])
|
||||
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "hint",
|
||||
"media": [],
|
||||
"metadata": {"_progress": True, "_tool_hint": True},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "Done",
|
||||
"media": [],
|
||||
"metadata": {},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
assert channel._send_text.await_count == 2
|
||||
channel._send_text.assert_any_await("wx-user", "hint", "ctx-1")
|
||||
channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_flushed_on_stream_end() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel.send_tool_hints = True
|
||||
channel._context_tokens["wx-user"] = "ctx-1"
|
||||
channel._context_token_at["wx-user"] = time.time()
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "hint",
|
||||
"media": [],
|
||||
"metadata": {"_progress": True, "_tool_hint": True},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
await channel.send_delta("wx-user", "", {"_stream_end": True})
|
||||
|
||||
channel._send_text.assert_awaited_once_with("wx-user", "hint", "ctx-1")
|
||||
assert "wx-user" not in channel._pending_tool_hints
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_clears_buffer() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._pending_tool_hints["wx-user"] = ["hint1", "hint2"]
|
||||
await channel.stop()
|
||||
assert "wx-user" not in channel._pending_tool_hints
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_tool_hints_false_drops_tool_hints() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel.send_tool_hints = False
|
||||
channel._send_text = AsyncMock()
|
||||
|
||||
await channel.send(
|
||||
type(
|
||||
"Msg",
|
||||
(),
|
||||
{
|
||||
"chat_id": "wx-user",
|
||||
"content": "hint",
|
||||
"media": [],
|
||||
"metadata": {"_progress": True, "_tool_hint": True},
|
||||
},
|
||||
)()
|
||||
)
|
||||
|
||||
channel._send_text.assert_not_awaited()
|
||||
assert "wx-user" not in channel._pending_tool_hints
|
||||
|
||||
@ -192,3 +192,20 @@ def test_match_provider_uses_preset_provider_when_forced() -> None:
|
||||
})
|
||||
name = config.get_provider_name()
|
||||
assert name == "anthropic"
|
||||
|
||||
|
||||
def test_match_provider_routes_forced_novita_model_api_models() -> None:
|
||||
config = Config.model_validate({
|
||||
"providers": {
|
||||
"novita": {"apiKey": "sk-test"},
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "deepseek-v4-pro",
|
||||
"provider": "novita",
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
assert config.get_provider_name() == "novita"
|
||||
assert config.get_api_base() == "https://api.novita.ai/openai"
|
||||
|
||||
@ -56,6 +56,35 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None:
|
||||
assert result.content == "hello world"
|
||||
|
||||
|
||||
def test_custom_provider_parse_chunks_deduplicates_parallel_tool_call_ids() -> None:
|
||||
chunks = [{
|
||||
"choices": [{
|
||||
"finish_reason": "tool_calls",
|
||||
"delta": {
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": 0,
|
||||
"id": "call_dup",
|
||||
"function": {"name": "read_file", "arguments": '{"path":"a.txt"}'},
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"id": "call_dup",
|
||||
"function": {"name": "read_file", "arguments": '{"path":"b.txt"}'},
|
||||
},
|
||||
],
|
||||
},
|
||||
}],
|
||||
}]
|
||||
|
||||
result = OpenAICompatProvider._parse_chunks(chunks)
|
||||
ids = [tool_call.id for tool_call in result.tool_calls or []]
|
||||
|
||||
assert ids[0] == "call_dup"
|
||||
assert len(ids) == 2
|
||||
assert len(set(ids)) == 2
|
||||
|
||||
|
||||
def test_local_provider_502_error_includes_reachability_hint() -> None:
|
||||
spec = find_by_name("ollama")
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
|
||||
@ -9,10 +9,13 @@ import pytest
|
||||
|
||||
from nanobot.providers.image_generation import (
|
||||
AIHubMixImageGenerationClient,
|
||||
CodexImageGenerationClient,
|
||||
GeminiImageGenerationClient,
|
||||
GeneratedImageResponse,
|
||||
ImageGenerationError,
|
||||
MiniMaxImageGenerationClient,
|
||||
OllamaImageGenerationClient,
|
||||
OpenAIImageGenerationClient,
|
||||
OpenRouterImageGenerationClient,
|
||||
StepFunImageGenerationClient,
|
||||
)
|
||||
@ -36,12 +39,14 @@ class FakeResponse:
|
||||
payload: dict[str, Any],
|
||||
status_code: int = 200,
|
||||
content: bytes = b"",
|
||||
sse_lines: list[str] | None = None,
|
||||
) -> None:
|
||||
self._payload = payload
|
||||
self.status_code = status_code
|
||||
self.text = str(payload)
|
||||
self.content = content
|
||||
self.request = httpx.Request("POST", "https://openrouter.ai/api/v1/chat/completions")
|
||||
self._sse_lines = sse_lines
|
||||
|
||||
def json(self) -> dict[str, Any]:
|
||||
return self._payload
|
||||
@ -51,6 +56,15 @@ class FakeResponse:
|
||||
response = httpx.Response(self.status_code, request=self.request, text=self.text)
|
||||
raise httpx.HTTPStatusError("failed", request=self.request, response=response)
|
||||
|
||||
async def aiter_lines(self):
|
||||
if self._sse_lines is not None:
|
||||
for line in self._sse_lines:
|
||||
yield line
|
||||
return
|
||||
# Fallback: treat response text as SSE lines
|
||||
for line in self.text.split("\n"):
|
||||
yield line
|
||||
|
||||
|
||||
class FakeClient:
|
||||
def __init__(self, response: FakeResponse) -> None:
|
||||
@ -133,6 +147,54 @@ async def test_openrouter_image_generation_requires_api_key() -> None:
|
||||
await client.generate(prompt="draw", model="model")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_image_generation_payload_and_response() -> None:
|
||||
raw_b64 = PNG_DATA_URL.removeprefix("data:image/png;base64,")
|
||||
fake = FakeClient(FakeResponse({"image": raw_b64}))
|
||||
client = OllamaImageGenerationClient(
|
||||
api_key="ollama-test",
|
||||
api_base="http://localhost:11434/v1/",
|
||||
extra_headers={"X-Test": "1"},
|
||||
extra_body={"seed": 123},
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(
|
||||
prompt="a sunset",
|
||||
model="x/z-image-turbo",
|
||||
aspect_ratio="16:9",
|
||||
image_size="1K",
|
||||
)
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
assert response.content == ""
|
||||
|
||||
call = fake.calls[0]
|
||||
assert call["url"] == "http://localhost:11434/api/generate"
|
||||
assert call["headers"]["Authorization"] == "Bearer ollama-test"
|
||||
assert call["headers"]["X-Test"] == "1"
|
||||
body = call["json"]
|
||||
assert body["model"] == "x/z-image-turbo"
|
||||
assert body["prompt"] == "a sunset"
|
||||
assert body["width"] == 1024
|
||||
assert body["height"] == 576
|
||||
assert body["steps"] == 0
|
||||
assert body["stream"] is False
|
||||
assert body["seed"] == 123
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_image_generation_rejects_reference_images() -> None:
|
||||
client = OllamaImageGenerationClient(api_key=None)
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="reference images"):
|
||||
await client.generate(
|
||||
prompt="edit this",
|
||||
model="x/z-image-turbo",
|
||||
reference_images=["ref.png"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aihubmix_image_generation_payload_and_response() -> None:
|
||||
raw_b64 = PNG_DATA_URL.removeprefix("data:image/png;base64,")
|
||||
@ -531,3 +593,437 @@ async def test_stepfun_no_images_raises() -> None:
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||
await client.generate(prompt="draw", model="step-image-edit-2")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_payload_and_response() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
api_base="https://api.openai.com/v1",
|
||||
extra_headers={"X-Test": "1"},
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(
|
||||
prompt="a cat on the moon",
|
||||
model="dall-e-3",
|
||||
aspect_ratio="16:9",
|
||||
)
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
call = fake.calls[0]
|
||||
assert call["url"] == "https://api.openai.com/v1/images/generations"
|
||||
assert call["headers"]["Authorization"] == "Bearer sk-openai-test"
|
||||
assert call["headers"]["X-Test"] == "1"
|
||||
body = call["json"]
|
||||
assert body["model"] == "dall-e-3"
|
||||
assert body["prompt"] == "a cat on the moon"
|
||||
assert body["response_format"] == "b64_json"
|
||||
assert body["n"] == 1
|
||||
assert body["size"] == "1792x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_b64_json_response_uses_detected_mime() -> None:
|
||||
raw_b64 = base64.b64encode(JPEG_BYTES).decode("ascii")
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": raw_b64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
assert response.images == [f"data:image/jpeg;base64,{raw_b64}"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_url_download_fallback() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"url": "https://cdn.example/image.png"}]}))
|
||||
fake.get_response = FakeResponse({}, content=PNG_BYTES)
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
assert response.images[0].startswith("data:image/png;base64,")
|
||||
assert fake.get_calls[0]["url"] == "https://cdn.example/image.png"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_multiple_images() -> None:
|
||||
fake = FakeClient(FakeResponse({
|
||||
"data": [
|
||||
{"b64_json": RAW_B64},
|
||||
{"b64_json": RAW_B64},
|
||||
]
|
||||
}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
assert len(response.images) == 2
|
||||
assert response.images == [PNG_DATA_URL, PNG_DATA_URL]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_aspect_ratio_to_size() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="1:1")
|
||||
assert fake.calls[0]["json"]["size"] == "1024x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_dalle3_uses_supported_orientation_sizes() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="3:4")
|
||||
await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="4:3")
|
||||
|
||||
assert fake.calls[0]["json"]["size"] == "1024x1792"
|
||||
assert fake.calls[1]["json"]["size"] == "1792x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_dalle2_uses_square_size_for_non_square_ratios() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(prompt="draw", model="dall-e-2", aspect_ratio="16:9")
|
||||
|
||||
assert fake.calls[0]["json"]["size"] == "1024x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_gpt_image_uses_supported_landscape_size() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="16:9")
|
||||
|
||||
assert fake.calls[0]["json"]["size"] == "1536x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_gpt_image_uses_supported_orientation_sizes() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="3:4")
|
||||
await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="4:3")
|
||||
|
||||
assert fake.calls[0]["json"]["size"] == "1024x1536"
|
||||
assert fake.calls[1]["json"]["size"] == "1536x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_default_size_when_no_aspect_ratio() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
body = fake.calls[0]["json"]
|
||||
assert body["size"] == "1024x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_ignores_explicit_size_unsupported_by_model_family() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(
|
||||
prompt="draw",
|
||||
model="dall-e-3",
|
||||
aspect_ratio="16:9",
|
||||
image_size="1536x1024",
|
||||
)
|
||||
|
||||
body = fake.calls[0]["json"]
|
||||
assert body["size"] == "1792x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_uses_explicit_image_size() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(
|
||||
prompt="draw",
|
||||
model="dall-e-3",
|
||||
aspect_ratio="16:9",
|
||||
image_size="1024x1024",
|
||||
)
|
||||
|
||||
body = fake.calls[0]["json"]
|
||||
assert body["size"] == "1024x1024"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_requires_api_key() -> None:
|
||||
client = OpenAIImageGenerationClient(api_key=None)
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="API key"):
|
||||
await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI Codex (Responses API)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_payload_and_response(monkeypatch) -> None:
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
@dataclass
|
||||
class FakeToken:
|
||||
account_id: str = "acct-123"
|
||||
access: str = "oauth-token"
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||
|
||||
sse_lines = [
|
||||
'data: {"type":"response.output_item.added","item":{"id":"ig_1","type":"image_generation_call","status":"in_progress"}}',
|
||||
"",
|
||||
f'data: {{"type":"response.output_item.done","item":{{"id":"ig_1","type":"image_generation_call","result":"{PNG_DATA_URL}","status":"completed"}}}}',
|
||||
"",
|
||||
'data: [DONE]',
|
||||
"",
|
||||
]
|
||||
fake = FakeClient(FakeResponse({}, sse_lines=sse_lines))
|
||||
client = CodexImageGenerationClient(
|
||||
api_key=None,
|
||||
api_base="https://chatgpt.com/backend-api",
|
||||
extra_headers={"X-Test": "1"},
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(
|
||||
prompt="draw a cat",
|
||||
model="gpt-5.4",
|
||||
)
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
assert response.content == ""
|
||||
call = fake.calls[0]
|
||||
assert call["url"] == "https://chatgpt.com/backend-api/codex/responses"
|
||||
assert call["headers"]["Authorization"] == "Bearer oauth-token"
|
||||
assert call["headers"]["chatgpt-account-id"] == "acct-123"
|
||||
assert call["headers"]["OpenAI-Beta"] == "responses=experimental"
|
||||
assert call["headers"]["X-Test"] == "1"
|
||||
body = call["json"]
|
||||
assert body["model"] == "gpt-5.4"
|
||||
assert body["instructions"] == "Generate an image based on the user's request."
|
||||
assert body["input"] == [{"role": "user", "content": "draw a cat"}]
|
||||
assert body["tools"] == [{"type": "image_generation"}]
|
||||
assert body["tool_choice"] == "auto"
|
||||
assert body["store"] is False
|
||||
assert body["stream"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_strips_model_prefix(monkeypatch) -> None:
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
@dataclass
|
||||
class FakeToken:
|
||||
account_id: str = "acct-123"
|
||||
access: str = "oauth-token"
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||
|
||||
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||
f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":"{PNG_DATA_URL}"}}}}',
|
||||
"",
|
||||
'data: [DONE]',
|
||||
"",
|
||||
]))
|
||||
client = CodexImageGenerationClient(
|
||||
api_key=None, client=fake # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
await client.generate(prompt="draw", model="openai-codex/gpt-5.4")
|
||||
|
||||
assert fake.calls[0]["json"]["model"] == "gpt-5.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_requires_oauth(monkeypatch) -> None:
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
raise RuntimeError("no token")
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
|
||||
client = CodexImageGenerationClient(api_key=None)
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="OAuth token"):
|
||||
await client.generate(prompt="draw", model="gpt-5.4")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_no_images_raises(monkeypatch) -> None:
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
@dataclass
|
||||
class FakeToken:
|
||||
account_id: str = "acct-123"
|
||||
access: str = "oauth-token"
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||
|
||||
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||
'data: {"type":"response.completed","response":{"status":"completed"}}',
|
||||
"",
|
||||
'data: [DONE]',
|
||||
"",
|
||||
]))
|
||||
client = CodexImageGenerationClient(
|
||||
api_key=None, client=fake # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||
await client.generate(prompt="draw", model="gpt-5.4")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_extracts_text_content(monkeypatch) -> None:
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
@dataclass
|
||||
class FakeToken:
|
||||
account_id: str = "acct-123"
|
||||
access: str = "oauth-token"
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||
|
||||
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||
'data: {"type":"response.output_text.delta","delta":"Here "}',
|
||||
"",
|
||||
'data: {"type":"response.output_text.delta","delta":"is your cat image."}',
|
||||
"",
|
||||
f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":"{PNG_DATA_URL}"}}}}',
|
||||
"",
|
||||
'data: [DONE]',
|
||||
"",
|
||||
]))
|
||||
client = CodexImageGenerationClient(
|
||||
api_key=None, client=fake # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(prompt="draw a cat", model="gpt-5.4")
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
assert response.content == "Here is your cat image."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_codex_json_result_format(monkeypatch) -> None:
|
||||
"""image_generation_call result can be a dict with image_url key."""
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
|
||||
@dataclass
|
||||
class FakeToken:
|
||||
account_id: str = "acct-123"
|
||||
access: str = "oauth-token"
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
|
||||
fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
|
||||
monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
|
||||
|
||||
fake = FakeClient(FakeResponse({}, sse_lines=[
|
||||
f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":{{"image_url":"{PNG_DATA_URL}"}}}}}}',
|
||||
"",
|
||||
'data: [DONE]',
|
||||
"",
|
||||
]))
|
||||
client = CodexImageGenerationClient(
|
||||
api_key=None, client=fake # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
response = await client.generate(prompt="draw", model="gpt-5.4")
|
||||
|
||||
assert response.images == [PNG_DATA_URL]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_no_images_raises() -> None:
|
||||
fake = FakeClient(FakeResponse({"data": []}))
|
||||
client = OpenAIImageGenerationClient(
|
||||
api_key="sk-openai-test",
|
||||
client=fake, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
with pytest.raises(ImageGenerationError, match="returned no images"):
|
||||
await client.generate(prompt="draw", model="dall-e-3")
|
||||
|
||||
@ -441,6 +441,15 @@ def test_openrouter_spec_is_gateway() -> None:
|
||||
assert spec.default_api_base == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def test_novita_spec_uses_openai_compatible_gateway() -> None:
|
||||
spec = find_by_name("novita")
|
||||
assert spec is not None
|
||||
assert spec.is_gateway is True
|
||||
assert spec.backend == "openai_compat"
|
||||
assert spec.env_key == "NOVITA_API_KEY"
|
||||
assert spec.default_api_base == "https://api.novita.ai/openai"
|
||||
|
||||
|
||||
def test_gemma_routes_to_gemini_provider() -> None:
|
||||
"""gemma models (e.g. gemma-3-27b-it) must auto-route to Gemini when GEMINI_API_KEY is set.
|
||||
Users running gemma via the Gemini API endpoint expect automatic provider detection."""
|
||||
@ -1013,6 +1022,41 @@ def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -
|
||||
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
||||
|
||||
|
||||
def test_openai_compat_deduplicates_duplicate_tool_call_ids_in_history() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{"role": "user", "content": "check both files"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "ab1b45c2a",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": '{"path":"a.txt"}'},
|
||||
},
|
||||
{
|
||||
"id": "ab1b45c2a",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": '{"path":"b.txt"}'},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "a"},
|
||||
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "b"},
|
||||
{"role": "user", "content": "continue"},
|
||||
])
|
||||
|
||||
tool_call_ids = [tc["id"] for tc in sanitized[1]["tool_calls"]]
|
||||
tool_result_ids = [sanitized[2]["tool_call_id"], sanitized[3]["tool_call_id"]]
|
||||
|
||||
assert tool_call_ids[0] == "ab1b45c2a"
|
||||
assert len(tool_call_ids) == len(set(tool_call_ids)) == 2
|
||||
assert tool_result_ids == tool_call_ids
|
||||
|
||||
|
||||
def test_openai_compat_stringifies_dict_tool_arguments() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
@ -1382,12 +1426,15 @@ def test_kimi_k25_thinking_enabled() -> None:
|
||||
"""kimi-k2.5 with reasoning_effort set should opt in to thinking."""
|
||||
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="medium")
|
||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||
# Moonshot rejects both 'reasoning_effort' and 'thinking' (#3939)
|
||||
assert "reasoning_effort" not in kw
|
||||
|
||||
|
||||
def test_kimi_k25_thinking_disabled_for_minimal() -> None:
|
||||
"""reasoning_effort='minimal' maps to thinking disabled for kimi-k2.5."""
|
||||
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="minimal")
|
||||
assert kw.get("extra_body") == {"thinking": {"type": "disabled"}}
|
||||
assert "reasoning_effort" not in kw
|
||||
|
||||
|
||||
def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None:
|
||||
@ -1397,21 +1444,36 @@ def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None:
|
||||
|
||||
|
||||
def test_kimi_k25_thinking_enabled_with_openrouter_prefix() -> None:
|
||||
"""OpenRouter-style model names like moonshotai/kimi-k2.5 must trigger thinking."""
|
||||
"""OpenRouter-style model names like moonshotai/kimi-k2.5 must trigger thinking.
|
||||
|
||||
OR drops upstream-provider `thinking` fields, so the same intent also has
|
||||
to go through OR's `reasoning.effort` shape (#3851 follow-up).
|
||||
"""
|
||||
kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.5", reasoning_effort="medium")
|
||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||
assert kw.get("extra_body") == {
|
||||
"thinking": {"type": "enabled"},
|
||||
"reasoning": {"effort": "medium"},
|
||||
}
|
||||
# Even via OR, reasoning_effort wire kwarg is dropped for kimi models
|
||||
assert "reasoning_effort" not in kw
|
||||
|
||||
|
||||
def test_kimi_k26_thinking_enabled() -> None:
|
||||
"""kimi-k2.6 with reasoning_effort set should opt in to thinking."""
|
||||
kw = _build_kwargs_for("moonshot", "kimi-k2.6", reasoning_effort="medium")
|
||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||
assert "reasoning_effort" not in kw
|
||||
|
||||
|
||||
def test_kimi_k26_thinking_enabled_with_openrouter_prefix() -> None:
|
||||
"""OpenRouter-style names like moonshotai/kimi-k2.6 must trigger thinking."""
|
||||
"""OpenRouter-style names like moonshotai/kimi-k2.6 must trigger thinking
|
||||
via both upstream `thinking` and OR's `reasoning.effort`."""
|
||||
kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.6", reasoning_effort="medium")
|
||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||
assert kw.get("extra_body") == {
|
||||
"thinking": {"type": "enabled"},
|
||||
"reasoning": {"effort": "medium"},
|
||||
}
|
||||
assert "reasoning_effort" not in kw
|
||||
|
||||
|
||||
def test_moonshot_kimi_k26_temperature_override() -> None:
|
||||
@ -1430,6 +1492,7 @@ def test_kimi_k26_code_preview_thinking_enabled() -> None:
|
||||
"""k2.6-code-preview also supports thinking; should behave like k2.5."""
|
||||
kw = _build_kwargs_for("moonshot", "k2.6-code-preview", reasoning_effort="high")
|
||||
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
|
||||
assert "reasoning_effort" not in kw
|
||||
|
||||
|
||||
def test_kimi_k2_series_no_thinking_injection() -> None:
|
||||
@ -1459,6 +1522,7 @@ def test_kimi_k25_thinking_disabled_for_none_string() -> None:
|
||||
"""reasoning_effort='none' maps to thinking disabled for kimi-k2.5."""
|
||||
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="none")
|
||||
assert kw.get("extra_body") == {"thinking": {"type": "disabled"}}
|
||||
assert "reasoning_effort" not in kw
|
||||
|
||||
|
||||
def test_dashscope_thinking_disabled_for_none_string() -> None:
|
||||
|
||||
97
tests/providers/test_novita_provider.py
Normal file
97
tests/providers/test_novita_provider.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""Tests for the Novita AI provider registration."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from nanobot.config.schema import Config, ProvidersConfig
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.registry import PROVIDERS, find_by_name
|
||||
|
||||
|
||||
def test_novita_config_field_exists() -> None:
|
||||
config = ProvidersConfig()
|
||||
|
||||
assert hasattr(config, "novita")
|
||||
|
||||
|
||||
def test_novita_provider_in_registry() -> None:
|
||||
specs = {spec.name: spec for spec in PROVIDERS}
|
||||
|
||||
assert "novita" in specs
|
||||
novita = specs["novita"]
|
||||
assert novita.backend == "openai_compat"
|
||||
assert novita.env_key == "NOVITA_API_KEY"
|
||||
assert novita.display_name == "Novita AI"
|
||||
assert novita.is_gateway is True
|
||||
assert novita.detect_by_base_keyword == "novita"
|
||||
assert novita.default_api_base == "https://api.novita.ai/openai"
|
||||
assert novita.strip_model_prefix is False
|
||||
|
||||
|
||||
def test_find_by_name_novita() -> None:
|
||||
spec = find_by_name("novita")
|
||||
|
||||
assert spec is not None
|
||||
assert spec.name == "novita"
|
||||
|
||||
|
||||
def test_novita_forced_provider_uses_default_api_base() -> None:
|
||||
config = Config.model_validate({
|
||||
"providers": {
|
||||
"novita": {
|
||||
"apiKey": "novita-key",
|
||||
},
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "deepseek-v4-pro",
|
||||
"provider": "novita",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
assert config.get_provider_name("deepseek-v4-pro") == "novita"
|
||||
assert config.get_api_key("deepseek-v4-pro") == "novita-key"
|
||||
assert config.get_api_base("deepseek-v4-pro") == "https://api.novita.ai/openai"
|
||||
|
||||
|
||||
def test_novita_gateway_routes_unprefixed_models_when_configured() -> None:
|
||||
config = Config.model_validate({
|
||||
"providers": {
|
||||
"novita": {
|
||||
"apiKey": "novita-key",
|
||||
},
|
||||
},
|
||||
"agents": {
|
||||
"defaults": {
|
||||
"model": "deepseek-v4-pro",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
assert config.get_provider_name("deepseek-v4-pro") == "novita"
|
||||
assert config.get_api_key("deepseek-v4-pro") == "novita-key"
|
||||
assert config.get_api_base("deepseek-v4-pro") == "https://api.novita.ai/openai"
|
||||
|
||||
|
||||
def test_novita_preserves_model_api_id() -> None:
|
||||
spec = find_by_name("novita")
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="novita-key",
|
||||
default_model="deepseek-v4-pro",
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
kwargs = provider._build_kwargs(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
tools=None,
|
||||
model="deepseek-v4-pro",
|
||||
max_tokens=1024,
|
||||
temperature=0.7,
|
||||
reasoning_effort=None,
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
assert kwargs["model"] == "deepseek-v4-pro"
|
||||
assert kwargs["max_tokens"] == 1024
|
||||
assert "max_completion_tokens" not in kwargs
|
||||
@ -155,6 +155,49 @@ class TestConvertMessages:
|
||||
assert items[0]["id"] == "fc_1"
|
||||
assert items[0]["name"] == "get_weather"
|
||||
|
||||
def test_duplicate_response_item_ids_are_made_unique(self):
|
||||
"""Codex rejects replayed Responses input items with duplicate ids."""
|
||||
_, items = convert_messages([
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_a|rs_same",
|
||||
"function": {"name": "first", "arguments": "{}"},
|
||||
}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_a|rs_same", "content": "ok"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_b|rs_same",
|
||||
"function": {"name": "second", "arguments": "{}"},
|
||||
}],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_b|rs_same", "content": "ok"},
|
||||
])
|
||||
function_call_ids = [
|
||||
item["id"] for item in items if item.get("type") == "function_call"
|
||||
]
|
||||
assert function_call_ids == ["rs_same", "rs_same_2"]
|
||||
assert len(function_call_ids) == len(set(function_call_ids))
|
||||
|
||||
def test_fallback_response_item_ids_are_unique_with_multiple_tool_calls(self):
|
||||
_, items = convert_messages([{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{"id": "call_a", "function": {"name": "first", "arguments": "{}"}},
|
||||
{"id": "call_b", "function": {"name": "second", "arguments": "{}"}},
|
||||
],
|
||||
}])
|
||||
function_call_ids = [
|
||||
item["id"] for item in items if item.get("type") == "function_call"
|
||||
]
|
||||
assert function_call_ids == ["fc_0", "fc_0_2"]
|
||||
assert len(function_call_ids) == len(set(function_call_ids))
|
||||
|
||||
def test_assistant_with_tool_calls_no_id(self):
|
||||
"""Fallback IDs when tool_call.id is missing."""
|
||||
_, items = convert_messages([{
|
||||
|
||||
@ -32,7 +32,7 @@ def _mimo_spec():
|
||||
|
||||
|
||||
def _openrouter_spec():
|
||||
"""Return the registered OpenRouter ProviderSpec (no thinking_style)."""
|
||||
"""Return the registered OpenRouter ProviderSpec."""
|
||||
specs = {s.name: s for s in PROVIDERS}
|
||||
return specs["openrouter"]
|
||||
|
||||
@ -77,6 +77,13 @@ def test_xiaomi_mimo_uses_thinking_type_style():
|
||||
assert spec.default_api_base == "https://api.xiaomimimo.com/v1"
|
||||
|
||||
|
||||
def test_openrouter_declares_gateway_reasoning_style():
|
||||
"""OpenRouter uses its own reasoning.effort field for routed thinking models."""
|
||||
spec = _openrouter_spec()
|
||||
assert spec.thinking_style == ""
|
||||
assert spec.gateway_reasoning_style == "reasoning_effort"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_kwargs wire-format
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -142,9 +149,11 @@ def test_mimo_reasoning_effort_unset_preserves_provider_default():
|
||||
|
||||
|
||||
def test_mimo_via_openrouter_reasoning_effort_none_disables_thinking():
|
||||
"""OpenRouter routes MiMo as "xiaomi/mimo-v2.5-pro"; the openrouter spec
|
||||
has no thinking_style, so the disable signal must come from the
|
||||
model-name path (#3845)."""
|
||||
"""OpenRouter routes MiMo as "xiaomi/mimo-v2.5-pro" and does NOT forward
|
||||
extra_body.thinking to upstream, so a disable signal must also reach OR
|
||||
in its own `reasoning.effort` shape. Verifies both the upstream-MiMo
|
||||
payload (#3845) and the OR-native payload (#3851 follow-up) are sent.
|
||||
"""
|
||||
provider = _openrouter_provider("xiaomi/mimo-v2.5-pro")
|
||||
kwargs = provider._build_kwargs(
|
||||
messages=_simple_messages(),
|
||||
@ -152,11 +161,15 @@ def test_mimo_via_openrouter_reasoning_effort_none_disables_thinking():
|
||||
temperature=0.7, reasoning_effort="none", tool_choice=None,
|
||||
)
|
||||
assert "reasoning_effort" not in kwargs
|
||||
assert kwargs["extra_body"] == {"thinking": {"type": "disabled"}}
|
||||
assert kwargs["extra_body"] == {
|
||||
"thinking": {"type": "disabled"},
|
||||
"reasoning": {"effort": "none"},
|
||||
}
|
||||
|
||||
|
||||
def test_mimo_via_openrouter_reasoning_effort_medium_enables_thinking():
|
||||
"""Same as the direct path: any non-none/minimal effort enables thinking."""
|
||||
"""Non-none/minimal effort enables thinking and the OR `reasoning.effort`
|
||||
field mirrors the requested effort level."""
|
||||
provider = _openrouter_provider("xiaomi/mimo-v2.5-pro")
|
||||
kwargs = provider._build_kwargs(
|
||||
messages=_simple_messages(),
|
||||
@ -164,7 +177,10 @@ def test_mimo_via_openrouter_reasoning_effort_medium_enables_thinking():
|
||||
temperature=0.7, reasoning_effort="medium", tool_choice=None,
|
||||
)
|
||||
assert kwargs.get("reasoning_effort") == "medium"
|
||||
assert kwargs["extra_body"] == {"thinking": {"type": "enabled"}}
|
||||
assert kwargs["extra_body"] == {
|
||||
"thinking": {"type": "enabled"},
|
||||
"reasoning": {"effort": "medium"},
|
||||
}
|
||||
|
||||
|
||||
def test_mimo_via_openrouter_bare_slug_also_matches():
|
||||
@ -176,12 +192,16 @@ def test_mimo_via_openrouter_bare_slug_also_matches():
|
||||
tools=None, model=None, max_tokens=100,
|
||||
temperature=0.7, reasoning_effort="none", tool_choice=None,
|
||||
)
|
||||
assert kwargs["extra_body"] == {"thinking": {"type": "disabled"}}
|
||||
assert kwargs["extra_body"] == {
|
||||
"thinking": {"type": "disabled"},
|
||||
"reasoning": {"effort": "none"},
|
||||
}
|
||||
|
||||
|
||||
def test_mimo_flash_via_openrouter_does_not_inject_thinking():
|
||||
"""mimo-v2-flash has no thinking mode per Xiaomi docs; the allowlist
|
||||
excludes it, so no thinking field should be injected on the gateway path."""
|
||||
excludes it, so neither the upstream `thinking` field nor OR's
|
||||
`reasoning.effort` should be injected on the gateway path."""
|
||||
provider = _openrouter_provider("xiaomi/mimo-v2-flash")
|
||||
kwargs = provider._build_kwargs(
|
||||
messages=_simple_messages(),
|
||||
@ -200,3 +220,18 @@ def test_non_mimo_model_via_openrouter_unaffected():
|
||||
temperature=0.7, reasoning_effort="none", tool_choice=None,
|
||||
)
|
||||
assert "extra_body" not in kwargs
|
||||
|
||||
|
||||
def test_kimi_via_openrouter_also_injects_reasoning_effort():
|
||||
"""Kimi has the same gateway problem as MiMo: OR drops the upstream
|
||||
`thinking` field. The same OR-reasoning injection should fire."""
|
||||
provider = _openrouter_provider("moonshotai/kimi-k2.5")
|
||||
kwargs = provider._build_kwargs(
|
||||
messages=_simple_messages(),
|
||||
tools=None, model=None, max_tokens=100,
|
||||
temperature=0.7, reasoning_effort="none", tool_choice=None,
|
||||
)
|
||||
assert kwargs["extra_body"] == {
|
||||
"thinking": {"type": "disabled"},
|
||||
"reasoning": {"effort": "none"},
|
||||
}
|
||||
|
||||
330
tests/tools/test_apply_patch_tool.py
Normal file
330
tests/tools/test_apply_patch_tool.py
Normal file
@ -0,0 +1,330 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from nanobot.agent.tools.apply_patch import ApplyPatchTool
|
||||
|
||||
|
||||
def test_apply_patch_edits_replace(tmp_path):
|
||||
target = tmp_path / "calc.py"
|
||||
target.write_text("def add(a, b):\n return a + b\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "calc.py",
|
||||
"action": "replace",
|
||||
"old_text": " return a + b",
|
||||
"new_text": " return a - b",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert "update calc.py" in result
|
||||
assert target.read_text() == "def add(a, b):\n return a - b\n"
|
||||
|
||||
|
||||
def test_apply_patch_edits_add_new_file(tmp_path):
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "config.py",
|
||||
"action": "add",
|
||||
"new_text": "DEBUG = True",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert "add config.py" in result
|
||||
assert (tmp_path / "config.py").read_text() == "DEBUG = True\n"
|
||||
|
||||
|
||||
def test_apply_patch_edits_preserves_new_file_trailing_blank_lines(tmp_path):
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "notes.txt",
|
||||
"action": "add",
|
||||
"new_text": "one\n\n",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert "add notes.txt" in result
|
||||
assert (tmp_path / "notes.txt").read_text() == "one\n\n"
|
||||
|
||||
|
||||
def test_apply_patch_edits_add_to_existing_file(tmp_path):
|
||||
target = tmp_path / "log.py"
|
||||
target.write_text("import logging\n\nlogger = logging.getLogger(__name__)\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "log.py",
|
||||
"action": "add",
|
||||
"new_text": "def debug(msg):\n logger.debug(msg)",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert "update log.py" in result
|
||||
assert (
|
||||
target.read_text()
|
||||
== "import logging\n\nlogger = logging.getLogger(__name__)\ndef debug(msg):\n logger.debug(msg)\n"
|
||||
)
|
||||
|
||||
|
||||
def test_apply_patch_edits_delete(tmp_path):
|
||||
target = tmp_path / "utils.py"
|
||||
target.write_text("def unused():\n pass\ndef used():\n return 1\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "utils.py",
|
||||
"action": "delete",
|
||||
"old_text": "def unused():\n pass\n",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert "update utils.py" in result
|
||||
assert target.read_text() == "def used():\n return 1\n"
|
||||
|
||||
|
||||
def test_apply_patch_edits_delete_entire_file(tmp_path):
|
||||
target = tmp_path / "obsolete.txt"
|
||||
target.write_text("remove me\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "obsolete.txt",
|
||||
"action": "delete",
|
||||
"old_text": "remove me\n",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert "delete obsolete.txt" in result
|
||||
assert not target.exists()
|
||||
|
||||
|
||||
def test_apply_patch_edits_delete_substring_with_surrounding_whitespace(tmp_path):
|
||||
target = tmp_path / "keep_whitespace.txt"
|
||||
target.write_text(" token \n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "keep_whitespace.txt",
|
||||
"action": "delete",
|
||||
"old_text": "token",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert "update keep_whitespace.txt" in result
|
||||
assert target.exists()
|
||||
assert target.read_text() == " \n"
|
||||
|
||||
|
||||
def test_apply_patch_edits_batch_multiple_files(tmp_path):
|
||||
a = tmp_path / "a.py"
|
||||
a.write_text("X = 1\n")
|
||||
b = tmp_path / "b.py"
|
||||
b.write_text("from a import X\nprint(X)\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "a.py",
|
||||
"action": "replace",
|
||||
"old_text": "X = 1",
|
||||
"new_text": "Y = 1",
|
||||
},
|
||||
{
|
||||
"path": "b.py",
|
||||
"action": "replace",
|
||||
"old_text": "from a import X",
|
||||
"new_text": "from a import Y",
|
||||
},
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert "update a.py" in result
|
||||
assert "update b.py" in result
|
||||
assert a.read_text() == "Y = 1\n"
|
||||
assert b.read_text() == "from a import Y\nprint(X)\n"
|
||||
|
||||
|
||||
def test_apply_patch_edits_rejects_ambiguous_old_text(tmp_path):
|
||||
target = tmp_path / "repeated.txt"
|
||||
target.write_text("target\nmiddle\ntarget\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "repeated.txt",
|
||||
"action": "replace",
|
||||
"old_text": "target",
|
||||
"new_text": "changed",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert "old_text appears multiple times" in result
|
||||
assert target.read_text() == "target\nmiddle\ntarget\n"
|
||||
|
||||
|
||||
def test_apply_patch_edits_dry_run_validates_without_writing(tmp_path):
|
||||
target = tmp_path / "dry.txt"
|
||||
target.write_text("before\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "dry.txt",
|
||||
"action": "replace",
|
||||
"old_text": "before",
|
||||
"new_text": "after",
|
||||
},
|
||||
{
|
||||
"path": "added.txt",
|
||||
"action": "add",
|
||||
"new_text": "new",
|
||||
},
|
||||
],
|
||||
dry_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
assert "Patch dry-run succeeded" in result
|
||||
assert target.read_text() == "before\n"
|
||||
assert not (tmp_path / "added.txt").exists()
|
||||
|
||||
|
||||
def test_apply_patch_edits_rejects_absolute_and_parent_paths(tmp_path):
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
absolute = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "/tmp/owned.txt",
|
||||
"action": "add",
|
||||
"new_text": "nope",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
parent = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "../owned.txt",
|
||||
"action": "add",
|
||||
"new_text": "nope",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
windows_absolute = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": r"C:\owned.txt",
|
||||
"action": "add",
|
||||
"new_text": "nope",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
windows_parent = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": r"..\owned.txt",
|
||||
"action": "add",
|
||||
"new_text": "nope",
|
||||
}
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert "must be relative" in absolute
|
||||
assert "must not contain '..'" in parent
|
||||
assert "must be relative" in windows_absolute
|
||||
assert "must not contain '..'" in windows_parent
|
||||
assert not (tmp_path.parent / "owned.txt").exists()
|
||||
|
||||
|
||||
def test_apply_patch_edits_reports_invalid_edit_shapes(tmp_path):
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
missing_path = asyncio.run(tool.execute(edits=[{"action": "add", "new_text": "x"}]))
|
||||
missing_action = asyncio.run(tool.execute(edits=[{"path": "x.txt", "new_text": "x"}]))
|
||||
non_object = asyncio.run(tool.execute(edits=["not an object"])) # type: ignore[list-item]
|
||||
|
||||
assert "path required for edit" in missing_path
|
||||
assert "action required for edit: x.txt" in missing_action
|
||||
assert "each edit must be an object" in non_object
|
||||
|
||||
|
||||
def test_apply_patch_edits_rolls_back_when_late_operation_fails(tmp_path):
|
||||
first = tmp_path / "first.txt"
|
||||
first.write_text("before\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(
|
||||
tool.execute(
|
||||
edits=[
|
||||
{
|
||||
"path": "first.txt",
|
||||
"action": "replace",
|
||||
"old_text": "before",
|
||||
"new_text": "after",
|
||||
},
|
||||
{
|
||||
"path": "missing.txt",
|
||||
"action": "delete",
|
||||
"old_text": "remove me",
|
||||
},
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
assert "file to update does not exist: missing.txt" in result
|
||||
assert first.read_text() == "before\n"
|
||||
@ -1,5 +1,5 @@
|
||||
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
|
||||
.ipynb detection, and create-file semantics."""
|
||||
notebook JSON editing, and create-file semantics."""
|
||||
|
||||
import pytest
|
||||
|
||||
@ -108,22 +108,27 @@ class TestEditCreateFile:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# .ipynb detection
|
||||
# .ipynb editing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditIpynbDetection:
|
||||
"""edit_file should refuse .ipynb and suggest notebook_edit."""
|
||||
class TestEditIpynbFiles:
|
||||
"""edit_file edits notebooks as normal JSON files."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path):
|
||||
async def test_ipynb_can_be_edited_as_json(self, tool, tmp_path):
|
||||
f = tmp_path / "analysis.ipynb"
|
||||
f.write_text('{"cells": []}', encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
||||
assert "notebook" in result.lower()
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='"cells": []',
|
||||
new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
|
||||
)
|
||||
assert "Successfully edited" in result
|
||||
assert '"source": "hi"' in f.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -162,7 +162,7 @@ class TestPathAppendPlatform:
|
||||
captured_cmd = None
|
||||
captured_env = {}
|
||||
|
||||
async def capture_spawn(cmd, cwd, env):
|
||||
async def capture_spawn(cmd, cwd, env, shell_program=None, login=True):
|
||||
nonlocal captured_cmd
|
||||
captured_cmd = cmd
|
||||
captured_env.update(env)
|
||||
@ -190,7 +190,7 @@ class TestPathAppendPlatform:
|
||||
|
||||
captured_env = {}
|
||||
|
||||
async def capture_spawn(cmd, cwd, env):
|
||||
async def capture_spawn(cmd, cwd, env, shell_program=None, login=True):
|
||||
captured_env.update(env)
|
||||
return mock_proc
|
||||
|
||||
|
||||
361
tests/tools/test_exec_session_tools.py
Normal file
361
tests/tools/test_exec_session_tools.py
Normal file
@ -0,0 +1,361 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.exec_session import ExecSessionManager, ListExecSessionsTool, WriteStdinTool
|
||||
|
||||
|
||||
def _python_command(code: str) -> str:
|
||||
if sys.platform == "win32":
|
||||
return f"{subprocess.list2cmdline([sys.executable])} -u -c {subprocess.list2cmdline([code])}"
|
||||
return f"{shlex.quote(sys.executable)} -u -c {shlex.quote(code)}"
|
||||
|
||||
|
||||
def _session_id(output: str) -> str:
|
||||
match = re.search(r"session_id:\s*([0-9a-f]+)", output)
|
||||
assert match, output
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def test_exec_keeps_one_shot_behavior_without_yield_time_ms(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
return await tool.execute(command="echo hello")
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "hello" in result
|
||||
assert "Exit code: 0" in result
|
||||
assert "session_id:" not in result
|
||||
|
||||
|
||||
def test_exec_accepts_command_aliases(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir="/")
|
||||
return await tool.execute(
|
||||
cmd=_python_command("import os; print(os.getcwd())"),
|
||||
workdir=str(tmp_path),
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert str(tmp_path) in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_returns_completed_session_output_when_yield_time_ms_is_used(tmp_path):
|
||||
async def run() -> str:
|
||||
manager = ExecSessionManager()
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
|
||||
result = await tool.execute(command="echo hello", yield_time_ms=1000)
|
||||
if "session_id:" in result:
|
||||
sid = _session_id(result)
|
||||
result += "\n" + await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
chars="",
|
||||
yield_time_ms=1000,
|
||||
)
|
||||
return result
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "hello" in result
|
||||
assert "Exit code: 0" in result
|
||||
assert "session_id:" not in result
|
||||
|
||||
|
||||
def test_exec_session_accepts_max_output_tokens_alias(tmp_path):
|
||||
async def run() -> str:
|
||||
manager = ExecSessionManager()
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
command = _python_command("print('A' * 2000)")
|
||||
return await tool.execute(
|
||||
command=command,
|
||||
yield_time_ms=1000,
|
||||
max_output_tokens=1000,
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "chars truncated" in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_one_shot_accepts_max_output_tokens_alias(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
command = _python_command("print('A' * 2000)")
|
||||
return await tool.execute(command=command, max_output_tokens=1000)
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "chars truncated" in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_accepts_supported_shell_parameter(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
return await tool.execute(command="echo shell-ok", shell="sh", login=False)
|
||||
|
||||
if sys.platform == "win32":
|
||||
return
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "shell-ok" in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_rejects_unsupported_shell(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
return await tool.execute(command="echo no", shell="python")
|
||||
|
||||
if sys.platform == "win32":
|
||||
return
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "unsupported shell" in result
|
||||
|
||||
|
||||
def test_exec_can_continue_with_stdin(tmp_path):
|
||||
async def run() -> tuple[str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import sys; print('ready', flush=True); "
|
||||
"line=sys.stdin.readline(); print('got:' + line.strip(), flush=True)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||
sid = _session_id(initial)
|
||||
result = await stdin_tool.execute(session_id=sid, chars="ping\n", yield_time_ms=1000)
|
||||
return initial, result
|
||||
|
||||
initial, result = asyncio.run(run())
|
||||
assert "ready" in initial
|
||||
assert "Process running" in initial
|
||||
assert "Elapsed:" in initial
|
||||
assert "got:ping" in result
|
||||
assert "Exit code: 0" in result
|
||||
assert "Elapsed:" in result
|
||||
|
||||
|
||||
def test_write_stdin_can_close_stdin(tmp_path):
|
||||
async def run() -> tuple[str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import sys; print('ready', flush=True); "
|
||||
"data=sys.stdin.read(); print('got:' + data, flush=True)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||
sid = _session_id(initial)
|
||||
result = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
chars="payload",
|
||||
close_stdin=True,
|
||||
yield_time_ms=1000,
|
||||
)
|
||||
return initial, result
|
||||
|
||||
initial, result = asyncio.run(run())
|
||||
assert "ready" in initial
|
||||
assert "got:payload" in result
|
||||
assert "Stdin closed." in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_write_stdin_can_terminate_session(tmp_path):
|
||||
async def run() -> tuple[str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=30, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('ready', flush=True); time.sleep(30)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||
sid = _session_id(initial)
|
||||
result = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
terminate=True,
|
||||
yield_time_ms=0,
|
||||
)
|
||||
return initial, result
|
||||
|
||||
initial, result = asyncio.run(run())
|
||||
assert "ready" in initial
|
||||
assert "Session terminated." in result
|
||||
assert "Exit code:" in result
|
||||
|
||||
|
||||
def test_write_stdin_accepts_max_output_tokens_alias(tmp_path):
|
||||
async def run() -> tuple[str, str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('A' * 2000, flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=0)
|
||||
sid = _session_id(initial)
|
||||
poll = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
yield_time_ms=500,
|
||||
max_output_tokens=1000,
|
||||
)
|
||||
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||
return initial, poll, cleanup
|
||||
|
||||
initial, poll, cleanup = asyncio.run(run())
|
||||
assert "Process running" in initial
|
||||
assert "chars truncated" in poll
|
||||
assert "Session terminated." in cleanup
|
||||
|
||||
|
||||
def test_write_stdin_preserves_completed_session_output_until_polled(tmp_path):
|
||||
async def run() -> tuple[str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('ready', flush=True); "
|
||||
"time.sleep(1.0); print('done', flush=True)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=300)
|
||||
sid = _session_id(initial)
|
||||
await asyncio.sleep(1.2)
|
||||
final = await stdin_tool.execute(session_id=sid, chars="", yield_time_ms=0)
|
||||
return initial, final
|
||||
|
||||
initial, final = asyncio.run(run())
|
||||
|
||||
assert "ready" in initial
|
||||
assert "done" in final
|
||||
assert "Exit code: 0" in final
|
||||
|
||||
|
||||
def test_write_stdin_can_wait_for_expected_output(tmp_path):
|
||||
async def run() -> tuple[str, str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('booting', flush=True); "
|
||||
"time.sleep(0.4); print('ready', flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=100)
|
||||
sid = _session_id(initial)
|
||||
waited = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
wait_for="ready",
|
||||
wait_timeout_ms=3000,
|
||||
yield_time_ms=0,
|
||||
)
|
||||
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||
return initial, waited, cleanup
|
||||
|
||||
initial, waited, cleanup = asyncio.run(run())
|
||||
|
||||
assert "Process running" in initial
|
||||
assert "booting" in initial + waited
|
||||
assert "ready" in waited
|
||||
assert "Wait target not observed" not in waited
|
||||
assert "Session terminated." in cleanup
|
||||
|
||||
|
||||
def test_write_stdin_wait_for_reports_timeout_without_killing_session(tmp_path):
|
||||
async def run() -> tuple[str, str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('booting', flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=100)
|
||||
sid = _session_id(initial)
|
||||
waited = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
wait_for="never-ready",
|
||||
wait_timeout_ms=200,
|
||||
yield_time_ms=0,
|
||||
)
|
||||
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||
return initial, waited, cleanup
|
||||
|
||||
initial, waited, cleanup = asyncio.run(run())
|
||||
|
||||
assert "Process running" in initial
|
||||
assert "booting" in initial + waited
|
||||
assert "Process running" in waited
|
||||
assert "Wait target not observed: 'never-ready'" in waited
|
||||
assert "Session terminated." in cleanup
|
||||
|
||||
|
||||
def test_exec_session_mode_reuses_exec_safety_guard(tmp_path):
|
||||
manager = ExecSessionManager()
|
||||
tool = ExecTool(
|
||||
working_dir=str(tmp_path),
|
||||
deny_patterns=[r"echo\s+blocked"],
|
||||
session_manager=manager,
|
||||
)
|
||||
|
||||
result = asyncio.run(tool.execute(command="echo blocked", yield_time_ms=0))
|
||||
|
||||
assert "blocked by deny pattern" in result
|
||||
|
||||
|
||||
def test_write_stdin_reports_missing_session(tmp_path):
|
||||
manager = ExecSessionManager()
|
||||
tool = WriteStdinTool(manager=manager)
|
||||
|
||||
result = asyncio.run(tool.execute(session_id="missing", chars=""))
|
||||
|
||||
assert "exec session not found" in result
|
||||
|
||||
|
||||
def test_list_exec_sessions_reports_running_commands(tmp_path):
|
||||
async def run() -> tuple[str, str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
list_tool = ListExecSessionsTool(manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('ready', flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||
sid = _session_id(initial)
|
||||
listing = await list_tool.execute()
|
||||
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||
return sid, listing, cleanup
|
||||
|
||||
sid, listing, cleanup = asyncio.run(run())
|
||||
|
||||
assert sid in listing
|
||||
assert "running" in listing
|
||||
assert "elapsed=" in listing
|
||||
assert "remaining=" in listing
|
||||
assert str(tmp_path) in listing
|
||||
assert "Session terminated." in cleanup
|
||||
|
||||
|
||||
def test_list_exec_sessions_reports_empty_state():
|
||||
result = asyncio.run(ListExecSessionsTool(manager=ExecSessionManager()).execute())
|
||||
|
||||
assert result == "No active exec sessions."
|
||||
216
tests/tools/test_file_edit_coding_enhancements.py
Normal file
216
tests/tools/test_file_edit_coding_enhancements.py
Normal file
@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool
|
||||
|
||||
|
||||
def test_read_file_force_bypasses_dedup(tmp_path):
|
||||
target = tmp_path / "data.txt"
|
||||
target.write_text("alpha\n")
|
||||
tool = ReadFileTool(workspace=tmp_path)
|
||||
|
||||
first = asyncio.run(tool.execute(path=str(target)))
|
||||
second = asyncio.run(tool.execute(path=str(target)))
|
||||
forced = asyncio.run(tool.execute(path=str(target), force=True))
|
||||
|
||||
assert "alpha" in first
|
||||
assert "unchanged" in second.lower()
|
||||
assert "alpha" in forced
|
||||
assert "unchanged" not in forced.lower()
|
||||
|
||||
|
||||
def test_edit_file_can_select_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("one\nsame\ntwo\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=2,
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert target.read_text() == "one\nsame\ntwo\nchanged\n"
|
||||
|
||||
|
||||
def test_edit_file_expected_replacements_guards_replace_all(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
replace_all=True,
|
||||
expected_replacements=1,
|
||||
))
|
||||
|
||||
assert "expected 1 replacements but would make 2" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_expected_replacements_allows_replace_all_when_count_matches(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
replace_all=True,
|
||||
expected_replacements=2,
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert target.read_text() == "changed\nchanged\n"
|
||||
|
||||
|
||||
def test_edit_file_can_select_nearest_line_hint(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("one\nsame\ntwo\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=4,
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert target.read_text() == "one\nsame\ntwo\nchanged\n"
|
||||
|
||||
|
||||
def test_edit_file_can_edit_ipynb_as_json(tmp_path):
|
||||
target = tmp_path / "analysis.ipynb"
|
||||
target.write_text('{"cells": []}')
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text='"cells": []',
|
||||
new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert '"source": "hi"' in target.read_text()
|
||||
|
||||
|
||||
def test_edit_file_multiple_match_hint_mentions_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
))
|
||||
|
||||
assert "old_text appears 2 times" in result
|
||||
assert "occurrence" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_ambiguous_line_hint(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nmiddle\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=2,
|
||||
))
|
||||
|
||||
assert "line_hint 2 is ambiguous" in result
|
||||
assert target.read_text() == "same\nmiddle\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_occurrence_with_replace_all(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=1,
|
||||
replace_all=True,
|
||||
))
|
||||
|
||||
assert "occurrence cannot be used with replace_all" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_line_hint_with_replace_all(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=1,
|
||||
replace_all=True,
|
||||
))
|
||||
|
||||
assert "line_hint cannot be used with replace_all" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_line_hint_with_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=1,
|
||||
line_hint=1,
|
||||
))
|
||||
|
||||
assert "line_hint cannot be used with occurrence" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_zero_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=0,
|
||||
))
|
||||
|
||||
assert "occurrence must be >= 1" in result
|
||||
assert target.read_text() == "same\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_zero_line_hint(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=0,
|
||||
))
|
||||
|
||||
assert "line_hint must be >= 1" in result
|
||||
assert target.read_text() == "same\n"
|
||||
@ -138,6 +138,39 @@ async def test_generate_image_tool_reports_missing_aihubmix_key(tmp_path: Path)
|
||||
assert result.startswith("Error: AIHubMix API key is not configured")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_tool_allows_ollama_without_api_key(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
set_config_path(tmp_path / "config.json")
|
||||
FakeImageClient.instances = []
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.tools.image_generation.get_image_gen_provider",
|
||||
lambda name: FakeImageClient if name == "ollama" else None,
|
||||
)
|
||||
tool = ImageGenerationTool(
|
||||
workspace=tmp_path,
|
||||
config=ImageGenerationToolConfig(
|
||||
enabled=True,
|
||||
provider="ollama",
|
||||
model="x/z-image-turbo",
|
||||
),
|
||||
provider_configs={"ollama": ProviderConfig(api_base="http://localhost:11434/v1")},
|
||||
)
|
||||
|
||||
result = await tool.execute(prompt="draw a cat")
|
||||
|
||||
payload = json.loads(result)
|
||||
assert len(payload["artifacts"]) == 1
|
||||
|
||||
fake = FakeImageClient.instances[0]
|
||||
assert fake.kwargs["api_key"] is None
|
||||
assert fake.kwargs["api_base"] == "http://localhost:11434/v1"
|
||||
assert fake.calls[0]["aspect_ratio"] == "1:1"
|
||||
assert fake.calls[0]["image_size"] == "1K"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_image_tool_rejects_reference_outside_workspace(tmp_path: Path) -> None:
|
||||
set_config_path(tmp_path / "config.json")
|
||||
|
||||
@ -1,147 +0,0 @@
|
||||
"""Tests for NotebookEditTool — Jupyter .ipynb editing."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
|
||||
|
||||
def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict:
|
||||
"""Build a minimal valid .ipynb structure."""
|
||||
return {
|
||||
"nbformat": nbformat,
|
||||
"nbformat_minor": nbformat_minor,
|
||||
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
|
||||
"cells": cells or [],
|
||||
}
|
||||
|
||||
|
||||
def _code_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _md_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "markdown", "source": source, "metadata": {}}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _write_nb(tmp_path, name: str, nb: dict) -> str:
|
||||
p = tmp_path / name
|
||||
p.write_text(json.dumps(nb), encoding="utf-8")
|
||||
return str(p)
|
||||
|
||||
|
||||
class TestNotebookEdit:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return NotebookEditTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_cell_content(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="print('world')")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][0]["source"] == "print('world')"
|
||||
assert saved["cells"][1]["source"] == "x = 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_cell_after_target(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 3
|
||||
assert saved["cells"][0]["source"] == "cell 0"
|
||||
assert saved["cells"][1]["source"] == "inserted"
|
||||
assert saved["cells"][2]["source"] == "cell 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_cell(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=1, edit_mode="delete")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 2
|
||||
assert saved["cells"][0]["source"] == "A"
|
||||
assert saved["cells"][1]["source"] == "C"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_notebook_from_scratch(self, tool, tmp_path):
|
||||
path = str(tmp_path / "new.ipynb")
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown")
|
||||
assert "Successfully" in result or "created" in result.lower()
|
||||
saved = json.loads((tmp_path / "new.ipynb").read_text())
|
||||
assert saved["nbformat"] == 4
|
||||
assert len(saved["cells"]) == 1
|
||||
assert saved["cells"][0]["cell_type"] == "markdown"
|
||||
assert saved["cells"][0]["source"] == "# Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_index_error(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("only cell")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=5, new_source="x")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_ipynb_rejected(self, tool, tmp_path):
|
||||
f = tmp_path / "script.py"
|
||||
f.write_text("pass")
|
||||
result = await tool.execute(path=str(f), cell_index=0, new_source="x")
|
||||
assert "Error" in result
|
||||
assert ".ipynb" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_metadata_and_outputs(self, tool, tmp_path):
|
||||
cell = _code_cell("old")
|
||||
cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}]
|
||||
cell["execution_count"] = 42
|
||||
nb = _make_notebook([cell])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="new")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["metadata"]["kernelspec"]["language"] == "python"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nbformat_45_generates_cell_id(self, tool, tmp_path):
|
||||
nb = _make_notebook([], nbformat_minor=5)
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert "id" in saved["cells"][0]
|
||||
assert len(saved["cells"][0]["id"]) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_with_cell_type_markdown(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][1]["cell_type"] == "markdown"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_edit_mode_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae")
|
||||
assert "Error" in result
|
||||
assert "edit_mode" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_type_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw")
|
||||
assert "Error" in result
|
||||
assert "cell_type" in result
|
||||
@ -12,7 +12,7 @@ import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.subagent import SubagentManager, SubagentStatus
|
||||
from nanobot.agent.tools.search import GrepTool
|
||||
from nanobot.agent.tools.search import FindFilesTool, GrepTool
|
||||
from nanobot.agent.tools.web import WebSearchTool
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import WebSearchConfig
|
||||
@ -33,6 +33,68 @@ async def test_web_search_tool_refreshes_dynamic_config_loader(monkeypatch) -> N
|
||||
assert await tool.execute("nanobot") == "duckduckgo:nanobot:3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_files_filters_by_query_glob_and_type(tmp_path: Path) -> None:
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "settings_view.tsx").write_text("export {}\n", encoding="utf-8")
|
||||
(tmp_path / "src" / "settings_api.py").write_text("pass\n", encoding="utf-8")
|
||||
(tmp_path / "README.md").write_text("settings\n", encoding="utf-8")
|
||||
|
||||
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||
result = await tool.execute(
|
||||
path=".",
|
||||
query="settings",
|
||||
glob="src/**",
|
||||
type="ts",
|
||||
)
|
||||
|
||||
assert result.splitlines() == ["src/settings_view.tsx"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_files_can_include_directories(tmp_path: Path) -> None:
|
||||
(tmp_path / "src" / "settings").mkdir(parents=True)
|
||||
(tmp_path / "src" / "settings" / "index.ts").write_text("export {}\n", encoding="utf-8")
|
||||
|
||||
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||
result = await tool.execute(path="src", query="settings", include_dirs=True)
|
||||
|
||||
assert "src/settings/" in result.splitlines()
|
||||
assert "src/settings/index.ts" in result.splitlines()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_files_supports_modified_sort_and_pagination(tmp_path: Path) -> None:
|
||||
(tmp_path / "src").mkdir()
|
||||
for idx, name in enumerate(("a.py", "b.py", "c.py"), start=1):
|
||||
file_path = tmp_path / "src" / name
|
||||
file_path.write_text("pass\n", encoding="utf-8")
|
||||
os.utime(file_path, (idx, idx))
|
||||
|
||||
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||
result = await tool.execute(
|
||||
path="src",
|
||||
type="py",
|
||||
sort="modified",
|
||||
head_limit=1,
|
||||
offset=1,
|
||||
)
|
||||
|
||||
assert result.splitlines()[0] == "src/b.py"
|
||||
assert "pagination: limit=1, offset=1" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_files_rejects_paths_outside_workspace(tmp_path: Path) -> None:
|
||||
outside = tmp_path.parent / "outside-find-files.txt"
|
||||
outside.write_text("secret\n", encoding="utf-8")
|
||||
|
||||
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||
result = await tool.execute(path=str(outside))
|
||||
|
||||
assert result.startswith("Error:")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grep_respects_glob_filter_and_context(tmp_path: Path) -> None:
|
||||
(tmp_path / "src").mkdir()
|
||||
@ -249,6 +311,7 @@ def test_agent_loop_registers_grep(tmp_path: Path) -> None:
|
||||
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
assert "find_files" in loop.tools.tool_names
|
||||
assert "grep" in loop.tools.tool_names
|
||||
|
||||
|
||||
@ -280,6 +343,7 @@ async def test_subagent_registers_grep(tmp_path: Path) -> None:
|
||||
status = SubagentStatus(task_id="sub-1", label="label", task_description="search task", started_at=time.monotonic())
|
||||
await mgr._run_subagent("sub-1", "search task", "label", {"channel": "cli", "chat_id": "direct"}, status)
|
||||
|
||||
assert "find_files" in captured["tool_names"]
|
||||
assert "grep" in captured["tool_names"]
|
||||
|
||||
|
||||
|
||||
46
tests/tools/test_tool_descriptions.py
Normal file
46
tests/tools/test_tool_descriptions.py
Normal file
@ -0,0 +1,46 @@
|
||||
from nanobot.agent.tools.apply_patch import ApplyPatchTool
|
||||
from nanobot.agent.tools.exec_session import ListExecSessionsTool, WriteStdinTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.search import FindFilesTool, GrepTool
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
|
||||
|
||||
def test_coding_tool_descriptions_steer_editing_priority() -> None:
|
||||
apply_patch = ApplyPatchTool().description.lower()
|
||||
edit_file = EditFileTool().description.lower()
|
||||
write_file = WriteFileTool().description.lower()
|
||||
|
||||
assert "default tool for code edits" in apply_patch
|
||||
assert "multi-file" in apply_patch
|
||||
assert "dry_run=true" in apply_patch
|
||||
assert "edit_file only for small exact replacements" in apply_patch
|
||||
|
||||
assert "small, exact replacement" in edit_file
|
||||
assert "copied from read_file" in edit_file
|
||||
assert "prefer apply_patch" in edit_file
|
||||
|
||||
assert "replace an entire file" in write_file
|
||||
assert "prefer apply_patch" in write_file
|
||||
|
||||
|
||||
def test_coding_tool_descriptions_steer_discovery_and_shell_usage() -> None:
|
||||
read_file = ReadFileTool().description.lower()
|
||||
find_files = FindFilesTool().description.lower()
|
||||
grep = GrepTool().description.lower()
|
||||
exec_tool = ExecTool().description.lower()
|
||||
write_stdin = WriteStdinTool().description.lower()
|
||||
list_sessions = ListExecSessionsTool().description.lower()
|
||||
|
||||
assert "find_files/list_dir first" in read_file
|
||||
assert "before editing" in read_file
|
||||
assert "prefer it over shell find/ls" in find_files
|
||||
assert "prefer this over shell grep" in grep
|
||||
|
||||
assert "tests, builds" in exec_tool
|
||||
assert "prefer read_file/find_files/grep" in exec_tool
|
||||
assert "apply_patch/write_file/edit_file" in exec_tool
|
||||
assert "yield_time_ms" in exec_tool
|
||||
|
||||
assert "do not use this to start new commands" in write_stdin
|
||||
assert "wait_for" in write_stdin
|
||||
assert "recover a session_id" in list_sessions
|
||||
@ -89,9 +89,11 @@ def test_discover_finds_concrete_tools():
|
||||
loader = ToolLoader()
|
||||
discovered = loader.discover()
|
||||
class_names = {cls.__name__ for cls in discovered}
|
||||
assert "ApplyPatchTool" in class_names
|
||||
assert "ExecTool" in class_names
|
||||
assert "MessageTool" in class_names
|
||||
assert "SpawnTool" in class_names
|
||||
assert "WriteStdinTool" in class_names
|
||||
|
||||
|
||||
def test_discover_excludes_abstract_and_mcp():
|
||||
@ -406,7 +408,8 @@ def test_loader_registers_same_tools_as_old_hardcoded():
|
||||
|
||||
expected = {
|
||||
"read_file", "write_file", "edit_file", "list_dir",
|
||||
"grep", "notebook_edit", "exec", "web_search", "web_fetch",
|
||||
"find_files", "grep", "exec", "write_stdin", "list_exec_sessions",
|
||||
"web_search", "web_fetch",
|
||||
"message", "spawn", "cron",
|
||||
}
|
||||
actual = set(registered)
|
||||
|
||||
@ -3,6 +3,8 @@ import subprocess
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools import (
|
||||
ArraySchema,
|
||||
IntegerSchema,
|
||||
@ -15,6 +17,7 @@ from nanobot.agent.tools import (
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.security.network import configure_ssrf_whitelist
|
||||
|
||||
|
||||
class SampleTool(Tool):
|
||||
@ -218,6 +221,39 @@ def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
|
||||
assert "/bin/python" not in paths
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_ignores_urls() -> None:
|
||||
cmd = 'curl -s -o /dev/null -w "%{http_code}" https://www.google.com'
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
assert paths == ["/dev/null"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"command",
|
||||
[
|
||||
'curl -s -o /dev/null -w "%{http_code}" https://www.google.com',
|
||||
'wget -q -O - http://example.com 2>&1 | head -c 100',
|
||||
'python3 -c "import urllib.request; print(urllib.request.urlopen(\'http://example.com\').read()[:100])"',
|
||||
],
|
||||
)
|
||||
def test_exec_guard_allows_public_urls(tmp_path, command: str) -> None:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command(command, str(tmp_path))
|
||||
assert error is None
|
||||
|
||||
|
||||
def test_exec_guard_allows_whitelisted_internal_urls(tmp_path) -> None:
|
||||
configure_ssrf_whitelist(["10.10.10.0/24"])
|
||||
try:
|
||||
tool = ExecTool(restrict_to_workspace=True)
|
||||
error = tool._guard_command(
|
||||
'curl -s -H "Authorization: Bearer ..." http://10.10.10.3:8123/api/',
|
||||
str(tmp_path),
|
||||
)
|
||||
assert error is None
|
||||
finally:
|
||||
configure_ssrf_whitelist([])
|
||||
|
||||
|
||||
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
|
||||
cmd = "cat /tmp/data.txt > /tmp/out.txt"
|
||||
paths = ExecTool._extract_absolute_paths(cmd)
|
||||
|
||||
@ -5,12 +5,13 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.utils.file_edit_events import (
|
||||
StreamingFileEditTracker,
|
||||
build_file_edit_end_event,
|
||||
build_file_edit_start_event,
|
||||
line_diff_stats,
|
||||
prepare_file_edit_tracker,
|
||||
prepare_file_edit_trackers,
|
||||
read_file_snapshot,
|
||||
StreamingFileEditTracker,
|
||||
)
|
||||
|
||||
|
||||
@ -81,6 +82,63 @@ def test_binary_file_is_reported_but_not_counted(tmp_path: Path) -> None:
|
||||
assert (event["added"], event["deleted"]) == (0, 0)
|
||||
|
||||
|
||||
def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) -> None:
|
||||
(tmp_path / "src").mkdir()
|
||||
existing = tmp_path / "src" / "existing.py"
|
||||
existing.write_text("old\nkeep\n", encoding="utf-8")
|
||||
delete_me = tmp_path / "src" / "delete_me.py"
|
||||
delete_me.write_text("gone\n", encoding="utf-8")
|
||||
|
||||
edits = [
|
||||
{"path": "src/new.py", "action": "add", "new_text": "fresh"},
|
||||
{"path": "src/existing.py", "action": "replace", "old_text": "old", "new_text": "new"},
|
||||
{"path": "src/delete_me.py", "action": "delete", "old_text": "gone\n"},
|
||||
]
|
||||
|
||||
trackers = prepare_file_edit_trackers(
|
||||
call_id="call-patch",
|
||||
tool_name="apply_patch",
|
||||
tool=None,
|
||||
workspace=tmp_path,
|
||||
params={"edits": edits},
|
||||
)
|
||||
|
||||
assert [tracker.display_path for tracker in trackers] == [
|
||||
"src/new.py",
|
||||
"src/existing.py",
|
||||
"src/delete_me.py",
|
||||
]
|
||||
|
||||
(tmp_path / "src" / "new.py").write_text("fresh\n", encoding="utf-8")
|
||||
existing.write_text("new\nkeep\n", encoding="utf-8")
|
||||
delete_me.unlink()
|
||||
|
||||
events = [build_file_edit_end_event(tracker, {"edits": edits}) for tracker in trackers]
|
||||
by_path = {event["path"]: event for event in events}
|
||||
assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0)
|
||||
assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1)
|
||||
assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1)
|
||||
|
||||
|
||||
def test_apply_patch_dry_run_does_not_prepare_file_edit_trackers(tmp_path: Path) -> None:
|
||||
(tmp_path / "file.txt").write_text("old\n", encoding="utf-8")
|
||||
|
||||
trackers = prepare_file_edit_trackers(
|
||||
call_id="call-patch",
|
||||
tool_name="apply_patch",
|
||||
tool=None,
|
||||
workspace=tmp_path,
|
||||
params={
|
||||
"dry_run": True,
|
||||
"edits": [
|
||||
{"path": "file.txt", "action": "replace", "old_text": "old", "new_text": "new"}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
assert trackers == []
|
||||
|
||||
|
||||
def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None:
|
||||
target = tmp_path / "large.txt"
|
||||
params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)}
|
||||
@ -140,6 +198,58 @@ def test_streaming_write_file_tracker_emits_live_line_counts(tmp_path: Path) ->
|
||||
assert events[-1]["deleted"] == 0
|
||||
|
||||
|
||||
def test_streaming_apply_patch_tracker_emits_live_counts_per_file(tmp_path: Path) -> None:
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "existing.py").write_text("old\nkeep\n", encoding="utf-8")
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-patch",
|
||||
"name": "apply_patch",
|
||||
"arguments_delta": (
|
||||
'{"edits":[{"path":"src/existing.py","action":"replace","old_text":"old","new_text":"new"}'
|
||||
',{"path":"src/new.py","action":"add","new_text":"fresh"}]}'
|
||||
),
|
||||
})
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
by_path = {event["path"]: event for event in events}
|
||||
assert by_path["src/existing.py"]["tool"] == "apply_patch"
|
||||
assert by_path["src/existing.py"]["status"] == "editing"
|
||||
assert by_path["src/existing.py"]["approximate"] is True
|
||||
assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1)
|
||||
assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0)
|
||||
|
||||
|
||||
def test_streaming_apply_patch_tracker_skips_dry_run(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-patch",
|
||||
"name": "apply_patch",
|
||||
"arguments_delta": (
|
||||
'{"dry_run":true,"edits":[{"path":"dry.md","action":"add","new_text":"preview"}]}'
|
||||
),
|
||||
})
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert events == []
|
||||
|
||||
|
||||
def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
@ -308,6 +418,43 @@ def test_streaming_tracker_applies_canonical_call_id_to_final_tool(tmp_path: Pat
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_streaming_tracker_does_not_restore_duplicate_canonical_ids(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call_dup",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"a.md","content":"one\\n"}',
|
||||
})
|
||||
await tracker.update({
|
||||
"index": 1,
|
||||
"call_id": "call_dup",
|
||||
"name": "write_file",
|
||||
"arguments_delta": '{"path":"b.md","content":"two\\n"}',
|
||||
})
|
||||
final_a = SimpleNamespace(
|
||||
id="call_dup",
|
||||
name="write_file",
|
||||
arguments={"path": "a.md", "content": "one\n"},
|
||||
)
|
||||
final_b = SimpleNamespace(
|
||||
id="call_unique",
|
||||
name="write_file",
|
||||
arguments={"path": "b.md", "content": "two\n"},
|
||||
)
|
||||
tracker.apply_final_call_ids([final_a, final_b])
|
||||
assert final_a.id == "call_dup"
|
||||
assert final_b.id == "call_unique"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_streaming_edit_file_tracker_flushes_small_pending_count(tmp_path: Path) -> None:
|
||||
target = tmp_path / "small.py"
|
||||
target.write_text("old\n", encoding="utf-8")
|
||||
|
||||
@ -43,6 +43,7 @@ const SIDEBAR_STORAGE_KEY = "nanobot-webui.sidebar";
|
||||
const COMPLETED_RUNS_STORAGE_KEY = "nanobot-webui.sidebar.completed-runs.v1";
|
||||
const RESTART_STARTED_KEY = "nanobot-webui.restartStartedAt";
|
||||
const SIDEBAR_WIDTH = 272;
|
||||
const SIDEBAR_RAIL_WIDTH = 56;
|
||||
const TOKEN_REFRESH_MARGIN_MS = 30_000;
|
||||
const TOKEN_REFRESH_MIN_DELAY_MS = 5_000;
|
||||
type ShellView = "chat" | "settings";
|
||||
@ -411,6 +412,10 @@ function Shell({
|
||||
setDesktopSidebarOpen(false);
|
||||
}, []);
|
||||
|
||||
const openDesktopSidebar = useCallback(() => {
|
||||
setDesktopSidebarOpen(true);
|
||||
}, []);
|
||||
|
||||
const closeMobileSidebar = useCallback(() => {
|
||||
setMobileSidebarOpen(false);
|
||||
}, []);
|
||||
@ -560,6 +565,21 @@ function Shell({
|
||||
setSessionSearchOpen(true);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (event: globalThis.KeyboardEvent) => {
|
||||
if (event.defaultPrevented) return;
|
||||
const plainCommandK =
|
||||
(event.metaKey || event.ctrlKey) && !event.altKey && !event.shiftKey;
|
||||
if (!plainCommandK) return;
|
||||
if (event.key.toLowerCase() !== "k") return;
|
||||
event.preventDefault();
|
||||
onOpenSessionSearch();
|
||||
};
|
||||
|
||||
window.addEventListener("keydown", handleKeyDown);
|
||||
return () => window.removeEventListener("keydown", handleKeyDown);
|
||||
}, [onOpenSessionSearch]);
|
||||
|
||||
const onSelectSearchResult = useCallback(
|
||||
(key: string) => {
|
||||
setSessionSearchOpen(false);
|
||||
@ -732,17 +752,19 @@ function Shell({
|
||||
"relative z-20 hidden shrink-0 overflow-hidden lg:block",
|
||||
"transition-[width] duration-300 ease-out",
|
||||
)}
|
||||
style={{ width: desktopSidebarOpen ? SIDEBAR_WIDTH : 0 }}
|
||||
style={{
|
||||
width: desktopSidebarOpen ? SIDEBAR_WIDTH : SIDEBAR_RAIL_WIDTH,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"absolute inset-y-0 left-0 h-full overflow-hidden bg-sidebar shadow-inner-right",
|
||||
"transition-transform duration-300 ease-out",
|
||||
desktopSidebarOpen ? "translate-x-0" : "-translate-x-full",
|
||||
)}
|
||||
style={{ width: SIDEBAR_WIDTH }}
|
||||
className="absolute inset-y-0 left-0 h-full w-full overflow-hidden bg-sidebar shadow-inner-right"
|
||||
>
|
||||
<Sidebar {...sidebarProps} onCollapse={closeDesktopSidebar} />
|
||||
<Sidebar
|
||||
{...sidebarProps}
|
||||
collapsed={!desktopSidebarOpen}
|
||||
onCollapse={closeDesktopSidebar}
|
||||
onExpand={openDesktopSidebar}
|
||||
/>
|
||||
</div>
|
||||
</aside>
|
||||
) : null}
|
||||
@ -769,17 +791,15 @@ function Shell({
|
||||
</Sheet>
|
||||
) : null}
|
||||
|
||||
{showMainSidebar ? (
|
||||
<SessionSearchDialog
|
||||
open={sessionSearchOpen}
|
||||
onOpenChange={setSessionSearchOpen}
|
||||
sessions={sessions}
|
||||
activeKey={activeKey}
|
||||
loading={loading}
|
||||
titleOverrides={sidebarState.title_overrides}
|
||||
onSelect={onSelectSearchResult}
|
||||
/>
|
||||
) : null}
|
||||
<SessionSearchDialog
|
||||
open={sessionSearchOpen}
|
||||
onOpenChange={setSessionSearchOpen}
|
||||
sessions={sessions}
|
||||
activeKey={activeKey}
|
||||
loading={loading}
|
||||
titleOverrides={sidebarState.title_overrides}
|
||||
onSelect={onSelectSearchResult}
|
||||
/>
|
||||
|
||||
<main className="relative flex h-full min-w-0 flex-1 flex-col">
|
||||
<div
|
||||
@ -797,7 +817,7 @@ function Shell({
|
||||
onTurnEnd={onTurnEnd}
|
||||
theme={theme}
|
||||
onToggleTheme={toggle}
|
||||
hideSidebarToggleOnDesktop={desktopSidebarOpen}
|
||||
hideSidebarToggleOnDesktop
|
||||
/>
|
||||
</div>
|
||||
{view === "settings" && (
|
||||
|
||||
@ -1,3 +1,9 @@
|
||||
import {
|
||||
memo,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useState,
|
||||
} from "react";
|
||||
import {
|
||||
Archive,
|
||||
ArchiveRestore,
|
||||
@ -19,6 +25,9 @@ import { deriveTitle, relativeTime } from "@/lib/format";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { ChatSummary, SidebarDensity, SidebarSortMode } from "@/lib/types";
|
||||
|
||||
const INITIAL_VISIBLE_SESSIONS = 160;
|
||||
const VISIBLE_SESSIONS_INCREMENT = 160;
|
||||
|
||||
interface ChatListProps {
|
||||
sessions: ChatSummary[];
|
||||
activeKey: string | null;
|
||||
@ -42,7 +51,7 @@ interface ChatListProps {
|
||||
emptyLabel?: string;
|
||||
}
|
||||
|
||||
export function ChatList({
|
||||
export const ChatList = memo(function ChatList({
|
||||
sessions,
|
||||
activeKey,
|
||||
onSelect,
|
||||
@ -65,6 +74,52 @@ export function ChatList({
|
||||
emptyLabel,
|
||||
}: ChatListProps) {
|
||||
const { t } = useTranslation();
|
||||
const [visibleLimit, setVisibleLimit] = useState(INITIAL_VISIBLE_SESSIONS);
|
||||
const labels = useMemo(() => ({
|
||||
pinned: t("chat.groups.pinned"),
|
||||
all: t("chat.groups.all"),
|
||||
today: t("chat.groups.today"),
|
||||
yesterday: t("chat.groups.yesterday"),
|
||||
earlier: t("chat.groups.earlier"),
|
||||
archived: t("chat.groups.archived"),
|
||||
fallbackTitle: t("chat.newChat"),
|
||||
}), [t]);
|
||||
const groups = useMemo(
|
||||
() => groupSessions(sessions, labels, {
|
||||
pinnedKeys,
|
||||
archivedKeys,
|
||||
titleOverrides,
|
||||
showArchived,
|
||||
sort,
|
||||
}),
|
||||
[
|
||||
archivedKeys,
|
||||
labels,
|
||||
pinnedKeys,
|
||||
sessions,
|
||||
showArchived,
|
||||
sort,
|
||||
titleOverrides,
|
||||
],
|
||||
);
|
||||
const limitedGroups = useMemo(
|
||||
() => limitGroups(groups, visibleLimit, activeKey),
|
||||
[activeKey, groups, visibleLimit],
|
||||
);
|
||||
const totalSessionCount = useMemo(
|
||||
() => groups.reduce((total, group) => total + group.sessions.length, 0),
|
||||
[groups],
|
||||
);
|
||||
const visibleSessionCount = useMemo(
|
||||
() => limitedGroups.reduce((total, group) => total + group.sessions.length, 0),
|
||||
[limitedGroups],
|
||||
);
|
||||
const hiddenSessionCount = Math.max(0, totalSessionCount - visibleSessionCount);
|
||||
|
||||
useEffect(() => {
|
||||
setVisibleLimit(INITIAL_VISIBLE_SESSIONS);
|
||||
}, [showArchived, sort]);
|
||||
|
||||
if (loading && sessions.length === 0) {
|
||||
return (
|
||||
<div className="px-3 py-6 text-[12px] text-muted-foreground">
|
||||
@ -81,21 +136,6 @@ export function ChatList({
|
||||
);
|
||||
}
|
||||
|
||||
const groups = groupSessions(sessions, {
|
||||
pinned: t("chat.groups.pinned"),
|
||||
all: t("chat.groups.all"),
|
||||
today: t("chat.groups.today"),
|
||||
yesterday: t("chat.groups.yesterday"),
|
||||
earlier: t("chat.groups.earlier"),
|
||||
archived: t("chat.groups.archived"),
|
||||
fallbackTitle: t("chat.newChat"),
|
||||
}, {
|
||||
pinnedKeys,
|
||||
archivedKeys,
|
||||
titleOverrides,
|
||||
showArchived,
|
||||
sort,
|
||||
});
|
||||
const pinned = new Set(pinnedKeys);
|
||||
const archived = new Set(archivedKeys);
|
||||
const running = new Set(runningChatIds);
|
||||
@ -105,7 +145,7 @@ export function ChatList({
|
||||
return (
|
||||
<div className="h-full min-h-0 min-w-0 overflow-x-hidden overflow-y-auto overscroll-contain">
|
||||
<div className="min-w-0 space-y-3 px-2 py-1.5">
|
||||
{groups.map((group) => (
|
||||
{limitedGroups.map((group) => (
|
||||
<section key={group.label} aria-label={group.label}>
|
||||
<div className="px-2 pb-1 text-[12px] font-medium text-muted-foreground/65">
|
||||
{group.label}
|
||||
@ -228,10 +268,25 @@ export function ChatList({
|
||||
</ul>
|
||||
</section>
|
||||
))}
|
||||
{hiddenSessionCount > 0 ? (
|
||||
<div className="px-2 pb-2 pt-1">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() =>
|
||||
setVisibleLimit((limit) =>
|
||||
Math.min(totalSessionCount, limit + VISIBLE_SESSIONS_INCREMENT),
|
||||
)
|
||||
}
|
||||
className="h-8 w-full rounded-full text-[12px] font-medium text-muted-foreground transition-colors hover:bg-sidebar-accent/65 hover:text-sidebar-foreground"
|
||||
>
|
||||
{t("chat.showMore", { count: hiddenSessionCount })}
|
||||
</button>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
function SessionActivityIndicator({
|
||||
state,
|
||||
@ -366,6 +421,45 @@ function groupSessions(
|
||||
return groups;
|
||||
}
|
||||
|
||||
function limitGroups(
|
||||
groups: Array<{ label: string; sessions: ChatSummary[] }>,
|
||||
limit: number,
|
||||
activeKey: string | null,
|
||||
): Array<{ label: string; sessions: ChatSummary[] }> {
|
||||
let remaining = Math.max(0, limit);
|
||||
let activeVisible = !activeKey;
|
||||
const out: Array<{ label: string; sessions: ChatSummary[] }> = [];
|
||||
|
||||
for (const group of groups) {
|
||||
const visible = remaining > 0
|
||||
? group.sessions.slice(0, remaining)
|
||||
: [];
|
||||
remaining -= visible.length;
|
||||
if (activeKey && visible.some((session) => session.key === activeKey)) {
|
||||
activeVisible = true;
|
||||
}
|
||||
if (visible.length > 0) {
|
||||
out.push({ label: group.label, sessions: visible });
|
||||
}
|
||||
}
|
||||
|
||||
if (activeVisible || !activeKey) return out;
|
||||
|
||||
for (const group of groups) {
|
||||
const active = group.sessions.find((session) => session.key === activeKey);
|
||||
if (!active) continue;
|
||||
const existing = out.find((item) => item.label === group.label);
|
||||
if (existing) {
|
||||
existing.sessions = [...existing.sessions, active];
|
||||
} else {
|
||||
out.push({ label: group.label, sessions: [active] });
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
function sortSessions(
|
||||
sessions: ChatSummary[],
|
||||
sort: SidebarSortMode,
|
||||
|
||||
@ -37,13 +37,16 @@ export function SessionSearchDialog({
|
||||
const [highlightedIndex, setHighlightedIndex] = useState(0);
|
||||
|
||||
const normalizedQuery = query.trim().toLowerCase();
|
||||
const results = useMemo(() => {
|
||||
const sessionResults = useMemo(() => {
|
||||
if (!open) return [];
|
||||
if (!normalizedQuery) return sessions;
|
||||
const terms = normalizedQuery.split(/\s+/).filter(Boolean);
|
||||
return sessions.filter((session) =>
|
||||
sessionMatchesTerms(session, terms, titleOverrides[session.key]),
|
||||
);
|
||||
}, [normalizedQuery, sessions, titleOverrides]);
|
||||
}, [normalizedQuery, open, sessions, titleOverrides]);
|
||||
const itemCount = sessionResults.length;
|
||||
const shortcutLabel = useMemo(getSearchShortcutLabel, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!open) return;
|
||||
@ -58,9 +61,9 @@ export function SessionSearchDialog({
|
||||
|
||||
useEffect(() => {
|
||||
setHighlightedIndex((index) =>
|
||||
results.length === 0 ? 0 : Math.min(index, results.length - 1),
|
||||
itemCount === 0 ? 0 : Math.min(index, itemCount - 1),
|
||||
);
|
||||
}, [results.length]);
|
||||
}, [itemCount]);
|
||||
|
||||
const handleSelect = (key: string) => {
|
||||
onOpenChange(false);
|
||||
@ -71,17 +74,19 @@ export function SessionSearchDialog({
|
||||
if (event.key === "ArrowDown") {
|
||||
event.preventDefault();
|
||||
setHighlightedIndex((index) =>
|
||||
results.length === 0 ? 0 : Math.min(index + 1, results.length - 1),
|
||||
itemCount === 0 ? 0 : (index + 1) % itemCount,
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (event.key === "ArrowUp") {
|
||||
event.preventDefault();
|
||||
setHighlightedIndex((index) => Math.max(index - 1, 0));
|
||||
setHighlightedIndex((index) =>
|
||||
itemCount === 0 ? 0 : (index - 1 + itemCount) % itemCount,
|
||||
);
|
||||
return;
|
||||
}
|
||||
if (event.key === "Enter") {
|
||||
const highlighted = results[highlightedIndex];
|
||||
const highlighted = sessionResults[highlightedIndex];
|
||||
if (!highlighted) return;
|
||||
event.preventDefault();
|
||||
handleSelect(highlighted.key);
|
||||
@ -125,70 +130,75 @@ export function SessionSearchDialog({
|
||||
aria-label={t("sidebar.searchAria")}
|
||||
className="h-full min-w-0 flex-1 bg-transparent text-[15px] font-medium text-foreground outline-none placeholder:text-muted-foreground/75"
|
||||
/>
|
||||
<kbd className="hidden h-6 shrink-0 items-center rounded-md border border-border/70 bg-muted/60 px-2 text-[11px] font-medium text-muted-foreground sm:inline-flex">
|
||||
{shortcutLabel}
|
||||
</kbd>
|
||||
</div>
|
||||
|
||||
<div className="min-h-0 overflow-y-auto overscroll-contain p-2">
|
||||
<div className="px-2 pb-1.5 pt-1 text-[12px] font-medium text-muted-foreground/70">
|
||||
{sectionLabel}
|
||||
</div>
|
||||
<section>
|
||||
<div className="px-2 pb-1.5 pt-1 text-[12px] font-medium text-muted-foreground/70">
|
||||
{sectionLabel}
|
||||
</div>
|
||||
|
||||
{loading && sessions.length === 0 ? (
|
||||
<div className="px-3 py-7 text-[13px] text-muted-foreground">
|
||||
{t("chat.loading")}
|
||||
</div>
|
||||
) : results.length === 0 ? (
|
||||
<div className="px-3 py-7 text-[13px] text-muted-foreground">
|
||||
{emptyLabel}
|
||||
</div>
|
||||
) : (
|
||||
<ul className="space-y-1">
|
||||
{results.map((session, index) => {
|
||||
const title = titleOverrides[session.key]?.trim() ||
|
||||
session.title?.trim() ||
|
||||
deriveTitle(session.preview, t("chat.newChat"));
|
||||
const preview = session.preview.trim();
|
||||
const showPreview =
|
||||
preview.length > 0 &&
|
||||
preview.toLowerCase() !== title.trim().toLowerCase();
|
||||
const highlighted = index === highlightedIndex;
|
||||
const active = session.key === activeKey;
|
||||
return (
|
||||
<li key={session.key}>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => handleSelect(session.key)}
|
||||
onMouseEnter={() => setHighlightedIndex(index)}
|
||||
aria-current={active ? "page" : undefined}
|
||||
className={cn(
|
||||
"flex min-h-12 w-full min-w-0 rounded-xl px-3 py-2.5 text-left transition-colors",
|
||||
highlighted
|
||||
? "bg-accent text-accent-foreground"
|
||||
: "text-popover-foreground hover:bg-accent/75 hover:text-accent-foreground",
|
||||
)}
|
||||
>
|
||||
<span className="min-w-0 flex-1">
|
||||
<span className="block truncate text-[14px] font-medium leading-5">
|
||||
{title}
|
||||
</span>
|
||||
{showPreview ? (
|
||||
<span
|
||||
className={cn(
|
||||
"block truncate text-[12px] leading-4",
|
||||
highlighted
|
||||
? "text-accent-foreground/70"
|
||||
: "text-muted-foreground",
|
||||
)}
|
||||
>
|
||||
{preview}
|
||||
{loading && sessions.length === 0 ? (
|
||||
<div className="px-3 py-7 text-[13px] text-muted-foreground">
|
||||
{t("chat.loading")}
|
||||
</div>
|
||||
) : sessionResults.length === 0 ? (
|
||||
<div className="px-3 py-7 text-[13px] text-muted-foreground">
|
||||
{emptyLabel}
|
||||
</div>
|
||||
) : (
|
||||
<ul className="space-y-1">
|
||||
{sessionResults.map((session, index) => {
|
||||
const title = titleOverrides[session.key]?.trim() ||
|
||||
session.title?.trim() ||
|
||||
deriveTitle(session.preview, t("chat.newChat"));
|
||||
const preview = session.preview.trim();
|
||||
const showPreview =
|
||||
preview.length > 0 &&
|
||||
preview.toLowerCase() !== title.trim().toLowerCase();
|
||||
const highlighted = index === highlightedIndex;
|
||||
const active = session.key === activeKey;
|
||||
return (
|
||||
<li key={session.key}>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => handleSelect(session.key)}
|
||||
onMouseEnter={() => setHighlightedIndex(index)}
|
||||
aria-current={active ? "page" : undefined}
|
||||
className={cn(
|
||||
"flex min-h-12 w-full min-w-0 rounded-xl px-3 py-2.5 text-left transition-colors",
|
||||
highlighted
|
||||
? "bg-accent text-accent-foreground"
|
||||
: "text-popover-foreground hover:bg-accent/75 hover:text-accent-foreground",
|
||||
)}
|
||||
>
|
||||
<span className="min-w-0 flex-1">
|
||||
<span className="block truncate text-[14px] font-medium leading-5">
|
||||
{title}
|
||||
</span>
|
||||
) : null}
|
||||
</span>
|
||||
</button>
|
||||
</li>
|
||||
);
|
||||
})}
|
||||
</ul>
|
||||
)}
|
||||
{showPreview ? (
|
||||
<span
|
||||
className={cn(
|
||||
"block truncate text-[12px] leading-4",
|
||||
highlighted
|
||||
? "text-accent-foreground/70"
|
||||
: "text-muted-foreground",
|
||||
)}
|
||||
>
|
||||
{preview}
|
||||
</span>
|
||||
) : null}
|
||||
</span>
|
||||
</button>
|
||||
</li>
|
||||
);
|
||||
})}
|
||||
</ul>
|
||||
)}
|
||||
</section>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
@ -211,3 +221,13 @@ function sessionMatchesTerms(
|
||||
|
||||
return terms.every((term) => haystack.includes(term));
|
||||
}
|
||||
|
||||
function getSearchShortcutLabel() {
|
||||
if (typeof navigator === "undefined") return "Ctrl K";
|
||||
const platform = navigator.platform.toLowerCase();
|
||||
const apple =
|
||||
platform.includes("mac") ||
|
||||
platform.includes("iphone") ||
|
||||
platform.includes("ipad");
|
||||
return apple ? "⌘K" : "Ctrl K";
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { useState } from "react";
|
||||
import { useState, type ReactNode } from "react";
|
||||
import {
|
||||
Archive,
|
||||
ListFilter,
|
||||
@ -28,6 +28,7 @@ import type {
|
||||
SidebarSortMode,
|
||||
SidebarViewState,
|
||||
} from "@/lib/types";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface SidebarProps {
|
||||
sessions: ChatSummary[];
|
||||
@ -44,7 +45,9 @@ interface SidebarProps {
|
||||
onToggleArchived: () => void;
|
||||
onUpdateView: (view: Partial<SidebarViewState>) => void;
|
||||
onCollapse: () => void;
|
||||
onExpand?: () => void;
|
||||
containActionMenus?: boolean;
|
||||
collapsed?: boolean;
|
||||
pinnedKeys?: string[];
|
||||
archivedKeys?: string[];
|
||||
titleOverrides?: Record<string, string>;
|
||||
@ -59,6 +62,8 @@ export function Sidebar(props: SidebarProps) {
|
||||
const { t } = useTranslation();
|
||||
const [menuPortalContainer, setMenuPortalContainer] =
|
||||
useState<HTMLElement | null>(null);
|
||||
const collapsed = Boolean(props.collapsed);
|
||||
const toggleLabel = t("thread.header.toggleSidebar");
|
||||
|
||||
return (
|
||||
<nav
|
||||
@ -66,108 +71,189 @@ export function Sidebar(props: SidebarProps) {
|
||||
aria-label={t("sidebar.navigation")}
|
||||
className="flex h-full w-full min-w-0 flex-col border-r border-sidebar-border/60 bg-sidebar text-sidebar-foreground"
|
||||
>
|
||||
<div className="flex items-center justify-between px-3 pb-2.5 pt-3">
|
||||
<picture className="block min-w-0">
|
||||
<source srcSet="/brand/nanobot_logo.webp" type="image/webp" />
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center px-3 pb-2.5 pt-3",
|
||||
collapsed ? "w-14 justify-start" : "justify-between",
|
||||
)}
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
aria-label={collapsed ? toggleLabel : undefined}
|
||||
aria-hidden={collapsed ? undefined : true}
|
||||
title={collapsed ? toggleLabel : undefined}
|
||||
onClick={collapsed ? props.onExpand : undefined}
|
||||
tabIndex={collapsed ? 0 : -1}
|
||||
className={cn(
|
||||
"flex h-9 w-9 shrink-0 items-center justify-center overflow-hidden rounded-xl transition-colors",
|
||||
collapsed
|
||||
? "-ml-0.5 hover:bg-sidebar-accent/75"
|
||||
: "pointer-events-none -ml-0.5",
|
||||
)}
|
||||
>
|
||||
<img
|
||||
src="/brand/nanobot_logo.png"
|
||||
alt="nanobot"
|
||||
className="h-6 w-auto select-none object-contain opacity-95"
|
||||
src="/brand/nanobot_icon.png"
|
||||
alt=""
|
||||
className="h-8 w-8 select-none object-contain"
|
||||
draggable={false}
|
||||
/>
|
||||
</picture>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
aria-label={t("sidebar.collapse")}
|
||||
onClick={props.onCollapse}
|
||||
className="h-7 w-7 rounded-lg text-muted-foreground/85 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
||||
>
|
||||
<Menu className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</button>
|
||||
{!collapsed && (
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
aria-label={t("sidebar.collapse")}
|
||||
onClick={props.onCollapse}
|
||||
className="h-7 w-7 rounded-lg text-muted-foreground/85 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
||||
>
|
||||
<Menu className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="space-y-1.5 px-2 pb-2">
|
||||
<Button
|
||||
<div
|
||||
className={cn(
|
||||
"space-y-1.5 px-2 pb-2",
|
||||
collapsed && "flex w-14 flex-col items-center px-0",
|
||||
)}
|
||||
>
|
||||
<SidebarActionButton
|
||||
collapsed={collapsed}
|
||||
label={t("sidebar.newChat")}
|
||||
onClick={props.onNewChat}
|
||||
className="h-8 w-full justify-start gap-2 rounded-full px-3 text-[12.5px] font-medium text-sidebar-foreground/92 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
||||
variant="ghost"
|
||||
>
|
||||
<SquarePen className="h-3.5 w-3.5" />
|
||||
{t("sidebar.newChat")}
|
||||
</Button>
|
||||
<Button
|
||||
type="button"
|
||||
icon={<SquarePen className="h-4 w-4" />}
|
||||
/>
|
||||
<SidebarActionButton
|
||||
collapsed={collapsed}
|
||||
label={t("sidebar.searchAria")}
|
||||
onClick={props.onOpenSearch}
|
||||
className="h-8 w-full justify-start gap-2 rounded-full px-3 text-[12.5px] font-medium text-sidebar-foreground/85 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
||||
variant="ghost"
|
||||
>
|
||||
<Search className="h-3.5 w-3.5" aria-hidden />
|
||||
{t("sidebar.searchAria")}
|
||||
</Button>
|
||||
icon={<Search className="h-4 w-4" />}
|
||||
/>
|
||||
<SidebarViewMenu
|
||||
compact={collapsed}
|
||||
view={props.viewState}
|
||||
onUpdateView={props.onUpdateView}
|
||||
/>
|
||||
{props.archivedCount ? (
|
||||
<Button
|
||||
type="button"
|
||||
<SidebarActionButton
|
||||
collapsed={collapsed}
|
||||
label={props.showArchived ? t("chat.hideArchived") : t("chat.showArchived")}
|
||||
onClick={props.onToggleArchived}
|
||||
className="h-8 w-full justify-start gap-2 rounded-full px-3 text-[12.5px] font-medium text-sidebar-foreground/75 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
||||
variant="ghost"
|
||||
>
|
||||
<Archive className="h-3.5 w-3.5" aria-hidden />
|
||||
{props.showArchived ? t("chat.hideArchived") : t("chat.showArchived")}
|
||||
</Button>
|
||||
icon={<Archive className="h-4 w-4" />}
|
||||
/>
|
||||
) : null}
|
||||
</div>
|
||||
<div className="flex min-h-0 min-w-0 flex-1 flex-col overflow-hidden">
|
||||
<ChatList
|
||||
sessions={props.sessions}
|
||||
activeKey={props.activeKey}
|
||||
loading={props.loading}
|
||||
emptyLabel={t("chat.noSessions")}
|
||||
onSelect={props.onSelect}
|
||||
onRequestDelete={props.onRequestDelete}
|
||||
onTogglePin={props.onTogglePin}
|
||||
onRequestRename={props.onRequestRename}
|
||||
onToggleArchive={props.onToggleArchive}
|
||||
pinnedKeys={props.pinnedKeys}
|
||||
archivedKeys={props.archivedKeys}
|
||||
titleOverrides={props.titleOverrides}
|
||||
runningChatIds={props.runningChatIds}
|
||||
completedChatIds={props.completedChatIds}
|
||||
density={props.viewState?.density}
|
||||
showPreviews={props.viewState?.show_previews}
|
||||
showTimestamps={props.viewState?.show_timestamps}
|
||||
sort={props.viewState?.sort}
|
||||
showArchived={props.showArchived}
|
||||
actionMenuPortalContainer={
|
||||
props.containActionMenus ? menuPortalContainer : undefined
|
||||
}
|
||||
/>
|
||||
<div
|
||||
className={cn(
|
||||
"flex min-h-0 min-w-0 flex-1 flex-col overflow-hidden transition-opacity duration-200",
|
||||
collapsed && "pointer-events-none opacity-0",
|
||||
)}
|
||||
>
|
||||
{!collapsed && (
|
||||
<ChatList
|
||||
sessions={props.sessions}
|
||||
activeKey={props.activeKey}
|
||||
loading={props.loading}
|
||||
emptyLabel={t("chat.noSessions")}
|
||||
onSelect={props.onSelect}
|
||||
onRequestDelete={props.onRequestDelete}
|
||||
onTogglePin={props.onTogglePin}
|
||||
onRequestRename={props.onRequestRename}
|
||||
onToggleArchive={props.onToggleArchive}
|
||||
pinnedKeys={props.pinnedKeys}
|
||||
archivedKeys={props.archivedKeys}
|
||||
titleOverrides={props.titleOverrides}
|
||||
runningChatIds={props.runningChatIds}
|
||||
completedChatIds={props.completedChatIds}
|
||||
density={props.viewState?.density}
|
||||
showPreviews={props.viewState?.show_previews}
|
||||
showTimestamps={props.viewState?.show_timestamps}
|
||||
sort={props.viewState?.sort}
|
||||
showArchived={props.showArchived}
|
||||
actionMenuPortalContainer={
|
||||
props.containActionMenus ? menuPortalContainer : undefined
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<Separator className="bg-sidebar-border/50" />
|
||||
<div className="flex items-center gap-1 px-2.5 py-2.5 text-xs">
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
<div
|
||||
className={cn(
|
||||
"flex items-center gap-1 px-2.5 py-2.5 text-xs",
|
||||
collapsed && "w-14 flex-col px-0",
|
||||
)}
|
||||
>
|
||||
<SidebarActionButton
|
||||
collapsed={collapsed}
|
||||
label={t("sidebar.settings")}
|
||||
onClick={props.onOpenSettings}
|
||||
className="h-8 min-w-0 flex-1 justify-start gap-2 rounded-full px-2.5 text-[12.5px] font-medium text-sidebar-foreground/85 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
||||
>
|
||||
<Settings className="h-3.5 w-3.5" aria-hidden />
|
||||
{t("sidebar.settings")}
|
||||
</Button>
|
||||
className={collapsed ? undefined : "flex-1"}
|
||||
icon={<Settings className="h-4 w-4" />}
|
||||
/>
|
||||
<ConnectionBadge />
|
||||
</div>
|
||||
</nav>
|
||||
);
|
||||
}
|
||||
|
||||
function SidebarActionButton({
|
||||
collapsed,
|
||||
label,
|
||||
icon,
|
||||
onClick,
|
||||
className,
|
||||
}: {
|
||||
collapsed: boolean;
|
||||
label: string;
|
||||
icon: ReactNode;
|
||||
onClick: () => void;
|
||||
className?: string;
|
||||
}) {
|
||||
return (
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
aria-label={label}
|
||||
title={collapsed ? label : undefined}
|
||||
onClick={onClick}
|
||||
className={cn(
|
||||
"group h-8 min-w-0 gap-2 overflow-hidden rounded-full font-medium text-sidebar-foreground/85 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground",
|
||||
"transition-[width,padding,border-radius,color,background-color] duration-300 ease-out",
|
||||
collapsed
|
||||
? "w-9 justify-center gap-0 rounded-xl px-0"
|
||||
: "w-full justify-start gap-2 px-3 text-[12.5px]",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
"flex shrink-0 items-center justify-center transition-transform duration-300 ease-out",
|
||||
collapsed ? "translate-x-0" : "translate-x-0",
|
||||
)}
|
||||
aria-hidden
|
||||
>
|
||||
{icon}
|
||||
</span>
|
||||
<span
|
||||
className={cn(
|
||||
"min-w-0 overflow-hidden truncate whitespace-nowrap transition-[max-width,opacity,transform] duration-200 ease-out",
|
||||
collapsed
|
||||
? "max-w-0 -translate-x-1 opacity-0"
|
||||
: "max-w-[12rem] translate-x-0 opacity-100",
|
||||
)}
|
||||
>
|
||||
{label}
|
||||
</span>
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
||||
function SidebarViewMenu({
|
||||
compact = false,
|
||||
view,
|
||||
onUpdateView,
|
||||
}: {
|
||||
compact?: boolean;
|
||||
view?: SidebarViewState;
|
||||
onUpdateView: (view: Partial<SidebarViewState>) => void;
|
||||
}) {
|
||||
@ -182,11 +268,28 @@ function SidebarViewMenu({
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button
|
||||
type="button"
|
||||
className="h-8 w-full justify-start gap-2 rounded-full px-3 text-[12.5px] font-medium text-sidebar-foreground/75 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground"
|
||||
aria-label={t("sidebar.viewOptions")}
|
||||
title={compact ? t("sidebar.viewOptions") : undefined}
|
||||
className={cn(
|
||||
"h-8 min-w-0 overflow-hidden font-medium text-sidebar-foreground/75 hover:bg-sidebar-accent/75 hover:text-sidebar-foreground",
|
||||
"transition-[width,padding,border-radius,color,background-color] duration-300 ease-out",
|
||||
compact
|
||||
? "w-9 justify-center gap-0 rounded-xl px-0"
|
||||
: "w-full justify-start gap-2 rounded-full px-3 text-[12.5px]",
|
||||
)}
|
||||
variant="ghost"
|
||||
>
|
||||
<ListFilter className="h-3.5 w-3.5" aria-hidden />
|
||||
{t("sidebar.viewOptions")}
|
||||
<ListFilter className="h-4 w-4 shrink-0" aria-hidden />
|
||||
<span
|
||||
className={cn(
|
||||
"min-w-0 overflow-hidden truncate whitespace-nowrap transition-[max-width,opacity,transform] duration-200 ease-out",
|
||||
compact
|
||||
? "max-w-0 -translate-x-1 opacity-0"
|
||||
: "max-w-[12rem] translate-x-0 opacity-100",
|
||||
)}
|
||||
>
|
||||
{t("sidebar.viewOptions")}
|
||||
</span>
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="start" className="w-52">
|
||||
|
||||
@ -27,6 +27,7 @@ interface ActivityCounts {
|
||||
fileCount: number;
|
||||
added: number;
|
||||
deleted: number;
|
||||
hasDiffStats: boolean;
|
||||
hasEditingFiles: boolean;
|
||||
hasFailedFiles: boolean;
|
||||
primaryFilePath?: string;
|
||||
@ -61,6 +62,7 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act
|
||||
}
|
||||
let added = 0;
|
||||
let deleted = 0;
|
||||
let hasDiffStats = false;
|
||||
let hasEditingFiles = false;
|
||||
let failedFileCount = 0;
|
||||
let primaryFilePath: string | undefined;
|
||||
@ -77,6 +79,10 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act
|
||||
if (edit.status === "error" || edit.binary) {
|
||||
continue;
|
||||
}
|
||||
if (!hasVisibleDiffStats(edit)) {
|
||||
continue;
|
||||
}
|
||||
hasDiffStats = true;
|
||||
added += edit.added;
|
||||
deleted += edit.deleted;
|
||||
}
|
||||
@ -86,6 +92,7 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act
|
||||
fileCount: fileEdits.length,
|
||||
added,
|
||||
deleted,
|
||||
hasDiffStats,
|
||||
hasEditingFiles,
|
||||
hasFailedFiles: fileEdits.length > 0 && failedFileCount === fileEdits.length,
|
||||
primaryFilePath,
|
||||
@ -120,6 +127,7 @@ export function AgentActivityCluster({
|
||||
fileCount,
|
||||
added,
|
||||
deleted,
|
||||
hasDiffStats,
|
||||
hasEditingFiles,
|
||||
hasFailedFiles,
|
||||
primaryFilePath,
|
||||
@ -140,6 +148,7 @@ export function AgentActivityCluster({
|
||||
const headerBusy = fileCount > 0 ? hasEditingFiles : isTurnStreaming;
|
||||
const singleFilePath = fileCount === 1 ? primaryFilePath : undefined;
|
||||
const singleFileTooltipPath = fileCount === 1 ? primaryFileTooltipPath : undefined;
|
||||
const hasVisibleActivity = reasoningSteps > 0 || toolCalls > 0 || fileCount > 0;
|
||||
|
||||
const fileActivitySummary = fileCount > 0
|
||||
? hasPendingFileEdit && !singleFilePath
|
||||
@ -243,6 +252,8 @@ export function AgentActivityCluster({
|
||||
autoFollowActivityRef.current = distance < ACTIVITY_SCROLL_NEAR_BOTTOM_PX;
|
||||
}, []);
|
||||
|
||||
if (!hasVisibleActivity) return null;
|
||||
|
||||
return (
|
||||
<div className={cn("w-full", hasBodyBelow && "mb-2")}>
|
||||
<button
|
||||
@ -282,7 +293,7 @@ export function AgentActivityCluster({
|
||||
{summary}
|
||||
</StreamingLabelSheen>
|
||||
)}
|
||||
{fileCount > 0 && (
|
||||
{fileCount > 0 && hasDiffStats && (
|
||||
<span className="inline-flex min-w-0 items-center gap-1 text-muted-foreground/85">
|
||||
<DiffPair added={added} deleted={deleted} />
|
||||
</span>
|
||||
@ -435,6 +446,17 @@ function summarizeFileEdits(edits: UIFileEdit[], active: boolean): FileEditSumma
|
||||
summary.absolute_path = edit.absolute_path;
|
||||
}
|
||||
summary.pending = summary.pending || !!edit.pending || !edit.path;
|
||||
if (!edit.path && edit.pending) {
|
||||
if (active && edit.status === "editing") {
|
||||
summary.hasActiveEditing = true;
|
||||
summary.approximate = summary.approximate || !!edit.approximate;
|
||||
if (!edit.binary) {
|
||||
summary.added += edit.added;
|
||||
summary.deleted += edit.deleted;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (active && edit.status === "editing") {
|
||||
summary.hasActiveEditing = true;
|
||||
summary.binary = summary.binary || !!edit.binary;
|
||||
@ -461,8 +483,16 @@ function summarizeFileEdits(edits: UIFileEdit[], active: boolean): FileEditSumma
|
||||
}
|
||||
}
|
||||
|
||||
return order.map((key) => {
|
||||
return order.flatMap((key) => {
|
||||
const summary = byPath.get(key)!;
|
||||
if (
|
||||
!summary.path
|
||||
&& !summary.hasActiveEditing
|
||||
&& !summary.hasSuccessfulChange
|
||||
&& !summary.hasFailed
|
||||
) {
|
||||
return [];
|
||||
}
|
||||
const status: UIFileEdit["status"] = summary.hasActiveEditing
|
||||
? "editing"
|
||||
: summary.hasSuccessfulChange
|
||||
@ -470,7 +500,7 @@ function summarizeFileEdits(edits: UIFileEdit[], active: boolean): FileEditSumma
|
||||
: summary.hasFailed
|
||||
? "error"
|
||||
: "done";
|
||||
return {
|
||||
return [{
|
||||
key: summary.key,
|
||||
path: summary.path,
|
||||
absolute_path: summary.absolute_path,
|
||||
@ -481,10 +511,14 @@ function summarizeFileEdits(edits: UIFileEdit[], active: boolean): FileEditSumma
|
||||
status,
|
||||
pending: summary.pending && !summary.path,
|
||||
error: summary.error,
|
||||
};
|
||||
}];
|
||||
});
|
||||
}
|
||||
|
||||
function hasVisibleDiffStats(edit: Pick<FileEditSummary, "added" | "deleted">): boolean {
|
||||
return edit.added > 0 || edit.deleted > 0;
|
||||
}
|
||||
|
||||
function FileEditGroup({ edits }: { edits: FileEditSummary[] }) {
|
||||
if (edits.length === 0) return null;
|
||||
return (
|
||||
@ -500,7 +534,7 @@ function FileEditRow({ edit }: { edit: FileEditSummary }) {
|
||||
const { t } = useTranslation();
|
||||
const editing = edit.status === "editing";
|
||||
const failed = edit.status === "error";
|
||||
const hasCountedDiff = !failed && !edit.binary;
|
||||
const hasCountedDiff = !failed && !edit.binary && hasVisibleDiffStats(edit);
|
||||
return (
|
||||
<li className="grid grid-cols-[minmax(0,1fr)_auto] items-center gap-3 rounded-md px-2 py-1.5 text-xs">
|
||||
<div className="flex min-w-0 items-center gap-2">
|
||||
|
||||
@ -32,12 +32,17 @@ export function ThreadHeader({
|
||||
onClick={onToggleSidebar}
|
||||
className={cn(
|
||||
"h-7 w-7 rounded-md text-muted-foreground hover:bg-accent/35 hover:text-foreground",
|
||||
hideSidebarToggleOnDesktop && "lg:pointer-events-none lg:opacity-0",
|
||||
hideSidebarToggleOnDesktop && "lg:hidden",
|
||||
)}
|
||||
>
|
||||
<Menu className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
<ThemeButton theme={theme} onToggleTheme={onToggleTheme} label={t("thread.header.toggleTheme")} />
|
||||
<ThemeButton
|
||||
theme={theme}
|
||||
onToggleTheme={onToggleTheme}
|
||||
label={t("thread.header.toggleTheme")}
|
||||
className="ml-auto"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -52,7 +57,7 @@ export function ThreadHeader({
|
||||
onClick={onToggleSidebar}
|
||||
className={cn(
|
||||
"h-7 w-7 rounded-md text-muted-foreground hover:bg-accent/35 hover:text-foreground",
|
||||
hideSidebarToggleOnDesktop && "lg:pointer-events-none lg:opacity-0",
|
||||
hideSidebarToggleOnDesktop && "lg:hidden",
|
||||
)}
|
||||
>
|
||||
<Menu className="h-3.5 w-3.5" />
|
||||
@ -62,7 +67,12 @@ export function ThreadHeader({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<ThemeButton theme={theme} onToggleTheme={onToggleTheme} label={t("thread.header.toggleTheme")} />
|
||||
<ThemeButton
|
||||
theme={theme}
|
||||
onToggleTheme={onToggleTheme}
|
||||
label={t("thread.header.toggleTheme")}
|
||||
className="ml-auto shrink-0"
|
||||
/>
|
||||
|
||||
<div aria-hidden className="pointer-events-none absolute inset-x-0 top-full h-4" />
|
||||
</div>
|
||||
@ -73,10 +83,12 @@ function ThemeButton({
|
||||
theme,
|
||||
onToggleTheme,
|
||||
label,
|
||||
className,
|
||||
}: {
|
||||
theme: "light" | "dark";
|
||||
onToggleTheme: () => void;
|
||||
label: string;
|
||||
className?: string;
|
||||
}) {
|
||||
return (
|
||||
<Button
|
||||
@ -84,7 +96,10 @@ function ThemeButton({
|
||||
size="icon"
|
||||
aria-label={label}
|
||||
onClick={onToggleTheme}
|
||||
className="h-8 w-8 rounded-full text-muted-foreground/85 hover:bg-accent/40 hover:text-foreground"
|
||||
className={cn(
|
||||
"h-8 w-8 rounded-full text-muted-foreground/85 hover:bg-accent/40 hover:text-foreground",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{theme === "dark" ? (
|
||||
<Sun className="h-4 w-4" />
|
||||
|
||||
@ -27,13 +27,25 @@ export function useSessions(): {
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const tokenRef = useRef(token);
|
||||
const optimisticKeysRef = useRef<Set<string>>(new Set());
|
||||
tokenRef.current = token;
|
||||
|
||||
const refresh = useCallback(async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
const rows = await listSessions(tokenRef.current);
|
||||
setSessions(rows);
|
||||
const serverKeys = new Set(rows.map((row) => row.key));
|
||||
setSessions((prev) => [
|
||||
...rows,
|
||||
...prev.filter(
|
||||
(session) =>
|
||||
optimisticKeysRef.current.has(session.key) &&
|
||||
!serverKeys.has(session.key),
|
||||
),
|
||||
]);
|
||||
for (const key of Array.from(optimisticKeysRef.current)) {
|
||||
if (serverKeys.has(key)) optimisticKeysRef.current.delete(key);
|
||||
}
|
||||
setError(null);
|
||||
} catch (e) {
|
||||
const msg =
|
||||
@ -57,6 +69,7 @@ export function useSessions(): {
|
||||
const createChat = useCallback(async (): Promise<string> => {
|
||||
const chatId = await client.newChat();
|
||||
const key = `websocket:${chatId}`;
|
||||
optimisticKeysRef.current.add(key);
|
||||
// Optimistic insert; a subsequent refresh will replace it with the
|
||||
// authoritative row once the server persists the session.
|
||||
setSessions((prev) => [
|
||||
@ -77,6 +90,7 @@ export function useSessions(): {
|
||||
const deleteChat = useCallback(
|
||||
async (key: string) => {
|
||||
await apiDeleteSession(tokenRef.current, key);
|
||||
optimisticKeysRef.current.delete(key);
|
||||
setSessions((prev) => prev.filter((s) => s.key !== key));
|
||||
},
|
||||
[],
|
||||
|
||||
@ -271,6 +271,7 @@
|
||||
"fallbackTitle": "Chat {{id}}",
|
||||
"loading": "Loading…",
|
||||
"noSessions": "No sessions yet.",
|
||||
"showMore": "Show {{count}} more",
|
||||
"actions": "Chat actions for {{title}}",
|
||||
"activity": {
|
||||
"running": "Agent running",
|
||||
|
||||
@ -224,6 +224,7 @@
|
||||
"fallbackTitle": "Chat {{id}}",
|
||||
"loading": "Cargando…",
|
||||
"noSessions": "Todavía no hay sesiones.",
|
||||
"showMore": "Mostrar {{count}} más",
|
||||
"actions": "Acciones del chat {{title}}",
|
||||
"activity": {
|
||||
"running": "Agent running",
|
||||
|
||||
@ -224,6 +224,7 @@
|
||||
"fallbackTitle": "Discussion {{id}}",
|
||||
"loading": "Chargement…",
|
||||
"noSessions": "Aucune session pour le moment.",
|
||||
"showMore": "Afficher {{count}} de plus",
|
||||
"actions": "Actions de la discussion {{title}}",
|
||||
"activity": {
|
||||
"running": "Agent running",
|
||||
|
||||
@ -224,6 +224,7 @@
|
||||
"fallbackTitle": "Obrolan {{id}}",
|
||||
"loading": "Memuat…",
|
||||
"noSessions": "Belum ada sesi.",
|
||||
"showMore": "Tampilkan {{count}} lagi",
|
||||
"actions": "Aksi obrolan untuk {{title}}",
|
||||
"activity": {
|
||||
"running": "Agent running",
|
||||
|
||||
@ -224,6 +224,7 @@
|
||||
"fallbackTitle": "チャット {{id}}",
|
||||
"loading": "読み込み中…",
|
||||
"noSessions": "まだセッションがありません。",
|
||||
"showMore": "さらに {{count}} 件表示",
|
||||
"actions": "「{{title}}」のチャット操作",
|
||||
"activity": {
|
||||
"running": "Agent running",
|
||||
|
||||
@ -224,6 +224,7 @@
|
||||
"fallbackTitle": "채팅 {{id}}",
|
||||
"loading": "불러오는 중…",
|
||||
"noSessions": "아직 세션이 없습니다.",
|
||||
"showMore": "{{count}}개 더 보기",
|
||||
"actions": "{{title}} 채팅 작업",
|
||||
"activity": {
|
||||
"running": "Agent running",
|
||||
|
||||
@ -224,6 +224,7 @@
|
||||
"fallbackTitle": "Trò chuyện {{id}}",
|
||||
"loading": "Đang tải…",
|
||||
"noSessions": "Chưa có phiên nào.",
|
||||
"showMore": "Hiển thị thêm {{count}}",
|
||||
"actions": "Tác vụ cho cuộc trò chuyện {{title}}",
|
||||
"activity": {
|
||||
"running": "Agent running",
|
||||
|
||||
@ -259,6 +259,7 @@
|
||||
"fallbackTitle": "对话 {{id}}",
|
||||
"loading": "加载中…",
|
||||
"noSessions": "还没有会话。",
|
||||
"showMore": "再显示 {{count}} 个",
|
||||
"actions": "“{{title}}” 的会话操作",
|
||||
"activity": {
|
||||
"running": "Agent 正在运行",
|
||||
|
||||
@ -224,6 +224,7 @@
|
||||
"fallbackTitle": "對話 {{id}}",
|
||||
"loading": "載入中…",
|
||||
"noSessions": "目前還沒有會話。",
|
||||
"showMore": "再顯示 {{count}} 個",
|
||||
"actions": "「{{title}}」的會話操作",
|
||||
"activity": {
|
||||
"running": "Agent 正在執行",
|
||||
|
||||
@ -271,6 +271,69 @@ describe("AgentActivityCluster", () => {
|
||||
}
|
||||
});
|
||||
|
||||
it("does not render zero diff counters for completed edits", () => {
|
||||
render(
|
||||
<AgentActivityCluster
|
||||
messages={activityMessages("", {
|
||||
id: "t2",
|
||||
role: "tool",
|
||||
kind: "trace",
|
||||
content: "edit_file()",
|
||||
traces: ["edit_file()"],
|
||||
fileEdits: [{
|
||||
call_id: "call-edit",
|
||||
tool: "edit_file",
|
||||
path: "src/app.tsx",
|
||||
phase: "end",
|
||||
added: 0,
|
||||
deleted: 0,
|
||||
approximate: false,
|
||||
status: "done",
|
||||
}],
|
||||
createdAt: 3,
|
||||
})}
|
||||
isTurnStreaming={false}
|
||||
hasBodyBelow={false}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByRole("button", { name: /edited app\.tsx/i })).toBeInTheDocument();
|
||||
expect(screen.queryByText("+0")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("-0")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("drops stale pathless pending edits after the turn completes", () => {
|
||||
render(
|
||||
<AgentActivityCluster
|
||||
messages={[{
|
||||
id: "t1",
|
||||
role: "tool",
|
||||
kind: "trace",
|
||||
content: "",
|
||||
traces: [],
|
||||
fileEdits: [{
|
||||
call_id: "call-edit",
|
||||
tool: "edit_file",
|
||||
path: "",
|
||||
phase: "start",
|
||||
added: 98,
|
||||
deleted: 0,
|
||||
approximate: true,
|
||||
status: "editing",
|
||||
pending: true,
|
||||
}],
|
||||
createdAt: 1,
|
||||
}]}
|
||||
isTurnStreaming={false}
|
||||
hasBodyBelow={false}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.queryByRole("button", { name: /preparing edit/i })).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("+98")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("0 tool calls")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders pending file edit placeholders before the path is known", () => {
|
||||
render(
|
||||
<AgentActivityCluster
|
||||
|
||||
@ -992,6 +992,73 @@ describe("App layout", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("opens search from the keyboard shortcut", async () => {
|
||||
mockSessions = [
|
||||
{
|
||||
key: "websocket:chat-a",
|
||||
channel: "websocket",
|
||||
chatId: "chat-a",
|
||||
createdAt: "2026-04-16T10:00:00Z",
|
||||
updatedAt: "2026-04-16T10:00:00Z",
|
||||
preview: "Existing chat",
|
||||
},
|
||||
];
|
||||
|
||||
render(<App />);
|
||||
|
||||
await waitFor(() => expect(connectSpy).toHaveBeenCalled());
|
||||
fireEvent.keyDown(window, { key: "k", metaKey: true });
|
||||
|
||||
const dialog = await screen.findByRole("dialog", { name: "Search" });
|
||||
expect(within(dialog).queryByText("Global actions")).not.toBeInTheDocument();
|
||||
expect(within(dialog).getByText("Existing chat")).toBeInTheDocument();
|
||||
|
||||
const textbox = within(dialog).getByRole("textbox", { name: "Search" });
|
||||
fireEvent.change(textbox, { target: { value: "missing" } });
|
||||
expect(within(dialog).queryByText("Existing chat")).not.toBeInTheDocument();
|
||||
|
||||
fireEvent.change(textbox, { target: { value: "existing" } });
|
||||
expect(within(dialog).getByText("Existing chat")).toBeInTheDocument();
|
||||
|
||||
fireEvent.keyDown(textbox, { key: "Enter" });
|
||||
await waitFor(() =>
|
||||
expect(screen.queryByRole("dialog", { name: "Search" })).not.toBeInTheDocument(),
|
||||
);
|
||||
expect(createChatSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("keeps large sidebars light while search still covers every chat", async () => {
|
||||
mockSessions = Array.from({ length: 170 }, (_, index) => {
|
||||
const chatId = `chat-${index}`;
|
||||
return {
|
||||
key: `websocket:${chatId}`,
|
||||
channel: "websocket" as const,
|
||||
chatId,
|
||||
createdAt: new Date(Date.UTC(2026, 3, 16, 12, 0 - index)).toISOString(),
|
||||
updatedAt: new Date(Date.UTC(2026, 3, 16, 12, 0 - index)).toISOString(),
|
||||
title: index === 169 ? "Hidden target" : `Bulk chat ${index}`,
|
||||
preview: "",
|
||||
};
|
||||
});
|
||||
|
||||
render(<App />);
|
||||
|
||||
await waitFor(() => expect(connectSpy).toHaveBeenCalled());
|
||||
const sidebar = screen.getByRole("navigation", { name: "Sidebar navigation" });
|
||||
await waitFor(() =>
|
||||
expect(within(sidebar).getByRole("button", { name: "Bulk chat 0" })).toBeInTheDocument(),
|
||||
);
|
||||
expect(within(sidebar).queryByText("Hidden target")).not.toBeInTheDocument();
|
||||
expect(within(sidebar).getByRole("button", { name: "Show 10 more" })).toBeInTheDocument();
|
||||
|
||||
fireEvent.click(within(sidebar).getByRole("button", { name: "Search" }));
|
||||
const dialog = await screen.findByRole("dialog", { name: "Search" });
|
||||
fireEvent.change(within(dialog).getByRole("textbox", { name: "Search" }), {
|
||||
target: { value: "hidden" },
|
||||
});
|
||||
expect(within(dialog).getByText("Hidden target")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("opens a blank start page without creating an empty chat", async () => {
|
||||
mockSessions = [
|
||||
{
|
||||
@ -1025,10 +1092,16 @@ describe("App layout", () => {
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: "Collapse sidebar" }));
|
||||
const desktopAside = container.querySelector("aside.lg\\:block") as HTMLElement;
|
||||
await waitFor(() => expect(desktopAside.style.width).toBe("0px"));
|
||||
await waitFor(() => expect(desktopAside.style.width).toBe("56px"));
|
||||
|
||||
expect(screen.queryByRole("button", { name: "Start a new chat" })).not.toBeInTheDocument();
|
||||
fireEvent.click(screen.getByRole("button", { name: "Toggle sidebar" }));
|
||||
const rail = screen.getByRole("navigation", { name: "Sidebar navigation" });
|
||||
expect(within(rail).getByRole("button", { name: "New chat" })).toBeInTheDocument();
|
||||
expect(within(rail).getByRole("button", { name: "Search" })).toBeInTheDocument();
|
||||
expect(within(rail).getByRole("button", { name: "View" })).toBeInTheDocument();
|
||||
expect(within(rail).queryByText("Existing chat")).not.toBeInTheDocument();
|
||||
|
||||
fireEvent.click(within(rail).getByRole("button", { name: "Toggle sidebar" }));
|
||||
await waitFor(() => expect(desktopAside.style.width).toBe("272px"));
|
||||
|
||||
const sidebar = screen.getByRole("navigation", { name: "Sidebar navigation" });
|
||||
|
||||
@ -78,6 +78,7 @@ describe("webui i18n", () => {
|
||||
const common = resource.common;
|
||||
expect(common.app.system.restarting).toBeTruthy();
|
||||
expect(common.sidebar.settings).toBeTruthy();
|
||||
expect(common.chat.showMore).toBeTruthy();
|
||||
expect(common.settings.sidebar.title).toBeTruthy();
|
||||
expect(common.settings.backToChat).toBeTruthy();
|
||||
for (const key of SETTINGS_NAV_KEYS) {
|
||||
|
||||
@ -157,6 +157,53 @@ describe("useSessions", () => {
|
||||
expect(api.listSessions).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("keeps a newly created chat visible until the server session list catches up", async () => {
|
||||
vi.mocked(api.listSessions)
|
||||
.mockResolvedValueOnce([])
|
||||
.mockResolvedValueOnce([])
|
||||
.mockResolvedValueOnce([
|
||||
{
|
||||
key: "websocket:chat-new",
|
||||
channel: "websocket",
|
||||
chatId: "chat-new",
|
||||
createdAt: "2026-05-20T10:00:00Z",
|
||||
updatedAt: "2026-05-20T10:01:00Z",
|
||||
title: "Generated title",
|
||||
preview: "First message",
|
||||
},
|
||||
]);
|
||||
const client = fakeClient();
|
||||
client.newChat.mockResolvedValue("chat-new");
|
||||
|
||||
const { result } = renderHook(() => useSessions(), {
|
||||
wrapper: wrap(client),
|
||||
});
|
||||
|
||||
await waitFor(() => expect(result.current.loading).toBe(false));
|
||||
expect(result.current.sessions).toEqual([]);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.createChat();
|
||||
});
|
||||
|
||||
expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.refresh();
|
||||
});
|
||||
|
||||
expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]);
|
||||
expect(result.current.sessions[0]?.preview).toBe("");
|
||||
|
||||
await act(async () => {
|
||||
await result.current.refresh();
|
||||
});
|
||||
|
||||
expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]);
|
||||
expect(result.current.sessions[0]?.preview).toBe("First message");
|
||||
expect(result.current.sessions[0]?.title).toBe("Generated title");
|
||||
});
|
||||
|
||||
it("passes through WebUI transcript user media as images and media", async () => {
|
||||
vi.mocked(api.fetchWebuiThread).mockResolvedValue({
|
||||
schemaVersion: 3,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user