diff --git a/README.md b/README.md index b5e4b02c0..1dbc82db8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,18 @@ ![cover-v5-optimized](./images/GitHub_README.png)
+

+ English | + 简体中文 | + 繁體中文 | + Español | + Français | + Bahasa Indonesia | + 日本語 | + 한국어 | + Русский | + Tiếng Việt +

PyPI Downloads @@ -61,7 +73,7 @@ - **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks. - **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened. - **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media. -- **2026-04-10** 📓 Notebook editing tool, multiple MCP servers, Feishu streaming & done-emoji. +- **2026-04-10** 📓 Multiple MCP servers, Feishu streaming & done-emoji. - **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config. - **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback. - **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools. diff --git a/docs/chat-apps.md b/docs/chat-apps.md index c0c1b4ba0..88242a5f7 100644 --- a/docs/chat-apps.md +++ b/docs/chat-apps.md @@ -17,6 +17,7 @@ Connect nanobot to your favorite chat platform. Want to build your own? See the | **Wecom** | Bot ID + Bot Secret | | **Microsoft Teams** | App ID + App Password + public HTTPS endpoint | | **Mochat** | Claw token (auto-setup available) | +| **Signal** | signal-cli daemon + phone number |

Telegram (Recommended) @@ -669,3 +670,69 @@ nanobot gateway ```
+ +
+Signal + +Uses **signal-cli** daemon in HTTP mode — receive messages via SSE, send via JSON-RPC. + +**1. Install signal-cli** + +Install [signal-cli](https://github.com/AsamK/signal-cli) and register a phone number: + +```bash +signal-cli -u +1234567890 register +signal-cli -u +1234567890 verify +``` + +Start the daemon: + +```bash +signal-cli -a +1234567890 daemon --http localhost:8080 +``` + +**2. Configure** + +```json +{ + "channels": { + "signal": { + "enabled": true, + "phoneNumber": "+1234567890", + "daemonHost": "localhost", + "daemonPort": 8080, + "dm": { + "enabled": true, + "policy": "open" + }, + "group": { + "enabled": true, + "policy": "open", + "requireMention": true + } + } + } +} +``` + +> - `phoneNumber`: Your registered Signal phone number. +> - `daemonHost` / `daemonPort`: Where signal-cli daemon is listening (default `localhost:8080`). +> - `dm.policy`: `"open"` (anyone can DM) or `"allowlist"` (only listed numbers/UUIDs). When `"allowlist"`, unlisted DM senders receive a pairing code. +> - `dm.allowFrom`: List of allowed phone numbers or UUIDs (used when policy is `"allowlist"`). +> - `group.policy`: `"open"` (all groups) or `"allowlist"` (only listed group IDs). +> - `group.requireMention`: When `true` (default), the bot only responds in groups when @mentioned. +> - `group.allowFrom`: List of allowed group IDs (used when group policy is `"allowlist"`). +> - `attachmentsDir`: Override the directory where signal-cli stores inbound attachments. Defaults to `~/.local/share/signal-cli/attachments` (the Linux default). Set this if signal-cli runs with a custom `XDG_DATA_HOME` or on macOS/Windows. +> - `groupMessageBufferSize`: Number of recent group messages kept for context (default `20`, must be > 0). + +**3. Run** + +```bash +nanobot gateway +``` + +> [!TIP] +> The channel automatically reconnects to the signal-cli daemon with exponential backoff if the connection drops. +> Markdown in bot replies is automatically converted to Signal text styles (bold, italic, code, etc.). + +
diff --git a/docs/configuration.md b/docs/configuration.md index dbd5e2626..e4fbe83eb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -148,6 +148,7 @@ ANTHROPIC_API_KEY="$(bw get password api/anthropic)" nanobot agent | `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | | `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) | | `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) | +| `novita` | LLM (Novita AI OpenAI-compatible gateway) | [novita.ai](https://novita.ai) | | `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | | `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) | | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index 19ee935c4..82ebfab65 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -22,7 +22,7 @@ from nanobot.utils.prompt_templates import render_template class ContextBuilder: """Builds the context (system prompt + messages) for the agent.""" - BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"] + BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md"] _RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]" _MAX_RECENT_HISTORY = 50 _MAX_HISTORY_CHARS = 32_000 # hard cap on recent history section size @@ -47,6 +47,8 @@ class ContextBuilder: if bootstrap: parts.append(bootstrap) + parts.append(render_template("agent/tool_contract.md")) + memory = self.memory.get_memory_context() if memory and not self._is_template_content(self.memory.read_memory(), "memory/MEMORY.md"): parts.append(f"# Memory\n\n{memory}") @@ -210,4 +212,3 @@ class ContextBuilder: if not images: return text return images + [{"type": "text", "text": text}] - diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 0b0164fd0..19494034f 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -19,7 +19,8 @@ from nanobot.utils.file_edit_events import ( build_file_edit_end_event, build_file_edit_error_event, build_file_edit_start_event, - prepare_file_edit_tracker, + prepare_file_edit_tracker as _prepare_file_edit_tracker, + prepare_file_edit_trackers, StreamingFileEditTracker, ) from nanobot.utils.helpers import ( @@ -58,11 +59,14 @@ _SNIP_SAFETY_BUFFER = 1024 _MICROCOMPACT_KEEP_RECENT = 10 _MICROCOMPACT_MIN_CHARS = 500 _COMPACTABLE_TOOLS = frozenset({ - "read_file", "exec", "grep", - "web_search", "web_fetch", "list_dir", + "read_file", "exec", "grep", "find_files", + "web_search", "web_fetch", "list_dir", "list_exec_sessions", }) _BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]" +# Backward-compatible module attribute for tests/extensions that monkeypatch +# the former single-file tracker hook. Runtime uses prepare_file_edit_trackers. +prepare_file_edit_tracker = _prepare_file_edit_tracker @dataclass(slots=True) @@ -857,8 +861,8 @@ class AgentRunner: and on_progress_accepts_file_edit_events(spec.progress_callback) ) progress_callback = spec.progress_callback if emit_file_edit_events else None - file_edit_tracker = ( - prepare_file_edit_tracker( + file_edit_trackers = ( + prepare_file_edit_trackers( call_id=tool_call.id, tool_name=tool_call.name, tool=tool, @@ -868,13 +872,13 @@ class AgentRunner: if progress_callback is not None else None ) - if file_edit_tracker is not None and progress_callback is not None: + if file_edit_trackers and progress_callback is not None: await invoke_file_edit_progress( progress_callback, [build_file_edit_start_event( file_edit_tracker, params if isinstance(params, dict) else None, - )], + ) for file_edit_tracker in file_edit_trackers], ) try: if tool is not None: @@ -884,10 +888,13 @@ class AgentRunner: except asyncio.CancelledError: raise except BaseException as exc: - if file_edit_tracker is not None and progress_callback is not None: + if file_edit_trackers and progress_callback is not None: await invoke_file_edit_progress( progress_callback, - [build_file_edit_error_event(file_edit_tracker, str(exc))], + [ + build_file_edit_error_event(file_edit_tracker, str(exc)) + for file_edit_tracker in file_edit_trackers + ], ) event = { "name": tool_call.name, @@ -910,10 +917,13 @@ class AgentRunner: return payload, event, None if isinstance(result, str) and result.startswith("Error"): - if file_edit_tracker is not None and progress_callback is not None: + if file_edit_trackers and progress_callback is not None: await invoke_file_edit_progress( progress_callback, - [build_file_edit_error_event(file_edit_tracker, result)], + [ + build_file_edit_error_event(file_edit_tracker, result) + for file_edit_tracker in file_edit_trackers + ], ) event = { "name": tool_call.name, @@ -933,13 +943,13 @@ class AgentRunner: return result + hint, event, RuntimeError(result) return result + hint, event, None - if file_edit_tracker is not None and progress_callback is not None: + if file_edit_trackers and progress_callback is not None: await invoke_file_edit_progress( progress_callback, [build_file_edit_end_event( file_edit_tracker, params if isinstance(params, dict) else None, - )], + ) for file_edit_tracker in file_edit_trackers], ) detail = "" if result is None else str(result) diff --git a/nanobot/agent/tools/apply_patch.py b/nanobot/agent/tools/apply_patch.py new file mode 100644 index 000000000..ac524f7fc --- /dev/null +++ b/nanobot/agent/tools/apply_patch.py @@ -0,0 +1,352 @@ +"""Apply file edits by providing structured edit instructions.""" + +from __future__ import annotations + +import difflib +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from nanobot.agent.tools.base import tool_parameters +from nanobot.agent.tools.filesystem import _FsTool +from nanobot.agent.tools.schema import ( + ArraySchema, + BooleanSchema, + ObjectSchema, + StringSchema, + tool_parameters_schema, +) + + +@dataclass(slots=True) +class _PatchSummary: + action: str + path: str + added: int = 0 + deleted: int = 0 + + +class _PatchError(ValueError): + pass + + +_ABSOLUTE_WINDOWS_RE = re.compile(r"^[A-Za-z]:[\\/]") + + +def _validate_relative_path(path: str) -> str: + normalized = path.strip() + if not normalized: + raise _PatchError("patch path cannot be empty") + if "\0" in normalized: + raise _PatchError(f"patch path contains a null byte: {path!r}") + if normalized.startswith(("~", "/", "\\")) or _ABSOLUTE_WINDOWS_RE.match(normalized): + raise _PatchError(f"patch path must be relative: {path}") + if any(part == ".." for part in re.split(r"[\\/]+", normalized)): + raise _PatchError(f"patch path must not contain '..': {path}") + return normalized + + +def _lines_to_text(lines: list[str]) -> str: + if not lines: + return "" + return "\n".join(lines) + "\n" + + +def _text_line_count(text: str) -> int: + if not text: + return 0 + return len(text.splitlines()) + + +def _line_diff_stats(before: str, after: str) -> tuple[int, int]: + before_lines = before.replace("\r\n", "\n").splitlines() + after_lines = after.replace("\r\n", "\n").splitlines() + added = 0 + deleted = 0 + matcher = difflib.SequenceMatcher(a=before_lines, b=after_lines, autojunk=False) + for tag, i1, i2, j1, j2 in matcher.get_opcodes(): + if tag == "equal": + continue + if tag in ("replace", "delete"): + deleted += i2 - i1 + if tag in ("replace", "insert"): + added += j2 - j1 + return added, deleted + + +def _format_summary(summary: _PatchSummary) -> str: + stats = "" + if summary.added or summary.deleted: + stats = f" (+{summary.added}/-{summary.deleted})" + return f"- {summary.action} {summary.path}{stats}" + + +@tool_parameters( + tool_parameters_schema( + edits=ArraySchema( + items=ObjectSchema( + path=StringSchema("Relative path to the file to edit."), + action=StringSchema( + "Operation type: replace (find and replace text), add (append new content or create file), delete (remove text).", + enum=["replace", "add", "delete"], + ), + old_text=StringSchema( + "Exact text to search for in the file. Required for replace and delete.", + nullable=True, + ), + new_text=StringSchema( + "Text to replace with or append. Required for replace and add.", + nullable=True, + ), + required=["path", "action"], + ), + description="List of edits to apply. Each edit specifies a file and the change to make.", + min_items=1, + max_items=20, + ), + dry_run=BooleanSchema( + description="Validate and summarize the patch without writing files.", + default=False, + ), + required=["edits"], + ) +) +class ApplyPatchTool(_FsTool): + """Apply file edits by providing structured edit instructions.""" + _scopes = {"core", "subagent"} + + @property + def name(self) -> str: + return "apply_patch" + + @property + def description(self) -> str: + return ( + "Default tool for code edits. Supports multi-file changes in a single call. " + "Provide a list of structured edits, each specifying a file path, action (replace/add/delete), and the text to change. " + "Paths must be relative. Set dry_run=true to validate and preview without writing files. " + "Use edit_file only for small exact replacements on a single file." + ) + + async def execute( + self, + edits: list[dict] | None = None, + dry_run: bool = False, + **kwargs: Any, + ) -> str: + try: + if not edits: + raise _PatchError("must provide edits") + + writes: dict[Path, str] = {} + deletes: set[Path] = set() + summaries: list[_PatchSummary] = [] + + for edit in edits: + if not isinstance(edit, dict): + raise _PatchError("each edit must be an object") + raw_path = edit.get("path") + if not isinstance(raw_path, str): + raise _PatchError("path required for edit") + path = _validate_relative_path(raw_path) + action = edit.get("action") + if not isinstance(action, str): + raise _PatchError(f"action required for edit: {path}") + source = self._resolve(path) + + if action == "add": + new_text = edit.get("new_text") + if new_text is None: + raise _PatchError(f"new_text required for add: {path}") + + pending = writes.get(source) + if pending is not None: + content = pending + exists = True + elif source.exists(): + raw = source.read_bytes() + try: + content = raw.decode("utf-8") + except UnicodeDecodeError: + raise _PatchError(f"file is not UTF-8 text: {path}") + exists = True + else: + content = "" + exists = False + + if exists: + uses_crlf = "\r\n" in content + new_norm = content.replace("\r\n", "\n") + new_text.replace("\r\n", "\n") + if new_norm and not new_norm.endswith("\n"): + new_norm += "\n" + if uses_crlf: + new_norm = new_norm.replace("\n", "\r\n") + writes[source] = new_norm + deletes.discard(source) + added, deleted = _line_diff_stats(content, new_norm) + action_name = "update" + else: + new_norm = new_text.replace("\r\n", "\n") + if new_norm and not new_norm.endswith("\n"): + new_norm += "\n" + writes[source] = new_norm + deletes.discard(source) + added = _text_line_count(new_norm) + deleted = 0 + action_name = "add" + + summaries.append( + _PatchSummary( + action=action_name, path=path, added=added, deleted=deleted + ) + ) + + elif action == "replace": + old_text = edit.get("old_text") or "" + if not old_text: + raise _PatchError(f"old_text required for replace: {path}") + new_text = edit.get("new_text") + if new_text is None: + raise _PatchError(f"new_text required for replace: {path}") + + pending = writes.get(source) + if pending is not None: + content = pending + elif source.exists(): + raw = source.read_bytes() + try: + content = raw.decode("utf-8") + except UnicodeDecodeError: + raise _PatchError(f"file is not UTF-8 text: {path}") + else: + raise _PatchError(f"file to update does not exist: {path}") + + if pending is None and not source.is_file(): + raise _PatchError(f"path to update is not a file: {path}") + + uses_crlf = "\r\n" in content + norm_content = content.replace("\r\n", "\n") + norm_old = old_text.replace("\r\n", "\n") + + pos = norm_content.find(norm_old) + if pos < 0: + raise _PatchError(f"old_text not found in {path}") + if norm_content.find(norm_old, pos + 1) >= 0: + raise _PatchError(f"old_text appears multiple times in {path}") + + new_norm = ( + norm_content[:pos] + + new_text.replace("\r\n", "\n") + + norm_content[pos + len(norm_old) :] + ) + if new_norm and not new_norm.endswith("\n"): + new_norm += "\n" + if uses_crlf: + new_norm = new_norm.replace("\n", "\r\n") + + writes[source] = new_norm + deletes.discard(source) + added, deleted = _line_diff_stats(content, new_norm) + summaries.append( + _PatchSummary( + action="update", path=path, added=added, deleted=deleted + ) + ) + + elif action == "delete": + old_text = edit.get("old_text") or "" + if not old_text: + raise _PatchError(f"old_text required for delete: {path}") + + pending = writes.get(source) + if pending is not None: + content = pending + elif source.exists(): + raw = source.read_bytes() + try: + content = raw.decode("utf-8") + except UnicodeDecodeError: + raise _PatchError(f"file is not UTF-8 text: {path}") + else: + raise _PatchError(f"file to update does not exist: {path}") + + if pending is None and not source.is_file(): + raise _PatchError(f"path to update is not a file: {path}") + + uses_crlf = "\r\n" in content + norm_content = content.replace("\r\n", "\n") + norm_old = old_text.replace("\r\n", "\n") + + pos = norm_content.find(norm_old) + if pos < 0: + raise _PatchError(f"old_text not found in {path}") + if norm_content.find(norm_old, pos + 1) >= 0: + raise _PatchError(f"old_text appears multiple times in {path}") + + if norm_old == norm_content: + deletes.add(source) + writes.pop(source, None) + added, deleted = 0, _text_line_count(content) + summaries.append( + _PatchSummary( + action="delete", path=path, added=added, deleted=deleted + ) + ) + else: + new_norm = ( + norm_content[:pos] + norm_content[pos + len(norm_old) :] + ) + if new_norm and not new_norm.endswith("\n"): + new_norm += "\n" + if uses_crlf: + new_norm = new_norm.replace("\n", "\r\n") + writes[source] = new_norm + deletes.discard(source) + added, deleted = _line_diff_stats(content, new_norm) + summaries.append( + _PatchSummary( + action="update", path=path, added=added, deleted=deleted + ) + ) + + else: + raise _PatchError(f"unknown action: {action}") + + if dry_run: + return "Patch dry-run succeeded:\n" + "\n".join( + _format_summary(summary) for summary in summaries + ) + + backups: dict[Path, bytes | None] = {} + for path in set(writes) | deletes: + backups[path] = path.read_bytes() if path.exists() else None + + try: + for path in deletes: + if path.exists(): + path.unlink() + for path, content in writes.items(): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content, encoding="utf-8", newline="") + except Exception: + for path, data in backups.items(): + if data is None: + if path.exists(): + path.unlink() + else: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(data) + raise + + for path in set(writes) | deletes: + self._file_states.record_write(path) + return "Patch applied:\n" + "\n".join( + _format_summary(summary) for summary in summaries + ) + except PermissionError as exc: + return f"Error: {exc}" + except _PatchError as exc: + return f"Error applying patch: {exc}" + except Exception as exc: + return f"Error applying patch: {exc}" diff --git a/nanobot/agent/tools/exec_session.py b/nanobot/agent/tools/exec_session.py new file mode 100644 index 000000000..4dadb2d36 --- /dev/null +++ b/nanobot/agent/tools/exec_session.py @@ -0,0 +1,591 @@ +"""Session support for long-running exec workflows.""" + +from __future__ import annotations + +import asyncio +import shutil +import time +import uuid +from contextlib import suppress +from dataclasses import dataclass +from typing import Any + +from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema + + +DEFAULT_YIELD_MS = 1000 +MAX_YIELD_MS = 30_000 +DEFAULT_WAIT_FOR_MS = 10_000 +MAX_WAIT_FOR_MS = 120_000 +DEFAULT_MAX_OUTPUT_CHARS = 10_000 +MAX_OUTPUT_CHARS = 50_000 + + +@dataclass(slots=True) +class _SessionPoll: + output: str + done: bool + exit_code: int | None + elapsed_s: float = 0.0 + timed_out: bool = False + terminated: bool = False + stdin_closed: bool = False + truncated_chars: int = 0 + + +@dataclass(slots=True) +class ExecSessionInfo: + session_id: str + command: str + cwd: str + elapsed_s: float + idle_s: float + remaining_s: float + returncode: int | None + + +class _ExecSession: + def __init__( + self, + *, + session_id: str, + process: asyncio.subprocess.Process, + command: str, + cwd: str, + timeout: int, + ) -> None: + self.session_id = session_id + self.process = process + self.command = command + self.cwd = cwd + self.started_at = time.monotonic() + self.deadline = time.monotonic() + timeout + self.last_access = time.monotonic() + self._chunks: list[str] = [] + self._lock = asyncio.Lock() + self._timed_out = False + self._stdout_task = asyncio.create_task(self._read_stream(process.stdout, "")) + self._stderr_task = asyncio.create_task(self._read_stream(process.stderr, "STDERR:\n")) + + async def _read_stream( + self, + stream: asyncio.StreamReader | None, + prefix: str, + ) -> None: + if stream is None: + return + first = True + while True: + chunk = await stream.read(4096) + if not chunk: + break + text = chunk.decode("utf-8", errors="replace") + if prefix and first: + text = prefix + text + first = False + async with self._lock: + self._chunks.append(text) + + async def write(self, chars: str) -> str | None: + if self.process.returncode is not None: + return "session has already exited" + if self.process.stdin is None: + return "session stdin is not available" + try: + self.process.stdin.write(chars.encode("utf-8")) + await self.process.stdin.drain() + except (BrokenPipeError, ConnectionResetError): + return "session stdin is closed" + return None + + async def close_stdin(self) -> str | None: + if self.process.returncode is not None: + return "session has already exited" + if self.process.stdin is None: + return "session stdin is not available" + self.process.stdin.close() + with suppress(BrokenPipeError, ConnectionResetError): + await self.process.stdin.wait_closed() + return None + + async def poll( + self, + yield_time_ms: int, + max_output_chars: int, + *, + terminated: bool = False, + stdin_closed: bool = False, + ) -> _SessionPoll: + self.last_access = time.monotonic() + if yield_time_ms > 0 and self.process.returncode is None: + await asyncio.sleep(min(yield_time_ms, MAX_YIELD_MS) / 1000) + + if self.process.returncode is None and time.monotonic() >= self.deadline: + self._timed_out = True + await self.kill() + + if self.process.returncode is not None: + with suppress(asyncio.TimeoutError): + await asyncio.wait_for( + asyncio.gather(self._stdout_task, self._stderr_task), + timeout=2.0, + ) + + async with self._lock: + output = "".join(self._chunks) + self._chunks.clear() + + output, truncated = _truncate_output(output, max_output_chars) + return _SessionPoll( + output=output, + done=self.process.returncode is not None, + exit_code=self.process.returncode, + elapsed_s=max(0.0, time.monotonic() - self.started_at), + timed_out=self._timed_out, + terminated=terminated, + stdin_closed=stdin_closed, + truncated_chars=truncated, + ) + + async def kill(self) -> None: + if self.process.returncode is not None: + return + self.process.kill() + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(self.process.wait(), timeout=5.0) + + +class ExecSessionManager: + def __init__(self, *, max_sessions: int = 8, idle_timeout: int = 1800) -> None: + self.max_sessions = max_sessions + self.idle_timeout = idle_timeout + self._sessions: dict[str, _ExecSession] = {} + self._lock = asyncio.Lock() + + async def start( + self, + *, + command: str, + cwd: str, + env: dict[str, str], + timeout: int, + shell_program: str | None, + login: bool, + yield_time_ms: int, + max_output_chars: int, + ) -> tuple[str, _SessionPoll]: + async with self._lock: + await self._cleanup_locked() + if len(self._sessions) >= self.max_sessions: + raise RuntimeError(f"maximum exec sessions reached ({self.max_sessions})") + process = await self._spawn(command, cwd, env, shell_program, login) + session_id = uuid.uuid4().hex[:12] + session = _ExecSession( + session_id=session_id, + process=process, + command=command, + cwd=cwd, + timeout=timeout, + ) + self._sessions[session_id] = session + + poll = await session.poll(yield_time_ms, max_output_chars) + if poll.done: + async with self._lock: + self._sessions.pop(session_id, None) + return session_id, poll + + async def write( + self, + *, + session_id: str, + chars: str | None, + close_stdin: bool, + terminate: bool, + yield_time_ms: int, + max_output_chars: int, + ) -> _SessionPoll: + async with self._lock: + await self._cleanup_locked() + session = self._sessions.get(session_id) + if session is None: + raise KeyError(session_id) + + if chars: + error = await session.write(chars) + if error: + raise RuntimeError(error) + stdin_closed = False + if close_stdin: + error = await session.close_stdin() + if error: + raise RuntimeError(error) + stdin_closed = True + if terminate: + await session.kill() + poll = await session.poll( + yield_time_ms, + max_output_chars, + terminated=terminate, + stdin_closed=stdin_closed, + ) + if poll.done: + async with self._lock: + self._sessions.pop(session_id, None) + return poll + + async def list(self) -> list[ExecSessionInfo]: + async with self._lock: + await self._cleanup_locked() + now = time.monotonic() + return [ + ExecSessionInfo( + session_id=session_id, + command=session.command, + cwd=session.cwd, + elapsed_s=max(0.0, now - session.started_at), + idle_s=max(0.0, now - session.last_access), + remaining_s=max(0.0, session.deadline - now), + returncode=session.process.returncode, + ) + for session_id, session in sorted(self._sessions.items()) + ] + + async def _cleanup_locked(self) -> None: + now = time.monotonic() + stale = [ + session_id + for session_id, session in self._sessions.items() + if now - session.last_access > self.idle_timeout + ] + for session_id in stale: + session = self._sessions.pop(session_id) + await session.kill() + + async def _spawn( + self, + command: str, + cwd: str, + env: dict[str, str], + shell_program: str | None, + login: bool, + ) -> asyncio.subprocess.Process: + from nanobot.agent.tools import shell + + if shell._IS_WINDOWS: + return await asyncio.create_subprocess_shell( + command, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=env, + ) + shell_program = shell_program or shutil.which("bash") or "/bin/bash" + args = [shell_program] + if login and shell_program.rsplit("/", 1)[-1] in {"bash", "zsh"}: + args.append("-l") + args.extend(["-c", command]) + return await asyncio.create_subprocess_exec( + *args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=env, + ) + + +DEFAULT_EXEC_SESSION_MANAGER = ExecSessionManager() + + +def clamp_session_int(value: int | None, default: int, minimum: int, maximum: int) -> int: + if value is None: + return default + return min(max(value, minimum), maximum) + + +def _truncate_output(output: str, max_output_chars: int) -> tuple[str, int]: + if len(output) <= max_output_chars: + return output, 0 + half = max_output_chars // 2 + omitted = len(output) - max_output_chars + return ( + output[:half] + + f"\n\n... ({omitted:,} chars truncated) ...\n\n" + + output[-half:], + omitted, + ) + + +def format_session_poll(session_id: str, poll: _SessionPoll) -> str: + parts = [poll.output] if poll.output else [] + if poll.truncated_chars: + parts.append(f"(output truncated by {poll.truncated_chars:,} chars)") + if poll.timed_out: + parts.append("Error: Command timed out; session was terminated.") + if poll.terminated and not poll.timed_out: + parts.append("Session terminated.") + if poll.stdin_closed: + parts.append("Stdin closed.") + if poll.done: + parts.append(f"Exit code: {poll.exit_code}") + else: + parts.append(f"Process running. session_id: {session_id}") + parts.append(f"Elapsed: {poll.elapsed_s:.1f}s") + return "\n".join(parts) if parts else "(no output yet)" + + +@tool_parameters( + tool_parameters_schema( + session_id=StringSchema("Session id returned by exec when yield_time_ms is used."), + chars=StringSchema( + "Bytes/text to write to stdin. Omit or pass an empty string to only poll recent output.", + nullable=True, + ), + close_stdin=BooleanSchema( + description="Close stdin after writing chars. Useful for commands waiting for EOF.", + default=False, + ), + terminate=BooleanSchema( + description="Terminate the running exec session.", + default=False, + ), + yield_time_ms=IntegerSchema( + DEFAULT_YIELD_MS, + description="Milliseconds to wait before returning recent output (default 1000, max 30000).", + minimum=0, + maximum=MAX_YIELD_MS, + ), + wait_for=StringSchema( + "Optional text to wait for in output before returning. " + "Useful for interactive commands and dev servers.", + nullable=True, + ), + wait_timeout_ms=IntegerSchema( + DEFAULT_WAIT_FOR_MS, + description="Maximum milliseconds to wait for wait_for text (default 10000, max 120000).", + minimum=0, + maximum=MAX_WAIT_FOR_MS, + nullable=True, + ), + max_output_chars=IntegerSchema( + DEFAULT_MAX_OUTPUT_CHARS, + description="Maximum output characters to return from this poll (default 10000, max 50000).", + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + ), + max_output_tokens=IntegerSchema( + DEFAULT_MAX_OUTPUT_CHARS, + description="Compatibility alias for max_output_chars. The current runtime uses a character budget.", + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + nullable=True, + ), + required=["session_id"], + ) +) +class WriteStdinTool(Tool): + """Write to or poll a running exec session.""" + + _scopes = {"core", "subagent"} + config_key = "exec" + + @classmethod + def config_cls(cls): + from nanobot.agent.tools.shell import ExecToolConfig + + return ExecToolConfig + + @classmethod + def enabled(cls, ctx: Any) -> bool: + return ctx.config.exec.enable + + def __init__( + self, + *, + manager: ExecSessionManager | None = None, + ) -> None: + self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER + + @classmethod + def create(cls, ctx: Any) -> Tool: + return cls() + + @property + def exclusive(self) -> bool: + return True + + @property + def name(self) -> str: + return "write_stdin" + + @property + def description(self) -> str: + return ( + "Interact with a running exec session created by exec with " + "yield_time_ms. Use chars='' to poll without writing, chars to send " + "stdin, close_stdin=true to send EOF, or terminate=true to stop the " + "process. Use wait_for with wait_timeout_ms for dev servers, test " + "watchers, and prompts where you need to wait for expected output. " + "Do not use this to start new commands; start them with exec." + ) + + async def execute( + self, + session_id: str, + chars: str | None = None, + close_stdin: bool = False, + terminate: bool = False, + yield_time_ms: int | None = None, + wait_for: str | None = None, + wait_timeout_ms: int | None = None, + max_output_chars: int | None = None, + max_output_tokens: int | None = None, + **kwargs: Any, + ) -> str: + try: + if max_output_chars is None: + max_output_chars = max_output_tokens + output_limit = clamp_session_int( + max_output_chars, + DEFAULT_MAX_OUTPUT_CHARS, + 1000, + MAX_OUTPUT_CHARS, + ) + if wait_for: + return await self._wait_for_output( + session_id=session_id, + chars=chars, + close_stdin=close_stdin, + terminate=terminate, + wait_for=wait_for, + wait_timeout_ms=clamp_session_int( + wait_timeout_ms, + DEFAULT_WAIT_FOR_MS, + 0, + MAX_WAIT_FOR_MS, + ), + max_output_chars=output_limit, + ) + poll = await self._manager.write( + session_id=session_id, + chars=chars, + close_stdin=close_stdin, + terminate=terminate, + yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS), + max_output_chars=output_limit, + ) + return format_session_poll(session_id, poll) + except KeyError: + return f"Error: exec session not found: {session_id}" + except Exception as exc: + return f"Error writing to exec session: {exc}" + + async def _wait_for_output( + self, + *, + session_id: str, + chars: str | None, + close_stdin: bool, + terminate: bool, + wait_for: str, + wait_timeout_ms: int, + max_output_chars: int, + ) -> str: + deadline = time.monotonic() + (wait_timeout_ms / 1000) + aggregate: list[str] = [] + first = True + poll: _SessionPoll | None = None + + while True: + remaining_ms = max(0, int((deadline - time.monotonic()) * 1000)) + step_ms = min(500, remaining_ms) + poll = await self._manager.write( + session_id=session_id, + chars=chars if first else None, + close_stdin=close_stdin if first else False, + terminate=terminate if first else False, + yield_time_ms=step_ms, + max_output_chars=max_output_chars, + ) + first = False + if poll.output: + aggregate.append(poll.output) + joined = "".join(aggregate) + if wait_for in joined: + poll.output = joined + return format_session_poll(session_id, poll) + if poll.done or remaining_ms <= 0: + poll.output = "".join(aggregate) + result = format_session_poll(session_id, poll) + if wait_for not in poll.output: + result += f"\nWait target not observed: {wait_for!r}" + return result + + +@tool_parameters(tool_parameters_schema()) +class ListExecSessionsTool(Tool): + """List active exec sessions.""" + + _scopes = {"core", "subagent"} + config_key = "exec" + + @classmethod + def config_cls(cls): + from nanobot.agent.tools.shell import ExecToolConfig + + return ExecToolConfig + + @classmethod + def enabled(cls, ctx: Any) -> bool: + return ctx.config.exec.enable + + def __init__( + self, + *, + manager: ExecSessionManager | None = None, + ) -> None: + self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER + + @classmethod + def create(cls, ctx: Any) -> Tool: + return cls() + + @property + def name(self) -> str: + return "list_exec_sessions" + + @property + def description(self) -> str: + return ( + "List active long-running exec sessions, including session_id, cwd, " + "elapsed time, idle time, remaining timeout, and command preview. " + "Use this to recover a session_id after context shifts before " + "polling, writing stdin, or terminating with write_stdin." + ) + + @property + def read_only(self) -> bool: + return True + + async def execute(self, **kwargs: Any) -> str: + try: + sessions = await self._manager.list() + if not sessions: + return "No active exec sessions." + lines = [] + for info in sessions: + command = " ".join(info.command.split()) + if len(command) > 120: + command = command[:119] + "..." + status = "exited" if info.returncode is not None else "running" + lines.append( + f"{info.session_id} | {status} | elapsed={info.elapsed_s:.1f}s " + f"| idle={info.idle_s:.1f}s | remaining={info.remaining_s:.1f}s " + f"| cwd={info.cwd} | {command}" + ) + return "\n".join(lines) + except Exception as exc: + return f"Error listing exec sessions: {exc}" diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 8f4f660da..fa63e5f66 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -132,6 +132,10 @@ def _parse_page_range(pages: str, total: int) -> tuple[int, int]: minimum=1, ), pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"), + force=BooleanSchema( + description="Bypass same-file read deduplication and return content again.", + default=False, + ), required=["path"], ) ) @@ -154,7 +158,11 @@ class ReadFileTool(_FsTool): "Text output format: LINE_NUM|CONTENT. " "Images return visual content for analysis. " "Supports PDF, DOCX, XLSX, PPTX documents. " + "Use find_files/list_dir first when the path is uncertain. " + "Read the relevant range before editing so replacements or patches " + "are based on current content. " "Use offset and limit for large text files. " + "Use force=true to re-read content even if unchanged. " "Reads exceeding ~128K chars are truncated." ) @@ -162,7 +170,15 @@ class ReadFileTool(_FsTool): def read_only(self) -> bool: return True - async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any: + async def execute( + self, + path: str | None = None, + offset: int = 1, + limit: int | None = None, + pages: str | None = None, + force: bool = False, + **kwargs: Any, + ) -> Any: try: if not path: return "Error reading file: Unknown path" @@ -202,7 +218,13 @@ class ReadFileTool(_FsTool): current_mtime = os.path.getmtime(fp) except OSError: current_mtime = 0.0 - if entry and entry.can_dedup and entry.offset == offset and entry.limit == limit: + if ( + not force + and entry + and entry.can_dedup + and entry.offset == offset + and entry.limit == limit + ): if current_mtime != entry.mtime: # File was modified externally - force full read and mark as not dedupable entry.can_dedup = False @@ -365,9 +387,10 @@ class WriteFileTool(_FsTool): @property def description(self) -> str: return ( - "Write content to a file. Overwrites if the file already exists; " - "creates parent directories as needed. " - "For partial edits, prefer edit_file instead." + "Create a new file or intentionally replace an entire file with " + "the provided content. Overwrites existing files and creates parent " + "directories as needed. For code changes or partial edits, prefer " + "apply_patch; use edit_file only for small exact replacements." ) async def execute(self, path: str | None = None, content: str | None = None, **kwargs: Any) -> str: @@ -657,6 +680,24 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]: old_text=StringSchema("The text to find and replace"), new_text=StringSchema("The text to replace with"), replace_all=BooleanSchema(description="Replace all occurrences (default false)"), + occurrence=IntegerSchema( + 1, + description="Optional 1-based occurrence to replace when old_text appears multiple times.", + minimum=1, + nullable=True, + ), + line_hint=IntegerSchema( + 1, + description="Optional 1-based line hint used to choose the nearest match.", + minimum=1, + nullable=True, + ), + expected_replacements=IntegerSchema( + 1, + description="Optional guard for the number of replacements that must be made.", + minimum=1, + nullable=True, + ), required=["path", "old_text", "new_text"], ) ) @@ -674,10 +715,13 @@ class EditFileTool(_FsTool): @property def description(self) -> str: return ( - "Edit a file by replacing old_text with new_text. " - "Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. " - "If old_text matches multiple times, you must provide more context " - "or set replace_all=true. Shows a diff of the closest match on failure." + "Perform a small, exact replacement in one file by replacing " + "old_text with new_text. Use this for narrow text substitutions " + "with old_text copied from read_file. For multi-file, structural, " + "or generated code edits, prefer apply_patch. If old_text matches " + "multiple times, provide more context or set occurrence, line_hint, " + "replace_all, and expected_replacements. Shows closest-match " + "diagnostics on failure." ) @staticmethod @@ -688,7 +732,8 @@ class EditFileTool(_FsTool): async def execute( self, path: str | None = None, old_text: str | None = None, new_text: str | None = None, - replace_all: bool = False, **kwargs: Any, + replace_all: bool = False, occurrence: int | None = None, + line_hint: int | None = None, expected_replacements: int | None = None, **kwargs: Any, ) -> str: try: if not path: @@ -697,10 +742,12 @@ class EditFileTool(_FsTool): raise ValueError("Unknown old_text") if new_text is None: raise ValueError("Unknown new_text") - - # .ipynb detection - if path.endswith(".ipynb"): - return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file." + if occurrence is not None and occurrence < 1: + return "Error: occurrence must be >= 1." + if line_hint is not None and line_hint < 1: + return "Error: line_hint must be >= 1." + if expected_replacements is not None and expected_replacements < 1: + return "Error: expected_replacements must be >= 1." fp = self._resolve(path) @@ -743,15 +790,42 @@ class EditFileTool(_FsTool): if not matches: return self._not_found_msg(old_text, content, path) count = len(matches) + if replace_all and occurrence is not None: + return "Error: occurrence cannot be used with replace_all=true." + if replace_all and line_hint is not None: + return "Error: line_hint cannot be used with replace_all=true." + if occurrence is not None and line_hint is not None: + return "Error: line_hint cannot be used with occurrence." if count > 1 and not replace_all: - line_numbers = [match.line for match in matches] - preview = ", ".join(f"line {n}" for n in line_numbers[:3]) - if len(line_numbers) > 3: - preview += ", ..." - location_hint = f" at {preview}" if preview else "" + if occurrence is not None: + if occurrence > count: + return ( + f"Error: occurrence {occurrence} is out of range; " + f"old_text appears {count} times." + ) + elif line_hint is not None: + nearest = min(matches, key=lambda match: abs(match.line - line_hint)) + distance = abs(nearest.line - line_hint) + if sum(1 for match in matches if abs(match.line - line_hint) == distance) > 1: + return ( + f"Error: line_hint {line_hint} is ambiguous; " + f"old_text appears {count} times." + ) + else: + line_numbers = [match.line for match in matches] + preview = ", ".join(f"line {n}" for n in line_numbers[:3]) + if len(line_numbers) > 3: + preview += ", ..." + location_hint = f" at {preview}" if preview else "" + return ( + f"Warning: old_text appears {count} times{location_hint}. " + "Provide more context, set occurrence to choose one match, " + "or set replace_all=true." + ) + elif occurrence is not None and occurrence > count: return ( - f"Warning: old_text appears {count} times{location_hint}. " - "Provide more context to make it unique, or set replace_all=true." + f"Error: occurrence {occurrence} is out of range; " + f"old_text appears {count} time." ) norm_new = new_text.replace("\r\n", "\n") @@ -760,7 +834,17 @@ class EditFileTool(_FsTool): if fp.suffix.lower() not in self._MARKDOWN_EXTS: norm_new = self._strip_trailing_ws(norm_new) - selected = matches if replace_all else matches[:1] + if replace_all: + selected = matches + elif line_hint is not None: + selected = [min(matches, key=lambda match: abs(match.line - line_hint))] + else: + selected = [matches[occurrence - 1 if occurrence else 0]] + if expected_replacements is not None and len(selected) != expected_replacements: + return ( + f"Error: expected {expected_replacements} replacements but " + f"would make {len(selected)}." + ) new_content = content for match in reversed(selected): replacement = _preserve_quote_style(norm_old, match.text, norm_new) diff --git a/nanobot/agent/tools/image_generation.py b/nanobot/agent/tools/image_generation.py index 58eaaf7d8..a194d0fee 100644 --- a/nanobot/agent/tools/image_generation.py +++ b/nanobot/agent/tools/image_generation.py @@ -21,7 +21,6 @@ from nanobot.providers.image_generation import ( ImageGenerationProvider, get_image_gen_provider, ) -from nanobot.providers.registry import find_by_name from nanobot.utils.artifacts import ( ArtifactError, generated_image_tool_result, @@ -118,10 +117,6 @@ class ImageGenerationTool(Tool): def _provider_config(self) -> ProviderConfig | None: return self.provider_configs.get(self.config.provider) - def _provider_allows_missing_api_key(self) -> bool: - spec = find_by_name(self.config.provider) - return bool(spec and (spec.is_local or spec.is_direct or spec.is_oauth)) - def _provider_client(self) -> ImageGenerationProvider | None: provider = self._provider_config() cls = get_image_gen_provider(self.config.provider) @@ -135,12 +130,6 @@ class ImageGenerationTool(Tool): } return cls(**kwargs) - def _missing_api_key_error(self) -> str: - cls = get_image_gen_provider(self.config.provider) - if cls and cls.missing_key_message: - return f"Error: {cls.missing_key_message}" - return f"Error: {self.config.provider} API key is not configured." - def _resolve_reference_image(self, value: str) -> str: raw_path = Path(value).expanduser() path = raw_path if raw_path.is_absolute() else self.workspace / raw_path @@ -178,9 +167,6 @@ class ImageGenerationTool(Tool): client = self._provider_client() if client is None: return f"Error: unsupported image generation provider '{self.config.provider}'" - provider = self._provider_config() - if not self._provider_allows_missing_api_key() and (not provider or not provider.api_key): - return self._missing_api_key_error() requested = count or 1 if requested > self.config.max_images_per_turn: diff --git a/nanobot/agent/tools/notebook.py b/nanobot/agent/tools/notebook.py deleted file mode 100644 index 0980b7c93..000000000 --- a/nanobot/agent/tools/notebook.py +++ /dev/null @@ -1,162 +0,0 @@ -"""NotebookEditTool — edit Jupyter .ipynb notebooks.""" - -from __future__ import annotations - -import json -import uuid -from typing import Any - -from nanobot.agent.tools.base import tool_parameters -from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema -from nanobot.agent.tools.filesystem import _FsTool - - -def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict: - cell: dict[str, Any] = { - "cell_type": cell_type, - "source": source, - "metadata": {}, - } - if cell_type == "code": - cell["outputs"] = [] - cell["execution_count"] = None - if generate_id: - cell["id"] = uuid.uuid4().hex[:8] - return cell - - -def _make_empty_notebook() -> dict: - return { - "nbformat": 4, - "nbformat_minor": 5, - "metadata": { - "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, - "language_info": {"name": "python"}, - }, - "cells": [], - } - - -@tool_parameters( - tool_parameters_schema( - path=StringSchema("Path to the .ipynb notebook file"), - cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0), - new_source=StringSchema("New source content for the cell"), - cell_type=StringSchema( - "Cell type: 'code' or 'markdown' (default: code)", - enum=["code", "markdown"], - ), - edit_mode=StringSchema( - "Mode: 'replace' (default), 'insert' (after target), or 'delete'", - enum=["replace", "insert", "delete"], - ), - required=["path", "cell_index"], - ) -) -class NotebookEditTool(_FsTool): - """Edit Jupyter notebook cells: replace, insert, or delete.""" - _scopes = {"core"} - - _VALID_CELL_TYPES = frozenset({"code", "markdown"}) - _VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"}) - - @property - def name(self) -> str: - return "notebook_edit" - - @property - def description(self) -> str: - return ( - "Edit a Jupyter notebook (.ipynb) cell. " - "Modes: replace (default) replaces cell content, " - "insert adds a new cell after the target index, " - "delete removes the cell at the index. " - "cell_index is 0-based." - ) - - async def execute( - self, - path: str | None = None, - cell_index: int = 0, - new_source: str = "", - cell_type: str = "code", - edit_mode: str = "replace", - **kwargs: Any, - ) -> str: - try: - if not path: - return "Error: path is required" - - if not path.endswith(".ipynb"): - return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files." - - if edit_mode not in self._VALID_EDIT_MODES: - return ( - f"Error: Invalid edit_mode '{edit_mode}'. " - "Use one of: replace, insert, delete." - ) - - if cell_type not in self._VALID_CELL_TYPES: - return ( - f"Error: Invalid cell_type '{cell_type}'. " - "Use one of: code, markdown." - ) - - fp = self._resolve(path) - - # Create new notebook if file doesn't exist and mode is insert - if not fp.exists(): - if edit_mode != "insert": - return f"Error: File not found: {path}" - nb = _make_empty_notebook() - cell = _new_cell(new_source, cell_type, generate_id=True) - nb["cells"].append(cell) - fp.parent.mkdir(parents=True, exist_ok=True) - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully created {fp} with 1 cell" - - try: - nb = json.loads(fp.read_text(encoding="utf-8")) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - return f"Error: Failed to parse notebook: {e}" - - cells = nb.get("cells", []) - nbformat_minor = nb.get("nbformat_minor", 0) - generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5 - - if edit_mode == "delete": - if cell_index < 0 or cell_index >= len(cells): - return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)" - cells.pop(cell_index) - nb["cells"] = cells - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully deleted cell {cell_index} from {fp}" - - if edit_mode == "insert": - insert_at = min(cell_index + 1, len(cells)) - cell = _new_cell(new_source, cell_type, generate_id=generate_id) - cells.insert(insert_at, cell) - nb["cells"] = cells - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully inserted cell at index {insert_at} in {fp}" - - # Default: replace - if cell_index < 0 or cell_index >= len(cells): - return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)" - cells[cell_index]["source"] = new_source - if cell_type and cells[cell_index].get("cell_type") != cell_type: - cells[cell_index]["cell_type"] = cell_type - if cell_type == "code": - cells[cell_index].setdefault("outputs", []) - cells[cell_index].setdefault("execution_count", None) - elif "outputs" in cells[cell_index]: - del cells[cell_index]["outputs"] - cells[cell_index].pop("execution_count", None) - nb["cells"] = cells - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully edited cell {cell_index} in {fp}" - - except PermissionError as e: - return f"Error: {e}" - except Exception as e: - return f"Error editing notebook: {e}" diff --git a/nanobot/agent/tools/search.py b/nanobot/agent/tools/search.py index 49448030b..0febb122c 100644 --- a/nanobot/agent/tools/search.py +++ b/nanobot/agent/tools/search.py @@ -1,4 +1,4 @@ -"""Search tools: grep.""" +"""Search tools: file discovery and grep.""" from __future__ import annotations @@ -12,6 +12,7 @@ from typing import Any, Iterable, TypeVar from nanobot.agent.tools.filesystem import ListDirTool, _FsTool _DEFAULT_HEAD_LIMIT = 250 +_DEFAULT_FILE_HEAD_LIMIT = 200 T = TypeVar("T") _TYPE_GLOB_MAP = { "py": ("*.py", "*.pyi"), @@ -88,6 +89,14 @@ def _matches_type(name: str, file_type: str | None) -> bool: return any(fnmatch.fnmatch(name.lower(), pattern.lower()) for pattern in patterns) +def _matches_query(rel_path: str, query: str | None) -> bool: + if not query: + return True + haystack = rel_path.lower() + terms = [part for part in query.lower().split() if part] + return all(term in haystack for term in terms) + + class _SearchTool(_FsTool): _IGNORE_DIRS = set(ListDirTool._IGNORE_DIRS) @@ -109,6 +118,163 @@ class _SearchTool(_FsTool): yield current / filename +class FindFilesTool(_SearchTool): + """Find files by path fragment, glob, or type.""" + _scopes = {"core", "subagent"} + + @property + def name(self) -> str: + return "find_files" + + @property + def description(self) -> str: + return ( + "Find files by path fragment, glob, or file type. " + "Use this before read_file when you need to locate files, and " + "prefer it over shell find/ls for ordinary workspace discovery. " + "Returns workspace-relative paths and skips common dependency/build " + "directories." + ) + + @property + def read_only(self) -> bool: + return True + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory or file to search in (default '.')", + }, + "query": { + "type": "string", + "description": ( + "Optional case-insensitive path fragment search. " + "Whitespace-separated terms must all be present." + ), + }, + "glob": { + "type": "string", + "description": "Optional file filter, e.g. '*.py' or 'tests/**/test_*.py'", + }, + "type": { + "type": "string", + "description": "Optional file type shorthand, e.g. 'py', 'ts', 'md', 'json'", + }, + "include_dirs": { + "type": "boolean", + "description": "Include matching directories as well as files (default false)", + }, + "sort": { + "type": "string", + "enum": ["path", "modified"], + "description": "Sort by path or most recently modified first (default path)", + }, + "head_limit": { + "type": "integer", + "description": "Maximum number of paths to return (default 200, 0 for all, max 1000)", + "minimum": 0, + "maximum": 1000, + }, + "offset": { + "type": "integer", + "description": "Skip the first N results before applying head_limit", + "minimum": 0, + "maximum": 100000, + }, + }, + } + + def _iter_paths(self, root: Path, *, include_dirs: bool) -> Iterable[Path]: + if root.is_file(): + yield root + return + if include_dirs: + yield root + for dirpath, dirnames, filenames in os.walk(root): + dirnames[:] = sorted(d for d in dirnames if d not in self._IGNORE_DIRS) + current = Path(dirpath) + if include_dirs and current != root: + yield current + for filename in sorted(filenames): + yield current / filename + + async def execute( + self, + path: str = ".", + query: str | None = None, + glob: str | None = None, + type: str | None = None, + include_dirs: bool = False, + sort: str = "path", + head_limit: int | None = None, + offset: int = 0, + **kwargs: Any, + ) -> str: + try: + target = self._resolve(path or ".") + if not target.exists(): + return f"Error: Path not found: {path}" + if not (target.is_dir() or target.is_file()): + return f"Error: Unsupported path: {path}" + + if sort not in {"path", "modified"}: + return "Error: sort must be 'path' or 'modified'" + + limit = ( + _DEFAULT_FILE_HEAD_LIMIT + if head_limit is None + else None if head_limit == 0 else head_limit + ) + root = target if target.is_dir() else target.parent + matches: list[tuple[str, float]] = [] + + for candidate in self._iter_paths(target, include_dirs=include_dirs): + if candidate.is_dir() and not include_dirs: + continue + rel_path = candidate.relative_to(root).as_posix() + display_path = self._display_path(candidate, root) + name = candidate.name + + if glob and not _match_glob(rel_path, name, glob): + continue + if candidate.is_file() and not _matches_type(name, type): + continue + if candidate.is_dir() and type: + continue + if not _matches_query(display_path, query): + continue + try: + mtime = candidate.stat().st_mtime + except OSError: + mtime = 0.0 + suffix = "/" if candidate.is_dir() else "" + matches.append((display_path + suffix, mtime)) + + if sort == "modified": + matches.sort(key=lambda item: (-item[1], item[0])) + else: + matches.sort(key=lambda item: item[0]) + + paths = [item[0] for item in matches] + paged, truncated = _paginate(paths, limit, offset) + if not paged: + return "No files found" + + result = "\n".join(paged) + note = _pagination_note(limit, offset, truncated) + if note: + result += "\n\n" + note + return result + except PermissionError as e: + return f"Error: {e}" + except Exception as e: + return f"Error finding files: {e}" + + class GrepTool(_SearchTool): """Search file contents using a regex-like pattern.""" _scopes = {"core", "subagent"} @@ -125,7 +291,8 @@ class GrepTool(_SearchTool): return ( "Search file contents with a regex pattern. " "Default output_mode is files_with_matches (file paths only); " - "use content mode for matching lines with context. " + "use content mode for matching lines with context. Prefer this " + "over shell grep for ordinary workspace searches. " "Skips binary and files >2 MB. Supports glob/type filtering." ) diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 0252b9746..537c89343 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -8,6 +8,7 @@ import re import shutil import sys from contextlib import suppress +from dataclasses import dataclass from pathlib import Path from typing import Any @@ -15,8 +16,17 @@ from loguru import logger from pydantic import Field from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.exec_session import ( + DEFAULT_MAX_OUTPUT_CHARS, + DEFAULT_YIELD_MS, + DEFAULT_EXEC_SESSION_MANAGER, + MAX_OUTPUT_CHARS, + MAX_YIELD_MS, + clamp_session_int, + format_session_poll, +) from nanobot.agent.tools.sandbox import wrap_command -from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema +from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base @@ -44,10 +54,22 @@ class ExecToolConfig(Base): deny_patterns: list[str] = Field(default_factory=list) +@dataclass(slots=True) +class _PreparedCommand: + command: str + cwd: str + env: dict[str, str] + timeout: int + shell_program: str | None + login: bool + + @tool_parameters( tool_parameters_schema( command=StringSchema("The shell command to execute"), + cmd=StringSchema("Compatibility alias for command"), working_dir=StringSchema("Optional working directory for the command"), + workdir=StringSchema("Compatibility alias for working_dir"), timeout=IntegerSchema( 60, description=( @@ -57,7 +79,44 @@ class ExecToolConfig(Base): minimum=1, maximum=600, ), - required=["command"], + shell=StringSchema( + "Optional shell binary to launch. On Unix, supports sh, bash, or zsh.", + nullable=True, + ), + login=BooleanSchema( + description="Whether to run bash/zsh with login shell semantics (default true).", + default=True, + nullable=True, + ), + yield_time_ms=IntegerSchema( + description=( + "Optional milliseconds to wait before returning output. " + "When set, a still-running command returns a session_id that " + "can be polled or written to with write_stdin. Omit this field " + "to keep one-shot exec behavior." + ), + minimum=0, + maximum=MAX_YIELD_MS, + nullable=True, + ), + max_output_chars=IntegerSchema( + description=( + "Maximum output characters to return when yield_time_ms is used " + "(default 10000, max 50000)." + ), + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + nullable=True, + ), + max_output_tokens=IntegerSchema( + description=( + "Compatibility alias for max_output_chars. The current runtime " + "uses a character budget." + ), + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + nullable=True, + ), ) ) class ExecTool(Tool): @@ -98,6 +157,7 @@ class ExecTool(Tool): sandbox: str = "", path_append: str = "", allowed_env_keys: list[str] | None = None, + session_manager: Any | None = None, ): self.timeout = timeout self.working_dir = working_dir @@ -125,6 +185,7 @@ class ExecTool(Tool): self.restrict_to_workspace = restrict_to_workspace self.path_append = path_append self.allowed_env_keys = allowed_env_keys or [] + self._session_manager = session_manager or DEFAULT_EXEC_SESSION_MANAGER @property def name(self) -> str: @@ -150,10 +211,15 @@ class ExecTool(Tool): def description(self) -> str: return ( "Execute a shell command and return its output. " - "Prefer read_file/write_file/edit_file over cat/echo/sed, " - "and grep/glob over shell find/grep. " + "Use this for tests, builds, package commands, git commands, and " + "other process execution. Prefer read_file/find_files/grep for " + "inspection and apply_patch/write_file/edit_file for file changes " + "instead of cat, shell find/grep, echo, or sed. " "Use -y or --yes flags to avoid interactive prompts. " - "Output is truncated at 10 000 chars; timeout defaults to 60s." + "For long-running or interactive commands, pass yield_time_ms; " + "if the command keeps running, exec returns a session_id that can " + "be polled or written to with write_stdin. Output is truncated at " + "10 000 chars; timeout defaults to 60s." ) @property @@ -161,9 +227,111 @@ class ExecTool(Tool): return True async def execute( - self, command: str, working_dir: str | None = None, - timeout: int | None = None, **kwargs: Any, + self, command: str | None = None, cmd: str | None = None, + working_dir: str | None = None, workdir: str | None = None, + timeout: int | None = None, shell: str | None = None, + login: bool | None = None, yield_time_ms: int | None = None, + max_output_chars: int | None = None, + max_output_tokens: int | None = None, + **kwargs: Any, ) -> str: + command = command or cmd + working_dir = working_dir or workdir + if not command: + return "Error: Missing command. Provide command or cmd." + if max_output_chars is None: + max_output_chars = max_output_tokens + + prepared = self._prepare_command(command, working_dir, timeout, shell, login) + if isinstance(prepared, str): + return prepared + + if yield_time_ms is not None: + return await self._execute_session(prepared, yield_time_ms, max_output_chars) + + try: + process = await self._spawn( + prepared.command, + prepared.cwd, + prepared.env, + prepared.shell_program, + prepared.login, + ) + + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), + timeout=prepared.timeout, + ) + except asyncio.TimeoutError: + await self._kill_process(process) + return f"Error: Command timed out after {prepared.timeout} seconds" + except asyncio.CancelledError: + await self._kill_process(process) + raise + + output_parts = [] + + if stdout: + output_parts.append(stdout.decode("utf-8", errors="replace")) + + if stderr: + stderr_text = stderr.decode("utf-8", errors="replace") + if stderr_text.strip(): + output_parts.append(f"STDERR:\n{stderr_text}") + + output_parts.append(f"\nExit code: {process.returncode}") + + result = "\n".join(output_parts) if output_parts else "(no output)" + + max_len = clamp_session_int(max_output_chars, self._MAX_OUTPUT, 1000, MAX_OUTPUT_CHARS) + if len(result) > max_len: + half = max_len // 2 + result = ( + result[:half] + + f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n" + + result[-half:] + ) + + return result + + except Exception as e: + return f"Error executing command: {str(e)}" + + async def _execute_session( + self, + prepared: _PreparedCommand, + yield_time_ms: int | None, + max_output_chars: int | None, + ) -> str: + try: + session_id, poll = await self._session_manager.start( + command=prepared.command, + cwd=prepared.cwd, + env=prepared.env, + timeout=prepared.timeout, + shell_program=prepared.shell_program, + login=prepared.login, + yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS), + max_output_chars=clamp_session_int( + max_output_chars, + DEFAULT_MAX_OUTPUT_CHARS, + 1000, + MAX_OUTPUT_CHARS, + ), + ) + return format_session_poll(session_id, poll) + except Exception as exc: + return f"Error executing command: {exc}" + + def _prepare_command( + self, + command: str, + working_dir: str | None = None, + timeout: int | None = None, + shell: str | None = None, + login: bool | None = None, + ) -> _PreparedCommand | str: cwd = working_dir or self.working_dir or os.getcwd() # Prevent an LLM-supplied working_dir from escaping the configured @@ -211,52 +379,24 @@ class ExecTool(Tool): env["NANOBOT_PATH_APPEND"] = self.path_append command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}' - try: - process = await self._spawn(command, cwd, env) + shell_program, shell_error = self._resolve_shell(shell) + if shell_error: + return shell_error - try: - stdout, stderr = await asyncio.wait_for( - process.communicate(), - timeout=effective_timeout, - ) - except asyncio.TimeoutError: - await self._kill_process(process) - return f"Error: Command timed out after {effective_timeout} seconds" - except asyncio.CancelledError: - await self._kill_process(process) - raise - - output_parts = [] - - if stdout: - output_parts.append(stdout.decode("utf-8", errors="replace")) - - if stderr: - stderr_text = stderr.decode("utf-8", errors="replace") - if stderr_text.strip(): - output_parts.append(f"STDERR:\n{stderr_text}") - - output_parts.append(f"\nExit code: {process.returncode}") - - result = "\n".join(output_parts) if output_parts else "(no output)" - - max_len = self._MAX_OUTPUT - if len(result) > max_len: - half = max_len // 2 - result = ( - result[:half] - + f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n" - + result[-half:] - ) - - return result - - except Exception as e: - return f"Error executing command: {str(e)}" + return _PreparedCommand( + command=command, + cwd=cwd, + env=env, + timeout=effective_timeout, + shell_program=shell_program, + login=True if login is None else login, + ) @staticmethod async def _spawn( command: str, cwd: str, env: dict[str, str], + shell_program: str | None = None, + login: bool = True, ) -> asyncio.subprocess.Process: """Launch *command* in a platform-appropriate shell.""" if _IS_WINDOWS: @@ -272,9 +412,14 @@ class ExecTool(Tool): cwd=cwd, env=env, ) - bash = shutil.which("bash") or "/bin/bash" + shell_program = shell_program or shutil.which("bash") or "/bin/bash" + args = [shell_program] + shell_name = Path(shell_program).name.lower() + if login and shell_name in {"bash", "bash.exe", "zsh", "zsh.exe"}: + args.append("-l") + args.extend(["-c", command]) return await asyncio.create_subprocess_exec( - bash, "-l", "-c", command, + *args, stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, @@ -282,6 +427,31 @@ class ExecTool(Tool): env=env, ) + @staticmethod + def _resolve_shell(shell: str | None) -> tuple[str | None, str | None]: + if not shell: + return None, None + if _IS_WINDOWS: + return None, "Error: shell parameter is not supported on Windows" + if "\0" in shell or "\n" in shell or "\r" in shell: + return None, "Error: shell contains invalid characters" + allowed = {"sh", "bash", "zsh"} + path = Path(shell).expanduser() + if path.is_absolute(): + if path.name not in allowed: + return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh" + if not path.is_file() or not os.access(path, os.X_OK): + return None, f"Error: shell is not executable: {shell}" + return str(path), None + if "/" in shell or "\\" in shell: + return None, "Error: shell must be a shell name or absolute path" + if shell not in allowed: + return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh" + resolved = shutil.which(shell) + if not resolved: + return None, f"Error: shell not found: {shell}" + return resolved, None + @staticmethod async def _kill_process(process: asyncio.subprocess.Process) -> None: """Kill a subprocess and reap it to prevent zombies.""" @@ -418,7 +588,7 @@ class ExecTool(Tool): # Windows: match drive-root paths like `C:\` as well as `C:\path\to\file`, and UNC paths like `\\server\share` # NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted. win_paths = re.findall( - r"(?:[A-Za-z]:[^\s\"'|><;]*|\\\\[^\s\"'|><;]+(?:\\[^\s\"'|><;]+)*)", + r"(?<;]*|\\\\[^\s\"'|><;]+(?:\\[^\s\"'|><;]+)*)", command ) posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py new file mode 100644 index 000000000..2a38f60ac --- /dev/null +++ b/nanobot/channels/signal.py @@ -0,0 +1,1402 @@ +"""Signal channel implementation using signal-cli daemon JSON-RPC interface.""" + +from __future__ import annotations + +import asyncio +import json +import re +import shutil +import unicodedata +from collections import deque +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import httpx +from pydantic import Field, computed_field, field_validator + +from nanobot.bus.events import InboundMessage, OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.base import BaseChannel +from nanobot.config.paths import get_media_dir +from nanobot.config.schema import Base +from nanobot.pairing import is_approved +from nanobot.utils.helpers import safe_filename, split_message + + +@dataclass +class _Run: + text: str + styles: frozenset[str] = field(default_factory=frozenset) + opaque: bool = False # code / table content — skip further pattern processing + + +_SIG_CODE_BLOCK_RE = re.compile(r"```(?:\w+)?\n?([\s\S]*?)```") +_SIG_INLINE_CODE_RE = re.compile(r"`([^`\n]+)`") +_SIG_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE) +_SIG_BLOCKQUOTE_RE = re.compile(r"^>\s*(.*)$", re.MULTILINE) +_SIG_BULLET_RE = re.compile(r"^[-*]\s+", re.MULTILINE) +_SIG_OLIST_RE = re.compile(r"^(\d+)\.\s+", re.MULTILINE) +_SIG_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") +_SIG_BOLD_RE = re.compile(r"\*\*(.+?)\*\*|__(.+?)__", re.DOTALL) +_SIG_ITALIC_RE = re.compile( + r"(? int: + """UTF-16 code-unit length, matching Signal BodyRange semantics.""" + return len(s.encode("utf-16-le")) // 2 + + +def _sig_strip_cell(s: str) -> str: + """Strip inline markdown from a table cell for plain-text rendering.""" + for pattern, repl in _SIG_CELL_STRIP_PATTERNS: + s = pattern.sub(repl, s) + return s.strip() + + +def _sig_render_table(table_lines: list[str]) -> str: + """Render a markdown pipe-table as fixed-width plain text.""" + + def dw(s: str) -> int: + return sum(2 if unicodedata.east_asian_width(c) in ("W", "F") else 1 for c in s) + + rows: list[list[str]] = [] + has_sep = False + for line in table_lines: + cells = [_sig_strip_cell(c) for c in line.strip().strip("|").split("|")] + if all(re.match(r"^:?-+:?$", c) for c in cells if c): + has_sep = True + continue + rows.append(cells) + if not rows or not has_sep: + return "\n".join(table_lines) + + ncols = max(len(r) for r in rows) + for r in rows: + r.extend([""] * (ncols - len(r))) + widths = [max(dw(r[c]) for r in rows) for c in range(ncols)] + + def dr(cells: list[str]) -> str: + return " ".join(f"{c}{' ' * (w - dw(c))}" for c, w in zip(cells, widths)) + + out = [dr(rows[0])] + out.append(" ".join("─" * w for w in widths)) + for row in rows[1:]: + out.append(dr(row)) + return "\n".join(out) + + +def _markdown_to_signal(text: str) -> tuple[str, list[str]]: + """Convert markdown text to Signal plain text + textStyle ranges. + + Returns ``(plain_text, text_styles)`` where ``text_styles`` are + ``"start:length:STYLE"`` strings for the signal-cli ``textStyle`` parameter. + """ + if not text: + return text, [] + + # Phase 1 (text-level): extract code blocks and tables with placeholder tokens + # so they're protected from inline-style processing. + protected: list[str] = [] + + def save_code(m: re.Match) -> str: + protected.append(m.group(1)) + return f"\x00C{len(protected) - 1}\x00" + + text = _SIG_CODE_BLOCK_RE.sub(save_code, text) + + # Detect and render pipe-tables line by line. + lines = text.split("\n") + rebuilt: list[str] = [] + i = 0 + while i < len(lines): + if re.match(r"^\s*\|.+\|", lines[i]): + tbl: list[str] = [] + while i < len(lines) and re.match(r"^\s*\|.+\|", lines[i]): + tbl.append(lines[i]) + i += 1 + rendered = _sig_render_table(tbl) + if rendered != "\n".join(tbl): + protected.append(rendered) + rebuilt.append(f"\x00C{len(protected) - 1}\x00") + else: + rebuilt.extend(tbl) + else: + rebuilt.append(lines[i]) + i += 1 + text = "\n".join(rebuilt) + + # Phase 2 (run-based): process inline patterns. + runs: list[_Run] = [_Run(text)] + + def transform( + pattern: re.Pattern, + make_runs: Callable[[re.Match, frozenset[str]], list[_Run]], + ) -> None: + new_runs: list[_Run] = [] + for run in runs: + if run.opaque: + new_runs.append(run) + continue + pos = 0 + for m in pattern.finditer(run.text): + if m.start() > pos: + new_runs.append(_Run(run.text[pos : m.start()], run.styles)) + new_runs.extend(make_runs(m, run.styles)) + pos = m.end() + if pos < len(run.text): + new_runs.append(_Run(run.text[pos:], run.styles)) + runs[:] = new_runs + + # Restore code/table placeholders as opaque MONOSPACE runs. + transform( + _SIG_TOKEN_RE, + lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)], + ) + + # Inline code (opaque). + transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)]) + + # Headers → bold plain text. + transform(_SIG_HEADER_RE, lambda m, s: [_Run(m.group(1), s | {"BOLD"})]) + + # Blockquotes → strip marker. + transform(_SIG_BLOCKQUOTE_RE, lambda m, s: [_Run(m.group(1), s)]) + + # Bullet lists → bullet character. + transform(_SIG_BULLET_RE, lambda m, s: [_Run("• ", s)]) + + # Numbered lists → normalize spacing. + transform(_SIG_OLIST_RE, lambda m, s: [_Run(m.group(1) + ". ", s)]) + + # Links → "text (url)" or bare url when text equals url. + def _link_runs(m: re.Match, s: frozenset) -> list[_Run]: + link_text, url = m.group(1), m.group(2) + + def _norm(u: str) -> str: + return re.sub(r"^https?://(www\.)?", "", u).rstrip("/").lower() + + if _norm(url) == _norm(link_text): + return [_Run(url, s)] + return [_Run(f"{link_text} ({url})", s)] + + transform(_SIG_LINK_RE, _link_runs) + + # Bold (before italic so ** doesn't interfere). + transform(_SIG_BOLD_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"BOLD"})]) + + # Italic (single * or _). + transform(_SIG_ITALIC_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"ITALIC"})]) + + # Strikethrough: ~~text~~ (standard) or ~text~ (single-tilde variant). + transform(_SIG_STRIKE_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"STRIKETHROUGH"})]) + + # Phase 3: assemble output. Offsets and lengths are emitted in UTF-16 code + # units because Signal's BodyRange (via signal-cli's textStyle) interprets + # them as such; Python's len() counts code points, which would shift ranges + # left by 1 unit per non-BMP character preceding them. + plain_text = "" + text_styles: list[str] = [] + utf16_offset = 0 + for run in runs: + if not run.text: + continue + plain_text += run.text + start = utf16_offset + length = _utf16_len(run.text) + utf16_offset += length + for style in sorted(run.styles): + text_styles.append(f"{start}:{length}:{style}") + + return plain_text, text_styles + + +def _partition_styles( + plain_text: str, chunks: list[str], text_styles: list[str] +) -> list[list[str]]: + """Partition Signal textStyle ranges across message chunks. + + ``split_message`` slices ``plain_text`` into pieces (optionally trimming + whitespace at the boundaries), but the style ranges produced by + ``_markdown_to_signal`` are expressed in UTF-16 offsets relative to the + full ``plain_text``. This redistributes them per chunk with offsets + rebased to each chunk's start. Ranges that span a boundary are split + across the chunks they touch; ranges that fall entirely in trimmed + whitespace are dropped. + """ + if not chunks: + return [] + if not text_styles: + return [[] for _ in chunks] + + # Locate each chunk's UTF-16 start in plain_text. split_message lstrips at + # boundaries (but not before the first chunk), so we skip whitespace + # between chunks to mirror that. + chunk_ranges: list[tuple[int, int]] = [] + cursor = 0 # Python codepoint cursor in plain_text + for i, chunk in enumerate(chunks): + if i > 0: + while cursor < len(plain_text) and plain_text[cursor].isspace(): + cursor += 1 + utf16_start = _utf16_len(plain_text[:cursor]) + utf16_end = utf16_start + _utf16_len(chunk) + chunk_ranges.append((utf16_start, utf16_end)) + cursor += len(chunk) + + result: list[list[str]] = [[] for _ in chunks] + for entry in text_styles: + s, ln, style = entry.split(":", 2) + r_start = int(s) + r_end = r_start + int(ln) + for i, (c_start, c_end) in enumerate(chunk_ranges): + if r_end <= c_start or r_start >= c_end: + continue + new_start = max(r_start, c_start) - c_start + new_end = min(r_end, c_end) - c_start + new_length = new_end - new_start + if new_length > 0: + result[i].append(f"{new_start}:{new_length}:{style}") + return result + + +class SignalDMConfig(Base): + """Signal DM policy configuration.""" + + enabled: bool = False + policy: str = "allowlist" # "open" or "allowlist" + allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers/UUIDs + + +class SignalGroupConfig(Base): + """Signal group policy configuration.""" + + enabled: bool = False + policy: str = "allowlist" # "open" or "allowlist" - which groups to operate in + allow_from: list[str] = Field(default_factory=list) # Allowed group IDs if allowlist policy + require_mention: bool = True # Whether bot must be mentioned to respond + + +class SignalConfig(Base): + """Signal channel configuration using signal-cli daemon (HTTP mode with -a flag only).""" + + enabled: bool = False + phone_number: str = "" # Your Signal phone number (e.g., "+1234567890") + daemon_host: str = "localhost" + daemon_port: int = 8080 + group_message_buffer_size: int = 20 # Number of recent group messages to keep for context + # Override the directory signal-cli writes inbound attachments to. When + # None, defaults to ~/.local/share/signal-cli/attachments (the daemon's + # platform default on Linux). Set this if the daemon is running with a + # custom XDG_DATA_HOME or on macOS/Windows where the default path differs. + attachments_dir: str | None = None + dm: SignalDMConfig = Field(default_factory=SignalDMConfig) + group: SignalGroupConfig = Field(default_factory=SignalGroupConfig) + + @field_validator("group_message_buffer_size") + @classmethod + def _validate_buffer_size(cls, v: int) -> int: + if v <= 0: + raise ValueError("group_message_buffer_size must be > 0") + return v + + @computed_field # type: ignore[prop-decorator] + @property + def allow_from(self) -> list[str]: + """Aggregate allowlist for the base-class is_allowed() check. + + Returns the union of dm.allow_from and group.allow_from so the base + channel gate sees a populated list when either sub-policy is configured. + A ``"*"`` wildcard in either sub-list propagates to allow all. + """ + return list(dict.fromkeys(self.dm.allow_from + self.group.allow_from)) + + +class SignalChannel(BaseChannel): + """ + Signal channel using signal-cli daemon via HTTP JSON-RPC interface. + + Requires signal-cli daemon in HTTP mode: + - signal-cli -a +1234567890 daemon --http localhost:8080 + + See https://github.com/AsamK/signal-cli for setup instructions. + """ + + name = "signal" + display_name = "Signal" + _TYPING_REFRESH_SECONDS = 10.0 + _MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB) + _HTTP_TIMEOUT_SECONDS = 60.0 + + @classmethod + def default_config(cls) -> dict[str, Any]: + return SignalConfig().model_dump(by_alias=True) + + def __init__(self, config: SignalConfig, bus: MessageBus): + if isinstance(config, dict): + config = SignalConfig.model_validate(config) + super().__init__(config, bus) + self.config: SignalConfig = config + self._http: httpx.AsyncClient | None = None + self._request_id = 0 + self._sse_task: asyncio.Task | None = None + self._typing_tasks: dict[str, asyncio.Task] = {} + self._typing_uuid_warnings: set[str] = set() + self._account_id_aliases: set[str] = set() + self._remember_account_id_alias(self.config.phone_number) + + # Rolling message buffer for group context (group_id -> deque of messages) + # Each message is a dict with: sender_name, sender_number, content, timestamp + self._group_buffers: dict[str, deque] = {} + + def is_allowed(self, sender_id: str) -> bool: + """Override base check to normalize and split pipe-joined identifiers. + + ``sender_id`` from Signal is the pipe-joined composite produced by + ``_collect_sender_id_parts``; allow_from entries may be single + identifiers or composites and may use the ``+`` prefix variant or + not. Delegates to ``_sender_matches_allowlist`` so the base gate + matches the per-policy DM gate. + """ + allow_list = self.config.allow_from + if "*" in allow_list: + return True + if self._sender_matches_allowlist(sender_id, allow_list): + return True + if self._sender_approved_via_pairing(sender_id): + return True + if not allow_list: + self.logger.warning("allow_from is empty — all access denied") + return False + + def _sender_approved_via_pairing(self, sender_id: str) -> bool: + """Return True if any normalized variant of sender_id is in the pairing store. + + Pairing approval may be recorded under any of the identifier forms + signal exposes (phone with/without ``+``, UUID, ACI), so we check + each part of the pipe-joined composite against ``is_approved``. + """ + for part in str(sender_id).split("|"): + for variant in self._normalize_signal_id(part): + if is_approved(self.name, variant): + return True + return False + + async def _handle_message( + self, + sender_id: str, + chat_id: str, + content: str, + media: list[str] | None = None, + metadata: dict[str, Any] | None = None, + session_key: str | None = None, + is_dm: bool = False, + ) -> None: + """Handle an inbound message whose policy has already been checked. + + ``_check_inbound_policy`` is the authoritative gate for DM/group + access, so we skip the base-class ``is_allowed()`` check and publish + directly to the bus. The denied-DM pairing path calls + ``super()._handle_message`` instead, which goes through + ``is_allowed`` and issues a pairing code. + """ + meta = metadata or {} + if self.supports_streaming: + meta = {**meta, "_wants_stream": True} + await self.bus.publish_inbound( + InboundMessage( + channel=self.name, + sender_id=str(sender_id), + chat_id=str(chat_id), + content=content, + media=media or [], + metadata=meta, + session_key_override=session_key, + ) + ) + + async def start(self) -> None: + """Start the Signal channel and connect to signal-cli daemon.""" + if not self.config.phone_number: + self.logger.error("Signal account not configured") + return + + self._running = True + await self._start_http_mode() + + async def _start_http_mode(self) -> None: + """Start Signal channel using Server-Sent Events for receiving messages.""" + base_url = f"http://{self.config.daemon_host}:{self.config.daemon_port}" + reconnect_delay_s = 1.0 + max_reconnect_delay_s = 30.0 + + while self._running: + try: + self.logger.info("Connecting to signal-cli daemon at {}...", base_url) + + # Create HTTP client + self._http = httpx.AsyncClient( + timeout=self._HTTP_TIMEOUT_SECONDS, base_url=base_url + ) + + # Test connection + try: + response = await self._http.get("/api/v1/check") + if response.status_code == 200: + self.logger.info("Connected to signal-cli daemon") + else: + raise ConnectionRefusedError( + f"signal-cli daemon check returned status {response.status_code}" + ) + except Exception as e: + raise ConnectionRefusedError(f"signal-cli daemon not responding: {e}") + + # Reset reconnect delay after successful connection check. + reconnect_delay_s = 1.0 + + # Ensure account-level typing indicators are enabled. + await self._ensure_typing_indicators_enabled() + + # Start SSE receiver and supervise it. If it exits while we're still + # running, treat it as a disconnect and reconnect. + self._sse_task = asyncio.create_task(self._sse_receive_loop()) + await self._sse_task + if self._running: + raise ConnectionError("Signal SSE stream ended unexpectedly") + + except asyncio.CancelledError: + break + except ConnectionRefusedError as e: + self.logger.error( + "{}. Make sure signal-cli daemon is running: " + "signal-cli -a {} daemon --http {}:{}", + e, + self.config.phone_number, + self.config.daemon_host, + self.config.daemon_port, + ) + except Exception as e: + self.logger.error("Signal channel error: {}", e) + finally: + if self._sse_task: + if not self._sse_task.done(): + self._sse_task.cancel() + try: + await self._sse_task + except asyncio.CancelledError: + pass + except Exception: + pass + self._sse_task = None + if self._http: + await self._http.aclose() + self._http = None + + if self._running: + self.logger.info( + "Reconnecting to signal-cli daemon in {:.0f} seconds...", reconnect_delay_s + ) + await asyncio.sleep(reconnect_delay_s) + reconnect_delay_s = min(reconnect_delay_s * 2, max_reconnect_delay_s) + + async def stop(self) -> None: + """Stop the Signal channel.""" + self._running = False + + # Stop SSE task + if self._sse_task: + self._sse_task.cancel() + try: + await self._sse_task + except asyncio.CancelledError: + pass + + # Cancel active typing indicators + for chat_id in list(self._typing_tasks): + await self._stop_typing(chat_id) + + # Close HTTP client + if self._http: + await self._http.aclose() + self._http = None + + async def send(self, msg: OutboundMessage) -> None: + """Send a message through Signal.""" + is_progress_message = bool(msg.metadata.get("_progress")) + try: + plain_text, text_styles = _markdown_to_signal(msg.content) + if not plain_text and not msg.media: + return + recipient_params = self._recipient_params(msg.chat_id) + + chunks = split_message(plain_text, self._MAX_MESSAGE_LEN) if plain_text else [""] + chunk_styles = _partition_styles(plain_text, chunks, text_styles) + for i, chunk in enumerate(chunks): + params: dict[str, Any] = {"message": chunk} + if chunk_styles[i]: + params["textStyle"] = chunk_styles[i] + params.update(recipient_params) + if msg.media and i == 0: + params["attachments"] = msg.media + + response = await self._send_request("send", params) + + if "error" in response: + self.logger.error("Error sending Signal message: {}", response['error']) + raise RuntimeError(f"signal-cli send failed: {response['error']}") + else: + self.logger.debug( + f"Signal message sent, timestamp: {response.get('result', {}).get('timestamp')}" + ) + + except Exception: + self.logger.exception("Error sending Signal message") + raise + finally: + # Keep typing active across progress updates; stop on the final reply. + if not is_progress_message: + # Avoid immediate START->STOP for fast responses, which can be invisible + # in some Signal clients. Let indicator expire naturally (~15s). + await self._stop_typing(msg.chat_id, send_stop=False) + + async def _sse_receive_loop(self) -> None: + """Receive messages via Server-Sent Events (HTTP mode).""" + if not self._http: + raise RuntimeError("HTTP client not initialized for Signal SSE stream") + + self.logger.info("Started Signal message receive loop (SSE)") + + try: + async with self._http.stream("GET", "/api/v1/events") as response: + if response.status_code != 200: + raise ConnectionError( + f"SSE connection failed with status {response.status_code}" + ) + + self.logger.info("Subscribed to Signal messages via SSE") + + # Buffer for accumulating SSE data across multiple lines + event_buffer = [] + + async for line in response.aiter_lines(): + if not self._running: + break + + # Debug: log raw SSE lines (except keepalive pings) + if line and line != ":": + self.logger.debug("SSE line received: {}", line[:200]) + + # SSE format handling + if isinstance(line, str): + # Empty line signals end of event + if not line or line == ":": + if event_buffer: + # Try to parse the accumulated data + data_str = "" + try: + data_str = "\n".join(event_buffer) + data = json.loads(data_str) + self.logger.debug("SSE event parsed: {}", data) + await self._handle_receive_notification(data) + except json.JSONDecodeError as e: + self.logger.warning( + "Invalid JSON in SSE buffer: {}, data: {}", + e, + data_str[:200], + ) + finally: + event_buffer = [] + + # "data:" line - accumulate it + elif line.startswith("data:"): + # SSE spec: strip one optional leading space after "data:". + event_buffer.append(line[6:] if line[5:6] == " " else line[5:]) + + # "event:" line - just log it (we only care about data) + elif line.startswith("event:"): + pass # Ignore event type for now + + if self._running: + raise ConnectionError("Signal SSE stream closed by remote endpoint") + + except asyncio.CancelledError: + self.logger.info("SSE receive loop cancelled") + raise + except Exception as e: + self.logger.error("Error in SSE receive loop: {}", e) + raise + + @asynccontextmanager + async def _safe_handle(self, action: str, payload: Any = None) -> AsyncIterator[None]: + """Swallow and log any exception from a top-level handler block. + + Logs `self.logger.error` with the action name, the exception, and a + bounded ``repr`` of the offending payload so the offending input is + recoverable from logs without having to correlate by timestamp. + """ + try: + yield + except Exception as e: + snippet = repr(payload)[:200] if payload is not None else "" + text = f"Error in {action}: {e}" + if snippet: + text += f" | payload={snippet}" + self.logger.opt(exception=True).error(text) + + async def _handle_receive_notification(self, params: dict[str, Any]) -> None: + """Handle incoming message notification from signal-cli.""" + self.logger.debug("_handle_receive_notification called with: {}", params) + async with self._safe_handle("receive notification", params): + # Extract envelope from SSE notification: {"envelope": {...}} + envelope = params.get("envelope", {}) + + self.logger.debug("Extracted envelope: {}", envelope) + + if not envelope: + self.logger.debug("No envelope found in params") + return + + # Extract sender information + sender_parts = self._collect_sender_id_parts(envelope) + source_name = envelope.get("sourceName") + + if not sender_parts: + self.logger.debug("Received message without source, skipping") + return + + sender_number = self._primary_sender_id(sender_parts) + sender_id = "|".join(sender_parts) + + # Keep aliases of the bot account for robust mention matching. + if any(self._id_matches_account(part) for part in sender_parts): + for part in sender_parts: + self._remember_account_id_alias(part) + + # Check different message types + data_message = envelope.get("dataMessage") + sync_message = envelope.get("syncMessage") + typing_message = envelope.get("typingMessage") + receipt_message = envelope.get("receiptMessage") + + # Ignore receipt messages (delivery/read receipts) + if receipt_message: + return + + # Handle data messages (incoming messages from others) + if data_message: + await self._handle_data_message(sender_id, sender_number, data_message, source_name) + + # Handle sync messages (messages sent from another device) + elif sync_message and sync_message.get("sentMessage"): + sent_msg = sync_message["sentMessage"] + destination = sent_msg.get("destination") or sent_msg.get("destinationNumber") + if destination: + self.logger.debug( + "Sync message sent to {}: {}", destination, sent_msg.get("message", "")[:50] + ) + + # Handle typing indicators (silently ignore) + elif typing_message: + pass # Ignore typing indicators + + async def _handle_data_message( + self, + sender_id: str, + sender_number: str, + data_message: dict[str, Any], + sender_name: str | None, + ) -> None: + """Handle a data message (text, attachments, etc.).""" + message_text = data_message.get("message") or "" + attachments = data_message.get("attachments", []) + mentions = data_message.get("mentions", []) + timestamp = data_message.get("timestamp") + + self.logger.info( + "Data message from {}: groupInfo={}, groupV2={}, keys={}", + sender_number, + data_message.get("groupInfo"), + data_message.get("groupV2"), + list(data_message.keys()), + ) + + if data_message.get("reaction"): + self.logger.debug( + "Ignoring reaction message from {}: {}", sender_number, data_message["reaction"] + ) + return + if not message_text and not attachments: + self.logger.debug("Ignoring empty message from {}", sender_number) + return + + group_info = data_message.get("groupInfo") + group_v2 = data_message.get("groupV2") + is_group_message = group_info is not None or group_v2 is not None + group_id = self._extract_group_id(group_info, group_v2) + + allowed, chat_id = self._check_inbound_policy( + sender_id=sender_id, + sender_number=sender_number, + group_id=group_id, + is_group_message=is_group_message, + message_text=message_text, + mentions=mentions, + sender_name=sender_name, + timestamp=timestamp, + ) + if not allowed: + # Mirror Slack: let denied DMs reach the base-class + # _handle_message so it can reply with a pairing code. + # Group denials stay dropped. + if not is_group_message and self.config.dm.enabled: + await super()._handle_message( + sender_id=sender_id, + chat_id=chat_id, + content="", + is_dm=True, + ) + return + + content, media_paths = self._assemble_inbound_content( + sender_name=sender_name, + sender_number=sender_number, + message_text=message_text, + attachments=attachments, + mentions=mentions, + is_group_message=is_group_message, + chat_id=chat_id, + ) + + self.logger.debug("Signal message from {}: {}...", sender_number, content[:50]) + + await self._start_typing(chat_id) + try: + await self._handle_message( + sender_id=sender_id, + chat_id=chat_id, + content=content, + media=media_paths, + metadata={ + "timestamp": timestamp, + "sender_name": sender_name, + "sender_number": sender_number, + "is_group": is_group_message, + "group_id": group_id, + }, + is_dm=not is_group_message, + ) + except Exception: + await self._stop_typing(chat_id) + raise + + def _check_inbound_policy( + self, + *, + sender_id: str, + sender_number: str, + group_id: str | None, + is_group_message: bool, + message_text: str, + mentions: list, + sender_name: str | None, + timestamp: int | None, + ) -> tuple[bool, str]: + """Decide whether to route an inbound message past DM/group policy. + + Returns ``(allow, chat_id)``. Has one side effect: when a group + message passes the enabled+allowlist gates, it is appended to the + group's rolling context buffer before the mention check. + """ + if is_group_message: + chat_id = group_id or sender_number + if not self.config.group.enabled: + self.logger.info("Ignoring group message from {} (groups disabled)", chat_id) + return False, chat_id + if ( + self.config.group.policy == "allowlist" + and chat_id not in self.config.group.allow_from + ): + self.logger.info( + "Ignoring group message from {} (policy: {})", + chat_id, + self.config.group.policy, + ) + return False, chat_id + + self._add_to_group_buffer( + group_id=chat_id, + sender_name=sender_name or sender_number, + sender_number=sender_number, + message_text=message_text, + timestamp=timestamp, + ) + + is_command = bool(message_text and message_text.strip().startswith("/")) + if not is_command and not self._should_respond_in_group(message_text, mentions): + self.logger.info( + "Ignoring group message (require_mention: {})", + self.config.group.require_mention, + ) + return False, chat_id + return True, chat_id + + # Direct message + chat_id = sender_number + if not self.config.dm.enabled: + self.logger.debug("Ignoring DM from {} (DMs disabled)", sender_id) + return False, chat_id + if self.config.dm.policy == "allowlist": + if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from): + self.logger.debug( + "Ignoring DM from {} (policy: {})", sender_id, self.config.dm.policy + ) + return False, chat_id + return True, chat_id + + def _assemble_inbound_content( + self, + *, + sender_name: str | None, + sender_number: str, + message_text: str, + attachments: list, + mentions: list, + is_group_message: bool, + chat_id: str, + ) -> tuple[str, list[str]]: + """Build ``(content, media_paths)`` for an inbound message. + + Pulls in group context, strips bot mentions, prefixes the sender's + display name on group messages, and copies any attachments from + signal-cli's storage into the channel media dir. + """ + content_parts: list[str] = [] + media_paths: list[str] = [] + + if is_group_message: + buffer_context = self._get_group_buffer_context(chat_id) + if buffer_context: + content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---") + + if message_text: + if is_group_message: + message_text = self._strip_bot_mention(message_text, mentions) + display_name = sender_name or sender_number + message_text = f"[{display_name}]: {message_text}" + content_parts.append(message_text) + + if attachments: + media_dir = get_media_dir("signal") + for attachment in attachments: + attachment_id = attachment.get("id") + content_type = attachment.get("contentType", "") + filename = attachment.get("filename") or f"attachment_{attachment_id}" + if not attachment_id: + continue + try: + source_path = self._signal_attachments_dir() / attachment_id + if source_path.exists(): + dest_path = media_dir / f"signal_{safe_filename(filename)}" + shutil.copy2(source_path, dest_path) + media_paths.append(str(dest_path)) + media_type = content_type.split("/")[0] if "/" in content_type else "file" + if media_type not in ("image", "audio", "video"): + media_type = "file" + content_parts.append(f"[{media_type}: {dest_path}]") + self.logger.debug("Downloaded attachment: {} -> {}", filename, dest_path) + else: + self.logger.warning("Attachment not found: {}", source_path) + content_parts.append(f"[attachment: {filename} - not found]") + except Exception as e: + self.logger.warning("Failed to process attachment {}: {}", filename, e) + content_parts.append(f"[attachment: {filename} - error]") + + content = "\n".join(content_parts) if content_parts else "[empty message]" + return content, media_paths + + def _add_to_group_buffer( + self, + group_id: str, + sender_name: str, + sender_number: str, + message_text: str, + timestamp: int | None, + ) -> None: + """ + Add a message to the group's rolling buffer. + + Args: + group_id: The group ID + sender_name: Display name of sender + sender_number: Phone number of sender + message_text: The message content + timestamp: Message timestamp + """ + # Create buffer for this group if it doesn't exist + if group_id not in self._group_buffers: + self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size) + + # Add message to buffer (deque will automatically drop oldest when full) + self._group_buffers[group_id].append( + { + "sender_name": sender_name, + "sender_number": sender_number, + "content": message_text, + "timestamp": timestamp, + } + ) + + self.logger.debug( + "Added message to group buffer {}: {}/{}", + group_id, + len(self._group_buffers[group_id]), + self.config.group_message_buffer_size, + ) + + def _get_group_buffer_context(self, group_id: str) -> str: + """ + Get formatted context from the group's message buffer. + + Args: + group_id: The group ID + + Returns: + Formatted string of recent messages (excluding the current one) + """ + if group_id not in self._group_buffers: + return "" + + buffer = self._group_buffers[group_id] + if len(buffer) <= 1: # Only current message, no context + return "" + + # Format all messages except the last one (which is the current message) + # We want to show context BEFORE the mention + context_messages = list(buffer)[:-1] # Exclude the last (current) message + + lines = [] + for msg in context_messages: + sender = msg["sender_name"] + content = msg["content"][:200] # Limit to 200 chars per message + lines.append(f"{sender}: {content}") + + return "\n".join(lines) + + def _signal_attachments_dir(self) -> Path: + """Return the directory signal-cli writes inbound attachments to. + + Defaults to ``~/.local/share/signal-cli/attachments`` (the daemon's + platform default on Linux) when ``config.attachments_dir`` is unset. + """ + configured = self.config.attachments_dir + if configured: + return Path(configured).expanduser() + return Path.home() / ".local/share/signal-cli/attachments" + + @staticmethod + def _normalize_signal_id(value: str) -> list[str]: + """Normalize Signal identifiers (phone/uuid/service-id) for matching.""" + raw = value.strip() + if not raw: + return [] + + normalized = [raw, raw.lower()] + if raw.startswith("+") and len(raw) > 1: + normalized.append(raw[1:]) + elif raw.isdigit(): + normalized.append(f"+{raw}") + return list(dict.fromkeys(normalized)) + + @classmethod + def _sender_matches_allowlist(cls, sender_id: str, allow_list: list[str]) -> bool: + """Return True if any normalized variant of sender_id is on allow_list. + + Both ``sender_id`` and each allow_list entry can be a single + identifier or a pipe-joined composite of several (e.g. + ``"+1234567890|uuid-abc"``); both sides are split on ``|`` and each + part is run through ``_normalize_signal_id`` so an allowlist entry + like ``1234567890`` matches a sender ``+1234567890`` (and vice + versa), and case-only differences in UUIDs/ACIs match too. + """ + if not allow_list: + return False + sender_variants: set[str] = set() + for part in str(sender_id).split("|"): + sender_variants.update(cls._normalize_signal_id(part)) + if not sender_variants: + return False + allow_variants: set[str] = set() + for entry in allow_list: + for part in str(entry).split("|"): + allow_variants.update(cls._normalize_signal_id(part)) + return bool(sender_variants & allow_variants) + + def _remember_account_id_alias(self, value: str | None) -> None: + """Remember known bot identifiers for mention matching.""" + if not value: + return + if not isinstance(value, str): + return + for candidate in self._normalize_signal_id(value): + self._account_id_aliases.add(candidate) + + def _id_matches_account(self, value: str | None) -> bool: + """Return True when an identifier refers to the bot account.""" + if not value: + return False + if not isinstance(value, str): + return False + return any( + candidate in self._account_id_aliases for candidate in self._normalize_signal_id(value) + ) + + @staticmethod + def _collect_sender_id_parts(envelope: dict[str, Any]) -> list[str]: + """Collect all known sender identifier variants from an envelope.""" + parts: list[str] = [] + for key in ( + "sourceNumber", + "source", + "sourceUuid", + "sourceServiceId", + "sourceAci", + "sourceACI", + ): + value = envelope.get(key) + if not isinstance(value, str): + continue + candidate = value.strip() + if candidate and candidate not in parts: + parts.append(candidate) + return parts + + @staticmethod + def _primary_sender_id(sender_parts: list[str]) -> str: + """Pick the best sender identifier for routing (prefer phone-like IDs).""" + for part in sender_parts: + if part.startswith("+") or part.isdigit(): + return part + return sender_parts[0] if sender_parts else "" + + @staticmethod + def _extract_group_id(group_info: Any, group_v2: Any) -> str | None: + """Extract group ID from groupInfo/groupV2 payloads across signal-cli variants.""" + for group_obj in (group_info, group_v2): + if not isinstance(group_obj, dict): + continue + for key in ("groupId", "id", "groupID"): + value = group_obj.get(key) + if isinstance(value, str) and value: + return value + return None + + @staticmethod + def _mention_id_candidates(mention: dict[str, Any]) -> list[str]: + """Extract possible identifier fields from a mention payload.""" + ids: list[str] = [] + + def _walk(value: dict[str, Any] | Any, depth: int = 0) -> None: + if depth > 2: + return + if not isinstance(value, dict): + return + for key, child in value.items(): + key_lower = str(key).lower() + if isinstance(child, str) and child: + if any(token in key_lower for token in ("number", "uuid", "serviceid", "aci")): + ids.append(child) + elif isinstance(child, dict): + _walk(child, depth + 1) + + _walk(mention) + return list(dict.fromkeys(ids)) + + @staticmethod + def _mention_span(mention: dict[str, Any]) -> tuple[int, int] | None: + """Extract a safe (start, length) span from a mention.""" + try: + start = int(mention.get("start", 0)) + length = int(mention.get("length", 0)) + except (TypeError, ValueError): + return None + + if start < 0 or length <= 0: + return None + return (start, length) + + @staticmethod + def _leading_placeholder_span(text: str | None) -> tuple[int, int] | None: + """ + Detect a leading Signal mention placeholder when mention metadata is missing. + + Some clients/integrations deliver mentions as a leading placeholder character + (typically U+FFFC) but omit `mentions` metadata in the payload. + """ + if not text: + return None + + start = 0 + while start < len(text) and text[start].isspace(): + start += 1 + + if start >= len(text): + return None + + marker = text[start] + if marker not in ("\ufffc", "\ufffd", "\x1b"): + return None + + next_index = start + 1 + if next_index < len(text) and not text[next_index].isspace(): + return None + + return (start, 1) + + def _should_respond_in_group(self, message_text: str, mentions: list[dict[str, Any]]) -> bool: + """ + Determine if the bot should respond to a group message. + + Args: + message_text: The message text content + mentions: List of mentions from Signal (format: [{"number": "+1234567890", "start": 0, "length": 10}]) + + Returns: + True if bot should respond, False otherwise + """ + # Group reply behavior is controlled only by group.require_mention. + if not self.config.group.require_mention: + return True + + # If mention is required, check if bot was mentioned. + for mention in mentions: + if not isinstance(mention, dict): + continue + for mention_id in self._mention_id_candidates(mention): + if self._id_matches_account(mention_id): + return True + + # Some Signal clients emit mention spans without recipient identifiers + # (for handle-style mentions). Accept a leading identifier-less mention + # as a mention of the bot to avoid false negatives. + for mention in mentions: + if not isinstance(mention, dict): + continue + if self._mention_id_candidates(mention): + continue + span = self._mention_span(mention) + if not span: + continue + start, _ = span + if message_text is not None and not message_text[:start].strip(): + self.logger.debug("Accepting identifier-less leading mention as bot mention") + return True + + # Some payloads omit `mentions` but still include the leading mention + # placeholder character in the message body. + if not mentions and self._leading_placeholder_span(message_text): + self.logger.debug("Accepting leading placeholder mention without mention metadata") + return True + + # Fallback: check for configured phone number in plain text. + if message_text and self.config.phone_number: + for account_id in self._normalize_signal_id(self.config.phone_number): + if account_id and account_id in message_text: + return True + + return False + + def _strip_bot_mention(self, text: str, mentions: list[dict[str, Any]]) -> str: + """ + Remove bot mentions from message text. + + Signal mentions are embedded in the text, so we need to remove them based on + the mentions array which provides start position and length. + + Args: + text: Original message text + mentions: List of mention objects with start/length positions + + Returns: + Text with bot mentions removed + """ + if not text: + return text + + # Build a list of (start, length) tuples for our bot's mentions + bot_mentions = [] + for mention in mentions: + if not isinstance(mention, dict): + continue + mention_ids = self._mention_id_candidates(mention) + span = self._mention_span(mention) + if not span: + continue + + # Strip matched bot mentions by ID. + if any(self._id_matches_account(mention_id) for mention_id in mention_ids): + bot_mentions.append(span) + continue + + # Also strip identifier-less leading mention spans (handle mentions). + if not mention_ids: + start, _ = span + if not text[:start].strip(): + bot_mentions.append(span) + + if not bot_mentions: + placeholder_span = self._leading_placeholder_span(text) + if placeholder_span: + bot_mentions.append(placeholder_span) + + # Sort mentions by start position (descending) to remove from end to start + # This prevents position shifts when removing earlier mentions + bot_mentions.sort(reverse=True) + + # Remove each mention + for start, length in bot_mentions: + if start >= len(text): + continue + end = min(len(text), start + length) + text = text[:start] + text[end:] + + return text.strip() + + @staticmethod + def _is_group_chat_id(chat_id: str) -> bool: + """Return True when chat_id appears to be a Signal group ID (base64).""" + return "=" in chat_id or (len(chat_id) > 40 and "-" not in chat_id) + + def _recipient_params(self, chat_id: str) -> dict[str, Any]: + """Build recipient params for signal-cli JSON-RPC methods.""" + if self._is_group_chat_id(chat_id): + return {"groupId": chat_id} + return {"recipient": [chat_id]} + + async def _start_typing(self, chat_id: str) -> None: + """Start periodic typing indicator updates for a chat.""" + await self._stop_typing(chat_id, send_stop=False) + await self._send_typing(chat_id) + self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id)) + + async def _stop_typing(self, chat_id: str, send_stop: bool = True) -> None: + """Stop typing indicator updates for a chat.""" + task = self._typing_tasks.pop(chat_id, None) + had_task = task is not None + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + if send_stop and had_task: + await self._send_typing(chat_id, stop=True) + + async def _typing_loop(self, chat_id: str) -> None: + """Send typing updates periodically until cancelled.""" + try: + while self._running: + await asyncio.sleep(self._TYPING_REFRESH_SECONDS) + await self._send_typing(chat_id, quiet_success=True) + except asyncio.CancelledError: + pass + except Exception as e: + self.logger.debug("Typing indicator loop stopped for {}: {}", chat_id, e) + + async def _send_typing( + self, chat_id: str, stop: bool = False, quiet_success: bool = False + ) -> None: + """Send a typing START/STOP message via signal-cli.""" + action = "stop" if stop else "start" + if ( + not self._is_group_chat_id(chat_id) + and chat_id.startswith("+") is False + and chat_id not in self._typing_uuid_warnings + ): + self._typing_uuid_warnings.add(chat_id) + self.logger.warning( + "Signal DM recipient is UUID-only (no phone number in envelope). " + "Some Signal clients may not render typing indicators for this recipient form." + ) + candidate_params: list[dict[str, Any]] + if self._is_group_chat_id(chat_id): + candidate_params = [{"groupId": chat_id}, {"groupId": [chat_id]}] + else: + candidate_params = [{"recipient": chat_id}, {"recipient": [chat_id]}] + + last_error: Any | None = None + for params in candidate_params: + if stop: + params["stop"] = True + try: + response = await self._send_request("sendTyping", params) + except Exception as e: + last_error = str(e) + continue + + if "error" not in response: + if not quiet_success: + self.logger.info("Signal typing {} sent for {}", action, chat_id) + return + + last_error = response["error"] + + self.logger.warning( + "Failed to send Signal typing {} for {}: {}", action, chat_id, last_error + ) + + async def _ensure_typing_indicators_enabled(self) -> None: + """Enable typing indicators on the bot account.""" + response = await self._send_request("updateConfiguration", {"typingIndicators": True}) + if "error" in response: + self.logger.warning( + "Failed to enable Signal typing indicators: {}", response["error"] + ) + else: + self.logger.info("Signal typing indicators enabled on account configuration") + + async def _send_request( + self, method: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Send a JSON-RPC request via HTTP and wait for response.""" + # Generate request ID + self._request_id += 1 + request_id = self._request_id + + # Build JSON-RPC request + request = {"jsonrpc": "2.0", "method": method, "id": request_id} + + if params: + request["params"] = params + + return await self._send_http_request(request) + + async def _send_http_request(self, request: dict[str, Any]) -> dict[str, Any]: + """Send JSON-RPC request via HTTP.""" + if not self._http: + raise RuntimeError("Not connected to signal-cli daemon") + + try: + response = await self._http.post("/api/v1/rpc", json=request) + response.raise_for_status() + return response.json() + except Exception as e: + self.logger.error("HTTP request failed: {}", e) + return {"error": {"message": str(e)}} diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 41390f8b3..a75c897f4 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -79,6 +79,12 @@ BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION} ERRCODE_SESSION_EXPIRED = -14 SESSION_PAUSE_DURATION_S = 60 * 60 +# iLink context_token is observed to expire server-side after ~90-160s of +# agent inactivity (openclaw/openclaw#61174). Proactively refresh before +# sending if the cached token is older than this threshold. +CONTEXT_TOKEN_MAX_AGE_S = 60 + + # Retry constants (matching the reference plugin's monitor.ts) MAX_CONSECUTIVE_FAILURES = 3 BACKOFF_DELAY_S = 30 @@ -159,6 +165,8 @@ class WeixinChannel(BaseChannel): self._session_pause_until: float = 0.0 self._typing_tasks: dict[str, asyncio.Task] = {} self._typing_tickets: dict[str, dict[str, Any]] = {} + self._context_token_at: dict[str, float] = {} + self._pending_tool_hints: dict[str, list[str]] = {} # ------------------------------------------------------------------ # State persistence @@ -486,6 +494,7 @@ class WeixinChannel(BaseChannel): except Exception: if not self._running: break + self.logger.exception("WeChat poll loop error") consecutive_failures += 1 if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: consecutive_failures = 0 @@ -495,6 +504,7 @@ class WeixinChannel(BaseChannel): async def stop(self) -> None: self._running = False + self._pending_tool_hints.clear() if self._poll_task and not self._poll_task.done(): self._poll_task.cancel() for chat_id in list(self._typing_tasks): @@ -545,6 +555,7 @@ class WeixinChannel(BaseChannel): # Check for API-level errors (monitor.ts checks both ret and errcode) ret = data.get("ret", 0) errcode = data.get("errcode", 0) + is_error = (ret is not None and ret != 0) or (errcode is not None and errcode != 0) if is_error: @@ -575,8 +586,10 @@ class WeixinChannel(BaseChannel): # Process messages (WeixinMessage[] from types.ts) msgs: list[dict] = data.get("msgs", []) or [] for msg in msgs: - with suppress(Exception): + try: await self._process_message(msg) + except Exception: + self.logger.exception("Failed to process WeChat message") # ------------------------------------------------------------------ # Inbound message processing (matches inbound.ts + process-message.ts) @@ -610,6 +623,7 @@ class WeixinChannel(BaseChannel): ctx_token = msg.get("context_token", "") if ctx_token: self._context_tokens[from_user_id] = ctx_token + self._context_token_at[from_user_id] = time.time() self._save_state() # Parse item_list (WeixinMessage.item_list — types.ts:161) @@ -915,6 +929,99 @@ class WeixinChannel(BaseChannel): } return "" + async def _refresh_context_token_if_stale( + self, chat_id: str, context_token: str + ) -> str: + """Return a fresh context_token if the cached one is too old. + + iLink context_token expires server-side after a short idle period + (empirically ~90s). Proactively refreshing before sending prevents + silent message loss on long agent turns or cron pushes. + """ + if not context_token: + return context_token + + now = time.time() + cached_at = self._context_token_at.get(chat_id, 0) + age = now - cached_at + + if age < CONTEXT_TOKEN_MAX_AGE_S: + return context_token + + self.logger.debug( + "WeChat context_token for {} is {:.0f}s old; refreshing via getconfig", + chat_id, + age, + ) + + body: dict[str, Any] = { + "ilink_user_id": chat_id, + "context_token": context_token, + "base_info": BASE_INFO, + } + try: + data = await self._api_post("ilink/bot/getconfig", body) + except Exception as e: + self.logger.warning("WeChat getconfig failed for {}: {}", chat_id, e) + return context_token + + if data.get("ret", 0) != 0: + self.logger.warning( + "WeChat getconfig returned ret={} for {}: {}", + data.get("ret"), + chat_id, + data.get("errmsg", ""), + ) + return context_token + + new_token = str(data.get("context_token", "") or "") + if new_token and new_token != context_token: + self.logger.info( + "WeChat context_token refreshed for {} (age {:.0f}s -> fresh)", + chat_id, + age, + ) + self._context_tokens[chat_id] = new_token + self._context_token_at[chat_id] = now + self._save_state() + return new_token + + return context_token + + async def _flush_tool_hints(self, chat_id: str) -> None: + """Send any buffered tool hints for *chat_id* as a single message. + + Tool hints are coalesced to reduce message count and avoid hitting the + WeChat iLink rate limit (~7 msgs / 5 min). Failures are logged but + not raised so that the main message send is never blocked. + """ + hints = self._pending_tool_hints.pop(chat_id, None) + if not hints: + return + + self.logger.info( + "Flushing {} buffered tool hint(s) for {}", + len(hints), + chat_id, + ) + + ctx_token = self._context_tokens.get(chat_id, "") + ctx_token = await self._refresh_context_token_if_stale(chat_id, ctx_token) + if not ctx_token: + self.logger.warning( + "Dropped {} buffered tool hint(s) for {}: no context_token", + len(hints), + chat_id, + ) + return + + try: + await self._send_text(chat_id, "\n\n".join(hints), ctx_token) + except Exception: + self.logger.exception( + "Failed to flush buffered tool hints for {}", chat_id + ) + async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None: """Best-effort sendtyping wrapper.""" if not typing_ticket: @@ -944,11 +1051,47 @@ class WeixinChannel(BaseChannel): self._assert_session_active() is_progress = bool((msg.metadata or {}).get("_progress", False)) + + # Buffer tool hints to coalesce consecutive ones and avoid burning + # WeChat iLink rate-limit quota (~7 msgs / 5 min). + if is_progress and (msg.metadata or {}).get("_tool_hint"): + if not self.send_tool_hints: + return + self._pending_tool_hints.setdefault(msg.chat_id, []).append(msg.content) + self.logger.debug( + "Buffered tool hint for {} (count={})", + msg.chat_id, + len(self._pending_tool_hints[msg.chat_id]), + ) + return + + # Reasoning deltas are invisible in WeChat (there is no reasoning + # UI). Skip them entirely — do not send and do not flush buffer. + if is_progress and (msg.metadata or {}).get("_reasoning_delta"): + self.logger.debug( + "Dropped invisible reasoning delta for {}", msg.chat_id + ) + return + + content = msg.content.strip() + + # Empty progress messages (e.g. after_iteration tool_events) must + # NOT act as separators — they have no visible content. + if is_progress and not content and not (msg.media or []): + self.logger.debug( + "Skipped empty progress message for {} (no visible content)", + msg.chat_id, + ) + return + + # Flush buffered hints before sending any visible message. + await self._flush_tool_hints(msg.chat_id) + if not is_progress: await self._stop_typing(msg.chat_id, clear_remote=True) - content = msg.content.strip() ctx_token = self._context_tokens.get(msg.chat_id, "") + ctx_token = await self._refresh_context_token_if_stale(msg.chat_id, ctx_token) if not ctx_token: raise RuntimeError( f"WeChat context_token missing for chat_id={msg.chat_id}, cannot send" @@ -1037,6 +1180,18 @@ class WeixinChannel(BaseChannel): with suppress(Exception): await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL) + async def send_delta( + self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None + ) -> None: + """Weixin iLink does not support native streaming deltas. + + We only hook ``_stream_end`` so buffered tool hints are flushed even + when the final answer carries the ``_streamed`` flag and bypasses + :meth:`send`. + """ + if metadata and metadata.get("_stream_end"): + await self._flush_tool_hints(chat_id) + async def _start_typing(self, chat_id: str, context_token: str = "") -> None: """Start typing indicator immediately when a message is received.""" if not self._client or not self._token or not chat_id: @@ -1120,10 +1275,11 @@ class WeixinChannel(BaseChannel): } data = await self._api_post("ilink/bot/sendmessage", body) + ret = data.get("ret", 0) errcode = data.get("errcode", 0) - if errcode and errcode != 0: + if (ret is not None and ret != 0) or (errcode is not None and errcode != 0): raise RuntimeError( - f"WeChat send text error (code {errcode}): {data.get('errmsg', '')}" + f"WeChat send text error (ret={ret}, errcode={errcode}): {data.get('errmsg', '')}" ) async def _send_media_file( @@ -1270,10 +1426,11 @@ class WeixinChannel(BaseChannel): } data = await self._api_post("ilink/bot/sendmessage", body) + ret = data.get("ret", 0) errcode = data.get("errcode", 0) - if errcode and errcode != 0: + if (ret is not None and ret != 0) or (errcode is not None and errcode != 0): raise RuntimeError( - f"WeChat send media error (code {errcode}): {data.get('errmsg', '')}" + f"WeChat send media error (ret={ret}, errcode={errcode}): {data.get('errmsg', '')}" ) diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index c0ad7e758..2e094cc09 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -211,6 +211,7 @@ class ProvidersConfig(Base): ant_ling: ProviderConfig = Field(default_factory=ProviderConfig) # Ant Ling aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) + novita: ProviderConfig = Field(default_factory=ProviderConfig) # Novita AI volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international) diff --git a/nanobot/providers/image_generation.py b/nanobot/providers/image_generation.py index 3ea8c374a..6cab279b1 100644 --- a/nanobot/providers/image_generation.py +++ b/nanobot/providers/image_generation.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import base64 import binascii import re @@ -898,6 +899,426 @@ def _minimax_images_from_payload(payload: dict[str, Any]) -> list[str]: return images +# --------------------------------------------------------------------------- +# OpenAI image generation +# --------------------------------------------------------------------------- + +_OPENAI_DALLE2_SUPPORTED_SIZES = {"256x256", "512x512", "1024x1024"} +_OPENAI_DALLE3_SUPPORTED_SIZES = {"1024x1024", "1792x1024", "1024x1792"} +_OPENAI_GPT_IMAGE_SUPPORTED_SIZES = { + "1024x1024", + "1536x1024", + "1024x1536", + "auto", +} +_OPENAI_DALLE2_ASPECT_RATIO_SIZES = { + "1:1": "1024x1024", + "16:9": "1024x1024", + "9:16": "1024x1024", + "3:4": "1024x1024", + "4:3": "1024x1024", +} +_OPENAI_DALLE3_ASPECT_RATIO_SIZES = { + "1:1": "1024x1024", + "16:9": "1792x1024", + "9:16": "1024x1792", + "3:4": "1024x1792", + "4:3": "1792x1024", +} +_OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES = { + "1:1": "1024x1024", + "16:9": "1536x1024", + "9:16": "1024x1536", + "3:4": "1024x1536", + "4:3": "1536x1024", +} + + +class OpenAIImageGenerationClient(ImageGenerationProvider): + """OpenAI Images API using an API key (``providers.openai.apiKey``).""" + + provider_name = "openai" + missing_key_message = ( + "OpenAI API key is not configured. Set providers.openai.apiKey." + ) + + def _default_base_url(self) -> str: + return "https://api.openai.com/v1" + + @staticmethod + def _strip_model_prefix(model: str) -> str: + """Remove ``openai/`` prefix if present (OpenRouter convention).""" + if model.startswith("openai/") or model.startswith("openai_codex/"): + return model.split("/", 1)[1] + return model + + async def generate( + self, + *, + prompt: str, + model: str, + reference_images: list[str] | None = None, + aspect_ratio: str | None = None, + image_size: str | None = None, + ) -> GeneratedImageResponse: + if not self.api_key: + raise ImageGenerationError(self.missing_key_message) + + if reference_images: + logger.warning( + "DALL-E models do not support reference images; " + "ignoring {} reference image(s) for {}", + len(reference_images), + model, + ) + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + **self.extra_headers, + } + + clean_model = self._strip_model_prefix(model) + body: dict[str, Any] = { + "model": clean_model, + "prompt": prompt, + } + + if not _openai_is_gpt_image_model(clean_model): + body["response_format"] = "b64_json" + body["n"] = 1 + + size = _openai_size(clean_model, aspect_ratio, image_size) + if size: + body["size"] = size + + body.update(self.extra_body) + + logger.info("OpenAI Images API request: POST {}/images/generations body={}", self.api_base, body) + + response = await self._http_post( + f"{self.api_base}/images/generations", + headers=headers, + body=body, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as exc: + detail = response.text[:1000] + logger.error("OpenAI Images API error ({}): {}", response.status_code, detail) + raise ImageGenerationError( + f"OpenAI image generation failed (HTTP {response.status_code}): {detail}" + ) from exc + + payload = response.json() + logger.info("OpenAI Images API response ({}): {}", response.status_code, + {k: v for k, v in payload.items() if k != "data"}) + + client = self._client + owns_client = client is None + if owns_client: + client = httpx.AsyncClient(timeout=self.timeout) + try: + images = await _openai_images_from_payload(client, payload) + finally: + if owns_client: + await client.aclose() + + self._require_images(images, payload) + + return GeneratedImageResponse(images=images, content="", raw=payload) + + +# --------------------------------------------------------------------------- +# OpenAI Codex image generation +# --------------------------------------------------------------------------- + + +class CodexImageGenerationClient(ImageGenerationProvider): + """OpenAI image generation via Codex subscription OAuth. + + Uses the Codex Responses API with the ``image_generation`` tool + (the same mechanism ChatGPT uses internally). No API key required — + the Codex OAuth token from ``oauth_cli_kit`` is used instead. + """ + + provider_name = "openai_codex" + missing_key_message = ( + "Codex OAuth token is unavailable. " + "Log in with Codex subscription first." + ) + + def _default_base_url(self) -> str: + return "https://chatgpt.com/backend-api" + + def _codex_model(self, model: str) -> str: + """Strip the ``openai-codex/`` prefix if present.""" + if model.startswith(("openai-codex/", "openai_codex/")): + return model.split("/", 1)[1] + return model + + async def generate( + self, + *, + prompt: str, + model: str, + reference_images: list[str] | None = None, + aspect_ratio: str | None = None, + image_size: str | None = None, + ) -> GeneratedImageResponse: + try: + from oauth_cli_kit import get_token as get_codex_token + except ImportError: + raise ImageGenerationError(self.missing_key_message) + + try: + token = await asyncio.to_thread(get_codex_token) + except Exception as exc: + raise ImageGenerationError(self.missing_key_message) from exc + if not token or not token.access: + raise ImageGenerationError(self.missing_key_message) + + logger.info( + "Using Codex OAuth token for image generation (account: {})", + token.account_id, + ) + + if reference_images: + logger.warning( + "Codex image generation does not support reference images; " + "ignoring {} reference image(s)", + len(reference_images), + ) + + headers = { + "Authorization": f"Bearer {token.access}", + "chatgpt-account-id": token.account_id, + "OpenAI-Beta": "responses=experimental", + "originator": "nanobot", + "User-Agent": "nanobot (python)", + "Content-Type": "application/json", + **self.extra_headers, + } + + body: dict[str, Any] = { + "model": self._codex_model(model), + "instructions": "Generate an image based on the user's request.", + "input": [{"role": "user", "content": prompt}], + "tools": [{"type": "image_generation"}], + "tool_choice": "auto", + "stream": True, + "store": False, + } + body.update(self.extra_body) + + logger.info("Codex Responses API request: POST {}/codex/responses body={}", + self.api_base, {k: v for k, v in body.items() if k != "input"}) + + response = await self._http_post( + f"{self.api_base}/codex/responses", + headers=headers, + body=body, + ) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as exc: + detail = response.text[:1000] + logger.error("Codex Responses API error ({}): {}", response.status_code, detail) + raise ImageGenerationError( + f"Codex image generation failed (HTTP {response.status_code}): {detail}" + ) from exc + + images, content_text = await _parse_codex_sse_images(response) + + raw = {"status": "completed"} + self._require_images(images, raw) + + return GeneratedImageResponse(images=images, content=content_text, raw=raw) + + +def _openai_size( + model: str, + aspect_ratio: str | None, + image_size: str | None, +) -> str: + """Resolve aspect ratio or image_size to an OpenAI Images API size string.""" + sizes, supported_sizes = _openai_size_options(model) + explicit_size = _normalize_openai_image_size(image_size) + if explicit_size and _openai_explicit_size_supported( + explicit_size, + supported_sizes=supported_sizes, + ): + return explicit_size + if explicit_size: + logger.warning( + "OpenAI image size '{}' is not supported by {}; using aspect ratio/default size", + explicit_size, + model, + ) + if aspect_ratio and aspect_ratio in sizes: + return sizes[aspect_ratio] + return "1024x1024" + + +def _openai_is_gpt_image_model(model: str) -> bool: + normalized = model.lower() + return normalized.startswith(("gpt-image", "chatgpt-image")) + + +def _openai_size_options(model: str) -> tuple[dict[str, str], set[str] | None]: + normalized = model.lower() + if normalized.startswith("dall-e-2"): + return _OPENAI_DALLE2_ASPECT_RATIO_SIZES, _OPENAI_DALLE2_SUPPORTED_SIZES + if normalized.startswith("dall-e-3"): + return _OPENAI_DALLE3_ASPECT_RATIO_SIZES, _OPENAI_DALLE3_SUPPORTED_SIZES + if normalized.startswith("gpt-image-2"): + return _OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, None + return _OPENAI_GPT_IMAGE_ASPECT_RATIO_SIZES, _OPENAI_GPT_IMAGE_SUPPORTED_SIZES + + +def _normalize_openai_image_size(image_size: str | None) -> str | None: + if not image_size: + return None + normalized = image_size.strip().lower() + return normalized or None + + +def _openai_explicit_size_supported( + size: str, + *, + supported_sizes: set[str] | None, +) -> bool: + if supported_sizes is not None: + return size in supported_sizes + width, sep, height = size.partition("x") + return bool(sep and width.isdecimal() and height.isdecimal()) + + +async def _openai_images_from_payload( + client: httpx.AsyncClient, + payload: dict[str, Any], +) -> list[str]: + """Extract images from OpenAI Images API response. + + Handles both ``b64_json`` (preferred) and ``url`` (downloaded) formats. + """ + images: list[str] = [] + for item in payload.get("data") or []: + if not isinstance(item, dict): + continue + b64 = item.get("b64_json") + if isinstance(b64, str) and b64: + images.append(_b64_image_data_url(b64)) + continue + url = item.get("url") + if isinstance(url, str) and url: + images.append(await _download_image_data_url(client, url)) + return images + + +def _codex_responses_images_from_payload(payload: dict[str, Any]) -> list[str]: + """Extract images from Codex Responses API ``image_generation_call`` output.""" + images: list[str] = [] + for item in payload.get("output") or []: + if not isinstance(item, dict): + continue + if item.get("type") != "image_generation_call": + continue + result = item.get("result") + if isinstance(result, str): + images.append(result if result.startswith("data:image/") else _b64_image_data_url(result)) + continue + if isinstance(result, dict): + image_url = result.get("image_url") or result.get("image") or "" + if isinstance(image_url, str): + images.append(image_url if image_url.startswith("data:image/") else _b64_image_data_url(image_url)) + return images + + +async def _parse_codex_sse_images( + response: httpx.Response, +) -> tuple[list[str], str]: + """Parse a Codex Responses API SSE stream for image generation output. + + Returns ``(images, content_text)``. + """ + import json as _json + + images: list[str] = [] + text_parts: list[str] = [] + + buffer: list[str] = [] + async for line_bytes in response.aiter_lines(): + line = line_bytes.strip() + if line == "": + if buffer: + data_lines = [] + for bl in buffer: + if bl.startswith("data:"): + data_lines.append(bl[5:].strip()) + buffer.clear() + if data_lines: + raw = "".join(data_lines) + if raw == "[DONE]": + break + try: + event = _json.loads(raw) + except Exception: + continue + ev_type = event.get("type", "") + if ev_type in ("error", "response.failed"): + logger.error("Codex SSE failure: {}", raw[:2000]) + _collect_images_from_sse_event(event, images) + _collect_text_from_sse_event(event, text_parts) + continue + buffer.append(line) + + # flush remaining + if buffer: + data_lines = [bl[5:].strip() for bl in buffer if bl.startswith("data:")] + raw = "".join(data_lines) + if raw and raw != "[DONE]": + try: + event = _json.loads(raw) + except Exception: + pass + else: + _collect_images_from_sse_event(event, images) + _collect_text_from_sse_event(event, text_parts) + + return images, "".join(text_parts).strip() + + +def _collect_images_from_sse_event(event: dict[str, Any], images: list[str]) -> None: + if event.get("type") != "response.output_item.done": + return + item = event.get("item") or {} + if item.get("type") != "image_generation_call": + return + result = item.get("result") + if isinstance(result, str): + if result.startswith("data:image/"): + images.append(result) + else: + images.append(_b64_image_data_url(result)) + elif isinstance(result, dict): + image_url = result.get("image_url") or result.get("image") or "" + if isinstance(image_url, str): + if image_url.startswith("data:image/"): + images.append(image_url) + else: + images.append(_b64_image_data_url(image_url)) + + +def _collect_text_from_sse_event(event: dict[str, Any], text_parts: list[str]) -> None: + if event.get("type") == "response.output_text.delta": + delta = event.get("delta") + if isinstance(delta, str) and delta: + text_parts.append(delta) + + # --------------------------------------------------------------------------- # StepFun (阶跃星辰) image generation # --------------------------------------------------------------------------- @@ -1025,9 +1446,11 @@ def _stepfun_images_from_payload(payload: dict[str, Any]) -> list[str]: # Provider registration # --------------------------------------------------------------------------- -register_image_gen_provider(OpenRouterImageGenerationClient) register_image_gen_provider(AIHubMixImageGenerationClient) +register_image_gen_provider(CodexImageGenerationClient) register_image_gen_provider(GeminiImageGenerationClient) register_image_gen_provider(OllamaImageGenerationClient) register_image_gen_provider(MiniMaxImageGenerationClient) +register_image_gen_provider(OpenAIImageGenerationClient) +register_image_gen_provider(OpenRouterImageGenerationClient) register_image_gen_provider(StepFunImageGenerationClient) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index b8112b529..8281d7d20 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -11,6 +11,7 @@ import secrets import string import time import uuid +from collections import deque from collections.abc import Awaitable, Callable from ipaddress import ip_address from typing import TYPE_CHECKING, Any @@ -74,41 +75,43 @@ _THINKING_STYLE_MAP: dict[str, Any] = { "enable_thinking": lambda on: {"enable_thinking": on}, "reasoning_split": lambda on: {"reasoning_split": on}, } +_GATEWAY_REASONING_STYLE_MAP: dict[str, Any] = { + "reasoning_effort": lambda effort: {"reasoning": {"effort": effort}}, +} +_MODEL_THINKING_STYLES: dict[str, str] = { + **dict.fromkeys(_KIMI_THINKING_MODELS, "thinking_type"), + **dict.fromkeys(_MIMO_THINKING_MODELS, "thinking_type"), +} -def _is_kimi_thinking_model(model_name: str) -> bool: - """Return True if model_name refers to a Kimi thinking-capable model. - - Supports two forms: - - Exact match: e.g. kimi-k2.5 / kimi-k2.6 in _KIMI_THINKING_MODELS - - Slug match: moonshotai/kimi-k2.5 -> the part after the last "/" - is checked against _KIMI_THINKING_MODELS - - This covers both the native Moonshot provider (bare slug) and - OpenRouter-style names (``"publisher/slug"``). - """ - name = model_name.lower() - if name in _KIMI_THINKING_MODELS: - return True - if "/" in name and name.rsplit("/", 1)[1] in _KIMI_THINKING_MODELS: - return True - return False +def _model_slug(model_name: str) -> str: + return model_name.lower().rsplit("/", 1)[-1] -def _is_mimo_thinking_model(model_name: str) -> bool: - """Return True if model_name refers to a MiMo thinking-capable model. +def _model_thinking_style(model_name: str) -> str: + return _MODEL_THINKING_STYLES.get(_model_slug(model_name), "") - Mirrors _is_kimi_thinking_model: gateway providers (e.g. OpenRouter - routing ``xiaomi/mimo-v2.5-pro``) have no ``thinking_style`` on their - spec, so the spec-driven branch in _build_kwargs misses them. The - model-name path catches those cases. - """ - name = model_name.lower() - if name in _MIMO_THINKING_MODELS: - return True - if "/" in name and name.rsplit("/", 1)[1] in _MIMO_THINKING_MODELS: - return True - return False + +def _thinking_styles_for(spec: ProviderSpec | None, model_name: str) -> list[str]: + styles: list[str] = [] + if spec and spec.thinking_style: + styles.append(spec.thinking_style) + model_style = _model_thinking_style(model_name) + if model_style and model_style not in styles: + styles.append(model_style) + return styles + + +def _thinking_extra_body(style: str, thinking_enabled: bool) -> dict[str, Any] | None: + builder = _THINKING_STYLE_MAP.get(style) + return builder(thinking_enabled) if builder else None + + +def _gateway_reasoning_extra_body(style: str, effort: str | None) -> dict[str, Any] | None: + if not effort: + return None + builder = _GATEWAY_REASONING_STYLE_MAP.get(style) + return builder(effort) if builder else None def _openai_compat_timeout_s() -> float: @@ -461,6 +464,7 @@ class OpenAICompatProvider(LLMProvider): """Strip non-standard keys, normalize tool_call IDs.""" sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS) id_map: dict[str, str] = {} + pending_tool_ids: dict[str, deque[str]] = {} force_string_content = bool(self._spec and self._spec.name == "deepseek") def map_id(value: Any) -> Any: @@ -468,15 +472,49 @@ class OpenAICompatProvider(LLMProvider): return value return id_map.setdefault(value, self._normalize_tool_call_id(value)) + def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str: + if isinstance(value, str) and value: + base = map_id(value) + else: + base = _short_tool_id() + if not isinstance(base, str) or not base: + base = _short_tool_id() + if base not in used_ids: + return base + seed = value if isinstance(value, str) and value else base + salt = 1 + while True: + candidate = self._normalize_tool_call_id(f"{seed}:{idx}:{salt}") + if isinstance(candidate, str) and candidate not in used_ids: + return candidate + salt += 1 + + def map_tool_result_id(value: Any) -> Any: + if not isinstance(value, str): + return value + queue = pending_tool_ids.get(value) + if queue: + mapped = queue.popleft() + if not queue: + pending_tool_ids.pop(value, None) + return mapped + return map_id(value) + for clean in sanitized: if isinstance(clean.get("tool_calls"), list): normalized = [] - for tc in clean["tool_calls"]: + used_ids: set[str] = set() + for idx, tc in enumerate(clean["tool_calls"]): if not isinstance(tc, dict): normalized.append(tc) continue tc_clean = dict(tc) - tc_clean["id"] = map_id(tc_clean.get("id")) + raw_id = tc_clean.get("id") + mapped_id = unique_tool_id(raw_id, used_ids, idx) + tc_clean["id"] = mapped_id + used_ids.add(mapped_id) + if isinstance(raw_id, str) and raw_id: + pending_tool_ids.setdefault(raw_id, deque()).append(mapped_id) function = tc_clean.get("function") if isinstance(function, dict): function_clean = dict(function) @@ -494,7 +532,7 @@ class OpenAICompatProvider(LLMProvider): # that mix non-empty content with tool_calls. clean["content"] = None if "tool_call_id" in clean and clean["tool_call_id"]: - clean["tool_call_id"] = map_id(clean["tool_call_id"]) + clean["tool_call_id"] = map_tool_result_id(clean["tool_call_id"]) if ( force_string_content and not (clean.get("role") == "assistant" and clean.get("tool_calls")) @@ -581,39 +619,27 @@ class OpenAICompatProvider(LLMProvider): if wire_effort and semantic_effort != "none": kwargs["reasoning_effort"] = wire_effort - # Provider-specific thinking parameters. - # Only sent when reasoning_effort is explicitly configured so that - # the provider default is preserved otherwise. - # The mapping is driven by ProviderSpec.thinking_style so that adding - # a new provider never requires touching this function. - if spec and spec.thinking_style and reasoning_effort is not None: + # Only send thinking controls when reasoning_effort is explicit so + # omitting the config preserves each provider's default. + if reasoning_effort is not None: thinking_enabled = semantic_effort not in ("none", "minimal") - extra = _THINKING_STYLE_MAP.get(spec.thinking_style, lambda _: None)(thinking_enabled) - if extra: - kwargs.setdefault("extra_body", {}).update(extra) + for thinking_style in _thinking_styles_for(spec, model_name): + extra = _thinking_extra_body(thinking_style, thinking_enabled) + if extra: + kwargs.setdefault("extra_body", {}).update(extra) + gateway_style = getattr(spec, "gateway_reasoning_style", "") if spec else "" + if gateway_style and _model_thinking_style(model_name): + extra = _gateway_reasoning_extra_body(gateway_style, semantic_effort) + if extra: + kwargs.setdefault("extra_body", {}).update(extra) - # Model-level thinking injection for Kimi thinking-capable models. - # Strip any provider prefix (e.g. "moonshotai/") before the set lookup - # so that OpenRouter-style names like "moonshotai/kimi-k2.5" are handled - # identically to bare names like "kimi-k2.5". - if reasoning_effort is not None and _is_kimi_thinking_model(model_name): - thinking_enabled = semantic_effort not in ("none", "minimal") - kwargs.setdefault("extra_body", {}).update( - {"thinking": {"type": "enabled" if thinking_enabled else "disabled"}} - ) - - # Model-level thinking injection for MiMo thinking-capable models. - # Same shape as Kimi: gateway providers (OpenRouter, etc.) lack the - # xiaomi_mimo spec's thinking_style, so the spec-driven branch above - # misses them — match by model name to catch "xiaomi/mimo-v2.5-pro" - # and friends. (Direct xiaomi_mimo requests are also covered here; - # both branches write the same payload, so the dict update is a - # safe no-op for already-handled cases.) - if reasoning_effort is not None and _is_mimo_thinking_model(model_name): - thinking_enabled = semantic_effort not in ("none", "minimal") - kwargs.setdefault("extra_body", {}).update( - {"thinking": {"type": "enabled" if thinking_enabled else "disabled"}} - ) + # Moonshot rejects requests that carry both 'reasoning_effort' + # and the native 'thinking' param. We already expressed the + # user's intent via the provider-native shape, so drop the + # redundant wire-level kwarg. Only kimi models need this — + # Xiaomi's API accepts both params. + if _model_slug(model_name) in _KIMI_THINKING_MODELS: + kwargs.pop("reasoning_effort", None) if tools: kwargs["tools"] = tools @@ -628,8 +654,7 @@ class OpenAICompatProvider(LLMProvider): and semantic_effort not in ("none", "minimal") and ( (spec and spec.thinking_style) - or _is_kimi_thinking_model(model_name) - or _is_mimo_thinking_model(model_name) + or _model_thinking_style(model_name) ) ) implicit_deepseek_thinking = ( @@ -1097,6 +1122,15 @@ class OpenAICompatProvider(LLMProvider): if delta: _accum_legacy_function_call(getattr(delta, "function_call", None)) + # Some providers (e.g. Zhipu/GLM) reuse the same tool_call id for + # parallel tool calls in streaming mode. Deduplicate before building + # the response so downstream tool messages don't collide. + _seen_tc_ids: set[str] = set() + for b in tc_bufs.values(): + if not b["id"] or b["id"] in _seen_tc_ids: + b["id"] = _short_tool_id() + _seen_tc_ids.add(b["id"]) + return LLMResponse( content="".join(content_parts) or None, tool_calls=[ diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 7c8edd271..ab7e2cf1e 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -71,6 +71,11 @@ class ProviderSpec: # "reasoning_split" — {"reasoning_split": true/false} (MiniMax) thinking_style: str = "" + # Gateway-native reasoning control to pair with model-level thinking styles. + # "reasoning_effort" — {"reasoning": {"effort": }} + # (OpenRouter) + gateway_reasoning_style: str = "" + # When True, treat the "reasoning" response field as formal content # when "content" is empty. Only set this for providers (e.g. StepFun) # whose API returns the actual answer in "reasoning" instead of "content". @@ -142,6 +147,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( detect_by_base_keyword="openrouter", default_api_base="https://openrouter.ai/api/v1", supports_prompt_caching=True, + gateway_reasoning_style="reasoning_effort", ), # Hugging Face Inference Providers: OpenAI-compatible router for chat models. ProviderSpec( @@ -193,6 +199,18 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( default_api_base="https://api.siliconflow.cn/v1", ), + # Novita AI: OpenAI-compatible gateway for hosted model APIs. + ProviderSpec( + name="novita", + keywords=("novita",), + env_key="NOVITA_API_KEY", + display_name="Novita AI", + backend="openai_compat", + is_gateway=True, + detect_by_base_keyword="novita", + default_api_base="https://api.novita.ai/openai", + ), + # VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models ProviderSpec( name="volcengine", diff --git a/nanobot/templates/AGENTS.md b/nanobot/templates/AGENTS.md index 0bf6de3d3..46cfc08c3 100644 --- a/nanobot/templates/AGENTS.md +++ b/nanobot/templates/AGENTS.md @@ -1,5 +1,9 @@ # Agent Instructions +## Workspace Guidance + +Use this file for project-specific preferences, recurring workflow conventions, and instructions you want the agent to remember for this workspace. Keep durable facts about the user in `USER.md`, personality/style guidance in `SOUL.md`, and long-term memory in `memory/MEMORY.md`. + ## Scheduled Reminders Before scheduling reminders, check available skills and follow skill guidance first. @@ -10,10 +14,10 @@ Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegr ## Heartbeat Tasks -`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks: +`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks. -- **Add**: `edit_file` to append new tasks -- **Remove**: `edit_file` to delete completed tasks -- **Rewrite**: `write_file` to replace all tasks +- Use `apply_patch` for normal task-list updates, especially when adding, removing, or changing multiple lines. +- Use `edit_file` only for small exact replacements copied from the current `HEARTBEAT.md`. +- Use `write_file` for first creation or intentional full-file rewrites. When the user asks for a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time cron reminder. diff --git a/nanobot/templates/TOOLS.md b/nanobot/templates/TOOLS.md deleted file mode 100644 index 374e49778..000000000 --- a/nanobot/templates/TOOLS.md +++ /dev/null @@ -1,28 +0,0 @@ -# Tool Usage Notes - -Tool signatures are provided automatically via function calling. -This file documents non-obvious constraints and usage patterns. - -## exec — Safety Limits - -- Commands have a configurable timeout (default 60s) -- Dangerous commands are blocked (rm -rf, format, dd, shutdown, etc.) -- Output is truncated at 10,000 characters -- `restrictToWorkspace` config can limit file access to the workspace - -## grep — Content Search - -- Use `grep` to search file contents inside the workspace -- Default behavior returns only matching file paths (`output_mode="files_with_matches"`) -- Supports optional `glob` filtering (e.g. `glob="*.py"`) plus `context_before` / `context_after` -- Supports `type="py"`, `type="ts"`, `type="md"` and similar shorthand filters -- Use `fixed_strings=true` for literal keywords containing regex characters -- Use `output_mode="files_with_matches"` to get only matching file paths -- Use `output_mode="count"` to size a search before reading full matches -- Use `head_limit` and `offset` to page across results -- Prefer this over `exec` for code and history searches -- Binary or oversized files may be skipped to keep results readable - -## cron — Scheduled Reminders - -- Please refer to cron skill for usage. diff --git a/nanobot/templates/agent/tool_contract.md b/nanobot/templates/agent/tool_contract.md new file mode 100644 index 000000000..edbba21c9 --- /dev/null +++ b/nanobot/templates/agent/tool_contract.md @@ -0,0 +1,60 @@ +# Tool Usage Notes + +Tool signatures are provided automatically via function calling. This section +documents the general tool contract and non-obvious usage patterns. + +## General Tool Contract + +- Use the narrowest structured tool that directly matches the task. +- Use read-only discovery before writes when state is uncertain. +- Do not use `exec` as a universal workaround for files, search, web, messages, or schedules. +- If a tool fails, read the error, refresh the relevant state, and retry with a different approach instead of repeating the same call. +- After meaningful changes, verify with the smallest reliable check: re-read changed state, run targeted tests, or inspect command output. +- Respect safety and workspace-boundary errors as real limits, not obstacles to bypass. + +## Discovery and Reading + +- Use `find_files` or `list_dir` to locate workspace paths before `read_file` when a path is uncertain. +- Use `grep` for content search inside the workspace; prefer it over shell grep for ordinary searches. +- `grep` defaults to `output_mode="files_with_matches"`; use `output_mode="content"` for matching lines with context. +- Use `fixed_strings=true` for literal keywords containing regex characters. +- Use `output_mode="count"` to size a broad search before reading full matches. +- Use `head_limit` and `offset` to page across large result sets. +- Binary or oversized files may be skipped to keep results readable. + +## File and Coding Workflows + +- For code or config changes, the default loop is: locate (`find_files`/`grep`), inspect (`read_file`), edit (`apply_patch`), then verify (`exec` or re-read). +- Use `apply_patch` as the default code editing tool, especially for multi-file changes, structural edits, generated code, moves, adds, or deletes. +- Use `apply_patch dry_run=true` when the patch is uncertain and you want validation plus a change summary before writing. +- Use `edit_file` only for small exact replacements in one file, with `old_text` copied from `read_file`; add `occurrence`, `line_hint`, or `expected_replacements` when ambiguity matters. +- Use `write_file` for new files or intentional full-file rewrites, not routine partial edits. +- If `apply_patch` or `edit_file` fails, re-read with `force=true`, narrow the context, and try a smaller patch rather than switching to shell `sed` or `echo`. + +## Process Execution + +- Use `exec` for tests, builds, package commands, git commands, and other process execution. +- Prefer dedicated file/search tools over `cat`, shell `find`, shell `grep`, `sed`, or `echo` for ordinary workspace inspection and edits. +- Use non-interactive flags such as `-y` or `--yes` when available. +- Commands have a configurable timeout (default 60s), dangerous commands are blocked, and output is truncated. +- For long-running or interactive commands, pass `yield_time_ms`; if the process keeps running, continue with `write_stdin`. +- Use `write_stdin` to poll, provide stdin, close stdin, wait for expected output with `wait_for`, or terminate an existing exec session. +- Use `list_exec_sessions` to recover active session IDs after context shifts. + +## Web and External Information + +- Use web tools when the user asks for current information, a specific URL, or information likely to have changed. +- Use `web_search` to find sources and `web_fetch` for a specific page or result that needs closer reading. +- Do not invent freshness-sensitive facts when tools can verify them. + +## Messaging and Media + +- Use `message` to send content or local media to the user/channel. +- `read_file` only reads content for your analysis; it does not deliver a file to the user. +- When sending an existing local file, attach it through the message/media mechanism instead of pasting file contents unless the user asked for text. + +## Scheduling and Background Work + +- Use `cron` for scheduled reminders or recurring jobs; do not run `nanobot cron` through `exec`. +- For heartbeat tasks, update `HEARTBEAT.md` according to the agent instructions. +- Do not write reminders only to memory files when the user expects an actual notification. diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index b5d2f6d73..fd929134d 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -3,15 +3,13 @@ from __future__ import annotations import difflib -import json import re import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Awaitable, Callable - -TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"}) +TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "apply_patch"}) _MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024 _LIVE_EMIT_INTERVAL_S = 0.18 _LIVE_EMIT_LINE_STEP = 24 @@ -154,19 +152,108 @@ def prepare_file_edit_tracker( workspace: Path | None, params: dict[str, Any] | None, ) -> FileEditTracker | None: + trackers = prepare_file_edit_trackers( + call_id=call_id, + tool_name=tool_name, + tool=tool, + workspace=workspace, + params=params, + ) + return trackers[0] if trackers else None + + +def prepare_file_edit_trackers( + *, + call_id: str, + tool_name: str, + tool: Any, + workspace: Path | None, + params: dict[str, Any] | None, +) -> list[FileEditTracker]: if not is_file_edit_tool(tool_name): - return None + return [] + paths = resolve_file_edit_paths(tool_name, tool, workspace, params) + trackers: list[FileEditTracker] = [] + seen: set[Path] = set() + for path in paths: + try: + resolved = path.resolve() + except Exception: + resolved = path + if resolved in seen: + continue + seen.add(resolved) + before = read_file_snapshot(path) + trackers.append(FileEditTracker( + call_id=str(call_id or ""), + tool=tool_name, + path=path, + display_path=display_file_edit_path(path, workspace), + before=before, + )) + return trackers + + +def resolve_file_edit_paths( + tool_name: str, + tool: Any, + workspace: Path | None, + params: dict[str, Any] | None, +) -> list[Path]: + if tool_name == "apply_patch": + return _resolve_apply_patch_paths(tool, workspace, params) path = resolve_file_edit_path(tool, workspace, params) if path is None: - return None - before = read_file_snapshot(path) - return FileEditTracker( - call_id=str(call_id or ""), - tool=tool_name, - path=path, - display_path=display_file_edit_path(path, workspace), - before=before, - ) + return [] + return [path] + + +def _resolve_apply_patch_paths( + tool: Any, + workspace: Path | None, + params: dict[str, Any] | None, +) -> list[Path]: + if not isinstance(params, dict): + return [] + edits = params.get("edits") + if not isinstance(edits, list) or not edits: + return [] + if params.get("dry_run") is True: + return [] + + resolved: list[Path] = [] + seen: set[Path] = set() + for edit in edits: + if not isinstance(edit, dict): + continue + raw_path = edit.get("path") + if not isinstance(raw_path, str) or not raw_path.strip(): + continue + path = _resolve_raw_file_edit_path(tool, workspace, raw_path) + if path is not None and path not in seen: + seen.add(path) + resolved.append(path) + return resolved + + +def _resolve_raw_file_edit_path( + tool: Any, + workspace: Path | None, + raw_path: str, +) -> Path | None: + resolver = getattr(tool, "_resolve", None) + if callable(resolver): + try: + resolved = resolver(raw_path) + if isinstance(resolved, Path): + return resolved + if resolved: + return Path(resolved) + except Exception: + return None + if workspace is None: + return Path(raw_path).expanduser().resolve() + return (workspace / raw_path).expanduser().resolve() def build_file_edit_start_event( @@ -304,6 +391,9 @@ class StreamingFileEditTracker: self._states[key] = state state.apply_delta(payload) + if state.name == "apply_patch": + await self._update_apply_patch(state) + return if state.name not in {"write_file", "edit_file"}: return if state.path is None: @@ -343,10 +433,80 @@ class StreamingFileEditTracker: deleted=deleted, )]) + async def _update_apply_patch(self, state: _StreamingFileEditState) -> None: + if _json_bool_true(state.arguments, "dry_run"): + return + tool = self._tools.get("apply_patch") if hasattr(self._tools, "get") else None + events: list[dict[str, Any]] = [] + now = time.monotonic() + + path_matches = list(re.finditer(r'"path"\s*:\s*"([^"]+)"', state.arguments)) + if not path_matches: + return + + for i, m in enumerate(path_matches): + raw_path = m.group(1) + path = _resolve_raw_file_edit_path(tool, self._workspace, raw_path) + if path is None: + continue + + segment_start = m.start() + segment_end = path_matches[i + 1].start() if i + 1 < len(path_matches) else len(state.arguments) + segment = state.arguments[segment_start:segment_end] + + action_match = re.search(r'"action"\s*:\s*"(replace|add|delete)"', segment) + action = action_match.group(1) if action_match else "replace" + + old_text = _extract_json_string_prefix(segment, "old_text") or "" + new_text = _extract_json_string_prefix(segment, "new_text") or "" + + added = _text_line_count(new_text) if action in ("replace", "add") else 0 + deleted = _text_line_count(old_text) if action in ("replace", "delete") else 0 + delete_file = action == "delete" + + file_state = state.patch_files.get(raw_path) + if file_state is None: + tracker = FileEditTracker( + call_id=state.call_id or state.key, + tool="apply_patch", + path=path, + display_path=display_file_edit_path(path, self._workspace), + before=read_file_snapshot(path), + ) + file_state = _StreamingPatchFileState(tracker=tracker) + state.patch_files[raw_path] = file_state + if delete_file and added == 0 and deleted == 0 and file_state.tracker.before.countable: + deleted = _text_line_count(file_state.tracker.before.text or "") + if not file_state.should_emit(added, deleted, now): + continue + file_state.mark_emitted(added, deleted, now) + events.append(build_file_edit_live_event( + file_state.tracker, + added=added, + deleted=deleted, + )) + if events: + await self._emit(events) + async def flush(self) -> None: events: list[dict[str, Any]] = [] now = time.monotonic() for state in self._states.values(): + for file_state in state.patch_files.values(): + added, deleted = file_state.last_added, file_state.last_deleted + if not file_state.emitted_once: + continue + if ( + file_state.last_emitted_added == added + and file_state.last_emitted_deleted == deleted + ): + continue + file_state.mark_emitted(added, deleted, now) + events.append(build_file_edit_live_event( + file_state.tracker, + added=added, + deleted=deleted, + )) if state.tracker is None: continue added, deleted = state.live_diff_counts() @@ -367,12 +527,14 @@ class StreamingFileEditTracker: def apply_final_call_ids(self, final_tool_calls: list[Any]) -> None: """Keep final start/end events keyed to any earlier streamed placeholder.""" + used_canonicals: set[str] = set() for tool_call in final_tool_calls: canonical = self.canonical_call_id_for(tool_call) - if canonical: + if canonical and canonical not in used_canonicals: try: tool_call.id = canonical - except Exception: + used_canonicals.add(canonical) + except (AttributeError, TypeError): pass def canonical_call_id_for(self, tool_call: Any) -> str | None: @@ -389,6 +551,10 @@ class StreamingFileEditTracker: """Mark streamed edits as failed when no final tool call will run.""" events: list[dict[str, Any]] = [] for state in self._states.values(): + for file_state in state.patch_files.values(): + if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls): + continue + events.append(build_file_edit_error_event(file_state.tracker, error)) if state.tracker is None: continue if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls): @@ -492,6 +658,39 @@ class _StreamingJsonStringField: self.last_char_cr = False +@dataclass(slots=True) +class _StreamingPatchFileState: + tracker: FileEditTracker + emitted_once: bool = False + last_emitted_added: int = -1 + last_emitted_deleted: int = -1 + last_emit_at: float = 0.0 + last_added: int = 0 + last_deleted: int = 0 + + def should_emit(self, added: int, deleted: int, now: float) -> bool: + self.last_added = added + self.last_deleted = deleted + if not self.emitted_once: + return True + if added == self.last_emitted_added and deleted == self.last_emitted_deleted: + return False + if max( + abs(added - self.last_emitted_added), + abs(deleted - self.last_emitted_deleted), + ) >= _LIVE_EMIT_LINE_STEP: + return True + return now - self.last_emit_at >= _LIVE_EMIT_INTERVAL_S + + def mark_emitted(self, added: int, deleted: int, now: float) -> None: + self.emitted_once = True + self.last_added = added + self.last_deleted = deleted + self.last_emitted_added = added + self.last_emitted_deleted = deleted + self.last_emit_at = now + + @dataclass(slots=True) class _StreamingFileEditState: key: str @@ -509,6 +708,7 @@ class _StreamingFileEditState: new_text: _StreamingJsonStringField = field( default_factory=lambda: _StreamingJsonStringField("new_text") ) + patch_files: dict[str, _StreamingPatchFileState] = field(default_factory=dict) emitted_once: bool = False last_emitted_added: int = -1 last_emitted_deleted: int = -1 @@ -531,6 +731,7 @@ class _StreamingFileEditState: self.content.reset() self.old_text.reset() self.new_text.reset() + self.patch_files.clear() return delta = payload.get("arguments_delta") if isinstance(delta, str) and delta: @@ -590,6 +791,14 @@ class _StreamingFileEditState: name = getattr(tool_call, "name", None) if name != self.name: return False + if self.name == "apply_patch": + arguments = getattr(tool_call, "arguments", None) + if not isinstance(arguments, dict): + return False + edits = arguments.get("edits") + if not isinstance(edits, list): + return False + return '"edits"' in self.arguments arguments = getattr(tool_call, "arguments", None) if not isinstance(arguments, dict): return False @@ -612,6 +821,51 @@ def _stream_key(payload: dict[str, Any]) -> str: return "" +def _json_bool_true(source: str, key: str) -> bool: + return re.search(rf'"{re.escape(key)}"\s*:\s*true\b', source) is not None + + +def _extract_json_string_prefix(source: str, key: str) -> str | None: + match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source) + if match is None: + return None + out: list[str] = [] + i = match.end() + escape = False + while i < len(source): + ch = source[i] + if escape: + escape = False + if ch == "n": + out.append("\n") + elif ch == "r": + out.append("\r") + elif ch == "t": + out.append("\t") + elif ch == "u": + digits = source[i + 1:i + 5] + if len(digits) < 4: + break + try: + out.append(chr(int(digits, 16))) + except ValueError: + break + i += 4 + else: + out.append(ch) + i += 1 + continue + if ch == "\\": + escape = True + i += 1 + continue + if ch == '"': + return "".join(out) + out.append(ch) + i += 1 + return "".join(out) + + def _extract_complete_json_string(source: str, key: str) -> str | None: match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source) if match is None: @@ -704,77 +958,4 @@ def _predict_after_text( return before_text.replace(old_text, new_text) return before_text.replace(old_text, new_text, 1) return None - if tool_name == "notebook_edit": - return _predict_notebook_after_text(params, before_text) return None - - -def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> str | None: - try: - nb = json.loads(before_text) if before_text.strip() else _empty_notebook() - except Exception: - return None - cells = nb.get("cells") - if not isinstance(cells, list): - return None - try: - cell_index = int(params.get("cell_index", 0)) - except (TypeError, ValueError): - return None - new_source = params.get("new_source") - source = new_source if isinstance(new_source, str) else "" - cell_type = ( - params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code" - ) - mode = ( - params.get("edit_mode") - if params.get("edit_mode") in ("replace", "insert", "delete") - else "replace" - ) - if mode == "delete": - if 0 <= cell_index < len(cells): - cells.pop(cell_index) - else: - return None - elif mode == "insert": - insert_at = min(max(cell_index + 1, 0), len(cells)) - cells.insert(insert_at, _new_notebook_cell(source, str(cell_type))) - else: - if not (0 <= cell_index < len(cells)): - return None - cell = cells[cell_index] - if not isinstance(cell, dict): - return None - cell["source"] = source - cell["cell_type"] = cell_type - if cell_type == "code": - cell.setdefault("outputs", []) - cell.setdefault("execution_count", None) - else: - cell.pop("outputs", None) - cell.pop("execution_count", None) - nb["cells"] = cells - try: - return json.dumps(nb, indent=1, ensure_ascii=False) - except Exception: - return None - - -def _empty_notebook() -> dict[str, Any]: - return { - "nbformat": 4, - "nbformat_minor": 5, - "metadata": { - "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, - "language_info": {"name": "python"}, - }, - "cells": [], - } - - -def _new_notebook_cell(source: str, cell_type: str) -> dict[str, Any]: - cell: dict[str, Any] = {"cell_type": cell_type, "source": source, "metadata": {}} - if cell_type == "code": - cell["outputs"] = [] - cell["execution_count"] = None - return cell diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 2a969298c..ae91bf394 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -576,7 +576,7 @@ def build_status_content( def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]: - """Sync bundled templates to workspace. Only creates missing files.""" + """Sync bundled templates to workspace. Creates missing files without overwriting user files.""" from importlib.resources import files as pkg_files try: @@ -589,10 +589,11 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str] added: list[str] = [] def _write(src, dest: Path): + content = src.read_text(encoding="utf-8") if src else "" if dest.exists(): return dest.parent.mkdir(parents=True, exist_ok=True) - dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8") + dest.write_text(content, encoding="utf-8") added.append(str(dest.relative_to(workspace))) for item in tpl.iterdir(): diff --git a/nanobot/utils/tool_hints.py b/nanobot/utils/tool_hints.py index 272a19c9a..3a6460701 100644 --- a/nanobot/utils/tool_hints.py +++ b/nanobot/utils/tool_hints.py @@ -11,8 +11,10 @@ _TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = { "read_file": (["path", "file_path"], "read {}", True, False), "write_file": (["path", "file_path"], "write {}", True, False), "edit": (["file_path", "path"], "edit {}", True, False), + "find_files": (["query", "glob", "path"], "find {}", False, False), "grep": (["pattern"], 'grep "{}"', False, False), "exec": (["command"], "$ {}", False, True), + "list_exec_sessions": ([], "exec sessions", False, False), "web_search": (["query"], 'search "{}"', False, False), "web_fetch": (["url"], "fetch {}", True, False), "list_dir": (["path"], "ls {}", True, False), @@ -81,6 +83,8 @@ def _extract_arg(tc, key_args: list[str]) -> str | None: def _fmt_known(tc, fmt: tuple, max_length: int = 40) -> str: """Format a registered tool using its template.""" + if not fmt[0] and "{}" not in fmt[1]: + return fmt[1] val = _extract_arg(tc, fmt[0]) if val is None: return tc.name diff --git a/nanobot/webui/settings_api.py b/nanobot/webui/settings_api.py index a5ab13c5a..6d43e22c8 100644 --- a/nanobot/webui/settings_api.py +++ b/nanobot/webui/settings_api.py @@ -73,12 +73,16 @@ def _mask_secret_hint(secret: str | None) -> str | None: def _provider_requires_api_key(spec: Any) -> bool: if spec.backend == "azure_openai": return True + if spec.is_oauth: + return False if spec.is_local or spec.is_direct: return False return True def _provider_configured_for_settings(spec: Any, provider_config: Any) -> bool: + if spec.is_oauth: + return True if _provider_requires_api_key(spec): return bool(provider_config.api_key) return bool( diff --git a/tests/agent/test_context_builder.py b/tests/agent/test_context_builder.py index 0206d0986..a36c0a30a 100644 --- a/tests/agent/test_context_builder.py +++ b/tests/agent/test_context_builder.py @@ -139,6 +139,13 @@ class TestLoadBootstrapFiles: for name in ContextBuilder.BOOTSTRAP_FILES: assert f"## {name}" in result + def test_legacy_tools_md_is_not_bootstrapped(self, tmp_path): + (tmp_path / "TOOLS.md").write_text("workspace tool notes", encoding="utf-8") + builder = _builder(tmp_path) + result = builder._load_bootstrap_files() + assert "TOOLS.md" not in result + assert "workspace tool notes" not in result + def test_utf8_content(self, tmp_path): (tmp_path / "AGENTS.md").write_text("用中文回复", encoding="utf-8") builder = _builder(tmp_path) @@ -171,6 +178,37 @@ class TestIsTemplateContent: assert ContextBuilder._is_template_content("totally different", "memory/MEMORY.md") is False +# --------------------------------------------------------------------------- +# Bundled bootstrap templates +# --------------------------------------------------------------------------- + + +class TestBundledToolContract: + def test_tool_contract_balances_general_and_coding_workflows(self): + from importlib.resources import files as pkg_files + + tpl = pkg_files("nanobot") / "templates" / "agent" / "tool_contract.md" + content = tpl.read_text(encoding="utf-8") + + assert "## General Tool Contract" in content + assert "Use the narrowest structured tool" in content + assert "Do not use `exec` as a universal workaround" in content + assert "## File and Coding Workflows" in content + assert "apply_patch" in content + assert "## Web and External Information" in content + assert "## Messaging and Media" in content + assert "## Scheduling and Background Work" in content + assert "pure coding" not in content.lower() + + def test_tool_contract_is_injected_without_workspace_file(self, tmp_path): + builder = _builder(tmp_path) + prompt = builder.build_system_prompt() + + assert "# Tool Usage Notes" in prompt + assert "## General Tool Contract" in prompt + assert "Do not use `exec` as a universal workaround" in prompt + + # --------------------------------------------------------------------------- # _build_user_content # --------------------------------------------------------------------------- diff --git a/tests/agent/test_onboard_logic.py b/tests/agent/test_onboard_logic.py index 11a284bb5..762da4f31 100644 --- a/tests/agent/test_onboard_logic.py +++ b/tests/agent/test_onboard_logic.py @@ -346,6 +346,26 @@ class TestSyncWorkspaceTemplates: content = (workspace / "AGENTS.md").read_text() assert content == "existing content" + def test_does_not_create_tools_md(self, tmp_path): + """Tool contract is injected internally, not copied into user workspaces.""" + workspace = tmp_path / "workspace" + + added = sync_workspace_templates(workspace, silent=True) + + assert "TOOLS.md" not in added + assert not (workspace / "TOOLS.md").exists() + + def test_preserves_existing_tools_md_without_overwriting(self, tmp_path): + """Legacy user workspaces may have TOOLS.md; sync should leave it untouched.""" + workspace = tmp_path / "workspace" + workspace.mkdir(parents=True) + tools_path = workspace / "TOOLS.md" + tools_path.write_text("custom tool notes", encoding="utf-8") + + sync_workspace_templates(workspace, silent=True) + + assert tools_path.read_text(encoding="utf-8") == "custom tool notes" + def test_creates_memory_directory(self, tmp_path): """Should create memory directory structure.""" workspace = tmp_path / "workspace" diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py new file mode 100644 index 000000000..277c85b83 --- /dev/null +++ b/tests/channels/test_signal_channel.py @@ -0,0 +1,1514 @@ +"""Tests for the Signal channel implementation.""" + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from nanobot.bus.events import InboundMessage, OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.signal import ( + SignalChannel, + SignalConfig, + SignalDMConfig, + SignalGroupConfig, +) + +# --------------------------------------------------------------------------- +# Fake HTTP client +# --------------------------------------------------------------------------- + + +class _FakeResponse: + def __init__(self, status_code: int = 200, body: dict | None = None) -> None: + self.status_code = status_code + self._body = body or {} + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self) -> dict: + return self._body + + +class _FakeHTTPClient: + """Minimal httpx.AsyncClient stand-in that records requests.""" + + def __init__(self, *, default_response: dict | None = None) -> None: + self.posts: list[dict] = [] + self.gets: list[str] = [] + self._response = _FakeResponse(body=default_response or {"result": {"timestamp": 123}}) + self.closed = False + + async def get(self, path: str) -> _FakeResponse: + self.gets.append(path) + return self._response + + async def post(self, path: str, *, json: dict) -> _FakeResponse: + self.posts.append({"path": path, "json": json}) + return self._response + + async def aclose(self) -> None: + self.closed = True + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_channel_with_capture(**overrides) -> tuple[SignalChannel, list[dict]]: + """Build a SignalChannel with _handle_message captured into a list and a + no-op _start_typing, used by every receive-flow test class. + """ + ch = _make_channel(**overrides) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + async def noop_typing(chat_id): + pass + + ch._handle_message = capture # type: ignore[method-assign] + ch._start_typing = noop_typing # type: ignore[method-assign] + return ch, handled + + +def _make_channel( + *, + phone_number: str = "+10000000000", + dm_enabled: bool = True, + dm_policy: str = "open", + dm_allow_from: list[str] | None = None, + group_enabled: bool = False, + group_policy: str = "open", + group_allow_from: list[str] | None = None, + require_mention: bool = True, + group_buffer_size: int = 20, + attachments_dir: str | None = None, +) -> SignalChannel: + config = SignalConfig( + enabled=True, + phone_number=phone_number, + dm=SignalDMConfig( + enabled=dm_enabled, + policy=dm_policy, + allow_from=dm_allow_from or [], + ), + group=SignalGroupConfig( + enabled=group_enabled, + policy=group_policy, + allow_from=group_allow_from or [], + require_mention=require_mention, + ), + group_message_buffer_size=group_buffer_size, + attachments_dir=attachments_dir, + ) + return SignalChannel(config, MessageBus()) + + +def _dm_envelope( + *, + source_number: str = "+19995550001", + source_uuid: str | None = None, + source_name: str | None = "Alice", + message: str = "hello", + attachments: list | None = None, + reaction: dict | None = None, + timestamp: int = 1000, +) -> dict: + data_message: dict = {"message": message, "timestamp": timestamp} + if attachments is not None: + data_message["attachments"] = attachments + if reaction is not None: + data_message["reaction"] = reaction + envelope: dict = { + "sourceNumber": source_number, + "sourceName": source_name, + "dataMessage": data_message, + } + if source_uuid: + envelope["sourceUuid"] = source_uuid + return {"envelope": envelope} + + +def _group_envelope( + *, + source_number: str = "+19995550001", + source_name: str = "Bob", + group_id: str = "group123==", + message: str = "hey group", + mentions: list | None = None, + timestamp: int = 2000, + use_v2: bool = False, +) -> dict: + group_obj = {"groupId": group_id} + key = "groupV2" if use_v2 else "groupInfo" + data_message: dict = { + "message": message, + "timestamp": timestamp, + key: group_obj, + "mentions": mentions or [], + } + return { + "envelope": { + "sourceNumber": source_number, + "sourceName": source_name, + "dataMessage": data_message, + } + } + + +# --------------------------------------------------------------------------- +# Static utility tests +# --------------------------------------------------------------------------- + + +class TestNormalizeSignalId: + def test_phone_number_kept_and_stripped(self): + result = SignalChannel._normalize_signal_id("+12345678901") + assert "+12345678901" in result + assert "12345678901" in result + + def test_digits_only_gets_plus_prefix(self): + result = SignalChannel._normalize_signal_id("12345678901") + assert "+12345678901" in result + + def test_lowercase_variant_added(self): + result = SignalChannel._normalize_signal_id("SOME-UUID") + assert "some-uuid" in result + + def test_empty_string_returns_empty(self): + assert SignalChannel._normalize_signal_id("") == [] + + def test_whitespace_stripped(self): + result = SignalChannel._normalize_signal_id(" +1234 ") + assert "+1234" in result + + +class TestCollectSenderIdParts: + def test_collects_source_number(self): + env = {"sourceNumber": "+15551234567"} + parts = SignalChannel._collect_sender_id_parts(env) + assert "+15551234567" in parts + + def test_collects_multiple_keys(self): + env = {"sourceNumber": "+15551234567", "sourceUuid": "uuid-abc"} + parts = SignalChannel._collect_sender_id_parts(env) + assert "+15551234567" in parts + assert "uuid-abc" in parts + + def test_deduplicates(self): + env = {"sourceNumber": "+15551234567", "source": "+15551234567"} + parts = SignalChannel._collect_sender_id_parts(env) + assert parts.count("+15551234567") == 1 + + def test_ignores_non_string_values(self): + env = {"sourceNumber": 12345, "sourceUuid": None} + parts = SignalChannel._collect_sender_id_parts(env) + assert parts == [] + + def test_empty_envelope_returns_empty(self): + assert SignalChannel._collect_sender_id_parts({}) == [] + + +class TestPrimarySenderId: + def test_prefers_phone_number(self): + assert SignalChannel._primary_sender_id(["+1234", "uuid-abc"]) == "+1234" + + def test_accepts_digit_only(self): + assert SignalChannel._primary_sender_id(["1234567890", "uuid-abc"]) == "1234567890" + + def test_falls_back_to_first_part(self): + assert SignalChannel._primary_sender_id(["uuid-abc", "other"]) == "uuid-abc" + + def test_empty_list_returns_empty(self): + assert SignalChannel._primary_sender_id([]) == "" + + +class TestExtractGroupId: + def test_extracts_from_group_info(self): + gid = SignalChannel._extract_group_id({"groupId": "abc=="}, None) + assert gid == "abc==" + + def test_extracts_from_group_v2(self): + gid = SignalChannel._extract_group_id(None, {"id": "xyz=="}) + assert gid == "xyz==" + + def test_prefers_group_info_over_v2(self): + gid = SignalChannel._extract_group_id({"groupId": "first"}, {"groupId": "second"}) + assert gid == "first" + + def test_returns_none_when_both_none(self): + assert SignalChannel._extract_group_id(None, None) is None + + def test_returns_none_when_not_dicts(self): + assert SignalChannel._extract_group_id("bad", 123) is None + + +class TestIsGroupChatId: + def test_base64_with_equals_is_group(self): + assert SignalChannel._is_group_chat_id("abc==") is True + + def test_long_id_without_dash_is_group(self): + long_id = "a" * 41 + assert SignalChannel._is_group_chat_id(long_id) is True + + def test_phone_number_is_not_group(self): + assert SignalChannel._is_group_chat_id("+12345678901") is False + + def test_uuid_with_dashes_is_not_group(self): + assert SignalChannel._is_group_chat_id("550e8400-e29b-41d4-a716-446655440000") is False + + +class TestRecipientParams: + def test_group_chat_uses_group_id(self): + ch = _make_channel() + params = ch._recipient_params("abc==") + assert params == {"groupId": "abc=="} + + def test_dm_uses_recipient_list(self): + ch = _make_channel() + params = ch._recipient_params("+12345678901") + assert params == {"recipient": ["+12345678901"]} + + +class TestMentionHelpers: + def test_mention_id_candidates_extracts_number(self): + mention = {"number": "+1234567890"} + ids = SignalChannel._mention_id_candidates(mention) + assert "+1234567890" in ids + + def test_mention_id_candidates_extracts_uuid(self): + mention = {"uuid": "some-uuid"} + ids = SignalChannel._mention_id_candidates(mention) + assert "some-uuid" in ids + + def test_mention_span_valid(self): + assert SignalChannel._mention_span({"start": 0, "length": 5}) == (0, 5) + + def test_mention_span_negative_start(self): + assert SignalChannel._mention_span({"start": -1, "length": 5}) is None + + def test_mention_span_zero_length(self): + assert SignalChannel._mention_span({"start": 0, "length": 0}) is None + + def test_mention_span_missing_keys(self): + assert SignalChannel._mention_span({}) is None + + def test_leading_placeholder_ufffc(self): + span = SignalChannel._leading_placeholder_span(" hello") + assert span == (0, 1) + + def test_leading_placeholder_not_at_start(self): + assert SignalChannel._leading_placeholder_span("hello ") is None + + def test_leading_placeholder_empty_string(self): + assert SignalChannel._leading_placeholder_span("") is None + + def test_leading_placeholder_plain_text(self): + assert SignalChannel._leading_placeholder_span("hello") is None + + +# --------------------------------------------------------------------------- +# Account ID alias / mention matching +# --------------------------------------------------------------------------- + + +class TestAccountIdAliases: + def test_phone_number_alias_registered_on_init(self): + ch = _make_channel(phone_number="+10000000000") + assert ch._id_matches_account("+10000000000") + + def test_digit_only_variant_matches(self): + ch = _make_channel(phone_number="+10000000000") + assert ch._id_matches_account("10000000000") + + def test_remember_alias_adds_uuid(self): + ch = _make_channel() + ch._remember_account_id_alias("some-uuid-abc") + assert ch._id_matches_account("some-uuid-abc") + + def test_non_matching_id_returns_false(self): + ch = _make_channel(phone_number="+10000000000") + assert not ch._id_matches_account("+19999999999") + + def test_none_and_non_string_return_false(self): + ch = _make_channel() + assert not ch._id_matches_account(None) + + +# --------------------------------------------------------------------------- +# _should_respond_in_group +# --------------------------------------------------------------------------- + + +class TestShouldRespondInGroup: + def _make_group_channel(self, require_mention: bool = True) -> SignalChannel: + return _make_channel( + phone_number="+10000000000", + group_enabled=True, + require_mention=require_mention, + ) + + def test_no_require_mention_always_responds(self): + ch = self._make_group_channel(require_mention=False) + assert ch._should_respond_in_group("anything", []) is True + + def test_require_mention_with_no_mentions_returns_false(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group("hello", []) is False + + def test_require_mention_with_bot_number_mention(self): + ch = self._make_group_channel(require_mention=True) + mentions = [{"number": "+10000000000", "start": 0, "length": 12}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_require_mention_with_uuid_mention(self): + ch = self._make_group_channel(require_mention=True) + ch._remember_account_id_alias("bot-uuid-123") + mentions = [{"uuid": "bot-uuid-123", "start": 0, "length": 8}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_identifier_less_leading_mention_accepted(self): + ch = self._make_group_channel(require_mention=True) + # Mention with no IDs but leading span — treated as bot mention + mentions = [{"start": 0, "length": 1}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_identifier_less_non_leading_mention_rejected(self): + ch = self._make_group_channel(require_mention=True) + mentions = [{"start": 5, "length": 1}] + assert ch._should_respond_in_group("hello ", mentions) is False + + def test_leading_placeholder_without_mentions_metadata(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group(" hello", []) is True + + def test_phone_number_in_text_triggers_response(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group("hey +10000000000 help", []) is True + + +# --------------------------------------------------------------------------- +# _strip_bot_mention +# --------------------------------------------------------------------------- + + +class TestStripBotMention: + def _make_channel_with_number(self) -> SignalChannel: + return _make_channel(phone_number="+10000000000") + + def test_strips_mention_by_phone(self): + ch = self._make_channel_with_number() + text = " hello" + mentions = [{"number": "+10000000000", "start": 0, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + assert result == "hello" + + def test_strips_identifier_less_leading_mention(self): + ch = self._make_channel_with_number() + text = " hello" + mentions = [{"start": 0, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + assert result == "hello" + + def test_strips_leading_placeholder_without_mention_metadata(self): + ch = self._make_channel_with_number() + text = " hello" + result = ch._strip_bot_mention(text, []) + assert result == "hello" + + def test_non_bot_mention_mid_text_not_stripped(self): + # A non-bot mention that is NOT a leading placeholder leaves the text alone. + ch = self._make_channel_with_number() + text = "hello  world" + mentions = [{"number": "+19999999999", "start": 6, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + # Mid-text placeholder from a non-bot mention should be untouched + assert "" in result + + def test_empty_text_returned_unchanged(self): + ch = self._make_channel_with_number() + assert ch._strip_bot_mention("", []) == "" + + +# --------------------------------------------------------------------------- +# Group message buffer +# --------------------------------------------------------------------------- + + +class TestGroupBuffer: + def test_add_and_get_context(self): + ch = _make_channel(group_buffer_size=5) + ch._add_to_group_buffer("g1", "Alice", "+1111", "first msg", 1000) + ch._add_to_group_buffer("g1", "Bob", "+2222", "second msg", 2000) + # Only messages before the latest are returned as context + ctx = ch._get_group_buffer_context("g1") + assert "first msg" in ctx + # The last message is not included (it's the "current" one) + assert "second msg" not in ctx + + def test_empty_context_when_only_one_message(self): + ch = _make_channel(group_buffer_size=5) + ch._add_to_group_buffer("g1", "Alice", "+1111", "only msg", 1000) + assert ch._get_group_buffer_context("g1") == "" + + def test_empty_context_when_group_unknown(self): + ch = _make_channel() + assert ch._get_group_buffer_context("unknown") == "" + + def test_buffer_respects_max_size(self): + ch = _make_channel(group_buffer_size=3) + for i in range(10): + ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i) + assert len(ch._group_buffers["g1"]) == 3 + + def test_zero_buffer_size_rejected_by_validator(self): + with pytest.raises(ValueError, match="group_message_buffer_size"): + _make_channel(group_buffer_size=0) + + def test_negative_buffer_size_rejected_by_validator(self): + with pytest.raises(ValueError, match="group_message_buffer_size"): + _make_channel(group_buffer_size=-1) + + def test_context_limits_message_length(self): + ch = _make_channel(group_buffer_size=5) + long_msg = "x" * 500 + ch._add_to_group_buffer("g1", "Alice", "+1111", long_msg, 1000) + ch._add_to_group_buffer("g1", "Bob", "+2222", "short", 2000) + ctx = ch._get_group_buffer_context("g1") + # Context is capped at 200 chars per message + assert len(ctx.split("Alice: ", 1)[1]) <= 200 + + +# --------------------------------------------------------------------------- +# _handle_data_message — DM routing +# --------------------------------------------------------------------------- + + +class TestIsAllowed: + """The base-channel allowlist gate is overridden to understand Signal's + pipe-joined composite sender_ids and the +/no-+ phone variants. + """ + + def test_denies_when_allowlist_empty(self): + ch = _make_channel(dm_enabled=True, dm_policy="allowlist") + assert ch.is_allowed("+19995550001") is False + + def test_denies_when_no_policy_allows(self): + """When both dm and group are disabled, is_allowed denies.""" + ch = _make_channel(dm_enabled=False, group_enabled=False) + assert ch.is_allowed("+19995550001") is False + + def test_allows_wildcard(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["*"]) + assert ch.is_allowed("+19995550001|some-uuid") is True + + def test_allows_composite_sender_against_split_allowlist(self): + """Composite sender_id, single-id allow_from — must match either part.""" + ch = _make_channel( + dm_policy="allowlist", + dm_allow_from=["+19995550001"], + ) + assert ch.is_allowed("+19995550001|1872ba20-uuid") is True + + def test_allows_composite_sender_against_composite_allowlist_entry(self): + """Backward compat: pipe-joined composite allowlist entries still match.""" + composite = "+19995550001|1872ba20-uuid" + ch = _make_channel(dm_policy="allowlist", dm_allow_from=[composite]) + assert ch.is_allowed(composite) is True + + def test_allows_when_only_uuid_part_is_listed(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["1872ba20-uuid"]) + assert ch.is_allowed("+19995550001|1872ba20-uuid") is True + + def test_denies_when_no_part_matches(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+12223334444"]) + assert ch.is_allowed("+19995550001|1872ba20-uuid") is False + + def test_allowlist_union_includes_group_ids(self): + """allow_from is the union of dm.allow_from and group.allow_from.""" + ch = _make_channel( + group_enabled=True, + group_policy="allowlist", + group_allow_from=["group-id-base64=="], + ) + assert "group-id-base64==" in ch.config.allow_from + + +class TestEndToEndDMRouting: + """End-to-end tests that keep the real _handle_message chain (no mock), + verifying that _check_inbound_policy + _handle_message work together + correctly for DM routing. The override of _handle_message publishes + directly to bus (policy already checked); denied DMs call + super()._handle_message which issues a pairing code. + """ + + @pytest.mark.asyncio + async def test_open_dm_policy_publishes_to_bus(self): + """Open DM: _check_inbound_policy passes → _handle_message publishes.""" + ch = _make_channel(dm_enabled=True, dm_policy="open") + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550001", message="hello") + await ch._handle_receive_notification(params) + + assert len(published) == 1 + assert published[0].content == "hello" + assert published[0].sender_id == "+19995550001" + + @pytest.mark.asyncio + async def test_allowlist_dm_denied_triggers_pairing(self): + """Allowlist DM: denied sender triggers pairing code via send().""" + ch = _make_channel(dm_enabled=True, dm_policy="allowlist", dm_allow_from=[]) + ch._http = _FakeHTTPClient() # type: ignore[assignment] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550002", message="hello") + await ch._handle_receive_notification(params) + + # Should NOT publish to bus — sender is not on allowlist. + assert published == [] + # Should have sent a pairing code via send (captured in HTTP posts). + assert len(ch._http.posts) == 1 # type: ignore[attr-defined] + sent_text = ch._http.posts[0]["json"]["params"]["message"] # type: ignore[attr-defined] + assert "pairing" in sent_text.lower() or "pair" in sent_text.lower() + + @pytest.mark.asyncio + async def test_allowlist_dm_denied_with_group_open_still_pairs(self): + """dm.policy="allowlist" + group.policy="open": denied DM sender + must still get a pairing code, not be leaked by the group open check.""" + ch = _make_channel( + dm_enabled=True, + dm_policy="allowlist", + dm_allow_from=[], + group_enabled=True, + group_policy="open", + ) + ch._http = _FakeHTTPClient() # type: ignore[assignment] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550002", message="hello") + await ch._handle_receive_notification(params) + + assert published == [] + assert len(ch._http.posts) == 1 # type: ignore[attr-defined] + + @pytest.mark.asyncio + async def test_open_group_policy_publishes_to_bus(self): + """Open group: group message from unknown sender publishes to bus.""" + ch = _make_channel( + group_enabled=True, + group_policy="open", + require_mention=False, + ) + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _group_envelope(group_id="grp==", message="hello group") + await ch._handle_receive_notification(params) + + assert len(published) == 1 + assert "hello group" in published[0].content + + +class TestCheckInboundPolicy: + """Direct tests for the policy gate that _handle_data_message now delegates to.""" + + def _call( + self, + ch: SignalChannel, + *, + sender_id: str = "+19995550001", + sender_number: str = "+19995550001", + group_id: str | None = None, + is_group_message: bool = False, + message_text: str = "hi", + mentions: list | None = None, + sender_name: str | None = "Alice", + timestamp: int | None = 1000, + ) -> tuple[bool, str]: + return ch._check_inbound_policy( + sender_id=sender_id, + sender_number=sender_number, + group_id=group_id, + is_group_message=is_group_message, + message_text=message_text, + mentions=mentions or [], + sender_name=sender_name, + timestamp=timestamp, + ) + + def test_dm_open_allows(self): + ch = _make_channel(dm_enabled=True, dm_policy="open") + allowed, chat_id = self._call(ch) + assert allowed is True + assert chat_id == "+19995550001" + + def test_dm_disabled_blocks(self): + ch = _make_channel(dm_enabled=False) + allowed, _ = self._call(ch) + assert allowed is False + + def test_dm_allowlist_blocks_unknown_sender(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+12223334444"]) + allowed, _ = self._call(ch, sender_id="+19995550001") + assert allowed is False + + def test_dm_allowlist_allows_known_sender(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+19995550001"]) + allowed, _ = self._call(ch, sender_id="+19995550001") + assert allowed is True + + def test_group_disabled_blocks(self): + ch = _make_channel(group_enabled=False) + allowed, _ = self._call(ch, is_group_message=True, group_id="g1") + assert allowed is False + + def test_group_open_with_mention_allows(self): + ch = _make_channel( + group_enabled=True, + group_policy="open", + phone_number="+10000000000", + require_mention=True, + ) + allowed, chat_id = self._call( + ch, + is_group_message=True, + group_id="g1", + message_text="hello @bot", + mentions=[{"number": "+10000000000", "start": 6, "length": 4}], + ) + assert allowed is True + assert chat_id == "g1" + + def test_group_open_without_mention_blocks(self): + ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) + allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="plain talk") + assert allowed is False + + def test_group_command_bypasses_mention_requirement(self): + ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) + allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="/help") + assert allowed is True + + def test_allowed_group_appends_to_buffer(self): + """Side effect: when a group message is allowed, it lands in the buffer.""" + ch = _make_channel(group_enabled=True, group_policy="open", require_mention=False) + self._call(ch, is_group_message=True, group_id="g1", message_text="first") + self._call(ch, is_group_message=True, group_id="g1", message_text="second") + assert len(ch._group_buffers["g1"]) == 2 + + def test_blocked_group_does_not_append_to_buffer(self): + """Side effect: when a group is disabled, the buffer must not change.""" + ch = _make_channel(group_enabled=False) + self._call(ch, is_group_message=True, group_id="g1", message_text="hi") + assert "g1" not in ch._group_buffers + + +class TestAttachmentsDir: + def test_default_attachments_dir(self): + ch = _make_channel() + expected = Path.home() / ".local/share/signal-cli/attachments" + assert ch._signal_attachments_dir() == expected + + def test_configured_attachments_dir(self, tmp_path): + ch = _make_channel(attachments_dir=str(tmp_path / "custom")) + assert ch._signal_attachments_dir() == tmp_path / "custom" + + def test_attachments_dir_expands_user(self): + ch = _make_channel(attachments_dir="~/signal-attachments") + assert ch._signal_attachments_dir() == Path.home() / "signal-attachments" + + +class TestHandleDataMessageDM: + def _make_dm_channel(self, policy="open", allow_from=None) -> tuple[SignalChannel, list]: + return _make_channel_with_capture( + dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or [] + ) + + @pytest.mark.asyncio + async def test_dm_open_policy_accepted(self): + ch, handled = self._make_dm_channel(policy="open") + params = _dm_envelope(source_number="+19995550001", message="hi") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "+19995550001" + assert handled[0]["content"] == "hi" + + @pytest.mark.asyncio + async def test_dm_allowlist_accepted(self): + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"]) + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_rejected_triggers_pairing(self): + # Denied DM senders go through super()._handle_message which checks + # is_allowed → sends pairing code via self.send(). + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+10000000001"]) + ch._http = _FakeHTTPClient() # type: ignore[attr-defined] + params = _dm_envelope(source_number="+19995550002") + await ch._handle_receive_notification(params) + # The denied DM path calls super()._handle_message, not self._handle_message, + # so the capture list stays empty. Verify pairing code was sent via HTTP. + assert handled == [] + assert len(ch._http.posts) == 1 # type: ignore[attr-defined] + sent_text = ch._http.posts[0]["json"]["params"]["message"] # type: ignore[attr-defined] + assert "pairing" in sent_text.lower() or "pair" in sent_text.lower() + + @pytest.mark.asyncio + async def test_dm_paired_sender_allowed_without_allowlist_entry(self, monkeypatch): + # Once a sender completes pairing they should pass is_allowed on every + # subsequent message — otherwise the pairing reply loops forever. + approved = {"+19995550002"} + monkeypatch.setattr( + "nanobot.channels.signal.is_approved", + lambda channel, sender_id: sender_id in approved, + ) + ch = _make_channel(dm_enabled=True, dm_policy="allowlist", dm_allow_from=[]) + assert ch.is_allowed("+19995550002") is True + # Variant forms (with/without "+") must still match a stored approval. + assert ch.is_allowed("19995550002") is True + # Unpaired sender stays denied. + assert ch.is_allowed("+19995559999") is False + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_without_plus_prefix(self): + """An allowlist entry without '+' must match a sender that carries '+'.""" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["19995550001"]) + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_with_plus_prefix(self): + """An allowlist entry with '+' must match a sender without '+'.""" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"]) + params = _dm_envelope(source_number="+19995550001", source_uuid=None) + # Replace envelope's sourceNumber with the non-prefixed form by editing + # the constructed dict directly so _collect_sender_id_parts sees it. + params["envelope"]["sourceNumber"] = "19995550001" + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_uuid_case_insensitive(self): + """UUID matching must be case-insensitive.""" + uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[uuid.lower()]) + params = _dm_envelope(source_number="+19995550001", source_uuid=uuid) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_pipe_joined_composite_entry(self): + """Allowlist entries written as ``phone|uuid`` composites still work. + + Some configs pre-date the per-part splitting and store the full + sender_id composite as a single allow_from entry. Keep matching it. + """ + composite = "+19995550001|1872ba20-f52a-4bad-b434-bf7f808c8b22" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[composite]) + params = _dm_envelope( + source_number="+19995550001", + source_uuid="1872ba20-f52a-4bad-b434-bf7f808c8b22", + ) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_disabled_rejected(self): + ch = _make_channel(dm_enabled=False) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + ch._handle_message = capture # type: ignore[method-assign] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_reaction_message_ignored(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(reaction={"emoji": "👍", "targetTimestamp": 999}) + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_empty_message_ignored(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(message="") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_receipt_message_ignored(self): + ch, handled = self._make_dm_channel() + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "receiptMessage": {"when": 1234}, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + @pytest.mark.asyncio + async def test_typing_indicator_ignored(self): + ch, handled = self._make_dm_channel() + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "typingMessage": {"action": "STARTED"}, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + @pytest.mark.asyncio + async def test_missing_envelope_ignored(self): + ch, handled = self._make_dm_channel() + await ch._handle_receive_notification({}) + assert handled == [] + + @pytest.mark.asyncio + async def test_metadata_passed_to_handle(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(source_number="+19995550001", source_name="Alice", timestamp=9999) + await ch._handle_receive_notification(params) + meta = handled[0]["metadata"] + assert meta["sender_name"] == "Alice" + assert meta["timestamp"] == 9999 + assert meta["is_group"] is False + + @pytest.mark.asyncio + async def test_sender_id_with_uuid_variant(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(source_number="+19995550001", source_uuid="uuid-abc") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + # sender_id combines both parts + assert "+19995550001" in handled[0]["sender_id"] + assert "uuid-abc" in handled[0]["sender_id"] + + @pytest.mark.asyncio + async def test_stop_typing_called_on_handle_error(self): + ch = _make_channel(dm_enabled=True, dm_policy="open") + typing_stopped: list[str] = [] + + async def fail_handle(**kwargs): + raise RuntimeError("boom") + + async def noop_typing(chat_id): + pass + + async def record_stop(chat_id, **kwargs): + typing_stopped.append(chat_id) + + ch._handle_message = fail_handle # type: ignore[method-assign] + ch._start_typing = noop_typing # type: ignore[method-assign] + ch._stop_typing = record_stop # type: ignore[method-assign] + + # _handle_receive_notification swallows exceptions; the typing stop + # still fires from _handle_data_message's except clause. + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + + assert "+19995550001" in typing_stopped + + +# --------------------------------------------------------------------------- +# _handle_data_message — group routing +# --------------------------------------------------------------------------- + + +class TestHandleDataMessageGroup: + def _make_group_channel( + self, + policy="open", + allow_from=None, + require_mention=True, + ) -> tuple[SignalChannel, list]: + return _make_channel_with_capture( + group_enabled=True, + group_policy=policy, + group_allow_from=allow_from or [], + require_mention=require_mention, + ) + + @pytest.mark.asyncio + async def test_group_disabled_rejected(self): + ch = _make_channel(group_enabled=False) + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_open_policy_no_mention_blocked_when_required(self): + ch, handled = self._make_group_channel(require_mention=True) + params = _group_envelope(group_id="grp==", message="hey everyone") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_open_policy_no_mention_required(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", message="hey everyone") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "grp==" + + @pytest.mark.asyncio + async def test_group_allowlist_accepted(self): + ch, handled = self._make_group_channel( + policy="allowlist", allow_from=["grp=="], require_mention=False + ) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_group_allowlist_rejected(self): + ch, handled = self._make_group_channel(policy="allowlist", allow_from=["other=="]) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_mention_triggers_response(self): + ch, handled = self._make_group_channel(require_mention=True) + ch._remember_account_id_alias("+10000000000") + mentions = [{"number": "+10000000000", "start": 0, "length": 1}] + params = _group_envelope(group_id="grp==", message=" hello", mentions=mentions) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_group_v2_id_extracted(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grpV2==", message="hi", use_v2=True) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "grpV2==" + + @pytest.mark.asyncio + async def test_group_message_includes_sender_prefix(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", source_name="Bob", message="hello") + await ch._handle_receive_notification(params) + assert "[Bob]:" in handled[0]["content"] + + @pytest.mark.asyncio + async def test_group_message_context_prepended(self): + ch, handled = self._make_group_channel(require_mention=False) + # First message — adds to buffer but no context yet + params1 = _group_envelope(group_id="grp==", source_name="Alice", message="msg1") + await ch._handle_receive_notification(params1) + # Second message — should include context from first + params2 = _group_envelope(group_id="grp==", source_name="Bob", message="msg2") + await ch._handle_receive_notification(params2) + assert "[Recent group messages for context:]" in handled[1]["content"] + assert "msg1" in handled[1]["content"] + + @pytest.mark.asyncio + async def test_group_metadata_marks_is_group(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled[0]["metadata"]["is_group"] is True + assert handled[0]["metadata"]["group_id"] == "grp==" + + @pytest.mark.asyncio + async def test_bot_account_alias_learned_from_incoming(self): + ch, handled = self._make_group_channel(require_mention=False) + # If the bot's own UUID appears in an envelope we learn it + params = _dm_envelope(source_number="+10000000000", source_uuid="new-bot-uuid") + # DMs from self are processed (learning alias), but DM policy is open + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + ch._start_typing = lambda chat_id: None # type: ignore[method-assign] + await ch._handle_receive_notification(params) + assert ch._id_matches_account("new-bot-uuid") + + +# --------------------------------------------------------------------------- +# Lifecycle / SSE +# --------------------------------------------------------------------------- + + +class _FakeSSEResponse: + """Minimal stand-in for httpx Response under stream().""" + + def __init__(self, lines: list[str], status_code: int = 200) -> None: + self.status_code = status_code + self._lines = lines + + async def aiter_lines(self): + for line in self._lines: + yield line + + +def _fake_streaming_client(lines: list[str], *, status_code: int = 200) -> MagicMock: + """Return an httpx.AsyncClient stand-in whose .stream() yields a FakeSSEResponse.""" + response = _FakeSSEResponse(lines, status_code=status_code) + + @asynccontextmanager + async def _ctx(*_args, **_kwargs): + yield response + + http = MagicMock() + http.stream = lambda *a, **kw: _ctx(*a, **kw) + return http + + +class TestLifecycle: + @pytest.mark.asyncio + async def test_start_returns_early_when_phone_missing(self): + """start() with an empty phone number must not enter the HTTP loop.""" + ch = _make_channel(phone_number="") + await ch.start() + assert ch._running is False + assert ch._http is None + assert ch._sse_task is None + + +class TestSSEReceiveLoop: + @pytest.mark.asyncio + async def test_dispatches_valid_envelope(self): + ch = _make_channel() + ch._running = True + + captured: list[dict] = [] + + async def capture(params): + captured.append(params) + + ch._handle_receive_notification = capture # type: ignore[method-assign] + ch._http = _fake_streaming_client( + ['data: {"envelope":{"sourceNumber":"+19995550001"}}', ""] + ) + + # Loop ends when lines exhaust; the surrounding _start_http_mode would + # treat that as a disconnect, but the loop itself raises ConnectionError + # when the stream closes while still running. + with pytest.raises(ConnectionError): + await ch._sse_receive_loop() + assert captured == [{"envelope": {"sourceNumber": "+19995550001"}}] + + @pytest.mark.asyncio + async def test_handles_invalid_json_frame(self): + """An unparseable SSE frame is logged and skipped without crashing.""" + ch = _make_channel() + ch._running = True + + captured: list[dict] = [] + + async def capture(params): + captured.append(params) + + ch._handle_receive_notification = capture # type: ignore[method-assign] + ch._http = _fake_streaming_client( + [ + "data: this-is-not-json", + "", # event boundary triggers parse attempt + 'data: {"envelope":{"sourceNumber":"+1"}}', + "", + ] + ) + + with pytest.raises(ConnectionError): + await ch._sse_receive_loop() + # Bad frame skipped; good frame still dispatched. + assert captured == [{"envelope": {"sourceNumber": "+1"}}] + + @pytest.mark.asyncio + async def test_non_200_status_raises(self): + ch = _make_channel() + ch._running = True + ch._http = _fake_streaming_client([], status_code=503) + with pytest.raises(ConnectionError, match="status 503"): + await ch._sse_receive_loop() + + @pytest.mark.asyncio + async def test_no_http_client_raises(self): + ch = _make_channel() + ch._http = None + with pytest.raises(RuntimeError, match="HTTP client not initialized"): + await ch._sse_receive_loop() + + +# --------------------------------------------------------------------------- +# Command handling +# --------------------------------------------------------------------------- + + +class TestCommandHandling: + @pytest.mark.asyncio + async def test_dm_command_forwarded_to_bus(self): + """Slash commands in DMs are forwarded to the bus for AgentLoop to handle.""" + ch, forwarded = _make_channel_with_capture(dm_enabled=True, dm_policy="open") + params = _dm_envelope(source_number="+19995550001", message="/reset") + await ch._handle_receive_notification(params) + assert len(forwarded) == 1 + assert forwarded[0]["content"].strip() == "/reset" + + @pytest.mark.asyncio + async def test_group_command_bypasses_mention_requirement(self): + """Slash commands in groups bypass the mention requirement and reach the bus.""" + ch, forwarded = _make_channel_with_capture( + group_enabled=True, group_policy="open", require_mention=True + ) + params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset") + await ch._handle_receive_notification(params) + assert len(forwarded) == 1 + assert "/reset" in forwarded[0]["content"] + + @pytest.mark.asyncio + async def test_command_denied_for_disallowed_dm_sender(self): + """Commands from senders not on the DM allowlist are dropped.""" + ch, forwarded = _make_channel_with_capture(dm_enabled=False) + params = _dm_envelope(source_number="+19995550001", message="/reset") + await ch._handle_receive_notification(params) + assert forwarded == [] + + +# --------------------------------------------------------------------------- +# send() — outbound messages +# --------------------------------------------------------------------------- + + +class TestSend: + def _make_send_channel(self) -> tuple[SignalChannel, _FakeHTTPClient]: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + return ch, client + + @pytest.mark.asyncio + async def test_send_plain_text_posts_rpc(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") + await ch.send(msg) + assert len(client.posts) == 1 + payload = client.posts[0]["json"] + assert payload["method"] == "send" + assert payload["params"]["message"] == "hello" + + @pytest.mark.asyncio + async def test_send_with_markdown_includes_text_styles(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="**bold**") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "textStyle" in params + assert any("BOLD" in s for s in params["textStyle"]) + + @pytest.mark.asyncio + async def test_send_split_message_redistributes_text_styles(self): + """Long message split across chunks: each chunk gets its own textStyle + with offsets rebased to that chunk.""" + ch, client = self._make_send_channel() + ch._MAX_MESSAGE_LEN = 12 # type: ignore[attr-defined] + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="**head** middle and **tail**", + ) + await ch.send(msg) + assert len(client.posts) >= 2 + # Chunk 0 has BOLD for "head"; chunk 1+ must also carry BOLD for "tail". + bold_chunks = [ + p["json"]["params"] + for p in client.posts + if any("BOLD" in s for s in p["json"]["params"].get("textStyle", [])) + ] + assert len(bold_chunks) >= 2, ( + "expected BOLD ranges in more than one chunk; got " + f"{[p['json']['params'] for p in client.posts]}" + ) + # Each emitted range must point inside its own chunk's text. + for params in bold_chunks: + chunk_text = params["message"] + for entry in params["textStyle"]: + s, ln, _ = entry.split(":", 2) + start, length = int(s), int(ln) + end_units = start + length + assert end_units <= len(chunk_text.encode("utf-16-le")) // 2 + + @pytest.mark.asyncio + async def test_send_empty_content_skips_rpc(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="") + await ch.send(msg) + assert client.posts == [] + + @pytest.mark.asyncio + async def test_send_to_group_uses_group_id(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="grp==", content="hi group") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "groupId" in params + assert "recipient" not in params + + @pytest.mark.asyncio + async def test_send_to_dm_uses_recipient(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hi") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "recipient" in params + + @pytest.mark.asyncio + async def test_send_with_media_includes_attachments(self): + ch, client = self._make_send_channel() + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="see attachment", + media=["/tmp/file.jpg"], + ) + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert params.get("attachments") == ["/tmp/file.jpg"] + + @pytest.mark.asyncio + async def test_send_progress_message_does_not_stop_typing(self): + ch, client = self._make_send_channel() + stopped: list[str] = [] + + async def record_stop(chat_id, **kwargs): + stopped.append(chat_id) + + ch._stop_typing = record_stop # type: ignore[method-assign] + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="working...", + metadata={"_progress": True}, + ) + await ch.send(msg) + # Progress messages should NOT stop the typing indicator + assert stopped == [] + + @pytest.mark.asyncio + async def test_send_final_message_stops_typing(self): + ch, client = self._make_send_channel() + stopped: list[str] = [] + + async def record_stop(chat_id, send_stop=True): + stopped.append(chat_id) + + ch._stop_typing = record_stop # type: ignore[method-assign] + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="done") + await ch.send(msg) + assert "+19995550001" in stopped + + @pytest.mark.asyncio + async def test_send_raises_on_daemon_error(self): + # _send_http_request turns every exception into {"error": ...}, so this branch + # is the only place ChannelManager retry can be triggered — must raise. + ch = _make_channel() + ch._http = _FakeHTTPClient(default_response={"error": {"message": "fail"}}) + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") + with pytest.raises(RuntimeError, match="signal-cli send failed"): + await ch.send(msg) + + +# --------------------------------------------------------------------------- +# stop() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stop_cancels_sse_task() -> None: + ch = _make_channel() + cancelled = False + + async def long_running(): + nonlocal cancelled + try: + await asyncio.sleep(9999) + except asyncio.CancelledError: + cancelled = True + raise + + ch._sse_task = asyncio.create_task(long_running()) + # Yield so the task can enter its body (reach the first await) before cancel. + await asyncio.sleep(0) + ch._running = True + + await ch.stop() + + assert cancelled + assert ch._running is False + + +@pytest.mark.asyncio +async def test_stop_closes_http_client() -> None: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + ch._running = True + + await ch.stop() + + assert client.closed + + +@pytest.mark.asyncio +async def test_stop_safe_when_no_sse_task() -> None: + ch = _make_channel() + ch._running = True + # Should not raise even with no _sse_task + await ch.stop() + assert ch._running is False + + +# --------------------------------------------------------------------------- +# _send_request / _send_http_request +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_send_request_increments_id() -> None: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + + await ch._send_request("testMethod", {"key": "val"}) + await ch._send_request("testMethod", {"key": "val"}) + + ids = [p["json"]["id"] for p in client.posts] + assert ids == [1, 2] + + +@pytest.mark.asyncio +async def test_send_request_raises_when_not_connected() -> None: + ch = _make_channel() + # _http is None by default + with pytest.raises(RuntimeError, match="Not connected"): + await ch._send_request("testMethod") + + +# --------------------------------------------------------------------------- +# _handle_receive_notification — envelope shapes +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_handle_notification_sync_message_does_not_forward() -> None: + ch = _make_channel(dm_enabled=True, dm_policy="open") + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "syncMessage": { + "sentMessage": { + "destination": "+19990000000", + "message": "sent from other device", + } + }, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + +@pytest.mark.asyncio +async def test_handle_notification_no_source_skipped() -> None: + ch = _make_channel(dm_enabled=True, dm_policy="open") + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + + notification = {"envelope": {"dataMessage": {"message": "ghost"}}} + await ch._handle_receive_notification(notification) + assert handled == [] + + +# --------------------------------------------------------------------------- +# Config: allow_from property aggregation +# --------------------------------------------------------------------------- + + +def test_config_allow_from_aggregates_dm_and_group() -> None: + config = SignalConfig( + enabled=True, + phone_number="+10000000000", + dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]), + group=SignalGroupConfig(enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]), + ) + combined = config.allow_from + assert "+1111" in combined + assert "+2222" in combined + assert "+3333" in combined + # Duplicates removed + assert combined.count("+1111") == 1 + + +def test_config_allow_from_wildcard_propagates() -> None: + config = SignalConfig( + enabled=True, + phone_number="+10000000000", + dm=SignalDMConfig(enabled=True, policy="open", allow_from=["*"]), + group=SignalGroupConfig(enabled=True, policy="open", allow_from=[]), + ) + assert "*" in config.allow_from diff --git a/tests/channels/test_signal_markdown.py b/tests/channels/test_signal_markdown.py new file mode 100644 index 000000000..37a21c6d8 --- /dev/null +++ b/tests/channels/test_signal_markdown.py @@ -0,0 +1,525 @@ +"""Unit tests for the Signal markdown → plain text + textStyle converter.""" + +from nanobot.channels.signal import _markdown_to_signal, _partition_styles +from nanobot.utils.helpers import split_message + + +def _utf16_len(s: str) -> int: + return len(s.encode("utf-16-le")) // 2 + + +def styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]: + """Return a dict mapping each styled substring to its style list.""" + result: dict[str, list[str]] = {} + for entry in text_styles: + start_s, length_s, style = entry.split(":", 2) + start, length = int(start_s), int(length_s) + span = plain[start : start + length] + result.setdefault(span, []).append(style) + return result + + +def utf16_styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]: + """Like styles_for, but slices `plain` using UTF-16 offsets (Signal's units).""" + encoded = plain.encode("utf-16-le") + result: dict[str, list[str]] = {} + for entry in text_styles: + start_s, length_s, style = entry.split(":", 2) + start, length = int(start_s), int(length_s) + span = encoded[start * 2 : (start + length) * 2].decode("utf-16-le") + result.setdefault(span, []).append(style) + return result + + +# --------------------------------------------------------------------------- +# Basic cases +# --------------------------------------------------------------------------- + + +def test_empty(): + plain, styles = _markdown_to_signal("") + assert plain == "" + assert styles == [] + + +def test_plain_text(): + plain, styles = _markdown_to_signal("hello world") + assert plain == "hello world" + assert styles == [] + + +def test_bold_stars(): + plain, styles = _markdown_to_signal("say **hello** now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["BOLD"]} + + +def test_bold_underscores(): + plain, styles = _markdown_to_signal("say __hello__ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["BOLD"]} + + +def test_italic_star(): + plain, styles = _markdown_to_signal("say *hello* now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["ITALIC"]} + + +def test_italic_underscore(): + plain, styles = _markdown_to_signal("say _hello_ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["ITALIC"]} + + +def test_strikethrough(): + plain, styles = _markdown_to_signal("say ~~hello~~ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["STRIKETHROUGH"]} + + +# --------------------------------------------------------------------------- +# Code +# --------------------------------------------------------------------------- + + +def test_inline_code(): + plain, styles = _markdown_to_signal("run `ls -la` here") + assert plain == "run ls -la here" + assert styles_for(plain, styles) == {"ls -la": ["MONOSPACE"]} + + +def test_code_block(): + plain, styles = _markdown_to_signal("```\nprint('hi')\n```") + assert "print('hi')" in plain + assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or "MONOSPACE" in str( + styles_for(plain, styles) + ) + + +def test_code_block_with_lang(): + plain, styles = _markdown_to_signal("```python\ncode\n```") + assert "code" in plain + assert any("MONOSPACE" in s for s in styles) + + +def test_code_block_not_processed_further(): + """Markdown inside a code block must not be styled.""" + plain, styles = _markdown_to_signal("```\n**not bold**\n```") + assert "**not bold**" in plain + # Only MONOSPACE should be applied, no BOLD + for entry in styles: + assert "BOLD" not in entry + + +def test_inline_code_not_processed_further(): + """Markdown inside inline code must not be styled.""" + plain, styles = _markdown_to_signal("use `**raw**` please") + assert "**raw**" in plain + for entry in styles: + assert "BOLD" not in entry + + +# --------------------------------------------------------------------------- +# Headers +# --------------------------------------------------------------------------- + + +def test_header_becomes_bold(): + plain, styles = _markdown_to_signal("# My Title") + assert plain == "My Title" + assert styles_for(plain, styles) == {"My Title": ["BOLD"]} + + +def test_h2_becomes_bold(): + plain, styles = _markdown_to_signal("## Sub-section") + assert plain == "Sub-section" + assert styles_for(plain, styles) == {"Sub-section": ["BOLD"]} + + +# --------------------------------------------------------------------------- +# Blockquotes +# --------------------------------------------------------------------------- + + +def test_blockquote_strips_marker(): + plain, styles = _markdown_to_signal("> some quote") + assert plain == "some quote" + assert styles == [] + + +# --------------------------------------------------------------------------- +# Lists +# --------------------------------------------------------------------------- + + +def test_bullet_dash(): + plain, styles = _markdown_to_signal("- item one") + assert plain == "• item one" + + +def test_bullet_star(): + plain, styles = _markdown_to_signal("* item two") + assert plain == "• item two" + + +def test_numbered_list(): + plain, styles = _markdown_to_signal("1. first\n2. second") + assert "1. first" in plain + assert "2. second" in plain + + +# --------------------------------------------------------------------------- +# Links +# --------------------------------------------------------------------------- + + +def test_link_text_differs_from_url(): + plain, styles = _markdown_to_signal("[Click here](https://example.com)") + assert plain == "Click here (https://example.com)" + assert styles == [] + + +def test_link_text_equals_url(): + plain, styles = _markdown_to_signal("[https://example.com](https://example.com)") + assert plain == "https://example.com" + assert styles == [] + + +def test_link_text_equals_url_without_scheme(): + plain, styles = _markdown_to_signal("[example.com](https://example.com)") + assert plain == "https://example.com" + + +# --------------------------------------------------------------------------- +# Mixed / nesting +# --------------------------------------------------------------------------- + + +def test_bold_and_italic_adjacent(): + plain, styles = _markdown_to_signal("**bold** and *italic*") + assert plain == "bold and italic" + sd = styles_for(plain, styles) + assert sd.get("bold") == ["BOLD"] + assert sd.get("italic") == ["ITALIC"] + + +def test_header_with_inline_code(): + """Header becomes BOLD; code inside becomes MONOSPACE (not double-BOLD).""" + plain, styles = _markdown_to_signal("# Use `grep`") + assert plain == "Use grep" + sd = styles_for(plain, styles) + assert "BOLD" in sd.get("Use ", []) or "BOLD" in str(styles) + assert "MONOSPACE" in sd.get("grep", []) + + +def test_multiline_mixed(): + md = "**Title**\n\nSome *italic* text.\n\n- bullet\n- another" + plain, styles = _markdown_to_signal(md) + assert "Title" in plain + assert "italic" in plain + assert "• bullet" in plain + sd = styles_for(plain, styles) + assert "BOLD" in sd.get("Title", []) + assert "ITALIC" in sd.get("italic", []) + + +# --------------------------------------------------------------------------- +# Table rendering +# --------------------------------------------------------------------------- + + +def test_table_rendered_as_monospace(): + md = "| A | B |\n| - | - |\n| 1 | 2 |" + plain, styles = _markdown_to_signal(md) + assert "A" in plain and "B" in plain + assert any("MONOSPACE" in s for s in styles) + + +# --------------------------------------------------------------------------- +# Style range format +# --------------------------------------------------------------------------- + + +def test_style_range_format(): + """Each style entry must be 'start:length:STYLE'.""" + _, styles = _markdown_to_signal("**bold** text") + for entry in styles: + parts = entry.split(":") + assert len(parts) == 3 + assert parts[0].isdigit() + assert parts[1].isdigit() + assert parts[2] in {"BOLD", "ITALIC", "STRIKETHROUGH", "MONOSPACE", "SPOILER"} + + +def test_style_ranges_are_within_bounds(): + text = "hello **world** end" + plain, styles = _markdown_to_signal(text) + for entry in styles: + start_s, length_s, _ = entry.split(":", 2) + start, length = int(start_s), int(length_s) + assert start >= 0 + assert start + length <= len(plain) + + +# --------------------------------------------------------------------------- +# Non-BMP / UTF-16 offsets +# +# Signal's BodyRange (and signal-cli's textStyle) interprets start/length in +# UTF-16 code units. Python's len() counts code points, so characters outside +# the BMP (emojis, supplementary CJK) shift offsets by +1 per occurrence. +# --------------------------------------------------------------------------- + + +def assert_within_utf16_bounds(plain: str, styles: list[str]) -> None: + limit = _utf16_len(plain) + for entry in styles: + start_s, length_s, _ = entry.split(":", 2) + start, length = int(start_s), int(length_s) + assert start >= 0 + assert start + length <= limit, f"range {entry} exceeds utf-16 length {limit} of {plain!r}" + + +def test_bold_with_emoji_inside(): + plain, styles = _markdown_to_signal("**hi 🎉 bye**") + assert plain == "hi 🎉 bye" + assert utf16_styles_for(plain, styles) == {"hi 🎉 bye": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_italic_with_trailing_emoji(): + plain, styles = _markdown_to_signal("*bye 🎉*") + assert plain == "bye 🎉" + assert utf16_styles_for(plain, styles) == {"bye 🎉": ["ITALIC"]} + assert_within_utf16_bounds(plain, styles) + + +def test_bold_after_emoji_prefix(): + plain, styles = _markdown_to_signal("🎉 **bold**") + assert plain == "🎉 bold" + assert utf16_styles_for(plain, styles) == {"bold": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_bold_after_and_inside_emoji(): + plain, styles = _markdown_to_signal("🎉 **a 🎊 b**") + assert plain == "🎉 a 🎊 b" + assert utf16_styles_for(plain, styles) == {"a 🎊 b": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_supplementary_cjk_in_bold(): + """Non-BMP CJK (U+20BB7) proves the bug is UTF-16, not emoji-specific.""" + plain, styles = _markdown_to_signal("**𠮷野家**") + assert plain == "𠮷野家" + assert utf16_styles_for(plain, styles) == {"𠮷野家": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_zwj_emoji_in_bold(): + """ZWJ family sequence = multiple surrogate pairs + BMP ZWJs.""" + plain, styles = _markdown_to_signal("**hi 👨‍👩‍👧 bye**") + assert plain == "hi 👨‍👩‍👧 bye" + assert utf16_styles_for(plain, styles) == {"hi 👨‍👩‍👧 bye": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_ascii_offsets_unchanged(): + """ASCII-only path must produce the same offsets as before the UTF-16 fix.""" + plain, styles = _markdown_to_signal("**bold** plain *it*") + assert plain == "bold plain it" + assert sorted(styles) == sorted(["0:4:BOLD", "11:2:ITALIC"]) + + +def test_reported_daily_brief_pattern(): + """Regression for the reported bug: a single non-BMP emoji shifts every + subsequent styled span left by 1 UTF-16 unit, lopping off the last letter. + """ + md = ( + "**Weather**\n" + "- Conditions: 🌩️ Thunderstorms\n\n" + "**News**\n" + "*World*\n" + "*Local*\n\n" + "**Quote of the Day**" + ) + plain, styles = _markdown_to_signal(md) + sd = utf16_styles_for(plain, styles) + assert sd.get("Weather") == ["BOLD"] + assert sd.get("News") == ["BOLD"] + assert sd.get("World") == ["ITALIC"] + assert sd.get("Local") == ["ITALIC"] + assert sd.get("Quote of the Day") == ["BOLD"] + assert_within_utf16_bounds(plain, styles) + + +# --------------------------------------------------------------------------- +# Chunk redistribution +# +# split_message can break a long Signal payload into multiple chunks. The +# style ranges from _markdown_to_signal are anchored to the full text, so +# they must be redistributed per-chunk with rebased offsets — otherwise +# styles for chunks 1..N are silently lost. +# --------------------------------------------------------------------------- + + +def _resolve_chunk_styles(text: str, max_len: int) -> tuple[list[str], list[list[str]]]: + """Helper: full markdown → signal pipeline, including chunking.""" + plain, styles = _markdown_to_signal(text) + chunks = split_message(plain, max_len) if plain else [""] + return chunks, _partition_styles(plain, chunks, styles) + + +def test_partition_styles_single_chunk_passthrough(): + plain, styles = _markdown_to_signal("**bold** plain *it*") + parts = _partition_styles(plain, [plain], styles) + assert parts == [styles] + + +def test_partition_styles_no_styles(): + plain = "hello world" + assert _partition_styles(plain, [plain], []) == [[]] + assert _partition_styles(plain, ["hello", "world"], []) == [[], []] + + +def test_partition_styles_drops_styles_outside_chunks(): + """Whitespace trimmed by split_message must not carry a style range.""" + plain = "a b" + # Fake a style spanning the trimmed whitespace only. + chunks = ["a", "b"] + parts = _partition_styles(plain, chunks, ["1:3:BOLD"]) + assert parts == [[], []] + + +def test_partition_styles_long_message_preserves_chunk_one_styles(): + """A bold span deep in the message must follow the message into chunk 1.""" + # Two ~30-char paragraphs separated by a blank line, then **tail**. + line_a = "alpha " * 5 # 30 chars, ends with space + line_b = "beta " * 5 + md = f"{line_a.strip()}\n\n{line_b.strip()}\n\n**tail**" + plain, styles = _markdown_to_signal(md) + # Force a split between the paragraphs. + max_len = len(line_a.strip()) + 2 # fits paragraph A + the "\n\n" + chunks = split_message(plain, max_len) + assert len(chunks) >= 2, "test setup must produce a split" + parts = _partition_styles(plain, chunks, styles) + # The bold "tail" should land in the last chunk, with chunk-relative offset. + final_chunk = chunks[-1] + final_styles = parts[-1] + assert any("BOLD" in s for s in final_styles) + for entry in final_styles: + s, ln, _ = entry.split(":", 2) + start, length = int(s), int(ln) + slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode( + "utf-16-le" + ) + assert slice_ == "tail" + + +def test_partition_styles_chunk_zero_styles_unchanged(): + """Styles entirely in chunk 0 keep their original offsets.""" + md = "**head** middle and **tail**" + plain, styles = _markdown_to_signal(md) + # Split so chunk 0 contains "head" and part of the rest, chunk 1 contains "tail". + chunks = split_message(plain, 12) + assert len(chunks) >= 2 + parts = _partition_styles(plain, chunks, styles) + # "head" lives in chunk 0; assert its offset is unchanged (chunk 0 starts at 0). + head_entries = [s for s in parts[0] if "BOLD" in s] + assert any(s.startswith("0:4:") for s in head_entries) + + +def test_partition_styles_with_non_bmp_chunk_offset(): + """Chunk-start offsets must be expressed in UTF-16 code units.""" + # Emoji in chunk 0, bold in chunk 1. + md = "🎉 alpha beta gamma\n\n**tail**" + plain, styles = _markdown_to_signal(md) + chunks = split_message(plain, 18) + assert len(chunks) >= 2 + parts = _partition_styles(plain, chunks, styles) + final_styles = parts[-1] + assert any("BOLD" in s for s in final_styles) + final_chunk = chunks[-1] + for entry in final_styles: + s, ln, _ = entry.split(":", 2) + start, length = int(s), int(ln) + slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode( + "utf-16-le" + ) + assert slice_ == "tail" + + +def test_partition_styles_range_spanning_chunks_is_split(): + """A style range that straddles a chunk boundary gets sliced into both chunks.""" + # Construct manually: plain = "abc def", style covers "abc def" (whole thing). + plain = "abc def" + chunks = split_message(plain, 4) # "abc" / "def" + assert chunks == ["abc", "def"] + parts = _partition_styles(plain, chunks, ["0:7:BOLD"]) + # Chunk 0 holds 0:3:BOLD, chunk 1 holds 0:3:BOLD (length=3 each, "def" only + # since the space was trimmed by lstrip). + assert parts[0] == ["0:3:BOLD"] + assert parts[1] == ["0:3:BOLD"] + + +# --------------------------------------------------------------------------- +# Adjacency, nesting, and malformed input +# --------------------------------------------------------------------------- + + +def test_bold_italic_combo_outer_bold_inner_italic(): + """`**_combo_**` carries both BOLD and ITALIC over the same span.""" + plain, styles = _markdown_to_signal("**_combo_**") + assert plain == "combo" + sd = styles_for(plain, styles) + assert set(sd.get("combo", [])) == {"BOLD", "ITALIC"} + + +def test_bold_and_italic_adjacent_no_separator(): + """`**bold***italic*` produces BOLD on `bold` and ITALIC on `italic`.""" + plain, styles = _markdown_to_signal("**bold***italic*") + assert plain == "bolditalic" + sd = styles_for(plain, styles) + assert sd.get("bold") == ["BOLD"] + assert sd.get("italic") == ["ITALIC"] + + +def test_unclosed_bold_falls_through_as_plain(): + """An unmatched `**` opener round-trips as literal text with no style.""" + plain, styles = _markdown_to_signal("**bold") + assert plain == "**bold" + assert styles == [] + + +def test_unclosed_inline_code_falls_through_as_plain(): + """An unmatched backtick round-trips as literal text with no style.""" + plain, styles = _markdown_to_signal("use `grep") + assert plain == "use `grep" + assert styles == [] + + +def test_inline_code_inside_blockquote(): + """Blockquote prefix is stripped; inline code becomes MONOSPACE.""" + plain, styles = _markdown_to_signal("> use `grep`") + assert plain == "use grep" + sd = styles_for(plain, styles) + assert sd.get("grep") == ["MONOSPACE"] + + +def test_header_with_inner_bold_produces_contiguous_bold_ranges(): + """`# **wrap** me` — header forces BOLD over the whole line; the inner `**` + splits the run, yielding two contiguous BOLD ranges that together cover + "wrap me". This is intentional — Signal renders adjacent same-style ranges + as a single visual span. + """ + plain, styles = _markdown_to_signal("# **wrap** me") + assert plain == "wrap me" + # Both ranges are BOLD; collectively they cover the whole "wrap me". + bold_ranges = [s for s in styles if s.endswith(":BOLD")] + assert len(bold_ranges) == 2 + covered = set() + for entry in bold_ranges: + start, length, _ = entry.split(":", 2) + for i in range(int(start), int(start) + int(length)): + covered.add(i) + assert covered == set(range(len(plain))) diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index cc011a244..74a780c80 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -1055,6 +1055,7 @@ async def test_settings_api_returns_safe_subset_and_updates_whitelist( } assert image_providers["openrouter"]["label"] == "OpenRouter" assert image_providers["openrouter"]["configured"] is False + assert image_providers["openai_codex"]["configured"] is True assert image_providers["gemini"]["label"] == "Gemini" assert body["runtime"]["config_path"] == str(config_path) workspace_path = body["runtime"]["workspace_path"].replace("\\", "/") diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index a695ba936..3d3606e75 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -1,6 +1,7 @@ import asyncio import json import tempfile +import time from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock @@ -374,6 +375,7 @@ async def test_send_uses_typing_start_and_cancel_when_ticket_available() -> None channel._client = object() channel._token = "token" channel._context_tokens["wx-user"] = "ctx-typing" + channel._context_token_at["wx-user"] = time.time() channel._send_text = AsyncMock() channel._api_post = AsyncMock( side_effect=[ @@ -402,6 +404,7 @@ async def test_send_still_sends_text_when_typing_ticket_missing() -> None: channel._client = object() channel._token = "token" channel._context_tokens["wx-user"] = "ctx-no-ticket" + channel._context_token_at["wx-user"] = time.time() channel._send_text = AsyncMock() channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "no config"}) @@ -1254,3 +1257,526 @@ async def test_send_text_succeeds_on_zero_errcode() -> None: await channel._send_text("wx-user", "hello", "ctx-ok") channel._api_post.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_text_raises_on_nonzero_ret_even_when_errcode_zero() -> None: + """_send_text must raise when the API returns ret != 0, even if errcode is 0. + + The iLink API signals failure through either field. Checking only errcode + caused silent message drops (responses generated but never delivered). + """ + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._api_post = AsyncMock( + return_value={"ret": -100, "errcode": 0, "errmsg": "internal error"} + ) + + with pytest.raises(RuntimeError, match="WeChat send text error.*ret=-100.*errcode=0"): + await channel._send_text("wx-user", "hello", "ctx-ok") + + channel._api_post.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# Tests for _poll_once not silently dropping messages on processing errors +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_poll_once_logs_exception_on_process_message_failure(monkeypatch) -> None: + """When _process_message raises, _poll_once must log the error and continue + processing remaining messages instead of silently swallowing the exception.""" + channel, _bus = _make_channel() + channel._client = SimpleNamespace(timeout=None) + channel._token = "token" + channel._get_updates_buf = "old-buf" + + calls = [] + logged_messages: list[str] = [] + + async def _failing_process(msg: dict) -> None: + calls.append(msg.get("message_id")) + if msg.get("message_id") == "msg-1": + raise RuntimeError("processing failed") + + channel._process_message = _failing_process # type: ignore[method-assign] + + monkeypatch.setattr( + channel.logger, + "exception", + lambda message, *args, **kwargs: logged_messages.append(str(message)), + ) + + channel._api_post = AsyncMock( # type: ignore[method-assign] + return_value={ + "ret": 0, + "errcode": 0, + "get_updates_buf": "new-buf", + "msgs": [ + {"message_id": "msg-1", "message_type": 1}, + {"message_id": "msg-2", "message_type": 1}, + ], + } + ) + + await channel._poll_once() + + # Both messages should have been attempted + assert calls == ["msg-1", "msg-2"] + # Buffer should still advance (already updated before processing) + assert channel._get_updates_buf == "new-buf" + # Error should be logged + assert any("Failed to process WeChat message" in m for m in logged_messages) + + +@pytest.mark.asyncio +async def test_poll_loop_logs_exception_and_continues_on_poll_failure(monkeypatch) -> None: + """When _poll_once raises a non-timeout exception, the start() loop must log + the error and continue polling instead of exiting silently.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.config.token = "token" # skip QR login in start() + channel._running = True + + call_count = 0 + logged_messages: list[str] = [] + + async def _failing_poll() -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("poll exploded") + channel._running = False # Stop after second call + + channel._poll_once = _failing_poll # type: ignore[method-assign] + + monkeypatch.setattr( + channel.logger, + "exception", + lambda message, *args, **kwargs: logged_messages.append(str(message)), + ) + + # Use a tiny retry delay so the test finishes quickly + original_retry = weixin_mod.RETRY_DELAY_S + weixin_mod.RETRY_DELAY_S = 0.01 + try: + await channel.start() + finally: + weixin_mod.RETRY_DELAY_S = original_retry + + assert call_count == 2 + assert any("WeChat poll loop error" in m for m in logged_messages) + + +# --------------------------------------------------------------------------- +# Tool-hint buffering +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_buffer_single_tool_hint_not_sent_immediately() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Using tool", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + channel._send_text.assert_not_awaited() + assert channel._pending_tool_hints["wx-user"] == ["Using tool"] + + +@pytest.mark.asyncio +async def test_buffer_multiple_tool_hints_flushed_on_final_answer() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + for hint in ["tool1", "tool2"]: + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": hint, + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + assert channel._send_text.await_count == 2 + channel._send_text.assert_any_await("wx-user", "tool1\n\ntool2", "ctx-1") + channel._send_text.assert_any_await("wx-user", "Done", "ctx-1") + assert "wx-user" not in channel._pending_tool_hints + + +@pytest.mark.asyncio +async def test_thought_progress_flushes_tool_hints() -> None: + """Thoughts are visible progress messages and must act as separators, + flushing buffered tool hints before they are sent.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + # Buffer a tool hint + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "search 'foo'", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + # Send a thought — progress but not a tool_hint. + # It must act as a separator and flush the buffered hint. + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Let me think...", + "media": [], + "metadata": {"_progress": True}, + }, + )() + ) + + # The buffered hint was flushed before the thought was sent. + channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1") + channel._send_text.assert_any_await("wx-user", "Let me think...", "ctx-1") + assert "wx-user" not in channel._pending_tool_hints + + # Final answer arrives with nothing left to flush. + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + assert channel._send_text.await_count == 3 + channel._send_text.assert_any_await("wx-user", "Done", "ctx-1") + + +@pytest.mark.asyncio +async def test_reasoning_delta_does_not_flush_tool_hints() -> None: + """Reasoning deltas are invisible in WeChat and must NOT flush buffered + tool hints — otherwise hints separated only by hidden reasoning would + fail to coalesce.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + # Buffer a tool hint + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "search 'foo'", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + # Send a reasoning delta — invisible in WeChat, must NOT flush + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Thinking step 1...", + "media": [], + "metadata": {"_progress": True, "_reasoning_delta": True}, + }, + )() + ) + + # Reasoning is invisible; hint stays buffered, _send_text not called + channel._send_text.assert_not_awaited() + assert channel._pending_tool_hints["wx-user"] == ["search 'foo'"] + + # Final answer flushes the buffered hint + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1") + channel._send_text.assert_any_await("wx-user", "Done", "ctx-1") + assert "wx-user" not in channel._pending_tool_hints + + +@pytest.mark.asyncio +async def test_empty_progress_message_does_not_flush_tool_hints() -> None: + """Empty progress messages (e.g. after_iteration tool_events) have no + visible content and must NOT act as separators.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + # Buffer a tool hint + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "search 'foo'", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + # Send an empty progress message (no content, no media) + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "", + "media": [], + "metadata": {"_progress": True, "_tool_events": [{"phase": "end"}]}, + }, + )() + ) + + # Nothing should have been sent yet + channel._send_text.assert_not_awaited() + assert channel._pending_tool_hints["wx-user"] == ["search 'foo'"] + + # Final answer flushes the buffered hint + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1") + channel._send_text.assert_any_await("wx-user", "Done", "ctx-1") + assert "wx-user" not in channel._pending_tool_hints + + +@pytest.mark.asyncio +async def test_buffer_flush_refreshes_context_token() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-old" + channel._context_token_at["wx-user"] = time.time() + channel._refresh_context_token_if_stale = AsyncMock(return_value="ctx-refreshed") + channel._send_text = AsyncMock() + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "hint", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + assert channel._refresh_context_token_if_stale.await_count == 2 + channel._refresh_context_token_if_stale.assert_any_await("wx-user", "ctx-old") + channel._send_text.assert_any_await("wx-user", "hint", "ctx-refreshed") + + +@pytest.mark.asyncio +async def test_buffer_flush_failure_does_not_block_final_answer() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock(side_effect=[RuntimeError("boom"), None]) + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "hint", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + assert channel._send_text.await_count == 2 + channel._send_text.assert_any_await("wx-user", "hint", "ctx-1") + channel._send_text.assert_any_await("wx-user", "Done", "ctx-1") + + +@pytest.mark.asyncio +async def test_buffer_flushed_on_stream_end() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "hint", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + await channel.send_delta("wx-user", "", {"_stream_end": True}) + + channel._send_text.assert_awaited_once_with("wx-user", "hint", "ctx-1") + assert "wx-user" not in channel._pending_tool_hints + + +@pytest.mark.asyncio +async def test_stop_clears_buffer() -> None: + channel, _bus = _make_channel() + channel._pending_tool_hints["wx-user"] = ["hint1", "hint2"] + await channel.stop() + assert "wx-user" not in channel._pending_tool_hints + + +@pytest.mark.asyncio +async def test_send_tool_hints_false_drops_tool_hints() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = False + channel._send_text = AsyncMock() + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "hint", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + channel._send_text.assert_not_awaited() + assert "wx-user" not in channel._pending_tool_hints diff --git a/tests/config/test_model_presets.py b/tests/config/test_model_presets.py index 046c5b04d..fe01c2547 100644 --- a/tests/config/test_model_presets.py +++ b/tests/config/test_model_presets.py @@ -192,3 +192,20 @@ def test_match_provider_uses_preset_provider_when_forced() -> None: }) name = config.get_provider_name() assert name == "anthropic" + + +def test_match_provider_routes_forced_novita_model_api_models() -> None: + config = Config.model_validate({ + "providers": { + "novita": {"apiKey": "sk-test"}, + }, + "agents": { + "defaults": { + "model": "deepseek-v4-pro", + "provider": "novita", + } + }, + }) + + assert config.get_provider_name() == "novita" + assert config.get_api_base() == "https://api.novita.ai/openai" diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py index 85314dc79..ee1f9a090 100644 --- a/tests/providers/test_custom_provider.py +++ b/tests/providers/test_custom_provider.py @@ -56,6 +56,35 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None: assert result.content == "hello world" +def test_custom_provider_parse_chunks_deduplicates_parallel_tool_call_ids() -> None: + chunks = [{ + "choices": [{ + "finish_reason": "tool_calls", + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "call_dup", + "function": {"name": "read_file", "arguments": '{"path":"a.txt"}'}, + }, + { + "index": 1, + "id": "call_dup", + "function": {"name": "read_file", "arguments": '{"path":"b.txt"}'}, + }, + ], + }, + }], + }] + + result = OpenAICompatProvider._parse_chunks(chunks) + ids = [tool_call.id for tool_call in result.tool_calls or []] + + assert ids[0] == "call_dup" + assert len(ids) == 2 + assert len(set(ids)) == 2 + + def test_local_provider_502_error_includes_reachability_hint() -> None: spec = find_by_name("ollama") with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): diff --git a/tests/providers/test_image_generation.py b/tests/providers/test_image_generation.py index 701f09f0a..f3ca1459c 100644 --- a/tests/providers/test_image_generation.py +++ b/tests/providers/test_image_generation.py @@ -9,11 +9,13 @@ import pytest from nanobot.providers.image_generation import ( AIHubMixImageGenerationClient, + CodexImageGenerationClient, GeminiImageGenerationClient, GeneratedImageResponse, ImageGenerationError, MiniMaxImageGenerationClient, OllamaImageGenerationClient, + OpenAIImageGenerationClient, OpenRouterImageGenerationClient, StepFunImageGenerationClient, ) @@ -37,12 +39,14 @@ class FakeResponse: payload: dict[str, Any], status_code: int = 200, content: bytes = b"", + sse_lines: list[str] | None = None, ) -> None: self._payload = payload self.status_code = status_code self.text = str(payload) self.content = content self.request = httpx.Request("POST", "https://openrouter.ai/api/v1/chat/completions") + self._sse_lines = sse_lines def json(self) -> dict[str, Any]: return self._payload @@ -52,6 +56,15 @@ class FakeResponse: response = httpx.Response(self.status_code, request=self.request, text=self.text) raise httpx.HTTPStatusError("failed", request=self.request, response=response) + async def aiter_lines(self): + if self._sse_lines is not None: + for line in self._sse_lines: + yield line + return + # Fallback: treat response text as SSE lines + for line in self.text.split("\n"): + yield line + class FakeClient: def __init__(self, response: FakeResponse) -> None: @@ -564,3 +577,437 @@ async def test_stepfun_no_images_raises() -> None: with pytest.raises(ImageGenerationError, match="returned no images"): await client.generate(prompt="draw", model="step-image-edit-2") + + +# --------------------------------------------------------------------------- +# OpenAI +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_openai_payload_and_response() -> None: + fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]})) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + api_base="https://api.openai.com/v1", + extra_headers={"X-Test": "1"}, + client=fake, # type: ignore[arg-type] + ) + + response = await client.generate( + prompt="a cat on the moon", + model="dall-e-3", + aspect_ratio="16:9", + ) + + assert response.images == [PNG_DATA_URL] + call = fake.calls[0] + assert call["url"] == "https://api.openai.com/v1/images/generations" + assert call["headers"]["Authorization"] == "Bearer sk-openai-test" + assert call["headers"]["X-Test"] == "1" + body = call["json"] + assert body["model"] == "dall-e-3" + assert body["prompt"] == "a cat on the moon" + assert body["response_format"] == "b64_json" + assert body["n"] == 1 + assert body["size"] == "1792x1024" + + +@pytest.mark.asyncio +async def test_openai_b64_json_response_uses_detected_mime() -> None: + raw_b64 = base64.b64encode(JPEG_BYTES).decode("ascii") + fake = FakeClient(FakeResponse({"data": [{"b64_json": raw_b64}]})) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + response = await client.generate(prompt="draw", model="dall-e-3") + + assert response.images == [f"data:image/jpeg;base64,{raw_b64}"] + + +@pytest.mark.asyncio +async def test_openai_url_download_fallback() -> None: + fake = FakeClient(FakeResponse({"data": [{"url": "https://cdn.example/image.png"}]})) + fake.get_response = FakeResponse({}, content=PNG_BYTES) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + response = await client.generate(prompt="draw", model="dall-e-3") + + assert response.images[0].startswith("data:image/png;base64,") + assert fake.get_calls[0]["url"] == "https://cdn.example/image.png" + + +@pytest.mark.asyncio +async def test_openai_multiple_images() -> None: + fake = FakeClient(FakeResponse({ + "data": [ + {"b64_json": RAW_B64}, + {"b64_json": RAW_B64}, + ] + })) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + response = await client.generate(prompt="draw", model="dall-e-3") + + assert len(response.images) == 2 + assert response.images == [PNG_DATA_URL, PNG_DATA_URL] + + +@pytest.mark.asyncio +async def test_openai_aspect_ratio_to_size() -> None: + fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]})) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="1:1") + assert fake.calls[0]["json"]["size"] == "1024x1024" + + +@pytest.mark.asyncio +async def test_openai_dalle3_uses_supported_orientation_sizes() -> None: + fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]})) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="3:4") + await client.generate(prompt="draw", model="dall-e-3", aspect_ratio="4:3") + + assert fake.calls[0]["json"]["size"] == "1024x1792" + assert fake.calls[1]["json"]["size"] == "1792x1024" + + +@pytest.mark.asyncio +async def test_openai_dalle2_uses_square_size_for_non_square_ratios() -> None: + fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]})) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + await client.generate(prompt="draw", model="dall-e-2", aspect_ratio="16:9") + + assert fake.calls[0]["json"]["size"] == "1024x1024" + + +@pytest.mark.asyncio +async def test_openai_gpt_image_uses_supported_landscape_size() -> None: + fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]})) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="16:9") + + assert fake.calls[0]["json"]["size"] == "1536x1024" + + +@pytest.mark.asyncio +async def test_openai_gpt_image_uses_supported_orientation_sizes() -> None: + fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]})) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="3:4") + await client.generate(prompt="draw", model="gpt-image-1", aspect_ratio="4:3") + + assert fake.calls[0]["json"]["size"] == "1024x1536" + assert fake.calls[1]["json"]["size"] == "1536x1024" + + +@pytest.mark.asyncio +async def test_openai_default_size_when_no_aspect_ratio() -> None: + fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]})) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + await client.generate(prompt="draw", model="dall-e-3") + + body = fake.calls[0]["json"] + assert body["size"] == "1024x1024" + + +@pytest.mark.asyncio +async def test_openai_ignores_explicit_size_unsupported_by_model_family() -> None: + fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]})) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + await client.generate( + prompt="draw", + model="dall-e-3", + aspect_ratio="16:9", + image_size="1536x1024", + ) + + body = fake.calls[0]["json"] + assert body["size"] == "1792x1024" + + +@pytest.mark.asyncio +async def test_openai_uses_explicit_image_size() -> None: + fake = FakeClient(FakeResponse({"data": [{"b64_json": RAW_B64}]})) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + await client.generate( + prompt="draw", + model="dall-e-3", + aspect_ratio="16:9", + image_size="1024x1024", + ) + + body = fake.calls[0]["json"] + assert body["size"] == "1024x1024" + + +@pytest.mark.asyncio +async def test_openai_requires_api_key() -> None: + client = OpenAIImageGenerationClient(api_key=None) + + with pytest.raises(ImageGenerationError, match="API key"): + await client.generate(prompt="draw", model="dall-e-3") + + +# --------------------------------------------------------------------------- +# OpenAI Codex (Responses API) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_codex_payload_and_response(monkeypatch) -> None: + import sys + from dataclasses import dataclass + from types import SimpleNamespace + + @dataclass + class FakeToken: + account_id: str = "acct-123" + access: str = "oauth-token" + + async def fake_to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr("asyncio.to_thread", fake_to_thread) + fake_oauth = SimpleNamespace(get_token=lambda: FakeToken()) + monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth) + + sse_lines = [ + 'data: {"type":"response.output_item.added","item":{"id":"ig_1","type":"image_generation_call","status":"in_progress"}}', + "", + f'data: {{"type":"response.output_item.done","item":{{"id":"ig_1","type":"image_generation_call","result":"{PNG_DATA_URL}","status":"completed"}}}}', + "", + 'data: [DONE]', + "", + ] + fake = FakeClient(FakeResponse({}, sse_lines=sse_lines)) + client = CodexImageGenerationClient( + api_key=None, + api_base="https://chatgpt.com/backend-api", + extra_headers={"X-Test": "1"}, + client=fake, # type: ignore[arg-type] + ) + + response = await client.generate( + prompt="draw a cat", + model="gpt-5.4", + ) + + assert response.images == [PNG_DATA_URL] + assert response.content == "" + call = fake.calls[0] + assert call["url"] == "https://chatgpt.com/backend-api/codex/responses" + assert call["headers"]["Authorization"] == "Bearer oauth-token" + assert call["headers"]["chatgpt-account-id"] == "acct-123" + assert call["headers"]["OpenAI-Beta"] == "responses=experimental" + assert call["headers"]["X-Test"] == "1" + body = call["json"] + assert body["model"] == "gpt-5.4" + assert body["instructions"] == "Generate an image based on the user's request." + assert body["input"] == [{"role": "user", "content": "draw a cat"}] + assert body["tools"] == [{"type": "image_generation"}] + assert body["tool_choice"] == "auto" + assert body["store"] is False + assert body["stream"] is True + + +@pytest.mark.asyncio +async def test_codex_strips_model_prefix(monkeypatch) -> None: + import sys + from dataclasses import dataclass + from types import SimpleNamespace + + @dataclass + class FakeToken: + account_id: str = "acct-123" + access: str = "oauth-token" + + async def fake_to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr("asyncio.to_thread", fake_to_thread) + fake_oauth = SimpleNamespace(get_token=lambda: FakeToken()) + monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth) + + fake = FakeClient(FakeResponse({}, sse_lines=[ + f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":"{PNG_DATA_URL}"}}}}', + "", + 'data: [DONE]', + "", + ])) + client = CodexImageGenerationClient( + api_key=None, client=fake # type: ignore[arg-type] + ) + + await client.generate(prompt="draw", model="openai-codex/gpt-5.4") + + assert fake.calls[0]["json"]["model"] == "gpt-5.4" + + +@pytest.mark.asyncio +async def test_codex_requires_oauth(monkeypatch) -> None: + async def fake_to_thread(fn, *args, **kwargs): + raise RuntimeError("no token") + + monkeypatch.setattr("asyncio.to_thread", fake_to_thread) + + client = CodexImageGenerationClient(api_key=None) + + with pytest.raises(ImageGenerationError, match="OAuth token"): + await client.generate(prompt="draw", model="gpt-5.4") + + +@pytest.mark.asyncio +async def test_codex_no_images_raises(monkeypatch) -> None: + import sys + from dataclasses import dataclass + from types import SimpleNamespace + + @dataclass + class FakeToken: + account_id: str = "acct-123" + access: str = "oauth-token" + + async def fake_to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr("asyncio.to_thread", fake_to_thread) + fake_oauth = SimpleNamespace(get_token=lambda: FakeToken()) + monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth) + + fake = FakeClient(FakeResponse({}, sse_lines=[ + 'data: {"type":"response.completed","response":{"status":"completed"}}', + "", + 'data: [DONE]', + "", + ])) + client = CodexImageGenerationClient( + api_key=None, client=fake # type: ignore[arg-type] + ) + + with pytest.raises(ImageGenerationError, match="returned no images"): + await client.generate(prompt="draw", model="gpt-5.4") + + +@pytest.mark.asyncio +async def test_codex_extracts_text_content(monkeypatch) -> None: + import sys + from dataclasses import dataclass + from types import SimpleNamespace + + @dataclass + class FakeToken: + account_id: str = "acct-123" + access: str = "oauth-token" + + async def fake_to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr("asyncio.to_thread", fake_to_thread) + fake_oauth = SimpleNamespace(get_token=lambda: FakeToken()) + monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth) + + fake = FakeClient(FakeResponse({}, sse_lines=[ + 'data: {"type":"response.output_text.delta","delta":"Here "}', + "", + 'data: {"type":"response.output_text.delta","delta":"is your cat image."}', + "", + f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":"{PNG_DATA_URL}"}}}}', + "", + 'data: [DONE]', + "", + ])) + client = CodexImageGenerationClient( + api_key=None, client=fake # type: ignore[arg-type] + ) + + response = await client.generate(prompt="draw a cat", model="gpt-5.4") + + assert response.images == [PNG_DATA_URL] + assert response.content == "Here is your cat image." + + +@pytest.mark.asyncio +async def test_codex_json_result_format(monkeypatch) -> None: + """image_generation_call result can be a dict with image_url key.""" + import sys + from dataclasses import dataclass + from types import SimpleNamespace + + @dataclass + class FakeToken: + account_id: str = "acct-123" + access: str = "oauth-token" + + async def fake_to_thread(fn, *args, **kwargs): + return fn(*args, **kwargs) + + monkeypatch.setattr("asyncio.to_thread", fake_to_thread) + fake_oauth = SimpleNamespace(get_token=lambda: FakeToken()) + monkeypatch.setitem(sys.modules, "oauth_cli_kit", fake_oauth) + + fake = FakeClient(FakeResponse({}, sse_lines=[ + f'data: {{"type":"response.output_item.done","item":{{"type":"image_generation_call","result":{{"image_url":"{PNG_DATA_URL}"}}}}}}', + "", + 'data: [DONE]', + "", + ])) + client = CodexImageGenerationClient( + api_key=None, client=fake # type: ignore[arg-type] + ) + + response = await client.generate(prompt="draw", model="gpt-5.4") + + assert response.images == [PNG_DATA_URL] + + +@pytest.mark.asyncio +async def test_openai_no_images_raises() -> None: + fake = FakeClient(FakeResponse({"data": []})) + client = OpenAIImageGenerationClient( + api_key="sk-openai-test", + client=fake, # type: ignore[arg-type] + ) + + with pytest.raises(ImageGenerationError, match="returned no images"): + await client.generate(prompt="draw", model="dall-e-3") diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 5f2ffec59..924ee0060 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -441,6 +441,15 @@ def test_openrouter_spec_is_gateway() -> None: assert spec.default_api_base == "https://openrouter.ai/api/v1" +def test_novita_spec_uses_openai_compatible_gateway() -> None: + spec = find_by_name("novita") + assert spec is not None + assert spec.is_gateway is True + assert spec.backend == "openai_compat" + assert spec.env_key == "NOVITA_API_KEY" + assert spec.default_api_base == "https://api.novita.ai/openai" + + def test_gemma_routes_to_gemini_provider() -> None: """gemma models (e.g. gemma-3-27b-it) must auto-route to Gemini when GEMINI_API_KEY is set. Users running gemma via the Gemini API endpoint expect automatic provider detection.""" @@ -1007,6 +1016,41 @@ def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() - assert sanitized[2]["tool_call_id"] == "3ec83c30d" +def test_openai_compat_deduplicates_duplicate_tool_call_ids_in_history() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + sanitized = provider._sanitize_messages([ + {"role": "user", "content": "check both files"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "ab1b45c2a", + "type": "function", + "function": {"name": "read_file", "arguments": '{"path":"a.txt"}'}, + }, + { + "id": "ab1b45c2a", + "type": "function", + "function": {"name": "read_file", "arguments": '{"path":"b.txt"}'}, + }, + ], + }, + {"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "a"}, + {"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "b"}, + {"role": "user", "content": "continue"}, + ]) + + tool_call_ids = [tc["id"] for tc in sanitized[1]["tool_calls"]] + tool_result_ids = [sanitized[2]["tool_call_id"], sanitized[3]["tool_call_id"]] + + assert tool_call_ids[0] == "ab1b45c2a" + assert len(tool_call_ids) == len(set(tool_call_ids)) == 2 + assert tool_result_ids == tool_call_ids + + def test_openai_compat_stringifies_dict_tool_arguments() -> None: with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): provider = OpenAICompatProvider() @@ -1376,12 +1420,15 @@ def test_kimi_k25_thinking_enabled() -> None: """kimi-k2.5 with reasoning_effort set should opt in to thinking.""" kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="medium") assert kw.get("extra_body") == {"thinking": {"type": "enabled"}} + # Moonshot rejects both 'reasoning_effort' and 'thinking' (#3939) + assert "reasoning_effort" not in kw def test_kimi_k25_thinking_disabled_for_minimal() -> None: """reasoning_effort='minimal' maps to thinking disabled for kimi-k2.5.""" kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="minimal") assert kw.get("extra_body") == {"thinking": {"type": "disabled"}} + assert "reasoning_effort" not in kw def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None: @@ -1391,21 +1438,36 @@ def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None: def test_kimi_k25_thinking_enabled_with_openrouter_prefix() -> None: - """OpenRouter-style model names like moonshotai/kimi-k2.5 must trigger thinking.""" + """OpenRouter-style model names like moonshotai/kimi-k2.5 must trigger thinking. + + OR drops upstream-provider `thinking` fields, so the same intent also has + to go through OR's `reasoning.effort` shape (#3851 follow-up). + """ kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.5", reasoning_effort="medium") - assert kw.get("extra_body") == {"thinking": {"type": "enabled"}} + assert kw.get("extra_body") == { + "thinking": {"type": "enabled"}, + "reasoning": {"effort": "medium"}, + } + # Even via OR, reasoning_effort wire kwarg is dropped for kimi models + assert "reasoning_effort" not in kw def test_kimi_k26_thinking_enabled() -> None: """kimi-k2.6 with reasoning_effort set should opt in to thinking.""" kw = _build_kwargs_for("moonshot", "kimi-k2.6", reasoning_effort="medium") assert kw.get("extra_body") == {"thinking": {"type": "enabled"}} + assert "reasoning_effort" not in kw def test_kimi_k26_thinking_enabled_with_openrouter_prefix() -> None: - """OpenRouter-style names like moonshotai/kimi-k2.6 must trigger thinking.""" + """OpenRouter-style names like moonshotai/kimi-k2.6 must trigger thinking + via both upstream `thinking` and OR's `reasoning.effort`.""" kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.6", reasoning_effort="medium") - assert kw.get("extra_body") == {"thinking": {"type": "enabled"}} + assert kw.get("extra_body") == { + "thinking": {"type": "enabled"}, + "reasoning": {"effort": "medium"}, + } + assert "reasoning_effort" not in kw def test_moonshot_kimi_k26_temperature_override() -> None: @@ -1424,6 +1486,7 @@ def test_kimi_k26_code_preview_thinking_enabled() -> None: """k2.6-code-preview also supports thinking; should behave like k2.5.""" kw = _build_kwargs_for("moonshot", "k2.6-code-preview", reasoning_effort="high") assert kw.get("extra_body") == {"thinking": {"type": "enabled"}} + assert "reasoning_effort" not in kw def test_kimi_k2_series_no_thinking_injection() -> None: @@ -1453,6 +1516,7 @@ def test_kimi_k25_thinking_disabled_for_none_string() -> None: """reasoning_effort='none' maps to thinking disabled for kimi-k2.5.""" kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="none") assert kw.get("extra_body") == {"thinking": {"type": "disabled"}} + assert "reasoning_effort" not in kw def test_dashscope_thinking_disabled_for_none_string() -> None: diff --git a/tests/providers/test_novita_provider.py b/tests/providers/test_novita_provider.py new file mode 100644 index 000000000..0b1e8ec12 --- /dev/null +++ b/tests/providers/test_novita_provider.py @@ -0,0 +1,97 @@ +"""Tests for the Novita AI provider registration.""" + +from unittest.mock import patch + +from nanobot.config.schema import Config, ProvidersConfig +from nanobot.providers.openai_compat_provider import OpenAICompatProvider +from nanobot.providers.registry import PROVIDERS, find_by_name + + +def test_novita_config_field_exists() -> None: + config = ProvidersConfig() + + assert hasattr(config, "novita") + + +def test_novita_provider_in_registry() -> None: + specs = {spec.name: spec for spec in PROVIDERS} + + assert "novita" in specs + novita = specs["novita"] + assert novita.backend == "openai_compat" + assert novita.env_key == "NOVITA_API_KEY" + assert novita.display_name == "Novita AI" + assert novita.is_gateway is True + assert novita.detect_by_base_keyword == "novita" + assert novita.default_api_base == "https://api.novita.ai/openai" + assert novita.strip_model_prefix is False + + +def test_find_by_name_novita() -> None: + spec = find_by_name("novita") + + assert spec is not None + assert spec.name == "novita" + + +def test_novita_forced_provider_uses_default_api_base() -> None: + config = Config.model_validate({ + "providers": { + "novita": { + "apiKey": "novita-key", + }, + }, + "agents": { + "defaults": { + "model": "deepseek-v4-pro", + "provider": "novita", + }, + }, + }) + + assert config.get_provider_name("deepseek-v4-pro") == "novita" + assert config.get_api_key("deepseek-v4-pro") == "novita-key" + assert config.get_api_base("deepseek-v4-pro") == "https://api.novita.ai/openai" + + +def test_novita_gateway_routes_unprefixed_models_when_configured() -> None: + config = Config.model_validate({ + "providers": { + "novita": { + "apiKey": "novita-key", + }, + }, + "agents": { + "defaults": { + "model": "deepseek-v4-pro", + }, + }, + }) + + assert config.get_provider_name("deepseek-v4-pro") == "novita" + assert config.get_api_key("deepseek-v4-pro") == "novita-key" + assert config.get_api_base("deepseek-v4-pro") == "https://api.novita.ai/openai" + + +def test_novita_preserves_model_api_id() -> None: + spec = find_by_name("novita") + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider( + api_key="novita-key", + default_model="deepseek-v4-pro", + spec=spec, + ) + + kwargs = provider._build_kwargs( + messages=[{"role": "user", "content": "hi"}], + tools=None, + model="deepseek-v4-pro", + max_tokens=1024, + temperature=0.7, + reasoning_effort=None, + tool_choice=None, + ) + + assert kwargs["model"] == "deepseek-v4-pro" + assert kwargs["max_tokens"] == 1024 + assert "max_completion_tokens" not in kwargs diff --git a/tests/providers/test_xiaomi_mimo_thinking.py b/tests/providers/test_xiaomi_mimo_thinking.py index 68ca6dd80..92161803f 100644 --- a/tests/providers/test_xiaomi_mimo_thinking.py +++ b/tests/providers/test_xiaomi_mimo_thinking.py @@ -32,7 +32,7 @@ def _mimo_spec(): def _openrouter_spec(): - """Return the registered OpenRouter ProviderSpec (no thinking_style).""" + """Return the registered OpenRouter ProviderSpec.""" specs = {s.name: s for s in PROVIDERS} return specs["openrouter"] @@ -77,6 +77,13 @@ def test_xiaomi_mimo_uses_thinking_type_style(): assert spec.default_api_base == "https://api.xiaomimimo.com/v1" +def test_openrouter_declares_gateway_reasoning_style(): + """OpenRouter uses its own reasoning.effort field for routed thinking models.""" + spec = _openrouter_spec() + assert spec.thinking_style == "" + assert spec.gateway_reasoning_style == "reasoning_effort" + + # --------------------------------------------------------------------------- # _build_kwargs wire-format # --------------------------------------------------------------------------- @@ -142,9 +149,11 @@ def test_mimo_reasoning_effort_unset_preserves_provider_default(): def test_mimo_via_openrouter_reasoning_effort_none_disables_thinking(): - """OpenRouter routes MiMo as "xiaomi/mimo-v2.5-pro"; the openrouter spec - has no thinking_style, so the disable signal must come from the - model-name path (#3845).""" + """OpenRouter routes MiMo as "xiaomi/mimo-v2.5-pro" and does NOT forward + extra_body.thinking to upstream, so a disable signal must also reach OR + in its own `reasoning.effort` shape. Verifies both the upstream-MiMo + payload (#3845) and the OR-native payload (#3851 follow-up) are sent. + """ provider = _openrouter_provider("xiaomi/mimo-v2.5-pro") kwargs = provider._build_kwargs( messages=_simple_messages(), @@ -152,11 +161,15 @@ def test_mimo_via_openrouter_reasoning_effort_none_disables_thinking(): temperature=0.7, reasoning_effort="none", tool_choice=None, ) assert "reasoning_effort" not in kwargs - assert kwargs["extra_body"] == {"thinking": {"type": "disabled"}} + assert kwargs["extra_body"] == { + "thinking": {"type": "disabled"}, + "reasoning": {"effort": "none"}, + } def test_mimo_via_openrouter_reasoning_effort_medium_enables_thinking(): - """Same as the direct path: any non-none/minimal effort enables thinking.""" + """Non-none/minimal effort enables thinking and the OR `reasoning.effort` + field mirrors the requested effort level.""" provider = _openrouter_provider("xiaomi/mimo-v2.5-pro") kwargs = provider._build_kwargs( messages=_simple_messages(), @@ -164,7 +177,10 @@ def test_mimo_via_openrouter_reasoning_effort_medium_enables_thinking(): temperature=0.7, reasoning_effort="medium", tool_choice=None, ) assert kwargs.get("reasoning_effort") == "medium" - assert kwargs["extra_body"] == {"thinking": {"type": "enabled"}} + assert kwargs["extra_body"] == { + "thinking": {"type": "enabled"}, + "reasoning": {"effort": "medium"}, + } def test_mimo_via_openrouter_bare_slug_also_matches(): @@ -176,12 +192,16 @@ def test_mimo_via_openrouter_bare_slug_also_matches(): tools=None, model=None, max_tokens=100, temperature=0.7, reasoning_effort="none", tool_choice=None, ) - assert kwargs["extra_body"] == {"thinking": {"type": "disabled"}} + assert kwargs["extra_body"] == { + "thinking": {"type": "disabled"}, + "reasoning": {"effort": "none"}, + } def test_mimo_flash_via_openrouter_does_not_inject_thinking(): """mimo-v2-flash has no thinking mode per Xiaomi docs; the allowlist - excludes it, so no thinking field should be injected on the gateway path.""" + excludes it, so neither the upstream `thinking` field nor OR's + `reasoning.effort` should be injected on the gateway path.""" provider = _openrouter_provider("xiaomi/mimo-v2-flash") kwargs = provider._build_kwargs( messages=_simple_messages(), @@ -200,3 +220,18 @@ def test_non_mimo_model_via_openrouter_unaffected(): temperature=0.7, reasoning_effort="none", tool_choice=None, ) assert "extra_body" not in kwargs + + +def test_kimi_via_openrouter_also_injects_reasoning_effort(): + """Kimi has the same gateway problem as MiMo: OR drops the upstream + `thinking` field. The same OR-reasoning injection should fire.""" + provider = _openrouter_provider("moonshotai/kimi-k2.5") + kwargs = provider._build_kwargs( + messages=_simple_messages(), + tools=None, model=None, max_tokens=100, + temperature=0.7, reasoning_effort="none", tool_choice=None, + ) + assert kwargs["extra_body"] == { + "thinking": {"type": "disabled"}, + "reasoning": {"effort": "none"}, + } diff --git a/tests/tools/test_apply_patch_tool.py b/tests/tools/test_apply_patch_tool.py new file mode 100644 index 000000000..2ba247368 --- /dev/null +++ b/tests/tools/test_apply_patch_tool.py @@ -0,0 +1,330 @@ +from __future__ import annotations + +import asyncio + +from nanobot.agent.tools.apply_patch import ApplyPatchTool + + +def test_apply_patch_edits_replace(tmp_path): + target = tmp_path / "calc.py" + target.write_text("def add(a, b):\n return a + b\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run( + tool.execute( + edits=[ + { + "path": "calc.py", + "action": "replace", + "old_text": " return a + b", + "new_text": " return a - b", + } + ] + ) + ) + + assert "update calc.py" in result + assert target.read_text() == "def add(a, b):\n return a - b\n" + + +def test_apply_patch_edits_add_new_file(tmp_path): + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run( + tool.execute( + edits=[ + { + "path": "config.py", + "action": "add", + "new_text": "DEBUG = True", + } + ] + ) + ) + + assert "add config.py" in result + assert (tmp_path / "config.py").read_text() == "DEBUG = True\n" + + +def test_apply_patch_edits_preserves_new_file_trailing_blank_lines(tmp_path): + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run( + tool.execute( + edits=[ + { + "path": "notes.txt", + "action": "add", + "new_text": "one\n\n", + } + ] + ) + ) + + assert "add notes.txt" in result + assert (tmp_path / "notes.txt").read_text() == "one\n\n" + + +def test_apply_patch_edits_add_to_existing_file(tmp_path): + target = tmp_path / "log.py" + target.write_text("import logging\n\nlogger = logging.getLogger(__name__)\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run( + tool.execute( + edits=[ + { + "path": "log.py", + "action": "add", + "new_text": "def debug(msg):\n logger.debug(msg)", + } + ] + ) + ) + + assert "update log.py" in result + assert ( + target.read_text() + == "import logging\n\nlogger = logging.getLogger(__name__)\ndef debug(msg):\n logger.debug(msg)\n" + ) + + +def test_apply_patch_edits_delete(tmp_path): + target = tmp_path / "utils.py" + target.write_text("def unused():\n pass\ndef used():\n return 1\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run( + tool.execute( + edits=[ + { + "path": "utils.py", + "action": "delete", + "old_text": "def unused():\n pass\n", + } + ] + ) + ) + + assert "update utils.py" in result + assert target.read_text() == "def used():\n return 1\n" + + +def test_apply_patch_edits_delete_entire_file(tmp_path): + target = tmp_path / "obsolete.txt" + target.write_text("remove me\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run( + tool.execute( + edits=[ + { + "path": "obsolete.txt", + "action": "delete", + "old_text": "remove me\n", + } + ] + ) + ) + + assert "delete obsolete.txt" in result + assert not target.exists() + + +def test_apply_patch_edits_delete_substring_with_surrounding_whitespace(tmp_path): + target = tmp_path / "keep_whitespace.txt" + target.write_text(" token \n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run( + tool.execute( + edits=[ + { + "path": "keep_whitespace.txt", + "action": "delete", + "old_text": "token", + } + ] + ) + ) + + assert "update keep_whitespace.txt" in result + assert target.exists() + assert target.read_text() == " \n" + + +def test_apply_patch_edits_batch_multiple_files(tmp_path): + a = tmp_path / "a.py" + a.write_text("X = 1\n") + b = tmp_path / "b.py" + b.write_text("from a import X\nprint(X)\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run( + tool.execute( + edits=[ + { + "path": "a.py", + "action": "replace", + "old_text": "X = 1", + "new_text": "Y = 1", + }, + { + "path": "b.py", + "action": "replace", + "old_text": "from a import X", + "new_text": "from a import Y", + }, + ] + ) + ) + + assert "update a.py" in result + assert "update b.py" in result + assert a.read_text() == "Y = 1\n" + assert b.read_text() == "from a import Y\nprint(X)\n" + + +def test_apply_patch_edits_rejects_ambiguous_old_text(tmp_path): + target = tmp_path / "repeated.txt" + target.write_text("target\nmiddle\ntarget\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run( + tool.execute( + edits=[ + { + "path": "repeated.txt", + "action": "replace", + "old_text": "target", + "new_text": "changed", + } + ] + ) + ) + + assert "old_text appears multiple times" in result + assert target.read_text() == "target\nmiddle\ntarget\n" + + +def test_apply_patch_edits_dry_run_validates_without_writing(tmp_path): + target = tmp_path / "dry.txt" + target.write_text("before\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run( + tool.execute( + edits=[ + { + "path": "dry.txt", + "action": "replace", + "old_text": "before", + "new_text": "after", + }, + { + "path": "added.txt", + "action": "add", + "new_text": "new", + }, + ], + dry_run=True, + ) + ) + + assert "Patch dry-run succeeded" in result + assert target.read_text() == "before\n" + assert not (tmp_path / "added.txt").exists() + + +def test_apply_patch_edits_rejects_absolute_and_parent_paths(tmp_path): + tool = ApplyPatchTool(workspace=tmp_path) + + absolute = asyncio.run( + tool.execute( + edits=[ + { + "path": "/tmp/owned.txt", + "action": "add", + "new_text": "nope", + } + ] + ) + ) + parent = asyncio.run( + tool.execute( + edits=[ + { + "path": "../owned.txt", + "action": "add", + "new_text": "nope", + } + ] + ) + ) + windows_absolute = asyncio.run( + tool.execute( + edits=[ + { + "path": r"C:\owned.txt", + "action": "add", + "new_text": "nope", + } + ] + ) + ) + windows_parent = asyncio.run( + tool.execute( + edits=[ + { + "path": r"..\owned.txt", + "action": "add", + "new_text": "nope", + } + ] + ) + ) + + assert "must be relative" in absolute + assert "must not contain '..'" in parent + assert "must be relative" in windows_absolute + assert "must not contain '..'" in windows_parent + assert not (tmp_path.parent / "owned.txt").exists() + + +def test_apply_patch_edits_reports_invalid_edit_shapes(tmp_path): + tool = ApplyPatchTool(workspace=tmp_path) + + missing_path = asyncio.run(tool.execute(edits=[{"action": "add", "new_text": "x"}])) + missing_action = asyncio.run(tool.execute(edits=[{"path": "x.txt", "new_text": "x"}])) + non_object = asyncio.run(tool.execute(edits=["not an object"])) # type: ignore[list-item] + + assert "path required for edit" in missing_path + assert "action required for edit: x.txt" in missing_action + assert "each edit must be an object" in non_object + + +def test_apply_patch_edits_rolls_back_when_late_operation_fails(tmp_path): + first = tmp_path / "first.txt" + first.write_text("before\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run( + tool.execute( + edits=[ + { + "path": "first.txt", + "action": "replace", + "old_text": "before", + "new_text": "after", + }, + { + "path": "missing.txt", + "action": "delete", + "old_text": "remove me", + }, + ] + ) + ) + + assert "file to update does not exist: missing.txt" in result + assert first.read_text() == "before\n" diff --git a/tests/tools/test_edit_enhancements.py b/tests/tools/test_edit_enhancements.py index 1f22c963b..7202fc37b 100644 --- a/tests/tools/test_edit_enhancements.py +++ b/tests/tools/test_edit_enhancements.py @@ -1,5 +1,5 @@ """Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions, -.ipynb detection, and create-file semantics.""" +notebook JSON editing, and create-file semantics.""" import pytest @@ -108,22 +108,27 @@ class TestEditCreateFile: # --------------------------------------------------------------------------- -# .ipynb detection +# .ipynb editing # --------------------------------------------------------------------------- -class TestEditIpynbDetection: - """edit_file should refuse .ipynb and suggest notebook_edit.""" +class TestEditIpynbFiles: + """edit_file edits notebooks as normal JSON files.""" @pytest.fixture() def tool(self, tmp_path): return EditFileTool(workspace=tmp_path) @pytest.mark.asyncio - async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path): + async def test_ipynb_can_be_edited_as_json(self, tool, tmp_path): f = tmp_path / "analysis.ipynb" f.write_text('{"cells": []}', encoding="utf-8") - result = await tool.execute(path=str(f), old_text="x", new_text="y") - assert "notebook" in result.lower() + result = await tool.execute( + path=str(f), + old_text='"cells": []', + new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]', + ) + assert "Successfully edited" in result + assert '"source": "hi"' in f.read_text(encoding="utf-8") # --------------------------------------------------------------------------- diff --git a/tests/tools/test_exec_platform.py b/tests/tools/test_exec_platform.py index 69a271ec1..ffb25f985 100644 --- a/tests/tools/test_exec_platform.py +++ b/tests/tools/test_exec_platform.py @@ -162,7 +162,7 @@ class TestPathAppendPlatform: captured_cmd = None captured_env = {} - async def capture_spawn(cmd, cwd, env): + async def capture_spawn(cmd, cwd, env, shell_program=None, login=True): nonlocal captured_cmd captured_cmd = cmd captured_env.update(env) @@ -190,7 +190,7 @@ class TestPathAppendPlatform: captured_env = {} - async def capture_spawn(cmd, cwd, env): + async def capture_spawn(cmd, cwd, env, shell_program=None, login=True): captured_env.update(env) return mock_proc diff --git a/tests/tools/test_exec_session_tools.py b/tests/tools/test_exec_session_tools.py new file mode 100644 index 000000000..f5fe45e96 --- /dev/null +++ b/tests/tools/test_exec_session_tools.py @@ -0,0 +1,361 @@ +from __future__ import annotations + +import asyncio +import re +import shlex +import subprocess +import sys + +from nanobot.agent.tools.shell import ExecTool +from nanobot.agent.tools.exec_session import ExecSessionManager, ListExecSessionsTool, WriteStdinTool + + +def _python_command(code: str) -> str: + if sys.platform == "win32": + return f"{subprocess.list2cmdline([sys.executable])} -u -c {subprocess.list2cmdline([code])}" + return f"{shlex.quote(sys.executable)} -u -c {shlex.quote(code)}" + + +def _session_id(output: str) -> str: + match = re.search(r"session_id:\s*([0-9a-f]+)", output) + assert match, output + return match.group(1) + + +def test_exec_keeps_one_shot_behavior_without_yield_time_ms(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + return await tool.execute(command="echo hello") + + result = asyncio.run(run()) + + assert "hello" in result + assert "Exit code: 0" in result + assert "session_id:" not in result + + +def test_exec_accepts_command_aliases(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir="/") + return await tool.execute( + cmd=_python_command("import os; print(os.getcwd())"), + workdir=str(tmp_path), + ) + + result = asyncio.run(run()) + + assert str(tmp_path) in result + assert "Exit code: 0" in result + + +def test_exec_returns_completed_session_output_when_yield_time_ms_is_used(tmp_path): + async def run() -> str: + manager = ExecSessionManager() + tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + + result = await tool.execute(command="echo hello", yield_time_ms=1000) + if "session_id:" in result: + sid = _session_id(result) + result += "\n" + await stdin_tool.execute( + session_id=sid, + chars="", + yield_time_ms=1000, + ) + return result + + result = asyncio.run(run()) + + assert "hello" in result + assert "Exit code: 0" in result + assert "session_id:" not in result + + +def test_exec_session_accepts_max_output_tokens_alias(tmp_path): + async def run() -> str: + manager = ExecSessionManager() + tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + command = _python_command("print('A' * 2000)") + return await tool.execute( + command=command, + yield_time_ms=1000, + max_output_tokens=1000, + ) + + result = asyncio.run(run()) + + assert "chars truncated" in result + assert "Exit code: 0" in result + + +def test_exec_one_shot_accepts_max_output_tokens_alias(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + command = _python_command("print('A' * 2000)") + return await tool.execute(command=command, max_output_tokens=1000) + + result = asyncio.run(run()) + + assert "chars truncated" in result + assert "Exit code: 0" in result + + +def test_exec_accepts_supported_shell_parameter(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + return await tool.execute(command="echo shell-ok", shell="sh", login=False) + + if sys.platform == "win32": + return + result = asyncio.run(run()) + + assert "shell-ok" in result + assert "Exit code: 0" in result + + +def test_exec_rejects_unsupported_shell(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + return await tool.execute(command="echo no", shell="python") + + if sys.platform == "win32": + return + result = asyncio.run(run()) + + assert "unsupported shell" in result + + +def test_exec_can_continue_with_stdin(tmp_path): + async def run() -> tuple[str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import sys; print('ready', flush=True); " + "line=sys.stdin.readline(); print('got:' + line.strip(), flush=True)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=500) + sid = _session_id(initial) + result = await stdin_tool.execute(session_id=sid, chars="ping\n", yield_time_ms=1000) + return initial, result + + initial, result = asyncio.run(run()) + assert "ready" in initial + assert "Process running" in initial + assert "Elapsed:" in initial + assert "got:ping" in result + assert "Exit code: 0" in result + assert "Elapsed:" in result + + +def test_write_stdin_can_close_stdin(tmp_path): + async def run() -> tuple[str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import sys; print('ready', flush=True); " + "data=sys.stdin.read(); print('got:' + data, flush=True)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=500) + sid = _session_id(initial) + result = await stdin_tool.execute( + session_id=sid, + chars="payload", + close_stdin=True, + yield_time_ms=1000, + ) + return initial, result + + initial, result = asyncio.run(run()) + assert "ready" in initial + assert "got:payload" in result + assert "Stdin closed." in result + assert "Exit code: 0" in result + + +def test_write_stdin_can_terminate_session(tmp_path): + async def run() -> tuple[str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=30, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('ready', flush=True); time.sleep(30)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=500) + sid = _session_id(initial) + result = await stdin_tool.execute( + session_id=sid, + terminate=True, + yield_time_ms=0, + ) + return initial, result + + initial, result = asyncio.run(run()) + assert "ready" in initial + assert "Session terminated." in result + assert "Exit code:" in result + + +def test_write_stdin_accepts_max_output_tokens_alias(tmp_path): + async def run() -> tuple[str, str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('A' * 2000, flush=True); time.sleep(5)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=0) + sid = _session_id(initial) + poll = await stdin_tool.execute( + session_id=sid, + yield_time_ms=500, + max_output_tokens=1000, + ) + cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0) + return initial, poll, cleanup + + initial, poll, cleanup = asyncio.run(run()) + assert "Process running" in initial + assert "chars truncated" in poll + assert "Session terminated." in cleanup + + +def test_write_stdin_preserves_completed_session_output_until_polled(tmp_path): + async def run() -> tuple[str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('ready', flush=True); " + "time.sleep(1.0); print('done', flush=True)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=300) + sid = _session_id(initial) + await asyncio.sleep(1.2) + final = await stdin_tool.execute(session_id=sid, chars="", yield_time_ms=0) + return initial, final + + initial, final = asyncio.run(run()) + + assert "ready" in initial + assert "done" in final + assert "Exit code: 0" in final + + +def test_write_stdin_can_wait_for_expected_output(tmp_path): + async def run() -> tuple[str, str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('booting', flush=True); " + "time.sleep(0.4); print('ready', flush=True); time.sleep(5)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=100) + sid = _session_id(initial) + waited = await stdin_tool.execute( + session_id=sid, + wait_for="ready", + wait_timeout_ms=3000, + yield_time_ms=0, + ) + cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0) + return initial, waited, cleanup + + initial, waited, cleanup = asyncio.run(run()) + + assert "Process running" in initial + assert "booting" in initial + waited + assert "ready" in waited + assert "Wait target not observed" not in waited + assert "Session terminated." in cleanup + + +def test_write_stdin_wait_for_reports_timeout_without_killing_session(tmp_path): + async def run() -> tuple[str, str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('booting', flush=True); time.sleep(5)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=100) + sid = _session_id(initial) + waited = await stdin_tool.execute( + session_id=sid, + wait_for="never-ready", + wait_timeout_ms=200, + yield_time_ms=0, + ) + cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0) + return initial, waited, cleanup + + initial, waited, cleanup = asyncio.run(run()) + + assert "Process running" in initial + assert "booting" in initial + waited + assert "Process running" in waited + assert "Wait target not observed: 'never-ready'" in waited + assert "Session terminated." in cleanup + + +def test_exec_session_mode_reuses_exec_safety_guard(tmp_path): + manager = ExecSessionManager() + tool = ExecTool( + working_dir=str(tmp_path), + deny_patterns=[r"echo\s+blocked"], + session_manager=manager, + ) + + result = asyncio.run(tool.execute(command="echo blocked", yield_time_ms=0)) + + assert "blocked by deny pattern" in result + + +def test_write_stdin_reports_missing_session(tmp_path): + manager = ExecSessionManager() + tool = WriteStdinTool(manager=manager) + + result = asyncio.run(tool.execute(session_id="missing", chars="")) + + assert "exec session not found" in result + + +def test_list_exec_sessions_reports_running_commands(tmp_path): + async def run() -> tuple[str, str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + list_tool = ListExecSessionsTool(manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('ready', flush=True); time.sleep(5)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=500) + sid = _session_id(initial) + listing = await list_tool.execute() + cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0) + return sid, listing, cleanup + + sid, listing, cleanup = asyncio.run(run()) + + assert sid in listing + assert "running" in listing + assert "elapsed=" in listing + assert "remaining=" in listing + assert str(tmp_path) in listing + assert "Session terminated." in cleanup + + +def test_list_exec_sessions_reports_empty_state(): + result = asyncio.run(ListExecSessionsTool(manager=ExecSessionManager()).execute()) + + assert result == "No active exec sessions." diff --git a/tests/tools/test_file_edit_coding_enhancements.py b/tests/tools/test_file_edit_coding_enhancements.py new file mode 100644 index 000000000..d361d88ae --- /dev/null +++ b/tests/tools/test_file_edit_coding_enhancements.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import asyncio + +from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool + + +def test_read_file_force_bypasses_dedup(tmp_path): + target = tmp_path / "data.txt" + target.write_text("alpha\n") + tool = ReadFileTool(workspace=tmp_path) + + first = asyncio.run(tool.execute(path=str(target))) + second = asyncio.run(tool.execute(path=str(target))) + forced = asyncio.run(tool.execute(path=str(target), force=True)) + + assert "alpha" in first + assert "unchanged" in second.lower() + assert "alpha" in forced + assert "unchanged" not in forced.lower() + + +def test_edit_file_can_select_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("one\nsame\ntwo\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=2, + )) + + assert "Successfully edited" in result + assert target.read_text() == "one\nsame\ntwo\nchanged\n" + + +def test_edit_file_expected_replacements_guards_replace_all(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + replace_all=True, + expected_replacements=1, + )) + + assert "expected 1 replacements but would make 2" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_expected_replacements_allows_replace_all_when_count_matches(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + replace_all=True, + expected_replacements=2, + )) + + assert "Successfully edited" in result + assert target.read_text() == "changed\nchanged\n" + + +def test_edit_file_can_select_nearest_line_hint(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("one\nsame\ntwo\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=4, + )) + + assert "Successfully edited" in result + assert target.read_text() == "one\nsame\ntwo\nchanged\n" + + +def test_edit_file_can_edit_ipynb_as_json(tmp_path): + target = tmp_path / "analysis.ipynb" + target.write_text('{"cells": []}') + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text='"cells": []', + new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]', + )) + + assert "Successfully edited" in result + assert '"source": "hi"' in target.read_text() + + +def test_edit_file_multiple_match_hint_mentions_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + )) + + assert "old_text appears 2 times" in result + assert "occurrence" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_ambiguous_line_hint(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nmiddle\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=2, + )) + + assert "line_hint 2 is ambiguous" in result + assert target.read_text() == "same\nmiddle\nsame\n" + + +def test_edit_file_rejects_occurrence_with_replace_all(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=1, + replace_all=True, + )) + + assert "occurrence cannot be used with replace_all" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_line_hint_with_replace_all(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=1, + replace_all=True, + )) + + assert "line_hint cannot be used with replace_all" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_line_hint_with_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=1, + line_hint=1, + )) + + assert "line_hint cannot be used with occurrence" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_zero_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=0, + )) + + assert "occurrence must be >= 1" in result + assert target.read_text() == "same\n" + + +def test_edit_file_rejects_zero_line_hint(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=0, + )) + + assert "line_hint must be >= 1" in result + assert target.read_text() == "same\n" diff --git a/tests/tools/test_notebook_tool.py b/tests/tools/test_notebook_tool.py deleted file mode 100644 index 232f13c4b..000000000 --- a/tests/tools/test_notebook_tool.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Tests for NotebookEditTool — Jupyter .ipynb editing.""" - -import json - -import pytest - -from nanobot.agent.tools.notebook import NotebookEditTool - - -def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict: - """Build a minimal valid .ipynb structure.""" - return { - "nbformat": nbformat, - "nbformat_minor": nbformat_minor, - "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}}, - "cells": cells or [], - } - - -def _code_cell(source: str, cell_id: str | None = None) -> dict: - cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None} - if cell_id: - cell["id"] = cell_id - return cell - - -def _md_cell(source: str, cell_id: str | None = None) -> dict: - cell = {"cell_type": "markdown", "source": source, "metadata": {}} - if cell_id: - cell["id"] = cell_id - return cell - - -def _write_nb(tmp_path, name: str, nb: dict) -> str: - p = tmp_path / name - p.write_text(json.dumps(nb), encoding="utf-8") - return str(p) - - -class TestNotebookEdit: - - @pytest.fixture() - def tool(self, tmp_path): - return NotebookEditTool(workspace=tmp_path) - - @pytest.mark.asyncio - async def test_replace_cell_content(self, tool, tmp_path): - nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="print('world')") - assert "Successfully" in result - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert saved["cells"][0]["source"] == "print('world')" - assert saved["cells"][1]["source"] == "x = 1" - - @pytest.mark.asyncio - async def test_insert_cell_after_target(self, tool, tmp_path): - nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert") - assert "Successfully" in result - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert len(saved["cells"]) == 3 - assert saved["cells"][0]["source"] == "cell 0" - assert saved["cells"][1]["source"] == "inserted" - assert saved["cells"][2]["source"] == "cell 1" - - @pytest.mark.asyncio - async def test_delete_cell(self, tool, tmp_path): - nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=1, edit_mode="delete") - assert "Successfully" in result - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert len(saved["cells"]) == 2 - assert saved["cells"][0]["source"] == "A" - assert saved["cells"][1]["source"] == "C" - - @pytest.mark.asyncio - async def test_create_new_notebook_from_scratch(self, tool, tmp_path): - path = str(tmp_path / "new.ipynb") - result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown") - assert "Successfully" in result or "created" in result.lower() - saved = json.loads((tmp_path / "new.ipynb").read_text()) - assert saved["nbformat"] == 4 - assert len(saved["cells"]) == 1 - assert saved["cells"][0]["cell_type"] == "markdown" - assert saved["cells"][0]["source"] == "# Hello" - - @pytest.mark.asyncio - async def test_invalid_cell_index_error(self, tool, tmp_path): - nb = _make_notebook([_code_cell("only cell")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=5, new_source="x") - assert "Error" in result - - @pytest.mark.asyncio - async def test_non_ipynb_rejected(self, tool, tmp_path): - f = tmp_path / "script.py" - f.write_text("pass") - result = await tool.execute(path=str(f), cell_index=0, new_source="x") - assert "Error" in result - assert ".ipynb" in result - - @pytest.mark.asyncio - async def test_preserves_metadata_and_outputs(self, tool, tmp_path): - cell = _code_cell("old") - cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}] - cell["execution_count"] = 42 - nb = _make_notebook([cell]) - path = _write_nb(tmp_path, "test.ipynb", nb) - await tool.execute(path=path, cell_index=0, new_source="new") - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert saved["metadata"]["kernelspec"]["language"] == "python" - - @pytest.mark.asyncio - async def test_nbformat_45_generates_cell_id(self, tool, tmp_path): - nb = _make_notebook([], nbformat_minor=5) - path = _write_nb(tmp_path, "test.ipynb", nb) - await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert") - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert "id" in saved["cells"][0] - assert len(saved["cells"][0]["id"]) > 0 - - @pytest.mark.asyncio - async def test_insert_with_cell_type_markdown(self, tool, tmp_path): - nb = _make_notebook([_code_cell("code")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown") - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert saved["cells"][1]["cell_type"] == "markdown" - - @pytest.mark.asyncio - async def test_invalid_edit_mode_rejected(self, tool, tmp_path): - nb = _make_notebook([_code_cell("code")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae") - assert "Error" in result - assert "edit_mode" in result - - @pytest.mark.asyncio - async def test_invalid_cell_type_rejected(self, tool, tmp_path): - nb = _make_notebook([_code_cell("code")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw") - assert "Error" in result - assert "cell_type" in result diff --git a/tests/tools/test_search_tools.py b/tests/tools/test_search_tools.py index 0d3697044..fc7c1944a 100644 --- a/tests/tools/test_search_tools.py +++ b/tests/tools/test_search_tools.py @@ -12,7 +12,7 @@ import pytest from nanobot.agent.loop import AgentLoop from nanobot.agent.subagent import SubagentManager, SubagentStatus -from nanobot.agent.tools.search import GrepTool +from nanobot.agent.tools.search import FindFilesTool, GrepTool from nanobot.agent.tools.web import WebSearchTool from nanobot.bus.queue import MessageBus from nanobot.config.schema import WebSearchConfig @@ -33,6 +33,68 @@ async def test_web_search_tool_refreshes_dynamic_config_loader(monkeypatch) -> N assert await tool.execute("nanobot") == "duckduckgo:nanobot:3" +@pytest.mark.asyncio +async def test_find_files_filters_by_query_glob_and_type(tmp_path: Path) -> None: + (tmp_path / "src").mkdir() + (tmp_path / "src" / "settings_view.tsx").write_text("export {}\n", encoding="utf-8") + (tmp_path / "src" / "settings_api.py").write_text("pass\n", encoding="utf-8") + (tmp_path / "README.md").write_text("settings\n", encoding="utf-8") + + tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path) + result = await tool.execute( + path=".", + query="settings", + glob="src/**", + type="ts", + ) + + assert result.splitlines() == ["src/settings_view.tsx"] + + +@pytest.mark.asyncio +async def test_find_files_can_include_directories(tmp_path: Path) -> None: + (tmp_path / "src" / "settings").mkdir(parents=True) + (tmp_path / "src" / "settings" / "index.ts").write_text("export {}\n", encoding="utf-8") + + tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path) + result = await tool.execute(path="src", query="settings", include_dirs=True) + + assert "src/settings/" in result.splitlines() + assert "src/settings/index.ts" in result.splitlines() + + +@pytest.mark.asyncio +async def test_find_files_supports_modified_sort_and_pagination(tmp_path: Path) -> None: + (tmp_path / "src").mkdir() + for idx, name in enumerate(("a.py", "b.py", "c.py"), start=1): + file_path = tmp_path / "src" / name + file_path.write_text("pass\n", encoding="utf-8") + os.utime(file_path, (idx, idx)) + + tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path) + result = await tool.execute( + path="src", + type="py", + sort="modified", + head_limit=1, + offset=1, + ) + + assert result.splitlines()[0] == "src/b.py" + assert "pagination: limit=1, offset=1" in result + + +@pytest.mark.asyncio +async def test_find_files_rejects_paths_outside_workspace(tmp_path: Path) -> None: + outside = tmp_path.parent / "outside-find-files.txt" + outside.write_text("secret\n", encoding="utf-8") + + tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path) + result = await tool.execute(path=str(outside)) + + assert result.startswith("Error:") + + @pytest.mark.asyncio async def test_grep_respects_glob_filter_and_context(tmp_path: Path) -> None: (tmp_path / "src").mkdir() @@ -249,6 +311,7 @@ def test_agent_loop_registers_grep(tmp_path: Path) -> None: loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + assert "find_files" in loop.tools.tool_names assert "grep" in loop.tools.tool_names @@ -280,6 +343,7 @@ async def test_subagent_registers_grep(tmp_path: Path) -> None: status = SubagentStatus(task_id="sub-1", label="label", task_description="search task", started_at=time.monotonic()) await mgr._run_subagent("sub-1", "search task", "label", {"channel": "cli", "chat_id": "direct"}, status) + assert "find_files" in captured["tool_names"] assert "grep" in captured["tool_names"] diff --git a/tests/tools/test_tool_descriptions.py b/tests/tools/test_tool_descriptions.py new file mode 100644 index 000000000..bb7665e4e --- /dev/null +++ b/tests/tools/test_tool_descriptions.py @@ -0,0 +1,46 @@ +from nanobot.agent.tools.apply_patch import ApplyPatchTool +from nanobot.agent.tools.exec_session import ListExecSessionsTool, WriteStdinTool +from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool +from nanobot.agent.tools.search import FindFilesTool, GrepTool +from nanobot.agent.tools.shell import ExecTool + + +def test_coding_tool_descriptions_steer_editing_priority() -> None: + apply_patch = ApplyPatchTool().description.lower() + edit_file = EditFileTool().description.lower() + write_file = WriteFileTool().description.lower() + + assert "default tool for code edits" in apply_patch + assert "multi-file" in apply_patch + assert "dry_run=true" in apply_patch + assert "edit_file only for small exact replacements" in apply_patch + + assert "small, exact replacement" in edit_file + assert "copied from read_file" in edit_file + assert "prefer apply_patch" in edit_file + + assert "replace an entire file" in write_file + assert "prefer apply_patch" in write_file + + +def test_coding_tool_descriptions_steer_discovery_and_shell_usage() -> None: + read_file = ReadFileTool().description.lower() + find_files = FindFilesTool().description.lower() + grep = GrepTool().description.lower() + exec_tool = ExecTool().description.lower() + write_stdin = WriteStdinTool().description.lower() + list_sessions = ListExecSessionsTool().description.lower() + + assert "find_files/list_dir first" in read_file + assert "before editing" in read_file + assert "prefer it over shell find/ls" in find_files + assert "prefer this over shell grep" in grep + + assert "tests, builds" in exec_tool + assert "prefer read_file/find_files/grep" in exec_tool + assert "apply_patch/write_file/edit_file" in exec_tool + assert "yield_time_ms" in exec_tool + + assert "do not use this to start new commands" in write_stdin + assert "wait_for" in write_stdin + assert "recover a session_id" in list_sessions diff --git a/tests/tools/test_tool_loader.py b/tests/tools/test_tool_loader.py index 54b4d92d5..62703883c 100644 --- a/tests/tools/test_tool_loader.py +++ b/tests/tools/test_tool_loader.py @@ -89,9 +89,11 @@ def test_discover_finds_concrete_tools(): loader = ToolLoader() discovered = loader.discover() class_names = {cls.__name__ for cls in discovered} + assert "ApplyPatchTool" in class_names assert "ExecTool" in class_names assert "MessageTool" in class_names assert "SpawnTool" in class_names + assert "WriteStdinTool" in class_names def test_discover_excludes_abstract_and_mcp(): @@ -406,7 +408,8 @@ def test_loader_registers_same_tools_as_old_hardcoded(): expected = { "read_file", "write_file", "edit_file", "list_dir", - "grep", "notebook_edit", "exec", "web_search", "web_fetch", + "find_files", "grep", "exec", "write_stdin", "list_exec_sessions", + "web_search", "web_fetch", "message", "spawn", "cron", } actual = set(registered) diff --git a/tests/tools/test_tool_validation.py b/tests/tools/test_tool_validation.py index 42620dcc6..188a8952f 100644 --- a/tests/tools/test_tool_validation.py +++ b/tests/tools/test_tool_validation.py @@ -3,6 +3,8 @@ import subprocess import sys from typing import Any +import pytest + from nanobot.agent.tools import ( ArraySchema, IntegerSchema, @@ -15,6 +17,7 @@ from nanobot.agent.tools import ( from nanobot.agent.tools.base import Tool from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.shell import ExecTool +from nanobot.security.network import configure_ssrf_whitelist class SampleTool(Tool): @@ -218,6 +221,39 @@ def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None: assert "/bin/python" not in paths +def test_exec_extract_absolute_paths_ignores_urls() -> None: + cmd = 'curl -s -o /dev/null -w "%{http_code}" https://www.google.com' + paths = ExecTool._extract_absolute_paths(cmd) + assert paths == ["/dev/null"] + + +@pytest.mark.parametrize( + "command", + [ + 'curl -s -o /dev/null -w "%{http_code}" https://www.google.com', + 'wget -q -O - http://example.com 2>&1 | head -c 100', + 'python3 -c "import urllib.request; print(urllib.request.urlopen(\'http://example.com\').read()[:100])"', + ], +) +def test_exec_guard_allows_public_urls(tmp_path, command: str) -> None: + tool = ExecTool(restrict_to_workspace=True) + error = tool._guard_command(command, str(tmp_path)) + assert error is None + + +def test_exec_guard_allows_whitelisted_internal_urls(tmp_path) -> None: + configure_ssrf_whitelist(["10.10.10.0/24"]) + try: + tool = ExecTool(restrict_to_workspace=True) + error = tool._guard_command( + 'curl -s -H "Authorization: Bearer ..." http://10.10.10.3:8123/api/', + str(tmp_path), + ) + assert error is None + finally: + configure_ssrf_whitelist([]) + + def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None: cmd = "cat /tmp/data.txt > /tmp/out.txt" paths = ExecTool._extract_absolute_paths(cmd) diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py index cdaae5167..fe035b41b 100644 --- a/tests/utils/test_file_edit_events.py +++ b/tests/utils/test_file_edit_events.py @@ -5,12 +5,13 @@ from pathlib import Path from types import SimpleNamespace from nanobot.utils.file_edit_events import ( + StreamingFileEditTracker, build_file_edit_end_event, build_file_edit_start_event, line_diff_stats, prepare_file_edit_tracker, + prepare_file_edit_trackers, read_file_snapshot, - StreamingFileEditTracker, ) @@ -81,6 +82,63 @@ def test_binary_file_is_reported_but_not_counted(tmp_path: Path) -> None: assert (event["added"], event["deleted"]) == (0, 0) +def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) -> None: + (tmp_path / "src").mkdir() + existing = tmp_path / "src" / "existing.py" + existing.write_text("old\nkeep\n", encoding="utf-8") + delete_me = tmp_path / "src" / "delete_me.py" + delete_me.write_text("gone\n", encoding="utf-8") + + edits = [ + {"path": "src/new.py", "action": "add", "new_text": "fresh"}, + {"path": "src/existing.py", "action": "replace", "old_text": "old", "new_text": "new"}, + {"path": "src/delete_me.py", "action": "delete", "old_text": "gone\n"}, + ] + + trackers = prepare_file_edit_trackers( + call_id="call-patch", + tool_name="apply_patch", + tool=None, + workspace=tmp_path, + params={"edits": edits}, + ) + + assert [tracker.display_path for tracker in trackers] == [ + "src/new.py", + "src/existing.py", + "src/delete_me.py", + ] + + (tmp_path / "src" / "new.py").write_text("fresh\n", encoding="utf-8") + existing.write_text("new\nkeep\n", encoding="utf-8") + delete_me.unlink() + + events = [build_file_edit_end_event(tracker, {"edits": edits}) for tracker in trackers] + by_path = {event["path"]: event for event in events} + assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0) + assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1) + assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1) + + +def test_apply_patch_dry_run_does_not_prepare_file_edit_trackers(tmp_path: Path) -> None: + (tmp_path / "file.txt").write_text("old\n", encoding="utf-8") + + trackers = prepare_file_edit_trackers( + call_id="call-patch", + tool_name="apply_patch", + tool=None, + workspace=tmp_path, + params={ + "dry_run": True, + "edits": [ + {"path": "file.txt", "action": "replace", "old_text": "old", "new_text": "new"} + ], + }, + ) + + assert trackers == [] + + def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None: target = tmp_path / "large.txt" params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)} @@ -140,6 +198,58 @@ def test_streaming_write_file_tracker_emits_live_line_counts(tmp_path: Path) -> assert events[-1]["deleted"] == 0 +def test_streaming_apply_patch_tracker_emits_live_counts_per_file(tmp_path: Path) -> None: + (tmp_path / "src").mkdir() + (tmp_path / "src" / "existing.py").write_text("old\nkeep\n", encoding="utf-8") + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-patch", + "name": "apply_patch", + "arguments_delta": ( + '{"edits":[{"path":"src/existing.py","action":"replace","old_text":"old","new_text":"new"}' + ',{"path":"src/new.py","action":"add","new_text":"fresh"}]}' + ), + }) + + asyncio.run(run()) + + by_path = {event["path"]: event for event in events} + assert by_path["src/existing.py"]["tool"] == "apply_patch" + assert by_path["src/existing.py"]["status"] == "editing" + assert by_path["src/existing.py"]["approximate"] is True + assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1) + assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0) + + +def test_streaming_apply_patch_tracker_skips_dry_run(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-patch", + "name": "apply_patch", + "arguments_delta": ( + '{"dry_run":true,"edits":[{"path":"dry.md","action":"add","new_text":"preview"}]}' + ), + }) + + asyncio.run(run()) + + assert events == [] + + def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None: events: list[dict] = [] @@ -308,6 +418,43 @@ def test_streaming_tracker_applies_canonical_call_id_to_final_tool(tmp_path: Pat asyncio.run(run()) +def test_streaming_tracker_does_not_restore_duplicate_canonical_ids(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call_dup", + "name": "write_file", + "arguments_delta": '{"path":"a.md","content":"one\\n"}', + }) + await tracker.update({ + "index": 1, + "call_id": "call_dup", + "name": "write_file", + "arguments_delta": '{"path":"b.md","content":"two\\n"}', + }) + final_a = SimpleNamespace( + id="call_dup", + name="write_file", + arguments={"path": "a.md", "content": "one\n"}, + ) + final_b = SimpleNamespace( + id="call_unique", + name="write_file", + arguments={"path": "b.md", "content": "two\n"}, + ) + tracker.apply_final_call_ids([final_a, final_b]) + assert final_a.id == "call_dup" + assert final_b.id == "call_unique" + + asyncio.run(run()) + + def test_streaming_edit_file_tracker_flushes_small_pending_count(tmp_path: Path) -> None: target = tmp_path / "small.py" target.write_text("old\n", encoding="utf-8") diff --git a/webui/src/App.tsx b/webui/src/App.tsx index c303446e2..8c6127829 100644 --- a/webui/src/App.tsx +++ b/webui/src/App.tsx @@ -43,6 +43,7 @@ const SIDEBAR_STORAGE_KEY = "nanobot-webui.sidebar"; const COMPLETED_RUNS_STORAGE_KEY = "nanobot-webui.sidebar.completed-runs.v1"; const RESTART_STARTED_KEY = "nanobot-webui.restartStartedAt"; const SIDEBAR_WIDTH = 272; +const SIDEBAR_RAIL_WIDTH = 56; const TOKEN_REFRESH_MARGIN_MS = 30_000; const TOKEN_REFRESH_MIN_DELAY_MS = 5_000; type ShellView = "chat" | "settings"; @@ -411,6 +412,10 @@ function Shell({ setDesktopSidebarOpen(false); }, []); + const openDesktopSidebar = useCallback(() => { + setDesktopSidebarOpen(true); + }, []); + const closeMobileSidebar = useCallback(() => { setMobileSidebarOpen(false); }, []); @@ -560,6 +565,21 @@ function Shell({ setSessionSearchOpen(true); }, []); + useEffect(() => { + const handleKeyDown = (event: globalThis.KeyboardEvent) => { + if (event.defaultPrevented) return; + const plainCommandK = + (event.metaKey || event.ctrlKey) && !event.altKey && !event.shiftKey; + if (!plainCommandK) return; + if (event.key.toLowerCase() !== "k") return; + event.preventDefault(); + onOpenSessionSearch(); + }; + + window.addEventListener("keydown", handleKeyDown); + return () => window.removeEventListener("keydown", handleKeyDown); + }, [onOpenSessionSearch]); + const onSelectSearchResult = useCallback( (key: string) => { setSessionSearchOpen(false); @@ -732,17 +752,19 @@ function Shell({ "relative z-20 hidden shrink-0 overflow-hidden lg:block", "transition-[width] duration-300 ease-out", )} - style={{ width: desktopSidebarOpen ? SIDEBAR_WIDTH : 0 }} + style={{ + width: desktopSidebarOpen ? SIDEBAR_WIDTH : SIDEBAR_RAIL_WIDTH, + }} >
- +
) : null} @@ -769,17 +791,15 @@ function Shell({ ) : null} - {showMainSidebar ? ( - - ) : null} +
{view === "settings" && ( diff --git a/webui/src/components/ChatList.tsx b/webui/src/components/ChatList.tsx index 705039aea..d098a5972 100644 --- a/webui/src/components/ChatList.tsx +++ b/webui/src/components/ChatList.tsx @@ -1,3 +1,9 @@ +import { + memo, + useEffect, + useMemo, + useState, +} from "react"; import { Archive, ArchiveRestore, @@ -19,6 +25,9 @@ import { deriveTitle, relativeTime } from "@/lib/format"; import { cn } from "@/lib/utils"; import type { ChatSummary, SidebarDensity, SidebarSortMode } from "@/lib/types"; +const INITIAL_VISIBLE_SESSIONS = 160; +const VISIBLE_SESSIONS_INCREMENT = 160; + interface ChatListProps { sessions: ChatSummary[]; activeKey: string | null; @@ -42,7 +51,7 @@ interface ChatListProps { emptyLabel?: string; } -export function ChatList({ +export const ChatList = memo(function ChatList({ sessions, activeKey, onSelect, @@ -65,6 +74,52 @@ export function ChatList({ emptyLabel, }: ChatListProps) { const { t } = useTranslation(); + const [visibleLimit, setVisibleLimit] = useState(INITIAL_VISIBLE_SESSIONS); + const labels = useMemo(() => ({ + pinned: t("chat.groups.pinned"), + all: t("chat.groups.all"), + today: t("chat.groups.today"), + yesterday: t("chat.groups.yesterday"), + earlier: t("chat.groups.earlier"), + archived: t("chat.groups.archived"), + fallbackTitle: t("chat.newChat"), + }), [t]); + const groups = useMemo( + () => groupSessions(sessions, labels, { + pinnedKeys, + archivedKeys, + titleOverrides, + showArchived, + sort, + }), + [ + archivedKeys, + labels, + pinnedKeys, + sessions, + showArchived, + sort, + titleOverrides, + ], + ); + const limitedGroups = useMemo( + () => limitGroups(groups, visibleLimit, activeKey), + [activeKey, groups, visibleLimit], + ); + const totalSessionCount = useMemo( + () => groups.reduce((total, group) => total + group.sessions.length, 0), + [groups], + ); + const visibleSessionCount = useMemo( + () => limitedGroups.reduce((total, group) => total + group.sessions.length, 0), + [limitedGroups], + ); + const hiddenSessionCount = Math.max(0, totalSessionCount - visibleSessionCount); + + useEffect(() => { + setVisibleLimit(INITIAL_VISIBLE_SESSIONS); + }, [showArchived, sort]); + if (loading && sessions.length === 0) { return (
@@ -81,21 +136,6 @@ export function ChatList({ ); } - const groups = groupSessions(sessions, { - pinned: t("chat.groups.pinned"), - all: t("chat.groups.all"), - today: t("chat.groups.today"), - yesterday: t("chat.groups.yesterday"), - earlier: t("chat.groups.earlier"), - archived: t("chat.groups.archived"), - fallbackTitle: t("chat.newChat"), - }, { - pinnedKeys, - archivedKeys, - titleOverrides, - showArchived, - sort, - }); const pinned = new Set(pinnedKeys); const archived = new Set(archivedKeys); const running = new Set(runningChatIds); @@ -105,7 +145,7 @@ export function ChatList({ return (
- {groups.map((group) => ( + {limitedGroups.map((group) => (
{group.label} @@ -228,10 +268,25 @@ export function ChatList({
))} + {hiddenSessionCount > 0 ? ( +
+ +
+ ) : null}
); -} +}); function SessionActivityIndicator({ state, @@ -366,6 +421,45 @@ function groupSessions( return groups; } +function limitGroups( + groups: Array<{ label: string; sessions: ChatSummary[] }>, + limit: number, + activeKey: string | null, +): Array<{ label: string; sessions: ChatSummary[] }> { + let remaining = Math.max(0, limit); + let activeVisible = !activeKey; + const out: Array<{ label: string; sessions: ChatSummary[] }> = []; + + for (const group of groups) { + const visible = remaining > 0 + ? group.sessions.slice(0, remaining) + : []; + remaining -= visible.length; + if (activeKey && visible.some((session) => session.key === activeKey)) { + activeVisible = true; + } + if (visible.length > 0) { + out.push({ label: group.label, sessions: visible }); + } + } + + if (activeVisible || !activeKey) return out; + + for (const group of groups) { + const active = group.sessions.find((session) => session.key === activeKey); + if (!active) continue; + const existing = out.find((item) => item.label === group.label); + if (existing) { + existing.sessions = [...existing.sessions, active]; + } else { + out.push({ label: group.label, sessions: [active] }); + } + return out; + } + + return out; +} + function sortSessions( sessions: ChatSummary[], sort: SidebarSortMode, diff --git a/webui/src/components/SessionSearchDialog.tsx b/webui/src/components/SessionSearchDialog.tsx index 3ddea0672..1e0f12044 100644 --- a/webui/src/components/SessionSearchDialog.tsx +++ b/webui/src/components/SessionSearchDialog.tsx @@ -37,13 +37,16 @@ export function SessionSearchDialog({ const [highlightedIndex, setHighlightedIndex] = useState(0); const normalizedQuery = query.trim().toLowerCase(); - const results = useMemo(() => { + const sessionResults = useMemo(() => { + if (!open) return []; if (!normalizedQuery) return sessions; const terms = normalizedQuery.split(/\s+/).filter(Boolean); return sessions.filter((session) => sessionMatchesTerms(session, terms, titleOverrides[session.key]), ); - }, [normalizedQuery, sessions, titleOverrides]); + }, [normalizedQuery, open, sessions, titleOverrides]); + const itemCount = sessionResults.length; + const shortcutLabel = useMemo(getSearchShortcutLabel, []); useEffect(() => { if (!open) return; @@ -58,9 +61,9 @@ export function SessionSearchDialog({ useEffect(() => { setHighlightedIndex((index) => - results.length === 0 ? 0 : Math.min(index, results.length - 1), + itemCount === 0 ? 0 : Math.min(index, itemCount - 1), ); - }, [results.length]); + }, [itemCount]); const handleSelect = (key: string) => { onOpenChange(false); @@ -71,17 +74,19 @@ export function SessionSearchDialog({ if (event.key === "ArrowDown") { event.preventDefault(); setHighlightedIndex((index) => - results.length === 0 ? 0 : Math.min(index + 1, results.length - 1), + itemCount === 0 ? 0 : (index + 1) % itemCount, ); return; } if (event.key === "ArrowUp") { event.preventDefault(); - setHighlightedIndex((index) => Math.max(index - 1, 0)); + setHighlightedIndex((index) => + itemCount === 0 ? 0 : (index - 1 + itemCount) % itemCount, + ); return; } if (event.key === "Enter") { - const highlighted = results[highlightedIndex]; + const highlighted = sessionResults[highlightedIndex]; if (!highlighted) return; event.preventDefault(); handleSelect(highlighted.key); @@ -125,70 +130,75 @@ export function SessionSearchDialog({ aria-label={t("sidebar.searchAria")} className="h-full min-w-0 flex-1 bg-transparent text-[15px] font-medium text-foreground outline-none placeholder:text-muted-foreground/75" /> + + {shortcutLabel} +
-
- {sectionLabel} -
+
+
+ {sectionLabel} +
- {loading && sessions.length === 0 ? ( -
- {t("chat.loading")} -
- ) : results.length === 0 ? ( -
- {emptyLabel} -
- ) : ( -
    - {results.map((session, index) => { - const title = titleOverrides[session.key]?.trim() || - session.title?.trim() || - deriveTitle(session.preview, t("chat.newChat")); - const preview = session.preview.trim(); - const showPreview = - preview.length > 0 && - preview.toLowerCase() !== title.trim().toLowerCase(); - const highlighted = index === highlightedIndex; - const active = session.key === activeKey; - return ( -
  • - -
  • - ); - })} -
- )} + {showPreview ? ( + + {preview} + + ) : null} + + + + ); + })} + + )} +
@@ -211,3 +221,13 @@ function sessionMatchesTerms( return terms.every((term) => haystack.includes(term)); } + +function getSearchShortcutLabel() { + if (typeof navigator === "undefined") return "Ctrl K"; + const platform = navigator.platform.toLowerCase(); + const apple = + platform.includes("mac") || + platform.includes("iphone") || + platform.includes("ipad"); + return apple ? "⌘K" : "Ctrl K"; +} diff --git a/webui/src/components/Sidebar.tsx b/webui/src/components/Sidebar.tsx index 53acfc3a9..1040c9696 100644 --- a/webui/src/components/Sidebar.tsx +++ b/webui/src/components/Sidebar.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useState, type ReactNode } from "react"; import { Archive, ListFilter, @@ -28,6 +28,7 @@ import type { SidebarSortMode, SidebarViewState, } from "@/lib/types"; +import { cn } from "@/lib/utils"; interface SidebarProps { sessions: ChatSummary[]; @@ -44,7 +45,9 @@ interface SidebarProps { onToggleArchived: () => void; onUpdateView: (view: Partial) => void; onCollapse: () => void; + onExpand?: () => void; containActionMenus?: boolean; + collapsed?: boolean; pinnedKeys?: string[]; archivedKeys?: string[]; titleOverrides?: Record; @@ -59,6 +62,8 @@ export function Sidebar(props: SidebarProps) { const { t } = useTranslation(); const [menuPortalContainer, setMenuPortalContainer] = useState(null); + const collapsed = Boolean(props.collapsed); + const toggleLabel = t("thread.header.toggleSidebar"); return ( ); } +function SidebarActionButton({ + collapsed, + label, + icon, + onClick, + className, +}: { + collapsed: boolean; + label: string; + icon: ReactNode; + onClick: () => void; + className?: string; +}) { + return ( + + ); +} + function SidebarViewMenu({ + compact = false, view, onUpdateView, }: { + compact?: boolean; view?: SidebarViewState; onUpdateView: (view: Partial) => void; }) { @@ -182,11 +268,28 @@ function SidebarViewMenu({ diff --git a/webui/src/components/thread/AgentActivityCluster.tsx b/webui/src/components/thread/AgentActivityCluster.tsx index 46be8b43d..135f33ef6 100644 --- a/webui/src/components/thread/AgentActivityCluster.tsx +++ b/webui/src/components/thread/AgentActivityCluster.tsx @@ -27,6 +27,7 @@ interface ActivityCounts { fileCount: number; added: number; deleted: number; + hasDiffStats: boolean; hasEditingFiles: boolean; hasFailedFiles: boolean; primaryFilePath?: string; @@ -61,6 +62,7 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act } let added = 0; let deleted = 0; + let hasDiffStats = false; let hasEditingFiles = false; let failedFileCount = 0; let primaryFilePath: string | undefined; @@ -77,6 +79,10 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act if (edit.status === "error" || edit.binary) { continue; } + if (!hasVisibleDiffStats(edit)) { + continue; + } + hasDiffStats = true; added += edit.added; deleted += edit.deleted; } @@ -86,6 +92,7 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act fileCount: fileEdits.length, added, deleted, + hasDiffStats, hasEditingFiles, hasFailedFiles: fileEdits.length > 0 && failedFileCount === fileEdits.length, primaryFilePath, @@ -120,6 +127,7 @@ export function AgentActivityCluster({ fileCount, added, deleted, + hasDiffStats, hasEditingFiles, hasFailedFiles, primaryFilePath, @@ -140,6 +148,7 @@ export function AgentActivityCluster({ const headerBusy = fileCount > 0 ? hasEditingFiles : isTurnStreaming; const singleFilePath = fileCount === 1 ? primaryFilePath : undefined; const singleFileTooltipPath = fileCount === 1 ? primaryFileTooltipPath : undefined; + const hasVisibleActivity = reasoningSteps > 0 || toolCalls > 0 || fileCount > 0; const fileActivitySummary = fileCount > 0 ? hasPendingFileEdit && !singleFilePath @@ -243,6 +252,8 @@ export function AgentActivityCluster({ autoFollowActivityRef.current = distance < ACTIVITY_SCROLL_NEAR_BOTTOM_PX; }, []); + if (!hasVisibleActivity) return null; + return (
- +
); } @@ -52,7 +57,7 @@ export function ThreadHeader({ onClick={onToggleSidebar} className={cn( "h-7 w-7 rounded-md text-muted-foreground hover:bg-accent/35 hover:text-foreground", - hideSidebarToggleOnDesktop && "lg:pointer-events-none lg:opacity-0", + hideSidebarToggleOnDesktop && "lg:hidden", )} > @@ -62,7 +67,12 @@ export function ThreadHeader({
- +
@@ -73,10 +83,12 @@ function ThemeButton({ theme, onToggleTheme, label, + className, }: { theme: "light" | "dark"; onToggleTheme: () => void; label: string; + className?: string; }) { return (