diff --git a/README.md b/README.md
index b5e4b02c0..1dbc82db8 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,18 @@

+
+ English |
+ 简体中文 |
+ 繁體中文 |
+ Español |
+ Français |
+ Bahasa Indonesia |
+ 日本語 |
+ 한국어 |
+ Русский |
+ Tiếng Việt
+
@@ -61,7 +73,7 @@
- **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks.
- **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened.
- **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media.
-- **2026-04-10** 📓 Notebook editing tool, multiple MCP servers, Feishu streaming & done-emoji.
+- **2026-04-10** 📓 Multiple MCP servers, Feishu streaming & done-emoji.
- **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config.
- **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback.
- **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools.
diff --git a/docs/chat-apps.md b/docs/chat-apps.md
index c0c1b4ba0..88242a5f7 100644
--- a/docs/chat-apps.md
+++ b/docs/chat-apps.md
@@ -17,6 +17,7 @@ Connect nanobot to your favorite chat platform. Want to build your own? See the
| **Wecom** | Bot ID + Bot Secret |
| **Microsoft Teams** | App ID + App Password + public HTTPS endpoint |
| **Mochat** | Claw token (auto-setup available) |
+| **Signal** | signal-cli daemon + phone number |
Telegram (Recommended)
@@ -669,3 +670,69 @@ nanobot gateway
```
+
+
+Signal
+
+Uses **signal-cli** daemon in HTTP mode — receive messages via SSE, send via JSON-RPC.
+
+**1. Install signal-cli**
+
+Install [signal-cli](https://github.com/AsamK/signal-cli) and register a phone number:
+
+```bash
+signal-cli -u +1234567890 register
+signal-cli -u +1234567890 verify
+```
+
+Start the daemon:
+
+```bash
+signal-cli -a +1234567890 daemon --http localhost:8080
+```
+
+**2. Configure**
+
+```json
+{
+ "channels": {
+ "signal": {
+ "enabled": true,
+ "phoneNumber": "+1234567890",
+ "daemonHost": "localhost",
+ "daemonPort": 8080,
+ "dm": {
+ "enabled": true,
+ "policy": "open"
+ },
+ "group": {
+ "enabled": true,
+ "policy": "open",
+ "requireMention": true
+ }
+ }
+ }
+}
+```
+
+> - `phoneNumber`: Your registered Signal phone number.
+> - `daemonHost` / `daemonPort`: Where signal-cli daemon is listening (default `localhost:8080`).
+> - `dm.policy`: `"open"` (anyone can DM) or `"allowlist"` (only listed numbers/UUIDs). When `"allowlist"`, unlisted DM senders receive a pairing code.
+> - `dm.allowFrom`: List of allowed phone numbers or UUIDs (used when policy is `"allowlist"`).
+> - `group.policy`: `"open"` (all groups) or `"allowlist"` (only listed group IDs).
+> - `group.requireMention`: When `true` (default), the bot only responds in groups when @mentioned.
+> - `group.allowFrom`: List of allowed group IDs (used when group policy is `"allowlist"`).
+> - `attachmentsDir`: Override the directory where signal-cli stores inbound attachments. Defaults to `~/.local/share/signal-cli/attachments` (the Linux default). Set this if signal-cli runs with a custom `XDG_DATA_HOME` or on macOS/Windows.
+> - `groupMessageBufferSize`: Number of recent group messages kept for context (default `20`, must be > 0).
+
+**3. Run**
+
+```bash
+nanobot gateway
+```
+
+> [!TIP]
+> The channel automatically reconnects to the signal-cli daemon with exponential backoff if the connection drops.
+> Markdown in bot replies is automatically converted to Signal text styles (bold, italic, code, etc.).
+
+
diff --git a/docs/configuration.md b/docs/configuration.md
index dbd5e2626..e4fbe83eb 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -148,6 +148,7 @@ ANTHROPIC_API_KEY="$(bw get password api/anthropic)" nanobot agent
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
| `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) |
| `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) |
+| `novita` | LLM (Novita AI OpenAI-compatible gateway) | [novita.ai](https://novita.ai) |
| `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) |
| `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) |
| `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) |
diff --git a/docs/image-generation.md b/docs/image-generation.md
index 6ca7ed3fd..a9d6b620c 100644
--- a/docs/image-generation.md
+++ b/docs/image-generation.md
@@ -23,7 +23,7 @@ The feature is disabled by default. Enable it in `~/.nanobot/config.json`, confi
}
```
-See [Provider Notes](#provider-notes) for AIHubMix, MiniMax, and Gemini configuration examples.
+See [Provider Notes](#provider-notes) for AIHubMix, MiniMax, Gemini, Ollama, and StepFun configuration examples.
> [!TIP]
> Prefer environment variables for API keys. nanobot resolves `${VAR_NAME}` values from the environment at startup.
@@ -46,7 +46,7 @@ The WebUI hides provider storage details from the user. The agent sees the saved
| Option | Type | Default | Description |
|--------|------|---------|-------------|
| `tools.imageGeneration.enabled` | boolean | `false` | Register the `generate_image` tool |
-| `tools.imageGeneration.provider` | string | `"openrouter"` | Image provider name. Supported values: `openrouter`, `aihubmix`, `minimax`, `gemini`, `stepfun` |
+| `tools.imageGeneration.provider` | string | `"openrouter"` | Image provider name. Supported values: `openrouter`, `aihubmix`, `minimax`, `gemini`, `ollama`, `stepfun` |
| `tools.imageGeneration.model` | string | `"openai/gpt-5.4-image-2"` | Provider model name |
| `tools.imageGeneration.defaultAspectRatio` | string | `"1:1"` | Default ratio when the prompt/tool call does not specify one |
| `tools.imageGeneration.defaultImageSize` | string | `"1K"` | Default size hint, for example `1K`, `2K`, `4K`, or `1024x1024` |
@@ -168,6 +168,31 @@ For reference-image edits, use a Gemini Flash image model:
Imagen 4 supports the aspect ratios `1:1`, `9:16`, `16:9`, `3:4`, and `4:3`. Unsupported ratios are ignored and the model uses its default. The `defaultImageSize` setting has no effect on Gemini models; sizing is controlled by `defaultAspectRatio` only. Reference images passed with an Imagen model are ignored (with a warning logged).
+### Ollama
+
+Ollama's experimental native image generation API works with local servers and hosted ollama.com models. Local access at `http://localhost:11434/api` does not require an API key; set `providers.ollama.apiKey` only when targeting `https://ollama.com/api`.
+
+```json
+{
+ "providers": {
+ "ollama": {
+ "apiBase": "http://localhost:11434/api"
+ }
+ },
+ "tools": {
+ "imageGeneration": {
+ "enabled": true,
+ "provider": "ollama",
+ "model": "x/z-image-turbo",
+ "defaultAspectRatio": "16:9",
+ "defaultImageSize": "2K"
+ }
+ }
+}
+```
+
+Ollama maps `defaultAspectRatio` and `defaultImageSize` to native `width` and `height` values. Reference images are not supported by this integration.
+
### StepFun
StepFun (阶跃星辰) `step-image-edit-2` supports text-to-image generation. The `step-1x-medium` variant additionally supports **style-reference** image edits, where a reference image guides the visual style of the output.
@@ -274,7 +299,7 @@ Use the reference image. Keep the same robot and composition, change the palette
|---------|-------|
| `generate_image` is not available | Set `tools.imageGeneration.enabled` to `true` and restart the gateway |
| Missing API key error | Configure `providers..apiKey`; if using `${VAR_NAME}`, confirm the environment variable is visible to the gateway process |
-| `unsupported image generation provider` | Use `openrouter`, `aihubmix`, `minimax`, `gemini`, or `stepfun` |
+| `unsupported image generation provider` | Use `openrouter`, `aihubmix`, `minimax`, `gemini`, `ollama`, or `stepfun` |
| AIHubMix says `Incorrect model ID` | Use `model: "gpt-image-2-free"`; nanobot expands it to the required `openai/gpt-image-2-free` model path internally |
| Generation times out | Try a smaller/default image size, set AIHubMix `extraBody.quality` to `"low"`, or retry later |
| Reference image rejected | Reference image paths must be inside the workspace or nanobot media directory and must be valid image files |
diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py
index 19ee935c4..82ebfab65 100644
--- a/nanobot/agent/context.py
+++ b/nanobot/agent/context.py
@@ -22,7 +22,7 @@ from nanobot.utils.prompt_templates import render_template
class ContextBuilder:
"""Builds the context (system prompt + messages) for the agent."""
- BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
+ BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md"]
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
_MAX_RECENT_HISTORY = 50
_MAX_HISTORY_CHARS = 32_000 # hard cap on recent history section size
@@ -47,6 +47,8 @@ class ContextBuilder:
if bootstrap:
parts.append(bootstrap)
+ parts.append(render_template("agent/tool_contract.md"))
+
memory = self.memory.get_memory_context()
if memory and not self._is_template_content(self.memory.read_memory(), "memory/MEMORY.md"):
parts.append(f"# Memory\n\n{memory}")
@@ -210,4 +212,3 @@ class ContextBuilder:
if not images:
return text
return images + [{"type": "text", "text": text}]
-
diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py
index 0b0164fd0..19494034f 100644
--- a/nanobot/agent/runner.py
+++ b/nanobot/agent/runner.py
@@ -19,7 +19,8 @@ from nanobot.utils.file_edit_events import (
build_file_edit_end_event,
build_file_edit_error_event,
build_file_edit_start_event,
- prepare_file_edit_tracker,
+ prepare_file_edit_tracker as _prepare_file_edit_tracker,
+ prepare_file_edit_trackers,
StreamingFileEditTracker,
)
from nanobot.utils.helpers import (
@@ -58,11 +59,14 @@ _SNIP_SAFETY_BUFFER = 1024
_MICROCOMPACT_KEEP_RECENT = 10
_MICROCOMPACT_MIN_CHARS = 500
_COMPACTABLE_TOOLS = frozenset({
- "read_file", "exec", "grep",
- "web_search", "web_fetch", "list_dir",
+ "read_file", "exec", "grep", "find_files",
+ "web_search", "web_fetch", "list_dir", "list_exec_sessions",
})
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
+# Backward-compatible module attribute for tests/extensions that monkeypatch
+# the former single-file tracker hook. Runtime uses prepare_file_edit_trackers.
+prepare_file_edit_tracker = _prepare_file_edit_tracker
@dataclass(slots=True)
@@ -857,8 +861,8 @@ class AgentRunner:
and on_progress_accepts_file_edit_events(spec.progress_callback)
)
progress_callback = spec.progress_callback if emit_file_edit_events else None
- file_edit_tracker = (
- prepare_file_edit_tracker(
+ file_edit_trackers = (
+ prepare_file_edit_trackers(
call_id=tool_call.id,
tool_name=tool_call.name,
tool=tool,
@@ -868,13 +872,13 @@ class AgentRunner:
if progress_callback is not None
else None
)
- if file_edit_tracker is not None and progress_callback is not None:
+ if file_edit_trackers and progress_callback is not None:
await invoke_file_edit_progress(
progress_callback,
[build_file_edit_start_event(
file_edit_tracker,
params if isinstance(params, dict) else None,
- )],
+ ) for file_edit_tracker in file_edit_trackers],
)
try:
if tool is not None:
@@ -884,10 +888,13 @@ class AgentRunner:
except asyncio.CancelledError:
raise
except BaseException as exc:
- if file_edit_tracker is not None and progress_callback is not None:
+ if file_edit_trackers and progress_callback is not None:
await invoke_file_edit_progress(
progress_callback,
- [build_file_edit_error_event(file_edit_tracker, str(exc))],
+ [
+ build_file_edit_error_event(file_edit_tracker, str(exc))
+ for file_edit_tracker in file_edit_trackers
+ ],
)
event = {
"name": tool_call.name,
@@ -910,10 +917,13 @@ class AgentRunner:
return payload, event, None
if isinstance(result, str) and result.startswith("Error"):
- if file_edit_tracker is not None and progress_callback is not None:
+ if file_edit_trackers and progress_callback is not None:
await invoke_file_edit_progress(
progress_callback,
- [build_file_edit_error_event(file_edit_tracker, result)],
+ [
+ build_file_edit_error_event(file_edit_tracker, result)
+ for file_edit_tracker in file_edit_trackers
+ ],
)
event = {
"name": tool_call.name,
@@ -933,13 +943,13 @@ class AgentRunner:
return result + hint, event, RuntimeError(result)
return result + hint, event, None
- if file_edit_tracker is not None and progress_callback is not None:
+ if file_edit_trackers and progress_callback is not None:
await invoke_file_edit_progress(
progress_callback,
[build_file_edit_end_event(
file_edit_tracker,
params if isinstance(params, dict) else None,
- )],
+ ) for file_edit_tracker in file_edit_trackers],
)
detail = "" if result is None else str(result)
diff --git a/nanobot/agent/tools/apply_patch.py b/nanobot/agent/tools/apply_patch.py
new file mode 100644
index 000000000..ac524f7fc
--- /dev/null
+++ b/nanobot/agent/tools/apply_patch.py
@@ -0,0 +1,352 @@
+"""Apply file edits by providing structured edit instructions."""
+
+from __future__ import annotations
+
+import difflib
+import re
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+from nanobot.agent.tools.base import tool_parameters
+from nanobot.agent.tools.filesystem import _FsTool
+from nanobot.agent.tools.schema import (
+ ArraySchema,
+ BooleanSchema,
+ ObjectSchema,
+ StringSchema,
+ tool_parameters_schema,
+)
+
+
+@dataclass(slots=True)
+class _PatchSummary:
+ action: str
+ path: str
+ added: int = 0
+ deleted: int = 0
+
+
+class _PatchError(ValueError):
+ pass
+
+
+_ABSOLUTE_WINDOWS_RE = re.compile(r"^[A-Za-z]:[\\/]")
+
+
+def _validate_relative_path(path: str) -> str:
+ normalized = path.strip()
+ if not normalized:
+ raise _PatchError("patch path cannot be empty")
+ if "\0" in normalized:
+ raise _PatchError(f"patch path contains a null byte: {path!r}")
+ if normalized.startswith(("~", "/", "\\")) or _ABSOLUTE_WINDOWS_RE.match(normalized):
+ raise _PatchError(f"patch path must be relative: {path}")
+ if any(part == ".." for part in re.split(r"[\\/]+", normalized)):
+ raise _PatchError(f"patch path must not contain '..': {path}")
+ return normalized
+
+
+def _lines_to_text(lines: list[str]) -> str:
+ if not lines:
+ return ""
+ return "\n".join(lines) + "\n"
+
+
+def _text_line_count(text: str) -> int:
+ if not text:
+ return 0
+ return len(text.splitlines())
+
+
+def _line_diff_stats(before: str, after: str) -> tuple[int, int]:
+ before_lines = before.replace("\r\n", "\n").splitlines()
+ after_lines = after.replace("\r\n", "\n").splitlines()
+ added = 0
+ deleted = 0
+ matcher = difflib.SequenceMatcher(a=before_lines, b=after_lines, autojunk=False)
+ for tag, i1, i2, j1, j2 in matcher.get_opcodes():
+ if tag == "equal":
+ continue
+ if tag in ("replace", "delete"):
+ deleted += i2 - i1
+ if tag in ("replace", "insert"):
+ added += j2 - j1
+ return added, deleted
+
+
+def _format_summary(summary: _PatchSummary) -> str:
+ stats = ""
+ if summary.added or summary.deleted:
+ stats = f" (+{summary.added}/-{summary.deleted})"
+ return f"- {summary.action} {summary.path}{stats}"
+
+
+@tool_parameters(
+ tool_parameters_schema(
+ edits=ArraySchema(
+ items=ObjectSchema(
+ path=StringSchema("Relative path to the file to edit."),
+ action=StringSchema(
+ "Operation type: replace (find and replace text), add (append new content or create file), delete (remove text).",
+ enum=["replace", "add", "delete"],
+ ),
+ old_text=StringSchema(
+ "Exact text to search for in the file. Required for replace and delete.",
+ nullable=True,
+ ),
+ new_text=StringSchema(
+ "Text to replace with or append. Required for replace and add.",
+ nullable=True,
+ ),
+ required=["path", "action"],
+ ),
+ description="List of edits to apply. Each edit specifies a file and the change to make.",
+ min_items=1,
+ max_items=20,
+ ),
+ dry_run=BooleanSchema(
+ description="Validate and summarize the patch without writing files.",
+ default=False,
+ ),
+ required=["edits"],
+ )
+)
+class ApplyPatchTool(_FsTool):
+ """Apply file edits by providing structured edit instructions."""
+ _scopes = {"core", "subagent"}
+
+ @property
+ def name(self) -> str:
+ return "apply_patch"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Default tool for code edits. Supports multi-file changes in a single call. "
+ "Provide a list of structured edits, each specifying a file path, action (replace/add/delete), and the text to change. "
+ "Paths must be relative. Set dry_run=true to validate and preview without writing files. "
+ "Use edit_file only for small exact replacements on a single file."
+ )
+
+ async def execute(
+ self,
+ edits: list[dict] | None = None,
+ dry_run: bool = False,
+ **kwargs: Any,
+ ) -> str:
+ try:
+ if not edits:
+ raise _PatchError("must provide edits")
+
+ writes: dict[Path, str] = {}
+ deletes: set[Path] = set()
+ summaries: list[_PatchSummary] = []
+
+ for edit in edits:
+ if not isinstance(edit, dict):
+ raise _PatchError("each edit must be an object")
+ raw_path = edit.get("path")
+ if not isinstance(raw_path, str):
+ raise _PatchError("path required for edit")
+ path = _validate_relative_path(raw_path)
+ action = edit.get("action")
+ if not isinstance(action, str):
+ raise _PatchError(f"action required for edit: {path}")
+ source = self._resolve(path)
+
+ if action == "add":
+ new_text = edit.get("new_text")
+ if new_text is None:
+ raise _PatchError(f"new_text required for add: {path}")
+
+ pending = writes.get(source)
+ if pending is not None:
+ content = pending
+ exists = True
+ elif source.exists():
+ raw = source.read_bytes()
+ try:
+ content = raw.decode("utf-8")
+ except UnicodeDecodeError:
+ raise _PatchError(f"file is not UTF-8 text: {path}")
+ exists = True
+ else:
+ content = ""
+ exists = False
+
+ if exists:
+ uses_crlf = "\r\n" in content
+ new_norm = content.replace("\r\n", "\n") + new_text.replace("\r\n", "\n")
+ if new_norm and not new_norm.endswith("\n"):
+ new_norm += "\n"
+ if uses_crlf:
+ new_norm = new_norm.replace("\n", "\r\n")
+ writes[source] = new_norm
+ deletes.discard(source)
+ added, deleted = _line_diff_stats(content, new_norm)
+ action_name = "update"
+ else:
+ new_norm = new_text.replace("\r\n", "\n")
+ if new_norm and not new_norm.endswith("\n"):
+ new_norm += "\n"
+ writes[source] = new_norm
+ deletes.discard(source)
+ added = _text_line_count(new_norm)
+ deleted = 0
+ action_name = "add"
+
+ summaries.append(
+ _PatchSummary(
+ action=action_name, path=path, added=added, deleted=deleted
+ )
+ )
+
+ elif action == "replace":
+ old_text = edit.get("old_text") or ""
+ if not old_text:
+ raise _PatchError(f"old_text required for replace: {path}")
+ new_text = edit.get("new_text")
+ if new_text is None:
+ raise _PatchError(f"new_text required for replace: {path}")
+
+ pending = writes.get(source)
+ if pending is not None:
+ content = pending
+ elif source.exists():
+ raw = source.read_bytes()
+ try:
+ content = raw.decode("utf-8")
+ except UnicodeDecodeError:
+ raise _PatchError(f"file is not UTF-8 text: {path}")
+ else:
+ raise _PatchError(f"file to update does not exist: {path}")
+
+ if pending is None and not source.is_file():
+ raise _PatchError(f"path to update is not a file: {path}")
+
+ uses_crlf = "\r\n" in content
+ norm_content = content.replace("\r\n", "\n")
+ norm_old = old_text.replace("\r\n", "\n")
+
+ pos = norm_content.find(norm_old)
+ if pos < 0:
+ raise _PatchError(f"old_text not found in {path}")
+ if norm_content.find(norm_old, pos + 1) >= 0:
+ raise _PatchError(f"old_text appears multiple times in {path}")
+
+ new_norm = (
+ norm_content[:pos]
+ + new_text.replace("\r\n", "\n")
+ + norm_content[pos + len(norm_old) :]
+ )
+ if new_norm and not new_norm.endswith("\n"):
+ new_norm += "\n"
+ if uses_crlf:
+ new_norm = new_norm.replace("\n", "\r\n")
+
+ writes[source] = new_norm
+ deletes.discard(source)
+ added, deleted = _line_diff_stats(content, new_norm)
+ summaries.append(
+ _PatchSummary(
+ action="update", path=path, added=added, deleted=deleted
+ )
+ )
+
+ elif action == "delete":
+ old_text = edit.get("old_text") or ""
+ if not old_text:
+ raise _PatchError(f"old_text required for delete: {path}")
+
+ pending = writes.get(source)
+ if pending is not None:
+ content = pending
+ elif source.exists():
+ raw = source.read_bytes()
+ try:
+ content = raw.decode("utf-8")
+ except UnicodeDecodeError:
+ raise _PatchError(f"file is not UTF-8 text: {path}")
+ else:
+ raise _PatchError(f"file to update does not exist: {path}")
+
+ if pending is None and not source.is_file():
+ raise _PatchError(f"path to update is not a file: {path}")
+
+ uses_crlf = "\r\n" in content
+ norm_content = content.replace("\r\n", "\n")
+ norm_old = old_text.replace("\r\n", "\n")
+
+ pos = norm_content.find(norm_old)
+ if pos < 0:
+ raise _PatchError(f"old_text not found in {path}")
+ if norm_content.find(norm_old, pos + 1) >= 0:
+ raise _PatchError(f"old_text appears multiple times in {path}")
+
+ if norm_old == norm_content:
+ deletes.add(source)
+ writes.pop(source, None)
+ added, deleted = 0, _text_line_count(content)
+ summaries.append(
+ _PatchSummary(
+ action="delete", path=path, added=added, deleted=deleted
+ )
+ )
+ else:
+ new_norm = (
+ norm_content[:pos] + norm_content[pos + len(norm_old) :]
+ )
+ if new_norm and not new_norm.endswith("\n"):
+ new_norm += "\n"
+ if uses_crlf:
+ new_norm = new_norm.replace("\n", "\r\n")
+ writes[source] = new_norm
+ deletes.discard(source)
+ added, deleted = _line_diff_stats(content, new_norm)
+ summaries.append(
+ _PatchSummary(
+ action="update", path=path, added=added, deleted=deleted
+ )
+ )
+
+ else:
+ raise _PatchError(f"unknown action: {action}")
+
+ if dry_run:
+ return "Patch dry-run succeeded:\n" + "\n".join(
+ _format_summary(summary) for summary in summaries
+ )
+
+ backups: dict[Path, bytes | None] = {}
+ for path in set(writes) | deletes:
+ backups[path] = path.read_bytes() if path.exists() else None
+
+ try:
+ for path in deletes:
+ if path.exists():
+ path.unlink()
+ for path, content in writes.items():
+ path.parent.mkdir(parents=True, exist_ok=True)
+ path.write_text(content, encoding="utf-8", newline="")
+ except Exception:
+ for path, data in backups.items():
+ if data is None:
+ if path.exists():
+ path.unlink()
+ else:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ path.write_bytes(data)
+ raise
+
+ for path in set(writes) | deletes:
+ self._file_states.record_write(path)
+ return "Patch applied:\n" + "\n".join(
+ _format_summary(summary) for summary in summaries
+ )
+ except PermissionError as exc:
+ return f"Error: {exc}"
+ except _PatchError as exc:
+ return f"Error applying patch: {exc}"
+ except Exception as exc:
+ return f"Error applying patch: {exc}"
diff --git a/nanobot/agent/tools/exec_session.py b/nanobot/agent/tools/exec_session.py
new file mode 100644
index 000000000..4dadb2d36
--- /dev/null
+++ b/nanobot/agent/tools/exec_session.py
@@ -0,0 +1,591 @@
+"""Session support for long-running exec workflows."""
+
+from __future__ import annotations
+
+import asyncio
+import shutil
+import time
+import uuid
+from contextlib import suppress
+from dataclasses import dataclass
+from typing import Any
+
+from nanobot.agent.tools.base import Tool, tool_parameters
+from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
+
+
+DEFAULT_YIELD_MS = 1000
+MAX_YIELD_MS = 30_000
+DEFAULT_WAIT_FOR_MS = 10_000
+MAX_WAIT_FOR_MS = 120_000
+DEFAULT_MAX_OUTPUT_CHARS = 10_000
+MAX_OUTPUT_CHARS = 50_000
+
+
+@dataclass(slots=True)
+class _SessionPoll:
+ output: str
+ done: bool
+ exit_code: int | None
+ elapsed_s: float = 0.0
+ timed_out: bool = False
+ terminated: bool = False
+ stdin_closed: bool = False
+ truncated_chars: int = 0
+
+
+@dataclass(slots=True)
+class ExecSessionInfo:
+ session_id: str
+ command: str
+ cwd: str
+ elapsed_s: float
+ idle_s: float
+ remaining_s: float
+ returncode: int | None
+
+
+class _ExecSession:
+ def __init__(
+ self,
+ *,
+ session_id: str,
+ process: asyncio.subprocess.Process,
+ command: str,
+ cwd: str,
+ timeout: int,
+ ) -> None:
+ self.session_id = session_id
+ self.process = process
+ self.command = command
+ self.cwd = cwd
+ self.started_at = time.monotonic()
+ self.deadline = time.monotonic() + timeout
+ self.last_access = time.monotonic()
+ self._chunks: list[str] = []
+ self._lock = asyncio.Lock()
+ self._timed_out = False
+ self._stdout_task = asyncio.create_task(self._read_stream(process.stdout, ""))
+ self._stderr_task = asyncio.create_task(self._read_stream(process.stderr, "STDERR:\n"))
+
+ async def _read_stream(
+ self,
+ stream: asyncio.StreamReader | None,
+ prefix: str,
+ ) -> None:
+ if stream is None:
+ return
+ first = True
+ while True:
+ chunk = await stream.read(4096)
+ if not chunk:
+ break
+ text = chunk.decode("utf-8", errors="replace")
+ if prefix and first:
+ text = prefix + text
+ first = False
+ async with self._lock:
+ self._chunks.append(text)
+
+ async def write(self, chars: str) -> str | None:
+ if self.process.returncode is not None:
+ return "session has already exited"
+ if self.process.stdin is None:
+ return "session stdin is not available"
+ try:
+ self.process.stdin.write(chars.encode("utf-8"))
+ await self.process.stdin.drain()
+ except (BrokenPipeError, ConnectionResetError):
+ return "session stdin is closed"
+ return None
+
+ async def close_stdin(self) -> str | None:
+ if self.process.returncode is not None:
+ return "session has already exited"
+ if self.process.stdin is None:
+ return "session stdin is not available"
+ self.process.stdin.close()
+ with suppress(BrokenPipeError, ConnectionResetError):
+ await self.process.stdin.wait_closed()
+ return None
+
+ async def poll(
+ self,
+ yield_time_ms: int,
+ max_output_chars: int,
+ *,
+ terminated: bool = False,
+ stdin_closed: bool = False,
+ ) -> _SessionPoll:
+ self.last_access = time.monotonic()
+ if yield_time_ms > 0 and self.process.returncode is None:
+ await asyncio.sleep(min(yield_time_ms, MAX_YIELD_MS) / 1000)
+
+ if self.process.returncode is None and time.monotonic() >= self.deadline:
+ self._timed_out = True
+ await self.kill()
+
+ if self.process.returncode is not None:
+ with suppress(asyncio.TimeoutError):
+ await asyncio.wait_for(
+ asyncio.gather(self._stdout_task, self._stderr_task),
+ timeout=2.0,
+ )
+
+ async with self._lock:
+ output = "".join(self._chunks)
+ self._chunks.clear()
+
+ output, truncated = _truncate_output(output, max_output_chars)
+ return _SessionPoll(
+ output=output,
+ done=self.process.returncode is not None,
+ exit_code=self.process.returncode,
+ elapsed_s=max(0.0, time.monotonic() - self.started_at),
+ timed_out=self._timed_out,
+ terminated=terminated,
+ stdin_closed=stdin_closed,
+ truncated_chars=truncated,
+ )
+
+ async def kill(self) -> None:
+ if self.process.returncode is not None:
+ return
+ self.process.kill()
+ with suppress(asyncio.TimeoutError):
+ await asyncio.wait_for(self.process.wait(), timeout=5.0)
+
+
+class ExecSessionManager:
+ def __init__(self, *, max_sessions: int = 8, idle_timeout: int = 1800) -> None:
+ self.max_sessions = max_sessions
+ self.idle_timeout = idle_timeout
+ self._sessions: dict[str, _ExecSession] = {}
+ self._lock = asyncio.Lock()
+
+ async def start(
+ self,
+ *,
+ command: str,
+ cwd: str,
+ env: dict[str, str],
+ timeout: int,
+ shell_program: str | None,
+ login: bool,
+ yield_time_ms: int,
+ max_output_chars: int,
+ ) -> tuple[str, _SessionPoll]:
+ async with self._lock:
+ await self._cleanup_locked()
+ if len(self._sessions) >= self.max_sessions:
+ raise RuntimeError(f"maximum exec sessions reached ({self.max_sessions})")
+ process = await self._spawn(command, cwd, env, shell_program, login)
+ session_id = uuid.uuid4().hex[:12]
+ session = _ExecSession(
+ session_id=session_id,
+ process=process,
+ command=command,
+ cwd=cwd,
+ timeout=timeout,
+ )
+ self._sessions[session_id] = session
+
+ poll = await session.poll(yield_time_ms, max_output_chars)
+ if poll.done:
+ async with self._lock:
+ self._sessions.pop(session_id, None)
+ return session_id, poll
+
+ async def write(
+ self,
+ *,
+ session_id: str,
+ chars: str | None,
+ close_stdin: bool,
+ terminate: bool,
+ yield_time_ms: int,
+ max_output_chars: int,
+ ) -> _SessionPoll:
+ async with self._lock:
+ await self._cleanup_locked()
+ session = self._sessions.get(session_id)
+ if session is None:
+ raise KeyError(session_id)
+
+ if chars:
+ error = await session.write(chars)
+ if error:
+ raise RuntimeError(error)
+ stdin_closed = False
+ if close_stdin:
+ error = await session.close_stdin()
+ if error:
+ raise RuntimeError(error)
+ stdin_closed = True
+ if terminate:
+ await session.kill()
+ poll = await session.poll(
+ yield_time_ms,
+ max_output_chars,
+ terminated=terminate,
+ stdin_closed=stdin_closed,
+ )
+ if poll.done:
+ async with self._lock:
+ self._sessions.pop(session_id, None)
+ return poll
+
+ async def list(self) -> list[ExecSessionInfo]:
+ async with self._lock:
+ await self._cleanup_locked()
+ now = time.monotonic()
+ return [
+ ExecSessionInfo(
+ session_id=session_id,
+ command=session.command,
+ cwd=session.cwd,
+ elapsed_s=max(0.0, now - session.started_at),
+ idle_s=max(0.0, now - session.last_access),
+ remaining_s=max(0.0, session.deadline - now),
+ returncode=session.process.returncode,
+ )
+ for session_id, session in sorted(self._sessions.items())
+ ]
+
+ async def _cleanup_locked(self) -> None:
+ now = time.monotonic()
+ stale = [
+ session_id
+ for session_id, session in self._sessions.items()
+ if now - session.last_access > self.idle_timeout
+ ]
+ for session_id in stale:
+ session = self._sessions.pop(session_id)
+ await session.kill()
+
+ async def _spawn(
+ self,
+ command: str,
+ cwd: str,
+ env: dict[str, str],
+ shell_program: str | None,
+ login: bool,
+ ) -> asyncio.subprocess.Process:
+ from nanobot.agent.tools import shell
+
+ if shell._IS_WINDOWS:
+ return await asyncio.create_subprocess_shell(
+ command,
+ stdin=asyncio.subprocess.PIPE,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ cwd=cwd,
+ env=env,
+ )
+ shell_program = shell_program or shutil.which("bash") or "/bin/bash"
+ args = [shell_program]
+ if login and shell_program.rsplit("/", 1)[-1] in {"bash", "zsh"}:
+ args.append("-l")
+ args.extend(["-c", command])
+ return await asyncio.create_subprocess_exec(
+ *args,
+ stdin=asyncio.subprocess.PIPE,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ cwd=cwd,
+ env=env,
+ )
+
+
+DEFAULT_EXEC_SESSION_MANAGER = ExecSessionManager()
+
+
+def clamp_session_int(value: int | None, default: int, minimum: int, maximum: int) -> int:
+ if value is None:
+ return default
+ return min(max(value, minimum), maximum)
+
+
+def _truncate_output(output: str, max_output_chars: int) -> tuple[str, int]:
+ if len(output) <= max_output_chars:
+ return output, 0
+ half = max_output_chars // 2
+ omitted = len(output) - max_output_chars
+ return (
+ output[:half]
+ + f"\n\n... ({omitted:,} chars truncated) ...\n\n"
+ + output[-half:],
+ omitted,
+ )
+
+
+def format_session_poll(session_id: str, poll: _SessionPoll) -> str:
+ parts = [poll.output] if poll.output else []
+ if poll.truncated_chars:
+ parts.append(f"(output truncated by {poll.truncated_chars:,} chars)")
+ if poll.timed_out:
+ parts.append("Error: Command timed out; session was terminated.")
+ if poll.terminated and not poll.timed_out:
+ parts.append("Session terminated.")
+ if poll.stdin_closed:
+ parts.append("Stdin closed.")
+ if poll.done:
+ parts.append(f"Exit code: {poll.exit_code}")
+ else:
+ parts.append(f"Process running. session_id: {session_id}")
+ parts.append(f"Elapsed: {poll.elapsed_s:.1f}s")
+ return "\n".join(parts) if parts else "(no output yet)"
+
+
+@tool_parameters(
+ tool_parameters_schema(
+ session_id=StringSchema("Session id returned by exec when yield_time_ms is used."),
+ chars=StringSchema(
+ "Bytes/text to write to stdin. Omit or pass an empty string to only poll recent output.",
+ nullable=True,
+ ),
+ close_stdin=BooleanSchema(
+ description="Close stdin after writing chars. Useful for commands waiting for EOF.",
+ default=False,
+ ),
+ terminate=BooleanSchema(
+ description="Terminate the running exec session.",
+ default=False,
+ ),
+ yield_time_ms=IntegerSchema(
+ DEFAULT_YIELD_MS,
+ description="Milliseconds to wait before returning recent output (default 1000, max 30000).",
+ minimum=0,
+ maximum=MAX_YIELD_MS,
+ ),
+ wait_for=StringSchema(
+ "Optional text to wait for in output before returning. "
+ "Useful for interactive commands and dev servers.",
+ nullable=True,
+ ),
+ wait_timeout_ms=IntegerSchema(
+ DEFAULT_WAIT_FOR_MS,
+ description="Maximum milliseconds to wait for wait_for text (default 10000, max 120000).",
+ minimum=0,
+ maximum=MAX_WAIT_FOR_MS,
+ nullable=True,
+ ),
+ max_output_chars=IntegerSchema(
+ DEFAULT_MAX_OUTPUT_CHARS,
+ description="Maximum output characters to return from this poll (default 10000, max 50000).",
+ minimum=1000,
+ maximum=MAX_OUTPUT_CHARS,
+ ),
+ max_output_tokens=IntegerSchema(
+ DEFAULT_MAX_OUTPUT_CHARS,
+ description="Compatibility alias for max_output_chars. The current runtime uses a character budget.",
+ minimum=1000,
+ maximum=MAX_OUTPUT_CHARS,
+ nullable=True,
+ ),
+ required=["session_id"],
+ )
+)
+class WriteStdinTool(Tool):
+ """Write to or poll a running exec session."""
+
+ _scopes = {"core", "subagent"}
+ config_key = "exec"
+
+ @classmethod
+ def config_cls(cls):
+ from nanobot.agent.tools.shell import ExecToolConfig
+
+ return ExecToolConfig
+
+ @classmethod
+ def enabled(cls, ctx: Any) -> bool:
+ return ctx.config.exec.enable
+
+ def __init__(
+ self,
+ *,
+ manager: ExecSessionManager | None = None,
+ ) -> None:
+ self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER
+
+ @classmethod
+ def create(cls, ctx: Any) -> Tool:
+ return cls()
+
+ @property
+ def exclusive(self) -> bool:
+ return True
+
+ @property
+ def name(self) -> str:
+ return "write_stdin"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Interact with a running exec session created by exec with "
+ "yield_time_ms. Use chars='' to poll without writing, chars to send "
+ "stdin, close_stdin=true to send EOF, or terminate=true to stop the "
+ "process. Use wait_for with wait_timeout_ms for dev servers, test "
+ "watchers, and prompts where you need to wait for expected output. "
+ "Do not use this to start new commands; start them with exec."
+ )
+
+ async def execute(
+ self,
+ session_id: str,
+ chars: str | None = None,
+ close_stdin: bool = False,
+ terminate: bool = False,
+ yield_time_ms: int | None = None,
+ wait_for: str | None = None,
+ wait_timeout_ms: int | None = None,
+ max_output_chars: int | None = None,
+ max_output_tokens: int | None = None,
+ **kwargs: Any,
+ ) -> str:
+ try:
+ if max_output_chars is None:
+ max_output_chars = max_output_tokens
+ output_limit = clamp_session_int(
+ max_output_chars,
+ DEFAULT_MAX_OUTPUT_CHARS,
+ 1000,
+ MAX_OUTPUT_CHARS,
+ )
+ if wait_for:
+ return await self._wait_for_output(
+ session_id=session_id,
+ chars=chars,
+ close_stdin=close_stdin,
+ terminate=terminate,
+ wait_for=wait_for,
+ wait_timeout_ms=clamp_session_int(
+ wait_timeout_ms,
+ DEFAULT_WAIT_FOR_MS,
+ 0,
+ MAX_WAIT_FOR_MS,
+ ),
+ max_output_chars=output_limit,
+ )
+ poll = await self._manager.write(
+ session_id=session_id,
+ chars=chars,
+ close_stdin=close_stdin,
+ terminate=terminate,
+ yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
+ max_output_chars=output_limit,
+ )
+ return format_session_poll(session_id, poll)
+ except KeyError:
+ return f"Error: exec session not found: {session_id}"
+ except Exception as exc:
+ return f"Error writing to exec session: {exc}"
+
+ async def _wait_for_output(
+ self,
+ *,
+ session_id: str,
+ chars: str | None,
+ close_stdin: bool,
+ terminate: bool,
+ wait_for: str,
+ wait_timeout_ms: int,
+ max_output_chars: int,
+ ) -> str:
+ deadline = time.monotonic() + (wait_timeout_ms / 1000)
+ aggregate: list[str] = []
+ first = True
+ poll: _SessionPoll | None = None
+
+ while True:
+ remaining_ms = max(0, int((deadline - time.monotonic()) * 1000))
+ step_ms = min(500, remaining_ms)
+ poll = await self._manager.write(
+ session_id=session_id,
+ chars=chars if first else None,
+ close_stdin=close_stdin if first else False,
+ terminate=terminate if first else False,
+ yield_time_ms=step_ms,
+ max_output_chars=max_output_chars,
+ )
+ first = False
+ if poll.output:
+ aggregate.append(poll.output)
+ joined = "".join(aggregate)
+ if wait_for in joined:
+ poll.output = joined
+ return format_session_poll(session_id, poll)
+ if poll.done or remaining_ms <= 0:
+ poll.output = "".join(aggregate)
+ result = format_session_poll(session_id, poll)
+ if wait_for not in poll.output:
+ result += f"\nWait target not observed: {wait_for!r}"
+ return result
+
+
+@tool_parameters(tool_parameters_schema())
+class ListExecSessionsTool(Tool):
+ """List active exec sessions."""
+
+ _scopes = {"core", "subagent"}
+ config_key = "exec"
+
+ @classmethod
+ def config_cls(cls):
+ from nanobot.agent.tools.shell import ExecToolConfig
+
+ return ExecToolConfig
+
+ @classmethod
+ def enabled(cls, ctx: Any) -> bool:
+ return ctx.config.exec.enable
+
+ def __init__(
+ self,
+ *,
+ manager: ExecSessionManager | None = None,
+ ) -> None:
+ self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER
+
+ @classmethod
+ def create(cls, ctx: Any) -> Tool:
+ return cls()
+
+ @property
+ def name(self) -> str:
+ return "list_exec_sessions"
+
+ @property
+ def description(self) -> str:
+ return (
+ "List active long-running exec sessions, including session_id, cwd, "
+ "elapsed time, idle time, remaining timeout, and command preview. "
+ "Use this to recover a session_id after context shifts before "
+ "polling, writing stdin, or terminating with write_stdin."
+ )
+
+ @property
+ def read_only(self) -> bool:
+ return True
+
+ async def execute(self, **kwargs: Any) -> str:
+ try:
+ sessions = await self._manager.list()
+ if not sessions:
+ return "No active exec sessions."
+ lines = []
+ for info in sessions:
+ command = " ".join(info.command.split())
+ if len(command) > 120:
+ command = command[:119] + "..."
+ status = "exited" if info.returncode is not None else "running"
+ lines.append(
+ f"{info.session_id} | {status} | elapsed={info.elapsed_s:.1f}s "
+ f"| idle={info.idle_s:.1f}s | remaining={info.remaining_s:.1f}s "
+ f"| cwd={info.cwd} | {command}"
+ )
+ return "\n".join(lines)
+ except Exception as exc:
+ return f"Error listing exec sessions: {exc}"
diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py
index 8f4f660da..fa63e5f66 100644
--- a/nanobot/agent/tools/filesystem.py
+++ b/nanobot/agent/tools/filesystem.py
@@ -132,6 +132,10 @@ def _parse_page_range(pages: str, total: int) -> tuple[int, int]:
minimum=1,
),
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
+ force=BooleanSchema(
+ description="Bypass same-file read deduplication and return content again.",
+ default=False,
+ ),
required=["path"],
)
)
@@ -154,7 +158,11 @@ class ReadFileTool(_FsTool):
"Text output format: LINE_NUM|CONTENT. "
"Images return visual content for analysis. "
"Supports PDF, DOCX, XLSX, PPTX documents. "
+ "Use find_files/list_dir first when the path is uncertain. "
+ "Read the relevant range before editing so replacements or patches "
+ "are based on current content. "
"Use offset and limit for large text files. "
+ "Use force=true to re-read content even if unchanged. "
"Reads exceeding ~128K chars are truncated."
)
@@ -162,7 +170,15 @@ class ReadFileTool(_FsTool):
def read_only(self) -> bool:
return True
- async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any:
+ async def execute(
+ self,
+ path: str | None = None,
+ offset: int = 1,
+ limit: int | None = None,
+ pages: str | None = None,
+ force: bool = False,
+ **kwargs: Any,
+ ) -> Any:
try:
if not path:
return "Error reading file: Unknown path"
@@ -202,7 +218,13 @@ class ReadFileTool(_FsTool):
current_mtime = os.path.getmtime(fp)
except OSError:
current_mtime = 0.0
- if entry and entry.can_dedup and entry.offset == offset and entry.limit == limit:
+ if (
+ not force
+ and entry
+ and entry.can_dedup
+ and entry.offset == offset
+ and entry.limit == limit
+ ):
if current_mtime != entry.mtime:
# File was modified externally - force full read and mark as not dedupable
entry.can_dedup = False
@@ -365,9 +387,10 @@ class WriteFileTool(_FsTool):
@property
def description(self) -> str:
return (
- "Write content to a file. Overwrites if the file already exists; "
- "creates parent directories as needed. "
- "For partial edits, prefer edit_file instead."
+ "Create a new file or intentionally replace an entire file with "
+ "the provided content. Overwrites existing files and creates parent "
+ "directories as needed. For code changes or partial edits, prefer "
+ "apply_patch; use edit_file only for small exact replacements."
)
async def execute(self, path: str | None = None, content: str | None = None, **kwargs: Any) -> str:
@@ -657,6 +680,24 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
old_text=StringSchema("The text to find and replace"),
new_text=StringSchema("The text to replace with"),
replace_all=BooleanSchema(description="Replace all occurrences (default false)"),
+ occurrence=IntegerSchema(
+ 1,
+ description="Optional 1-based occurrence to replace when old_text appears multiple times.",
+ minimum=1,
+ nullable=True,
+ ),
+ line_hint=IntegerSchema(
+ 1,
+ description="Optional 1-based line hint used to choose the nearest match.",
+ minimum=1,
+ nullable=True,
+ ),
+ expected_replacements=IntegerSchema(
+ 1,
+ description="Optional guard for the number of replacements that must be made.",
+ minimum=1,
+ nullable=True,
+ ),
required=["path", "old_text", "new_text"],
)
)
@@ -674,10 +715,13 @@ class EditFileTool(_FsTool):
@property
def description(self) -> str:
return (
- "Edit a file by replacing old_text with new_text. "
- "Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. "
- "If old_text matches multiple times, you must provide more context "
- "or set replace_all=true. Shows a diff of the closest match on failure."
+ "Perform a small, exact replacement in one file by replacing "
+ "old_text with new_text. Use this for narrow text substitutions "
+ "with old_text copied from read_file. For multi-file, structural, "
+ "or generated code edits, prefer apply_patch. If old_text matches "
+ "multiple times, provide more context or set occurrence, line_hint, "
+ "replace_all, and expected_replacements. Shows closest-match "
+ "diagnostics on failure."
)
@staticmethod
@@ -688,7 +732,8 @@ class EditFileTool(_FsTool):
async def execute(
self, path: str | None = None, old_text: str | None = None,
new_text: str | None = None,
- replace_all: bool = False, **kwargs: Any,
+ replace_all: bool = False, occurrence: int | None = None,
+ line_hint: int | None = None, expected_replacements: int | None = None, **kwargs: Any,
) -> str:
try:
if not path:
@@ -697,10 +742,12 @@ class EditFileTool(_FsTool):
raise ValueError("Unknown old_text")
if new_text is None:
raise ValueError("Unknown new_text")
-
- # .ipynb detection
- if path.endswith(".ipynb"):
- return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file."
+ if occurrence is not None and occurrence < 1:
+ return "Error: occurrence must be >= 1."
+ if line_hint is not None and line_hint < 1:
+ return "Error: line_hint must be >= 1."
+ if expected_replacements is not None and expected_replacements < 1:
+ return "Error: expected_replacements must be >= 1."
fp = self._resolve(path)
@@ -743,15 +790,42 @@ class EditFileTool(_FsTool):
if not matches:
return self._not_found_msg(old_text, content, path)
count = len(matches)
+ if replace_all and occurrence is not None:
+ return "Error: occurrence cannot be used with replace_all=true."
+ if replace_all and line_hint is not None:
+ return "Error: line_hint cannot be used with replace_all=true."
+ if occurrence is not None and line_hint is not None:
+ return "Error: line_hint cannot be used with occurrence."
if count > 1 and not replace_all:
- line_numbers = [match.line for match in matches]
- preview = ", ".join(f"line {n}" for n in line_numbers[:3])
- if len(line_numbers) > 3:
- preview += ", ..."
- location_hint = f" at {preview}" if preview else ""
+ if occurrence is not None:
+ if occurrence > count:
+ return (
+ f"Error: occurrence {occurrence} is out of range; "
+ f"old_text appears {count} times."
+ )
+ elif line_hint is not None:
+ nearest = min(matches, key=lambda match: abs(match.line - line_hint))
+ distance = abs(nearest.line - line_hint)
+ if sum(1 for match in matches if abs(match.line - line_hint) == distance) > 1:
+ return (
+ f"Error: line_hint {line_hint} is ambiguous; "
+ f"old_text appears {count} times."
+ )
+ else:
+ line_numbers = [match.line for match in matches]
+ preview = ", ".join(f"line {n}" for n in line_numbers[:3])
+ if len(line_numbers) > 3:
+ preview += ", ..."
+ location_hint = f" at {preview}" if preview else ""
+ return (
+ f"Warning: old_text appears {count} times{location_hint}. "
+ "Provide more context, set occurrence to choose one match, "
+ "or set replace_all=true."
+ )
+ elif occurrence is not None and occurrence > count:
return (
- f"Warning: old_text appears {count} times{location_hint}. "
- "Provide more context to make it unique, or set replace_all=true."
+ f"Error: occurrence {occurrence} is out of range; "
+ f"old_text appears {count} time."
)
norm_new = new_text.replace("\r\n", "\n")
@@ -760,7 +834,17 @@ class EditFileTool(_FsTool):
if fp.suffix.lower() not in self._MARKDOWN_EXTS:
norm_new = self._strip_trailing_ws(norm_new)
- selected = matches if replace_all else matches[:1]
+ if replace_all:
+ selected = matches
+ elif line_hint is not None:
+ selected = [min(matches, key=lambda match: abs(match.line - line_hint))]
+ else:
+ selected = [matches[occurrence - 1 if occurrence else 0]]
+ if expected_replacements is not None and len(selected) != expected_replacements:
+ return (
+ f"Error: expected {expected_replacements} replacements but "
+ f"would make {len(selected)}."
+ )
new_content = content
for match in reversed(selected):
replacement = _preserve_quote_style(norm_old, match.text, norm_new)
diff --git a/nanobot/agent/tools/image_generation.py b/nanobot/agent/tools/image_generation.py
index f2f599ded..a194d0fee 100644
--- a/nanobot/agent/tools/image_generation.py
+++ b/nanobot/agent/tools/image_generation.py
@@ -130,12 +130,6 @@ class ImageGenerationTool(Tool):
}
return cls(**kwargs)
- def _missing_api_key_error(self) -> str:
- cls = get_image_gen_provider(self.config.provider)
- if cls and cls.missing_key_message:
- return f"Error: {cls.missing_key_message}"
- return f"Error: {self.config.provider} API key is not configured."
-
def _resolve_reference_image(self, value: str) -> str:
raw_path = Path(value).expanduser()
path = raw_path if raw_path.is_absolute() else self.workspace / raw_path
@@ -173,9 +167,6 @@ class ImageGenerationTool(Tool):
client = self._provider_client()
if client is None:
return f"Error: unsupported image generation provider '{self.config.provider}'"
- provider = self._provider_config()
- if not provider or not provider.api_key:
- return self._missing_api_key_error()
requested = count or 1
if requested > self.config.max_images_per_turn:
diff --git a/nanobot/agent/tools/notebook.py b/nanobot/agent/tools/notebook.py
deleted file mode 100644
index 0980b7c93..000000000
--- a/nanobot/agent/tools/notebook.py
+++ /dev/null
@@ -1,162 +0,0 @@
-"""NotebookEditTool — edit Jupyter .ipynb notebooks."""
-
-from __future__ import annotations
-
-import json
-import uuid
-from typing import Any
-
-from nanobot.agent.tools.base import tool_parameters
-from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
-from nanobot.agent.tools.filesystem import _FsTool
-
-
-def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict:
- cell: dict[str, Any] = {
- "cell_type": cell_type,
- "source": source,
- "metadata": {},
- }
- if cell_type == "code":
- cell["outputs"] = []
- cell["execution_count"] = None
- if generate_id:
- cell["id"] = uuid.uuid4().hex[:8]
- return cell
-
-
-def _make_empty_notebook() -> dict:
- return {
- "nbformat": 4,
- "nbformat_minor": 5,
- "metadata": {
- "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
- "language_info": {"name": "python"},
- },
- "cells": [],
- }
-
-
-@tool_parameters(
- tool_parameters_schema(
- path=StringSchema("Path to the .ipynb notebook file"),
- cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0),
- new_source=StringSchema("New source content for the cell"),
- cell_type=StringSchema(
- "Cell type: 'code' or 'markdown' (default: code)",
- enum=["code", "markdown"],
- ),
- edit_mode=StringSchema(
- "Mode: 'replace' (default), 'insert' (after target), or 'delete'",
- enum=["replace", "insert", "delete"],
- ),
- required=["path", "cell_index"],
- )
-)
-class NotebookEditTool(_FsTool):
- """Edit Jupyter notebook cells: replace, insert, or delete."""
- _scopes = {"core"}
-
- _VALID_CELL_TYPES = frozenset({"code", "markdown"})
- _VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"})
-
- @property
- def name(self) -> str:
- return "notebook_edit"
-
- @property
- def description(self) -> str:
- return (
- "Edit a Jupyter notebook (.ipynb) cell. "
- "Modes: replace (default) replaces cell content, "
- "insert adds a new cell after the target index, "
- "delete removes the cell at the index. "
- "cell_index is 0-based."
- )
-
- async def execute(
- self,
- path: str | None = None,
- cell_index: int = 0,
- new_source: str = "",
- cell_type: str = "code",
- edit_mode: str = "replace",
- **kwargs: Any,
- ) -> str:
- try:
- if not path:
- return "Error: path is required"
-
- if not path.endswith(".ipynb"):
- return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files."
-
- if edit_mode not in self._VALID_EDIT_MODES:
- return (
- f"Error: Invalid edit_mode '{edit_mode}'. "
- "Use one of: replace, insert, delete."
- )
-
- if cell_type not in self._VALID_CELL_TYPES:
- return (
- f"Error: Invalid cell_type '{cell_type}'. "
- "Use one of: code, markdown."
- )
-
- fp = self._resolve(path)
-
- # Create new notebook if file doesn't exist and mode is insert
- if not fp.exists():
- if edit_mode != "insert":
- return f"Error: File not found: {path}"
- nb = _make_empty_notebook()
- cell = _new_cell(new_source, cell_type, generate_id=True)
- nb["cells"].append(cell)
- fp.parent.mkdir(parents=True, exist_ok=True)
- fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
- return f"Successfully created {fp} with 1 cell"
-
- try:
- nb = json.loads(fp.read_text(encoding="utf-8"))
- except (json.JSONDecodeError, UnicodeDecodeError) as e:
- return f"Error: Failed to parse notebook: {e}"
-
- cells = nb.get("cells", [])
- nbformat_minor = nb.get("nbformat_minor", 0)
- generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5
-
- if edit_mode == "delete":
- if cell_index < 0 or cell_index >= len(cells):
- return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
- cells.pop(cell_index)
- nb["cells"] = cells
- fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
- return f"Successfully deleted cell {cell_index} from {fp}"
-
- if edit_mode == "insert":
- insert_at = min(cell_index + 1, len(cells))
- cell = _new_cell(new_source, cell_type, generate_id=generate_id)
- cells.insert(insert_at, cell)
- nb["cells"] = cells
- fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
- return f"Successfully inserted cell at index {insert_at} in {fp}"
-
- # Default: replace
- if cell_index < 0 or cell_index >= len(cells):
- return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
- cells[cell_index]["source"] = new_source
- if cell_type and cells[cell_index].get("cell_type") != cell_type:
- cells[cell_index]["cell_type"] = cell_type
- if cell_type == "code":
- cells[cell_index].setdefault("outputs", [])
- cells[cell_index].setdefault("execution_count", None)
- elif "outputs" in cells[cell_index]:
- del cells[cell_index]["outputs"]
- cells[cell_index].pop("execution_count", None)
- nb["cells"] = cells
- fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
- return f"Successfully edited cell {cell_index} in {fp}"
-
- except PermissionError as e:
- return f"Error: {e}"
- except Exception as e:
- return f"Error editing notebook: {e}"
diff --git a/nanobot/agent/tools/search.py b/nanobot/agent/tools/search.py
index 49448030b..0febb122c 100644
--- a/nanobot/agent/tools/search.py
+++ b/nanobot/agent/tools/search.py
@@ -1,4 +1,4 @@
-"""Search tools: grep."""
+"""Search tools: file discovery and grep."""
from __future__ import annotations
@@ -12,6 +12,7 @@ from typing import Any, Iterable, TypeVar
from nanobot.agent.tools.filesystem import ListDirTool, _FsTool
_DEFAULT_HEAD_LIMIT = 250
+_DEFAULT_FILE_HEAD_LIMIT = 200
T = TypeVar("T")
_TYPE_GLOB_MAP = {
"py": ("*.py", "*.pyi"),
@@ -88,6 +89,14 @@ def _matches_type(name: str, file_type: str | None) -> bool:
return any(fnmatch.fnmatch(name.lower(), pattern.lower()) for pattern in patterns)
+def _matches_query(rel_path: str, query: str | None) -> bool:
+ if not query:
+ return True
+ haystack = rel_path.lower()
+ terms = [part for part in query.lower().split() if part]
+ return all(term in haystack for term in terms)
+
+
class _SearchTool(_FsTool):
_IGNORE_DIRS = set(ListDirTool._IGNORE_DIRS)
@@ -109,6 +118,163 @@ class _SearchTool(_FsTool):
yield current / filename
+class FindFilesTool(_SearchTool):
+ """Find files by path fragment, glob, or type."""
+ _scopes = {"core", "subagent"}
+
+ @property
+ def name(self) -> str:
+ return "find_files"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Find files by path fragment, glob, or file type. "
+ "Use this before read_file when you need to locate files, and "
+ "prefer it over shell find/ls for ordinary workspace discovery. "
+ "Returns workspace-relative paths and skips common dependency/build "
+ "directories."
+ )
+
+ @property
+ def read_only(self) -> bool:
+ return True
+
+ @property
+ def parameters(self) -> dict[str, Any]:
+ return {
+ "type": "object",
+ "properties": {
+ "path": {
+ "type": "string",
+ "description": "Directory or file to search in (default '.')",
+ },
+ "query": {
+ "type": "string",
+ "description": (
+ "Optional case-insensitive path fragment search. "
+ "Whitespace-separated terms must all be present."
+ ),
+ },
+ "glob": {
+ "type": "string",
+ "description": "Optional file filter, e.g. '*.py' or 'tests/**/test_*.py'",
+ },
+ "type": {
+ "type": "string",
+ "description": "Optional file type shorthand, e.g. 'py', 'ts', 'md', 'json'",
+ },
+ "include_dirs": {
+ "type": "boolean",
+ "description": "Include matching directories as well as files (default false)",
+ },
+ "sort": {
+ "type": "string",
+ "enum": ["path", "modified"],
+ "description": "Sort by path or most recently modified first (default path)",
+ },
+ "head_limit": {
+ "type": "integer",
+ "description": "Maximum number of paths to return (default 200, 0 for all, max 1000)",
+ "minimum": 0,
+ "maximum": 1000,
+ },
+ "offset": {
+ "type": "integer",
+ "description": "Skip the first N results before applying head_limit",
+ "minimum": 0,
+ "maximum": 100000,
+ },
+ },
+ }
+
+ def _iter_paths(self, root: Path, *, include_dirs: bool) -> Iterable[Path]:
+ if root.is_file():
+ yield root
+ return
+ if include_dirs:
+ yield root
+ for dirpath, dirnames, filenames in os.walk(root):
+ dirnames[:] = sorted(d for d in dirnames if d not in self._IGNORE_DIRS)
+ current = Path(dirpath)
+ if include_dirs and current != root:
+ yield current
+ for filename in sorted(filenames):
+ yield current / filename
+
+ async def execute(
+ self,
+ path: str = ".",
+ query: str | None = None,
+ glob: str | None = None,
+ type: str | None = None,
+ include_dirs: bool = False,
+ sort: str = "path",
+ head_limit: int | None = None,
+ offset: int = 0,
+ **kwargs: Any,
+ ) -> str:
+ try:
+ target = self._resolve(path or ".")
+ if not target.exists():
+ return f"Error: Path not found: {path}"
+ if not (target.is_dir() or target.is_file()):
+ return f"Error: Unsupported path: {path}"
+
+ if sort not in {"path", "modified"}:
+ return "Error: sort must be 'path' or 'modified'"
+
+ limit = (
+ _DEFAULT_FILE_HEAD_LIMIT
+ if head_limit is None
+ else None if head_limit == 0 else head_limit
+ )
+ root = target if target.is_dir() else target.parent
+ matches: list[tuple[str, float]] = []
+
+ for candidate in self._iter_paths(target, include_dirs=include_dirs):
+ if candidate.is_dir() and not include_dirs:
+ continue
+ rel_path = candidate.relative_to(root).as_posix()
+ display_path = self._display_path(candidate, root)
+ name = candidate.name
+
+ if glob and not _match_glob(rel_path, name, glob):
+ continue
+ if candidate.is_file() and not _matches_type(name, type):
+ continue
+ if candidate.is_dir() and type:
+ continue
+ if not _matches_query(display_path, query):
+ continue
+ try:
+ mtime = candidate.stat().st_mtime
+ except OSError:
+ mtime = 0.0
+ suffix = "/" if candidate.is_dir() else ""
+ matches.append((display_path + suffix, mtime))
+
+ if sort == "modified":
+ matches.sort(key=lambda item: (-item[1], item[0]))
+ else:
+ matches.sort(key=lambda item: item[0])
+
+ paths = [item[0] for item in matches]
+ paged, truncated = _paginate(paths, limit, offset)
+ if not paged:
+ return "No files found"
+
+ result = "\n".join(paged)
+ note = _pagination_note(limit, offset, truncated)
+ if note:
+ result += "\n\n" + note
+ return result
+ except PermissionError as e:
+ return f"Error: {e}"
+ except Exception as e:
+ return f"Error finding files: {e}"
+
+
class GrepTool(_SearchTool):
"""Search file contents using a regex-like pattern."""
_scopes = {"core", "subagent"}
@@ -125,7 +291,8 @@ class GrepTool(_SearchTool):
return (
"Search file contents with a regex pattern. "
"Default output_mode is files_with_matches (file paths only); "
- "use content mode for matching lines with context. "
+ "use content mode for matching lines with context. Prefer this "
+ "over shell grep for ordinary workspace searches. "
"Skips binary and files >2 MB. Supports glob/type filtering."
)
diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py
index 0252b9746..537c89343 100644
--- a/nanobot/agent/tools/shell.py
+++ b/nanobot/agent/tools/shell.py
@@ -8,6 +8,7 @@ import re
import shutil
import sys
from contextlib import suppress
+from dataclasses import dataclass
from pathlib import Path
from typing import Any
@@ -15,8 +16,17 @@ from loguru import logger
from pydantic import Field
from nanobot.agent.tools.base import Tool, tool_parameters
+from nanobot.agent.tools.exec_session import (
+ DEFAULT_MAX_OUTPUT_CHARS,
+ DEFAULT_YIELD_MS,
+ DEFAULT_EXEC_SESSION_MANAGER,
+ MAX_OUTPUT_CHARS,
+ MAX_YIELD_MS,
+ clamp_session_int,
+ format_session_poll,
+)
from nanobot.agent.tools.sandbox import wrap_command
-from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
+from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base
@@ -44,10 +54,22 @@ class ExecToolConfig(Base):
deny_patterns: list[str] = Field(default_factory=list)
+@dataclass(slots=True)
+class _PreparedCommand:
+ command: str
+ cwd: str
+ env: dict[str, str]
+ timeout: int
+ shell_program: str | None
+ login: bool
+
+
@tool_parameters(
tool_parameters_schema(
command=StringSchema("The shell command to execute"),
+ cmd=StringSchema("Compatibility alias for command"),
working_dir=StringSchema("Optional working directory for the command"),
+ workdir=StringSchema("Compatibility alias for working_dir"),
timeout=IntegerSchema(
60,
description=(
@@ -57,7 +79,44 @@ class ExecToolConfig(Base):
minimum=1,
maximum=600,
),
- required=["command"],
+ shell=StringSchema(
+ "Optional shell binary to launch. On Unix, supports sh, bash, or zsh.",
+ nullable=True,
+ ),
+ login=BooleanSchema(
+ description="Whether to run bash/zsh with login shell semantics (default true).",
+ default=True,
+ nullable=True,
+ ),
+ yield_time_ms=IntegerSchema(
+ description=(
+ "Optional milliseconds to wait before returning output. "
+ "When set, a still-running command returns a session_id that "
+ "can be polled or written to with write_stdin. Omit this field "
+ "to keep one-shot exec behavior."
+ ),
+ minimum=0,
+ maximum=MAX_YIELD_MS,
+ nullable=True,
+ ),
+ max_output_chars=IntegerSchema(
+ description=(
+ "Maximum output characters to return when yield_time_ms is used "
+ "(default 10000, max 50000)."
+ ),
+ minimum=1000,
+ maximum=MAX_OUTPUT_CHARS,
+ nullable=True,
+ ),
+ max_output_tokens=IntegerSchema(
+ description=(
+ "Compatibility alias for max_output_chars. The current runtime "
+ "uses a character budget."
+ ),
+ minimum=1000,
+ maximum=MAX_OUTPUT_CHARS,
+ nullable=True,
+ ),
)
)
class ExecTool(Tool):
@@ -98,6 +157,7 @@ class ExecTool(Tool):
sandbox: str = "",
path_append: str = "",
allowed_env_keys: list[str] | None = None,
+ session_manager: Any | None = None,
):
self.timeout = timeout
self.working_dir = working_dir
@@ -125,6 +185,7 @@ class ExecTool(Tool):
self.restrict_to_workspace = restrict_to_workspace
self.path_append = path_append
self.allowed_env_keys = allowed_env_keys or []
+ self._session_manager = session_manager or DEFAULT_EXEC_SESSION_MANAGER
@property
def name(self) -> str:
@@ -150,10 +211,15 @@ class ExecTool(Tool):
def description(self) -> str:
return (
"Execute a shell command and return its output. "
- "Prefer read_file/write_file/edit_file over cat/echo/sed, "
- "and grep/glob over shell find/grep. "
+ "Use this for tests, builds, package commands, git commands, and "
+ "other process execution. Prefer read_file/find_files/grep for "
+ "inspection and apply_patch/write_file/edit_file for file changes "
+ "instead of cat, shell find/grep, echo, or sed. "
"Use -y or --yes flags to avoid interactive prompts. "
- "Output is truncated at 10 000 chars; timeout defaults to 60s."
+ "For long-running or interactive commands, pass yield_time_ms; "
+ "if the command keeps running, exec returns a session_id that can "
+ "be polled or written to with write_stdin. Output is truncated at "
+ "10 000 chars; timeout defaults to 60s."
)
@property
@@ -161,9 +227,111 @@ class ExecTool(Tool):
return True
async def execute(
- self, command: str, working_dir: str | None = None,
- timeout: int | None = None, **kwargs: Any,
+ self, command: str | None = None, cmd: str | None = None,
+ working_dir: str | None = None, workdir: str | None = None,
+ timeout: int | None = None, shell: str | None = None,
+ login: bool | None = None, yield_time_ms: int | None = None,
+ max_output_chars: int | None = None,
+ max_output_tokens: int | None = None,
+ **kwargs: Any,
) -> str:
+ command = command or cmd
+ working_dir = working_dir or workdir
+ if not command:
+ return "Error: Missing command. Provide command or cmd."
+ if max_output_chars is None:
+ max_output_chars = max_output_tokens
+
+ prepared = self._prepare_command(command, working_dir, timeout, shell, login)
+ if isinstance(prepared, str):
+ return prepared
+
+ if yield_time_ms is not None:
+ return await self._execute_session(prepared, yield_time_ms, max_output_chars)
+
+ try:
+ process = await self._spawn(
+ prepared.command,
+ prepared.cwd,
+ prepared.env,
+ prepared.shell_program,
+ prepared.login,
+ )
+
+ try:
+ stdout, stderr = await asyncio.wait_for(
+ process.communicate(),
+ timeout=prepared.timeout,
+ )
+ except asyncio.TimeoutError:
+ await self._kill_process(process)
+ return f"Error: Command timed out after {prepared.timeout} seconds"
+ except asyncio.CancelledError:
+ await self._kill_process(process)
+ raise
+
+ output_parts = []
+
+ if stdout:
+ output_parts.append(stdout.decode("utf-8", errors="replace"))
+
+ if stderr:
+ stderr_text = stderr.decode("utf-8", errors="replace")
+ if stderr_text.strip():
+ output_parts.append(f"STDERR:\n{stderr_text}")
+
+ output_parts.append(f"\nExit code: {process.returncode}")
+
+ result = "\n".join(output_parts) if output_parts else "(no output)"
+
+ max_len = clamp_session_int(max_output_chars, self._MAX_OUTPUT, 1000, MAX_OUTPUT_CHARS)
+ if len(result) > max_len:
+ half = max_len // 2
+ result = (
+ result[:half]
+ + f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
+ + result[-half:]
+ )
+
+ return result
+
+ except Exception as e:
+ return f"Error executing command: {str(e)}"
+
+ async def _execute_session(
+ self,
+ prepared: _PreparedCommand,
+ yield_time_ms: int | None,
+ max_output_chars: int | None,
+ ) -> str:
+ try:
+ session_id, poll = await self._session_manager.start(
+ command=prepared.command,
+ cwd=prepared.cwd,
+ env=prepared.env,
+ timeout=prepared.timeout,
+ shell_program=prepared.shell_program,
+ login=prepared.login,
+ yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
+ max_output_chars=clamp_session_int(
+ max_output_chars,
+ DEFAULT_MAX_OUTPUT_CHARS,
+ 1000,
+ MAX_OUTPUT_CHARS,
+ ),
+ )
+ return format_session_poll(session_id, poll)
+ except Exception as exc:
+ return f"Error executing command: {exc}"
+
+ def _prepare_command(
+ self,
+ command: str,
+ working_dir: str | None = None,
+ timeout: int | None = None,
+ shell: str | None = None,
+ login: bool | None = None,
+ ) -> _PreparedCommand | str:
cwd = working_dir or self.working_dir or os.getcwd()
# Prevent an LLM-supplied working_dir from escaping the configured
@@ -211,52 +379,24 @@ class ExecTool(Tool):
env["NANOBOT_PATH_APPEND"] = self.path_append
command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}'
- try:
- process = await self._spawn(command, cwd, env)
+ shell_program, shell_error = self._resolve_shell(shell)
+ if shell_error:
+ return shell_error
- try:
- stdout, stderr = await asyncio.wait_for(
- process.communicate(),
- timeout=effective_timeout,
- )
- except asyncio.TimeoutError:
- await self._kill_process(process)
- return f"Error: Command timed out after {effective_timeout} seconds"
- except asyncio.CancelledError:
- await self._kill_process(process)
- raise
-
- output_parts = []
-
- if stdout:
- output_parts.append(stdout.decode("utf-8", errors="replace"))
-
- if stderr:
- stderr_text = stderr.decode("utf-8", errors="replace")
- if stderr_text.strip():
- output_parts.append(f"STDERR:\n{stderr_text}")
-
- output_parts.append(f"\nExit code: {process.returncode}")
-
- result = "\n".join(output_parts) if output_parts else "(no output)"
-
- max_len = self._MAX_OUTPUT
- if len(result) > max_len:
- half = max_len // 2
- result = (
- result[:half]
- + f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
- + result[-half:]
- )
-
- return result
-
- except Exception as e:
- return f"Error executing command: {str(e)}"
+ return _PreparedCommand(
+ command=command,
+ cwd=cwd,
+ env=env,
+ timeout=effective_timeout,
+ shell_program=shell_program,
+ login=True if login is None else login,
+ )
@staticmethod
async def _spawn(
command: str, cwd: str, env: dict[str, str],
+ shell_program: str | None = None,
+ login: bool = True,
) -> asyncio.subprocess.Process:
"""Launch *command* in a platform-appropriate shell."""
if _IS_WINDOWS:
@@ -272,9 +412,14 @@ class ExecTool(Tool):
cwd=cwd,
env=env,
)
- bash = shutil.which("bash") or "/bin/bash"
+ shell_program = shell_program or shutil.which("bash") or "/bin/bash"
+ args = [shell_program]
+ shell_name = Path(shell_program).name.lower()
+ if login and shell_name in {"bash", "bash.exe", "zsh", "zsh.exe"}:
+ args.append("-l")
+ args.extend(["-c", command])
return await asyncio.create_subprocess_exec(
- bash, "-l", "-c", command,
+ *args,
stdin=asyncio.subprocess.DEVNULL,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
@@ -282,6 +427,31 @@ class ExecTool(Tool):
env=env,
)
+ @staticmethod
+ def _resolve_shell(shell: str | None) -> tuple[str | None, str | None]:
+ if not shell:
+ return None, None
+ if _IS_WINDOWS:
+ return None, "Error: shell parameter is not supported on Windows"
+ if "\0" in shell or "\n" in shell or "\r" in shell:
+ return None, "Error: shell contains invalid characters"
+ allowed = {"sh", "bash", "zsh"}
+ path = Path(shell).expanduser()
+ if path.is_absolute():
+ if path.name not in allowed:
+ return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
+ if not path.is_file() or not os.access(path, os.X_OK):
+ return None, f"Error: shell is not executable: {shell}"
+ return str(path), None
+ if "/" in shell or "\\" in shell:
+ return None, "Error: shell must be a shell name or absolute path"
+ if shell not in allowed:
+ return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
+ resolved = shutil.which(shell)
+ if not resolved:
+ return None, f"Error: shell not found: {shell}"
+ return resolved, None
+
@staticmethod
async def _kill_process(process: asyncio.subprocess.Process) -> None:
"""Kill a subprocess and reap it to prevent zombies."""
@@ -418,7 +588,7 @@ class ExecTool(Tool):
# Windows: match drive-root paths like `C:\` as well as `C:\path\to\file`, and UNC paths like `\\server\share`
# NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted.
win_paths = re.findall(
- r"(?:[A-Za-z]:[^\s\"'|><;]*|\\\\[^\s\"'|><;]+(?:\\[^\s\"'|><;]+)*)",
+ r"(?<;]*|\\\\[^\s\"'|><;]+(?:\\[^\s\"'|><;]+)*)",
command
)
posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only
diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py
new file mode 100644
index 000000000..2a38f60ac
--- /dev/null
+++ b/nanobot/channels/signal.py
@@ -0,0 +1,1402 @@
+"""Signal channel implementation using signal-cli daemon JSON-RPC interface."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+import re
+import shutil
+import unicodedata
+from collections import deque
+from collections.abc import AsyncIterator, Callable
+from contextlib import asynccontextmanager
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any
+
+import httpx
+from pydantic import Field, computed_field, field_validator
+
+from nanobot.bus.events import InboundMessage, OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.base import BaseChannel
+from nanobot.config.paths import get_media_dir
+from nanobot.config.schema import Base
+from nanobot.pairing import is_approved
+from nanobot.utils.helpers import safe_filename, split_message
+
+
+@dataclass
+class _Run:
+ text: str
+ styles: frozenset[str] = field(default_factory=frozenset)
+ opaque: bool = False # code / table content — skip further pattern processing
+
+
+_SIG_CODE_BLOCK_RE = re.compile(r"```(?:\w+)?\n?([\s\S]*?)```")
+_SIG_INLINE_CODE_RE = re.compile(r"`([^`\n]+)`")
+_SIG_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
+_SIG_BLOCKQUOTE_RE = re.compile(r"^>\s*(.*)$", re.MULTILINE)
+_SIG_BULLET_RE = re.compile(r"^[-*]\s+", re.MULTILINE)
+_SIG_OLIST_RE = re.compile(r"^(\d+)\.\s+", re.MULTILINE)
+_SIG_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
+_SIG_BOLD_RE = re.compile(r"\*\*(.+?)\*\*|__(.+?)__", re.DOTALL)
+_SIG_ITALIC_RE = re.compile(
+ r"(? int:
+ """UTF-16 code-unit length, matching Signal BodyRange semantics."""
+ return len(s.encode("utf-16-le")) // 2
+
+
+def _sig_strip_cell(s: str) -> str:
+ """Strip inline markdown from a table cell for plain-text rendering."""
+ for pattern, repl in _SIG_CELL_STRIP_PATTERNS:
+ s = pattern.sub(repl, s)
+ return s.strip()
+
+
+def _sig_render_table(table_lines: list[str]) -> str:
+ """Render a markdown pipe-table as fixed-width plain text."""
+
+ def dw(s: str) -> int:
+ return sum(2 if unicodedata.east_asian_width(c) in ("W", "F") else 1 for c in s)
+
+ rows: list[list[str]] = []
+ has_sep = False
+ for line in table_lines:
+ cells = [_sig_strip_cell(c) for c in line.strip().strip("|").split("|")]
+ if all(re.match(r"^:?-+:?$", c) for c in cells if c):
+ has_sep = True
+ continue
+ rows.append(cells)
+ if not rows or not has_sep:
+ return "\n".join(table_lines)
+
+ ncols = max(len(r) for r in rows)
+ for r in rows:
+ r.extend([""] * (ncols - len(r)))
+ widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
+
+ def dr(cells: list[str]) -> str:
+ return " ".join(f"{c}{' ' * (w - dw(c))}" for c, w in zip(cells, widths))
+
+ out = [dr(rows[0])]
+ out.append(" ".join("─" * w for w in widths))
+ for row in rows[1:]:
+ out.append(dr(row))
+ return "\n".join(out)
+
+
+def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
+ """Convert markdown text to Signal plain text + textStyle ranges.
+
+ Returns ``(plain_text, text_styles)`` where ``text_styles`` are
+ ``"start:length:STYLE"`` strings for the signal-cli ``textStyle`` parameter.
+ """
+ if not text:
+ return text, []
+
+ # Phase 1 (text-level): extract code blocks and tables with placeholder tokens
+ # so they're protected from inline-style processing.
+ protected: list[str] = []
+
+ def save_code(m: re.Match) -> str:
+ protected.append(m.group(1))
+ return f"\x00C{len(protected) - 1}\x00"
+
+ text = _SIG_CODE_BLOCK_RE.sub(save_code, text)
+
+ # Detect and render pipe-tables line by line.
+ lines = text.split("\n")
+ rebuilt: list[str] = []
+ i = 0
+ while i < len(lines):
+ if re.match(r"^\s*\|.+\|", lines[i]):
+ tbl: list[str] = []
+ while i < len(lines) and re.match(r"^\s*\|.+\|", lines[i]):
+ tbl.append(lines[i])
+ i += 1
+ rendered = _sig_render_table(tbl)
+ if rendered != "\n".join(tbl):
+ protected.append(rendered)
+ rebuilt.append(f"\x00C{len(protected) - 1}\x00")
+ else:
+ rebuilt.extend(tbl)
+ else:
+ rebuilt.append(lines[i])
+ i += 1
+ text = "\n".join(rebuilt)
+
+ # Phase 2 (run-based): process inline patterns.
+ runs: list[_Run] = [_Run(text)]
+
+ def transform(
+ pattern: re.Pattern,
+ make_runs: Callable[[re.Match, frozenset[str]], list[_Run]],
+ ) -> None:
+ new_runs: list[_Run] = []
+ for run in runs:
+ if run.opaque:
+ new_runs.append(run)
+ continue
+ pos = 0
+ for m in pattern.finditer(run.text):
+ if m.start() > pos:
+ new_runs.append(_Run(run.text[pos : m.start()], run.styles))
+ new_runs.extend(make_runs(m, run.styles))
+ pos = m.end()
+ if pos < len(run.text):
+ new_runs.append(_Run(run.text[pos:], run.styles))
+ runs[:] = new_runs
+
+ # Restore code/table placeholders as opaque MONOSPACE runs.
+ transform(
+ _SIG_TOKEN_RE,
+ lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)],
+ )
+
+ # Inline code (opaque).
+ transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)])
+
+ # Headers → bold plain text.
+ transform(_SIG_HEADER_RE, lambda m, s: [_Run(m.group(1), s | {"BOLD"})])
+
+ # Blockquotes → strip marker.
+ transform(_SIG_BLOCKQUOTE_RE, lambda m, s: [_Run(m.group(1), s)])
+
+ # Bullet lists → bullet character.
+ transform(_SIG_BULLET_RE, lambda m, s: [_Run("• ", s)])
+
+ # Numbered lists → normalize spacing.
+ transform(_SIG_OLIST_RE, lambda m, s: [_Run(m.group(1) + ". ", s)])
+
+ # Links → "text (url)" or bare url when text equals url.
+ def _link_runs(m: re.Match, s: frozenset) -> list[_Run]:
+ link_text, url = m.group(1), m.group(2)
+
+ def _norm(u: str) -> str:
+ return re.sub(r"^https?://(www\.)?", "", u).rstrip("/").lower()
+
+ if _norm(url) == _norm(link_text):
+ return [_Run(url, s)]
+ return [_Run(f"{link_text} ({url})", s)]
+
+ transform(_SIG_LINK_RE, _link_runs)
+
+ # Bold (before italic so ** doesn't interfere).
+ transform(_SIG_BOLD_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"BOLD"})])
+
+ # Italic (single * or _).
+ transform(_SIG_ITALIC_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"ITALIC"})])
+
+ # Strikethrough: ~~text~~ (standard) or ~text~ (single-tilde variant).
+ transform(_SIG_STRIKE_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"STRIKETHROUGH"})])
+
+ # Phase 3: assemble output. Offsets and lengths are emitted in UTF-16 code
+ # units because Signal's BodyRange (via signal-cli's textStyle) interprets
+ # them as such; Python's len() counts code points, which would shift ranges
+ # left by 1 unit per non-BMP character preceding them.
+ plain_text = ""
+ text_styles: list[str] = []
+ utf16_offset = 0
+ for run in runs:
+ if not run.text:
+ continue
+ plain_text += run.text
+ start = utf16_offset
+ length = _utf16_len(run.text)
+ utf16_offset += length
+ for style in sorted(run.styles):
+ text_styles.append(f"{start}:{length}:{style}")
+
+ return plain_text, text_styles
+
+
+def _partition_styles(
+ plain_text: str, chunks: list[str], text_styles: list[str]
+) -> list[list[str]]:
+ """Partition Signal textStyle ranges across message chunks.
+
+ ``split_message`` slices ``plain_text`` into pieces (optionally trimming
+ whitespace at the boundaries), but the style ranges produced by
+ ``_markdown_to_signal`` are expressed in UTF-16 offsets relative to the
+ full ``plain_text``. This redistributes them per chunk with offsets
+ rebased to each chunk's start. Ranges that span a boundary are split
+ across the chunks they touch; ranges that fall entirely in trimmed
+ whitespace are dropped.
+ """
+ if not chunks:
+ return []
+ if not text_styles:
+ return [[] for _ in chunks]
+
+ # Locate each chunk's UTF-16 start in plain_text. split_message lstrips at
+ # boundaries (but not before the first chunk), so we skip whitespace
+ # between chunks to mirror that.
+ chunk_ranges: list[tuple[int, int]] = []
+ cursor = 0 # Python codepoint cursor in plain_text
+ for i, chunk in enumerate(chunks):
+ if i > 0:
+ while cursor < len(plain_text) and plain_text[cursor].isspace():
+ cursor += 1
+ utf16_start = _utf16_len(plain_text[:cursor])
+ utf16_end = utf16_start + _utf16_len(chunk)
+ chunk_ranges.append((utf16_start, utf16_end))
+ cursor += len(chunk)
+
+ result: list[list[str]] = [[] for _ in chunks]
+ for entry in text_styles:
+ s, ln, style = entry.split(":", 2)
+ r_start = int(s)
+ r_end = r_start + int(ln)
+ for i, (c_start, c_end) in enumerate(chunk_ranges):
+ if r_end <= c_start or r_start >= c_end:
+ continue
+ new_start = max(r_start, c_start) - c_start
+ new_end = min(r_end, c_end) - c_start
+ new_length = new_end - new_start
+ if new_length > 0:
+ result[i].append(f"{new_start}:{new_length}:{style}")
+ return result
+
+
+class SignalDMConfig(Base):
+ """Signal DM policy configuration."""
+
+ enabled: bool = False
+ policy: str = "allowlist" # "open" or "allowlist"
+ allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers/UUIDs
+
+
+class SignalGroupConfig(Base):
+ """Signal group policy configuration."""
+
+ enabled: bool = False
+ policy: str = "allowlist" # "open" or "allowlist" - which groups to operate in
+ allow_from: list[str] = Field(default_factory=list) # Allowed group IDs if allowlist policy
+ require_mention: bool = True # Whether bot must be mentioned to respond
+
+
+class SignalConfig(Base):
+ """Signal channel configuration using signal-cli daemon (HTTP mode with -a flag only)."""
+
+ enabled: bool = False
+ phone_number: str = "" # Your Signal phone number (e.g., "+1234567890")
+ daemon_host: str = "localhost"
+ daemon_port: int = 8080
+ group_message_buffer_size: int = 20 # Number of recent group messages to keep for context
+ # Override the directory signal-cli writes inbound attachments to. When
+ # None, defaults to ~/.local/share/signal-cli/attachments (the daemon's
+ # platform default on Linux). Set this if the daemon is running with a
+ # custom XDG_DATA_HOME or on macOS/Windows where the default path differs.
+ attachments_dir: str | None = None
+ dm: SignalDMConfig = Field(default_factory=SignalDMConfig)
+ group: SignalGroupConfig = Field(default_factory=SignalGroupConfig)
+
+ @field_validator("group_message_buffer_size")
+ @classmethod
+ def _validate_buffer_size(cls, v: int) -> int:
+ if v <= 0:
+ raise ValueError("group_message_buffer_size must be > 0")
+ return v
+
+ @computed_field # type: ignore[prop-decorator]
+ @property
+ def allow_from(self) -> list[str]:
+ """Aggregate allowlist for the base-class is_allowed() check.
+
+ Returns the union of dm.allow_from and group.allow_from so the base
+ channel gate sees a populated list when either sub-policy is configured.
+ A ``"*"`` wildcard in either sub-list propagates to allow all.
+ """
+ return list(dict.fromkeys(self.dm.allow_from + self.group.allow_from))
+
+
+class SignalChannel(BaseChannel):
+ """
+ Signal channel using signal-cli daemon via HTTP JSON-RPC interface.
+
+ Requires signal-cli daemon in HTTP mode:
+ - signal-cli -a +1234567890 daemon --http localhost:8080
+
+ See https://github.com/AsamK/signal-cli for setup instructions.
+ """
+
+ name = "signal"
+ display_name = "Signal"
+ _TYPING_REFRESH_SECONDS = 10.0
+ _MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB)
+ _HTTP_TIMEOUT_SECONDS = 60.0
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return SignalConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: SignalConfig, bus: MessageBus):
+ if isinstance(config, dict):
+ config = SignalConfig.model_validate(config)
+ super().__init__(config, bus)
+ self.config: SignalConfig = config
+ self._http: httpx.AsyncClient | None = None
+ self._request_id = 0
+ self._sse_task: asyncio.Task | None = None
+ self._typing_tasks: dict[str, asyncio.Task] = {}
+ self._typing_uuid_warnings: set[str] = set()
+ self._account_id_aliases: set[str] = set()
+ self._remember_account_id_alias(self.config.phone_number)
+
+ # Rolling message buffer for group context (group_id -> deque of messages)
+ # Each message is a dict with: sender_name, sender_number, content, timestamp
+ self._group_buffers: dict[str, deque] = {}
+
+ def is_allowed(self, sender_id: str) -> bool:
+ """Override base check to normalize and split pipe-joined identifiers.
+
+ ``sender_id`` from Signal is the pipe-joined composite produced by
+ ``_collect_sender_id_parts``; allow_from entries may be single
+ identifiers or composites and may use the ``+`` prefix variant or
+ not. Delegates to ``_sender_matches_allowlist`` so the base gate
+ matches the per-policy DM gate.
+ """
+ allow_list = self.config.allow_from
+ if "*" in allow_list:
+ return True
+ if self._sender_matches_allowlist(sender_id, allow_list):
+ return True
+ if self._sender_approved_via_pairing(sender_id):
+ return True
+ if not allow_list:
+ self.logger.warning("allow_from is empty — all access denied")
+ return False
+
+ def _sender_approved_via_pairing(self, sender_id: str) -> bool:
+ """Return True if any normalized variant of sender_id is in the pairing store.
+
+ Pairing approval may be recorded under any of the identifier forms
+ signal exposes (phone with/without ``+``, UUID, ACI), so we check
+ each part of the pipe-joined composite against ``is_approved``.
+ """
+ for part in str(sender_id).split("|"):
+ for variant in self._normalize_signal_id(part):
+ if is_approved(self.name, variant):
+ return True
+ return False
+
+ async def _handle_message(
+ self,
+ sender_id: str,
+ chat_id: str,
+ content: str,
+ media: list[str] | None = None,
+ metadata: dict[str, Any] | None = None,
+ session_key: str | None = None,
+ is_dm: bool = False,
+ ) -> None:
+ """Handle an inbound message whose policy has already been checked.
+
+ ``_check_inbound_policy`` is the authoritative gate for DM/group
+ access, so we skip the base-class ``is_allowed()`` check and publish
+ directly to the bus. The denied-DM pairing path calls
+ ``super()._handle_message`` instead, which goes through
+ ``is_allowed`` and issues a pairing code.
+ """
+ meta = metadata or {}
+ if self.supports_streaming:
+ meta = {**meta, "_wants_stream": True}
+ await self.bus.publish_inbound(
+ InboundMessage(
+ channel=self.name,
+ sender_id=str(sender_id),
+ chat_id=str(chat_id),
+ content=content,
+ media=media or [],
+ metadata=meta,
+ session_key_override=session_key,
+ )
+ )
+
+ async def start(self) -> None:
+ """Start the Signal channel and connect to signal-cli daemon."""
+ if not self.config.phone_number:
+ self.logger.error("Signal account not configured")
+ return
+
+ self._running = True
+ await self._start_http_mode()
+
+ async def _start_http_mode(self) -> None:
+ """Start Signal channel using Server-Sent Events for receiving messages."""
+ base_url = f"http://{self.config.daemon_host}:{self.config.daemon_port}"
+ reconnect_delay_s = 1.0
+ max_reconnect_delay_s = 30.0
+
+ while self._running:
+ try:
+ self.logger.info("Connecting to signal-cli daemon at {}...", base_url)
+
+ # Create HTTP client
+ self._http = httpx.AsyncClient(
+ timeout=self._HTTP_TIMEOUT_SECONDS, base_url=base_url
+ )
+
+ # Test connection
+ try:
+ response = await self._http.get("/api/v1/check")
+ if response.status_code == 200:
+ self.logger.info("Connected to signal-cli daemon")
+ else:
+ raise ConnectionRefusedError(
+ f"signal-cli daemon check returned status {response.status_code}"
+ )
+ except Exception as e:
+ raise ConnectionRefusedError(f"signal-cli daemon not responding: {e}")
+
+ # Reset reconnect delay after successful connection check.
+ reconnect_delay_s = 1.0
+
+ # Ensure account-level typing indicators are enabled.
+ await self._ensure_typing_indicators_enabled()
+
+ # Start SSE receiver and supervise it. If it exits while we're still
+ # running, treat it as a disconnect and reconnect.
+ self._sse_task = asyncio.create_task(self._sse_receive_loop())
+ await self._sse_task
+ if self._running:
+ raise ConnectionError("Signal SSE stream ended unexpectedly")
+
+ except asyncio.CancelledError:
+ break
+ except ConnectionRefusedError as e:
+ self.logger.error(
+ "{}. Make sure signal-cli daemon is running: "
+ "signal-cli -a {} daemon --http {}:{}",
+ e,
+ self.config.phone_number,
+ self.config.daemon_host,
+ self.config.daemon_port,
+ )
+ except Exception as e:
+ self.logger.error("Signal channel error: {}", e)
+ finally:
+ if self._sse_task:
+ if not self._sse_task.done():
+ self._sse_task.cancel()
+ try:
+ await self._sse_task
+ except asyncio.CancelledError:
+ pass
+ except Exception:
+ pass
+ self._sse_task = None
+ if self._http:
+ await self._http.aclose()
+ self._http = None
+
+ if self._running:
+ self.logger.info(
+ "Reconnecting to signal-cli daemon in {:.0f} seconds...", reconnect_delay_s
+ )
+ await asyncio.sleep(reconnect_delay_s)
+ reconnect_delay_s = min(reconnect_delay_s * 2, max_reconnect_delay_s)
+
+ async def stop(self) -> None:
+ """Stop the Signal channel."""
+ self._running = False
+
+ # Stop SSE task
+ if self._sse_task:
+ self._sse_task.cancel()
+ try:
+ await self._sse_task
+ except asyncio.CancelledError:
+ pass
+
+ # Cancel active typing indicators
+ for chat_id in list(self._typing_tasks):
+ await self._stop_typing(chat_id)
+
+ # Close HTTP client
+ if self._http:
+ await self._http.aclose()
+ self._http = None
+
+ async def send(self, msg: OutboundMessage) -> None:
+ """Send a message through Signal."""
+ is_progress_message = bool(msg.metadata.get("_progress"))
+ try:
+ plain_text, text_styles = _markdown_to_signal(msg.content)
+ if not plain_text and not msg.media:
+ return
+ recipient_params = self._recipient_params(msg.chat_id)
+
+ chunks = split_message(plain_text, self._MAX_MESSAGE_LEN) if plain_text else [""]
+ chunk_styles = _partition_styles(plain_text, chunks, text_styles)
+ for i, chunk in enumerate(chunks):
+ params: dict[str, Any] = {"message": chunk}
+ if chunk_styles[i]:
+ params["textStyle"] = chunk_styles[i]
+ params.update(recipient_params)
+ if msg.media and i == 0:
+ params["attachments"] = msg.media
+
+ response = await self._send_request("send", params)
+
+ if "error" in response:
+ self.logger.error("Error sending Signal message: {}", response['error'])
+ raise RuntimeError(f"signal-cli send failed: {response['error']}")
+ else:
+ self.logger.debug(
+ f"Signal message sent, timestamp: {response.get('result', {}).get('timestamp')}"
+ )
+
+ except Exception:
+ self.logger.exception("Error sending Signal message")
+ raise
+ finally:
+ # Keep typing active across progress updates; stop on the final reply.
+ if not is_progress_message:
+ # Avoid immediate START->STOP for fast responses, which can be invisible
+ # in some Signal clients. Let indicator expire naturally (~15s).
+ await self._stop_typing(msg.chat_id, send_stop=False)
+
+ async def _sse_receive_loop(self) -> None:
+ """Receive messages via Server-Sent Events (HTTP mode)."""
+ if not self._http:
+ raise RuntimeError("HTTP client not initialized for Signal SSE stream")
+
+ self.logger.info("Started Signal message receive loop (SSE)")
+
+ try:
+ async with self._http.stream("GET", "/api/v1/events") as response:
+ if response.status_code != 200:
+ raise ConnectionError(
+ f"SSE connection failed with status {response.status_code}"
+ )
+
+ self.logger.info("Subscribed to Signal messages via SSE")
+
+ # Buffer for accumulating SSE data across multiple lines
+ event_buffer = []
+
+ async for line in response.aiter_lines():
+ if not self._running:
+ break
+
+ # Debug: log raw SSE lines (except keepalive pings)
+ if line and line != ":":
+ self.logger.debug("SSE line received: {}", line[:200])
+
+ # SSE format handling
+ if isinstance(line, str):
+ # Empty line signals end of event
+ if not line or line == ":":
+ if event_buffer:
+ # Try to parse the accumulated data
+ data_str = ""
+ try:
+ data_str = "\n".join(event_buffer)
+ data = json.loads(data_str)
+ self.logger.debug("SSE event parsed: {}", data)
+ await self._handle_receive_notification(data)
+ except json.JSONDecodeError as e:
+ self.logger.warning(
+ "Invalid JSON in SSE buffer: {}, data: {}",
+ e,
+ data_str[:200],
+ )
+ finally:
+ event_buffer = []
+
+ # "data:" line - accumulate it
+ elif line.startswith("data:"):
+ # SSE spec: strip one optional leading space after "data:".
+ event_buffer.append(line[6:] if line[5:6] == " " else line[5:])
+
+ # "event:" line - just log it (we only care about data)
+ elif line.startswith("event:"):
+ pass # Ignore event type for now
+
+ if self._running:
+ raise ConnectionError("Signal SSE stream closed by remote endpoint")
+
+ except asyncio.CancelledError:
+ self.logger.info("SSE receive loop cancelled")
+ raise
+ except Exception as e:
+ self.logger.error("Error in SSE receive loop: {}", e)
+ raise
+
+ @asynccontextmanager
+ async def _safe_handle(self, action: str, payload: Any = None) -> AsyncIterator[None]:
+ """Swallow and log any exception from a top-level handler block.
+
+ Logs `self.logger.error` with the action name, the exception, and a
+ bounded ``repr`` of the offending payload so the offending input is
+ recoverable from logs without having to correlate by timestamp.
+ """
+ try:
+ yield
+ except Exception as e:
+ snippet = repr(payload)[:200] if payload is not None else ""
+ text = f"Error in {action}: {e}"
+ if snippet:
+ text += f" | payload={snippet}"
+ self.logger.opt(exception=True).error(text)
+
+ async def _handle_receive_notification(self, params: dict[str, Any]) -> None:
+ """Handle incoming message notification from signal-cli."""
+ self.logger.debug("_handle_receive_notification called with: {}", params)
+ async with self._safe_handle("receive notification", params):
+ # Extract envelope from SSE notification: {"envelope": {...}}
+ envelope = params.get("envelope", {})
+
+ self.logger.debug("Extracted envelope: {}", envelope)
+
+ if not envelope:
+ self.logger.debug("No envelope found in params")
+ return
+
+ # Extract sender information
+ sender_parts = self._collect_sender_id_parts(envelope)
+ source_name = envelope.get("sourceName")
+
+ if not sender_parts:
+ self.logger.debug("Received message without source, skipping")
+ return
+
+ sender_number = self._primary_sender_id(sender_parts)
+ sender_id = "|".join(sender_parts)
+
+ # Keep aliases of the bot account for robust mention matching.
+ if any(self._id_matches_account(part) for part in sender_parts):
+ for part in sender_parts:
+ self._remember_account_id_alias(part)
+
+ # Check different message types
+ data_message = envelope.get("dataMessage")
+ sync_message = envelope.get("syncMessage")
+ typing_message = envelope.get("typingMessage")
+ receipt_message = envelope.get("receiptMessage")
+
+ # Ignore receipt messages (delivery/read receipts)
+ if receipt_message:
+ return
+
+ # Handle data messages (incoming messages from others)
+ if data_message:
+ await self._handle_data_message(sender_id, sender_number, data_message, source_name)
+
+ # Handle sync messages (messages sent from another device)
+ elif sync_message and sync_message.get("sentMessage"):
+ sent_msg = sync_message["sentMessage"]
+ destination = sent_msg.get("destination") or sent_msg.get("destinationNumber")
+ if destination:
+ self.logger.debug(
+ "Sync message sent to {}: {}", destination, sent_msg.get("message", "")[:50]
+ )
+
+ # Handle typing indicators (silently ignore)
+ elif typing_message:
+ pass # Ignore typing indicators
+
+ async def _handle_data_message(
+ self,
+ sender_id: str,
+ sender_number: str,
+ data_message: dict[str, Any],
+ sender_name: str | None,
+ ) -> None:
+ """Handle a data message (text, attachments, etc.)."""
+ message_text = data_message.get("message") or ""
+ attachments = data_message.get("attachments", [])
+ mentions = data_message.get("mentions", [])
+ timestamp = data_message.get("timestamp")
+
+ self.logger.info(
+ "Data message from {}: groupInfo={}, groupV2={}, keys={}",
+ sender_number,
+ data_message.get("groupInfo"),
+ data_message.get("groupV2"),
+ list(data_message.keys()),
+ )
+
+ if data_message.get("reaction"):
+ self.logger.debug(
+ "Ignoring reaction message from {}: {}", sender_number, data_message["reaction"]
+ )
+ return
+ if not message_text and not attachments:
+ self.logger.debug("Ignoring empty message from {}", sender_number)
+ return
+
+ group_info = data_message.get("groupInfo")
+ group_v2 = data_message.get("groupV2")
+ is_group_message = group_info is not None or group_v2 is not None
+ group_id = self._extract_group_id(group_info, group_v2)
+
+ allowed, chat_id = self._check_inbound_policy(
+ sender_id=sender_id,
+ sender_number=sender_number,
+ group_id=group_id,
+ is_group_message=is_group_message,
+ message_text=message_text,
+ mentions=mentions,
+ sender_name=sender_name,
+ timestamp=timestamp,
+ )
+ if not allowed:
+ # Mirror Slack: let denied DMs reach the base-class
+ # _handle_message so it can reply with a pairing code.
+ # Group denials stay dropped.
+ if not is_group_message and self.config.dm.enabled:
+ await super()._handle_message(
+ sender_id=sender_id,
+ chat_id=chat_id,
+ content="",
+ is_dm=True,
+ )
+ return
+
+ content, media_paths = self._assemble_inbound_content(
+ sender_name=sender_name,
+ sender_number=sender_number,
+ message_text=message_text,
+ attachments=attachments,
+ mentions=mentions,
+ is_group_message=is_group_message,
+ chat_id=chat_id,
+ )
+
+ self.logger.debug("Signal message from {}: {}...", sender_number, content[:50])
+
+ await self._start_typing(chat_id)
+ try:
+ await self._handle_message(
+ sender_id=sender_id,
+ chat_id=chat_id,
+ content=content,
+ media=media_paths,
+ metadata={
+ "timestamp": timestamp,
+ "sender_name": sender_name,
+ "sender_number": sender_number,
+ "is_group": is_group_message,
+ "group_id": group_id,
+ },
+ is_dm=not is_group_message,
+ )
+ except Exception:
+ await self._stop_typing(chat_id)
+ raise
+
+ def _check_inbound_policy(
+ self,
+ *,
+ sender_id: str,
+ sender_number: str,
+ group_id: str | None,
+ is_group_message: bool,
+ message_text: str,
+ mentions: list,
+ sender_name: str | None,
+ timestamp: int | None,
+ ) -> tuple[bool, str]:
+ """Decide whether to route an inbound message past DM/group policy.
+
+ Returns ``(allow, chat_id)``. Has one side effect: when a group
+ message passes the enabled+allowlist gates, it is appended to the
+ group's rolling context buffer before the mention check.
+ """
+ if is_group_message:
+ chat_id = group_id or sender_number
+ if not self.config.group.enabled:
+ self.logger.info("Ignoring group message from {} (groups disabled)", chat_id)
+ return False, chat_id
+ if (
+ self.config.group.policy == "allowlist"
+ and chat_id not in self.config.group.allow_from
+ ):
+ self.logger.info(
+ "Ignoring group message from {} (policy: {})",
+ chat_id,
+ self.config.group.policy,
+ )
+ return False, chat_id
+
+ self._add_to_group_buffer(
+ group_id=chat_id,
+ sender_name=sender_name or sender_number,
+ sender_number=sender_number,
+ message_text=message_text,
+ timestamp=timestamp,
+ )
+
+ is_command = bool(message_text and message_text.strip().startswith("/"))
+ if not is_command and not self._should_respond_in_group(message_text, mentions):
+ self.logger.info(
+ "Ignoring group message (require_mention: {})",
+ self.config.group.require_mention,
+ )
+ return False, chat_id
+ return True, chat_id
+
+ # Direct message
+ chat_id = sender_number
+ if not self.config.dm.enabled:
+ self.logger.debug("Ignoring DM from {} (DMs disabled)", sender_id)
+ return False, chat_id
+ if self.config.dm.policy == "allowlist":
+ if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from):
+ self.logger.debug(
+ "Ignoring DM from {} (policy: {})", sender_id, self.config.dm.policy
+ )
+ return False, chat_id
+ return True, chat_id
+
+ def _assemble_inbound_content(
+ self,
+ *,
+ sender_name: str | None,
+ sender_number: str,
+ message_text: str,
+ attachments: list,
+ mentions: list,
+ is_group_message: bool,
+ chat_id: str,
+ ) -> tuple[str, list[str]]:
+ """Build ``(content, media_paths)`` for an inbound message.
+
+ Pulls in group context, strips bot mentions, prefixes the sender's
+ display name on group messages, and copies any attachments from
+ signal-cli's storage into the channel media dir.
+ """
+ content_parts: list[str] = []
+ media_paths: list[str] = []
+
+ if is_group_message:
+ buffer_context = self._get_group_buffer_context(chat_id)
+ if buffer_context:
+ content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---")
+
+ if message_text:
+ if is_group_message:
+ message_text = self._strip_bot_mention(message_text, mentions)
+ display_name = sender_name or sender_number
+ message_text = f"[{display_name}]: {message_text}"
+ content_parts.append(message_text)
+
+ if attachments:
+ media_dir = get_media_dir("signal")
+ for attachment in attachments:
+ attachment_id = attachment.get("id")
+ content_type = attachment.get("contentType", "")
+ filename = attachment.get("filename") or f"attachment_{attachment_id}"
+ if not attachment_id:
+ continue
+ try:
+ source_path = self._signal_attachments_dir() / attachment_id
+ if source_path.exists():
+ dest_path = media_dir / f"signal_{safe_filename(filename)}"
+ shutil.copy2(source_path, dest_path)
+ media_paths.append(str(dest_path))
+ media_type = content_type.split("/")[0] if "/" in content_type else "file"
+ if media_type not in ("image", "audio", "video"):
+ media_type = "file"
+ content_parts.append(f"[{media_type}: {dest_path}]")
+ self.logger.debug("Downloaded attachment: {} -> {}", filename, dest_path)
+ else:
+ self.logger.warning("Attachment not found: {}", source_path)
+ content_parts.append(f"[attachment: {filename} - not found]")
+ except Exception as e:
+ self.logger.warning("Failed to process attachment {}: {}", filename, e)
+ content_parts.append(f"[attachment: {filename} - error]")
+
+ content = "\n".join(content_parts) if content_parts else "[empty message]"
+ return content, media_paths
+
+ def _add_to_group_buffer(
+ self,
+ group_id: str,
+ sender_name: str,
+ sender_number: str,
+ message_text: str,
+ timestamp: int | None,
+ ) -> None:
+ """
+ Add a message to the group's rolling buffer.
+
+ Args:
+ group_id: The group ID
+ sender_name: Display name of sender
+ sender_number: Phone number of sender
+ message_text: The message content
+ timestamp: Message timestamp
+ """
+ # Create buffer for this group if it doesn't exist
+ if group_id not in self._group_buffers:
+ self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size)
+
+ # Add message to buffer (deque will automatically drop oldest when full)
+ self._group_buffers[group_id].append(
+ {
+ "sender_name": sender_name,
+ "sender_number": sender_number,
+ "content": message_text,
+ "timestamp": timestamp,
+ }
+ )
+
+ self.logger.debug(
+ "Added message to group buffer {}: {}/{}",
+ group_id,
+ len(self._group_buffers[group_id]),
+ self.config.group_message_buffer_size,
+ )
+
+ def _get_group_buffer_context(self, group_id: str) -> str:
+ """
+ Get formatted context from the group's message buffer.
+
+ Args:
+ group_id: The group ID
+
+ Returns:
+ Formatted string of recent messages (excluding the current one)
+ """
+ if group_id not in self._group_buffers:
+ return ""
+
+ buffer = self._group_buffers[group_id]
+ if len(buffer) <= 1: # Only current message, no context
+ return ""
+
+ # Format all messages except the last one (which is the current message)
+ # We want to show context BEFORE the mention
+ context_messages = list(buffer)[:-1] # Exclude the last (current) message
+
+ lines = []
+ for msg in context_messages:
+ sender = msg["sender_name"]
+ content = msg["content"][:200] # Limit to 200 chars per message
+ lines.append(f"{sender}: {content}")
+
+ return "\n".join(lines)
+
+ def _signal_attachments_dir(self) -> Path:
+ """Return the directory signal-cli writes inbound attachments to.
+
+ Defaults to ``~/.local/share/signal-cli/attachments`` (the daemon's
+ platform default on Linux) when ``config.attachments_dir`` is unset.
+ """
+ configured = self.config.attachments_dir
+ if configured:
+ return Path(configured).expanduser()
+ return Path.home() / ".local/share/signal-cli/attachments"
+
+ @staticmethod
+ def _normalize_signal_id(value: str) -> list[str]:
+ """Normalize Signal identifiers (phone/uuid/service-id) for matching."""
+ raw = value.strip()
+ if not raw:
+ return []
+
+ normalized = [raw, raw.lower()]
+ if raw.startswith("+") and len(raw) > 1:
+ normalized.append(raw[1:])
+ elif raw.isdigit():
+ normalized.append(f"+{raw}")
+ return list(dict.fromkeys(normalized))
+
+ @classmethod
+ def _sender_matches_allowlist(cls, sender_id: str, allow_list: list[str]) -> bool:
+ """Return True if any normalized variant of sender_id is on allow_list.
+
+ Both ``sender_id`` and each allow_list entry can be a single
+ identifier or a pipe-joined composite of several (e.g.
+ ``"+1234567890|uuid-abc"``); both sides are split on ``|`` and each
+ part is run through ``_normalize_signal_id`` so an allowlist entry
+ like ``1234567890`` matches a sender ``+1234567890`` (and vice
+ versa), and case-only differences in UUIDs/ACIs match too.
+ """
+ if not allow_list:
+ return False
+ sender_variants: set[str] = set()
+ for part in str(sender_id).split("|"):
+ sender_variants.update(cls._normalize_signal_id(part))
+ if not sender_variants:
+ return False
+ allow_variants: set[str] = set()
+ for entry in allow_list:
+ for part in str(entry).split("|"):
+ allow_variants.update(cls._normalize_signal_id(part))
+ return bool(sender_variants & allow_variants)
+
+ def _remember_account_id_alias(self, value: str | None) -> None:
+ """Remember known bot identifiers for mention matching."""
+ if not value:
+ return
+ if not isinstance(value, str):
+ return
+ for candidate in self._normalize_signal_id(value):
+ self._account_id_aliases.add(candidate)
+
+ def _id_matches_account(self, value: str | None) -> bool:
+ """Return True when an identifier refers to the bot account."""
+ if not value:
+ return False
+ if not isinstance(value, str):
+ return False
+ return any(
+ candidate in self._account_id_aliases for candidate in self._normalize_signal_id(value)
+ )
+
+ @staticmethod
+ def _collect_sender_id_parts(envelope: dict[str, Any]) -> list[str]:
+ """Collect all known sender identifier variants from an envelope."""
+ parts: list[str] = []
+ for key in (
+ "sourceNumber",
+ "source",
+ "sourceUuid",
+ "sourceServiceId",
+ "sourceAci",
+ "sourceACI",
+ ):
+ value = envelope.get(key)
+ if not isinstance(value, str):
+ continue
+ candidate = value.strip()
+ if candidate and candidate not in parts:
+ parts.append(candidate)
+ return parts
+
+ @staticmethod
+ def _primary_sender_id(sender_parts: list[str]) -> str:
+ """Pick the best sender identifier for routing (prefer phone-like IDs)."""
+ for part in sender_parts:
+ if part.startswith("+") or part.isdigit():
+ return part
+ return sender_parts[0] if sender_parts else ""
+
+ @staticmethod
+ def _extract_group_id(group_info: Any, group_v2: Any) -> str | None:
+ """Extract group ID from groupInfo/groupV2 payloads across signal-cli variants."""
+ for group_obj in (group_info, group_v2):
+ if not isinstance(group_obj, dict):
+ continue
+ for key in ("groupId", "id", "groupID"):
+ value = group_obj.get(key)
+ if isinstance(value, str) and value:
+ return value
+ return None
+
+ @staticmethod
+ def _mention_id_candidates(mention: dict[str, Any]) -> list[str]:
+ """Extract possible identifier fields from a mention payload."""
+ ids: list[str] = []
+
+ def _walk(value: dict[str, Any] | Any, depth: int = 0) -> None:
+ if depth > 2:
+ return
+ if not isinstance(value, dict):
+ return
+ for key, child in value.items():
+ key_lower = str(key).lower()
+ if isinstance(child, str) and child:
+ if any(token in key_lower for token in ("number", "uuid", "serviceid", "aci")):
+ ids.append(child)
+ elif isinstance(child, dict):
+ _walk(child, depth + 1)
+
+ _walk(mention)
+ return list(dict.fromkeys(ids))
+
+ @staticmethod
+ def _mention_span(mention: dict[str, Any]) -> tuple[int, int] | None:
+ """Extract a safe (start, length) span from a mention."""
+ try:
+ start = int(mention.get("start", 0))
+ length = int(mention.get("length", 0))
+ except (TypeError, ValueError):
+ return None
+
+ if start < 0 or length <= 0:
+ return None
+ return (start, length)
+
+ @staticmethod
+ def _leading_placeholder_span(text: str | None) -> tuple[int, int] | None:
+ """
+ Detect a leading Signal mention placeholder when mention metadata is missing.
+
+ Some clients/integrations deliver mentions as a leading placeholder character
+ (typically U+FFFC) but omit `mentions` metadata in the payload.
+ """
+ if not text:
+ return None
+
+ start = 0
+ while start < len(text) and text[start].isspace():
+ start += 1
+
+ if start >= len(text):
+ return None
+
+ marker = text[start]
+ if marker not in ("\ufffc", "\ufffd", "\x1b"):
+ return None
+
+ next_index = start + 1
+ if next_index < len(text) and not text[next_index].isspace():
+ return None
+
+ return (start, 1)
+
+ def _should_respond_in_group(self, message_text: str, mentions: list[dict[str, Any]]) -> bool:
+ """
+ Determine if the bot should respond to a group message.
+
+ Args:
+ message_text: The message text content
+ mentions: List of mentions from Signal (format: [{"number": "+1234567890", "start": 0, "length": 10}])
+
+ Returns:
+ True if bot should respond, False otherwise
+ """
+ # Group reply behavior is controlled only by group.require_mention.
+ if not self.config.group.require_mention:
+ return True
+
+ # If mention is required, check if bot was mentioned.
+ for mention in mentions:
+ if not isinstance(mention, dict):
+ continue
+ for mention_id in self._mention_id_candidates(mention):
+ if self._id_matches_account(mention_id):
+ return True
+
+ # Some Signal clients emit mention spans without recipient identifiers
+ # (for handle-style mentions). Accept a leading identifier-less mention
+ # as a mention of the bot to avoid false negatives.
+ for mention in mentions:
+ if not isinstance(mention, dict):
+ continue
+ if self._mention_id_candidates(mention):
+ continue
+ span = self._mention_span(mention)
+ if not span:
+ continue
+ start, _ = span
+ if message_text is not None and not message_text[:start].strip():
+ self.logger.debug("Accepting identifier-less leading mention as bot mention")
+ return True
+
+ # Some payloads omit `mentions` but still include the leading mention
+ # placeholder character in the message body.
+ if not mentions and self._leading_placeholder_span(message_text):
+ self.logger.debug("Accepting leading placeholder mention without mention metadata")
+ return True
+
+ # Fallback: check for configured phone number in plain text.
+ if message_text and self.config.phone_number:
+ for account_id in self._normalize_signal_id(self.config.phone_number):
+ if account_id and account_id in message_text:
+ return True
+
+ return False
+
+ def _strip_bot_mention(self, text: str, mentions: list[dict[str, Any]]) -> str:
+ """
+ Remove bot mentions from message text.
+
+ Signal mentions are embedded in the text, so we need to remove them based on
+ the mentions array which provides start position and length.
+
+ Args:
+ text: Original message text
+ mentions: List of mention objects with start/length positions
+
+ Returns:
+ Text with bot mentions removed
+ """
+ if not text:
+ return text
+
+ # Build a list of (start, length) tuples for our bot's mentions
+ bot_mentions = []
+ for mention in mentions:
+ if not isinstance(mention, dict):
+ continue
+ mention_ids = self._mention_id_candidates(mention)
+ span = self._mention_span(mention)
+ if not span:
+ continue
+
+ # Strip matched bot mentions by ID.
+ if any(self._id_matches_account(mention_id) for mention_id in mention_ids):
+ bot_mentions.append(span)
+ continue
+
+ # Also strip identifier-less leading mention spans (handle mentions).
+ if not mention_ids:
+ start, _ = span
+ if not text[:start].strip():
+ bot_mentions.append(span)
+
+ if not bot_mentions:
+ placeholder_span = self._leading_placeholder_span(text)
+ if placeholder_span:
+ bot_mentions.append(placeholder_span)
+
+ # Sort mentions by start position (descending) to remove from end to start
+ # This prevents position shifts when removing earlier mentions
+ bot_mentions.sort(reverse=True)
+
+ # Remove each mention
+ for start, length in bot_mentions:
+ if start >= len(text):
+ continue
+ end = min(len(text), start + length)
+ text = text[:start] + text[end:]
+
+ return text.strip()
+
+ @staticmethod
+ def _is_group_chat_id(chat_id: str) -> bool:
+ """Return True when chat_id appears to be a Signal group ID (base64)."""
+ return "=" in chat_id or (len(chat_id) > 40 and "-" not in chat_id)
+
+ def _recipient_params(self, chat_id: str) -> dict[str, Any]:
+ """Build recipient params for signal-cli JSON-RPC methods."""
+ if self._is_group_chat_id(chat_id):
+ return {"groupId": chat_id}
+ return {"recipient": [chat_id]}
+
+ async def _start_typing(self, chat_id: str) -> None:
+ """Start periodic typing indicator updates for a chat."""
+ await self._stop_typing(chat_id, send_stop=False)
+ await self._send_typing(chat_id)
+ self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
+
+ async def _stop_typing(self, chat_id: str, send_stop: bool = True) -> None:
+ """Stop typing indicator updates for a chat."""
+ task = self._typing_tasks.pop(chat_id, None)
+ had_task = task is not None
+ if task and not task.done():
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+ if send_stop and had_task:
+ await self._send_typing(chat_id, stop=True)
+
+ async def _typing_loop(self, chat_id: str) -> None:
+ """Send typing updates periodically until cancelled."""
+ try:
+ while self._running:
+ await asyncio.sleep(self._TYPING_REFRESH_SECONDS)
+ await self._send_typing(chat_id, quiet_success=True)
+ except asyncio.CancelledError:
+ pass
+ except Exception as e:
+ self.logger.debug("Typing indicator loop stopped for {}: {}", chat_id, e)
+
+ async def _send_typing(
+ self, chat_id: str, stop: bool = False, quiet_success: bool = False
+ ) -> None:
+ """Send a typing START/STOP message via signal-cli."""
+ action = "stop" if stop else "start"
+ if (
+ not self._is_group_chat_id(chat_id)
+ and chat_id.startswith("+") is False
+ and chat_id not in self._typing_uuid_warnings
+ ):
+ self._typing_uuid_warnings.add(chat_id)
+ self.logger.warning(
+ "Signal DM recipient is UUID-only (no phone number in envelope). "
+ "Some Signal clients may not render typing indicators for this recipient form."
+ )
+ candidate_params: list[dict[str, Any]]
+ if self._is_group_chat_id(chat_id):
+ candidate_params = [{"groupId": chat_id}, {"groupId": [chat_id]}]
+ else:
+ candidate_params = [{"recipient": chat_id}, {"recipient": [chat_id]}]
+
+ last_error: Any | None = None
+ for params in candidate_params:
+ if stop:
+ params["stop"] = True
+ try:
+ response = await self._send_request("sendTyping", params)
+ except Exception as e:
+ last_error = str(e)
+ continue
+
+ if "error" not in response:
+ if not quiet_success:
+ self.logger.info("Signal typing {} sent for {}", action, chat_id)
+ return
+
+ last_error = response["error"]
+
+ self.logger.warning(
+ "Failed to send Signal typing {} for {}: {}", action, chat_id, last_error
+ )
+
+ async def _ensure_typing_indicators_enabled(self) -> None:
+ """Enable typing indicators on the bot account."""
+ response = await self._send_request("updateConfiguration", {"typingIndicators": True})
+ if "error" in response:
+ self.logger.warning(
+ "Failed to enable Signal typing indicators: {}", response["error"]
+ )
+ else:
+ self.logger.info("Signal typing indicators enabled on account configuration")
+
+ async def _send_request(
+ self, method: str, params: dict[str, Any] | None = None
+ ) -> dict[str, Any]:
+ """Send a JSON-RPC request via HTTP and wait for response."""
+ # Generate request ID
+ self._request_id += 1
+ request_id = self._request_id
+
+ # Build JSON-RPC request
+ request = {"jsonrpc": "2.0", "method": method, "id": request_id}
+
+ if params:
+ request["params"] = params
+
+ return await self._send_http_request(request)
+
+ async def _send_http_request(self, request: dict[str, Any]) -> dict[str, Any]:
+ """Send JSON-RPC request via HTTP."""
+ if not self._http:
+ raise RuntimeError("Not connected to signal-cli daemon")
+
+ try:
+ response = await self._http.post("/api/v1/rpc", json=request)
+ response.raise_for_status()
+ return response.json()
+ except Exception as e:
+ self.logger.error("HTTP request failed: {}", e)
+ return {"error": {"message": str(e)}}
diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py
index 41390f8b3..a75c897f4 100644
--- a/nanobot/channels/weixin.py
+++ b/nanobot/channels/weixin.py
@@ -79,6 +79,12 @@ BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION}
ERRCODE_SESSION_EXPIRED = -14
SESSION_PAUSE_DURATION_S = 60 * 60
+# iLink context_token is observed to expire server-side after ~90-160s of
+# agent inactivity (openclaw/openclaw#61174). Proactively refresh before
+# sending if the cached token is older than this threshold.
+CONTEXT_TOKEN_MAX_AGE_S = 60
+
+
# Retry constants (matching the reference plugin's monitor.ts)
MAX_CONSECUTIVE_FAILURES = 3
BACKOFF_DELAY_S = 30
@@ -159,6 +165,8 @@ class WeixinChannel(BaseChannel):
self._session_pause_until: float = 0.0
self._typing_tasks: dict[str, asyncio.Task] = {}
self._typing_tickets: dict[str, dict[str, Any]] = {}
+ self._context_token_at: dict[str, float] = {}
+ self._pending_tool_hints: dict[str, list[str]] = {}
# ------------------------------------------------------------------
# State persistence
@@ -486,6 +494,7 @@ class WeixinChannel(BaseChannel):
except Exception:
if not self._running:
break
+ self.logger.exception("WeChat poll loop error")
consecutive_failures += 1
if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
consecutive_failures = 0
@@ -495,6 +504,7 @@ class WeixinChannel(BaseChannel):
async def stop(self) -> None:
self._running = False
+ self._pending_tool_hints.clear()
if self._poll_task and not self._poll_task.done():
self._poll_task.cancel()
for chat_id in list(self._typing_tasks):
@@ -545,6 +555,7 @@ class WeixinChannel(BaseChannel):
# Check for API-level errors (monitor.ts checks both ret and errcode)
ret = data.get("ret", 0)
errcode = data.get("errcode", 0)
+
is_error = (ret is not None and ret != 0) or (errcode is not None and errcode != 0)
if is_error:
@@ -575,8 +586,10 @@ class WeixinChannel(BaseChannel):
# Process messages (WeixinMessage[] from types.ts)
msgs: list[dict] = data.get("msgs", []) or []
for msg in msgs:
- with suppress(Exception):
+ try:
await self._process_message(msg)
+ except Exception:
+ self.logger.exception("Failed to process WeChat message")
# ------------------------------------------------------------------
# Inbound message processing (matches inbound.ts + process-message.ts)
@@ -610,6 +623,7 @@ class WeixinChannel(BaseChannel):
ctx_token = msg.get("context_token", "")
if ctx_token:
self._context_tokens[from_user_id] = ctx_token
+ self._context_token_at[from_user_id] = time.time()
self._save_state()
# Parse item_list (WeixinMessage.item_list — types.ts:161)
@@ -915,6 +929,99 @@ class WeixinChannel(BaseChannel):
}
return ""
+ async def _refresh_context_token_if_stale(
+ self, chat_id: str, context_token: str
+ ) -> str:
+ """Return a fresh context_token if the cached one is too old.
+
+ iLink context_token expires server-side after a short idle period
+ (empirically ~90s). Proactively refreshing before sending prevents
+ silent message loss on long agent turns or cron pushes.
+ """
+ if not context_token:
+ return context_token
+
+ now = time.time()
+ cached_at = self._context_token_at.get(chat_id, 0)
+ age = now - cached_at
+
+ if age < CONTEXT_TOKEN_MAX_AGE_S:
+ return context_token
+
+ self.logger.debug(
+ "WeChat context_token for {} is {:.0f}s old; refreshing via getconfig",
+ chat_id,
+ age,
+ )
+
+ body: dict[str, Any] = {
+ "ilink_user_id": chat_id,
+ "context_token": context_token,
+ "base_info": BASE_INFO,
+ }
+ try:
+ data = await self._api_post("ilink/bot/getconfig", body)
+ except Exception as e:
+ self.logger.warning("WeChat getconfig failed for {}: {}", chat_id, e)
+ return context_token
+
+ if data.get("ret", 0) != 0:
+ self.logger.warning(
+ "WeChat getconfig returned ret={} for {}: {}",
+ data.get("ret"),
+ chat_id,
+ data.get("errmsg", ""),
+ )
+ return context_token
+
+ new_token = str(data.get("context_token", "") or "")
+ if new_token and new_token != context_token:
+ self.logger.info(
+ "WeChat context_token refreshed for {} (age {:.0f}s -> fresh)",
+ chat_id,
+ age,
+ )
+ self._context_tokens[chat_id] = new_token
+ self._context_token_at[chat_id] = now
+ self._save_state()
+ return new_token
+
+ return context_token
+
+ async def _flush_tool_hints(self, chat_id: str) -> None:
+ """Send any buffered tool hints for *chat_id* as a single message.
+
+ Tool hints are coalesced to reduce message count and avoid hitting the
+ WeChat iLink rate limit (~7 msgs / 5 min). Failures are logged but
+ not raised so that the main message send is never blocked.
+ """
+ hints = self._pending_tool_hints.pop(chat_id, None)
+ if not hints:
+ return
+
+ self.logger.info(
+ "Flushing {} buffered tool hint(s) for {}",
+ len(hints),
+ chat_id,
+ )
+
+ ctx_token = self._context_tokens.get(chat_id, "")
+ ctx_token = await self._refresh_context_token_if_stale(chat_id, ctx_token)
+ if not ctx_token:
+ self.logger.warning(
+ "Dropped {} buffered tool hint(s) for {}: no context_token",
+ len(hints),
+ chat_id,
+ )
+ return
+
+ try:
+ await self._send_text(chat_id, "\n\n".join(hints), ctx_token)
+ except Exception:
+ self.logger.exception(
+ "Failed to flush buffered tool hints for {}", chat_id
+ )
+
async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None:
"""Best-effort sendtyping wrapper."""
if not typing_ticket:
@@ -944,11 +1051,47 @@ class WeixinChannel(BaseChannel):
self._assert_session_active()
is_progress = bool((msg.metadata or {}).get("_progress", False))
+
+ # Buffer tool hints to coalesce consecutive ones and avoid burning
+ # WeChat iLink rate-limit quota (~7 msgs / 5 min).
+ if is_progress and (msg.metadata or {}).get("_tool_hint"):
+ if not self.send_tool_hints:
+ return
+ self._pending_tool_hints.setdefault(msg.chat_id, []).append(msg.content)
+ self.logger.debug(
+ "Buffered tool hint for {} (count={})",
+ msg.chat_id,
+ len(self._pending_tool_hints[msg.chat_id]),
+ )
+ return
+
+ # Reasoning deltas are invisible in WeChat (there is no reasoning
+ # UI). Skip them entirely — do not send and do not flush buffer.
+ if is_progress and (msg.metadata or {}).get("_reasoning_delta"):
+ self.logger.debug(
+ "Dropped invisible reasoning delta for {}", msg.chat_id
+ )
+ return
+
+ content = msg.content.strip()
+
+ # Empty progress messages (e.g. after_iteration tool_events) must
+ # NOT act as separators — they have no visible content.
+ if is_progress and not content and not (msg.media or []):
+ self.logger.debug(
+ "Skipped empty progress message for {} (no visible content)",
+ msg.chat_id,
+ )
+ return
+
+ # Flush buffered hints before sending any visible message.
+ await self._flush_tool_hints(msg.chat_id)
+
if not is_progress:
await self._stop_typing(msg.chat_id, clear_remote=True)
- content = msg.content.strip()
ctx_token = self._context_tokens.get(msg.chat_id, "")
+ ctx_token = await self._refresh_context_token_if_stale(msg.chat_id, ctx_token)
if not ctx_token:
raise RuntimeError(
f"WeChat context_token missing for chat_id={msg.chat_id}, cannot send"
@@ -1037,6 +1180,18 @@ class WeixinChannel(BaseChannel):
with suppress(Exception):
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
+ async def send_delta(
+ self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None
+ ) -> None:
+ """Weixin iLink does not support native streaming deltas.
+
+ We only hook ``_stream_end`` so buffered tool hints are flushed even
+ when the final answer carries the ``_streamed`` flag and bypasses
+ :meth:`send`.
+ """
+ if metadata and metadata.get("_stream_end"):
+ await self._flush_tool_hints(chat_id)
+
async def _start_typing(self, chat_id: str, context_token: str = "") -> None:
"""Start typing indicator immediately when a message is received."""
if not self._client or not self._token or not chat_id:
@@ -1120,10 +1275,11 @@ class WeixinChannel(BaseChannel):
}
data = await self._api_post("ilink/bot/sendmessage", body)
+ ret = data.get("ret", 0)
errcode = data.get("errcode", 0)
- if errcode and errcode != 0:
+ if (ret is not None and ret != 0) or (errcode is not None and errcode != 0):
raise RuntimeError(
- f"WeChat send text error (code {errcode}): {data.get('errmsg', '')}"
+ f"WeChat send text error (ret={ret}, errcode={errcode}): {data.get('errmsg', '')}"
)
async def _send_media_file(
@@ -1270,10 +1426,11 @@ class WeixinChannel(BaseChannel):
}
data = await self._api_post("ilink/bot/sendmessage", body)
+ ret = data.get("ret", 0)
errcode = data.get("errcode", 0)
- if errcode and errcode != 0:
+ if (ret is not None and ret != 0) or (errcode is not None and errcode != 0):
raise RuntimeError(
- f"WeChat send media error (code {errcode}): {data.get('errmsg', '')}"
+ f"WeChat send media error (ret={ret}, errcode={errcode}): {data.get('errmsg', '')}"
)
diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py
index c0ad7e758..2e094cc09 100644
--- a/nanobot/config/schema.py
+++ b/nanobot/config/schema.py
@@ -211,6 +211,7 @@ class ProvidersConfig(Base):
ant_ling: ProviderConfig = Field(default_factory=ProviderConfig) # Ant Ling
aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
+ novita: ProviderConfig = Field(default_factory=ProviderConfig) # Novita AI
volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
diff --git a/nanobot/providers/image_generation.py b/nanobot/providers/image_generation.py
index 8f25195cf..8adde0f55 100644
--- a/nanobot/providers/image_generation.py
+++ b/nanobot/providers/image_generation.py
@@ -2,8 +2,10 @@
from __future__ import annotations
+import asyncio
import base64
import binascii
+import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
@@ -31,6 +33,14 @@ _AIHUBMIX_ASPECT_RATIO_SIZES = {
}
_GEMINI_DEFAULT_TIMEOUT_S = 120.0
_GEMINI_IMAGEN_ASPECT_RATIOS = {"1:1", "9:16", "16:9", "3:4", "4:3"}
+_OLLAMA_DEFAULT_SIDE = 1024
+_OLLAMA_SIZE_PRESETS = {
+ "1K": 1024,
+ "2K": 2048,
+ "4K": 4096,
+}
+_OLLAMA_EXPLICIT_SIZE_RE = re.compile(r"^\s*(\d+)\s*[xX]\s*(\d+)\s*$")
+_OLLAMA_ASPECT_RATIO_RE = re.compile(r"^\s*(\d+)\s*:\s*(\d+)\s*$")
class ImageGenerationError(RuntimeError):
@@ -438,6 +448,139 @@ def _http_error_detail(response: httpx.Response) -> str:
return response.text[:500] or ""
+def _round_to_multiple(value: float, multiple: int = 8) -> int:
+ rounded = int(round(value / multiple) * multiple)
+ return max(multiple, rounded)
+
+
+def _ollama_dimensions(aspect_ratio: str | None, image_size: str | None) -> tuple[int, int]:
+ if image_size:
+ size = image_size.strip()
+ explicit = _OLLAMA_EXPLICIT_SIZE_RE.fullmatch(size)
+ if explicit:
+ return int(explicit.group(1)), int(explicit.group(2))
+ long_side = _OLLAMA_SIZE_PRESETS.get(size.upper(), _OLLAMA_DEFAULT_SIDE)
+ else:
+ long_side = _OLLAMA_DEFAULT_SIDE
+
+ if not aspect_ratio:
+ return long_side, long_side
+
+ ratio = _OLLAMA_ASPECT_RATIO_RE.fullmatch(aspect_ratio.strip())
+ if ratio is None:
+ return long_side, long_side
+
+ width_ratio = int(ratio.group(1))
+ height_ratio = int(ratio.group(2))
+ if width_ratio <= 0 or height_ratio <= 0:
+ return long_side, long_side
+
+ if width_ratio >= height_ratio:
+ width = long_side
+ height = _round_to_multiple(long_side * height_ratio / width_ratio)
+ else:
+ height = long_side
+ width = _round_to_multiple(long_side * width_ratio / height_ratio)
+ return max(8, width), max(8, height)
+
+
+def _ollama_image_data_url(value: str) -> str:
+ if value.startswith("data:image/"):
+ return value
+ return _b64_image_data_url(value)
+
+
+def _ollama_images_from_payload(payload: dict[str, Any]) -> list[str]:
+ images: list[str] = []
+
+ def collect(value: Any) -> None:
+ if isinstance(value, str) and value:
+ images.append(_ollama_image_data_url(value))
+ elif isinstance(value, list):
+ for item in value:
+ collect(item)
+
+ collect(payload.get("image"))
+ collect(payload.get("images"))
+ return images
+
+
+class OllamaImageGenerationClient(ImageGenerationProvider):
+ """Async client for Ollama native image generation models."""
+
+ provider_name = "ollama"
+ default_timeout = 300.0
+
+ def _default_base_url(self) -> str:
+ return "http://localhost:11434/api"
+
+ def _resolve_base_url(self, api_base: str | None) -> str:
+ if api_base:
+ base = api_base.rstrip("/")
+ if base.endswith("/v1"):
+ return f"{base[:-3]}/api"
+ return base
+ return self._default_base_url()
+
+ async def generate(
+ self,
+ *,
+ prompt: str,
+ model: str,
+ reference_images: list[str] | None = None,
+ aspect_ratio: str | None = None,
+ image_size: str | None = None,
+ ) -> GeneratedImageResponse:
+ if reference_images:
+ raise ImageGenerationError(
+ "Ollama image generation does not support reference images"
+ )
+
+ width, height = _ollama_dimensions(aspect_ratio, image_size)
+ body: dict[str, Any] = {
+ "model": model,
+ "prompt": prompt,
+ "width": width,
+ "height": height,
+ "steps": 0,
+ }
+ body.update(self.extra_body)
+ body["stream"] = False
+
+ headers = {
+ "Content-Type": "application/json",
+ **self.extra_headers,
+ }
+ if self.api_key:
+ headers["Authorization"] = f"Bearer {self.api_key}"
+
+ url = f"{self.api_base}/generate"
+ response = await self._http_post(url, headers=headers, body=body)
+
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as exc:
+ detail = _http_error_detail(response)
+ logger.error(
+ "Ollama image generation failed (HTTP {}): {}",
+ response.status_code,
+ detail,
+ )
+ raise ImageGenerationError(
+ f"Ollama image generation failed (HTTP {response.status_code}): {detail}"
+ ) from exc
+
+ data = response.json()
+ images = _ollama_images_from_payload(data)
+
+ self._require_images(images, data)
+
+ response_text = data.get("response")
+ content = response_text if isinstance(response_text, str) else ""
+
+ return GeneratedImageResponse(images=images, content=content, raw=data)
+
+
class GeminiImageGenerationClient(ImageGenerationProvider):
"""Async client for Gemini/Imagen image generation via the Generative Language API."""
@@ -759,6 +902,426 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]:
return images
+# ---------------------------------------------------------------------------
+# OpenAI image generation
+# ---------------------------------------------------------------------------
+
+_OPENAI_DALLE2_SUPPORTED_SIZES = {"256x256", "512x512", "1024x1024"}
+_OPENAI_DALLE3_SUPPORTED_SIZES = {"1024x1024", "1792x1024", "1024x1792"}
+_OPENAI_GPT_IMAGE_SUPPORTED_SIZES = {
+ "1024x1024",
+ "1536x1024",
+ "1024x1536",
+ "auto",
+}
+_OPENAI_DALLE2_ASPECT_RATIO_SIZES = {
+ "1:1": "1024x1024",
+ "16:9": "1024x1024",
+ "9:16": "1024x1024",
+ "3:4": "1024x1024",
+ "4:3": "1024x1024",
+}
+_OPENAI_DALLE3_ASPECT_RATIO_SIZES = {
+ "1:1": "1024x1024",
+ "16:9": "1792x1024",
+ "9:16": "1024x1792",
+ "3:4": "1024x1792",
+ "4:3": "1792x1024",
+}
+_OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES = {
+ "1:1": "1024x1024",
+ "16:9": "1536x1024",
+ "9:16": "1024x1536",
+ "3:4": "1024x1536",
+ "4:3": "1536x1024",
+}
+
+
+class OpenAIImageGenerationClient(ImageGenerationProvider):
+ """OpenAI Images API using an API key (``providers.openai.apiKey``)."""
+
+ provider_name = "openai"
+ missing_key_message = (
+ "OpenAI API key is not configured. Set providers.openai.apiKey."
+ )
+
+ def _default_base_url(self) -> str:
+ return "https://api.openai.com/v1"
+
+ @staticmethod
+ def _strip_model_prefix(model: str) -> str:
+ """Remove ``openai/`` prefix if present (OpenRouter convention)."""
+ if model.startswith("openai/") or model.startswith("openai_codex/"):
+ return model.split("/", 1)[1]
+ return model
+
+ async def generate(
+ self,
+ *,
+ prompt: str,
+ model: str,
+ reference_images: list[str] | None = None,
+ aspect_ratio: str | None = None,
+ image_size: str | None = None,
+ ) -> GeneratedImageResponse:
+ if not self.api_key:
+ raise ImageGenerationError(self.missing_key_message)
+
+ if reference_images:
+ logger.warning(
+ "DALL-E models do not support reference images; "
+ "ignoring {} reference image(s) for {}",
+ len(reference_images),
+ model,
+ )
+
+ headers = {
+ "Authorization": f"Bearer {self.api_key}",
+ "Content-Type": "application/json",
+ **self.extra_headers,
+ }
+
+ clean_model = self._strip_model_prefix(model)
+ body: dict[str, Any] = {
+ "model": clean_model,
+ "prompt": prompt,
+ }
+
+ if not _openai_is_gpt_image_model(clean_model):
+ body["response_format"] = "b64_json"
+ body["n"] = 1
+
+ size = _openai_size(clean_model, aspect_ratio, image_size)
+ if size:
+ body["size"] = size
+
+ body.update(self.extra_body)
+
+ logger.info("OpenAI Images API request: POST {}/images/generations body={}", self.api_base, body)
+
+ response = await self._http_post(
+ f"{self.api_base}/images/generations",
+ headers=headers,
+ body=body,
+ )
+
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as exc:
+ detail = response.text[:1000]
+ logger.error("OpenAI Images API error ({}): {}", response.status_code, detail)
+ raise ImageGenerationError(
+ f"OpenAI image generation failed (HTTP {response.status_code}): {detail}"
+ ) from exc
+
+ payload = response.json()
+ logger.info("OpenAI Images API response ({}): {}", response.status_code,
+ {k: v for k, v in payload.items() if k != "data"})
+
+ client = self._client
+ owns_client = client is None
+ if owns_client:
+ client = httpx.AsyncClient(timeout=self.timeout)
+ try:
+ images = await _openai_images_from_payload(client, payload)
+ finally:
+ if owns_client:
+ await client.aclose()
+
+ self._require_images(images, payload)
+
+ return GeneratedImageResponse(images=images, content="", raw=payload)
+
+
+# ---------------------------------------------------------------------------
+# OpenAI Codex image generation
+# ---------------------------------------------------------------------------
+
+
+class CodexImageGenerationClient(ImageGenerationProvider):
+ """OpenAI image generation via Codex subscription OAuth.
+
+ Uses the Codex Responses API with the ``image_generation`` tool
+ (the same mechanism ChatGPT uses internally). No API key required —
+ the Codex OAuth token from ``oauth_cli_kit`` is used instead.
+ """
+
+ provider_name = "openai_codex"
+ missing_key_message = (
+ "Codex OAuth token is unavailable. "
+ "Log in with Codex subscription first."
+ )
+
+ def _default_base_url(self) -> str:
+ return "https://chatgpt.com/backend-api"
+
+ def _codex_model(self, model: str) -> str:
+ """Strip the ``openai-codex/`` prefix if present."""
+ if model.startswith(("openai-codex/", "openai_codex/")):
+ return model.split("/", 1)[1]
+ return model
+
+ async def generate(
+ self,
+ *,
+ prompt: str,
+ model: str,
+ reference_images: list[str] | None = None,
+ aspect_ratio: str | None = None,
+ image_size: str | None = None,
+ ) -> GeneratedImageResponse:
+ try:
+ from oauth_cli_kit import get_token as get_codex_token
+ except ImportError:
+ raise ImageGenerationError(self.missing_key_message)
+
+ try:
+ token = await asyncio.to_thread(get_codex_token)
+ except Exception as exc:
+ raise ImageGenerationError(self.missing_key_message) from exc
+ if not token or not token.access:
+ raise ImageGenerationError(self.missing_key_message)
+
+ logger.info(
+ "Using Codex OAuth token for image generation (account: {})",
+ token.account_id,
+ )
+
+ if reference_images:
+ logger.warning(
+ "Codex image generation does not support reference images; "
+ "ignoring {} reference image(s)",
+ len(reference_images),
+ )
+
+ headers = {
+ "Authorization": f"Bearer {token.access}",
+ "chatgpt-account-id": token.account_id,
+ "OpenAI-Beta": "responses=experimental",
+ "originator": "nanobot",
+ "User-Agent": "nanobot (python)",
+ "Content-Type": "application/json",
+ **self.extra_headers,
+ }
+
+ body: dict[str, Any] = {
+ "model": self._codex_model(model),
+ "instructions": "Generate an image based on the user's request.",
+ "input": [{"role": "user", "content": prompt}],
+ "tools": [{"type": "image_generation"}],
+ "tool_choice": "auto",
+ "stream": True,
+ "store": False,
+ }
+ body.update(self.extra_body)
+
+ logger.info("Codex Responses API request: POST {}/codex/responses body={}",
+ self.api_base, {k: v for k, v in body.items() if k != "input"})
+
+ response = await self._http_post(
+ f"{self.api_base}/codex/responses",
+ headers=headers,
+ body=body,
+ )
+
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as exc:
+ detail = response.text[:1000]
+ logger.error("Codex Responses API error ({}): {}", response.status_code, detail)
+ raise ImageGenerationError(
+ f"Codex image generation failed (HTTP {response.status_code}): {detail}"
+ ) from exc
+
+ images, content_text = await _parse_codex_sse_images(response)
+
+ raw = {"status": "completed"}
+ self._require_images(images, raw)
+
+ return GeneratedImageResponse(images=images, content=content_text, raw=raw)
+
+
+def _openai_size(
+ model: str,
+ aspect_ratio: str | None,
+ image_size: str | None,
+) -> str:
+ """Resolve aspect ratio or image_size to an OpenAI Images API size string."""
+ sizes, supported_sizes = _openai_size_options(model)
+ explicit_size = _normalize_openai_image_size(image_size)
+ if explicit_size and _openai_explicit_size_supported(
+ explicit_size,
+ supported_sizes=supported_sizes,
+ ):
+ return explicit_size
+ if explicit_size:
+ logger.warning(
+ "OpenAI image size '{}' is not supported by {}; using aspect ratio/default size",
+ explicit_size,
+ model,
+ )
+ if aspect_ratio and aspect_ratio in sizes:
+ return sizes[aspect_ratio]
+ return "1024x1024"
+
+
+def _openai_is_gpt_image_model(model: str) -> bool:
+ normalized = model.lower()
+ return normalized.startswith(("gpt-image", "chatgpt-image"))
+
+
+def _openai_size_options(model: str) -> tuple[dict[str, str], set[str] | None]:
+ normalized = model.lower()
+ if normalized.startswith("dall-e-2"):
+ return _OPENAI_DALLE2_ASPECT_RATIO_SIZES, _OPENAI_DALLE2_SUPPORTED_SIZES
+ if normalized.startswith("dall-e-3"):
+ return _OPENAI_DALLE3_ASPECT_RATIO_SIZES, _OPENAI_DALLE3_SUPPORTED_SIZES
+ if normalized.startswith("gpt-image-2"):
+ return _OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, None
+ return _OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, _OPENAI_GPT_IMAGE_SUPPORTED_SIZES
+
+
+def _normalize_openai_image_size(image_size: str | None) -> str | None:
+ if not image_size:
+ return None
+ normalized = image_size.strip().lower()
+ return normalized or None
+
+
+def _openai_explicit_size_supported(
+ size: str,
+ *,
+ supported_sizes: set[str] | None,
+) -> bool:
+ if supported_sizes is not None:
+ return size in supported_sizes
+ width, sep, height = size.partition("x")
+ return bool(sep and width.isdecimal() and height.isdecimal())
+
+
+async def _openai_images_from_payload(
+ client: httpx.AsyncClient,
+ payload: dict[str, Any],
+) -> list[str]:
+ """Extract images from OpenAI Images API response.
+
+ Handles both ``b64_json`` (preferred) and ``url`` (downloaded) formats.
+ """
+ images: list[str] = []
+ for item in payload.get("data") or []:
+ if not isinstance(item, dict):
+ continue
+ b64 = item.get("b64_json")
+ if isinstance(b64, str) and b64:
+ images.append(_b64_image_data_url(b64))
+ continue
+ url = item.get("url")
+ if isinstance(url, str) and url:
+ images.append(await _download_image_data_url(client, url))
+ return images
+
+
+def _codex_responses_images_from_payload(payload: dict[str, Any]) -> list[str]:
+ """Extract images from Codex Responses API ``image_generation_call`` output."""
+ images: list[str] = []
+ for item in payload.get("output") or []:
+ if not isinstance(item, dict):
+ continue
+ if item.get("type") != "image_generation_call":
+ continue
+ result = item.get("result")
+ if isinstance(result, str):
+ images.append(result if result.startswith("data:image/") else _b64_image_data_url(result))
+ continue
+ if isinstance(result, dict):
+ image_url = result.get("image_url") or result.get("image") or ""
+ if isinstance(image_url, str):
+ images.append(image_url if image_url.startswith("data:image/") else _b64_image_data_url(image_url))
+ return images
+
+
+async def _parse_codex_sse_images(
+ response: httpx.Response,
+) -> tuple[list[str], str]:
+ """Parse a Codex Responses API SSE stream for image generation output.
+
+ Returns ``(images, content_text)``.
+ """
+ import json as _json
+
+ images: list[str] = []
+ text_parts: list[str] = []
+
+ buffer: list[str] = []
+ async for line_bytes in response.aiter_lines():
+ line = line_bytes.strip()
+ if line == "":
+ if buffer:
+ data_lines = []
+ for bl in buffer:
+ if bl.startswith("data:"):
+ data_lines.append(bl[5:].strip())
+ buffer.clear()
+ if data_lines:
+ raw = "".join(data_lines)
+ if raw == "[DONE]":
+ break
+ try:
+ event = _json.loads(raw)
+ except Exception:
+ continue
+ ev_type = event.get("type", "")
+ if ev_type in ("error", "response.failed"):
+ logger.error("Codex SSE failure: {}", raw[:2000])
+ _collect_images_from_sse_event(event, images)
+ _collect_text_from_sse_event(event, text_parts)
+ continue
+ buffer.append(line)
+
+ # flush remaining
+ if buffer:
+ data_lines = [bl[5:].strip() for bl in buffer if bl.startswith("data:")]
+ raw = "".join(data_lines)
+ if raw and raw != "[DONE]":
+ try:
+ event = _json.loads(raw)
+ except Exception:
+ pass
+ else:
+ _collect_images_from_sse_event(event, images)
+ _collect_text_from_sse_event(event, text_parts)
+
+ return images, "".join(text_parts).strip()
+
+
+def _collect_images_from_sse_event(event: dict[str, Any], images: list[str]) -> None:
+ if event.get("type") != "response.output_item.done":
+ return
+ item = event.get("item") or {}
+ if item.get("type") != "image_generation_call":
+ return
+ result = item.get("result")
+ if isinstance(result, str):
+ if result.startswith("data:image/"):
+ images.append(result)
+ else:
+ images.append(_b64_image_data_url(result))
+ elif isinstance(result, dict):
+ image_url = result.get("image_url") or result.get("image") or ""
+ if isinstance(image_url, str):
+ if image_url.startswith("data:image/"):
+ images.append(image_url)
+ else:
+ images.append(_b64_image_data_url(image_url))
+
+
+def _collect_text_from_sse_event(event: dict[str, Any], text_parts: list[str]) -> None:
+ if event.get("type") == "response.output_text.delta":
+ delta = event.get("delta")
+ if isinstance(delta, str) and delta:
+ text_parts.append(delta)
+
+
# ---------------------------------------------------------------------------
# StepFun (阶跃星辰) image generation
# ---------------------------------------------------------------------------
@@ -886,8 +1449,11 @@ def _stepfun_images_from_payload(payload: dict[str, Any]) -> list[str]:
# Provider registration
# ---------------------------------------------------------------------------
-register_image_gen_provider(OpenRouterImageGenerationClient)
register_image_gen_provider(AIHubMixImageGenerationClient)
+register_image_gen_provider(CodexImageGenerationClient)
register_image_gen_provider(GeminiImageGenerationClient)
+register_image_gen_provider(OllamaImageGenerationClient)
register_image_gen_provider(MiniMaxImageGenerationClient)
+register_image_gen_provider(OpenAIImageGenerationClient)
+register_image_gen_provider(OpenRouterImageGenerationClient)
register_image_gen_provider(StepFunImageGenerationClient)
diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py
index b8112b529..8281d7d20 100644
--- a/nanobot/providers/openai_compat_provider.py
+++ b/nanobot/providers/openai_compat_provider.py
@@ -11,6 +11,7 @@ import secrets
import string
import time
import uuid
+from collections import deque
from collections.abc import Awaitable, Callable
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any
@@ -74,41 +75,43 @@ _THINKING_STYLE_MAP: dict[str, Any] = {
"enable_thinking": lambda on: {"enable_thinking": on},
"reasoning_split": lambda on: {"reasoning_split": on},
}
+_GATEWAY_REASONING_STYLE_MAP: dict[str, Any] = {
+ "reasoning_effort": lambda effort: {"reasoning": {"effort": effort}},
+}
+_MODEL_THINKING_STYLES: dict[str, str] = {
+ **dict.fromkeys(_KIMI_THINKING_MODELS, "thinking_type"),
+ **dict.fromkeys(_MIMO_THINKING_MODELS, "thinking_type"),
+}
-def _is_kimi_thinking_model(model_name: str) -> bool:
- """Return True if model_name refers to a Kimi thinking-capable model.
-
- Supports two forms:
- - Exact match: e.g. kimi-k2.5 / kimi-k2.6 in _KIMI_THINKING_MODELS
- - Slug match: moonshotai/kimi-k2.5 -> the part after the last "/"
- is checked against _KIMI_THINKING_MODELS
-
- This covers both the native Moonshot provider (bare slug) and
- OpenRouter-style names (``"publisher/slug"``).
- """
- name = model_name.lower()
- if name in _KIMI_THINKING_MODELS:
- return True
- if "/" in name and name.rsplit("/", 1)[1] in _KIMI_THINKING_MODELS:
- return True
- return False
+def _model_slug(model_name: str) -> str:
+ return model_name.lower().rsplit("/", 1)[-1]
-def _is_mimo_thinking_model(model_name: str) -> bool:
- """Return True if model_name refers to a MiMo thinking-capable model.
+def _model_thinking_style(model_name: str) -> str:
+ return _MODEL_THINKING_STYLES.get(_model_slug(model_name), "")
- Mirrors _is_kimi_thinking_model: gateway providers (e.g. OpenRouter
- routing ``xiaomi/mimo-v2.5-pro``) have no ``thinking_style`` on their
- spec, so the spec-driven branch in _build_kwargs misses them. The
- model-name path catches those cases.
- """
- name = model_name.lower()
- if name in _MIMO_THINKING_MODELS:
- return True
- if "/" in name and name.rsplit("/", 1)[1] in _MIMO_THINKING_MODELS:
- return True
- return False
+
+def _thinking_styles_for(spec: ProviderSpec | None, model_name: str) -> list[str]:
+ styles: list[str] = []
+ if spec and spec.thinking_style:
+ styles.append(spec.thinking_style)
+ model_style = _model_thinking_style(model_name)
+ if model_style and model_style not in styles:
+ styles.append(model_style)
+ return styles
+
+
+def _thinking_extra_body(style: str, thinking_enabled: bool) -> dict[str, Any] | None:
+ builder = _THINKING_STYLE_MAP.get(style)
+ return builder(thinking_enabled) if builder else None
+
+
+def _gateway_reasoning_extra_body(style: str, effort: str | None) -> dict[str, Any] | None:
+ if not effort:
+ return None
+ builder = _GATEWAY_REASONING_STYLE_MAP.get(style)
+ return builder(effort) if builder else None
def _openai_compat_timeout_s() -> float:
@@ -461,6 +464,7 @@ class OpenAICompatProvider(LLMProvider):
"""Strip non-standard keys, normalize tool_call IDs."""
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
id_map: dict[str, str] = {}
+ pending_tool_ids: dict[str, deque[str]] = {}
force_string_content = bool(self._spec and self._spec.name == "deepseek")
def map_id(value: Any) -> Any:
@@ -468,15 +472,49 @@ class OpenAICompatProvider(LLMProvider):
return value
return id_map.setdefault(value, self._normalize_tool_call_id(value))
+ def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str:
+ if isinstance(value, str) and value:
+ base = map_id(value)
+ else:
+ base = _short_tool_id()
+ if not isinstance(base, str) or not base:
+ base = _short_tool_id()
+ if base not in used_ids:
+ return base
+ seed = value if isinstance(value, str) and value else base
+ salt = 1
+ while True:
+ candidate = self._normalize_tool_call_id(f"{seed}:{idx}:{salt}")
+ if isinstance(candidate, str) and candidate not in used_ids:
+ return candidate
+ salt += 1
+
+ def map_tool_result_id(value: Any) -> Any:
+ if not isinstance(value, str):
+ return value
+ queue = pending_tool_ids.get(value)
+ if queue:
+ mapped = queue.popleft()
+ if not queue:
+ pending_tool_ids.pop(value, None)
+ return mapped
+ return map_id(value)
+
for clean in sanitized:
if isinstance(clean.get("tool_calls"), list):
normalized = []
- for tc in clean["tool_calls"]:
+ used_ids: set[str] = set()
+ for idx, tc in enumerate(clean["tool_calls"]):
if not isinstance(tc, dict):
normalized.append(tc)
continue
tc_clean = dict(tc)
- tc_clean["id"] = map_id(tc_clean.get("id"))
+ raw_id = tc_clean.get("id")
+ mapped_id = unique_tool_id(raw_id, used_ids, idx)
+ tc_clean["id"] = mapped_id
+ used_ids.add(mapped_id)
+ if isinstance(raw_id, str) and raw_id:
+ pending_tool_ids.setdefault(raw_id, deque()).append(mapped_id)
function = tc_clean.get("function")
if isinstance(function, dict):
function_clean = dict(function)
@@ -494,7 +532,7 @@ class OpenAICompatProvider(LLMProvider):
# that mix non-empty content with tool_calls.
clean["content"] = None
if "tool_call_id" in clean and clean["tool_call_id"]:
- clean["tool_call_id"] = map_id(clean["tool_call_id"])
+ clean["tool_call_id"] = map_tool_result_id(clean["tool_call_id"])
if (
force_string_content
and not (clean.get("role") == "assistant" and clean.get("tool_calls"))
@@ -581,39 +619,27 @@ class OpenAICompatProvider(LLMProvider):
if wire_effort and semantic_effort != "none":
kwargs["reasoning_effort"] = wire_effort
- # Provider-specific thinking parameters.
- # Only sent when reasoning_effort is explicitly configured so that
- # the provider default is preserved otherwise.
- # The mapping is driven by ProviderSpec.thinking_style so that adding
- # a new provider never requires touching this function.
- if spec and spec.thinking_style and reasoning_effort is not None:
+ # Only send thinking controls when reasoning_effort is explicit so
+ # omitting the config preserves each provider's default.
+ if reasoning_effort is not None:
thinking_enabled = semantic_effort not in ("none", "minimal")
- extra = _THINKING_STYLE_MAP.get(spec.thinking_style, lambda _: None)(thinking_enabled)
- if extra:
- kwargs.setdefault("extra_body", {}).update(extra)
+ for thinking_style in _thinking_styles_for(spec, model_name):
+ extra = _thinking_extra_body(thinking_style, thinking_enabled)
+ if extra:
+ kwargs.setdefault("extra_body", {}).update(extra)
+ gateway_style = getattr(spec, "gateway_reasoning_style", "") if spec else ""
+ if gateway_style and _model_thinking_style(model_name):
+ extra = _gateway_reasoning_extra_body(gateway_style, semantic_effort)
+ if extra:
+ kwargs.setdefault("extra_body", {}).update(extra)
- # Model-level thinking injection for Kimi thinking-capable models.
- # Strip any provider prefix (e.g. "moonshotai/") before the set lookup
- # so that OpenRouter-style names like "moonshotai/kimi-k2.5" are handled
- # identically to bare names like "kimi-k2.5".
- if reasoning_effort is not None and _is_kimi_thinking_model(model_name):
- thinking_enabled = semantic_effort not in ("none", "minimal")
- kwargs.setdefault("extra_body", {}).update(
- {"thinking": {"type": "enabled" if thinking_enabled else "disabled"}}
- )
-
- # Model-level thinking injection for MiMo thinking-capable models.
- # Same shape as Kimi: gateway providers (OpenRouter, etc.) lack the
- # xiaomi_mimo spec's thinking_style, so the spec-driven branch above
- # misses them — match by model name to catch "xiaomi/mimo-v2.5-pro"
- # and friends. (Direct xiaomi_mimo requests are also covered here;
- # both branches write the same payload, so the dict update is a
- # safe no-op for already-handled cases.)
- if reasoning_effort is not None and _is_mimo_thinking_model(model_name):
- thinking_enabled = semantic_effort not in ("none", "minimal")
- kwargs.setdefault("extra_body", {}).update(
- {"thinking": {"type": "enabled" if thinking_enabled else "disabled"}}
- )
+ # Moonshot rejects requests that carry both 'reasoning_effort'
+ # and the native 'thinking' param. We already expressed the
+ # user's intent via the provider-native shape, so drop the
+ # redundant wire-level kwarg. Only kimi models need this —
+ # Xiaomi's API accepts both params.
+ if _model_slug(model_name) in _KIMI_THINKING_MODELS:
+ kwargs.pop("reasoning_effort", None)
if tools:
kwargs["tools"] = tools
@@ -628,8 +654,7 @@ class OpenAICompatProvider(LLMProvider):
and semantic_effort not in ("none", "minimal")
and (
(spec and spec.thinking_style)
- or _is_kimi_thinking_model(model_name)
- or _is_mimo_thinking_model(model_name)
+ or _model_thinking_style(model_name)
)
)
implicit_deepseek_thinking = (
@@ -1097,6 +1122,15 @@ class OpenAICompatProvider(LLMProvider):
if delta:
_accum_legacy_function_call(getattr(delta, "function_call", None))
+ # Some providers (e.g. Zhipu/GLM) reuse the same tool_call id for
+ # parallel tool calls in streaming mode. Deduplicate before building
+ # the response so downstream tool messages don't collide.
+ _seen_tc_ids: set[str] = set()
+ for b in tc_bufs.values():
+ if not b["id"] or b["id"] in _seen_tc_ids:
+ b["id"] = _short_tool_id()
+ _seen_tc_ids.add(b["id"])
+
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=[
diff --git a/nanobot/providers/openai_responses/converters.py b/nanobot/providers/openai_responses/converters.py
index e0bfe832d..27c59ab58 100644
--- a/nanobot/providers/openai_responses/converters.py
+++ b/nanobot/providers/openai_responses/converters.py
@@ -15,6 +15,7 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str
"""
system_prompt = ""
input_items: list[dict[str, Any]] = []
+ used_item_ids: set[str] = set()
for idx, msg in enumerate(messages):
role = msg.get("role")
@@ -30,17 +31,19 @@ def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str
if role == "assistant":
if isinstance(content, str) and content:
+ message_id = _unique_item_id(f"msg_{idx}", used_item_ids)
input_items.append({
"type": "message", "role": "assistant",
"content": [{"type": "output_text", "text": content}],
- "status": "completed", "id": f"msg_{idx}",
+ "status": "completed", "id": message_id,
})
for tool_call in msg.get("tool_calls", []) or []:
fn = tool_call.get("function") or {}
call_id, item_id = split_tool_call_id(tool_call.get("id"))
+ response_item_id = _unique_item_id(item_id or f"fc_{idx}", used_item_ids)
input_items.append({
"type": "function_call",
- "id": item_id or f"fc_{idx}",
+ "id": response_item_id,
"call_id": call_id or f"call_{idx}",
"name": fn.get("name"),
"arguments": fn.get("arguments") or "{}",
@@ -97,6 +100,20 @@ def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
return converted
+def _unique_item_id(item_id: str, used: set[str]) -> str:
+ """Return a Responses input item id that is unique within one request."""
+ if item_id not in used:
+ used.add(item_id)
+ return item_id
+
+ suffix = 2
+ while f"{item_id}_{suffix}" in used:
+ suffix += 1
+ unique = f"{item_id}_{suffix}"
+ used.add(unique)
+ return unique
+
+
def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]:
"""Split a compound ``call_id|item_id`` string.
diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py
index 7c8edd271..ab7e2cf1e 100644
--- a/nanobot/providers/registry.py
+++ b/nanobot/providers/registry.py
@@ -71,6 +71,11 @@ class ProviderSpec:
# "reasoning_split" — {"reasoning_split": true/false} (MiniMax)
thinking_style: str = ""
+ # Gateway-native reasoning control to pair with model-level thinking styles.
+ # "reasoning_effort" — {"reasoning": {"effort": }}
+ # (OpenRouter)
+ gateway_reasoning_style: str = ""
+
# When True, treat the "reasoning" response field as formal content
# when "content" is empty. Only set this for providers (e.g. StepFun)
# whose API returns the actual answer in "reasoning" instead of "content".
@@ -142,6 +147,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
detect_by_base_keyword="openrouter",
default_api_base="https://openrouter.ai/api/v1",
supports_prompt_caching=True,
+ gateway_reasoning_style="reasoning_effort",
),
# Hugging Face Inference Providers: OpenAI-compatible router for chat models.
ProviderSpec(
@@ -193,6 +199,18 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
default_api_base="https://api.siliconflow.cn/v1",
),
+ # Novita AI: OpenAI-compatible gateway for hosted model APIs.
+ ProviderSpec(
+ name="novita",
+ keywords=("novita",),
+ env_key="NOVITA_API_KEY",
+ display_name="Novita AI",
+ backend="openai_compat",
+ is_gateway=True,
+ detect_by_base_keyword="novita",
+ default_api_base="https://api.novita.ai/openai",
+ ),
+
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
ProviderSpec(
name="volcengine",
diff --git a/nanobot/templates/AGENTS.md b/nanobot/templates/AGENTS.md
index 0bf6de3d3..46cfc08c3 100644
--- a/nanobot/templates/AGENTS.md
+++ b/nanobot/templates/AGENTS.md
@@ -1,5 +1,9 @@
# Agent Instructions
+## Workspace Guidance
+
+Use this file for project-specific preferences, recurring workflow conventions, and instructions you want the agent to remember for this workspace. Keep durable facts about the user in `USER.md`, personality/style guidance in `SOUL.md`, and long-term memory in `memory/MEMORY.md`.
+
## Scheduled Reminders
Before scheduling reminders, check available skills and follow skill guidance first.
@@ -10,10 +14,10 @@ Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegr
## Heartbeat Tasks
-`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks:
+`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks.
-- **Add**: `edit_file` to append new tasks
-- **Remove**: `edit_file` to delete completed tasks
-- **Rewrite**: `write_file` to replace all tasks
+- Use `apply_patch` for normal task-list updates, especially when adding, removing, or changing multiple lines.
+- Use `edit_file` only for small exact replacements copied from the current `HEARTBEAT.md`.
+- Use `write_file` for first creation or intentional full-file rewrites.
When the user asks for a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time cron reminder.
diff --git a/nanobot/templates/TOOLS.md b/nanobot/templates/TOOLS.md
deleted file mode 100644
index 374e49778..000000000
--- a/nanobot/templates/TOOLS.md
+++ /dev/null
@@ -1,28 +0,0 @@
-# Tool Usage Notes
-
-Tool signatures are provided automatically via function calling.
-This file documents non-obvious constraints and usage patterns.
-
-## exec — Safety Limits
-
-- Commands have a configurable timeout (default 60s)
-- Dangerous commands are blocked (rm -rf, format, dd, shutdown, etc.)
-- Output is truncated at 10,000 characters
-- `restrictToWorkspace` config can limit file access to the workspace
-
-## grep — Content Search
-
-- Use `grep` to search file contents inside the workspace
-- Default behavior returns only matching file paths (`output_mode="files_with_matches"`)
-- Supports optional `glob` filtering (e.g. `glob="*.py"`) plus `context_before` / `context_after`
-- Supports `type="py"`, `type="ts"`, `type="md"` and similar shorthand filters
-- Use `fixed_strings=true` for literal keywords containing regex characters
-- Use `output_mode="files_with_matches"` to get only matching file paths
-- Use `output_mode="count"` to size a search before reading full matches
-- Use `head_limit` and `offset` to page across results
-- Prefer this over `exec` for code and history searches
-- Binary or oversized files may be skipped to keep results readable
-
-## cron — Scheduled Reminders
-
-- Please refer to cron skill for usage.
diff --git a/nanobot/templates/agent/tool_contract.md b/nanobot/templates/agent/tool_contract.md
new file mode 100644
index 000000000..edbba21c9
--- /dev/null
+++ b/nanobot/templates/agent/tool_contract.md
@@ -0,0 +1,60 @@
+# Tool Usage Notes
+
+Tool signatures are provided automatically via function calling. This section
+documents the general tool contract and non-obvious usage patterns.
+
+## General Tool Contract
+
+- Use the narrowest structured tool that directly matches the task.
+- Use read-only discovery before writes when state is uncertain.
+- Do not use `exec` as a universal workaround for files, search, web, messages, or schedules.
+- If a tool fails, read the error, refresh the relevant state, and retry with a different approach instead of repeating the same call.
+- After meaningful changes, verify with the smallest reliable check: re-read changed state, run targeted tests, or inspect command output.
+- Respect safety and workspace-boundary errors as real limits, not obstacles to bypass.
+
+## Discovery and Reading
+
+- Use `find_files` or `list_dir` to locate workspace paths before `read_file` when a path is uncertain.
+- Use `grep` for content search inside the workspace; prefer it over shell grep for ordinary searches.
+- `grep` defaults to `output_mode="files_with_matches"`; use `output_mode="content"` for matching lines with context.
+- Use `fixed_strings=true` for literal keywords containing regex characters.
+- Use `output_mode="count"` to size a broad search before reading full matches.
+- Use `head_limit` and `offset` to page across large result sets.
+- Binary or oversized files may be skipped to keep results readable.
+
+## File and Coding Workflows
+
+- For code or config changes, the default loop is: locate (`find_files`/`grep`), inspect (`read_file`), edit (`apply_patch`), then verify (`exec` or re-read).
+- Use `apply_patch` as the default code editing tool, especially for multi-file changes, structural edits, generated code, moves, adds, or deletes.
+- Use `apply_patch dry_run=true` when the patch is uncertain and you want validation plus a change summary before writing.
+- Use `edit_file` only for small exact replacements in one file, with `old_text` copied from `read_file`; add `occurrence`, `line_hint`, or `expected_replacements` when ambiguity matters.
+- Use `write_file` for new files or intentional full-file rewrites, not routine partial edits.
+- If `apply_patch` or `edit_file` fails, re-read with `force=true`, narrow the context, and try a smaller patch rather than switching to shell `sed` or `echo`.
+
+## Process Execution
+
+- Use `exec` for tests, builds, package commands, git commands, and other process execution.
+- Prefer dedicated file/search tools over `cat`, shell `find`, shell `grep`, `sed`, or `echo` for ordinary workspace inspection and edits.
+- Use non-interactive flags such as `-y` or `--yes` when available.
+- Commands have a configurable timeout (default 60s), dangerous commands are blocked, and output is truncated.
+- For long-running or interactive commands, pass `yield_time_ms`; if the process keeps running, continue with `write_stdin`.
+- Use `write_stdin` to poll, provide stdin, close stdin, wait for expected output with `wait_for`, or terminate an existing exec session.
+- Use `list_exec_sessions` to recover active session IDs after context shifts.
+
+## Web and External Information
+
+- Use web tools when the user asks for current information, a specific URL, or information likely to have changed.
+- Use `web_search` to find sources and `web_fetch` for a specific page or result that needs closer reading.
+- Do not invent freshness-sensitive facts when tools can verify them.
+
+## Messaging and Media
+
+- Use `message` to send content or local media to the user/channel.
+- `read_file` only reads content for your analysis; it does not deliver a file to the user.
+- When sending an existing local file, attach it through the message/media mechanism instead of pasting file contents unless the user asked for text.
+
+## Scheduling and Background Work
+
+- Use `cron` for scheduled reminders or recurring jobs; do not run `nanobot cron` through `exec`.
+- For heartbeat tasks, update `HEARTBEAT.md` according to the agent instructions.
+- Do not write reminders only to memory files when the user expects an actual notification.
diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py
index b5d2f6d73..fd929134d 100644
--- a/nanobot/utils/file_edit_events.py
+++ b/nanobot/utils/file_edit_events.py
@@ -3,15 +3,13 @@
from __future__ import annotations
import difflib
-import json
import re
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Awaitable, Callable
-
-TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"})
+TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "apply_patch"})
_MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024
_LIVE_EMIT_INTERVAL_S = 0.18
_LIVE_EMIT_LINE_STEP = 24
@@ -154,19 +152,108 @@ def prepare_file_edit_tracker(
workspace: Path | None,
params: dict[str, Any] | None,
) -> FileEditTracker | None:
+ trackers = prepare_file_edit_trackers(
+ call_id=call_id,
+ tool_name=tool_name,
+ tool=tool,
+ workspace=workspace,
+ params=params,
+ )
+ return trackers[0] if trackers else None
+
+
+def prepare_file_edit_trackers(
+ *,
+ call_id: str,
+ tool_name: str,
+ tool: Any,
+ workspace: Path | None,
+ params: dict[str, Any] | None,
+) -> list[FileEditTracker]:
if not is_file_edit_tool(tool_name):
- return None
+ return []
+ paths = resolve_file_edit_paths(tool_name, tool, workspace, params)
+ trackers: list[FileEditTracker] = []
+ seen: set[Path] = set()
+ for path in paths:
+ try:
+ resolved = path.resolve()
+ except Exception:
+ resolved = path
+ if resolved in seen:
+ continue
+ seen.add(resolved)
+ before = read_file_snapshot(path)
+ trackers.append(FileEditTracker(
+ call_id=str(call_id or ""),
+ tool=tool_name,
+ path=path,
+ display_path=display_file_edit_path(path, workspace),
+ before=before,
+ ))
+ return trackers
+
+
+def resolve_file_edit_paths(
+ tool_name: str,
+ tool: Any,
+ workspace: Path | None,
+ params: dict[str, Any] | None,
+) -> list[Path]:
+ if tool_name == "apply_patch":
+ return _resolve_apply_patch_paths(tool, workspace, params)
path = resolve_file_edit_path(tool, workspace, params)
if path is None:
- return None
- before = read_file_snapshot(path)
- return FileEditTracker(
- call_id=str(call_id or ""),
- tool=tool_name,
- path=path,
- display_path=display_file_edit_path(path, workspace),
- before=before,
- )
+ return []
+ return [path]
+
+
+def _resolve_apply_patch_paths(
+ tool: Any,
+ workspace: Path | None,
+ params: dict[str, Any] | None,
+) -> list[Path]:
+ if not isinstance(params, dict):
+ return []
+ edits = params.get("edits")
+ if not isinstance(edits, list) or not edits:
+ return []
+ if params.get("dry_run") is True:
+ return []
+
+ resolved: list[Path] = []
+ seen: set[Path] = set()
+ for edit in edits:
+ if not isinstance(edit, dict):
+ continue
+ raw_path = edit.get("path")
+ if not isinstance(raw_path, str) or not raw_path.strip():
+ continue
+ path = _resolve_raw_file_edit_path(tool, workspace, raw_path)
+ if path is not None and path not in seen:
+ seen.add(path)
+ resolved.append(path)
+ return resolved
+
+
+def _resolve_raw_file_edit_path(
+ tool: Any,
+ workspace: Path | None,
+ raw_path: str,
+) -> Path | None:
+ resolver = getattr(tool, "_resolve", None)
+ if callable(resolver):
+ try:
+ resolved = resolver(raw_path)
+ if isinstance(resolved, Path):
+ return resolved
+ if resolved:
+ return Path(resolved)
+ except Exception:
+ return None
+ if workspace is None:
+ return Path(raw_path).expanduser().resolve()
+ return (workspace / raw_path).expanduser().resolve()
def build_file_edit_start_event(
@@ -304,6 +391,9 @@ class StreamingFileEditTracker:
self._states[key] = state
state.apply_delta(payload)
+ if state.name == "apply_patch":
+ await self._update_apply_patch(state)
+ return
if state.name not in {"write_file", "edit_file"}:
return
if state.path is None:
@@ -343,10 +433,80 @@ class StreamingFileEditTracker:
deleted=deleted,
)])
+ async def _update_apply_patch(self, state: _StreamingFileEditState) -> None:
+ if _json_bool_true(state.arguments, "dry_run"):
+ return
+ tool = self._tools.get("apply_patch") if hasattr(self._tools, "get") else None
+ events: list[dict[str, Any]] = []
+ now = time.monotonic()
+
+ path_matches = list(re.finditer(r'"path"\s*:\s*"([^"]+)"', state.arguments))
+ if not path_matches:
+ return
+
+ for i, m in enumerate(path_matches):
+ raw_path = m.group(1)
+ path = _resolve_raw_file_edit_path(tool, self._workspace, raw_path)
+ if path is None:
+ continue
+
+ segment_start = m.start()
+ segment_end = path_matches[i + 1].start() if i + 1 < len(path_matches) else len(state.arguments)
+ segment = state.arguments[segment_start:segment_end]
+
+ action_match = re.search(r'"action"\s*:\s*"(replace|add|delete)"', segment)
+ action = action_match.group(1) if action_match else "replace"
+
+ old_text = _extract_json_string_prefix(segment, "old_text") or ""
+ new_text = _extract_json_string_prefix(segment, "new_text") or ""
+
+ added = _text_line_count(new_text) if action in ("replace", "add") else 0
+ deleted = _text_line_count(old_text) if action in ("replace", "delete") else 0
+ delete_file = action == "delete"
+
+ file_state = state.patch_files.get(raw_path)
+ if file_state is None:
+ tracker = FileEditTracker(
+ call_id=state.call_id or state.key,
+ tool="apply_patch",
+ path=path,
+ display_path=display_file_edit_path(path, self._workspace),
+ before=read_file_snapshot(path),
+ )
+ file_state = _StreamingPatchFileState(tracker=tracker)
+ state.patch_files[raw_path] = file_state
+ if delete_file and added == 0 and deleted == 0 and file_state.tracker.before.countable:
+ deleted = _text_line_count(file_state.tracker.before.text or "")
+ if not file_state.should_emit(added, deleted, now):
+ continue
+ file_state.mark_emitted(added, deleted, now)
+ events.append(build_file_edit_live_event(
+ file_state.tracker,
+ added=added,
+ deleted=deleted,
+ ))
+ if events:
+ await self._emit(events)
+
async def flush(self) -> None:
events: list[dict[str, Any]] = []
now = time.monotonic()
for state in self._states.values():
+ for file_state in state.patch_files.values():
+ added, deleted = file_state.last_added, file_state.last_deleted
+ if not file_state.emitted_once:
+ continue
+ if (
+ file_state.last_emitted_added == added
+ and file_state.last_emitted_deleted == deleted
+ ):
+ continue
+ file_state.mark_emitted(added, deleted, now)
+ events.append(build_file_edit_live_event(
+ file_state.tracker,
+ added=added,
+ deleted=deleted,
+ ))
if state.tracker is None:
continue
added, deleted = state.live_diff_counts()
@@ -367,12 +527,14 @@ class StreamingFileEditTracker:
def apply_final_call_ids(self, final_tool_calls: list[Any]) -> None:
"""Keep final start/end events keyed to any earlier streamed placeholder."""
+ used_canonicals: set[str] = set()
for tool_call in final_tool_calls:
canonical = self.canonical_call_id_for(tool_call)
- if canonical:
+ if canonical and canonical not in used_canonicals:
try:
tool_call.id = canonical
- except Exception:
+ used_canonicals.add(canonical)
+ except (AttributeError, TypeError):
pass
def canonical_call_id_for(self, tool_call: Any) -> str | None:
@@ -389,6 +551,10 @@ class StreamingFileEditTracker:
"""Mark streamed edits as failed when no final tool call will run."""
events: list[dict[str, Any]] = []
for state in self._states.values():
+ for file_state in state.patch_files.values():
+ if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
+ continue
+ events.append(build_file_edit_error_event(file_state.tracker, error))
if state.tracker is None:
continue
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
@@ -492,6 +658,39 @@ class _StreamingJsonStringField:
self.last_char_cr = False
+@dataclass(slots=True)
+class _StreamingPatchFileState:
+ tracker: FileEditTracker
+ emitted_once: bool = False
+ last_emitted_added: int = -1
+ last_emitted_deleted: int = -1
+ last_emit_at: float = 0.0
+ last_added: int = 0
+ last_deleted: int = 0
+
+ def should_emit(self, added: int, deleted: int, now: float) -> bool:
+ self.last_added = added
+ self.last_deleted = deleted
+ if not self.emitted_once:
+ return True
+ if added == self.last_emitted_added and deleted == self.last_emitted_deleted:
+ return False
+ if max(
+ abs(added - self.last_emitted_added),
+ abs(deleted - self.last_emitted_deleted),
+ ) >= _LIVE_EMIT_LINE_STEP:
+ return True
+ return now - self.last_emit_at >= _LIVE_EMIT_INTERVAL_S
+
+ def mark_emitted(self, added: int, deleted: int, now: float) -> None:
+ self.emitted_once = True
+ self.last_added = added
+ self.last_deleted = deleted
+ self.last_emitted_added = added
+ self.last_emitted_deleted = deleted
+ self.last_emit_at = now
+
+
@dataclass(slots=True)
class _StreamingFileEditState:
key: str
@@ -509,6 +708,7 @@ class _StreamingFileEditState:
new_text: _StreamingJsonStringField = field(
default_factory=lambda: _StreamingJsonStringField("new_text")
)
+ patch_files: dict[str, _StreamingPatchFileState] = field(default_factory=dict)
emitted_once: bool = False
last_emitted_added: int = -1
last_emitted_deleted: int = -1
@@ -531,6 +731,7 @@ class _StreamingFileEditState:
self.content.reset()
self.old_text.reset()
self.new_text.reset()
+ self.patch_files.clear()
return
delta = payload.get("arguments_delta")
if isinstance(delta, str) and delta:
@@ -590,6 +791,14 @@ class _StreamingFileEditState:
name = getattr(tool_call, "name", None)
if name != self.name:
return False
+ if self.name == "apply_patch":
+ arguments = getattr(tool_call, "arguments", None)
+ if not isinstance(arguments, dict):
+ return False
+ edits = arguments.get("edits")
+ if not isinstance(edits, list):
+ return False
+ return '"edits"' in self.arguments
arguments = getattr(tool_call, "arguments", None)
if not isinstance(arguments, dict):
return False
@@ -612,6 +821,51 @@ def _stream_key(payload: dict[str, Any]) -> str:
return ""
+def _json_bool_true(source: str, key: str) -> bool:
+ return re.search(rf'"{re.escape(key)}"\s*:\s*true\b', source) is not None
+
+
+def _extract_json_string_prefix(source: str, key: str) -> str | None:
+ match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
+ if match is None:
+ return None
+ out: list[str] = []
+ i = match.end()
+ escape = False
+ while i < len(source):
+ ch = source[i]
+ if escape:
+ escape = False
+ if ch == "n":
+ out.append("\n")
+ elif ch == "r":
+ out.append("\r")
+ elif ch == "t":
+ out.append("\t")
+ elif ch == "u":
+ digits = source[i + 1:i + 5]
+ if len(digits) < 4:
+ break
+ try:
+ out.append(chr(int(digits, 16)))
+ except ValueError:
+ break
+ i += 4
+ else:
+ out.append(ch)
+ i += 1
+ continue
+ if ch == "\\":
+ escape = True
+ i += 1
+ continue
+ if ch == '"':
+ return "".join(out)
+ out.append(ch)
+ i += 1
+ return "".join(out)
+
+
def _extract_complete_json_string(source: str, key: str) -> str | None:
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
if match is None:
@@ -704,77 +958,4 @@ def _predict_after_text(
return before_text.replace(old_text, new_text)
return before_text.replace(old_text, new_text, 1)
return None
- if tool_name == "notebook_edit":
- return _predict_notebook_after_text(params, before_text)
return None
-
-
-def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> str | None:
- try:
- nb = json.loads(before_text) if before_text.strip() else _empty_notebook()
- except Exception:
- return None
- cells = nb.get("cells")
- if not isinstance(cells, list):
- return None
- try:
- cell_index = int(params.get("cell_index", 0))
- except (TypeError, ValueError):
- return None
- new_source = params.get("new_source")
- source = new_source if isinstance(new_source, str) else ""
- cell_type = (
- params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code"
- )
- mode = (
- params.get("edit_mode")
- if params.get("edit_mode") in ("replace", "insert", "delete")
- else "replace"
- )
- if mode == "delete":
- if 0 <= cell_index < len(cells):
- cells.pop(cell_index)
- else:
- return None
- elif mode == "insert":
- insert_at = min(max(cell_index + 1, 0), len(cells))
- cells.insert(insert_at, _new_notebook_cell(source, str(cell_type)))
- else:
- if not (0 <= cell_index < len(cells)):
- return None
- cell = cells[cell_index]
- if not isinstance(cell, dict):
- return None
- cell["source"] = source
- cell["cell_type"] = cell_type
- if cell_type == "code":
- cell.setdefault("outputs", [])
- cell.setdefault("execution_count", None)
- else:
- cell.pop("outputs", None)
- cell.pop("execution_count", None)
- nb["cells"] = cells
- try:
- return json.dumps(nb, indent=1, ensure_ascii=False)
- except Exception:
- return None
-
-
-def _empty_notebook() -> dict[str, Any]:
- return {
- "nbformat": 4,
- "nbformat_minor": 5,
- "metadata": {
- "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
- "language_info": {"name": "python"},
- },
- "cells": [],
- }
-
-
-def _new_notebook_cell(source: str, cell_type: str) -> dict[str, Any]:
- cell: dict[str, Any] = {"cell_type": cell_type, "source": source, "metadata": {}}
- if cell_type == "code":
- cell["outputs"] = []
- cell["execution_count"] = None
- return cell
diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py
index 2a969298c..ae91bf394 100644
--- a/nanobot/utils/helpers.py
+++ b/nanobot/utils/helpers.py
@@ -576,7 +576,7 @@ def build_status_content(
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
- """Sync bundled templates to workspace. Only creates missing files."""
+ """Sync bundled templates to workspace. Creates missing files without overwriting user files."""
from importlib.resources import files as pkg_files
try:
@@ -589,10 +589,11 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]
added: list[str] = []
def _write(src, dest: Path):
+ content = src.read_text(encoding="utf-8") if src else ""
if dest.exists():
return
dest.parent.mkdir(parents=True, exist_ok=True)
- dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8")
+ dest.write_text(content, encoding="utf-8")
added.append(str(dest.relative_to(workspace)))
for item in tpl.iterdir():
diff --git a/nanobot/utils/tool_hints.py b/nanobot/utils/tool_hints.py
index 272a19c9a..3a6460701 100644
--- a/nanobot/utils/tool_hints.py
+++ b/nanobot/utils/tool_hints.py
@@ -11,8 +11,10 @@ _TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = {
"read_file": (["path", "file_path"], "read {}", True, False),
"write_file": (["path", "file_path"], "write {}", True, False),
"edit": (["file_path", "path"], "edit {}", True, False),
+ "find_files": (["query", "glob", "path"], "find {}", False, False),
"grep": (["pattern"], 'grep "{}"', False, False),
"exec": (["command"], "$ {}", False, True),
+ "list_exec_sessions": ([], "exec sessions", False, False),
"web_search": (["query"], 'search "{}"', False, False),
"web_fetch": (["url"], "fetch {}", True, False),
"list_dir": (["path"], "ls {}", True, False),
@@ -81,6 +83,8 @@ def _extract_arg(tc, key_args: list[str]) -> str | None:
def _fmt_known(tc, fmt: tuple, max_length: int = 40) -> str:
"""Format a registered tool using its template."""
+ if not fmt[0] and "{}" not in fmt[1]:
+ return fmt[1]
val = _extract_arg(tc, fmt[0])
if val is None:
return tc.name
diff --git a/nanobot/webui/settings_api.py b/nanobot/webui/settings_api.py
index a5ab13c5a..6d43e22c8 100644
--- a/nanobot/webui/settings_api.py
+++ b/nanobot/webui/settings_api.py
@@ -73,12 +73,16 @@ def _mask_secret_hint(secret: str | None) -> str | None:
def _provider_requires_api_key(spec: Any) -> bool:
if spec.backend == "azure_openai":
return True
+ if spec.is_oauth:
+ return False
if spec.is_local or spec.is_direct:
return False
return True
def _provider_configured_for_settings(spec: Any, provider_config: Any) -> bool:
+ if spec.is_oauth:
+ return True
if _provider_requires_api_key(spec):
return bool(provider_config.api_key)
return bool(
diff --git a/tests/agent/test_context_builder.py b/tests/agent/test_context_builder.py
index 0206d0986..a36c0a30a 100644
--- a/tests/agent/test_context_builder.py
+++ b/tests/agent/test_context_builder.py
@@ -139,6 +139,13 @@ class TestLoadBootstrapFiles:
for name in ContextBuilder.BOOTSTRAP_FILES:
assert f"## {name}" in result
+ def test_legacy_tools_md_is_not_bootstrapped(self, tmp_path):
+ (tmp_path / "TOOLS.md").write_text("workspace tool notes", encoding="utf-8")
+ builder = _builder(tmp_path)
+ result = builder._load_bootstrap_files()
+ assert "TOOLS.md" not in result
+ assert "workspace tool notes" not in result
+
def test_utf8_content(self, tmp_path):
(tmp_path / "AGENTS.md").write_text("用中文回复", encoding="utf-8")
builder = _builder(tmp_path)
@@ -171,6 +178,37 @@ class TestIsTemplateContent:
assert ContextBuilder._is_template_content("totally different", "memory/MEMORY.md") is False
+# ---------------------------------------------------------------------------
+# Bundled bootstrap templates
+# ---------------------------------------------------------------------------
+
+
+class TestBundledToolContract:
+ def test_tool_contract_balances_general_and_coding_workflows(self):
+ from importlib.resources import files as pkg_files
+
+ tpl = pkg_files("nanobot") / "templates" / "agent" / "tool_contract.md"
+ content = tpl.read_text(encoding="utf-8")
+
+ assert "## General Tool Contract" in content
+ assert "Use the narrowest structured tool" in content
+ assert "Do not use `exec` as a universal workaround" in content
+ assert "## File and Coding Workflows" in content
+ assert "apply_patch" in content
+ assert "## Web and External Information" in content
+ assert "## Messaging and Media" in content
+ assert "## Scheduling and Background Work" in content
+ assert "pure coding" not in content.lower()
+
+ def test_tool_contract_is_injected_without_workspace_file(self, tmp_path):
+ builder = _builder(tmp_path)
+ prompt = builder.build_system_prompt()
+
+ assert "# Tool Usage Notes" in prompt
+ assert "## General Tool Contract" in prompt
+ assert "Do not use `exec` as a universal workaround" in prompt
+
+
# ---------------------------------------------------------------------------
# _build_user_content
# ---------------------------------------------------------------------------
diff --git a/tests/agent/test_onboard_logic.py b/tests/agent/test_onboard_logic.py
index 11a284bb5..762da4f31 100644
--- a/tests/agent/test_onboard_logic.py
+++ b/tests/agent/test_onboard_logic.py
@@ -346,6 +346,26 @@ class TestSyncWorkspaceTemplates:
content = (workspace / "AGENTS.md").read_text()
assert content == "existing content"
+ def test_does_not_create_tools_md(self, tmp_path):
+ """Tool contract is injected internally, not copied into user workspaces."""
+ workspace = tmp_path / "workspace"
+
+ added = sync_workspace_templates(workspace, silent=True)
+
+ assert "TOOLS.md" not in added
+ assert not (workspace / "TOOLS.md").exists()
+
+ def test_preserves_existing_tools_md_without_overwriting(self, tmp_path):
+ """Legacy user workspaces may have TOOLS.md; sync should leave it untouched."""
+ workspace = tmp_path / "workspace"
+ workspace.mkdir(parents=True)
+ tools_path = workspace / "TOOLS.md"
+ tools_path.write_text("custom tool notes", encoding="utf-8")
+
+ sync_workspace_templates(workspace, silent=True)
+
+ assert tools_path.read_text(encoding="utf-8") == "custom tool notes"
+
def test_creates_memory_directory(self, tmp_path):
"""Should create memory directory structure."""
workspace = tmp_path / "workspace"
diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py
new file mode 100644
index 000000000..277c85b83
--- /dev/null
+++ b/tests/channels/test_signal_channel.py
@@ -0,0 +1,1514 @@
+"""Tests for the Signal channel implementation."""
+
+from __future__ import annotations
+
+import asyncio
+from contextlib import asynccontextmanager
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from nanobot.bus.events import InboundMessage, OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.signal import (
+ SignalChannel,
+ SignalConfig,
+ SignalDMConfig,
+ SignalGroupConfig,
+)
+
+# ---------------------------------------------------------------------------
+# Fake HTTP client
+# ---------------------------------------------------------------------------
+
+
+class _FakeResponse:
+ def __init__(self, status_code: int = 200, body: dict | None = None) -> None:
+ self.status_code = status_code
+ self._body = body or {}
+
+ def raise_for_status(self) -> None:
+ if self.status_code >= 400:
+ raise RuntimeError(f"HTTP {self.status_code}")
+
+ def json(self) -> dict:
+ return self._body
+
+
+class _FakeHTTPClient:
+ """Minimal httpx.AsyncClient stand-in that records requests."""
+
+ def __init__(self, *, default_response: dict | None = None) -> None:
+ self.posts: list[dict] = []
+ self.gets: list[str] = []
+ self._response = _FakeResponse(body=default_response or {"result": {"timestamp": 123}})
+ self.closed = False
+
+ async def get(self, path: str) -> _FakeResponse:
+ self.gets.append(path)
+ return self._response
+
+ async def post(self, path: str, *, json: dict) -> _FakeResponse:
+ self.posts.append({"path": path, "json": json})
+ return self._response
+
+ async def aclose(self) -> None:
+ self.closed = True
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _make_channel_with_capture(**overrides) -> tuple[SignalChannel, list[dict]]:
+ """Build a SignalChannel with _handle_message captured into a list and a
+ no-op _start_typing, used by every receive-flow test class.
+ """
+ ch = _make_channel(**overrides)
+ handled: list[dict] = []
+
+ async def capture(**kwargs):
+ handled.append(kwargs)
+
+ async def noop_typing(chat_id):
+ pass
+
+ ch._handle_message = capture # type: ignore[method-assign]
+ ch._start_typing = noop_typing # type: ignore[method-assign]
+ return ch, handled
+
+
+def _make_channel(
+ *,
+ phone_number: str = "+10000000000",
+ dm_enabled: bool = True,
+ dm_policy: str = "open",
+ dm_allow_from: list[str] | None = None,
+ group_enabled: bool = False,
+ group_policy: str = "open",
+ group_allow_from: list[str] | None = None,
+ require_mention: bool = True,
+ group_buffer_size: int = 20,
+ attachments_dir: str | None = None,
+) -> SignalChannel:
+ config = SignalConfig(
+ enabled=True,
+ phone_number=phone_number,
+ dm=SignalDMConfig(
+ enabled=dm_enabled,
+ policy=dm_policy,
+ allow_from=dm_allow_from or [],
+ ),
+ group=SignalGroupConfig(
+ enabled=group_enabled,
+ policy=group_policy,
+ allow_from=group_allow_from or [],
+ require_mention=require_mention,
+ ),
+ group_message_buffer_size=group_buffer_size,
+ attachments_dir=attachments_dir,
+ )
+ return SignalChannel(config, MessageBus())
+
+
+def _dm_envelope(
+ *,
+ source_number: str = "+19995550001",
+ source_uuid: str | None = None,
+ source_name: str | None = "Alice",
+ message: str = "hello",
+ attachments: list | None = None,
+ reaction: dict | None = None,
+ timestamp: int = 1000,
+) -> dict:
+ data_message: dict = {"message": message, "timestamp": timestamp}
+ if attachments is not None:
+ data_message["attachments"] = attachments
+ if reaction is not None:
+ data_message["reaction"] = reaction
+ envelope: dict = {
+ "sourceNumber": source_number,
+ "sourceName": source_name,
+ "dataMessage": data_message,
+ }
+ if source_uuid:
+ envelope["sourceUuid"] = source_uuid
+ return {"envelope": envelope}
+
+
+def _group_envelope(
+ *,
+ source_number: str = "+19995550001",
+ source_name: str = "Bob",
+ group_id: str = "group123==",
+ message: str = "hey group",
+ mentions: list | None = None,
+ timestamp: int = 2000,
+ use_v2: bool = False,
+) -> dict:
+ group_obj = {"groupId": group_id}
+ key = "groupV2" if use_v2 else "groupInfo"
+ data_message: dict = {
+ "message": message,
+ "timestamp": timestamp,
+ key: group_obj,
+ "mentions": mentions or [],
+ }
+ return {
+ "envelope": {
+ "sourceNumber": source_number,
+ "sourceName": source_name,
+ "dataMessage": data_message,
+ }
+ }
+
+
+# ---------------------------------------------------------------------------
+# Static utility tests
+# ---------------------------------------------------------------------------
+
+
+class TestNormalizeSignalId:
+ def test_phone_number_kept_and_stripped(self):
+ result = SignalChannel._normalize_signal_id("+12345678901")
+ assert "+12345678901" in result
+ assert "12345678901" in result
+
+ def test_digits_only_gets_plus_prefix(self):
+ result = SignalChannel._normalize_signal_id("12345678901")
+ assert "+12345678901" in result
+
+ def test_lowercase_variant_added(self):
+ result = SignalChannel._normalize_signal_id("SOME-UUID")
+ assert "some-uuid" in result
+
+ def test_empty_string_returns_empty(self):
+ assert SignalChannel._normalize_signal_id("") == []
+
+ def test_whitespace_stripped(self):
+ result = SignalChannel._normalize_signal_id(" +1234 ")
+ assert "+1234" in result
+
+
+class TestCollectSenderIdParts:
+ def test_collects_source_number(self):
+ env = {"sourceNumber": "+15551234567"}
+ parts = SignalChannel._collect_sender_id_parts(env)
+ assert "+15551234567" in parts
+
+ def test_collects_multiple_keys(self):
+ env = {"sourceNumber": "+15551234567", "sourceUuid": "uuid-abc"}
+ parts = SignalChannel._collect_sender_id_parts(env)
+ assert "+15551234567" in parts
+ assert "uuid-abc" in parts
+
+ def test_deduplicates(self):
+ env = {"sourceNumber": "+15551234567", "source": "+15551234567"}
+ parts = SignalChannel._collect_sender_id_parts(env)
+ assert parts.count("+15551234567") == 1
+
+ def test_ignores_non_string_values(self):
+ env = {"sourceNumber": 12345, "sourceUuid": None}
+ parts = SignalChannel._collect_sender_id_parts(env)
+ assert parts == []
+
+ def test_empty_envelope_returns_empty(self):
+ assert SignalChannel._collect_sender_id_parts({}) == []
+
+
+class TestPrimarySenderId:
+ def test_prefers_phone_number(self):
+ assert SignalChannel._primary_sender_id(["+1234", "uuid-abc"]) == "+1234"
+
+ def test_accepts_digit_only(self):
+ assert SignalChannel._primary_sender_id(["1234567890", "uuid-abc"]) == "1234567890"
+
+ def test_falls_back_to_first_part(self):
+ assert SignalChannel._primary_sender_id(["uuid-abc", "other"]) == "uuid-abc"
+
+ def test_empty_list_returns_empty(self):
+ assert SignalChannel._primary_sender_id([]) == ""
+
+
+class TestExtractGroupId:
+ def test_extracts_from_group_info(self):
+ gid = SignalChannel._extract_group_id({"groupId": "abc=="}, None)
+ assert gid == "abc=="
+
+ def test_extracts_from_group_v2(self):
+ gid = SignalChannel._extract_group_id(None, {"id": "xyz=="})
+ assert gid == "xyz=="
+
+ def test_prefers_group_info_over_v2(self):
+ gid = SignalChannel._extract_group_id({"groupId": "first"}, {"groupId": "second"})
+ assert gid == "first"
+
+ def test_returns_none_when_both_none(self):
+ assert SignalChannel._extract_group_id(None, None) is None
+
+ def test_returns_none_when_not_dicts(self):
+ assert SignalChannel._extract_group_id("bad", 123) is None
+
+
+class TestIsGroupChatId:
+ def test_base64_with_equals_is_group(self):
+ assert SignalChannel._is_group_chat_id("abc==") is True
+
+ def test_long_id_without_dash_is_group(self):
+ long_id = "a" * 41
+ assert SignalChannel._is_group_chat_id(long_id) is True
+
+ def test_phone_number_is_not_group(self):
+ assert SignalChannel._is_group_chat_id("+12345678901") is False
+
+ def test_uuid_with_dashes_is_not_group(self):
+ assert SignalChannel._is_group_chat_id("550e8400-e29b-41d4-a716-446655440000") is False
+
+
+class TestRecipientParams:
+ def test_group_chat_uses_group_id(self):
+ ch = _make_channel()
+ params = ch._recipient_params("abc==")
+ assert params == {"groupId": "abc=="}
+
+ def test_dm_uses_recipient_list(self):
+ ch = _make_channel()
+ params = ch._recipient_params("+12345678901")
+ assert params == {"recipient": ["+12345678901"]}
+
+
+class TestMentionHelpers:
+ def test_mention_id_candidates_extracts_number(self):
+ mention = {"number": "+1234567890"}
+ ids = SignalChannel._mention_id_candidates(mention)
+ assert "+1234567890" in ids
+
+ def test_mention_id_candidates_extracts_uuid(self):
+ mention = {"uuid": "some-uuid"}
+ ids = SignalChannel._mention_id_candidates(mention)
+ assert "some-uuid" in ids
+
+ def test_mention_span_valid(self):
+ assert SignalChannel._mention_span({"start": 0, "length": 5}) == (0, 5)
+
+ def test_mention_span_negative_start(self):
+ assert SignalChannel._mention_span({"start": -1, "length": 5}) is None
+
+ def test_mention_span_zero_length(self):
+ assert SignalChannel._mention_span({"start": 0, "length": 0}) is None
+
+ def test_mention_span_missing_keys(self):
+ assert SignalChannel._mention_span({}) is None
+
+ def test_leading_placeholder_ufffc(self):
+ span = SignalChannel._leading_placeholder_span(" hello")
+ assert span == (0, 1)
+
+ def test_leading_placeholder_not_at_start(self):
+ assert SignalChannel._leading_placeholder_span("hello ") is None
+
+ def test_leading_placeholder_empty_string(self):
+ assert SignalChannel._leading_placeholder_span("") is None
+
+ def test_leading_placeholder_plain_text(self):
+ assert SignalChannel._leading_placeholder_span("hello") is None
+
+
+# ---------------------------------------------------------------------------
+# Account ID alias / mention matching
+# ---------------------------------------------------------------------------
+
+
+class TestAccountIdAliases:
+ def test_phone_number_alias_registered_on_init(self):
+ ch = _make_channel(phone_number="+10000000000")
+ assert ch._id_matches_account("+10000000000")
+
+ def test_digit_only_variant_matches(self):
+ ch = _make_channel(phone_number="+10000000000")
+ assert ch._id_matches_account("10000000000")
+
+ def test_remember_alias_adds_uuid(self):
+ ch = _make_channel()
+ ch._remember_account_id_alias("some-uuid-abc")
+ assert ch._id_matches_account("some-uuid-abc")
+
+ def test_non_matching_id_returns_false(self):
+ ch = _make_channel(phone_number="+10000000000")
+ assert not ch._id_matches_account("+19999999999")
+
+ def test_none_and_non_string_return_false(self):
+ ch = _make_channel()
+ assert not ch._id_matches_account(None)
+
+
+# ---------------------------------------------------------------------------
+# _should_respond_in_group
+# ---------------------------------------------------------------------------
+
+
+class TestShouldRespondInGroup:
+ def _make_group_channel(self, require_mention: bool = True) -> SignalChannel:
+ return _make_channel(
+ phone_number="+10000000000",
+ group_enabled=True,
+ require_mention=require_mention,
+ )
+
+ def test_no_require_mention_always_responds(self):
+ ch = self._make_group_channel(require_mention=False)
+ assert ch._should_respond_in_group("anything", []) is True
+
+ def test_require_mention_with_no_mentions_returns_false(self):
+ ch = self._make_group_channel(require_mention=True)
+ assert ch._should_respond_in_group("hello", []) is False
+
+ def test_require_mention_with_bot_number_mention(self):
+ ch = self._make_group_channel(require_mention=True)
+ mentions = [{"number": "+10000000000", "start": 0, "length": 12}]
+ assert ch._should_respond_in_group(" hello", mentions) is True
+
+ def test_require_mention_with_uuid_mention(self):
+ ch = self._make_group_channel(require_mention=True)
+ ch._remember_account_id_alias("bot-uuid-123")
+ mentions = [{"uuid": "bot-uuid-123", "start": 0, "length": 8}]
+ assert ch._should_respond_in_group(" hello", mentions) is True
+
+ def test_identifier_less_leading_mention_accepted(self):
+ ch = self._make_group_channel(require_mention=True)
+ # Mention with no IDs but leading span — treated as bot mention
+ mentions = [{"start": 0, "length": 1}]
+ assert ch._should_respond_in_group(" hello", mentions) is True
+
+ def test_identifier_less_non_leading_mention_rejected(self):
+ ch = self._make_group_channel(require_mention=True)
+ mentions = [{"start": 5, "length": 1}]
+ assert ch._should_respond_in_group("hello ", mentions) is False
+
+ def test_leading_placeholder_without_mentions_metadata(self):
+ ch = self._make_group_channel(require_mention=True)
+ assert ch._should_respond_in_group(" hello", []) is True
+
+ def test_phone_number_in_text_triggers_response(self):
+ ch = self._make_group_channel(require_mention=True)
+ assert ch._should_respond_in_group("hey +10000000000 help", []) is True
+
+
+# ---------------------------------------------------------------------------
+# _strip_bot_mention
+# ---------------------------------------------------------------------------
+
+
+class TestStripBotMention:
+ def _make_channel_with_number(self) -> SignalChannel:
+ return _make_channel(phone_number="+10000000000")
+
+ def test_strips_mention_by_phone(self):
+ ch = self._make_channel_with_number()
+ text = " hello"
+ mentions = [{"number": "+10000000000", "start": 0, "length": 1}]
+ result = ch._strip_bot_mention(text, mentions)
+ assert result == "hello"
+
+ def test_strips_identifier_less_leading_mention(self):
+ ch = self._make_channel_with_number()
+ text = " hello"
+ mentions = [{"start": 0, "length": 1}]
+ result = ch._strip_bot_mention(text, mentions)
+ assert result == "hello"
+
+ def test_strips_leading_placeholder_without_mention_metadata(self):
+ ch = self._make_channel_with_number()
+ text = " hello"
+ result = ch._strip_bot_mention(text, [])
+ assert result == "hello"
+
+ def test_non_bot_mention_mid_text_not_stripped(self):
+ # A non-bot mention that is NOT a leading placeholder leaves the text alone.
+ ch = self._make_channel_with_number()
+ text = "hello  world"
+ mentions = [{"number": "+19999999999", "start": 6, "length": 1}]
+ result = ch._strip_bot_mention(text, mentions)
+ # Mid-text placeholder from a non-bot mention should be untouched
+ assert "" in result
+
+ def test_empty_text_returned_unchanged(self):
+ ch = self._make_channel_with_number()
+ assert ch._strip_bot_mention("", []) == ""
+
+
+# ---------------------------------------------------------------------------
+# Group message buffer
+# ---------------------------------------------------------------------------
+
+
+class TestGroupBuffer:
+ def test_add_and_get_context(self):
+ ch = _make_channel(group_buffer_size=5)
+ ch._add_to_group_buffer("g1", "Alice", "+1111", "first msg", 1000)
+ ch._add_to_group_buffer("g1", "Bob", "+2222", "second msg", 2000)
+ # Only messages before the latest are returned as context
+ ctx = ch._get_group_buffer_context("g1")
+ assert "first msg" in ctx
+ # The last message is not included (it's the "current" one)
+ assert "second msg" not in ctx
+
+ def test_empty_context_when_only_one_message(self):
+ ch = _make_channel(group_buffer_size=5)
+ ch._add_to_group_buffer("g1", "Alice", "+1111", "only msg", 1000)
+ assert ch._get_group_buffer_context("g1") == ""
+
+ def test_empty_context_when_group_unknown(self):
+ ch = _make_channel()
+ assert ch._get_group_buffer_context("unknown") == ""
+
+ def test_buffer_respects_max_size(self):
+ ch = _make_channel(group_buffer_size=3)
+ for i in range(10):
+ ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i)
+ assert len(ch._group_buffers["g1"]) == 3
+
+ def test_zero_buffer_size_rejected_by_validator(self):
+ with pytest.raises(ValueError, match="group_message_buffer_size"):
+ _make_channel(group_buffer_size=0)
+
+ def test_negative_buffer_size_rejected_by_validator(self):
+ with pytest.raises(ValueError, match="group_message_buffer_size"):
+ _make_channel(group_buffer_size=-1)
+
+ def test_context_limits_message_length(self):
+ ch = _make_channel(group_buffer_size=5)
+ long_msg = "x" * 500
+ ch._add_to_group_buffer("g1", "Alice", "+1111", long_msg, 1000)
+ ch._add_to_group_buffer("g1", "Bob", "+2222", "short", 2000)
+ ctx = ch._get_group_buffer_context("g1")
+ # Context is capped at 200 chars per message
+ assert len(ctx.split("Alice: ", 1)[1]) <= 200
+
+
+# ---------------------------------------------------------------------------
+# _handle_data_message — DM routing
+# ---------------------------------------------------------------------------
+
+
+class TestIsAllowed:
+ """The base-channel allowlist gate is overridden to understand Signal's
+ pipe-joined composite sender_ids and the +/no-+ phone variants.
+ """
+
+ def test_denies_when_allowlist_empty(self):
+ ch = _make_channel(dm_enabled=True, dm_policy="allowlist")
+ assert ch.is_allowed("+19995550001") is False
+
+ def test_denies_when_no_policy_allows(self):
+ """When both dm and group are disabled, is_allowed denies."""
+ ch = _make_channel(dm_enabled=False, group_enabled=False)
+ assert ch.is_allowed("+19995550001") is False
+
+ def test_allows_wildcard(self):
+ ch = _make_channel(dm_policy="allowlist", dm_allow_from=["*"])
+ assert ch.is_allowed("+19995550001|some-uuid") is True
+
+ def test_allows_composite_sender_against_split_allowlist(self):
+ """Composite sender_id, single-id allow_from — must match either part."""
+ ch = _make_channel(
+ dm_policy="allowlist",
+ dm_allow_from=["+19995550001"],
+ )
+ assert ch.is_allowed("+19995550001|1872ba20-uuid") is True
+
+ def test_allows_composite_sender_against_composite_allowlist_entry(self):
+ """Backward compat: pipe-joined composite allowlist entries still match."""
+ composite = "+19995550001|1872ba20-uuid"
+ ch = _make_channel(dm_policy="allowlist", dm_allow_from=[composite])
+ assert ch.is_allowed(composite) is True
+
+ def test_allows_when_only_uuid_part_is_listed(self):
+ ch = _make_channel(dm_policy="allowlist", dm_allow_from=["1872ba20-uuid"])
+ assert ch.is_allowed("+19995550001|1872ba20-uuid") is True
+
+ def test_denies_when_no_part_matches(self):
+ ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+12223334444"])
+ assert ch.is_allowed("+19995550001|1872ba20-uuid") is False
+
+ def test_allowlist_union_includes_group_ids(self):
+ """allow_from is the union of dm.allow_from and group.allow_from."""
+ ch = _make_channel(
+ group_enabled=True,
+ group_policy="allowlist",
+ group_allow_from=["group-id-base64=="],
+ )
+ assert "group-id-base64==" in ch.config.allow_from
+
+
+class TestEndToEndDMRouting:
+ """End-to-end tests that keep the real _handle_message chain (no mock),
+ verifying that _check_inbound_policy + _handle_message work together
+ correctly for DM routing. The override of _handle_message publishes
+ directly to bus (policy already checked); denied DMs call
+ super()._handle_message which issues a pairing code.
+ """
+
+ @pytest.mark.asyncio
+ async def test_open_dm_policy_publishes_to_bus(self):
+ """Open DM: _check_inbound_policy passes → _handle_message publishes."""
+ ch = _make_channel(dm_enabled=True, dm_policy="open")
+
+ async def noop_typing(chat_id):
+ pass
+
+ ch._start_typing = noop_typing # type: ignore[method-assign]
+ published: list[InboundMessage] = []
+
+ async def capture_publish(msg: InboundMessage):
+ published.append(msg)
+
+ ch.bus.publish_inbound = capture_publish # type: ignore[method-assign]
+
+ params = _dm_envelope(source_number="+19995550001", message="hello")
+ await ch._handle_receive_notification(params)
+
+ assert len(published) == 1
+ assert published[0].content == "hello"
+ assert published[0].sender_id == "+19995550001"
+
+ @pytest.mark.asyncio
+ async def test_allowlist_dm_denied_triggers_pairing(self):
+ """Allowlist DM: denied sender triggers pairing code via send()."""
+ ch = _make_channel(dm_enabled=True, dm_policy="allowlist", dm_allow_from=[])
+ ch._http = _FakeHTTPClient() # type: ignore[assignment]
+
+ async def noop_typing(chat_id):
+ pass
+
+ ch._start_typing = noop_typing # type: ignore[method-assign]
+ published: list[InboundMessage] = []
+
+ async def capture_publish(msg: InboundMessage):
+ published.append(msg)
+
+ ch.bus.publish_inbound = capture_publish # type: ignore[method-assign]
+
+ params = _dm_envelope(source_number="+19995550002", message="hello")
+ await ch._handle_receive_notification(params)
+
+ # Should NOT publish to bus — sender is not on allowlist.
+ assert published == []
+ # Should have sent a pairing code via send (captured in HTTP posts).
+ assert len(ch._http.posts) == 1 # type: ignore[attr-defined]
+ sent_text = ch._http.posts[0]["json"]["params"]["message"] # type: ignore[attr-defined]
+ assert "pairing" in sent_text.lower() or "pair" in sent_text.lower()
+
+ @pytest.mark.asyncio
+ async def test_allowlist_dm_denied_with_group_open_still_pairs(self):
+ """dm.policy="allowlist" + group.policy="open": denied DM sender
+ must still get a pairing code, not be leaked by the group open check."""
+ ch = _make_channel(
+ dm_enabled=True,
+ dm_policy="allowlist",
+ dm_allow_from=[],
+ group_enabled=True,
+ group_policy="open",
+ )
+ ch._http = _FakeHTTPClient() # type: ignore[assignment]
+
+ async def noop_typing(chat_id):
+ pass
+
+ ch._start_typing = noop_typing # type: ignore[method-assign]
+ published: list[InboundMessage] = []
+
+ async def capture_publish(msg: InboundMessage):
+ published.append(msg)
+
+ ch.bus.publish_inbound = capture_publish # type: ignore[method-assign]
+
+ params = _dm_envelope(source_number="+19995550002", message="hello")
+ await ch._handle_receive_notification(params)
+
+ assert published == []
+ assert len(ch._http.posts) == 1 # type: ignore[attr-defined]
+
+ @pytest.mark.asyncio
+ async def test_open_group_policy_publishes_to_bus(self):
+ """Open group: group message from unknown sender publishes to bus."""
+ ch = _make_channel(
+ group_enabled=True,
+ group_policy="open",
+ require_mention=False,
+ )
+
+ async def noop_typing(chat_id):
+ pass
+
+ ch._start_typing = noop_typing # type: ignore[method-assign]
+ published: list[InboundMessage] = []
+
+ async def capture_publish(msg: InboundMessage):
+ published.append(msg)
+
+ ch.bus.publish_inbound = capture_publish # type: ignore[method-assign]
+
+ params = _group_envelope(group_id="grp==", message="hello group")
+ await ch._handle_receive_notification(params)
+
+ assert len(published) == 1
+ assert "hello group" in published[0].content
+
+
+class TestCheckInboundPolicy:
+ """Direct tests for the policy gate that _handle_data_message now delegates to."""
+
+ def _call(
+ self,
+ ch: SignalChannel,
+ *,
+ sender_id: str = "+19995550001",
+ sender_number: str = "+19995550001",
+ group_id: str | None = None,
+ is_group_message: bool = False,
+ message_text: str = "hi",
+ mentions: list | None = None,
+ sender_name: str | None = "Alice",
+ timestamp: int | None = 1000,
+ ) -> tuple[bool, str]:
+ return ch._check_inbound_policy(
+ sender_id=sender_id,
+ sender_number=sender_number,
+ group_id=group_id,
+ is_group_message=is_group_message,
+ message_text=message_text,
+ mentions=mentions or [],
+ sender_name=sender_name,
+ timestamp=timestamp,
+ )
+
+ def test_dm_open_allows(self):
+ ch = _make_channel(dm_enabled=True, dm_policy="open")
+ allowed, chat_id = self._call(ch)
+ assert allowed is True
+ assert chat_id == "+19995550001"
+
+ def test_dm_disabled_blocks(self):
+ ch = _make_channel(dm_enabled=False)
+ allowed, _ = self._call(ch)
+ assert allowed is False
+
+ def test_dm_allowlist_blocks_unknown_sender(self):
+ ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+12223334444"])
+ allowed, _ = self._call(ch, sender_id="+19995550001")
+ assert allowed is False
+
+ def test_dm_allowlist_allows_known_sender(self):
+ ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+19995550001"])
+ allowed, _ = self._call(ch, sender_id="+19995550001")
+ assert allowed is True
+
+ def test_group_disabled_blocks(self):
+ ch = _make_channel(group_enabled=False)
+ allowed, _ = self._call(ch, is_group_message=True, group_id="g1")
+ assert allowed is False
+
+ def test_group_open_with_mention_allows(self):
+ ch = _make_channel(
+ group_enabled=True,
+ group_policy="open",
+ phone_number="+10000000000",
+ require_mention=True,
+ )
+ allowed, chat_id = self._call(
+ ch,
+ is_group_message=True,
+ group_id="g1",
+ message_text="hello @bot",
+ mentions=[{"number": "+10000000000", "start": 6, "length": 4}],
+ )
+ assert allowed is True
+ assert chat_id == "g1"
+
+ def test_group_open_without_mention_blocks(self):
+ ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
+ allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="plain talk")
+ assert allowed is False
+
+ def test_group_command_bypasses_mention_requirement(self):
+ ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
+ allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="/help")
+ assert allowed is True
+
+ def test_allowed_group_appends_to_buffer(self):
+ """Side effect: when a group message is allowed, it lands in the buffer."""
+ ch = _make_channel(group_enabled=True, group_policy="open", require_mention=False)
+ self._call(ch, is_group_message=True, group_id="g1", message_text="first")
+ self._call(ch, is_group_message=True, group_id="g1", message_text="second")
+ assert len(ch._group_buffers["g1"]) == 2
+
+ def test_blocked_group_does_not_append_to_buffer(self):
+ """Side effect: when a group is disabled, the buffer must not change."""
+ ch = _make_channel(group_enabled=False)
+ self._call(ch, is_group_message=True, group_id="g1", message_text="hi")
+ assert "g1" not in ch._group_buffers
+
+
+class TestAttachmentsDir:
+ def test_default_attachments_dir(self):
+ ch = _make_channel()
+ expected = Path.home() / ".local/share/signal-cli/attachments"
+ assert ch._signal_attachments_dir() == expected
+
+ def test_configured_attachments_dir(self, tmp_path):
+ ch = _make_channel(attachments_dir=str(tmp_path / "custom"))
+ assert ch._signal_attachments_dir() == tmp_path / "custom"
+
+ def test_attachments_dir_expands_user(self):
+ ch = _make_channel(attachments_dir="~/signal-attachments")
+ assert ch._signal_attachments_dir() == Path.home() / "signal-attachments"
+
+
+class TestHandleDataMessageDM:
+ def _make_dm_channel(self, policy="open", allow_from=None) -> tuple[SignalChannel, list]:
+ return _make_channel_with_capture(
+ dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or []
+ )
+
+ @pytest.mark.asyncio
+ async def test_dm_open_policy_accepted(self):
+ ch, handled = self._make_dm_channel(policy="open")
+ params = _dm_envelope(source_number="+19995550001", message="hi")
+ await ch._handle_receive_notification(params)
+ assert len(handled) == 1
+ assert handled[0]["chat_id"] == "+19995550001"
+ assert handled[0]["content"] == "hi"
+
+ @pytest.mark.asyncio
+ async def test_dm_allowlist_accepted(self):
+ ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"])
+ params = _dm_envelope(source_number="+19995550001")
+ await ch._handle_receive_notification(params)
+ assert len(handled) == 1
+
+ @pytest.mark.asyncio
+ async def test_dm_allowlist_rejected_triggers_pairing(self):
+ # Denied DM senders go through super()._handle_message which checks
+ # is_allowed → sends pairing code via self.send().
+ ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+10000000001"])
+ ch._http = _FakeHTTPClient() # type: ignore[attr-defined]
+ params = _dm_envelope(source_number="+19995550002")
+ await ch._handle_receive_notification(params)
+ # The denied DM path calls super()._handle_message, not self._handle_message,
+ # so the capture list stays empty. Verify pairing code was sent via HTTP.
+ assert handled == []
+ assert len(ch._http.posts) == 1 # type: ignore[attr-defined]
+ sent_text = ch._http.posts[0]["json"]["params"]["message"] # type: ignore[attr-defined]
+ assert "pairing" in sent_text.lower() or "pair" in sent_text.lower()
+
+ @pytest.mark.asyncio
+ async def test_dm_paired_sender_allowed_without_allowlist_entry(self, monkeypatch):
+ # Once a sender completes pairing they should pass is_allowed on every
+ # subsequent message — otherwise the pairing reply loops forever.
+ approved = {"+19995550002"}
+ monkeypatch.setattr(
+ "nanobot.channels.signal.is_approved",
+ lambda channel, sender_id: sender_id in approved,
+ )
+ ch = _make_channel(dm_enabled=True, dm_policy="allowlist", dm_allow_from=[])
+ assert ch.is_allowed("+19995550002") is True
+ # Variant forms (with/without "+") must still match a stored approval.
+ assert ch.is_allowed("19995550002") is True
+ # Unpaired sender stays denied.
+ assert ch.is_allowed("+19995559999") is False
+
+ @pytest.mark.asyncio
+ async def test_dm_allowlist_matches_without_plus_prefix(self):
+ """An allowlist entry without '+' must match a sender that carries '+'."""
+ ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["19995550001"])
+ params = _dm_envelope(source_number="+19995550001")
+ await ch._handle_receive_notification(params)
+ assert len(handled) == 1
+
+ @pytest.mark.asyncio
+ async def test_dm_allowlist_matches_with_plus_prefix(self):
+ """An allowlist entry with '+' must match a sender without '+'."""
+ ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"])
+ params = _dm_envelope(source_number="+19995550001", source_uuid=None)
+ # Replace envelope's sourceNumber with the non-prefixed form by editing
+ # the constructed dict directly so _collect_sender_id_parts sees it.
+ params["envelope"]["sourceNumber"] = "19995550001"
+ await ch._handle_receive_notification(params)
+ assert len(handled) == 1
+
+ @pytest.mark.asyncio
+ async def test_dm_allowlist_matches_uuid_case_insensitive(self):
+ """UUID matching must be case-insensitive."""
+ uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890"
+ ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[uuid.lower()])
+ params = _dm_envelope(source_number="+19995550001", source_uuid=uuid)
+ await ch._handle_receive_notification(params)
+ assert len(handled) == 1
+
+ @pytest.mark.asyncio
+ async def test_dm_allowlist_matches_pipe_joined_composite_entry(self):
+ """Allowlist entries written as ``phone|uuid`` composites still work.
+
+ Some configs pre-date the per-part splitting and store the full
+ sender_id composite as a single allow_from entry. Keep matching it.
+ """
+ composite = "+19995550001|1872ba20-f52a-4bad-b434-bf7f808c8b22"
+ ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[composite])
+ params = _dm_envelope(
+ source_number="+19995550001",
+ source_uuid="1872ba20-f52a-4bad-b434-bf7f808c8b22",
+ )
+ await ch._handle_receive_notification(params)
+ assert len(handled) == 1
+
+ @pytest.mark.asyncio
+ async def test_dm_disabled_rejected(self):
+ ch = _make_channel(dm_enabled=False)
+ handled: list[dict] = []
+
+ async def capture(**kwargs):
+ handled.append(kwargs)
+
+ ch._handle_message = capture # type: ignore[method-assign]
+
+ async def noop_typing(chat_id):
+ pass
+
+ ch._start_typing = noop_typing # type: ignore[method-assign]
+ params = _dm_envelope(source_number="+19995550001")
+ await ch._handle_receive_notification(params)
+ assert handled == []
+
+ @pytest.mark.asyncio
+ async def test_reaction_message_ignored(self):
+ ch, handled = self._make_dm_channel()
+ params = _dm_envelope(reaction={"emoji": "👍", "targetTimestamp": 999})
+ await ch._handle_receive_notification(params)
+ assert handled == []
+
+ @pytest.mark.asyncio
+ async def test_empty_message_ignored(self):
+ ch, handled = self._make_dm_channel()
+ params = _dm_envelope(message="")
+ await ch._handle_receive_notification(params)
+ assert handled == []
+
+ @pytest.mark.asyncio
+ async def test_receipt_message_ignored(self):
+ ch, handled = self._make_dm_channel()
+ notification = {
+ "envelope": {
+ "sourceNumber": "+19995550001",
+ "receiptMessage": {"when": 1234},
+ }
+ }
+ await ch._handle_receive_notification(notification)
+ assert handled == []
+
+ @pytest.mark.asyncio
+ async def test_typing_indicator_ignored(self):
+ ch, handled = self._make_dm_channel()
+ notification = {
+ "envelope": {
+ "sourceNumber": "+19995550001",
+ "typingMessage": {"action": "STARTED"},
+ }
+ }
+ await ch._handle_receive_notification(notification)
+ assert handled == []
+
+ @pytest.mark.asyncio
+ async def test_missing_envelope_ignored(self):
+ ch, handled = self._make_dm_channel()
+ await ch._handle_receive_notification({})
+ assert handled == []
+
+ @pytest.mark.asyncio
+ async def test_metadata_passed_to_handle(self):
+ ch, handled = self._make_dm_channel()
+ params = _dm_envelope(source_number="+19995550001", source_name="Alice", timestamp=9999)
+ await ch._handle_receive_notification(params)
+ meta = handled[0]["metadata"]
+ assert meta["sender_name"] == "Alice"
+ assert meta["timestamp"] == 9999
+ assert meta["is_group"] is False
+
+ @pytest.mark.asyncio
+ async def test_sender_id_with_uuid_variant(self):
+ ch, handled = self._make_dm_channel()
+ params = _dm_envelope(source_number="+19995550001", source_uuid="uuid-abc")
+ await ch._handle_receive_notification(params)
+ assert len(handled) == 1
+ # sender_id combines both parts
+ assert "+19995550001" in handled[0]["sender_id"]
+ assert "uuid-abc" in handled[0]["sender_id"]
+
+ @pytest.mark.asyncio
+ async def test_stop_typing_called_on_handle_error(self):
+ ch = _make_channel(dm_enabled=True, dm_policy="open")
+ typing_stopped: list[str] = []
+
+ async def fail_handle(**kwargs):
+ raise RuntimeError("boom")
+
+ async def noop_typing(chat_id):
+ pass
+
+ async def record_stop(chat_id, **kwargs):
+ typing_stopped.append(chat_id)
+
+ ch._handle_message = fail_handle # type: ignore[method-assign]
+ ch._start_typing = noop_typing # type: ignore[method-assign]
+ ch._stop_typing = record_stop # type: ignore[method-assign]
+
+ # _handle_receive_notification swallows exceptions; the typing stop
+ # still fires from _handle_data_message's except clause.
+ params = _dm_envelope(source_number="+19995550001")
+ await ch._handle_receive_notification(params)
+
+ assert "+19995550001" in typing_stopped
+
+
+# ---------------------------------------------------------------------------
+# _handle_data_message — group routing
+# ---------------------------------------------------------------------------
+
+
+class TestHandleDataMessageGroup:
+ def _make_group_channel(
+ self,
+ policy="open",
+ allow_from=None,
+ require_mention=True,
+ ) -> tuple[SignalChannel, list]:
+ return _make_channel_with_capture(
+ group_enabled=True,
+ group_policy=policy,
+ group_allow_from=allow_from or [],
+ require_mention=require_mention,
+ )
+
+ @pytest.mark.asyncio
+ async def test_group_disabled_rejected(self):
+ ch = _make_channel(group_enabled=False)
+ handled: list[dict] = []
+ ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign]
+ params = _group_envelope(group_id="grp==", message="hi")
+ await ch._handle_receive_notification(params)
+ assert handled == []
+
+ @pytest.mark.asyncio
+ async def test_group_open_policy_no_mention_blocked_when_required(self):
+ ch, handled = self._make_group_channel(require_mention=True)
+ params = _group_envelope(group_id="grp==", message="hey everyone")
+ await ch._handle_receive_notification(params)
+ assert handled == []
+
+ @pytest.mark.asyncio
+ async def test_group_open_policy_no_mention_required(self):
+ ch, handled = self._make_group_channel(require_mention=False)
+ params = _group_envelope(group_id="grp==", message="hey everyone")
+ await ch._handle_receive_notification(params)
+ assert len(handled) == 1
+ assert handled[0]["chat_id"] == "grp=="
+
+ @pytest.mark.asyncio
+ async def test_group_allowlist_accepted(self):
+ ch, handled = self._make_group_channel(
+ policy="allowlist", allow_from=["grp=="], require_mention=False
+ )
+ params = _group_envelope(group_id="grp==", message="hi")
+ await ch._handle_receive_notification(params)
+ assert len(handled) == 1
+
+ @pytest.mark.asyncio
+ async def test_group_allowlist_rejected(self):
+ ch, handled = self._make_group_channel(policy="allowlist", allow_from=["other=="])
+ params = _group_envelope(group_id="grp==", message="hi")
+ await ch._handle_receive_notification(params)
+ assert handled == []
+
+ @pytest.mark.asyncio
+ async def test_group_mention_triggers_response(self):
+ ch, handled = self._make_group_channel(require_mention=True)
+ ch._remember_account_id_alias("+10000000000")
+ mentions = [{"number": "+10000000000", "start": 0, "length": 1}]
+ params = _group_envelope(group_id="grp==", message=" hello", mentions=mentions)
+ await ch._handle_receive_notification(params)
+ assert len(handled) == 1
+
+ @pytest.mark.asyncio
+ async def test_group_v2_id_extracted(self):
+ ch, handled = self._make_group_channel(require_mention=False)
+ params = _group_envelope(group_id="grpV2==", message="hi", use_v2=True)
+ await ch._handle_receive_notification(params)
+ assert len(handled) == 1
+ assert handled[0]["chat_id"] == "grpV2=="
+
+ @pytest.mark.asyncio
+ async def test_group_message_includes_sender_prefix(self):
+ ch, handled = self._make_group_channel(require_mention=False)
+ params = _group_envelope(group_id="grp==", source_name="Bob", message="hello")
+ await ch._handle_receive_notification(params)
+ assert "[Bob]:" in handled[0]["content"]
+
+ @pytest.mark.asyncio
+ async def test_group_message_context_prepended(self):
+ ch, handled = self._make_group_channel(require_mention=False)
+ # First message — adds to buffer but no context yet
+ params1 = _group_envelope(group_id="grp==", source_name="Alice", message="msg1")
+ await ch._handle_receive_notification(params1)
+ # Second message — should include context from first
+ params2 = _group_envelope(group_id="grp==", source_name="Bob", message="msg2")
+ await ch._handle_receive_notification(params2)
+ assert "[Recent group messages for context:]" in handled[1]["content"]
+ assert "msg1" in handled[1]["content"]
+
+ @pytest.mark.asyncio
+ async def test_group_metadata_marks_is_group(self):
+ ch, handled = self._make_group_channel(require_mention=False)
+ params = _group_envelope(group_id="grp==", message="hi")
+ await ch._handle_receive_notification(params)
+ assert handled[0]["metadata"]["is_group"] is True
+ assert handled[0]["metadata"]["group_id"] == "grp=="
+
+ @pytest.mark.asyncio
+ async def test_bot_account_alias_learned_from_incoming(self):
+ ch, handled = self._make_group_channel(require_mention=False)
+ # If the bot's own UUID appears in an envelope we learn it
+ params = _dm_envelope(source_number="+10000000000", source_uuid="new-bot-uuid")
+ # DMs from self are processed (learning alias), but DM policy is open
+ ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign]
+ ch._start_typing = lambda chat_id: None # type: ignore[method-assign]
+ await ch._handle_receive_notification(params)
+ assert ch._id_matches_account("new-bot-uuid")
+
+
+# ---------------------------------------------------------------------------
+# Lifecycle / SSE
+# ---------------------------------------------------------------------------
+
+
+class _FakeSSEResponse:
+ """Minimal stand-in for httpx Response under stream()."""
+
+ def __init__(self, lines: list[str], status_code: int = 200) -> None:
+ self.status_code = status_code
+ self._lines = lines
+
+ async def aiter_lines(self):
+ for line in self._lines:
+ yield line
+
+
+def _fake_streaming_client(lines: list[str], *, status_code: int = 200) -> MagicMock:
+ """Return an httpx.AsyncClient stand-in whose .stream() yields a FakeSSEResponse."""
+ response = _FakeSSEResponse(lines, status_code=status_code)
+
+ @asynccontextmanager
+ async def _ctx(*_args, **_kwargs):
+ yield response
+
+ http = MagicMock()
+ http.stream = lambda *a, **kw: _ctx(*a, **kw)
+ return http
+
+
+class TestLifecycle:
+ @pytest.mark.asyncio
+ async def test_start_returns_early_when_phone_missing(self):
+ """start() with an empty phone number must not enter the HTTP loop."""
+ ch = _make_channel(phone_number="")
+ await ch.start()
+ assert ch._running is False
+ assert ch._http is None
+ assert ch._sse_task is None
+
+
+class TestSSEReceiveLoop:
+ @pytest.mark.asyncio
+ async def test_dispatches_valid_envelope(self):
+ ch = _make_channel()
+ ch._running = True
+
+ captured: list[dict] = []
+
+ async def capture(params):
+ captured.append(params)
+
+ ch._handle_receive_notification = capture # type: ignore[method-assign]
+ ch._http = _fake_streaming_client(
+ ['data: {"envelope":{"sourceNumber":"+19995550001"}}', ""]
+ )
+
+ # Loop ends when lines exhaust; the surrounding _start_http_mode would
+ # treat that as a disconnect, but the loop itself raises ConnectionError
+ # when the stream closes while still running.
+ with pytest.raises(ConnectionError):
+ await ch._sse_receive_loop()
+ assert captured == [{"envelope": {"sourceNumber": "+19995550001"}}]
+
+ @pytest.mark.asyncio
+ async def test_handles_invalid_json_frame(self):
+ """An unparseable SSE frame is logged and skipped without crashing."""
+ ch = _make_channel()
+ ch._running = True
+
+ captured: list[dict] = []
+
+ async def capture(params):
+ captured.append(params)
+
+ ch._handle_receive_notification = capture # type: ignore[method-assign]
+ ch._http = _fake_streaming_client(
+ [
+ "data: this-is-not-json",
+ "", # event boundary triggers parse attempt
+ 'data: {"envelope":{"sourceNumber":"+1"}}',
+ "",
+ ]
+ )
+
+ with pytest.raises(ConnectionError):
+ await ch._sse_receive_loop()
+ # Bad frame skipped; good frame still dispatched.
+ assert captured == [{"envelope": {"sourceNumber": "+1"}}]
+
+ @pytest.mark.asyncio
+ async def test_non_200_status_raises(self):
+ ch = _make_channel()
+ ch._running = True
+ ch._http = _fake_streaming_client([], status_code=503)
+ with pytest.raises(ConnectionError, match="status 503"):
+ await ch._sse_receive_loop()
+
+ @pytest.mark.asyncio
+ async def test_no_http_client_raises(self):
+ ch = _make_channel()
+ ch._http = None
+ with pytest.raises(RuntimeError, match="HTTP client not initialized"):
+ await ch._sse_receive_loop()
+
+
+# ---------------------------------------------------------------------------
+# Command handling
+# ---------------------------------------------------------------------------
+
+
+class TestCommandHandling:
+ @pytest.mark.asyncio
+ async def test_dm_command_forwarded_to_bus(self):
+ """Slash commands in DMs are forwarded to the bus for AgentLoop to handle."""
+ ch, forwarded = _make_channel_with_capture(dm_enabled=True, dm_policy="open")
+ params = _dm_envelope(source_number="+19995550001", message="/reset")
+ await ch._handle_receive_notification(params)
+ assert len(forwarded) == 1
+ assert forwarded[0]["content"].strip() == "/reset"
+
+ @pytest.mark.asyncio
+ async def test_group_command_bypasses_mention_requirement(self):
+ """Slash commands in groups bypass the mention requirement and reach the bus."""
+ ch, forwarded = _make_channel_with_capture(
+ group_enabled=True, group_policy="open", require_mention=True
+ )
+ params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset")
+ await ch._handle_receive_notification(params)
+ assert len(forwarded) == 1
+ assert "/reset" in forwarded[0]["content"]
+
+ @pytest.mark.asyncio
+ async def test_command_denied_for_disallowed_dm_sender(self):
+ """Commands from senders not on the DM allowlist are dropped."""
+ ch, forwarded = _make_channel_with_capture(dm_enabled=False)
+ params = _dm_envelope(source_number="+19995550001", message="/reset")
+ await ch._handle_receive_notification(params)
+ assert forwarded == []
+
+
+# ---------------------------------------------------------------------------
+# send() — outbound messages
+# ---------------------------------------------------------------------------
+
+
+class TestSend:
+ def _make_send_channel(self) -> tuple[SignalChannel, _FakeHTTPClient]:
+ ch = _make_channel()
+ client = _FakeHTTPClient()
+ ch._http = client # type: ignore[assignment]
+ return ch, client
+
+ @pytest.mark.asyncio
+ async def test_send_plain_text_posts_rpc(self):
+ ch, client = self._make_send_channel()
+ msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello")
+ await ch.send(msg)
+ assert len(client.posts) == 1
+ payload = client.posts[0]["json"]
+ assert payload["method"] == "send"
+ assert payload["params"]["message"] == "hello"
+
+ @pytest.mark.asyncio
+ async def test_send_with_markdown_includes_text_styles(self):
+ ch, client = self._make_send_channel()
+ msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="**bold**")
+ await ch.send(msg)
+ params = client.posts[0]["json"]["params"]
+ assert "textStyle" in params
+ assert any("BOLD" in s for s in params["textStyle"])
+
+ @pytest.mark.asyncio
+ async def test_send_split_message_redistributes_text_styles(self):
+ """Long message split across chunks: each chunk gets its own textStyle
+ with offsets rebased to that chunk."""
+ ch, client = self._make_send_channel()
+ ch._MAX_MESSAGE_LEN = 12 # type: ignore[attr-defined]
+ msg = OutboundMessage(
+ channel="signal",
+ chat_id="+19995550001",
+ content="**head** middle and **tail**",
+ )
+ await ch.send(msg)
+ assert len(client.posts) >= 2
+ # Chunk 0 has BOLD for "head"; chunk 1+ must also carry BOLD for "tail".
+ bold_chunks = [
+ p["json"]["params"]
+ for p in client.posts
+ if any("BOLD" in s for s in p["json"]["params"].get("textStyle", []))
+ ]
+ assert len(bold_chunks) >= 2, (
+ "expected BOLD ranges in more than one chunk; got "
+ f"{[p['json']['params'] for p in client.posts]}"
+ )
+ # Each emitted range must point inside its own chunk's text.
+ for params in bold_chunks:
+ chunk_text = params["message"]
+ for entry in params["textStyle"]:
+ s, ln, _ = entry.split(":", 2)
+ start, length = int(s), int(ln)
+ end_units = start + length
+ assert end_units <= len(chunk_text.encode("utf-16-le")) // 2
+
+ @pytest.mark.asyncio
+ async def test_send_empty_content_skips_rpc(self):
+ ch, client = self._make_send_channel()
+ msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="")
+ await ch.send(msg)
+ assert client.posts == []
+
+ @pytest.mark.asyncio
+ async def test_send_to_group_uses_group_id(self):
+ ch, client = self._make_send_channel()
+ msg = OutboundMessage(channel="signal", chat_id="grp==", content="hi group")
+ await ch.send(msg)
+ params = client.posts[0]["json"]["params"]
+ assert "groupId" in params
+ assert "recipient" not in params
+
+ @pytest.mark.asyncio
+ async def test_send_to_dm_uses_recipient(self):
+ ch, client = self._make_send_channel()
+ msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hi")
+ await ch.send(msg)
+ params = client.posts[0]["json"]["params"]
+ assert "recipient" in params
+
+ @pytest.mark.asyncio
+ async def test_send_with_media_includes_attachments(self):
+ ch, client = self._make_send_channel()
+ msg = OutboundMessage(
+ channel="signal",
+ chat_id="+19995550001",
+ content="see attachment",
+ media=["/tmp/file.jpg"],
+ )
+ await ch.send(msg)
+ params = client.posts[0]["json"]["params"]
+ assert params.get("attachments") == ["/tmp/file.jpg"]
+
+ @pytest.mark.asyncio
+ async def test_send_progress_message_does_not_stop_typing(self):
+ ch, client = self._make_send_channel()
+ stopped: list[str] = []
+
+ async def record_stop(chat_id, **kwargs):
+ stopped.append(chat_id)
+
+ ch._stop_typing = record_stop # type: ignore[method-assign]
+ msg = OutboundMessage(
+ channel="signal",
+ chat_id="+19995550001",
+ content="working...",
+ metadata={"_progress": True},
+ )
+ await ch.send(msg)
+ # Progress messages should NOT stop the typing indicator
+ assert stopped == []
+
+ @pytest.mark.asyncio
+ async def test_send_final_message_stops_typing(self):
+ ch, client = self._make_send_channel()
+ stopped: list[str] = []
+
+ async def record_stop(chat_id, send_stop=True):
+ stopped.append(chat_id)
+
+ ch._stop_typing = record_stop # type: ignore[method-assign]
+ msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="done")
+ await ch.send(msg)
+ assert "+19995550001" in stopped
+
+ @pytest.mark.asyncio
+ async def test_send_raises_on_daemon_error(self):
+ # _send_http_request turns every exception into {"error": ...}, so this branch
+ # is the only place ChannelManager retry can be triggered — must raise.
+ ch = _make_channel()
+ ch._http = _FakeHTTPClient(default_response={"error": {"message": "fail"}})
+ msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello")
+ with pytest.raises(RuntimeError, match="signal-cli send failed"):
+ await ch.send(msg)
+
+
+# ---------------------------------------------------------------------------
+# stop()
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_stop_cancels_sse_task() -> None:
+ ch = _make_channel()
+ cancelled = False
+
+ async def long_running():
+ nonlocal cancelled
+ try:
+ await asyncio.sleep(9999)
+ except asyncio.CancelledError:
+ cancelled = True
+ raise
+
+ ch._sse_task = asyncio.create_task(long_running())
+ # Yield so the task can enter its body (reach the first await) before cancel.
+ await asyncio.sleep(0)
+ ch._running = True
+
+ await ch.stop()
+
+ assert cancelled
+ assert ch._running is False
+
+
+@pytest.mark.asyncio
+async def test_stop_closes_http_client() -> None:
+ ch = _make_channel()
+ client = _FakeHTTPClient()
+ ch._http = client # type: ignore[assignment]
+ ch._running = True
+
+ await ch.stop()
+
+ assert client.closed
+
+
+@pytest.mark.asyncio
+async def test_stop_safe_when_no_sse_task() -> None:
+ ch = _make_channel()
+ ch._running = True
+ # Should not raise even with no _sse_task
+ await ch.stop()
+ assert ch._running is False
+
+
+# ---------------------------------------------------------------------------
+# _send_request / _send_http_request
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_send_request_increments_id() -> None:
+ ch = _make_channel()
+ client = _FakeHTTPClient()
+ ch._http = client # type: ignore[assignment]
+
+ await ch._send_request("testMethod", {"key": "val"})
+ await ch._send_request("testMethod", {"key": "val"})
+
+ ids = [p["json"]["id"] for p in client.posts]
+ assert ids == [1, 2]
+
+
+@pytest.mark.asyncio
+async def test_send_request_raises_when_not_connected() -> None:
+ ch = _make_channel()
+ # _http is None by default
+ with pytest.raises(RuntimeError, match="Not connected"):
+ await ch._send_request("testMethod")
+
+
+# ---------------------------------------------------------------------------
+# _handle_receive_notification — envelope shapes
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_handle_notification_sync_message_does_not_forward() -> None:
+ ch = _make_channel(dm_enabled=True, dm_policy="open")
+ handled: list[dict] = []
+ ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign]
+
+ notification = {
+ "envelope": {
+ "sourceNumber": "+19995550001",
+ "syncMessage": {
+ "sentMessage": {
+ "destination": "+19990000000",
+ "message": "sent from other device",
+ }
+ },
+ }
+ }
+ await ch._handle_receive_notification(notification)
+ assert handled == []
+
+
+@pytest.mark.asyncio
+async def test_handle_notification_no_source_skipped() -> None:
+ ch = _make_channel(dm_enabled=True, dm_policy="open")
+ handled: list[dict] = []
+ ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign]
+
+ notification = {"envelope": {"dataMessage": {"message": "ghost"}}}
+ await ch._handle_receive_notification(notification)
+ assert handled == []
+
+
+# ---------------------------------------------------------------------------
+# Config: allow_from property aggregation
+# ---------------------------------------------------------------------------
+
+
+def test_config_allow_from_aggregates_dm_and_group() -> None:
+ config = SignalConfig(
+ enabled=True,
+ phone_number="+10000000000",
+ dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]),
+ group=SignalGroupConfig(enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]),
+ )
+ combined = config.allow_from
+ assert "+1111" in combined
+ assert "+2222" in combined
+ assert "+3333" in combined
+ # Duplicates removed
+ assert combined.count("+1111") == 1
+
+
+def test_config_allow_from_wildcard_propagates() -> None:
+ config = SignalConfig(
+ enabled=True,
+ phone_number="+10000000000",
+ dm=SignalDMConfig(enabled=True, policy="open", allow_from=["*"]),
+ group=SignalGroupConfig(enabled=True, policy="open", allow_from=[]),
+ )
+ assert "*" in config.allow_from
diff --git a/tests/channels/test_signal_markdown.py b/tests/channels/test_signal_markdown.py
new file mode 100644
index 000000000..37a21c6d8
--- /dev/null
+++ b/tests/channels/test_signal_markdown.py
@@ -0,0 +1,525 @@
+"""Unit tests for the Signal markdown → plain text + textStyle converter."""
+
+from nanobot.channels.signal import _markdown_to_signal, _partition_styles
+from nanobot.utils.helpers import split_message
+
+
+def _utf16_len(s: str) -> int:
+ return len(s.encode("utf-16-le")) // 2
+
+
+def styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]:
+ """Return a dict mapping each styled substring to its style list."""
+ result: dict[str, list[str]] = {}
+ for entry in text_styles:
+ start_s, length_s, style = entry.split(":", 2)
+ start, length = int(start_s), int(length_s)
+ span = plain[start : start + length]
+ result.setdefault(span, []).append(style)
+ return result
+
+
+def utf16_styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]:
+ """Like styles_for, but slices `plain` using UTF-16 offsets (Signal's units)."""
+ encoded = plain.encode("utf-16-le")
+ result: dict[str, list[str]] = {}
+ for entry in text_styles:
+ start_s, length_s, style = entry.split(":", 2)
+ start, length = int(start_s), int(length_s)
+ span = encoded[start * 2 : (start + length) * 2].decode("utf-16-le")
+ result.setdefault(span, []).append(style)
+ return result
+
+
+# ---------------------------------------------------------------------------
+# Basic cases
+# ---------------------------------------------------------------------------
+
+
+def test_empty():
+ plain, styles = _markdown_to_signal("")
+ assert plain == ""
+ assert styles == []
+
+
+def test_plain_text():
+ plain, styles = _markdown_to_signal("hello world")
+ assert plain == "hello world"
+ assert styles == []
+
+
+def test_bold_stars():
+ plain, styles = _markdown_to_signal("say **hello** now")
+ assert plain == "say hello now"
+ assert styles_for(plain, styles) == {"hello": ["BOLD"]}
+
+
+def test_bold_underscores():
+ plain, styles = _markdown_to_signal("say __hello__ now")
+ assert plain == "say hello now"
+ assert styles_for(plain, styles) == {"hello": ["BOLD"]}
+
+
+def test_italic_star():
+ plain, styles = _markdown_to_signal("say *hello* now")
+ assert plain == "say hello now"
+ assert styles_for(plain, styles) == {"hello": ["ITALIC"]}
+
+
+def test_italic_underscore():
+ plain, styles = _markdown_to_signal("say _hello_ now")
+ assert plain == "say hello now"
+ assert styles_for(plain, styles) == {"hello": ["ITALIC"]}
+
+
+def test_strikethrough():
+ plain, styles = _markdown_to_signal("say ~~hello~~ now")
+ assert plain == "say hello now"
+ assert styles_for(plain, styles) == {"hello": ["STRIKETHROUGH"]}
+
+
+# ---------------------------------------------------------------------------
+# Code
+# ---------------------------------------------------------------------------
+
+
+def test_inline_code():
+ plain, styles = _markdown_to_signal("run `ls -la` here")
+ assert plain == "run ls -la here"
+ assert styles_for(plain, styles) == {"ls -la": ["MONOSPACE"]}
+
+
+def test_code_block():
+ plain, styles = _markdown_to_signal("```\nprint('hi')\n```")
+ assert "print('hi')" in plain
+ assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or "MONOSPACE" in str(
+ styles_for(plain, styles)
+ )
+
+
+def test_code_block_with_lang():
+ plain, styles = _markdown_to_signal("```python\ncode\n```")
+ assert "code" in plain
+ assert any("MONOSPACE" in s for s in styles)
+
+
+def test_code_block_not_processed_further():
+ """Markdown inside a code block must not be styled."""
+ plain, styles = _markdown_to_signal("```\n**not bold**\n```")
+ assert "**not bold**" in plain
+ # Only MONOSPACE should be applied, no BOLD
+ for entry in styles:
+ assert "BOLD" not in entry
+
+
+def test_inline_code_not_processed_further():
+ """Markdown inside inline code must not be styled."""
+ plain, styles = _markdown_to_signal("use `**raw**` please")
+ assert "**raw**" in plain
+ for entry in styles:
+ assert "BOLD" not in entry
+
+
+# ---------------------------------------------------------------------------
+# Headers
+# ---------------------------------------------------------------------------
+
+
+def test_header_becomes_bold():
+ plain, styles = _markdown_to_signal("# My Title")
+ assert plain == "My Title"
+ assert styles_for(plain, styles) == {"My Title": ["BOLD"]}
+
+
+def test_h2_becomes_bold():
+ plain, styles = _markdown_to_signal("## Sub-section")
+ assert plain == "Sub-section"
+ assert styles_for(plain, styles) == {"Sub-section": ["BOLD"]}
+
+
+# ---------------------------------------------------------------------------
+# Blockquotes
+# ---------------------------------------------------------------------------
+
+
+def test_blockquote_strips_marker():
+ plain, styles = _markdown_to_signal("> some quote")
+ assert plain == "some quote"
+ assert styles == []
+
+
+# ---------------------------------------------------------------------------
+# Lists
+# ---------------------------------------------------------------------------
+
+
+def test_bullet_dash():
+ plain, styles = _markdown_to_signal("- item one")
+ assert plain == "• item one"
+
+
+def test_bullet_star():
+ plain, styles = _markdown_to_signal("* item two")
+ assert plain == "• item two"
+
+
+def test_numbered_list():
+ plain, styles = _markdown_to_signal("1. first\n2. second")
+ assert "1. first" in plain
+ assert "2. second" in plain
+
+
+# ---------------------------------------------------------------------------
+# Links
+# ---------------------------------------------------------------------------
+
+
+def test_link_text_differs_from_url():
+ plain, styles = _markdown_to_signal("[Click here](https://example.com)")
+ assert plain == "Click here (https://example.com)"
+ assert styles == []
+
+
+def test_link_text_equals_url():
+ plain, styles = _markdown_to_signal("[https://example.com](https://example.com)")
+ assert plain == "https://example.com"
+ assert styles == []
+
+
+def test_link_text_equals_url_without_scheme():
+ plain, styles = _markdown_to_signal("[example.com](https://example.com)")
+ assert plain == "https://example.com"
+
+
+# ---------------------------------------------------------------------------
+# Mixed / nesting
+# ---------------------------------------------------------------------------
+
+
+def test_bold_and_italic_adjacent():
+ plain, styles = _markdown_to_signal("**bold** and *italic*")
+ assert plain == "bold and italic"
+ sd = styles_for(plain, styles)
+ assert sd.get("bold") == ["BOLD"]
+ assert sd.get("italic") == ["ITALIC"]
+
+
+def test_header_with_inline_code():
+ """Header becomes BOLD; code inside becomes MONOSPACE (not double-BOLD)."""
+ plain, styles = _markdown_to_signal("# Use `grep`")
+ assert plain == "Use grep"
+ sd = styles_for(plain, styles)
+ assert "BOLD" in sd.get("Use ", []) or "BOLD" in str(styles)
+ assert "MONOSPACE" in sd.get("grep", [])
+
+
+def test_multiline_mixed():
+ md = "**Title**\n\nSome *italic* text.\n\n- bullet\n- another"
+ plain, styles = _markdown_to_signal(md)
+ assert "Title" in plain
+ assert "italic" in plain
+ assert "• bullet" in plain
+ sd = styles_for(plain, styles)
+ assert "BOLD" in sd.get("Title", [])
+ assert "ITALIC" in sd.get("italic", [])
+
+
+# ---------------------------------------------------------------------------
+# Table rendering
+# ---------------------------------------------------------------------------
+
+
+def test_table_rendered_as_monospace():
+ md = "| A | B |\n| - | - |\n| 1 | 2 |"
+ plain, styles = _markdown_to_signal(md)
+ assert "A" in plain and "B" in plain
+ assert any("MONOSPACE" in s for s in styles)
+
+
+# ---------------------------------------------------------------------------
+# Style range format
+# ---------------------------------------------------------------------------
+
+
+def test_style_range_format():
+ """Each style entry must be 'start:length:STYLE'."""
+ _, styles = _markdown_to_signal("**bold** text")
+ for entry in styles:
+ parts = entry.split(":")
+ assert len(parts) == 3
+ assert parts[0].isdigit()
+ assert parts[1].isdigit()
+ assert parts[2] in {"BOLD", "ITALIC", "STRIKETHROUGH", "MONOSPACE", "SPOILER"}
+
+
+def test_style_ranges_are_within_bounds():
+ text = "hello **world** end"
+ plain, styles = _markdown_to_signal(text)
+ for entry in styles:
+ start_s, length_s, _ = entry.split(":", 2)
+ start, length = int(start_s), int(length_s)
+ assert start >= 0
+ assert start + length <= len(plain)
+
+
+# ---------------------------------------------------------------------------
+# Non-BMP / UTF-16 offsets
+#
+# Signal's BodyRange (and signal-cli's textStyle) interprets start/length in
+# UTF-16 code units. Python's len() counts code points, so characters outside
+# the BMP (emojis, supplementary CJK) shift offsets by +1 per occurrence.
+# ---------------------------------------------------------------------------
+
+
+def assert_within_utf16_bounds(plain: str, styles: list[str]) -> None:
+ limit = _utf16_len(plain)
+ for entry in styles:
+ start_s, length_s, _ = entry.split(":", 2)
+ start, length = int(start_s), int(length_s)
+ assert start >= 0
+ assert start + length <= limit, f"range {entry} exceeds utf-16 length {limit} of {plain!r}"
+
+
+def test_bold_with_emoji_inside():
+ plain, styles = _markdown_to_signal("**hi 🎉 bye**")
+ assert plain == "hi 🎉 bye"
+ assert utf16_styles_for(plain, styles) == {"hi 🎉 bye": ["BOLD"]}
+ assert_within_utf16_bounds(plain, styles)
+
+
+def test_italic_with_trailing_emoji():
+ plain, styles = _markdown_to_signal("*bye 🎉*")
+ assert plain == "bye 🎉"
+ assert utf16_styles_for(plain, styles) == {"bye 🎉": ["ITALIC"]}
+ assert_within_utf16_bounds(plain, styles)
+
+
+def test_bold_after_emoji_prefix():
+ plain, styles = _markdown_to_signal("🎉 **bold**")
+ assert plain == "🎉 bold"
+ assert utf16_styles_for(plain, styles) == {"bold": ["BOLD"]}
+ assert_within_utf16_bounds(plain, styles)
+
+
+def test_bold_after_and_inside_emoji():
+ plain, styles = _markdown_to_signal("🎉 **a 🎊 b**")
+ assert plain == "🎉 a 🎊 b"
+ assert utf16_styles_for(plain, styles) == {"a 🎊 b": ["BOLD"]}
+ assert_within_utf16_bounds(plain, styles)
+
+
+def test_supplementary_cjk_in_bold():
+ """Non-BMP CJK (U+20BB7) proves the bug is UTF-16, not emoji-specific."""
+ plain, styles = _markdown_to_signal("**𠮷野家**")
+ assert plain == "𠮷野家"
+ assert utf16_styles_for(plain, styles) == {"𠮷野家": ["BOLD"]}
+ assert_within_utf16_bounds(plain, styles)
+
+
+def test_zwj_emoji_in_bold():
+ """ZWJ family sequence = multiple surrogate pairs + BMP ZWJs."""
+ plain, styles = _markdown_to_signal("**hi 👨👩👧 bye**")
+ assert plain == "hi 👨👩👧 bye"
+ assert utf16_styles_for(plain, styles) == {"hi 👨👩👧 bye": ["BOLD"]}
+ assert_within_utf16_bounds(plain, styles)
+
+
+def test_ascii_offsets_unchanged():
+ """ASCII-only path must produce the same offsets as before the UTF-16 fix."""
+ plain, styles = _markdown_to_signal("**bold** plain *it*")
+ assert plain == "bold plain it"
+ assert sorted(styles) == sorted(["0:4:BOLD", "11:2:ITALIC"])
+
+
+def test_reported_daily_brief_pattern():
+ """Regression for the reported bug: a single non-BMP emoji shifts every
+ subsequent styled span left by 1 UTF-16 unit, lopping off the last letter.
+ """
+ md = (
+ "**Weather**\n"
+ "- Conditions: 🌩️ Thunderstorms\n\n"
+ "**News**\n"
+ "*World*\n"
+ "*Local*\n\n"
+ "**Quote of the Day**"
+ )
+ plain, styles = _markdown_to_signal(md)
+ sd = utf16_styles_for(plain, styles)
+ assert sd.get("Weather") == ["BOLD"]
+ assert sd.get("News") == ["BOLD"]
+ assert sd.get("World") == ["ITALIC"]
+ assert sd.get("Local") == ["ITALIC"]
+ assert sd.get("Quote of the Day") == ["BOLD"]
+ assert_within_utf16_bounds(plain, styles)
+
+
+# ---------------------------------------------------------------------------
+# Chunk redistribution
+#
+# split_message can break a long Signal payload into multiple chunks. The
+# style ranges from _markdown_to_signal are anchored to the full text, so
+# they must be redistributed per-chunk with rebased offsets — otherwise
+# styles for chunks 1..N are silently lost.
+# ---------------------------------------------------------------------------
+
+
+def _resolve_chunk_styles(text: str, max_len: int) -> tuple[list[str], list[list[str]]]:
+ """Helper: full markdown → signal pipeline, including chunking."""
+ plain, styles = _markdown_to_signal(text)
+ chunks = split_message(plain, max_len) if plain else [""]
+ return chunks, _partition_styles(plain, chunks, styles)
+
+
+def test_partition_styles_single_chunk_passthrough():
+ plain, styles = _markdown_to_signal("**bold** plain *it*")
+ parts = _partition_styles(plain, [plain], styles)
+ assert parts == [styles]
+
+
+def test_partition_styles_no_styles():
+ plain = "hello world"
+ assert _partition_styles(plain, [plain], []) == [[]]
+ assert _partition_styles(plain, ["hello", "world"], []) == [[], []]
+
+
+def test_partition_styles_drops_styles_outside_chunks():
+ """Whitespace trimmed by split_message must not carry a style range."""
+ plain = "a b"
+ # Fake a style spanning the trimmed whitespace only.
+ chunks = ["a", "b"]
+ parts = _partition_styles(plain, chunks, ["1:3:BOLD"])
+ assert parts == [[], []]
+
+
+def test_partition_styles_long_message_preserves_chunk_one_styles():
+ """A bold span deep in the message must follow the message into chunk 1."""
+ # Two ~30-char paragraphs separated by a blank line, then **tail**.
+ line_a = "alpha " * 5 # 30 chars, ends with space
+ line_b = "beta " * 5
+ md = f"{line_a.strip()}\n\n{line_b.strip()}\n\n**tail**"
+ plain, styles = _markdown_to_signal(md)
+ # Force a split between the paragraphs.
+ max_len = len(line_a.strip()) + 2 # fits paragraph A + the "\n\n"
+ chunks = split_message(plain, max_len)
+ assert len(chunks) >= 2, "test setup must produce a split"
+ parts = _partition_styles(plain, chunks, styles)
+ # The bold "tail" should land in the last chunk, with chunk-relative offset.
+ final_chunk = chunks[-1]
+ final_styles = parts[-1]
+ assert any("BOLD" in s for s in final_styles)
+ for entry in final_styles:
+ s, ln, _ = entry.split(":", 2)
+ start, length = int(s), int(ln)
+ slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
+ "utf-16-le"
+ )
+ assert slice_ == "tail"
+
+
+def test_partition_styles_chunk_zero_styles_unchanged():
+ """Styles entirely in chunk 0 keep their original offsets."""
+ md = "**head** middle and **tail**"
+ plain, styles = _markdown_to_signal(md)
+ # Split so chunk 0 contains "head" and part of the rest, chunk 1 contains "tail".
+ chunks = split_message(plain, 12)
+ assert len(chunks) >= 2
+ parts = _partition_styles(plain, chunks, styles)
+ # "head" lives in chunk 0; assert its offset is unchanged (chunk 0 starts at 0).
+ head_entries = [s for s in parts[0] if "BOLD" in s]
+ assert any(s.startswith("0:4:") for s in head_entries)
+
+
+def test_partition_styles_with_non_bmp_chunk_offset():
+ """Chunk-start offsets must be expressed in UTF-16 code units."""
+ # Emoji in chunk 0, bold in chunk 1.
+ md = "🎉 alpha beta gamma\n\n**tail**"
+ plain, styles = _markdown_to_signal(md)
+ chunks = split_message(plain, 18)
+ assert len(chunks) >= 2
+ parts = _partition_styles(plain, chunks, styles)
+ final_styles = parts[-1]
+ assert any("BOLD" in s for s in final_styles)
+ final_chunk = chunks[-1]
+ for entry in final_styles:
+ s, ln, _ = entry.split(":", 2)
+ start, length = int(s), int(ln)
+ slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
+ "utf-16-le"
+ )
+ assert slice_ == "tail"
+
+
+def test_partition_styles_range_spanning_chunks_is_split():
+ """A style range that straddles a chunk boundary gets sliced into both chunks."""
+ # Construct manually: plain = "abc def", style covers "abc def" (whole thing).
+ plain = "abc def"
+ chunks = split_message(plain, 4) # "abc" / "def"
+ assert chunks == ["abc", "def"]
+ parts = _partition_styles(plain, chunks, ["0:7:BOLD"])
+ # Chunk 0 holds 0:3:BOLD, chunk 1 holds 0:3:BOLD (length=3 each, "def" only
+ # since the space was trimmed by lstrip).
+ assert parts[0] == ["0:3:BOLD"]
+ assert parts[1] == ["0:3:BOLD"]
+
+
+# ---------------------------------------------------------------------------
+# Adjacency, nesting, and malformed input
+# ---------------------------------------------------------------------------
+
+
+def test_bold_italic_combo_outer_bold_inner_italic():
+ """`**_combo_**` carries both BOLD and ITALIC over the same span."""
+ plain, styles = _markdown_to_signal("**_combo_**")
+ assert plain == "combo"
+ sd = styles_for(plain, styles)
+ assert set(sd.get("combo", [])) == {"BOLD", "ITALIC"}
+
+
+def test_bold_and_italic_adjacent_no_separator():
+ """`**bold***italic*` produces BOLD on `bold` and ITALIC on `italic`."""
+ plain, styles = _markdown_to_signal("**bold***italic*")
+ assert plain == "bolditalic"
+ sd = styles_for(plain, styles)
+ assert sd.get("bold") == ["BOLD"]
+ assert sd.get("italic") == ["ITALIC"]
+
+
+def test_unclosed_bold_falls_through_as_plain():
+ """An unmatched `**` opener round-trips as literal text with no style."""
+ plain, styles = _markdown_to_signal("**bold")
+ assert plain == "**bold"
+ assert styles == []
+
+
+def test_unclosed_inline_code_falls_through_as_plain():
+ """An unmatched backtick round-trips as literal text with no style."""
+ plain, styles = _markdown_to_signal("use `grep")
+ assert plain == "use `grep"
+ assert styles == []
+
+
+def test_inline_code_inside_blockquote():
+ """Blockquote prefix is stripped; inline code becomes MONOSPACE."""
+ plain, styles = _markdown_to_signal("> use `grep`")
+ assert plain == "use grep"
+ sd = styles_for(plain, styles)
+ assert sd.get("grep") == ["MONOSPACE"]
+
+
+def test_header_with_inner_bold_produces_contiguous_bold_ranges():
+ """`# **wrap** me` — header forces BOLD over the whole line; the inner `**`
+ splits the run, yielding two contiguous BOLD ranges that together cover
+ "wrap me". This is intentional — Signal renders adjacent same-style ranges
+ as a single visual span.
+ """
+ plain, styles = _markdown_to_signal("# **wrap** me")
+ assert plain == "wrap me"
+ # Both ranges are BOLD; collectively they cover the whole "wrap me".
+ bold_ranges = [s for s in styles if s.endswith(":BOLD")]
+ assert len(bold_ranges) == 2
+ covered = set()
+ for entry in bold_ranges:
+ start, length, _ = entry.split(":", 2)
+ for i in range(int(start), int(start) + int(length)):
+ covered.add(i)
+ assert covered == set(range(len(plain)))
diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py
index cc011a244..74a780c80 100644
--- a/tests/channels/test_websocket_channel.py
+++ b/tests/channels/test_websocket_channel.py
@@ -1055,6 +1055,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist(
}
assert image_providers["openrouter"]["label"] == "OpenRouter"
assert image_providers["openrouter"]["configured"] is False
+ assert image_providers["openai_codex"]["configured"] is True
assert image_providers["gemini"]["label"] == "Gemini"
assert body["runtime"]["config_path"] == str(config_path)
workspace_path = body["runtime"]["workspace_path"].replace("\\", "/")
diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py
index a695ba936..3d3606e75 100644
--- a/tests/channels/test_weixin_channel.py
+++ b/tests/channels/test_weixin_channel.py
@@ -1,6 +1,7 @@
import asyncio
import json
import tempfile
+import time
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock
@@ -374,6 +375,7 @@ async def test_send_uses_typing_start_and_cancel_when_ticket_available() -> None
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-typing"
+ channel._context_token_at["wx-user"] = time.time()
channel._send_text = AsyncMock()
channel._api_post = AsyncMock(
side_effect=[
@@ -402,6 +404,7 @@ async def test_send_still_sends_text_when_typing_ticket_missing() -> None:
channel._client = object()
channel._token = "token"
channel._context_tokens["wx-user"] = "ctx-no-ticket"
+ channel._context_token_at["wx-user"] = time.time()
channel._send_text = AsyncMock()
channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "no config"})
@@ -1254,3 +1257,526 @@ async def test_send_text_succeeds_on_zero_errcode() -> None:
await channel._send_text("wx-user", "hello", "ctx-ok")
channel._api_post.assert_awaited_once()
+
+
+@pytest.mark.asyncio
+async def test_send_text_raises_on_nonzero_ret_even_when_errcode_zero() -> None:
+ """_send_text must raise when the API returns ret != 0, even if errcode is 0.
+
+ The iLink API signals failure through either field. Checking only errcode
+ caused silent message drops (responses generated but never delivered).
+ """
+ channel, _bus = _make_channel()
+ channel._client = object()
+ channel._token = "token"
+ channel._api_post = AsyncMock(
+ return_value={"ret": -100, "errcode": 0, "errmsg": "internal error"}
+ )
+
+ with pytest.raises(RuntimeError, match="WeChat send text error.*ret=-100.*errcode=0"):
+ await channel._send_text("wx-user", "hello", "ctx-ok")
+
+ channel._api_post.assert_awaited_once()
+
+
+# ---------------------------------------------------------------------------
+# Tests for _poll_once not silently dropping messages on processing errors
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_poll_once_logs_exception_on_process_message_failure(monkeypatch) -> None:
+ """When _process_message raises, _poll_once must log the error and continue
+ processing remaining messages instead of silently swallowing the exception."""
+ channel, _bus = _make_channel()
+ channel._client = SimpleNamespace(timeout=None)
+ channel._token = "token"
+ channel._get_updates_buf = "old-buf"
+
+ calls = []
+ logged_messages: list[str] = []
+
+ async def _failing_process(msg: dict) -> None:
+ calls.append(msg.get("message_id"))
+ if msg.get("message_id") == "msg-1":
+ raise RuntimeError("processing failed")
+
+ channel._process_message = _failing_process # type: ignore[method-assign]
+
+ monkeypatch.setattr(
+ channel.logger,
+ "exception",
+ lambda message, *args, **kwargs: logged_messages.append(str(message)),
+ )
+
+ channel._api_post = AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "ret": 0,
+ "errcode": 0,
+ "get_updates_buf": "new-buf",
+ "msgs": [
+ {"message_id": "msg-1", "message_type": 1},
+ {"message_id": "msg-2", "message_type": 1},
+ ],
+ }
+ )
+
+ await channel._poll_once()
+
+ # Both messages should have been attempted
+ assert calls == ["msg-1", "msg-2"]
+ # Buffer should still advance (already updated before processing)
+ assert channel._get_updates_buf == "new-buf"
+ # Error should be logged
+ assert any("Failed to process WeChat message" in m for m in logged_messages)
+
+
+@pytest.mark.asyncio
+async def test_poll_loop_logs_exception_and_continues_on_poll_failure(monkeypatch) -> None:
+ """When _poll_once raises a non-timeout exception, the start() loop must log
+ the error and continue polling instead of exiting silently."""
+ channel, _bus = _make_channel()
+ channel._client = object()
+ channel._token = "token"
+ channel.config.token = "token" # skip QR login in start()
+ channel._running = True
+
+ call_count = 0
+ logged_messages: list[str] = []
+
+ async def _failing_poll() -> None:
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ raise RuntimeError("poll exploded")
+ channel._running = False # Stop after second call
+
+ channel._poll_once = _failing_poll # type: ignore[method-assign]
+
+ monkeypatch.setattr(
+ channel.logger,
+ "exception",
+ lambda message, *args, **kwargs: logged_messages.append(str(message)),
+ )
+
+ # Use a tiny retry delay so the test finishes quickly
+ original_retry = weixin_mod.RETRY_DELAY_S
+ weixin_mod.RETRY_DELAY_S = 0.01
+ try:
+ await channel.start()
+ finally:
+ weixin_mod.RETRY_DELAY_S = original_retry
+
+ assert call_count == 2
+ assert any("WeChat poll loop error" in m for m in logged_messages)
+
+
+# ---------------------------------------------------------------------------
+# Tool-hint buffering
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_buffer_single_tool_hint_not_sent_immediately() -> None:
+ channel, _bus = _make_channel()
+ channel._client = object()
+ channel._token = "token"
+ channel.send_tool_hints = True
+ channel._context_tokens["wx-user"] = "ctx-1"
+ channel._context_token_at["wx-user"] = time.time()
+ channel._send_text = AsyncMock()
+
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "Using tool",
+ "media": [],
+ "metadata": {"_progress": True, "_tool_hint": True},
+ },
+ )()
+ )
+
+ channel._send_text.assert_not_awaited()
+ assert channel._pending_tool_hints["wx-user"] == ["Using tool"]
+
+
+@pytest.mark.asyncio
+async def test_buffer_multiple_tool_hints_flushed_on_final_answer() -> None:
+ channel, _bus = _make_channel()
+ channel._client = object()
+ channel._token = "token"
+ channel.send_tool_hints = True
+ channel._context_tokens["wx-user"] = "ctx-1"
+ channel._context_token_at["wx-user"] = time.time()
+ channel._send_text = AsyncMock()
+
+ for hint in ["tool1", "tool2"]:
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": hint,
+ "media": [],
+ "metadata": {"_progress": True, "_tool_hint": True},
+ },
+ )()
+ )
+
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "Done",
+ "media": [],
+ "metadata": {},
+ },
+ )()
+ )
+
+ assert channel._send_text.await_count == 2
+ channel._send_text.assert_any_await("wx-user", "tool1\n\ntool2", "ctx-1")
+ channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
+ assert "wx-user" not in channel._pending_tool_hints
+
+
+@pytest.mark.asyncio
+async def test_thought_progress_flushes_tool_hints() -> None:
+ """Thoughts are visible progress messages and must act as separators,
+ flushing buffered tool hints before they are sent."""
+ channel, _bus = _make_channel()
+ channel._client = object()
+ channel._token = "token"
+ channel.send_tool_hints = True
+ channel._context_tokens["wx-user"] = "ctx-1"
+ channel._context_token_at["wx-user"] = time.time()
+ channel._send_text = AsyncMock()
+
+ # Buffer a tool hint
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "search 'foo'",
+ "media": [],
+ "metadata": {"_progress": True, "_tool_hint": True},
+ },
+ )()
+ )
+
+ # Send a thought — progress but not a tool_hint.
+ # It must act as a separator and flush the buffered hint.
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "Let me think...",
+ "media": [],
+ "metadata": {"_progress": True},
+ },
+ )()
+ )
+
+ # The buffered hint was flushed before the thought was sent.
+ channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1")
+ channel._send_text.assert_any_await("wx-user", "Let me think...", "ctx-1")
+ assert "wx-user" not in channel._pending_tool_hints
+
+ # Final answer arrives with nothing left to flush.
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "Done",
+ "media": [],
+ "metadata": {},
+ },
+ )()
+ )
+
+ assert channel._send_text.await_count == 3
+ channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
+
+
+@pytest.mark.asyncio
+async def test_reasoning_delta_does_not_flush_tool_hints() -> None:
+ """Reasoning deltas are invisible in WeChat and must NOT flush buffered
+ tool hints — otherwise hints separated only by hidden reasoning would
+ fail to coalesce."""
+ channel, _bus = _make_channel()
+ channel._client = object()
+ channel._token = "token"
+ channel.send_tool_hints = True
+ channel._context_tokens["wx-user"] = "ctx-1"
+ channel._context_token_at["wx-user"] = time.time()
+ channel._send_text = AsyncMock()
+
+ # Buffer a tool hint
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "search 'foo'",
+ "media": [],
+ "metadata": {"_progress": True, "_tool_hint": True},
+ },
+ )()
+ )
+
+ # Send a reasoning delta — invisible in WeChat, must NOT flush
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "Thinking step 1...",
+ "media": [],
+ "metadata": {"_progress": True, "_reasoning_delta": True},
+ },
+ )()
+ )
+
+ # Reasoning is invisible; hint stays buffered, _send_text not called
+ channel._send_text.assert_not_awaited()
+ assert channel._pending_tool_hints["wx-user"] == ["search 'foo'"]
+
+ # Final answer flushes the buffered hint
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "Done",
+ "media": [],
+ "metadata": {},
+ },
+ )()
+ )
+
+ channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1")
+ channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
+ assert "wx-user" not in channel._pending_tool_hints
+
+
+@pytest.mark.asyncio
+async def test_empty_progress_message_does_not_flush_tool_hints() -> None:
+ """Empty progress messages (e.g. after_iteration tool_events) have no
+ visible content and must NOT act as separators."""
+ channel, _bus = _make_channel()
+ channel._client = object()
+ channel._token = "token"
+ channel.send_tool_hints = True
+ channel._context_tokens["wx-user"] = "ctx-1"
+ channel._context_token_at["wx-user"] = time.time()
+ channel._send_text = AsyncMock()
+
+ # Buffer a tool hint
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "search 'foo'",
+ "media": [],
+ "metadata": {"_progress": True, "_tool_hint": True},
+ },
+ )()
+ )
+
+ # Send an empty progress message (no content, no media)
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "",
+ "media": [],
+ "metadata": {"_progress": True, "_tool_events": [{"phase": "end"}]},
+ },
+ )()
+ )
+
+ # Nothing should have been sent yet
+ channel._send_text.assert_not_awaited()
+ assert channel._pending_tool_hints["wx-user"] == ["search 'foo'"]
+
+ # Final answer flushes the buffered hint
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "Done",
+ "media": [],
+ "metadata": {},
+ },
+ )()
+ )
+
+ channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1")
+ channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
+ assert "wx-user" not in channel._pending_tool_hints
+
+
+@pytest.mark.asyncio
+async def test_buffer_flush_refreshes_context_token() -> None:
+ channel, _bus = _make_channel()
+ channel._client = object()
+ channel._token = "token"
+ channel.send_tool_hints = True
+ channel._context_tokens["wx-user"] = "ctx-old"
+ channel._context_token_at["wx-user"] = time.time()
+ channel._refresh_context_token_if_stale = AsyncMock(return_value="ctx-refreshed")
+ channel._send_text = AsyncMock()
+
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "hint",
+ "media": [],
+ "metadata": {"_progress": True, "_tool_hint": True},
+ },
+ )()
+ )
+
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "Done",
+ "media": [],
+ "metadata": {},
+ },
+ )()
+ )
+
+ assert channel._refresh_context_token_if_stale.await_count == 2
+ channel._refresh_context_token_if_stale.assert_any_await("wx-user", "ctx-old")
+ channel._send_text.assert_any_await("wx-user", "hint", "ctx-refreshed")
+
+
+@pytest.mark.asyncio
+async def test_buffer_flush_failure_does_not_block_final_answer() -> None:
+ channel, _bus = _make_channel()
+ channel._client = object()
+ channel._token = "token"
+ channel.send_tool_hints = True
+ channel._context_tokens["wx-user"] = "ctx-1"
+ channel._context_token_at["wx-user"] = time.time()
+ channel._send_text = AsyncMock(side_effect=[RuntimeError("boom"), None])
+
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "hint",
+ "media": [],
+ "metadata": {"_progress": True, "_tool_hint": True},
+ },
+ )()
+ )
+
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "Done",
+ "media": [],
+ "metadata": {},
+ },
+ )()
+ )
+
+ assert channel._send_text.await_count == 2
+ channel._send_text.assert_any_await("wx-user", "hint", "ctx-1")
+ channel._send_text.assert_any_await("wx-user", "Done", "ctx-1")
+
+
+@pytest.mark.asyncio
+async def test_buffer_flushed_on_stream_end() -> None:
+ channel, _bus = _make_channel()
+ channel._client = object()
+ channel._token = "token"
+ channel.send_tool_hints = True
+ channel._context_tokens["wx-user"] = "ctx-1"
+ channel._context_token_at["wx-user"] = time.time()
+ channel._send_text = AsyncMock()
+
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "hint",
+ "media": [],
+ "metadata": {"_progress": True, "_tool_hint": True},
+ },
+ )()
+ )
+
+ await channel.send_delta("wx-user", "", {"_stream_end": True})
+
+ channel._send_text.assert_awaited_once_with("wx-user", "hint", "ctx-1")
+ assert "wx-user" not in channel._pending_tool_hints
+
+
+@pytest.mark.asyncio
+async def test_stop_clears_buffer() -> None:
+ channel, _bus = _make_channel()
+ channel._pending_tool_hints["wx-user"] = ["hint1", "hint2"]
+ await channel.stop()
+ assert "wx-user" not in channel._pending_tool_hints
+
+
+@pytest.mark.asyncio
+async def test_send_tool_hints_false_drops_tool_hints() -> None:
+ channel, _bus = _make_channel()
+ channel._client = object()
+ channel._token = "token"
+ channel.send_tool_hints = False
+ channel._send_text = AsyncMock()
+
+ await channel.send(
+ type(
+ "Msg",
+ (),
+ {
+ "chat_id": "wx-user",
+ "content": "hint",
+ "media": [],
+ "metadata": {"_progress": True, "_tool_hint": True},
+ },
+ )()
+ )
+
+ channel._send_text.assert_not_awaited()
+ assert "wx-user" not in channel._pending_tool_hints
diff --git a/tests/config/test_model_presets.py b/tests/config/test_model_presets.py
index 046c5b04d..fe01c2547 100644
--- a/tests/config/test_model_presets.py
+++ b/tests/config/test_model_presets.py
@@ -192,3 +192,20 @@ def test_match_provider_uses_preset_provider_when_forced() -> None:
})
name = config.get_provider_name()
assert name == "anthropic"
+
+
+def test_match_provider_routes_forced_novita_model_api_models() -> None:
+ config = Config.model_validate({
+ "providers": {
+ "novita": {"apiKey": "sk-test"},
+ },
+ "agents": {
+ "defaults": {
+ "model": "deepseek-v4-pro",
+ "provider": "novita",
+ }
+ },
+ })
+
+ assert config.get_provider_name() == "novita"
+ assert config.get_api_base() == "https://api.novita.ai/openai"
diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py
index 85314dc79..ee1f9a090 100644
--- a/tests/providers/test_custom_provider.py
+++ b/tests/providers/test_custom_provider.py
@@ -56,6 +56,35 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None:
assert result.content == "hello world"
+def test_custom_provider_parse_chunks_deduplicates_parallel_tool_call_ids() -> None:
+ chunks = [{
+ "choices": [{
+ "finish_reason": "tool_calls",
+ "delta": {
+ "tool_calls": [
+ {
+ "index": 0,
+ "id": "call_dup",
+ "function": {"name": "read_file", "arguments": '{"path":"a.txt"}'},
+ },
+ {
+ "index": 1,
+ "id": "call_dup",
+ "function": {"name": "read_file", "arguments": '{"path":"b.txt"}'},
+ },
+ ],
+ },
+ }],
+ }]
+
+ result = OpenAICompatProvider._parse_chunks(chunks)
+ ids = [tool_call.id for tool_call in result.tool_calls or []]
+
+ assert ids[0] == "call_dup"
+ assert len(ids) == 2
+ assert len(set(ids)) == 2
+
+
def test_local_provider_502_error_includes_reachability_hint() -> None:
spec = find_by_name("ollama")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
diff --git a/tests/providers/test_image_generation.py b/tests/providers/test_image_generation.py
index c42d947d5..77025895c 100644
--- a/tests/providers/test_image_generation.py
+++ b/tests/providers/test_image_generation.py
@@ -9,10 +9,13 @@ import pytest
from nanobot.providers.image_generation import (
AIHubMixImageGenerationClient,
+ CodexImageGenerationClient,
GeminiImageGenerationClient,
GeneratedImageResponse,
ImageGenerationError,
MiniMaxImageGenerationClient,
+ OllamaImageGenerationClient,
+ OpenAIImageGenerationClient,
OpenRouterImageGenerationClient,
StepFunImageGenerationClient,
)
@@ -36,12 +39,14 @@ class FakeResponse:
payload: dict[str, Any],
status_code: int = 200,
content: bytes = b"",
+ sse_lines: list[str] | None = None,
) -> None:
self._payload = payload
self.status_code = status_code
self.text = str(payload)
self.content = content
self.request = httpx.Request("POST", "https://openrouter.ai/api/v1/chat/completions")
+ self._sse_lines = sse_lines
def json(self) -> dict[str, Any]:
return self._payload
@@ -51,6 +56,15 @@ class FakeResponse:
response = httpx.Response(self.status_code, request=self.request, text=self.text)
raise httpx.HTTPStatusError("failed", request=self.request, response=response)
+ async def aiter_lines(self):
+ if self._sse_lines is not None:
+ for line in self._sse_lines:
+ yield line
+ return
+ # Fallback: treat response text as SSE lines
+ for line in self.text.split("\n"):
+ yield line
+
class FakeClient:
def __init__(self, response: FakeResponse) -> None:
@@ -133,6 +147,54 @@ async def test_openrouter_image_generation_requires_api_key() -> None:
await client.generate(prompt="draw", model="model")
+@pytest.mark.asyncio
+async def test_ollama_image_generation_payload_and_response() -> None:
+ raw_b64 = PNG_DATA_URL.removeprefix("data:image/png;base64,")
+ fake = FakeClient(FakeResponse({"image": raw_b64}))
+ client = OllamaImageGenerationClient(
+ api_key="ollama-test",
+ api_base="http://localhost:11434/v1/",
+ extra_headers={"X-Test": "1"},
+ extra_body={"seed": 123},
+ client=fake, # type: ignore[arg-type]
+ )
+
+ response = await client.generate(
+ prompt="a sunset",
+ model="x/z-image-turbo",
+ aspect_ratio="16:9",
+ image_size="1K",
+ )
+
+ assert response.images == [PNG_DATA_URL]
+ assert response.content == ""
+
+ call = fake.calls[0]
+ assert call["url"] == "http://localhost:11434/api/generate"
+ assert call["headers"]["Authorization"] == "Bearer ollama-test"
+ assert call["headers"]["X-Test"] == "1"
+ body = call["json"]
+ assert body["model"] == "x/z-image-turbo"
+ assert body["prompt"] == "a sunset"
+ assert body["width"] == 1024
+ assert body["height"] == 576
+ assert body["steps"] == 0
+ assert body["stream"] is False
+ assert body["seed"] == 123
+
+
+@pytest.mark.asyncio
+async def test_ollama_image_generation_rejects_reference_images() -> None:
+ client = OllamaImageGenerationClient(api_key=None)
+
+ with pytest.raises(ImageGenerationError, match="reference images"):
+ await client.generate(
+ prompt="edit this",
+ model="x/z-image-turbo",
+ reference_images=["ref.png"],
+ )
+
+
@pytest.mark.asyncio
async def test_aihubmix_image_generation_payload_and_response() -> None:
raw_b64 = PNG_DATA_URL.removeprefix("data:image/png;base64,")
@@ -531,3 +593,437 @@ async def test_stepfun_no_images_raises() -> None:
with pytest.raises(ImageGenerationError, match="returned no images"):
await client.generate(prompt="draw", model="step-image-edit-2")
+
+
+# ---------------------------------------------------------------------------
+# OpenAI
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_openai_payload_and_response() -> None:
+ fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ api_base="https://api.openai.com/v1",
+ extra_headers={"X-Test": "1"},
+ client=fake, # type: ignore[arg-type]
+ )
+
+ response = await client.generate(
+ prompt="a cat on the moon",
+ model="dall-e-3",
+ aspect_ratio="16:9",
+ )
+
+ assert response.images == [PNG_DATA_URL]
+ call = fake.calls[0]
+ assert call["url"] == "https://api.openai.com/v1/images/generations"
+ assert call["headers"]["Authorization"] == "Bearer sk-openai-test"
+ assert call["headers"]["X-Test"] == "1"
+ body = call["json"]
+ assert body["model"] == "dall-e-3"
+ assert body["prompt"] == "a cat on the moon"
+ assert body["response_format"] == "b64_json"
+ assert body["n"] == 1
+ assert body["size"] == "1792x1024"
+
+
+@pytest.mark.asyncio
+async def test_openai_b64_json_response_uses_detected_mime() -> None:
+ raw_b64 = base64.b64encode(JPEG_BYTES).decode("ascii")
+ fake = FakeClient(FakeResponse({"data": [{"b64_json": raw_b64}]}))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ response = await client.generate(prompt="draw", model="dall-e-3")
+
+ assert response.images == [f"data:image/jpeg;base64,{raw_b64}"]
+
+
+@pytest.mark.asyncio
+async def test_openai_url_download_fallback() -> None:
+ fake = FakeClient(FakeResponse({"data": [{"url": "https://cdn.example/image.png"}]}))
+ fake.get_response = FakeResponse({}, content=PNG_BYTES)
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ response = await client.generate(prompt="draw", model="dall-e-3")
+
+ assert response.images[0].startswith("data:image/png;base64,")
+ assert fake.get_calls[0]["url"] == "https://cdn.example/image.png"
+
+
+@pytest.mark.asyncio
+async def test_openai_multiple_images() -> None:
+ fake = FakeClient(FakeResponse({
+ "data": [
+ {"b64_json": RAW_B64},
+ {"b64_json": RAW_B64},
+ ]
+ }))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ response = await client.generate(prompt="draw", model="dall-e-3")
+
+ assert len(response.images) == 2
+ assert response.images == [PNG_DATA_URL, PNG_DATA_URL]
+
+
+@pytest.mark.asyncio
+async def test_openai_aspect_ratio_to_size() -> None:
+ fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="1:1")
+ assert fake.calls[0]["json"]["size"] == "1024x1024"
+
+
+@pytest.mark.asyncio
+async def test_openai_dalle3_uses_supported_orientation_sizes() -> None:
+ fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="3:4")
+ await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="4:3")
+
+ assert fake.calls[0]["json"]["size"] == "1024x1792"
+ assert fake.calls[1]["json"]["size"] == "1792x1024"
+
+
+@pytest.mark.asyncio
+async def test_openai_dalle2_uses_square_size_for_non_square_ratios() -> None:
+ fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ await client.generate(prompt="draw", model="dall-e-2", aspect_ratio="16:9")
+
+ assert fake.calls[0]["json"]["size"] == "1024x1024"
+
+
+@pytest.mark.asyncio
+async def test_openai_gpt_image_uses_supported_landscape_size() -> None:
+ fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="16:9")
+
+ assert fake.calls[0]["json"]["size"] == "1536x1024"
+
+
+@pytest.mark.asyncio
+async def test_openai_gpt_image_uses_supported_orientation_sizes() -> None:
+ fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="3:4")
+ await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="4:3")
+
+ assert fake.calls[0]["json"]["size"] == "1024x1536"
+ assert fake.calls[1]["json"]["size"] == "1536x1024"
+
+
+@pytest.mark.asyncio
+async def test_openai_default_size_when_no_aspect_ratio() -> None:
+ fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ await client.generate(prompt="draw", model="dall-e-3")
+
+ body = fake.calls[0]["json"]
+ assert body["size"] == "1024x1024"
+
+
+@pytest.mark.asyncio
+async def test_openai_ignores_explicit_size_unsupported_by_model_family() -> None:
+ fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ await client.generate(
+ prompt="draw",
+ model="dall-e-3",
+ aspect_ratio="16:9",
+ image_size="1536x1024",
+ )
+
+ body = fake.calls[0]["json"]
+ assert body["size"] == "1792x1024"
+
+
+@pytest.mark.asyncio
+async def test_openai_uses_explicit_image_size() -> None:
+ fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]}))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ await client.generate(
+ prompt="draw",
+ model="dall-e-3",
+ aspect_ratio="16:9",
+ image_size="1024x1024",
+ )
+
+ body = fake.calls[0]["json"]
+ assert body["size"] == "1024x1024"
+
+
+@pytest.mark.asyncio
+async def test_openai_requires_api_key() -> None:
+ client = OpenAIImageGenerationClient(api_key=None)
+
+ with pytest.raises(ImageGenerationError, match="API key"):
+ await client.generate(prompt="draw", model="dall-e-3")
+
+
+# ---------------------------------------------------------------------------
+# OpenAI Codex (Responses API)
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_codex_payload_and_response(monkeypatch) -> None:
+ import sys
+ from dataclasses import dataclass
+ from types import SimpleNamespace
+
+ @dataclass
+ class FakeToken:
+ account_id: str = "acct-123"
+ access: str = "oauth-token"
+
+ async def fake_to_thread(fn, *args, **kwargs):
+ return fn(*args, **kwargs)
+
+ monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
+ fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
+ monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
+
+ sse_lines = [
+ 'data: {"type":"response.output_item.added","item":{"id":"ig_1","type":"image_generation_call","status":"in_progress"}}',
+ "",
+ f'data: {{"type":"response.output_item.done","item":{{"id":"ig_1","type":"image_generation_call","result":"{PNG_DATA_URL}","status":"completed"}}}}',
+ "",
+ 'data: [DONE]',
+ "",
+ ]
+ fake = FakeClient(FakeResponse({}, sse_lines=sse_lines))
+ client = CodexImageGenerationClient(
+ api_key=None,
+ api_base="https://chatgpt.com/backend-api",
+ extra_headers={"X-Test": "1"},
+ client=fake, # type: ignore[arg-type]
+ )
+
+ response = await client.generate(
+ prompt="draw a cat",
+ model="gpt-5.4",
+ )
+
+ assert response.images == [PNG_DATA_URL]
+ assert response.content == ""
+ call = fake.calls[0]
+ assert call["url"] == "https://chatgpt.com/backend-api/codex/responses"
+ assert call["headers"]["Authorization"] == "Bearer oauth-token"
+ assert call["headers"]["chatgpt-account-id"] == "acct-123"
+ assert call["headers"]["OpenAI-Beta"] == "responses=experimental"
+ assert call["headers"]["X-Test"] == "1"
+ body = call["json"]
+ assert body["model"] == "gpt-5.4"
+ assert body["instructions"] == "Generate an image based on the user's request."
+ assert body["input"] == [{"role": "user", "content": "draw a cat"}]
+ assert body["tools"] == [{"type": "image_generation"}]
+ assert body["tool_choice"] == "auto"
+ assert body["store"] is False
+ assert body["stream"] is True
+
+
+@pytest.mark.asyncio
+async def test_codex_strips_model_prefix(monkeypatch) -> None:
+ import sys
+ from dataclasses import dataclass
+ from types import SimpleNamespace
+
+ @dataclass
+ class FakeToken:
+ account_id: str = "acct-123"
+ access: str = "oauth-token"
+
+ async def fake_to_thread(fn, *args, **kwargs):
+ return fn(*args, **kwargs)
+
+ monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
+ fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
+ monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
+
+ fake = FakeClient(FakeResponse({}, sse_lines=[
+ f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":"{PNG_DATA_URL}"}}}}',
+ "",
+ 'data: [DONE]',
+ "",
+ ]))
+ client = CodexImageGenerationClient(
+ api_key=None, client=fake # type: ignore[arg-type]
+ )
+
+ await client.generate(prompt="draw", model="openai-codex/gpt-5.4")
+
+ assert fake.calls[0]["json"]["model"] == "gpt-5.4"
+
+
+@pytest.mark.asyncio
+async def test_codex_requires_oauth(monkeypatch) -> None:
+ async def fake_to_thread(fn, *args, **kwargs):
+ raise RuntimeError("no token")
+
+ monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
+
+ client = CodexImageGenerationClient(api_key=None)
+
+ with pytest.raises(ImageGenerationError, match="OAuth token"):
+ await client.generate(prompt="draw", model="gpt-5.4")
+
+
+@pytest.mark.asyncio
+async def test_codex_no_images_raises(monkeypatch) -> None:
+ import sys
+ from dataclasses import dataclass
+ from types import SimpleNamespace
+
+ @dataclass
+ class FakeToken:
+ account_id: str = "acct-123"
+ access: str = "oauth-token"
+
+ async def fake_to_thread(fn, *args, **kwargs):
+ return fn(*args, **kwargs)
+
+ monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
+ fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
+ monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
+
+ fake = FakeClient(FakeResponse({}, sse_lines=[
+ 'data: {"type":"response.completed","response":{"status":"completed"}}',
+ "",
+ 'data: [DONE]',
+ "",
+ ]))
+ client = CodexImageGenerationClient(
+ api_key=None, client=fake # type: ignore[arg-type]
+ )
+
+ with pytest.raises(ImageGenerationError, match="returned no images"):
+ await client.generate(prompt="draw", model="gpt-5.4")
+
+
+@pytest.mark.asyncio
+async def test_codex_extracts_text_content(monkeypatch) -> None:
+ import sys
+ from dataclasses import dataclass
+ from types import SimpleNamespace
+
+ @dataclass
+ class FakeToken:
+ account_id: str = "acct-123"
+ access: str = "oauth-token"
+
+ async def fake_to_thread(fn, *args, **kwargs):
+ return fn(*args, **kwargs)
+
+ monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
+ fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
+ monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
+
+ fake = FakeClient(FakeResponse({}, sse_lines=[
+ 'data: {"type":"response.output_text.delta","delta":"Here "}',
+ "",
+ 'data: {"type":"response.output_text.delta","delta":"is your cat image."}',
+ "",
+ f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":"{PNG_DATA_URL}"}}}}',
+ "",
+ 'data: [DONE]',
+ "",
+ ]))
+ client = CodexImageGenerationClient(
+ api_key=None, client=fake # type: ignore[arg-type]
+ )
+
+ response = await client.generate(prompt="draw a cat", model="gpt-5.4")
+
+ assert response.images == [PNG_DATA_URL]
+ assert response.content == "Here is your cat image."
+
+
+@pytest.mark.asyncio
+async def test_codex_json_result_format(monkeypatch) -> None:
+ """image_generation_call result can be a dict with image_url key."""
+ import sys
+ from dataclasses import dataclass
+ from types import SimpleNamespace
+
+ @dataclass
+ class FakeToken:
+ account_id: str = "acct-123"
+ access: str = "oauth-token"
+
+ async def fake_to_thread(fn, *args, **kwargs):
+ return fn(*args, **kwargs)
+
+ monkeypatch.setattr("asyncio.to_thread", fake_to_thread)
+ fake_oauth = SimpleNamespace(get_token=lambda: FakeToken())
+ monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth)
+
+ fake = FakeClient(FakeResponse({}, sse_lines=[
+ f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":{{"image_url":"{PNG_DATA_URL}"}}}}}}',
+ "",
+ 'data: [DONE]',
+ "",
+ ]))
+ client = CodexImageGenerationClient(
+ api_key=None, client=fake # type: ignore[arg-type]
+ )
+
+ response = await client.generate(prompt="draw", model="gpt-5.4")
+
+ assert response.images == [PNG_DATA_URL]
+
+
+@pytest.mark.asyncio
+async def test_openai_no_images_raises() -> None:
+ fake = FakeClient(FakeResponse({"data": []}))
+ client = OpenAIImageGenerationClient(
+ api_key="sk-openai-test",
+ client=fake, # type: ignore[arg-type]
+ )
+
+ with pytest.raises(ImageGenerationError, match="returned no images"):
+ await client.generate(prompt="draw", model="dall-e-3")
diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py
index 76414ad35..6e00cba19 100644
--- a/tests/providers/test_litellm_kwargs.py
+++ b/tests/providers/test_litellm_kwargs.py
@@ -441,6 +441,15 @@ def test_openrouter_spec_is_gateway() -> None:
assert spec.default_api_base == "https://openrouter.ai/api/v1"
+def test_novita_spec_uses_openai_compatible_gateway() -> None:
+ spec = find_by_name("novita")
+ assert spec is not None
+ assert spec.is_gateway is True
+ assert spec.backend == "openai_compat"
+ assert spec.env_key == "NOVITA_API_KEY"
+ assert spec.default_api_base == "https://api.novita.ai/openai"
+
+
def test_gemma_routes_to_gemini_provider() -> None:
"""gemma models (e.g. gemma-3-27b-it) must auto-route to Gemini when GEMINI_API_KEY is set.
Users running gemma via the Gemini API endpoint expect automatic provider detection."""
@@ -1013,6 +1022,41 @@ def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
+def test_openai_compat_deduplicates_duplicate_tool_call_ids_in_history() -> None:
+ with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
+ provider = OpenAICompatProvider()
+
+ sanitized = provider._sanitize_messages([
+ {"role": "user", "content": "check both files"},
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {
+ "id": "ab1b45c2a",
+ "type": "function",
+ "function": {"name": "read_file", "arguments": '{"path":"a.txt"}'},
+ },
+ {
+ "id": "ab1b45c2a",
+ "type": "function",
+ "function": {"name": "read_file", "arguments": '{"path":"b.txt"}'},
+ },
+ ],
+ },
+ {"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "a"},
+ {"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "b"},
+ {"role": "user", "content": "continue"},
+ ])
+
+ tool_call_ids = [tc["id"] for tc in sanitized[1]["tool_calls"]]
+ tool_result_ids = [sanitized[2]["tool_call_id"], sanitized[3]["tool_call_id"]]
+
+ assert tool_call_ids[0] == "ab1b45c2a"
+ assert len(tool_call_ids) == len(set(tool_call_ids)) == 2
+ assert tool_result_ids == tool_call_ids
+
+
def test_openai_compat_stringifies_dict_tool_arguments() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
@@ -1382,12 +1426,15 @@ def test_kimi_k25_thinking_enabled() -> None:
"""kimi-k2.5 with reasoning_effort set should opt in to thinking."""
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="medium")
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
+ # Moonshot rejects both 'reasoning_effort' and 'thinking' (#3939)
+ assert "reasoning_effort" not in kw
def test_kimi_k25_thinking_disabled_for_minimal() -> None:
"""reasoning_effort='minimal' maps to thinking disabled for kimi-k2.5."""
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="minimal")
assert kw.get("extra_body") == {"thinking": {"type": "disabled"}}
+ assert "reasoning_effort" not in kw
def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None:
@@ -1397,21 +1444,36 @@ def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None:
def test_kimi_k25_thinking_enabled_with_openrouter_prefix() -> None:
- """OpenRouter-style model names like moonshotai/kimi-k2.5 must trigger thinking."""
+ """OpenRouter-style model names like moonshotai/kimi-k2.5 must trigger thinking.
+
+ OR drops upstream-provider `thinking` fields, so the same intent also has
+ to go through OR's `reasoning.effort` shape (#3851 follow-up).
+ """
kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.5", reasoning_effort="medium")
- assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
+ assert kw.get("extra_body") == {
+ "thinking": {"type": "enabled"},
+ "reasoning": {"effort": "medium"},
+ }
+ # Even via OR, reasoning_effort wire kwarg is dropped for kimi models
+ assert "reasoning_effort" not in kw
def test_kimi_k26_thinking_enabled() -> None:
"""kimi-k2.6 with reasoning_effort set should opt in to thinking."""
kw = _build_kwargs_for("moonshot", "kimi-k2.6", reasoning_effort="medium")
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
+ assert "reasoning_effort" not in kw
def test_kimi_k26_thinking_enabled_with_openrouter_prefix() -> None:
- """OpenRouter-style names like moonshotai/kimi-k2.6 must trigger thinking."""
+ """OpenRouter-style names like moonshotai/kimi-k2.6 must trigger thinking
+ via both upstream `thinking` and OR's `reasoning.effort`."""
kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.6", reasoning_effort="medium")
- assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
+ assert kw.get("extra_body") == {
+ "thinking": {"type": "enabled"},
+ "reasoning": {"effort": "medium"},
+ }
+ assert "reasoning_effort" not in kw
def test_moonshot_kimi_k26_temperature_override() -> None:
@@ -1430,6 +1492,7 @@ def test_kimi_k26_code_preview_thinking_enabled() -> None:
"""k2.6-code-preview also supports thinking; should behave like k2.5."""
kw = _build_kwargs_for("moonshot", "k2.6-code-preview", reasoning_effort="high")
assert kw.get("extra_body") == {"thinking": {"type": "enabled"}}
+ assert "reasoning_effort" not in kw
def test_kimi_k2_series_no_thinking_injection() -> None:
@@ -1459,6 +1522,7 @@ def test_kimi_k25_thinking_disabled_for_none_string() -> None:
"""reasoning_effort='none' maps to thinking disabled for kimi-k2.5."""
kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="none")
assert kw.get("extra_body") == {"thinking": {"type": "disabled"}}
+ assert "reasoning_effort" not in kw
def test_dashscope_thinking_disabled_for_none_string() -> None:
diff --git a/tests/providers/test_novita_provider.py b/tests/providers/test_novita_provider.py
new file mode 100644
index 000000000..0b1e8ec12
--- /dev/null
+++ b/tests/providers/test_novita_provider.py
@@ -0,0 +1,97 @@
+"""Tests for the Novita AI provider registration."""
+
+from unittest.mock import patch
+
+from nanobot.config.schema import Config, ProvidersConfig
+from nanobot.providers.openai_compat_provider import OpenAICompatProvider
+from nanobot.providers.registry import PROVIDERS, find_by_name
+
+
+def test_novita_config_field_exists() -> None:
+ config = ProvidersConfig()
+
+ assert hasattr(config, "novita")
+
+
+def test_novita_provider_in_registry() -> None:
+ specs = {spec.name: spec for spec in PROVIDERS}
+
+ assert "novita" in specs
+ novita = specs["novita"]
+ assert novita.backend == "openai_compat"
+ assert novita.env_key == "NOVITA_API_KEY"
+ assert novita.display_name == "Novita AI"
+ assert novita.is_gateway is True
+ assert novita.detect_by_base_keyword == "novita"
+ assert novita.default_api_base == "https://api.novita.ai/openai"
+ assert novita.strip_model_prefix is False
+
+
+def test_find_by_name_novita() -> None:
+ spec = find_by_name("novita")
+
+ assert spec is not None
+ assert spec.name == "novita"
+
+
+def test_novita_forced_provider_uses_default_api_base() -> None:
+ config = Config.model_validate({
+ "providers": {
+ "novita": {
+ "apiKey": "novita-key",
+ },
+ },
+ "agents": {
+ "defaults": {
+ "model": "deepseek-v4-pro",
+ "provider": "novita",
+ },
+ },
+ })
+
+ assert config.get_provider_name("deepseek-v4-pro") == "novita"
+ assert config.get_api_key("deepseek-v4-pro") == "novita-key"
+ assert config.get_api_base("deepseek-v4-pro") == "https://api.novita.ai/openai"
+
+
+def test_novita_gateway_routes_unprefixed_models_when_configured() -> None:
+ config = Config.model_validate({
+ "providers": {
+ "novita": {
+ "apiKey": "novita-key",
+ },
+ },
+ "agents": {
+ "defaults": {
+ "model": "deepseek-v4-pro",
+ },
+ },
+ })
+
+ assert config.get_provider_name("deepseek-v4-pro") == "novita"
+ assert config.get_api_key("deepseek-v4-pro") == "novita-key"
+ assert config.get_api_base("deepseek-v4-pro") == "https://api.novita.ai/openai"
+
+
+def test_novita_preserves_model_api_id() -> None:
+ spec = find_by_name("novita")
+ with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
+ provider = OpenAICompatProvider(
+ api_key="novita-key",
+ default_model="deepseek-v4-pro",
+ spec=spec,
+ )
+
+ kwargs = provider._build_kwargs(
+ messages=[{"role": "user", "content": "hi"}],
+ tools=None,
+ model="deepseek-v4-pro",
+ max_tokens=1024,
+ temperature=0.7,
+ reasoning_effort=None,
+ tool_choice=None,
+ )
+
+ assert kwargs["model"] == "deepseek-v4-pro"
+ assert kwargs["max_tokens"] == 1024
+ assert "max_completion_tokens" not in kwargs
diff --git a/tests/providers/test_openai_responses.py b/tests/providers/test_openai_responses.py
index 74a934f85..36040db58 100644
--- a/tests/providers/test_openai_responses.py
+++ b/tests/providers/test_openai_responses.py
@@ -155,6 +155,49 @@ class TestConvertMessages:
assert items[0]["id"] == "fc_1"
assert items[0]["name"] == "get_weather"
+ def test_duplicate_response_item_ids_are_made_unique(self):
+ """Codex rejects replayed Responses input items with duplicate ids."""
+ _, items = convert_messages([
+ {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [{
+ "id": "call_a|rs_same",
+ "function": {"name": "first", "arguments": "{}"},
+ }],
+ },
+ {"role": "tool", "tool_call_id": "call_a|rs_same", "content": "ok"},
+ {
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [{
+ "id": "call_b|rs_same",
+ "function": {"name": "second", "arguments": "{}"},
+ }],
+ },
+ {"role": "tool", "tool_call_id": "call_b|rs_same", "content": "ok"},
+ ])
+ function_call_ids = [
+ item["id"] for item in items if item.get("type") == "function_call"
+ ]
+ assert function_call_ids == ["rs_same", "rs_same_2"]
+ assert len(function_call_ids) == len(set(function_call_ids))
+
+ def test_fallback_response_item_ids_are_unique_with_multiple_tool_calls(self):
+ _, items = convert_messages([{
+ "role": "assistant",
+ "content": None,
+ "tool_calls": [
+ {"id": "call_a", "function": {"name": "first", "arguments": "{}"}},
+ {"id": "call_b", "function": {"name": "second", "arguments": "{}"}},
+ ],
+ }])
+ function_call_ids = [
+ item["id"] for item in items if item.get("type") == "function_call"
+ ]
+ assert function_call_ids == ["fc_0", "fc_0_2"]
+ assert len(function_call_ids) == len(set(function_call_ids))
+
def test_assistant_with_tool_calls_no_id(self):
"""Fallback IDs when tool_call.id is missing."""
_, items = convert_messages([{
diff --git a/tests/providers/test_xiaomi_mimo_thinking.py b/tests/providers/test_xiaomi_mimo_thinking.py
index 68ca6dd80..92161803f 100644
--- a/tests/providers/test_xiaomi_mimo_thinking.py
+++ b/tests/providers/test_xiaomi_mimo_thinking.py
@@ -32,7 +32,7 @@ def _mimo_spec():
def _openrouter_spec():
- """Return the registered OpenRouter ProviderSpec (no thinking_style)."""
+ """Return the registered OpenRouter ProviderSpec."""
specs = {s.name: s for s in PROVIDERS}
return specs["openrouter"]
@@ -77,6 +77,13 @@ def test_xiaomi_mimo_uses_thinking_type_style():
assert spec.default_api_base == "https://api.xiaomimimo.com/v1"
+def test_openrouter_declares_gateway_reasoning_style():
+ """OpenRouter uses its own reasoning.effort field for routed thinking models."""
+ spec = _openrouter_spec()
+ assert spec.thinking_style == ""
+ assert spec.gateway_reasoning_style == "reasoning_effort"
+
+
# ---------------------------------------------------------------------------
# _build_kwargs wire-format
# ---------------------------------------------------------------------------
@@ -142,9 +149,11 @@ def test_mimo_reasoning_effort_unset_preserves_provider_default():
def test_mimo_via_openrouter_reasoning_effort_none_disables_thinking():
- """OpenRouter routes MiMo as "xiaomi/mimo-v2.5-pro"; the openrouter spec
- has no thinking_style, so the disable signal must come from the
- model-name path (#3845)."""
+ """OpenRouter routes MiMo as "xiaomi/mimo-v2.5-pro" and does NOT forward
+ extra_body.thinking to upstream, so a disable signal must also reach OR
+ in its own `reasoning.effort` shape. Verifies both the upstream-MiMo
+ payload (#3845) and the OR-native payload (#3851 follow-up) are sent.
+ """
provider = _openrouter_provider("xiaomi/mimo-v2.5-pro")
kwargs = provider._build_kwargs(
messages=_simple_messages(),
@@ -152,11 +161,15 @@ def test_mimo_via_openrouter_reasoning_effort_none_disables_thinking():
temperature=0.7, reasoning_effort="none", tool_choice=None,
)
assert "reasoning_effort" not in kwargs
- assert kwargs["extra_body"] == {"thinking": {"type": "disabled"}}
+ assert kwargs["extra_body"] == {
+ "thinking": {"type": "disabled"},
+ "reasoning": {"effort": "none"},
+ }
def test_mimo_via_openrouter_reasoning_effort_medium_enables_thinking():
- """Same as the direct path: any non-none/minimal effort enables thinking."""
+ """Non-none/minimal effort enables thinking and the OR `reasoning.effort`
+ field mirrors the requested effort level."""
provider = _openrouter_provider("xiaomi/mimo-v2.5-pro")
kwargs = provider._build_kwargs(
messages=_simple_messages(),
@@ -164,7 +177,10 @@ def test_mimo_via_openrouter_reasoning_effort_medium_enables_thinking():
temperature=0.7, reasoning_effort="medium", tool_choice=None,
)
assert kwargs.get("reasoning_effort") == "medium"
- assert kwargs["extra_body"] == {"thinking": {"type": "enabled"}}
+ assert kwargs["extra_body"] == {
+ "thinking": {"type": "enabled"},
+ "reasoning": {"effort": "medium"},
+ }
def test_mimo_via_openrouter_bare_slug_also_matches():
@@ -176,12 +192,16 @@ def test_mimo_via_openrouter_bare_slug_also_matches():
tools=None, model=None, max_tokens=100,
temperature=0.7, reasoning_effort="none", tool_choice=None,
)
- assert kwargs["extra_body"] == {"thinking": {"type": "disabled"}}
+ assert kwargs["extra_body"] == {
+ "thinking": {"type": "disabled"},
+ "reasoning": {"effort": "none"},
+ }
def test_mimo_flash_via_openrouter_does_not_inject_thinking():
"""mimo-v2-flash has no thinking mode per Xiaomi docs; the allowlist
- excludes it, so no thinking field should be injected on the gateway path."""
+ excludes it, so neither the upstream `thinking` field nor OR's
+ `reasoning.effort` should be injected on the gateway path."""
provider = _openrouter_provider("xiaomi/mimo-v2-flash")
kwargs = provider._build_kwargs(
messages=_simple_messages(),
@@ -200,3 +220,18 @@ def test_non_mimo_model_via_openrouter_unaffected():
temperature=0.7, reasoning_effort="none", tool_choice=None,
)
assert "extra_body" not in kwargs
+
+
+def test_kimi_via_openrouter_also_injects_reasoning_effort():
+ """Kimi has the same gateway problem as MiMo: OR drops the upstream
+ `thinking` field. The same OR-reasoning injection should fire."""
+ provider = _openrouter_provider("moonshotai/kimi-k2.5")
+ kwargs = provider._build_kwargs(
+ messages=_simple_messages(),
+ tools=None, model=None, max_tokens=100,
+ temperature=0.7, reasoning_effort="none", tool_choice=None,
+ )
+ assert kwargs["extra_body"] == {
+ "thinking": {"type": "disabled"},
+ "reasoning": {"effort": "none"},
+ }
diff --git a/tests/tools/test_apply_patch_tool.py b/tests/tools/test_apply_patch_tool.py
new file mode 100644
index 000000000..2ba247368
--- /dev/null
+++ b/tests/tools/test_apply_patch_tool.py
@@ -0,0 +1,330 @@
+from __future__ import annotations
+
+import asyncio
+
+from nanobot.agent.tools.apply_patch import ApplyPatchTool
+
+
+def test_apply_patch_edits_replace(tmp_path):
+ target = tmp_path / "calc.py"
+ target.write_text("def add(a, b):\n return a + b\n")
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ result = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "calc.py",
+ "action": "replace",
+ "old_text": " return a + b",
+ "new_text": " return a - b",
+ }
+ ]
+ )
+ )
+
+ assert "update calc.py" in result
+ assert target.read_text() == "def add(a, b):\n return a - b\n"
+
+
+def test_apply_patch_edits_add_new_file(tmp_path):
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ result = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "config.py",
+ "action": "add",
+ "new_text": "DEBUG = True",
+ }
+ ]
+ )
+ )
+
+ assert "add config.py" in result
+ assert (tmp_path / "config.py").read_text() == "DEBUG = True\n"
+
+
+def test_apply_patch_edits_preserves_new_file_trailing_blank_lines(tmp_path):
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ result = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "notes.txt",
+ "action": "add",
+ "new_text": "one\n\n",
+ }
+ ]
+ )
+ )
+
+ assert "add notes.txt" in result
+ assert (tmp_path / "notes.txt").read_text() == "one\n\n"
+
+
+def test_apply_patch_edits_add_to_existing_file(tmp_path):
+ target = tmp_path / "log.py"
+ target.write_text("import logging\n\nlogger = logging.getLogger(__name__)\n")
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ result = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "log.py",
+ "action": "add",
+ "new_text": "def debug(msg):\n logger.debug(msg)",
+ }
+ ]
+ )
+ )
+
+ assert "update log.py" in result
+ assert (
+ target.read_text()
+ == "import logging\n\nlogger = logging.getLogger(__name__)\ndef debug(msg):\n logger.debug(msg)\n"
+ )
+
+
+def test_apply_patch_edits_delete(tmp_path):
+ target = tmp_path / "utils.py"
+ target.write_text("def unused():\n pass\ndef used():\n return 1\n")
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ result = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "utils.py",
+ "action": "delete",
+ "old_text": "def unused():\n pass\n",
+ }
+ ]
+ )
+ )
+
+ assert "update utils.py" in result
+ assert target.read_text() == "def used():\n return 1\n"
+
+
+def test_apply_patch_edits_delete_entire_file(tmp_path):
+ target = tmp_path / "obsolete.txt"
+ target.write_text("remove me\n")
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ result = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "obsolete.txt",
+ "action": "delete",
+ "old_text": "remove me\n",
+ }
+ ]
+ )
+ )
+
+ assert "delete obsolete.txt" in result
+ assert not target.exists()
+
+
+def test_apply_patch_edits_delete_substring_with_surrounding_whitespace(tmp_path):
+ target = tmp_path / "keep_whitespace.txt"
+ target.write_text(" token \n")
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ result = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "keep_whitespace.txt",
+ "action": "delete",
+ "old_text": "token",
+ }
+ ]
+ )
+ )
+
+ assert "update keep_whitespace.txt" in result
+ assert target.exists()
+ assert target.read_text() == " \n"
+
+
+def test_apply_patch_edits_batch_multiple_files(tmp_path):
+ a = tmp_path / "a.py"
+ a.write_text("X = 1\n")
+ b = tmp_path / "b.py"
+ b.write_text("from a import X\nprint(X)\n")
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ result = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "a.py",
+ "action": "replace",
+ "old_text": "X = 1",
+ "new_text": "Y = 1",
+ },
+ {
+ "path": "b.py",
+ "action": "replace",
+ "old_text": "from a import X",
+ "new_text": "from a import Y",
+ },
+ ]
+ )
+ )
+
+ assert "update a.py" in result
+ assert "update b.py" in result
+ assert a.read_text() == "Y = 1\n"
+ assert b.read_text() == "from a import Y\nprint(X)\n"
+
+
+def test_apply_patch_edits_rejects_ambiguous_old_text(tmp_path):
+ target = tmp_path / "repeated.txt"
+ target.write_text("target\nmiddle\ntarget\n")
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ result = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "repeated.txt",
+ "action": "replace",
+ "old_text": "target",
+ "new_text": "changed",
+ }
+ ]
+ )
+ )
+
+ assert "old_text appears multiple times" in result
+ assert target.read_text() == "target\nmiddle\ntarget\n"
+
+
+def test_apply_patch_edits_dry_run_validates_without_writing(tmp_path):
+ target = tmp_path / "dry.txt"
+ target.write_text("before\n")
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ result = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "dry.txt",
+ "action": "replace",
+ "old_text": "before",
+ "new_text": "after",
+ },
+ {
+ "path": "added.txt",
+ "action": "add",
+ "new_text": "new",
+ },
+ ],
+ dry_run=True,
+ )
+ )
+
+ assert "Patch dry-run succeeded" in result
+ assert target.read_text() == "before\n"
+ assert not (tmp_path / "added.txt").exists()
+
+
+def test_apply_patch_edits_rejects_absolute_and_parent_paths(tmp_path):
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ absolute = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "/tmp/owned.txt",
+ "action": "add",
+ "new_text": "nope",
+ }
+ ]
+ )
+ )
+ parent = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "../owned.txt",
+ "action": "add",
+ "new_text": "nope",
+ }
+ ]
+ )
+ )
+ windows_absolute = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": r"C:\owned.txt",
+ "action": "add",
+ "new_text": "nope",
+ }
+ ]
+ )
+ )
+ windows_parent = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": r"..\owned.txt",
+ "action": "add",
+ "new_text": "nope",
+ }
+ ]
+ )
+ )
+
+ assert "must be relative" in absolute
+ assert "must not contain '..'" in parent
+ assert "must be relative" in windows_absolute
+ assert "must not contain '..'" in windows_parent
+ assert not (tmp_path.parent / "owned.txt").exists()
+
+
+def test_apply_patch_edits_reports_invalid_edit_shapes(tmp_path):
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ missing_path = asyncio.run(tool.execute(edits=[{"action": "add", "new_text": "x"}]))
+ missing_action = asyncio.run(tool.execute(edits=[{"path": "x.txt", "new_text": "x"}]))
+ non_object = asyncio.run(tool.execute(edits=["not an object"])) # type: ignore[list-item]
+
+ assert "path required for edit" in missing_path
+ assert "action required for edit: x.txt" in missing_action
+ assert "each edit must be an object" in non_object
+
+
+def test_apply_patch_edits_rolls_back_when_late_operation_fails(tmp_path):
+ first = tmp_path / "first.txt"
+ first.write_text("before\n")
+ tool = ApplyPatchTool(workspace=tmp_path)
+
+ result = asyncio.run(
+ tool.execute(
+ edits=[
+ {
+ "path": "first.txt",
+ "action": "replace",
+ "old_text": "before",
+ "new_text": "after",
+ },
+ {
+ "path": "missing.txt",
+ "action": "delete",
+ "old_text": "remove me",
+ },
+ ]
+ )
+ )
+
+ assert "file to update does not exist: missing.txt" in result
+ assert first.read_text() == "before\n"
diff --git a/tests/tools/test_edit_enhancements.py b/tests/tools/test_edit_enhancements.py
index 1f22c963b..7202fc37b 100644
--- a/tests/tools/test_edit_enhancements.py
+++ b/tests/tools/test_edit_enhancements.py
@@ -1,5 +1,5 @@
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
-.ipynb detection, and create-file semantics."""
+notebook JSON editing, and create-file semantics."""
import pytest
@@ -108,22 +108,27 @@ class TestEditCreateFile:
# ---------------------------------------------------------------------------
-# .ipynb detection
+# .ipynb editing
# ---------------------------------------------------------------------------
-class TestEditIpynbDetection:
- """edit_file should refuse .ipynb and suggest notebook_edit."""
+class TestEditIpynbFiles:
+ """edit_file edits notebooks as normal JSON files."""
@pytest.fixture()
def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio
- async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path):
+ async def test_ipynb_can_be_edited_as_json(self, tool, tmp_path):
f = tmp_path / "analysis.ipynb"
f.write_text('{"cells": []}', encoding="utf-8")
- result = await tool.execute(path=str(f), old_text="x", new_text="y")
- assert "notebook" in result.lower()
+ result = await tool.execute(
+ path=str(f),
+ old_text='"cells": []',
+ new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
+ )
+ assert "Successfully edited" in result
+ assert '"source": "hi"' in f.read_text(encoding="utf-8")
# ---------------------------------------------------------------------------
diff --git a/tests/tools/test_exec_platform.py b/tests/tools/test_exec_platform.py
index 69a271ec1..ffb25f985 100644
--- a/tests/tools/test_exec_platform.py
+++ b/tests/tools/test_exec_platform.py
@@ -162,7 +162,7 @@ class TestPathAppendPlatform:
captured_cmd = None
captured_env = {}
- async def capture_spawn(cmd, cwd, env):
+ async def capture_spawn(cmd, cwd, env, shell_program=None, login=True):
nonlocal captured_cmd
captured_cmd = cmd
captured_env.update(env)
@@ -190,7 +190,7 @@ class TestPathAppendPlatform:
captured_env = {}
- async def capture_spawn(cmd, cwd, env):
+ async def capture_spawn(cmd, cwd, env, shell_program=None, login=True):
captured_env.update(env)
return mock_proc
diff --git a/tests/tools/test_exec_session_tools.py b/tests/tools/test_exec_session_tools.py
new file mode 100644
index 000000000..f5fe45e96
--- /dev/null
+++ b/tests/tools/test_exec_session_tools.py
@@ -0,0 +1,361 @@
+from __future__ import annotations
+
+import asyncio
+import re
+import shlex
+import subprocess
+import sys
+
+from nanobot.agent.tools.shell import ExecTool
+from nanobot.agent.tools.exec_session import ExecSessionManager, ListExecSessionsTool, WriteStdinTool
+
+
+def _python_command(code: str) -> str:
+ if sys.platform == "win32":
+ return f"{subprocess.list2cmdline([sys.executable])} -u -c {subprocess.list2cmdline([code])}"
+ return f"{shlex.quote(sys.executable)} -u -c {shlex.quote(code)}"
+
+
+def _session_id(output: str) -> str:
+ match = re.search(r"session_id:\s*([0-9a-f]+)", output)
+ assert match, output
+ return match.group(1)
+
+
+def test_exec_keeps_one_shot_behavior_without_yield_time_ms(tmp_path):
+ async def run() -> str:
+ tool = ExecTool(working_dir=str(tmp_path), timeout=5)
+ return await tool.execute(command="echo hello")
+
+ result = asyncio.run(run())
+
+ assert "hello" in result
+ assert "Exit code: 0" in result
+ assert "session_id:" not in result
+
+
+def test_exec_accepts_command_aliases(tmp_path):
+ async def run() -> str:
+ tool = ExecTool(working_dir="/")
+ return await tool.execute(
+ cmd=_python_command("import os; print(os.getcwd())"),
+ workdir=str(tmp_path),
+ )
+
+ result = asyncio.run(run())
+
+ assert str(tmp_path) in result
+ assert "Exit code: 0" in result
+
+
+def test_exec_returns_completed_session_output_when_yield_time_ms_is_used(tmp_path):
+ async def run() -> str:
+ manager = ExecSessionManager()
+ tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
+ stdin_tool = WriteStdinTool(manager=manager)
+
+ result = await tool.execute(command="echo hello", yield_time_ms=1000)
+ if "session_id:" in result:
+ sid = _session_id(result)
+ result += "\n" + await stdin_tool.execute(
+ session_id=sid,
+ chars="",
+ yield_time_ms=1000,
+ )
+ return result
+
+ result = asyncio.run(run())
+
+ assert "hello" in result
+ assert "Exit code: 0" in result
+ assert "session_id:" not in result
+
+
+def test_exec_session_accepts_max_output_tokens_alias(tmp_path):
+ async def run() -> str:
+ manager = ExecSessionManager()
+ tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
+ command = _python_command("print('A' * 2000)")
+ return await tool.execute(
+ command=command,
+ yield_time_ms=1000,
+ max_output_tokens=1000,
+ )
+
+ result = asyncio.run(run())
+
+ assert "chars truncated" in result
+ assert "Exit code: 0" in result
+
+
+def test_exec_one_shot_accepts_max_output_tokens_alias(tmp_path):
+ async def run() -> str:
+ tool = ExecTool(working_dir=str(tmp_path), timeout=5)
+ command = _python_command("print('A' * 2000)")
+ return await tool.execute(command=command, max_output_tokens=1000)
+
+ result = asyncio.run(run())
+
+ assert "chars truncated" in result
+ assert "Exit code: 0" in result
+
+
+def test_exec_accepts_supported_shell_parameter(tmp_path):
+ async def run() -> str:
+ tool = ExecTool(working_dir=str(tmp_path), timeout=5)
+ return await tool.execute(command="echo shell-ok", shell="sh", login=False)
+
+ if sys.platform == "win32":
+ return
+ result = asyncio.run(run())
+
+ assert "shell-ok" in result
+ assert "Exit code: 0" in result
+
+
+def test_exec_rejects_unsupported_shell(tmp_path):
+ async def run() -> str:
+ tool = ExecTool(working_dir=str(tmp_path), timeout=5)
+ return await tool.execute(command="echo no", shell="python")
+
+ if sys.platform == "win32":
+ return
+ result = asyncio.run(run())
+
+ assert "unsupported shell" in result
+
+
+def test_exec_can_continue_with_stdin(tmp_path):
+ async def run() -> tuple[str, str]:
+ manager = ExecSessionManager()
+ exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
+ stdin_tool = WriteStdinTool(manager=manager)
+ command = _python_command(
+ "import sys; print('ready', flush=True); "
+ "line=sys.stdin.readline(); print('got:' + line.strip(), flush=True)"
+ )
+
+ initial = await exec_tool.execute(command=command, yield_time_ms=500)
+ sid = _session_id(initial)
+ result = await stdin_tool.execute(session_id=sid, chars="ping\n", yield_time_ms=1000)
+ return initial, result
+
+ initial, result = asyncio.run(run())
+ assert "ready" in initial
+ assert "Process running" in initial
+ assert "Elapsed:" in initial
+ assert "got:ping" in result
+ assert "Exit code: 0" in result
+ assert "Elapsed:" in result
+
+
+def test_write_stdin_can_close_stdin(tmp_path):
+ async def run() -> tuple[str, str]:
+ manager = ExecSessionManager()
+ exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
+ stdin_tool = WriteStdinTool(manager=manager)
+ command = _python_command(
+ "import sys; print('ready', flush=True); "
+ "data=sys.stdin.read(); print('got:' + data, flush=True)"
+ )
+
+ initial = await exec_tool.execute(command=command, yield_time_ms=500)
+ sid = _session_id(initial)
+ result = await stdin_tool.execute(
+ session_id=sid,
+ chars="payload",
+ close_stdin=True,
+ yield_time_ms=1000,
+ )
+ return initial, result
+
+ initial, result = asyncio.run(run())
+ assert "ready" in initial
+ assert "got:payload" in result
+ assert "Stdin closed." in result
+ assert "Exit code: 0" in result
+
+
+def test_write_stdin_can_terminate_session(tmp_path):
+ async def run() -> tuple[str, str]:
+ manager = ExecSessionManager()
+ exec_tool = ExecTool(working_dir=str(tmp_path), timeout=30, session_manager=manager)
+ stdin_tool = WriteStdinTool(manager=manager)
+ command = _python_command(
+ "import time; print('ready', flush=True); time.sleep(30)"
+ )
+
+ initial = await exec_tool.execute(command=command, yield_time_ms=500)
+ sid = _session_id(initial)
+ result = await stdin_tool.execute(
+ session_id=sid,
+ terminate=True,
+ yield_time_ms=0,
+ )
+ return initial, result
+
+ initial, result = asyncio.run(run())
+ assert "ready" in initial
+ assert "Session terminated." in result
+ assert "Exit code:" in result
+
+
+def test_write_stdin_accepts_max_output_tokens_alias(tmp_path):
+ async def run() -> tuple[str, str, str]:
+ manager = ExecSessionManager()
+ exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
+ stdin_tool = WriteStdinTool(manager=manager)
+ command = _python_command(
+ "import time; print('A' * 2000, flush=True); time.sleep(5)"
+ )
+
+ initial = await exec_tool.execute(command=command, yield_time_ms=0)
+ sid = _session_id(initial)
+ poll = await stdin_tool.execute(
+ session_id=sid,
+ yield_time_ms=500,
+ max_output_tokens=1000,
+ )
+ cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
+ return initial, poll, cleanup
+
+ initial, poll, cleanup = asyncio.run(run())
+ assert "Process running" in initial
+ assert "chars truncated" in poll
+ assert "Session terminated." in cleanup
+
+
+def test_write_stdin_preserves_completed_session_output_until_polled(tmp_path):
+ async def run() -> tuple[str, str]:
+ manager = ExecSessionManager()
+ exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
+ stdin_tool = WriteStdinTool(manager=manager)
+ command = _python_command(
+ "import time; print('ready', flush=True); "
+ "time.sleep(1.0); print('done', flush=True)"
+ )
+
+ initial = await exec_tool.execute(command=command, yield_time_ms=300)
+ sid = _session_id(initial)
+ await asyncio.sleep(1.2)
+ final = await stdin_tool.execute(session_id=sid, chars="", yield_time_ms=0)
+ return initial, final
+
+ initial, final = asyncio.run(run())
+
+ assert "ready" in initial
+ assert "done" in final
+ assert "Exit code: 0" in final
+
+
+def test_write_stdin_can_wait_for_expected_output(tmp_path):
+ async def run() -> tuple[str, str, str]:
+ manager = ExecSessionManager()
+ exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
+ stdin_tool = WriteStdinTool(manager=manager)
+ command = _python_command(
+ "import time; print('booting', flush=True); "
+ "time.sleep(0.4); print('ready', flush=True); time.sleep(5)"
+ )
+
+ initial = await exec_tool.execute(command=command, yield_time_ms=100)
+ sid = _session_id(initial)
+ waited = await stdin_tool.execute(
+ session_id=sid,
+ wait_for="ready",
+ wait_timeout_ms=3000,
+ yield_time_ms=0,
+ )
+ cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
+ return initial, waited, cleanup
+
+ initial, waited, cleanup = asyncio.run(run())
+
+ assert "Process running" in initial
+ assert "booting" in initial + waited
+ assert "ready" in waited
+ assert "Wait target not observed" not in waited
+ assert "Session terminated." in cleanup
+
+
+def test_write_stdin_wait_for_reports_timeout_without_killing_session(tmp_path):
+ async def run() -> tuple[str, str, str]:
+ manager = ExecSessionManager()
+ exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
+ stdin_tool = WriteStdinTool(manager=manager)
+ command = _python_command(
+ "import time; print('booting', flush=True); time.sleep(5)"
+ )
+
+ initial = await exec_tool.execute(command=command, yield_time_ms=100)
+ sid = _session_id(initial)
+ waited = await stdin_tool.execute(
+ session_id=sid,
+ wait_for="never-ready",
+ wait_timeout_ms=200,
+ yield_time_ms=0,
+ )
+ cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
+ return initial, waited, cleanup
+
+ initial, waited, cleanup = asyncio.run(run())
+
+ assert "Process running" in initial
+ assert "booting" in initial + waited
+ assert "Process running" in waited
+ assert "Wait target not observed: 'never-ready'" in waited
+ assert "Session terminated." in cleanup
+
+
+def test_exec_session_mode_reuses_exec_safety_guard(tmp_path):
+ manager = ExecSessionManager()
+ tool = ExecTool(
+ working_dir=str(tmp_path),
+ deny_patterns=[r"echo\s+blocked"],
+ session_manager=manager,
+ )
+
+ result = asyncio.run(tool.execute(command="echo blocked", yield_time_ms=0))
+
+ assert "blocked by deny pattern" in result
+
+
+def test_write_stdin_reports_missing_session(tmp_path):
+ manager = ExecSessionManager()
+ tool = WriteStdinTool(manager=manager)
+
+ result = asyncio.run(tool.execute(session_id="missing", chars=""))
+
+ assert "exec session not found" in result
+
+
+def test_list_exec_sessions_reports_running_commands(tmp_path):
+ async def run() -> tuple[str, str, str]:
+ manager = ExecSessionManager()
+ exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
+ list_tool = ListExecSessionsTool(manager=manager)
+ stdin_tool = WriteStdinTool(manager=manager)
+ command = _python_command(
+ "import time; print('ready', flush=True); time.sleep(5)"
+ )
+
+ initial = await exec_tool.execute(command=command, yield_time_ms=500)
+ sid = _session_id(initial)
+ listing = await list_tool.execute()
+ cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
+ return sid, listing, cleanup
+
+ sid, listing, cleanup = asyncio.run(run())
+
+ assert sid in listing
+ assert "running" in listing
+ assert "elapsed=" in listing
+ assert "remaining=" in listing
+ assert str(tmp_path) in listing
+ assert "Session terminated." in cleanup
+
+
+def test_list_exec_sessions_reports_empty_state():
+ result = asyncio.run(ListExecSessionsTool(manager=ExecSessionManager()).execute())
+
+ assert result == "No active exec sessions."
diff --git a/tests/tools/test_file_edit_coding_enhancements.py b/tests/tools/test_file_edit_coding_enhancements.py
new file mode 100644
index 000000000..d361d88ae
--- /dev/null
+++ b/tests/tools/test_file_edit_coding_enhancements.py
@@ -0,0 +1,216 @@
+from __future__ import annotations
+
+import asyncio
+
+from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool
+
+
+def test_read_file_force_bypasses_dedup(tmp_path):
+ target = tmp_path / "data.txt"
+ target.write_text("alpha\n")
+ tool = ReadFileTool(workspace=tmp_path)
+
+ first = asyncio.run(tool.execute(path=str(target)))
+ second = asyncio.run(tool.execute(path=str(target)))
+ forced = asyncio.run(tool.execute(path=str(target), force=True))
+
+ assert "alpha" in first
+ assert "unchanged" in second.lower()
+ assert "alpha" in forced
+ assert "unchanged" not in forced.lower()
+
+
+def test_edit_file_can_select_occurrence(tmp_path):
+ target = tmp_path / "duplicate.txt"
+ target.write_text("one\nsame\ntwo\nsame\n")
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text="same",
+ new_text="changed",
+ occurrence=2,
+ ))
+
+ assert "Successfully edited" in result
+ assert target.read_text() == "one\nsame\ntwo\nchanged\n"
+
+
+def test_edit_file_expected_replacements_guards_replace_all(tmp_path):
+ target = tmp_path / "duplicate.txt"
+ target.write_text("same\nsame\n")
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text="same",
+ new_text="changed",
+ replace_all=True,
+ expected_replacements=1,
+ ))
+
+ assert "expected 1 replacements but would make 2" in result
+ assert target.read_text() == "same\nsame\n"
+
+
+def test_edit_file_expected_replacements_allows_replace_all_when_count_matches(tmp_path):
+ target = tmp_path / "duplicate.txt"
+ target.write_text("same\nsame\n")
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text="same",
+ new_text="changed",
+ replace_all=True,
+ expected_replacements=2,
+ ))
+
+ assert "Successfully edited" in result
+ assert target.read_text() == "changed\nchanged\n"
+
+
+def test_edit_file_can_select_nearest_line_hint(tmp_path):
+ target = tmp_path / "duplicate.txt"
+ target.write_text("one\nsame\ntwo\nsame\n")
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text="same",
+ new_text="changed",
+ line_hint=4,
+ ))
+
+ assert "Successfully edited" in result
+ assert target.read_text() == "one\nsame\ntwo\nchanged\n"
+
+
+def test_edit_file_can_edit_ipynb_as_json(tmp_path):
+ target = tmp_path / "analysis.ipynb"
+ target.write_text('{"cells": []}')
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text='"cells": []',
+ new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
+ ))
+
+ assert "Successfully edited" in result
+ assert '"source": "hi"' in target.read_text()
+
+
+def test_edit_file_multiple_match_hint_mentions_occurrence(tmp_path):
+ target = tmp_path / "duplicate.txt"
+ target.write_text("same\nsame\n")
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text="same",
+ new_text="changed",
+ ))
+
+ assert "old_text appears 2 times" in result
+ assert "occurrence" in result
+ assert target.read_text() == "same\nsame\n"
+
+
+def test_edit_file_rejects_ambiguous_line_hint(tmp_path):
+ target = tmp_path / "duplicate.txt"
+ target.write_text("same\nmiddle\nsame\n")
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text="same",
+ new_text="changed",
+ line_hint=2,
+ ))
+
+ assert "line_hint 2 is ambiguous" in result
+ assert target.read_text() == "same\nmiddle\nsame\n"
+
+
+def test_edit_file_rejects_occurrence_with_replace_all(tmp_path):
+ target = tmp_path / "duplicate.txt"
+ target.write_text("same\nsame\n")
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text="same",
+ new_text="changed",
+ occurrence=1,
+ replace_all=True,
+ ))
+
+ assert "occurrence cannot be used with replace_all" in result
+ assert target.read_text() == "same\nsame\n"
+
+
+def test_edit_file_rejects_line_hint_with_replace_all(tmp_path):
+ target = tmp_path / "duplicate.txt"
+ target.write_text("same\nsame\n")
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text="same",
+ new_text="changed",
+ line_hint=1,
+ replace_all=True,
+ ))
+
+ assert "line_hint cannot be used with replace_all" in result
+ assert target.read_text() == "same\nsame\n"
+
+
+def test_edit_file_rejects_line_hint_with_occurrence(tmp_path):
+ target = tmp_path / "duplicate.txt"
+ target.write_text("same\nsame\n")
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text="same",
+ new_text="changed",
+ occurrence=1,
+ line_hint=1,
+ ))
+
+ assert "line_hint cannot be used with occurrence" in result
+ assert target.read_text() == "same\nsame\n"
+
+
+def test_edit_file_rejects_zero_occurrence(tmp_path):
+ target = tmp_path / "duplicate.txt"
+ target.write_text("same\n")
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text="same",
+ new_text="changed",
+ occurrence=0,
+ ))
+
+ assert "occurrence must be >= 1" in result
+ assert target.read_text() == "same\n"
+
+
+def test_edit_file_rejects_zero_line_hint(tmp_path):
+ target = tmp_path / "duplicate.txt"
+ target.write_text("same\n")
+ tool = EditFileTool(workspace=tmp_path)
+
+ result = asyncio.run(tool.execute(
+ path=str(target),
+ old_text="same",
+ new_text="changed",
+ line_hint=0,
+ ))
+
+ assert "line_hint must be >= 1" in result
+ assert target.read_text() == "same\n"
diff --git a/tests/tools/test_image_generation_tool.py b/tests/tools/test_image_generation_tool.py
index 92ed8a339..f5d2d9183 100644
--- a/tests/tools/test_image_generation_tool.py
+++ b/tests/tools/test_image_generation_tool.py
@@ -138,6 +138,39 @@ async def test_generate_image_tool_reports_missing_aihubmix_key(tmp_path: Path)
assert result.startswith("Error: AIHubMix API key is not configured")
+@pytest.mark.asyncio
+async def test_generate_image_tool_allows_ollama_without_api_key(
+ tmp_path: Path,
+ monkeypatch: pytest.MonkeyPatch,
+) -> None:
+ set_config_path(tmp_path / "config.json")
+ FakeImageClient.instances = []
+ monkeypatch.setattr(
+ "nanobot.agent.tools.image_generation.get_image_gen_provider",
+ lambda name: FakeImageClient if name == "ollama" else None,
+ )
+ tool = ImageGenerationTool(
+ workspace=tmp_path,
+ config=ImageGenerationToolConfig(
+ enabled=True,
+ provider="ollama",
+ model="x/z-image-turbo",
+ ),
+ provider_configs={"ollama": ProviderConfig(api_base="http://localhost:11434/v1")},
+ )
+
+ result = await tool.execute(prompt="draw a cat")
+
+ payload = json.loads(result)
+ assert len(payload["artifacts"]) == 1
+
+ fake = FakeImageClient.instances[0]
+ assert fake.kwargs["api_key"] is None
+ assert fake.kwargs["api_base"] == "http://localhost:11434/v1"
+ assert fake.calls[0]["aspect_ratio"] == "1:1"
+ assert fake.calls[0]["image_size"] == "1K"
+
+
@pytest.mark.asyncio
async def test_generate_image_tool_rejects_reference_outside_workspace(tmp_path: Path) -> None:
set_config_path(tmp_path / "config.json")
diff --git a/tests/tools/test_notebook_tool.py b/tests/tools/test_notebook_tool.py
deleted file mode 100644
index 232f13c4b..000000000
--- a/tests/tools/test_notebook_tool.py
+++ /dev/null
@@ -1,147 +0,0 @@
-"""Tests for NotebookEditTool — Jupyter .ipynb editing."""
-
-import json
-
-import pytest
-
-from nanobot.agent.tools.notebook import NotebookEditTool
-
-
-def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict:
- """Build a minimal valid .ipynb structure."""
- return {
- "nbformat": nbformat,
- "nbformat_minor": nbformat_minor,
- "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
- "cells": cells or [],
- }
-
-
-def _code_cell(source: str, cell_id: str | None = None) -> dict:
- cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None}
- if cell_id:
- cell["id"] = cell_id
- return cell
-
-
-def _md_cell(source: str, cell_id: str | None = None) -> dict:
- cell = {"cell_type": "markdown", "source": source, "metadata": {}}
- if cell_id:
- cell["id"] = cell_id
- return cell
-
-
-def _write_nb(tmp_path, name: str, nb: dict) -> str:
- p = tmp_path / name
- p.write_text(json.dumps(nb), encoding="utf-8")
- return str(p)
-
-
-class TestNotebookEdit:
-
- @pytest.fixture()
- def tool(self, tmp_path):
- return NotebookEditTool(workspace=tmp_path)
-
- @pytest.mark.asyncio
- async def test_replace_cell_content(self, tool, tmp_path):
- nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")])
- path = _write_nb(tmp_path, "test.ipynb", nb)
- result = await tool.execute(path=path, cell_index=0, new_source="print('world')")
- assert "Successfully" in result
- saved = json.loads((tmp_path / "test.ipynb").read_text())
- assert saved["cells"][0]["source"] == "print('world')"
- assert saved["cells"][1]["source"] == "x = 1"
-
- @pytest.mark.asyncio
- async def test_insert_cell_after_target(self, tool, tmp_path):
- nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")])
- path = _write_nb(tmp_path, "test.ipynb", nb)
- result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert")
- assert "Successfully" in result
- saved = json.loads((tmp_path / "test.ipynb").read_text())
- assert len(saved["cells"]) == 3
- assert saved["cells"][0]["source"] == "cell 0"
- assert saved["cells"][1]["source"] == "inserted"
- assert saved["cells"][2]["source"] == "cell 1"
-
- @pytest.mark.asyncio
- async def test_delete_cell(self, tool, tmp_path):
- nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")])
- path = _write_nb(tmp_path, "test.ipynb", nb)
- result = await tool.execute(path=path, cell_index=1, edit_mode="delete")
- assert "Successfully" in result
- saved = json.loads((tmp_path / "test.ipynb").read_text())
- assert len(saved["cells"]) == 2
- assert saved["cells"][0]["source"] == "A"
- assert saved["cells"][1]["source"] == "C"
-
- @pytest.mark.asyncio
- async def test_create_new_notebook_from_scratch(self, tool, tmp_path):
- path = str(tmp_path / "new.ipynb")
- result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown")
- assert "Successfully" in result or "created" in result.lower()
- saved = json.loads((tmp_path / "new.ipynb").read_text())
- assert saved["nbformat"] == 4
- assert len(saved["cells"]) == 1
- assert saved["cells"][0]["cell_type"] == "markdown"
- assert saved["cells"][0]["source"] == "# Hello"
-
- @pytest.mark.asyncio
- async def test_invalid_cell_index_error(self, tool, tmp_path):
- nb = _make_notebook([_code_cell("only cell")])
- path = _write_nb(tmp_path, "test.ipynb", nb)
- result = await tool.execute(path=path, cell_index=5, new_source="x")
- assert "Error" in result
-
- @pytest.mark.asyncio
- async def test_non_ipynb_rejected(self, tool, tmp_path):
- f = tmp_path / "script.py"
- f.write_text("pass")
- result = await tool.execute(path=str(f), cell_index=0, new_source="x")
- assert "Error" in result
- assert ".ipynb" in result
-
- @pytest.mark.asyncio
- async def test_preserves_metadata_and_outputs(self, tool, tmp_path):
- cell = _code_cell("old")
- cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}]
- cell["execution_count"] = 42
- nb = _make_notebook([cell])
- path = _write_nb(tmp_path, "test.ipynb", nb)
- await tool.execute(path=path, cell_index=0, new_source="new")
- saved = json.loads((tmp_path / "test.ipynb").read_text())
- assert saved["metadata"]["kernelspec"]["language"] == "python"
-
- @pytest.mark.asyncio
- async def test_nbformat_45_generates_cell_id(self, tool, tmp_path):
- nb = _make_notebook([], nbformat_minor=5)
- path = _write_nb(tmp_path, "test.ipynb", nb)
- await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert")
- saved = json.loads((tmp_path / "test.ipynb").read_text())
- assert "id" in saved["cells"][0]
- assert len(saved["cells"][0]["id"]) > 0
-
- @pytest.mark.asyncio
- async def test_insert_with_cell_type_markdown(self, tool, tmp_path):
- nb = _make_notebook([_code_cell("code")])
- path = _write_nb(tmp_path, "test.ipynb", nb)
- await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown")
- saved = json.loads((tmp_path / "test.ipynb").read_text())
- assert saved["cells"][1]["cell_type"] == "markdown"
-
- @pytest.mark.asyncio
- async def test_invalid_edit_mode_rejected(self, tool, tmp_path):
- nb = _make_notebook([_code_cell("code")])
- path = _write_nb(tmp_path, "test.ipynb", nb)
- result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae")
- assert "Error" in result
- assert "edit_mode" in result
-
- @pytest.mark.asyncio
- async def test_invalid_cell_type_rejected(self, tool, tmp_path):
- nb = _make_notebook([_code_cell("code")])
- path = _write_nb(tmp_path, "test.ipynb", nb)
- result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw")
- assert "Error" in result
- assert "cell_type" in result
diff --git a/tests/tools/test_search_tools.py b/tests/tools/test_search_tools.py
index 0d3697044..fc7c1944a 100644
--- a/tests/tools/test_search_tools.py
+++ b/tests/tools/test_search_tools.py
@@ -12,7 +12,7 @@ import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.agent.subagent import SubagentManager, SubagentStatus
-from nanobot.agent.tools.search import GrepTool
+from nanobot.agent.tools.search import FindFilesTool, GrepTool
from nanobot.agent.tools.web import WebSearchTool
from nanobot.bus.queue import MessageBus
from nanobot.config.schema import WebSearchConfig
@@ -33,6 +33,68 @@ async def test_web_search_tool_refreshes_dynamic_config_loader(monkeypatch) -> N
assert await tool.execute("nanobot") == "duckduckgo:nanobot:3"
+@pytest.mark.asyncio
+async def test_find_files_filters_by_query_glob_and_type(tmp_path: Path) -> None:
+ (tmp_path / "src").mkdir()
+ (tmp_path / "src" / "settings_view.tsx").write_text("export {}\n", encoding="utf-8")
+ (tmp_path / "src" / "settings_api.py").write_text("pass\n", encoding="utf-8")
+ (tmp_path / "README.md").write_text("settings\n", encoding="utf-8")
+
+ tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
+ result = await tool.execute(
+ path=".",
+ query="settings",
+ glob="src/**",
+ type="ts",
+ )
+
+ assert result.splitlines() == ["src/settings_view.tsx"]
+
+
+@pytest.mark.asyncio
+async def test_find_files_can_include_directories(tmp_path: Path) -> None:
+ (tmp_path / "src" / "settings").mkdir(parents=True)
+ (tmp_path / "src" / "settings" / "index.ts").write_text("export {}\n", encoding="utf-8")
+
+ tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
+ result = await tool.execute(path="src", query="settings", include_dirs=True)
+
+ assert "src/settings/" in result.splitlines()
+ assert "src/settings/index.ts" in result.splitlines()
+
+
+@pytest.mark.asyncio
+async def test_find_files_supports_modified_sort_and_pagination(tmp_path: Path) -> None:
+ (tmp_path / "src").mkdir()
+ for idx, name in enumerate(("a.py", "b.py", "c.py"), start=1):
+ file_path = tmp_path / "src" / name
+ file_path.write_text("pass\n", encoding="utf-8")
+ os.utime(file_path, (idx, idx))
+
+ tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
+ result = await tool.execute(
+ path="src",
+ type="py",
+ sort="modified",
+ head_limit=1,
+ offset=1,
+ )
+
+ assert result.splitlines()[0] == "src/b.py"
+ assert "pagination: limit=1, offset=1" in result
+
+
+@pytest.mark.asyncio
+async def test_find_files_rejects_paths_outside_workspace(tmp_path: Path) -> None:
+ outside = tmp_path.parent / "outside-find-files.txt"
+ outside.write_text("secret\n", encoding="utf-8")
+
+ tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
+ result = await tool.execute(path=str(outside))
+
+ assert result.startswith("Error:")
+
+
@pytest.mark.asyncio
async def test_grep_respects_glob_filter_and_context(tmp_path: Path) -> None:
(tmp_path / "src").mkdir()
@@ -249,6 +311,7 @@ def test_agent_loop_registers_grep(tmp_path: Path) -> None:
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
+ assert "find_files" in loop.tools.tool_names
assert "grep" in loop.tools.tool_names
@@ -280,6 +343,7 @@ async def test_subagent_registers_grep(tmp_path: Path) -> None:
status = SubagentStatus(task_id="sub-1", label="label", task_description="search task", started_at=time.monotonic())
await mgr._run_subagent("sub-1", "search task", "label", {"channel": "cli", "chat_id": "direct"}, status)
+ assert "find_files" in captured["tool_names"]
assert "grep" in captured["tool_names"]
diff --git a/tests/tools/test_tool_descriptions.py b/tests/tools/test_tool_descriptions.py
new file mode 100644
index 000000000..bb7665e4e
--- /dev/null
+++ b/tests/tools/test_tool_descriptions.py
@@ -0,0 +1,46 @@
+from nanobot.agent.tools.apply_patch import ApplyPatchTool
+from nanobot.agent.tools.exec_session import ListExecSessionsTool, WriteStdinTool
+from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
+from nanobot.agent.tools.search import FindFilesTool, GrepTool
+from nanobot.agent.tools.shell import ExecTool
+
+
+def test_coding_tool_descriptions_steer_editing_priority() -> None:
+ apply_patch = ApplyPatchTool().description.lower()
+ edit_file = EditFileTool().description.lower()
+ write_file = WriteFileTool().description.lower()
+
+ assert "default tool for code edits" in apply_patch
+ assert "multi-file" in apply_patch
+ assert "dry_run=true" in apply_patch
+ assert "edit_file only for small exact replacements" in apply_patch
+
+ assert "small, exact replacement" in edit_file
+ assert "copied from read_file" in edit_file
+ assert "prefer apply_patch" in edit_file
+
+ assert "replace an entire file" in write_file
+ assert "prefer apply_patch" in write_file
+
+
+def test_coding_tool_descriptions_steer_discovery_and_shell_usage() -> None:
+ read_file = ReadFileTool().description.lower()
+ find_files = FindFilesTool().description.lower()
+ grep = GrepTool().description.lower()
+ exec_tool = ExecTool().description.lower()
+ write_stdin = WriteStdinTool().description.lower()
+ list_sessions = ListExecSessionsTool().description.lower()
+
+ assert "find_files/list_dir first" in read_file
+ assert "before editing" in read_file
+ assert "prefer it over shell find/ls" in find_files
+ assert "prefer this over shell grep" in grep
+
+ assert "tests, builds" in exec_tool
+ assert "prefer read_file/find_files/grep" in exec_tool
+ assert "apply_patch/write_file/edit_file" in exec_tool
+ assert "yield_time_ms" in exec_tool
+
+ assert "do not use this to start new commands" in write_stdin
+ assert "wait_for" in write_stdin
+ assert "recover a session_id" in list_sessions
diff --git a/tests/tools/test_tool_loader.py b/tests/tools/test_tool_loader.py
index 54b4d92d5..62703883c 100644
--- a/tests/tools/test_tool_loader.py
+++ b/tests/tools/test_tool_loader.py
@@ -89,9 +89,11 @@ def test_discover_finds_concrete_tools():
loader = ToolLoader()
discovered = loader.discover()
class_names = {cls.__name__ for cls in discovered}
+ assert "ApplyPatchTool" in class_names
assert "ExecTool" in class_names
assert "MessageTool" in class_names
assert "SpawnTool" in class_names
+ assert "WriteStdinTool" in class_names
def test_discover_excludes_abstract_and_mcp():
@@ -406,7 +408,8 @@ def test_loader_registers_same_tools_as_old_hardcoded():
expected = {
"read_file", "write_file", "edit_file", "list_dir",
- "grep", "notebook_edit", "exec", "web_search", "web_fetch",
+ "find_files", "grep", "exec", "write_stdin", "list_exec_sessions",
+ "web_search", "web_fetch",
"message", "spawn", "cron",
}
actual = set(registered)
diff --git a/tests/tools/test_tool_validation.py b/tests/tools/test_tool_validation.py
index 42620dcc6..188a8952f 100644
--- a/tests/tools/test_tool_validation.py
+++ b/tests/tools/test_tool_validation.py
@@ -3,6 +3,8 @@ import subprocess
import sys
from typing import Any
+import pytest
+
from nanobot.agent.tools import (
ArraySchema,
IntegerSchema,
@@ -15,6 +17,7 @@ from nanobot.agent.tools import (
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
+from nanobot.security.network import configure_ssrf_whitelist
class SampleTool(Tool):
@@ -218,6 +221,39 @@ def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None:
assert "/bin/python" not in paths
+def test_exec_extract_absolute_paths_ignores_urls() -> None:
+ cmd = 'curl -s -o /dev/null -w "%{http_code}" https://www.google.com'
+ paths = ExecTool._extract_absolute_paths(cmd)
+ assert paths == ["/dev/null"]
+
+
+@pytest.mark.parametrize(
+ "command",
+ [
+ 'curl -s -o /dev/null -w "%{http_code}" https://www.google.com',
+ 'wget -q -O - http://example.com 2>&1 | head -c 100',
+ 'python3 -c "import urllib.request; print(urllib.request.urlopen(\'http://example.com\').read()[:100])"',
+ ],
+)
+def test_exec_guard_allows_public_urls(tmp_path, command: str) -> None:
+ tool = ExecTool(restrict_to_workspace=True)
+ error = tool._guard_command(command, str(tmp_path))
+ assert error is None
+
+
+def test_exec_guard_allows_whitelisted_internal_urls(tmp_path) -> None:
+ configure_ssrf_whitelist(["10.10.10.0/24"])
+ try:
+ tool = ExecTool(restrict_to_workspace=True)
+ error = tool._guard_command(
+ 'curl -s -H "Authorization: Bearer ..." http://10.10.10.3:8123/api/',
+ str(tmp_path),
+ )
+ assert error is None
+ finally:
+ configure_ssrf_whitelist([])
+
+
def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None:
cmd = "cat /tmp/data.txt > /tmp/out.txt"
paths = ExecTool._extract_absolute_paths(cmd)
diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py
index cdaae5167..fe035b41b 100644
--- a/tests/utils/test_file_edit_events.py
+++ b/tests/utils/test_file_edit_events.py
@@ -5,12 +5,13 @@ from pathlib import Path
from types import SimpleNamespace
from nanobot.utils.file_edit_events import (
+ StreamingFileEditTracker,
build_file_edit_end_event,
build_file_edit_start_event,
line_diff_stats,
prepare_file_edit_tracker,
+ prepare_file_edit_trackers,
read_file_snapshot,
- StreamingFileEditTracker,
)
@@ -81,6 +82,63 @@ def test_binary_file_is_reported_but_not_counted(tmp_path: Path) -> None:
assert (event["added"], event["deleted"]) == (0, 0)
+def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) -> None:
+ (tmp_path / "src").mkdir()
+ existing = tmp_path / "src" / "existing.py"
+ existing.write_text("old\nkeep\n", encoding="utf-8")
+ delete_me = tmp_path / "src" / "delete_me.py"
+ delete_me.write_text("gone\n", encoding="utf-8")
+
+ edits = [
+ {"path": "src/new.py", "action": "add", "new_text": "fresh"},
+ {"path": "src/existing.py", "action": "replace", "old_text": "old", "new_text": "new"},
+ {"path": "src/delete_me.py", "action": "delete", "old_text": "gone\n"},
+ ]
+
+ trackers = prepare_file_edit_trackers(
+ call_id="call-patch",
+ tool_name="apply_patch",
+ tool=None,
+ workspace=tmp_path,
+ params={"edits": edits},
+ )
+
+ assert [tracker.display_path for tracker in trackers] == [
+ "src/new.py",
+ "src/existing.py",
+ "src/delete_me.py",
+ ]
+
+ (tmp_path / "src" / "new.py").write_text("fresh\n", encoding="utf-8")
+ existing.write_text("new\nkeep\n", encoding="utf-8")
+ delete_me.unlink()
+
+ events = [build_file_edit_end_event(tracker, {"edits": edits}) for tracker in trackers]
+ by_path = {event["path"]: event for event in events}
+ assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0)
+ assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1)
+ assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1)
+
+
+def test_apply_patch_dry_run_does_not_prepare_file_edit_trackers(tmp_path: Path) -> None:
+ (tmp_path / "file.txt").write_text("old\n", encoding="utf-8")
+
+ trackers = prepare_file_edit_trackers(
+ call_id="call-patch",
+ tool_name="apply_patch",
+ tool=None,
+ workspace=tmp_path,
+ params={
+ "dry_run": True,
+ "edits": [
+ {"path": "file.txt", "action": "replace", "old_text": "old", "new_text": "new"}
+ ],
+ },
+ )
+
+ assert trackers == []
+
+
def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None:
target = tmp_path / "large.txt"
params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)}
@@ -140,6 +198,58 @@ def test_streaming_write_file_tracker_emits_live_line_counts(tmp_path: Path) ->
assert events[-1]["deleted"] == 0
+def test_streaming_apply_patch_tracker_emits_live_counts_per_file(tmp_path: Path) -> None:
+ (tmp_path / "src").mkdir()
+ (tmp_path / "src" / "existing.py").write_text("old\nkeep\n", encoding="utf-8")
+ events: list[dict] = []
+
+ async def emit(batch: list[dict]) -> None:
+ events.extend(batch)
+
+ async def run() -> None:
+ tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
+ await tracker.update({
+ "index": 0,
+ "call_id": "call-patch",
+ "name": "apply_patch",
+ "arguments_delta": (
+ '{"edits":[{"path":"src/existing.py","action":"replace","old_text":"old","new_text":"new"}'
+ ',{"path":"src/new.py","action":"add","new_text":"fresh"}]}'
+ ),
+ })
+
+ asyncio.run(run())
+
+ by_path = {event["path"]: event for event in events}
+ assert by_path["src/existing.py"]["tool"] == "apply_patch"
+ assert by_path["src/existing.py"]["status"] == "editing"
+ assert by_path["src/existing.py"]["approximate"] is True
+ assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1)
+ assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0)
+
+
+def test_streaming_apply_patch_tracker_skips_dry_run(tmp_path: Path) -> None:
+ events: list[dict] = []
+
+ async def emit(batch: list[dict]) -> None:
+ events.extend(batch)
+
+ async def run() -> None:
+ tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
+ await tracker.update({
+ "index": 0,
+ "call_id": "call-patch",
+ "name": "apply_patch",
+ "arguments_delta": (
+ '{"dry_run":true,"edits":[{"path":"dry.md","action":"add","new_text":"preview"}]}'
+ ),
+ })
+
+ asyncio.run(run())
+
+ assert events == []
+
+
def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None:
events: list[dict] = []
@@ -308,6 +418,43 @@ def test_streaming_tracker_applies_canonical_call_id_to_final_tool(tmp_path: Pat
asyncio.run(run())
+def test_streaming_tracker_does_not_restore_duplicate_canonical_ids(tmp_path: Path) -> None:
+ events: list[dict] = []
+
+ async def emit(batch: list[dict]) -> None:
+ events.extend(batch)
+
+ async def run() -> None:
+ tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
+ await tracker.update({
+ "index": 0,
+ "call_id": "call_dup",
+ "name": "write_file",
+ "arguments_delta": '{"path":"a.md","content":"one\\n"}',
+ })
+ await tracker.update({
+ "index": 1,
+ "call_id": "call_dup",
+ "name": "write_file",
+ "arguments_delta": '{"path":"b.md","content":"two\\n"}',
+ })
+ final_a = SimpleNamespace(
+ id="call_dup",
+ name="write_file",
+ arguments={"path": "a.md", "content": "one\n"},
+ )
+ final_b = SimpleNamespace(
+ id="call_unique",
+ name="write_file",
+ arguments={"path": "b.md", "content": "two\n"},
+ )
+ tracker.apply_final_call_ids([final_a, final_b])
+ assert final_a.id == "call_dup"
+ assert final_b.id == "call_unique"
+
+ asyncio.run(run())
+
+
def test_streaming_edit_file_tracker_flushes_small_pending_count(tmp_path: Path) -> None:
target = tmp_path / "small.py"
target.write_text("old\n", encoding="utf-8")
diff --git a/webui/src/App.tsx b/webui/src/App.tsx
index c303446e2..8c6127829 100644
--- a/webui/src/App.tsx
+++ b/webui/src/App.tsx
@@ -43,6 +43,7 @@ const SIDEBAR_STORAGE_KEY = "nanobot-webui.sidebar";
const COMPLETED_RUNS_STORAGE_KEY = "nanobot-webui.sidebar.completed-runs.v1";
const RESTART_STARTED_KEY = "nanobot-webui.restartStartedAt";
const SIDEBAR_WIDTH = 272;
+const SIDEBAR_RAIL_WIDTH = 56;
const TOKEN_REFRESH_MARGIN_MS = 30_000;
const TOKEN_REFRESH_MIN_DELAY_MS = 5_000;
type ShellView = "chat" | "settings";
@@ -411,6 +412,10 @@ function Shell({
setDesktopSidebarOpen(false);
}, []);
+ const openDesktopSidebar = useCallback(() => {
+ setDesktopSidebarOpen(true);
+ }, []);
+
const closeMobileSidebar = useCallback(() => {
setMobileSidebarOpen(false);
}, []);
@@ -560,6 +565,21 @@ function Shell({
setSessionSearchOpen(true);
}, []);
+ useEffect(() => {
+ const handleKeyDown = (event: globalThis.KeyboardEvent) => {
+ if (event.defaultPrevented) return;
+ const plainCommandK =
+ (event.metaKey || event.ctrlKey) && !event.altKey && !event.shiftKey;
+ if (!plainCommandK) return;
+ if (event.key.toLowerCase() !== "k") return;
+ event.preventDefault();
+ onOpenSessionSearch();
+ };
+
+ window.addEventListener("keydown", handleKeyDown);
+ return () => window.removeEventListener("keydown", handleKeyDown);
+ }, [onOpenSessionSearch]);
+
const onSelectSearchResult = useCallback(
(key: string) => {
setSessionSearchOpen(false);
@@ -732,17 +752,19 @@ function Shell({
"relative z-20 hidden shrink-0 overflow-hidden lg:block",
"transition-[width] duration-300 ease-out",
)}
- style={{ width: desktopSidebarOpen ? SIDEBAR_WIDTH : 0 }}
+ style={{
+ width: desktopSidebarOpen ? SIDEBAR_WIDTH : SIDEBAR_RAIL_WIDTH,
+ }}
>
-
+
) : null}
@@ -769,17 +791,15 @@ function Shell({
) : null}
- {showMainSidebar ? (
-
- ) : null}
+
{view === "settings" && (
diff --git a/webui/src/components/ChatList.tsx b/webui/src/components/ChatList.tsx
index 705039aea..d098a5972 100644
--- a/webui/src/components/ChatList.tsx
+++ b/webui/src/components/ChatList.tsx
@@ -1,3 +1,9 @@
+import {
+ memo,
+ useEffect,
+ useMemo,
+ useState,
+} from "react";
import {
Archive,
ArchiveRestore,
@@ -19,6 +25,9 @@ import { deriveTitle, relativeTime } from "@/lib/format";
import { cn } from "@/lib/utils";
import type { ChatSummary, SidebarDensity, SidebarSortMode } from "@/lib/types";
+const INITIAL_VISIBLE_SESSIONS = 160;
+const VISIBLE_SESSIONS_INCREMENT = 160;
+
interface ChatListProps {
sessions: ChatSummary[];
activeKey: string | null;
@@ -42,7 +51,7 @@ interface ChatListProps {
emptyLabel?: string;
}
-export function ChatList({
+export const ChatList = memo(function ChatList({
sessions,
activeKey,
onSelect,
@@ -65,6 +74,52 @@ export function ChatList({
emptyLabel,
}: ChatListProps) {
const { t } = useTranslation();
+ const [visibleLimit, setVisibleLimit] = useState(INITIAL_VISIBLE_SESSIONS);
+ const labels = useMemo(() => ({
+ pinned: t("chat.groups.pinned"),
+ all: t("chat.groups.all"),
+ today: t("chat.groups.today"),
+ yesterday: t("chat.groups.yesterday"),
+ earlier: t("chat.groups.earlier"),
+ archived: t("chat.groups.archived"),
+ fallbackTitle: t("chat.newChat"),
+ }), [t]);
+ const groups = useMemo(
+ () => groupSessions(sessions, labels, {
+ pinnedKeys,
+ archivedKeys,
+ titleOverrides,
+ showArchived,
+ sort,
+ }),
+ [
+ archivedKeys,
+ labels,
+ pinnedKeys,
+ sessions,
+ showArchived,
+ sort,
+ titleOverrides,
+ ],
+ );
+ const limitedGroups = useMemo(
+ () => limitGroups(groups, visibleLimit, activeKey),
+ [activeKey, groups, visibleLimit],
+ );
+ const totalSessionCount = useMemo(
+ () => groups.reduce((total, group) => total + group.sessions.length, 0),
+ [groups],
+ );
+ const visibleSessionCount = useMemo(
+ () => limitedGroups.reduce((total, group) => total + group.sessions.length, 0),
+ [limitedGroups],
+ );
+ const hiddenSessionCount = Math.max(0, totalSessionCount - visibleSessionCount);
+
+ useEffect(() => {
+ setVisibleLimit(INITIAL_VISIBLE_SESSIONS);
+ }, [showArchived, sort]);
+
if (loading && sessions.length === 0) {
return (