From 09a692be6f6630b951c805467508b9b2b30f18e8 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Wed, 20 May 2026 18:21:48 +0800 Subject: [PATCH 01/54] docs(readme): add multi-language doc site links MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Link nanobot.wiki documentation in 10 languages from README header: English, 简体中文, 繁體中文, Español, Français, Bahasa Indonesia, 日本語, 한국어, Русский, Tiếng Việt. --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index b5e4b02c0..bf751f58a 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,18 @@ ![cover-v5-optimized](./images/GitHub_README.png)
+

+ English | + 简体中文 | + 繁體中文 | + Español | + Français | + Bahasa Indonesia | + 日本語 | + 한국어 | + Русский | + Tiếng Việt +

PyPI Downloads From 6851fa57a6d0ca75e754065cc56c5c88a4332cc5 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Wed, 20 May 2026 12:51:26 +0800 Subject: [PATCH 02/54] feat(tools): optimize coding workflows --- README.md | 2 +- nanobot/agent/tools/apply_patch.py | 341 +++++++++++++++ nanobot/agent/tools/exec_session.py | 409 ++++++++++++++++++ nanobot/agent/tools/filesystem.py | 109 ++++- nanobot/agent/tools/notebook.py | 162 ------- nanobot/agent/tools/shell.py | 263 +++++++++-- nanobot/utils/file_edit_events.py | 76 +--- tests/tools/test_apply_patch_tool.py | 238 ++++++++++ tests/tools/test_edit_enhancements.py | 19 +- tests/tools/test_exec_session_tools.py | 242 +++++++++++ .../test_file_edit_coding_enhancements.py | 216 +++++++++ tests/tools/test_notebook_tool.py | 147 ------- tests/tools/test_tool_loader.py | 4 +- 13 files changed, 1771 insertions(+), 457 deletions(-) create mode 100644 nanobot/agent/tools/apply_patch.py create mode 100644 nanobot/agent/tools/exec_session.py delete mode 100644 nanobot/agent/tools/notebook.py create mode 100644 tests/tools/test_apply_patch_tool.py create mode 100644 tests/tools/test_exec_session_tools.py create mode 100644 tests/tools/test_file_edit_coding_enhancements.py delete mode 100644 tests/tools/test_notebook_tool.py diff --git a/README.md b/README.md index b5e4b02c0..b97545731 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ - **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks. - **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened. - **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media. -- **2026-04-10** 📓 Notebook editing tool, multiple MCP servers, Feishu streaming & done-emoji. +- **2026-04-10** 📓 Multiple MCP servers, Feishu streaming & done-emoji. - **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config. - **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback. - **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools. diff --git a/nanobot/agent/tools/apply_patch.py b/nanobot/agent/tools/apply_patch.py new file mode 100644 index 000000000..e69a65f10 --- /dev/null +++ b/nanobot/agent/tools/apply_patch.py @@ -0,0 +1,341 @@ +"""Structured patch editing tool for coding workflows.""" + +from __future__ import annotations + +import difflib +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +from nanobot.agent.tools.base import tool_parameters +from nanobot.agent.tools.filesystem import _FsTool +from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema + + +PatchKind = Literal["add", "delete", "update"] + + +@dataclass(slots=True) +class _Hunk: + header: str | None + lines: list[tuple[str, str]] + + +@dataclass(slots=True) +class _PatchOp: + kind: PatchKind + path: str + new_path: str | None = None + add_lines: list[str] | None = None + hunks: list[_Hunk] | None = None + + +class _PatchError(ValueError): + pass + + +_ABSOLUTE_WINDOWS_RE = re.compile(r"^[A-Za-z]:[\\/]") + + +def _is_file_header(line: str) -> bool: + return ( + line.startswith("*** Add File: ") + or line.startswith("*** Delete File: ") + or line.startswith("*** Update File: ") + ) + + +def _validate_relative_path(path: str) -> str: + normalized = path.strip() + if not normalized: + raise _PatchError("patch path cannot be empty") + if "\0" in normalized: + raise _PatchError(f"patch path contains a null byte: {path!r}") + if normalized.startswith(("~", "/", "\\")) or _ABSOLUTE_WINDOWS_RE.match(normalized): + raise _PatchError(f"patch path must be relative: {path}") + if any(part == ".." for part in re.split(r"[\\/]+", normalized)): + raise _PatchError(f"patch path must not contain '..': {path}") + return normalized + + +def _lines_to_text(lines: list[str]) -> str: + if not lines: + return "" + return "\n".join(lines) + "\n" + + +def _parse_patch(patch: str) -> list[_PatchOp]: + lines = patch.replace("\r\n", "\n").replace("\r", "\n").split("\n") + if lines and lines[-1] == "": + lines.pop() + if not lines or lines[0] != "*** Begin Patch": + raise _PatchError("patch must start with '*** Begin Patch'") + if len(lines) < 2 or lines[-1] != "*** End Patch": + raise _PatchError("patch must end with '*** End Patch'") + + ops: list[_PatchOp] = [] + i = 1 + end = len(lines) - 1 + while i < end: + line = lines[i] + if line.startswith("*** Add File: "): + path = _validate_relative_path(line.removeprefix("*** Add File: ")) + i += 1 + add_lines: list[str] = [] + while i < end and not _is_file_header(lines[i]): + if not lines[i].startswith("+"): + raise _PatchError(f"Add File lines must start with '+': {lines[i]!r}") + add_lines.append(lines[i][1:]) + i += 1 + ops.append(_PatchOp(kind="add", path=path, add_lines=add_lines)) + continue + + if line.startswith("*** Delete File: "): + path = _validate_relative_path(line.removeprefix("*** Delete File: ")) + ops.append(_PatchOp(kind="delete", path=path)) + i += 1 + continue + + if line.startswith("*** Update File: "): + path = _validate_relative_path(line.removeprefix("*** Update File: ")) + i += 1 + new_path: str | None = None + if i < end and lines[i].startswith("*** Move to: "): + new_path = _validate_relative_path(lines[i].removeprefix("*** Move to: ")) + i += 1 + + hunks: list[_Hunk] = [] + while i < end and not _is_file_header(lines[i]): + if not lines[i].startswith("@@"): + raise _PatchError(f"Update File sections require '@@' hunks: {lines[i]!r}") + header = lines[i][2:].strip() or None + i += 1 + hunk_lines: list[tuple[str, str]] = [] + while i < end and not lines[i].startswith("@@") and not _is_file_header(lines[i]): + if lines[i] == "*** End of File": + i += 1 + break + if lines[i] == r"\ No newline at end of file": + i += 1 + continue + if not lines[i] or lines[i][0] not in {" ", "+", "-"}: + raise _PatchError(f"Hunk lines must start with ' ', '+', or '-': {lines[i]!r}") + hunk_lines.append((lines[i][0], lines[i][1:])) + i += 1 + if not hunk_lines: + raise _PatchError(f"Update File hunk is empty: {path}") + hunks.append(_Hunk(header=header, lines=hunk_lines)) + if not hunks and new_path is None: + raise _PatchError(f"Update File requires at least one hunk or Move to: {path}") + ops.append(_PatchOp(kind="update", path=path, new_path=new_path, hunks=hunks)) + continue + + raise _PatchError(f"unknown patch header: {line!r}") + + if not ops: + raise _PatchError("patch contains no file operations") + return ops + + +def _find_with_eof_fallback(content: str, needle: str, start: int) -> tuple[int, int]: + pos = content.find(needle, start) + if pos >= 0: + return pos, len(needle) + if needle.endswith("\n"): + trimmed = needle[:-1] + pos = content.find(trimmed, start) + if pos >= 0 and pos + len(trimmed) == len(content): + return pos, len(trimmed) + return -1, 0 + + +def _line_offset(content: str, line_number: int) -> int: + if line_number <= 1: + return 0 + offset = 0 + for current, line in enumerate(content.splitlines(keepends=True), start=1): + if current >= line_number: + return offset + offset += len(line) + return len(content) + + +def _line_hint(header: str | None) -> int | None: + if not header: + return None + match = re.search(r"-(\d+)(?:,\d+)?", header) + return int(match.group(1)) if match else None + + +def _hunk_mismatch(path: str, old_text: str, content: str, header: str | None) -> str: + lines = content.splitlines(keepends=True) + old_lines = old_text.splitlines(keepends=True) + window = max(1, len(old_lines)) + best_ratio, best_start = -1.0, 0 + best_lines: list[str] = [] + for i in range(max(1, len(lines) - window + 1)): + current = lines[i : i + window] + ratio = difflib.SequenceMatcher(None, "".join(old_lines), "".join(current)).ratio() + if ratio > best_ratio: + best_ratio, best_start, best_lines = ratio, i, current + + label = f" after header {header!r}" if header else "" + if best_ratio <= 0: + return f"hunk does not match {path}{label}" + diff = "\n".join(difflib.unified_diff( + old_lines, + best_lines, + fromfile="patch hunk", + tofile=f"{path} (actual, line {best_start + 1})", + lineterm="", + )) + return ( + f"hunk does not match {path}{label}. " + f"Best match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" + ) + + +def _apply_hunks(path: str, content: str, hunks: list[_Hunk]) -> str: + cursor = 0 + for hunk in hunks: + old_lines = [text for marker, text in hunk.lines if marker in {" ", "-"}] + new_lines = [text for marker, text in hunk.lines if marker in {" ", "+"}] + old_text = _lines_to_text(old_lines) + new_text = _lines_to_text(new_lines) + + search_start = cursor + line_hint = None + if hunk.header: + line_hint = _line_hint(hunk.header) + if line_hint is not None: + search_start = _line_offset(content, line_hint) + else: + header_pos = content.find(hunk.header, cursor) + if header_pos >= 0: + search_start = header_pos + + if old_text: + pos, match_len = _find_with_eof_fallback(content, old_text, search_start) + if pos < 0 and search_start != 0 and line_hint is None: + pos, match_len = _find_with_eof_fallback(content, old_text, 0) + if pos < 0: + raise _PatchError(_hunk_mismatch(path, old_text, content, hunk.header)) + else: + pos = search_start + match_len = 0 + + content = content[:pos] + new_text + content[pos + match_len:] + cursor = pos + len(new_text) + return content + + +@tool_parameters( + tool_parameters_schema( + patch=StringSchema( + "Full patch text. Use *** Begin Patch / *** End Patch and file sections " + "for Add File, Update File, Delete File, and optional Move to.", + min_length=1, + ), + required=["patch"], + ) +) +class ApplyPatchTool(_FsTool): + """Apply a structured multi-file patch.""" + _scopes = {"core", "subagent"} + + @property + def name(self) -> str: + return "apply_patch" + + @property + def description(self) -> str: + return ( + "Apply a structured patch for code edits. The patch must include " + "*** Begin Patch and *** End Patch. Supports Add File, Update File, " + "Delete File, and Move to. Paths must be relative. Prefer this for " + "multi-file coding changes; use edit_file for small exact replacements." + ) + + async def execute(self, patch: str, **kwargs: Any) -> str: + try: + ops = _parse_patch(patch) + writes: dict[Path, str] = {} + deletes: set[Path] = set() + touched: list[str] = [] + + for op in ops: + source = self._resolve(op.path) + if op.kind == "add": + if source.exists(): + raise _PatchError(f"file to add already exists: {op.path}") + writes[source] = _lines_to_text(op.add_lines or []) + deletes.discard(source) + touched.append(f"add {op.path}") + continue + + if op.kind == "delete": + if not source.exists(): + raise _PatchError(f"file to delete does not exist: {op.path}") + if not source.is_file(): + raise _PatchError(f"path to delete is not a file: {op.path}") + deletes.add(source) + writes.pop(source, None) + touched.append(f"delete {op.path}") + continue + + if not source.exists(): + raise _PatchError(f"file to update does not exist: {op.path}") + if not source.is_file(): + raise _PatchError(f"path to update is not a file: {op.path}") + raw = source.read_bytes() + try: + content = raw.decode("utf-8") + except UnicodeDecodeError as exc: + raise _PatchError(f"file to update is not UTF-8 text: {op.path}") from exc + uses_crlf = "\r\n" in content + content = content.replace("\r\n", "\n") + new_content = _apply_hunks(op.path, content, op.hunks or []) + if uses_crlf: + new_content = new_content.replace("\n", "\r\n") + + target = self._resolve(op.new_path) if op.new_path else source + if op.new_path and target.exists() and target != source: + raise _PatchError(f"move target already exists: {op.new_path}") + writes[target] = new_content + deletes.discard(target) + if target != source: + deletes.add(source) + action = f"move {op.path} -> {op.new_path}" if op.new_path else f"update {op.path}" + touched.append(action) + + backups: dict[Path, bytes | None] = {} + for path in set(writes) | deletes: + backups[path] = path.read_bytes() if path.exists() else None + + try: + for path in deletes: + if path.exists(): + path.unlink() + for path, content in writes.items(): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content, encoding="utf-8", newline="") + except Exception: + for path, data in backups.items(): + if data is None: + if path.exists(): + path.unlink() + else: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(data) + raise + + for path in set(writes) | deletes: + self._file_states.record_write(path) + return "Patch applied:\n" + "\n".join(f"- {item}" for item in touched) + except PermissionError as exc: + return f"Error: {exc}" + except _PatchError as exc: + return f"Error applying patch: {exc}" + except Exception as exc: + return f"Error applying patch: {exc}" diff --git a/nanobot/agent/tools/exec_session.py b/nanobot/agent/tools/exec_session.py new file mode 100644 index 000000000..202fbc640 --- /dev/null +++ b/nanobot/agent/tools/exec_session.py @@ -0,0 +1,409 @@ +"""Session support for long-running exec workflows.""" + +from __future__ import annotations + +import asyncio +import shutil +import time +import uuid +from contextlib import suppress +from dataclasses import dataclass +from typing import Any + +from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema + + +DEFAULT_YIELD_MS = 1000 +MAX_YIELD_MS = 30_000 +DEFAULT_MAX_OUTPUT_CHARS = 10_000 +MAX_OUTPUT_CHARS = 50_000 + + +@dataclass(slots=True) +class _SessionPoll: + output: str + done: bool + exit_code: int | None + timed_out: bool = False + terminated: bool = False + stdin_closed: bool = False + truncated_chars: int = 0 + + +class _ExecSession: + def __init__( + self, + *, + session_id: str, + process: asyncio.subprocess.Process, + timeout: int, + ) -> None: + self.session_id = session_id + self.process = process + self.deadline = time.monotonic() + timeout + self.last_access = time.monotonic() + self._chunks: list[str] = [] + self._lock = asyncio.Lock() + self._timed_out = False + self._stdout_task = asyncio.create_task(self._read_stream(process.stdout, "")) + self._stderr_task = asyncio.create_task(self._read_stream(process.stderr, "STDERR:\n")) + + async def _read_stream( + self, + stream: asyncio.StreamReader | None, + prefix: str, + ) -> None: + if stream is None: + return + first = True + while True: + chunk = await stream.read(4096) + if not chunk: + break + text = chunk.decode("utf-8", errors="replace") + if prefix and first: + text = prefix + text + first = False + async with self._lock: + self._chunks.append(text) + + async def write(self, chars: str) -> str | None: + if self.process.returncode is not None: + return "session has already exited" + if self.process.stdin is None: + return "session stdin is not available" + try: + self.process.stdin.write(chars.encode("utf-8")) + await self.process.stdin.drain() + except (BrokenPipeError, ConnectionResetError): + return "session stdin is closed" + return None + + async def close_stdin(self) -> str | None: + if self.process.returncode is not None: + return "session has already exited" + if self.process.stdin is None: + return "session stdin is not available" + self.process.stdin.close() + with suppress(BrokenPipeError, ConnectionResetError): + await self.process.stdin.wait_closed() + return None + + async def poll( + self, + yield_time_ms: int, + max_output_chars: int, + *, + terminated: bool = False, + stdin_closed: bool = False, + ) -> _SessionPoll: + self.last_access = time.monotonic() + if yield_time_ms > 0 and self.process.returncode is None: + await asyncio.sleep(min(yield_time_ms, MAX_YIELD_MS) / 1000) + + if self.process.returncode is None and time.monotonic() >= self.deadline: + self._timed_out = True + await self.kill() + + if self.process.returncode is not None: + with suppress(asyncio.TimeoutError): + await asyncio.wait_for( + asyncio.gather(self._stdout_task, self._stderr_task), + timeout=2.0, + ) + + async with self._lock: + output = "".join(self._chunks) + self._chunks.clear() + + output, truncated = _truncate_output(output, max_output_chars) + return _SessionPoll( + output=output, + done=self.process.returncode is not None, + exit_code=self.process.returncode, + timed_out=self._timed_out, + terminated=terminated, + stdin_closed=stdin_closed, + truncated_chars=truncated, + ) + + async def kill(self) -> None: + if self.process.returncode is not None: + return + self.process.kill() + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(self.process.wait(), timeout=5.0) + + +class ExecSessionManager: + def __init__(self, *, max_sessions: int = 8, idle_timeout: int = 1800) -> None: + self.max_sessions = max_sessions + self.idle_timeout = idle_timeout + self._sessions: dict[str, _ExecSession] = {} + self._lock = asyncio.Lock() + + async def start( + self, + *, + command: str, + cwd: str, + env: dict[str, str], + timeout: int, + shell_program: str | None, + login: bool, + yield_time_ms: int, + max_output_chars: int, + ) -> tuple[str, _SessionPoll]: + async with self._lock: + await self._cleanup_locked() + if len(self._sessions) >= self.max_sessions: + raise RuntimeError(f"maximum exec sessions reached ({self.max_sessions})") + process = await self._spawn(command, cwd, env, shell_program, login) + session_id = uuid.uuid4().hex[:12] + session = _ExecSession(session_id=session_id, process=process, timeout=timeout) + self._sessions[session_id] = session + + poll = await session.poll(yield_time_ms, max_output_chars) + if poll.done: + async with self._lock: + self._sessions.pop(session_id, None) + return session_id, poll + + async def write( + self, + *, + session_id: str, + chars: str | None, + close_stdin: bool, + terminate: bool, + yield_time_ms: int, + max_output_chars: int, + ) -> _SessionPoll: + async with self._lock: + await self._cleanup_locked() + session = self._sessions.get(session_id) + if session is None: + raise KeyError(session_id) + + if chars is not None: + error = await session.write(chars) + if error: + raise RuntimeError(error) + stdin_closed = False + if close_stdin: + error = await session.close_stdin() + if error: + raise RuntimeError(error) + stdin_closed = True + if terminate: + await session.kill() + poll = await session.poll( + yield_time_ms, + max_output_chars, + terminated=terminate, + stdin_closed=stdin_closed, + ) + if poll.done: + async with self._lock: + self._sessions.pop(session_id, None) + return poll + + async def _cleanup_locked(self) -> None: + now = time.monotonic() + stale = [ + session_id + for session_id, session in self._sessions.items() + if session.process.returncode is not None + or now - session.last_access > self.idle_timeout + ] + for session_id in stale: + session = self._sessions.pop(session_id) + await session.kill() + + async def _spawn( + self, + command: str, + cwd: str, + env: dict[str, str], + shell_program: str | None, + login: bool, + ) -> asyncio.subprocess.Process: + from nanobot.agent.tools import shell + + if shell._IS_WINDOWS: + return await asyncio.create_subprocess_shell( + command, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=env, + ) + shell_program = shell_program or shutil.which("bash") or "/bin/bash" + args = [shell_program] + if login and shell_program.rsplit("/", 1)[-1] in {"bash", "zsh"}: + args.append("-l") + args.extend(["-c", command]) + return await asyncio.create_subprocess_exec( + *args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=env, + ) + + +DEFAULT_EXEC_SESSION_MANAGER = ExecSessionManager() + + +def clamp_session_int(value: int | None, default: int, minimum: int, maximum: int) -> int: + if value is None: + return default + return min(max(value, minimum), maximum) + + +def _truncate_output(output: str, max_output_chars: int) -> tuple[str, int]: + if len(output) <= max_output_chars: + return output, 0 + half = max_output_chars // 2 + omitted = len(output) - max_output_chars + return ( + output[:half] + + f"\n\n... ({omitted:,} chars truncated) ...\n\n" + + output[-half:], + omitted, + ) + + +def format_session_poll(session_id: str, poll: _SessionPoll) -> str: + parts = [poll.output] if poll.output else [] + if poll.truncated_chars: + parts.append(f"(output truncated by {poll.truncated_chars:,} chars)") + if poll.timed_out: + parts.append("Error: Command timed out; session was terminated.") + if poll.terminated and not poll.timed_out: + parts.append("Session terminated.") + if poll.stdin_closed: + parts.append("Stdin closed.") + if poll.done: + parts.append(f"Exit code: {poll.exit_code}") + else: + parts.append(f"Process running. session_id: {session_id}") + return "\n".join(parts) if parts else "(no output yet)" + + +@tool_parameters( + tool_parameters_schema( + session_id=StringSchema("Session id returned by exec when yield_time_ms is used."), + chars=StringSchema( + "Bytes/text to write to stdin. Omit or pass an empty string to only poll recent output.", + nullable=True, + ), + close_stdin=BooleanSchema( + description="Close stdin after writing chars. Useful for commands waiting for EOF.", + default=False, + ), + terminate=BooleanSchema( + description="Terminate the running exec session.", + default=False, + ), + yield_time_ms=IntegerSchema( + DEFAULT_YIELD_MS, + description="Milliseconds to wait before returning recent output (default 1000, max 30000).", + minimum=0, + maximum=MAX_YIELD_MS, + ), + max_output_chars=IntegerSchema( + DEFAULT_MAX_OUTPUT_CHARS, + description="Maximum output characters to return from this poll (default 10000, max 50000).", + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + ), + max_output_tokens=IntegerSchema( + DEFAULT_MAX_OUTPUT_CHARS, + description="Compatibility alias for max_output_chars. The current runtime uses a character budget.", + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + nullable=True, + ), + required=["session_id"], + ) +) +class WriteStdinTool(Tool): + """Write to or poll a running exec session.""" + + _scopes = {"core", "subagent"} + config_key = "exec" + + @classmethod + def config_cls(cls): + from nanobot.agent.tools.shell import ExecToolConfig + + return ExecToolConfig + + @classmethod + def enabled(cls, ctx: Any) -> bool: + return ctx.config.exec.enable + + def __init__( + self, + *, + manager: ExecSessionManager | None = None, + ) -> None: + self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER + + @classmethod + def create(cls, ctx: Any) -> Tool: + return cls() + + @property + def exclusive(self) -> bool: + return True + + @property + def name(self) -> str: + return "write_stdin" + + @property + def description(self) -> str: + return ( + "Write text to a running exec session and return recent output. " + "Use chars='' to poll without writing. Set close_stdin=true to send EOF, " + "or terminate=true to stop the session. Sessions finish automatically " + "when their process exits." + ) + + async def execute( + self, + session_id: str, + chars: str | None = None, + close_stdin: bool = False, + terminate: bool = False, + yield_time_ms: int | None = None, + max_output_chars: int | None = None, + max_output_tokens: int | None = None, + **kwargs: Any, + ) -> str: + try: + if max_output_chars is None: + max_output_chars = max_output_tokens + poll = await self._manager.write( + session_id=session_id, + chars=chars, + close_stdin=close_stdin, + terminate=terminate, + yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS), + max_output_chars=clamp_session_int( + max_output_chars, + DEFAULT_MAX_OUTPUT_CHARS, + 1000, + MAX_OUTPUT_CHARS, + ), + ) + return format_session_poll(session_id, poll) + except KeyError: + return f"Error: exec session not found: {session_id}" + except Exception as exc: + return f"Error writing to exec session: {exc}" diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 8f4f660da..728ff9317 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -132,6 +132,10 @@ def _parse_page_range(pages: str, total: int) -> tuple[int, int]: minimum=1, ), pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"), + force=BooleanSchema( + description="Bypass same-file read deduplication and return content again.", + default=False, + ), required=["path"], ) ) @@ -155,6 +159,7 @@ class ReadFileTool(_FsTool): "Images return visual content for analysis. " "Supports PDF, DOCX, XLSX, PPTX documents. " "Use offset and limit for large text files. " + "Use force=true to re-read content even if unchanged. " "Reads exceeding ~128K chars are truncated." ) @@ -162,7 +167,15 @@ class ReadFileTool(_FsTool): def read_only(self) -> bool: return True - async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any: + async def execute( + self, + path: str | None = None, + offset: int = 1, + limit: int | None = None, + pages: str | None = None, + force: bool = False, + **kwargs: Any, + ) -> Any: try: if not path: return "Error reading file: Unknown path" @@ -202,7 +215,13 @@ class ReadFileTool(_FsTool): current_mtime = os.path.getmtime(fp) except OSError: current_mtime = 0.0 - if entry and entry.can_dedup and entry.offset == offset and entry.limit == limit: + if ( + not force + and entry + and entry.can_dedup + and entry.offset == offset + and entry.limit == limit + ): if current_mtime != entry.mtime: # File was modified externally - force full read and mark as not dedupable entry.can_dedup = False @@ -657,6 +676,24 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]: old_text=StringSchema("The text to find and replace"), new_text=StringSchema("The text to replace with"), replace_all=BooleanSchema(description="Replace all occurrences (default false)"), + occurrence=IntegerSchema( + 1, + description="Optional 1-based occurrence to replace when old_text appears multiple times.", + minimum=1, + nullable=True, + ), + line_hint=IntegerSchema( + 1, + description="Optional 1-based line hint used to choose the nearest match.", + minimum=1, + nullable=True, + ), + expected_replacements=IntegerSchema( + 1, + description="Optional guard for the number of replacements that must be made.", + minimum=1, + nullable=True, + ), required=["path", "old_text", "new_text"], ) ) @@ -677,7 +714,7 @@ class EditFileTool(_FsTool): "Edit a file by replacing old_text with new_text. " "Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. " "If old_text matches multiple times, you must provide more context " - "or set replace_all=true. Shows a diff of the closest match on failure." + "or set occurrence/line_hint/replace_all. Shows a diff of the closest match on failure." ) @staticmethod @@ -688,7 +725,8 @@ class EditFileTool(_FsTool): async def execute( self, path: str | None = None, old_text: str | None = None, new_text: str | None = None, - replace_all: bool = False, **kwargs: Any, + replace_all: bool = False, occurrence: int | None = None, + line_hint: int | None = None, expected_replacements: int | None = None, **kwargs: Any, ) -> str: try: if not path: @@ -697,10 +735,12 @@ class EditFileTool(_FsTool): raise ValueError("Unknown old_text") if new_text is None: raise ValueError("Unknown new_text") - - # .ipynb detection - if path.endswith(".ipynb"): - return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file." + if occurrence is not None and occurrence < 1: + return "Error: occurrence must be >= 1." + if line_hint is not None and line_hint < 1: + return "Error: line_hint must be >= 1." + if expected_replacements is not None and expected_replacements < 1: + return "Error: expected_replacements must be >= 1." fp = self._resolve(path) @@ -743,15 +783,42 @@ class EditFileTool(_FsTool): if not matches: return self._not_found_msg(old_text, content, path) count = len(matches) + if replace_all and occurrence is not None: + return "Error: occurrence cannot be used with replace_all=true." + if replace_all and line_hint is not None: + return "Error: line_hint cannot be used with replace_all=true." + if occurrence is not None and line_hint is not None: + return "Error: line_hint cannot be used with occurrence." if count > 1 and not replace_all: - line_numbers = [match.line for match in matches] - preview = ", ".join(f"line {n}" for n in line_numbers[:3]) - if len(line_numbers) > 3: - preview += ", ..." - location_hint = f" at {preview}" if preview else "" + if occurrence is not None: + if occurrence > count: + return ( + f"Error: occurrence {occurrence} is out of range; " + f"old_text appears {count} times." + ) + elif line_hint is not None: + nearest = min(matches, key=lambda match: abs(match.line - line_hint)) + distance = abs(nearest.line - line_hint) + if sum(1 for match in matches if abs(match.line - line_hint) == distance) > 1: + return ( + f"Error: line_hint {line_hint} is ambiguous; " + f"old_text appears {count} times." + ) + else: + line_numbers = [match.line for match in matches] + preview = ", ".join(f"line {n}" for n in line_numbers[:3]) + if len(line_numbers) > 3: + preview += ", ..." + location_hint = f" at {preview}" if preview else "" + return ( + f"Warning: old_text appears {count} times{location_hint}. " + "Provide more context, set occurrence to choose one match, " + "or set replace_all=true." + ) + elif occurrence is not None and occurrence > count: return ( - f"Warning: old_text appears {count} times{location_hint}. " - "Provide more context to make it unique, or set replace_all=true." + f"Error: occurrence {occurrence} is out of range; " + f"old_text appears {count} time." ) norm_new = new_text.replace("\r\n", "\n") @@ -760,7 +827,17 @@ class EditFileTool(_FsTool): if fp.suffix.lower() not in self._MARKDOWN_EXTS: norm_new = self._strip_trailing_ws(norm_new) - selected = matches if replace_all else matches[:1] + if replace_all: + selected = matches + elif line_hint is not None: + selected = [min(matches, key=lambda match: abs(match.line - line_hint))] + else: + selected = [matches[occurrence - 1 if occurrence else 0]] + if expected_replacements is not None and len(selected) != expected_replacements: + return ( + f"Error: expected {expected_replacements} replacements but " + f"would make {len(selected)}." + ) new_content = content for match in reversed(selected): replacement = _preserve_quote_style(norm_old, match.text, norm_new) diff --git a/nanobot/agent/tools/notebook.py b/nanobot/agent/tools/notebook.py deleted file mode 100644 index 0980b7c93..000000000 --- a/nanobot/agent/tools/notebook.py +++ /dev/null @@ -1,162 +0,0 @@ -"""NotebookEditTool — edit Jupyter .ipynb notebooks.""" - -from __future__ import annotations - -import json -import uuid -from typing import Any - -from nanobot.agent.tools.base import tool_parameters -from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema -from nanobot.agent.tools.filesystem import _FsTool - - -def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict: - cell: dict[str, Any] = { - "cell_type": cell_type, - "source": source, - "metadata": {}, - } - if cell_type == "code": - cell["outputs"] = [] - cell["execution_count"] = None - if generate_id: - cell["id"] = uuid.uuid4().hex[:8] - return cell - - -def _make_empty_notebook() -> dict: - return { - "nbformat": 4, - "nbformat_minor": 5, - "metadata": { - "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, - "language_info": {"name": "python"}, - }, - "cells": [], - } - - -@tool_parameters( - tool_parameters_schema( - path=StringSchema("Path to the .ipynb notebook file"), - cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0), - new_source=StringSchema("New source content for the cell"), - cell_type=StringSchema( - "Cell type: 'code' or 'markdown' (default: code)", - enum=["code", "markdown"], - ), - edit_mode=StringSchema( - "Mode: 'replace' (default), 'insert' (after target), or 'delete'", - enum=["replace", "insert", "delete"], - ), - required=["path", "cell_index"], - ) -) -class NotebookEditTool(_FsTool): - """Edit Jupyter notebook cells: replace, insert, or delete.""" - _scopes = {"core"} - - _VALID_CELL_TYPES = frozenset({"code", "markdown"}) - _VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"}) - - @property - def name(self) -> str: - return "notebook_edit" - - @property - def description(self) -> str: - return ( - "Edit a Jupyter notebook (.ipynb) cell. " - "Modes: replace (default) replaces cell content, " - "insert adds a new cell after the target index, " - "delete removes the cell at the index. " - "cell_index is 0-based." - ) - - async def execute( - self, - path: str | None = None, - cell_index: int = 0, - new_source: str = "", - cell_type: str = "code", - edit_mode: str = "replace", - **kwargs: Any, - ) -> str: - try: - if not path: - return "Error: path is required" - - if not path.endswith(".ipynb"): - return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files." - - if edit_mode not in self._VALID_EDIT_MODES: - return ( - f"Error: Invalid edit_mode '{edit_mode}'. " - "Use one of: replace, insert, delete." - ) - - if cell_type not in self._VALID_CELL_TYPES: - return ( - f"Error: Invalid cell_type '{cell_type}'. " - "Use one of: code, markdown." - ) - - fp = self._resolve(path) - - # Create new notebook if file doesn't exist and mode is insert - if not fp.exists(): - if edit_mode != "insert": - return f"Error: File not found: {path}" - nb = _make_empty_notebook() - cell = _new_cell(new_source, cell_type, generate_id=True) - nb["cells"].append(cell) - fp.parent.mkdir(parents=True, exist_ok=True) - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully created {fp} with 1 cell" - - try: - nb = json.loads(fp.read_text(encoding="utf-8")) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - return f"Error: Failed to parse notebook: {e}" - - cells = nb.get("cells", []) - nbformat_minor = nb.get("nbformat_minor", 0) - generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5 - - if edit_mode == "delete": - if cell_index < 0 or cell_index >= len(cells): - return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)" - cells.pop(cell_index) - nb["cells"] = cells - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully deleted cell {cell_index} from {fp}" - - if edit_mode == "insert": - insert_at = min(cell_index + 1, len(cells)) - cell = _new_cell(new_source, cell_type, generate_id=generate_id) - cells.insert(insert_at, cell) - nb["cells"] = cells - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully inserted cell at index {insert_at} in {fp}" - - # Default: replace - if cell_index < 0 or cell_index >= len(cells): - return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)" - cells[cell_index]["source"] = new_source - if cell_type and cells[cell_index].get("cell_type") != cell_type: - cells[cell_index]["cell_type"] = cell_type - if cell_type == "code": - cells[cell_index].setdefault("outputs", []) - cells[cell_index].setdefault("execution_count", None) - elif "outputs" in cells[cell_index]: - del cells[cell_index]["outputs"] - cells[cell_index].pop("execution_count", None) - nb["cells"] = cells - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully edited cell {cell_index} in {fp}" - - except PermissionError as e: - return f"Error: {e}" - except Exception as e: - return f"Error editing notebook: {e}" diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 0252b9746..47d3e9065 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -8,6 +8,7 @@ import re import shutil import sys from contextlib import suppress +from dataclasses import dataclass from pathlib import Path from typing import Any @@ -15,8 +16,17 @@ from loguru import logger from pydantic import Field from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.exec_session import ( + DEFAULT_MAX_OUTPUT_CHARS, + DEFAULT_YIELD_MS, + DEFAULT_EXEC_SESSION_MANAGER, + MAX_OUTPUT_CHARS, + MAX_YIELD_MS, + clamp_session_int, + format_session_poll, +) from nanobot.agent.tools.sandbox import wrap_command -from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema +from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base @@ -44,10 +54,22 @@ class ExecToolConfig(Base): deny_patterns: list[str] = Field(default_factory=list) +@dataclass(slots=True) +class _PreparedCommand: + command: str + cwd: str + env: dict[str, str] + timeout: int + shell_program: str | None + login: bool + + @tool_parameters( tool_parameters_schema( command=StringSchema("The shell command to execute"), + cmd=StringSchema("Compatibility alias for command"), working_dir=StringSchema("Optional working directory for the command"), + workdir=StringSchema("Compatibility alias for working_dir"), timeout=IntegerSchema( 60, description=( @@ -57,7 +79,44 @@ class ExecToolConfig(Base): minimum=1, maximum=600, ), - required=["command"], + shell=StringSchema( + "Optional shell binary to launch. On Unix, supports sh, bash, or zsh.", + nullable=True, + ), + login=BooleanSchema( + description="Whether to run bash/zsh with login shell semantics (default true).", + default=True, + nullable=True, + ), + yield_time_ms=IntegerSchema( + description=( + "Optional milliseconds to wait before returning output. " + "When set, a still-running command returns a session_id that " + "can be polled or written to with write_stdin. Omit this field " + "to keep one-shot exec behavior." + ), + minimum=0, + maximum=MAX_YIELD_MS, + nullable=True, + ), + max_output_chars=IntegerSchema( + description=( + "Maximum output characters to return when yield_time_ms is used " + "(default 10000, max 50000)." + ), + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + nullable=True, + ), + max_output_tokens=IntegerSchema( + description=( + "Compatibility alias for max_output_chars. The current runtime " + "uses a character budget." + ), + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + nullable=True, + ), ) ) class ExecTool(Tool): @@ -98,6 +157,7 @@ class ExecTool(Tool): sandbox: str = "", path_append: str = "", allowed_env_keys: list[str] | None = None, + session_manager: Any | None = None, ): self.timeout = timeout self.working_dir = working_dir @@ -125,6 +185,7 @@ class ExecTool(Tool): self.restrict_to_workspace = restrict_to_workspace self.path_append = path_append self.allowed_env_keys = allowed_env_keys or [] + self._session_manager = session_manager or DEFAULT_EXEC_SESSION_MANAGER @property def name(self) -> str: @@ -153,7 +214,10 @@ class ExecTool(Tool): "Prefer read_file/write_file/edit_file over cat/echo/sed, " "and grep/glob over shell find/grep. " "Use -y or --yes flags to avoid interactive prompts. " - "Output is truncated at 10 000 chars; timeout defaults to 60s." + "For long-running or interactive commands, pass yield_time_ms; " + "if the command keeps running, exec returns a session_id that can " + "be polled or written to with write_stdin. Output is truncated at " + "10 000 chars; timeout defaults to 60s." ) @property @@ -161,9 +225,111 @@ class ExecTool(Tool): return True async def execute( - self, command: str, working_dir: str | None = None, - timeout: int | None = None, **kwargs: Any, + self, command: str | None = None, cmd: str | None = None, + working_dir: str | None = None, workdir: str | None = None, + timeout: int | None = None, shell: str | None = None, + login: bool | None = None, yield_time_ms: int | None = None, + max_output_chars: int | None = None, + max_output_tokens: int | None = None, + **kwargs: Any, ) -> str: + command = command or cmd + working_dir = working_dir or workdir + if not command: + return "Error: Missing command. Provide command or cmd." + if max_output_chars is None: + max_output_chars = max_output_tokens + + prepared = self._prepare_command(command, working_dir, timeout, shell, login) + if isinstance(prepared, str): + return prepared + + if yield_time_ms is not None: + return await self._execute_session(prepared, yield_time_ms, max_output_chars) + + try: + process = await self._spawn( + prepared.command, + prepared.cwd, + prepared.env, + prepared.shell_program, + prepared.login, + ) + + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), + timeout=prepared.timeout, + ) + except asyncio.TimeoutError: + await self._kill_process(process) + return f"Error: Command timed out after {prepared.timeout} seconds" + except asyncio.CancelledError: + await self._kill_process(process) + raise + + output_parts = [] + + if stdout: + output_parts.append(stdout.decode("utf-8", errors="replace")) + + if stderr: + stderr_text = stderr.decode("utf-8", errors="replace") + if stderr_text.strip(): + output_parts.append(f"STDERR:\n{stderr_text}") + + output_parts.append(f"\nExit code: {process.returncode}") + + result = "\n".join(output_parts) if output_parts else "(no output)" + + max_len = clamp_session_int(max_output_chars, self._MAX_OUTPUT, 1000, MAX_OUTPUT_CHARS) + if len(result) > max_len: + half = max_len // 2 + result = ( + result[:half] + + f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n" + + result[-half:] + ) + + return result + + except Exception as e: + return f"Error executing command: {str(e)}" + + async def _execute_session( + self, + prepared: _PreparedCommand, + yield_time_ms: int | None, + max_output_chars: int | None, + ) -> str: + try: + session_id, poll = await self._session_manager.start( + command=prepared.command, + cwd=prepared.cwd, + env=prepared.env, + timeout=prepared.timeout, + shell_program=prepared.shell_program, + login=prepared.login, + yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS), + max_output_chars=clamp_session_int( + max_output_chars, + DEFAULT_MAX_OUTPUT_CHARS, + 1000, + MAX_OUTPUT_CHARS, + ), + ) + return format_session_poll(session_id, poll) + except Exception as exc: + return f"Error executing command: {exc}" + + def _prepare_command( + self, + command: str, + working_dir: str | None = None, + timeout: int | None = None, + shell: str | None = None, + login: bool | None = None, + ) -> _PreparedCommand | str: cwd = working_dir or self.working_dir or os.getcwd() # Prevent an LLM-supplied working_dir from escaping the configured @@ -211,52 +377,24 @@ class ExecTool(Tool): env["NANOBOT_PATH_APPEND"] = self.path_append command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}' - try: - process = await self._spawn(command, cwd, env) + shell_program, shell_error = self._resolve_shell(shell) + if shell_error: + return shell_error - try: - stdout, stderr = await asyncio.wait_for( - process.communicate(), - timeout=effective_timeout, - ) - except asyncio.TimeoutError: - await self._kill_process(process) - return f"Error: Command timed out after {effective_timeout} seconds" - except asyncio.CancelledError: - await self._kill_process(process) - raise - - output_parts = [] - - if stdout: - output_parts.append(stdout.decode("utf-8", errors="replace")) - - if stderr: - stderr_text = stderr.decode("utf-8", errors="replace") - if stderr_text.strip(): - output_parts.append(f"STDERR:\n{stderr_text}") - - output_parts.append(f"\nExit code: {process.returncode}") - - result = "\n".join(output_parts) if output_parts else "(no output)" - - max_len = self._MAX_OUTPUT - if len(result) > max_len: - half = max_len // 2 - result = ( - result[:half] - + f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n" - + result[-half:] - ) - - return result - - except Exception as e: - return f"Error executing command: {str(e)}" + return _PreparedCommand( + command=command, + cwd=cwd, + env=env, + timeout=effective_timeout, + shell_program=shell_program, + login=True if login is None else login, + ) @staticmethod async def _spawn( command: str, cwd: str, env: dict[str, str], + shell_program: str | None = None, + login: bool = True, ) -> asyncio.subprocess.Process: """Launch *command* in a platform-appropriate shell.""" if _IS_WINDOWS: @@ -272,9 +410,13 @@ class ExecTool(Tool): cwd=cwd, env=env, ) - bash = shutil.which("bash") or "/bin/bash" + shell_program = shell_program or shutil.which("bash") or "/bin/bash" + args = [shell_program] + if login and Path(shell_program).name in {"bash", "zsh"}: + args.append("-l") + args.extend(["-c", command]) return await asyncio.create_subprocess_exec( - bash, "-l", "-c", command, + *args, stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, @@ -282,6 +424,31 @@ class ExecTool(Tool): env=env, ) + @staticmethod + def _resolve_shell(shell: str | None) -> tuple[str | None, str | None]: + if not shell: + return None, None + if _IS_WINDOWS: + return None, "Error: shell parameter is not supported on Windows" + if "\0" in shell or "\n" in shell or "\r" in shell: + return None, "Error: shell contains invalid characters" + allowed = {"sh", "bash", "zsh"} + path = Path(shell).expanduser() + if path.is_absolute(): + if path.name not in allowed: + return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh" + if not path.is_file() or not os.access(path, os.X_OK): + return None, f"Error: shell is not executable: {shell}" + return str(path), None + if "/" in shell or "\\" in shell: + return None, "Error: shell must be a shell name or absolute path" + if shell not in allowed: + return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh" + resolved = shutil.which(shell) + if not resolved: + return None, f"Error: shell not found: {shell}" + return resolved, None + @staticmethod async def _kill_process(process: asyncio.subprocess.Process) -> None: """Kill a subprocess and reap it to prevent zombies.""" diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index b5d2f6d73..056041f4b 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -3,7 +3,6 @@ from __future__ import annotations import difflib -import json import re import time from dataclasses import dataclass, field @@ -11,7 +10,7 @@ from pathlib import Path from typing import Any, Awaitable, Callable -TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"}) +TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file"}) _MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024 _LIVE_EMIT_INTERVAL_S = 0.18 _LIVE_EMIT_LINE_STEP = 24 @@ -704,77 +703,4 @@ def _predict_after_text( return before_text.replace(old_text, new_text) return before_text.replace(old_text, new_text, 1) return None - if tool_name == "notebook_edit": - return _predict_notebook_after_text(params, before_text) return None - - -def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> str | None: - try: - nb = json.loads(before_text) if before_text.strip() else _empty_notebook() - except Exception: - return None - cells = nb.get("cells") - if not isinstance(cells, list): - return None - try: - cell_index = int(params.get("cell_index", 0)) - except (TypeError, ValueError): - return None - new_source = params.get("new_source") - source = new_source if isinstance(new_source, str) else "" - cell_type = ( - params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code" - ) - mode = ( - params.get("edit_mode") - if params.get("edit_mode") in ("replace", "insert", "delete") - else "replace" - ) - if mode == "delete": - if 0 <= cell_index < len(cells): - cells.pop(cell_index) - else: - return None - elif mode == "insert": - insert_at = min(max(cell_index + 1, 0), len(cells)) - cells.insert(insert_at, _new_notebook_cell(source, str(cell_type))) - else: - if not (0 <= cell_index < len(cells)): - return None - cell = cells[cell_index] - if not isinstance(cell, dict): - return None - cell["source"] = source - cell["cell_type"] = cell_type - if cell_type == "code": - cell.setdefault("outputs", []) - cell.setdefault("execution_count", None) - else: - cell.pop("outputs", None) - cell.pop("execution_count", None) - nb["cells"] = cells - try: - return json.dumps(nb, indent=1, ensure_ascii=False) - except Exception: - return None - - -def _empty_notebook() -> dict[str, Any]: - return { - "nbformat": 4, - "nbformat_minor": 5, - "metadata": { - "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, - "language_info": {"name": "python"}, - }, - "cells": [], - } - - -def _new_notebook_cell(source: str, cell_type: str) -> dict[str, Any]: - cell: dict[str, Any] = {"cell_type": cell_type, "source": source, "metadata": {}} - if cell_type == "code": - cell["outputs"] = [] - cell["execution_count"] = None - return cell diff --git a/tests/tools/test_apply_patch_tool.py b/tests/tools/test_apply_patch_tool.py new file mode 100644 index 000000000..ea98794b3 --- /dev/null +++ b/tests/tools/test_apply_patch_tool.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import asyncio + +from nanobot.agent.tools.apply_patch import ApplyPatchTool + + +def test_apply_patch_adds_file(tmp_path): + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Add File: hello.txt ++Hello ++world +*** End Patch +""" + )) + + assert "Patch applied" in result + assert (tmp_path / "hello.txt").read_text() == "Hello\nworld\n" + + +def test_apply_patch_updates_multiple_hunks(tmp_path): + target = tmp_path / "multi.txt" + target.write_text("line1\nline2\nline3\nline4\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: multi.txt +@@ +-line2 ++changed2 +@@ +-line4 ++changed4 +*** End Patch +""" + )) + + assert "update multi.txt" in result + assert target.read_text() == "line1\nchanged2\nline3\nchanged4\n" + + +def test_apply_patch_ignores_standard_no_newline_marker(tmp_path): + target = tmp_path / "plain.txt" + target.write_text("before") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: plain.txt +@@ -1,1 +1,1 @@ +-before +\\ No newline at end of file ++after +\\ No newline at end of file +*** End Patch +""" + )) + + assert "update plain.txt" in result + assert target.read_text() == "after\n" + + +def test_apply_patch_rejects_empty_hunk(tmp_path): + target = tmp_path / "plain.txt" + target.write_text("before\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: plain.txt +@@ +*** End Patch +""" + )) + + assert "hunk is empty" in result + assert target.read_text() == "before\n" + + +def test_apply_patch_uses_unified_diff_line_hint(tmp_path): + target = tmp_path / "repeated.txt" + target.write_text("target\nmiddle\ntarget\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: repeated.txt +@@ -3,1 +3,1 @@ +-target ++changed +*** End Patch +""" + )) + + assert "update repeated.txt" in result + assert target.read_text() == "target\nmiddle\nchanged\n" + + +def test_apply_patch_line_hint_does_not_fallback_to_earlier_match(tmp_path): + target = tmp_path / "repeated.txt" + target.write_text("target\nmiddle\nother\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: repeated.txt +@@ -3,1 +3,1 @@ +-target ++changed +*** End Patch +""" + )) + + assert "hunk does not match repeated.txt" in result + assert target.read_text() == "target\nmiddle\nother\n" + + +def test_apply_patch_mismatch_reports_best_match(tmp_path): + target = tmp_path / "near.txt" + target.write_text("alpha\nbeta\ngamma\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: near.txt +@@ -2,1 +2,1 @@ +-betx ++delta +*** End Patch +""" + )) + + assert "hunk does not match near.txt" in result + assert "Best match" in result + assert "line 2" in result + assert target.read_text() == "alpha\nbeta\ngamma\n" + + +def test_apply_patch_moves_and_updates_file(tmp_path): + source = tmp_path / "old" / "name.txt" + source.parent.mkdir() + source.write_text("old content\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: old/name.txt +*** Move to: renamed/dir/name.txt +@@ +-old content ++new content +*** End Patch +""" + )) + + assert "move old/name.txt -> renamed/dir/name.txt" in result + assert not source.exists() + assert (tmp_path / "renamed" / "dir" / "name.txt").read_text() == "new content\n" + + +def test_apply_patch_deletes_file(tmp_path): + target = tmp_path / "obsolete.txt" + target.write_text("remove me\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Delete File: obsolete.txt +*** End Patch +""" + )) + + assert "delete obsolete.txt" in result + assert not target.exists() + + +def test_apply_patch_rejects_absolute_and_parent_paths(tmp_path): + tool = ApplyPatchTool(workspace=tmp_path) + + absolute = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Add File: /tmp/owned.txt ++nope +*** End Patch +""" + )) + parent = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Add File: ../owned.txt ++nope +*** End Patch +""" + )) + + assert "must be relative" in absolute + assert "must not contain '..'" in parent + assert not (tmp_path.parent / "owned.txt").exists() + + +def test_apply_patch_does_not_overwrite_existing_file_with_add(tmp_path): + target = tmp_path / "existing.txt" + target.write_text("keep me\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Add File: existing.txt ++replace me +*** End Patch +""" + )) + + assert "file to add already exists" in result + assert target.read_text() == "keep me\n" + + +def test_apply_patch_rolls_back_when_late_operation_fails(tmp_path): + first = tmp_path / "first.txt" + first.write_text("before\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: first.txt +@@ +-before ++after +*** Delete File: missing.txt +*** End Patch +""" + )) + + assert "file to delete does not exist" in result + assert first.read_text() == "before\n" diff --git a/tests/tools/test_edit_enhancements.py b/tests/tools/test_edit_enhancements.py index 1f22c963b..7202fc37b 100644 --- a/tests/tools/test_edit_enhancements.py +++ b/tests/tools/test_edit_enhancements.py @@ -1,5 +1,5 @@ """Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions, -.ipynb detection, and create-file semantics.""" +notebook JSON editing, and create-file semantics.""" import pytest @@ -108,22 +108,27 @@ class TestEditCreateFile: # --------------------------------------------------------------------------- -# .ipynb detection +# .ipynb editing # --------------------------------------------------------------------------- -class TestEditIpynbDetection: - """edit_file should refuse .ipynb and suggest notebook_edit.""" +class TestEditIpynbFiles: + """edit_file edits notebooks as normal JSON files.""" @pytest.fixture() def tool(self, tmp_path): return EditFileTool(workspace=tmp_path) @pytest.mark.asyncio - async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path): + async def test_ipynb_can_be_edited_as_json(self, tool, tmp_path): f = tmp_path / "analysis.ipynb" f.write_text('{"cells": []}', encoding="utf-8") - result = await tool.execute(path=str(f), old_text="x", new_text="y") - assert "notebook" in result.lower() + result = await tool.execute( + path=str(f), + old_text='"cells": []', + new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]', + ) + assert "Successfully edited" in result + assert '"source": "hi"' in f.read_text(encoding="utf-8") # --------------------------------------------------------------------------- diff --git a/tests/tools/test_exec_session_tools.py b/tests/tools/test_exec_session_tools.py new file mode 100644 index 000000000..945473926 --- /dev/null +++ b/tests/tools/test_exec_session_tools.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import asyncio +import re +import shlex +import subprocess +import sys + +from nanobot.agent.tools.shell import ExecTool +from nanobot.agent.tools.exec_session import ExecSessionManager, WriteStdinTool + + +def _python_command(code: str) -> str: + if sys.platform == "win32": + return f"{subprocess.list2cmdline([sys.executable])} -u -c {subprocess.list2cmdline([code])}" + return f"{shlex.quote(sys.executable)} -u -c {shlex.quote(code)}" + + +def _session_id(output: str) -> str: + match = re.search(r"session_id:\s*([0-9a-f]+)", output) + assert match, output + return match.group(1) + + +def test_exec_keeps_one_shot_behavior_without_yield_time_ms(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + return await tool.execute(command="echo hello") + + result = asyncio.run(run()) + + assert "hello" in result + assert "Exit code: 0" in result + assert "session_id:" not in result + + +def test_exec_accepts_command_aliases(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir="/") + return await tool.execute(cmd="pwd", workdir=str(tmp_path)) + + result = asyncio.run(run()) + + assert str(tmp_path) in result + assert "Exit code: 0" in result + + +def test_exec_returns_completed_session_output_when_yield_time_ms_is_used(tmp_path): + async def run() -> str: + manager = ExecSessionManager() + tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + + result = await tool.execute(command="echo hello", yield_time_ms=1000) + if "session_id:" in result: + sid = _session_id(result) + result += "\n" + await stdin_tool.execute( + session_id=sid, + chars="", + yield_time_ms=1000, + ) + return result + + result = asyncio.run(run()) + + assert "hello" in result + assert "Exit code: 0" in result + assert "session_id:" not in result + + +def test_exec_session_accepts_max_output_tokens_alias(tmp_path): + async def run() -> str: + manager = ExecSessionManager() + tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + command = _python_command("print('A' * 2000)") + return await tool.execute( + command=command, + yield_time_ms=1000, + max_output_tokens=1000, + ) + + result = asyncio.run(run()) + + assert "chars truncated" in result + assert "Exit code: 0" in result + + +def test_exec_one_shot_accepts_max_output_tokens_alias(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + command = _python_command("print('A' * 2000)") + return await tool.execute(command=command, max_output_tokens=1000) + + result = asyncio.run(run()) + + assert "chars truncated" in result + assert "Exit code: 0" in result + + +def test_exec_accepts_supported_shell_parameter(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + return await tool.execute(command="echo shell-ok", shell="sh", login=False) + + if sys.platform == "win32": + return + result = asyncio.run(run()) + + assert "shell-ok" in result + assert "Exit code: 0" in result + + +def test_exec_rejects_unsupported_shell(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + return await tool.execute(command="echo no", shell="python") + + if sys.platform == "win32": + return + result = asyncio.run(run()) + + assert "unsupported shell" in result + + +def test_exec_can_continue_with_stdin(tmp_path): + async def run() -> tuple[str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import sys; print('ready', flush=True); " + "line=sys.stdin.readline(); print('got:' + line.strip(), flush=True)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=500) + sid = _session_id(initial) + result = await stdin_tool.execute(session_id=sid, chars="ping\n", yield_time_ms=1000) + return initial, result + + initial, result = asyncio.run(run()) + assert "ready" in initial + assert "Process running" in initial + assert "got:ping" in result + assert "Exit code: 0" in result + + +def test_write_stdin_can_close_stdin(tmp_path): + async def run() -> tuple[str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import sys; print('ready', flush=True); " + "data=sys.stdin.read(); print('got:' + data, flush=True)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=500) + sid = _session_id(initial) + result = await stdin_tool.execute( + session_id=sid, + chars="payload", + close_stdin=True, + yield_time_ms=1000, + ) + return initial, result + + initial, result = asyncio.run(run()) + assert "ready" in initial + assert "got:payload" in result + assert "Stdin closed." in result + assert "Exit code: 0" in result + + +def test_write_stdin_can_terminate_session(tmp_path): + async def run() -> tuple[str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=30, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('ready', flush=True); time.sleep(30)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=500) + sid = _session_id(initial) + result = await stdin_tool.execute( + session_id=sid, + terminate=True, + yield_time_ms=0, + ) + return initial, result + + initial, result = asyncio.run(run()) + assert "ready" in initial + assert "Session terminated." in result + assert "Exit code:" in result + + +def test_write_stdin_accepts_max_output_tokens_alias(tmp_path): + async def run() -> tuple[str, str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('A' * 2000, flush=True); time.sleep(5)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=0) + sid = _session_id(initial) + poll = await stdin_tool.execute( + session_id=sid, + yield_time_ms=500, + max_output_tokens=1000, + ) + cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0) + return initial, poll, cleanup + + initial, poll, cleanup = asyncio.run(run()) + assert "Process running" in initial + assert "chars truncated" in poll + assert "Session terminated." in cleanup + + +def test_exec_session_mode_reuses_exec_safety_guard(tmp_path): + manager = ExecSessionManager() + tool = ExecTool( + working_dir=str(tmp_path), + deny_patterns=[r"echo\s+blocked"], + session_manager=manager, + ) + + result = asyncio.run(tool.execute(command="echo blocked", yield_time_ms=0)) + + assert "blocked by deny pattern" in result + + +def test_write_stdin_reports_missing_session(tmp_path): + manager = ExecSessionManager() + tool = WriteStdinTool(manager=manager) + + result = asyncio.run(tool.execute(session_id="missing", chars="")) + + assert "exec session not found" in result diff --git a/tests/tools/test_file_edit_coding_enhancements.py b/tests/tools/test_file_edit_coding_enhancements.py new file mode 100644 index 000000000..d361d88ae --- /dev/null +++ b/tests/tools/test_file_edit_coding_enhancements.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import asyncio + +from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool + + +def test_read_file_force_bypasses_dedup(tmp_path): + target = tmp_path / "data.txt" + target.write_text("alpha\n") + tool = ReadFileTool(workspace=tmp_path) + + first = asyncio.run(tool.execute(path=str(target))) + second = asyncio.run(tool.execute(path=str(target))) + forced = asyncio.run(tool.execute(path=str(target), force=True)) + + assert "alpha" in first + assert "unchanged" in second.lower() + assert "alpha" in forced + assert "unchanged" not in forced.lower() + + +def test_edit_file_can_select_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("one\nsame\ntwo\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=2, + )) + + assert "Successfully edited" in result + assert target.read_text() == "one\nsame\ntwo\nchanged\n" + + +def test_edit_file_expected_replacements_guards_replace_all(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + replace_all=True, + expected_replacements=1, + )) + + assert "expected 1 replacements but would make 2" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_expected_replacements_allows_replace_all_when_count_matches(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + replace_all=True, + expected_replacements=2, + )) + + assert "Successfully edited" in result + assert target.read_text() == "changed\nchanged\n" + + +def test_edit_file_can_select_nearest_line_hint(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("one\nsame\ntwo\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=4, + )) + + assert "Successfully edited" in result + assert target.read_text() == "one\nsame\ntwo\nchanged\n" + + +def test_edit_file_can_edit_ipynb_as_json(tmp_path): + target = tmp_path / "analysis.ipynb" + target.write_text('{"cells": []}') + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text='"cells": []', + new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]', + )) + + assert "Successfully edited" in result + assert '"source": "hi"' in target.read_text() + + +def test_edit_file_multiple_match_hint_mentions_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + )) + + assert "old_text appears 2 times" in result + assert "occurrence" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_ambiguous_line_hint(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nmiddle\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=2, + )) + + assert "line_hint 2 is ambiguous" in result + assert target.read_text() == "same\nmiddle\nsame\n" + + +def test_edit_file_rejects_occurrence_with_replace_all(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=1, + replace_all=True, + )) + + assert "occurrence cannot be used with replace_all" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_line_hint_with_replace_all(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=1, + replace_all=True, + )) + + assert "line_hint cannot be used with replace_all" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_line_hint_with_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=1, + line_hint=1, + )) + + assert "line_hint cannot be used with occurrence" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_zero_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=0, + )) + + assert "occurrence must be >= 1" in result + assert target.read_text() == "same\n" + + +def test_edit_file_rejects_zero_line_hint(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=0, + )) + + assert "line_hint must be >= 1" in result + assert target.read_text() == "same\n" diff --git a/tests/tools/test_notebook_tool.py b/tests/tools/test_notebook_tool.py deleted file mode 100644 index 232f13c4b..000000000 --- a/tests/tools/test_notebook_tool.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Tests for NotebookEditTool — Jupyter .ipynb editing.""" - -import json - -import pytest - -from nanobot.agent.tools.notebook import NotebookEditTool - - -def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict: - """Build a minimal valid .ipynb structure.""" - return { - "nbformat": nbformat, - "nbformat_minor": nbformat_minor, - "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}}, - "cells": cells or [], - } - - -def _code_cell(source: str, cell_id: str | None = None) -> dict: - cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None} - if cell_id: - cell["id"] = cell_id - return cell - - -def _md_cell(source: str, cell_id: str | None = None) -> dict: - cell = {"cell_type": "markdown", "source": source, "metadata": {}} - if cell_id: - cell["id"] = cell_id - return cell - - -def _write_nb(tmp_path, name: str, nb: dict) -> str: - p = tmp_path / name - p.write_text(json.dumps(nb), encoding="utf-8") - return str(p) - - -class TestNotebookEdit: - - @pytest.fixture() - def tool(self, tmp_path): - return NotebookEditTool(workspace=tmp_path) - - @pytest.mark.asyncio - async def test_replace_cell_content(self, tool, tmp_path): - nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="print('world')") - assert "Successfully" in result - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert saved["cells"][0]["source"] == "print('world')" - assert saved["cells"][1]["source"] == "x = 1" - - @pytest.mark.asyncio - async def test_insert_cell_after_target(self, tool, tmp_path): - nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert") - assert "Successfully" in result - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert len(saved["cells"]) == 3 - assert saved["cells"][0]["source"] == "cell 0" - assert saved["cells"][1]["source"] == "inserted" - assert saved["cells"][2]["source"] == "cell 1" - - @pytest.mark.asyncio - async def test_delete_cell(self, tool, tmp_path): - nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=1, edit_mode="delete") - assert "Successfully" in result - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert len(saved["cells"]) == 2 - assert saved["cells"][0]["source"] == "A" - assert saved["cells"][1]["source"] == "C" - - @pytest.mark.asyncio - async def test_create_new_notebook_from_scratch(self, tool, tmp_path): - path = str(tmp_path / "new.ipynb") - result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown") - assert "Successfully" in result or "created" in result.lower() - saved = json.loads((tmp_path / "new.ipynb").read_text()) - assert saved["nbformat"] == 4 - assert len(saved["cells"]) == 1 - assert saved["cells"][0]["cell_type"] == "markdown" - assert saved["cells"][0]["source"] == "# Hello" - - @pytest.mark.asyncio - async def test_invalid_cell_index_error(self, tool, tmp_path): - nb = _make_notebook([_code_cell("only cell")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=5, new_source="x") - assert "Error" in result - - @pytest.mark.asyncio - async def test_non_ipynb_rejected(self, tool, tmp_path): - f = tmp_path / "script.py" - f.write_text("pass") - result = await tool.execute(path=str(f), cell_index=0, new_source="x") - assert "Error" in result - assert ".ipynb" in result - - @pytest.mark.asyncio - async def test_preserves_metadata_and_outputs(self, tool, tmp_path): - cell = _code_cell("old") - cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}] - cell["execution_count"] = 42 - nb = _make_notebook([cell]) - path = _write_nb(tmp_path, "test.ipynb", nb) - await tool.execute(path=path, cell_index=0, new_source="new") - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert saved["metadata"]["kernelspec"]["language"] == "python" - - @pytest.mark.asyncio - async def test_nbformat_45_generates_cell_id(self, tool, tmp_path): - nb = _make_notebook([], nbformat_minor=5) - path = _write_nb(tmp_path, "test.ipynb", nb) - await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert") - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert "id" in saved["cells"][0] - assert len(saved["cells"][0]["id"]) > 0 - - @pytest.mark.asyncio - async def test_insert_with_cell_type_markdown(self, tool, tmp_path): - nb = _make_notebook([_code_cell("code")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown") - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert saved["cells"][1]["cell_type"] == "markdown" - - @pytest.mark.asyncio - async def test_invalid_edit_mode_rejected(self, tool, tmp_path): - nb = _make_notebook([_code_cell("code")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae") - assert "Error" in result - assert "edit_mode" in result - - @pytest.mark.asyncio - async def test_invalid_cell_type_rejected(self, tool, tmp_path): - nb = _make_notebook([_code_cell("code")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw") - assert "Error" in result - assert "cell_type" in result diff --git a/tests/tools/test_tool_loader.py b/tests/tools/test_tool_loader.py index 54b4d92d5..2dfb25cb7 100644 --- a/tests/tools/test_tool_loader.py +++ b/tests/tools/test_tool_loader.py @@ -89,9 +89,11 @@ def test_discover_finds_concrete_tools(): loader = ToolLoader() discovered = loader.discover() class_names = {cls.__name__ for cls in discovered} + assert "ApplyPatchTool" in class_names assert "ExecTool" in class_names assert "MessageTool" in class_names assert "SpawnTool" in class_names + assert "WriteStdinTool" in class_names def test_discover_excludes_abstract_and_mcp(): @@ -406,7 +408,7 @@ def test_loader_registers_same_tools_as_old_hardcoded(): expected = { "read_file", "write_file", "edit_file", "list_dir", - "grep", "notebook_edit", "exec", "web_search", "web_fetch", + "grep", "exec", "web_search", "web_fetch", "message", "spawn", "cron", } actual = set(registered) From 3e154bb5cf819ca3a60ea0c0573f814ddc97da30 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Wed, 20 May 2026 23:42:55 +0800 Subject: [PATCH 03/54] fix(tools): align exec platform test doubles --- tests/tools/test_exec_platform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tools/test_exec_platform.py b/tests/tools/test_exec_platform.py index 69a271ec1..ffb25f985 100644 --- a/tests/tools/test_exec_platform.py +++ b/tests/tools/test_exec_platform.py @@ -162,7 +162,7 @@ class TestPathAppendPlatform: captured_cmd = None captured_env = {} - async def capture_spawn(cmd, cwd, env): + async def capture_spawn(cmd, cwd, env, shell_program=None, login=True): nonlocal captured_cmd captured_cmd = cmd captured_env.update(env) @@ -190,7 +190,7 @@ class TestPathAppendPlatform: captured_env = {} - async def capture_spawn(cmd, cwd, env): + async def capture_spawn(cmd, cwd, env, shell_program=None, login=True): captured_env.update(env) return mock_proc From 480ca28a2de9de5f36f670e1a766fc2da638fb33 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 00:58:05 +0800 Subject: [PATCH 04/54] feat(tools): improve coding workflow recovery --- nanobot/agent/runner.py | 36 ++++-- nanobot/agent/tools/exec_session.py | 114 ++++++++++++++++- nanobot/agent/tools/search.py | 166 ++++++++++++++++++++++++- nanobot/utils/file_edit_events.py | 111 +++++++++++++++-- nanobot/utils/tool_hints.py | 4 + tests/tools/test_exec_session_tools.py | 59 ++++++++- tests/tools/test_search_tools.py | 66 +++++++++- tests/tools/test_tool_loader.py | 3 +- tests/utils/test_file_edit_events.py | 44 +++++++ 9 files changed, 571 insertions(+), 32 deletions(-) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 0b0164fd0..19494034f 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -19,7 +19,8 @@ from nanobot.utils.file_edit_events import ( build_file_edit_end_event, build_file_edit_error_event, build_file_edit_start_event, - prepare_file_edit_tracker, + prepare_file_edit_tracker as _prepare_file_edit_tracker, + prepare_file_edit_trackers, StreamingFileEditTracker, ) from nanobot.utils.helpers import ( @@ -58,11 +59,14 @@ _SNIP_SAFETY_BUFFER = 1024 _MICROCOMPACT_KEEP_RECENT = 10 _MICROCOMPACT_MIN_CHARS = 500 _COMPACTABLE_TOOLS = frozenset({ - "read_file", "exec", "grep", - "web_search", "web_fetch", "list_dir", + "read_file", "exec", "grep", "find_files", + "web_search", "web_fetch", "list_dir", "list_exec_sessions", }) _BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]" +# Backward-compatible module attribute for tests/extensions that monkeypatch +# the former single-file tracker hook. Runtime uses prepare_file_edit_trackers. +prepare_file_edit_tracker = _prepare_file_edit_tracker @dataclass(slots=True) @@ -857,8 +861,8 @@ class AgentRunner: and on_progress_accepts_file_edit_events(spec.progress_callback) ) progress_callback = spec.progress_callback if emit_file_edit_events else None - file_edit_tracker = ( - prepare_file_edit_tracker( + file_edit_trackers = ( + prepare_file_edit_trackers( call_id=tool_call.id, tool_name=tool_call.name, tool=tool, @@ -868,13 +872,13 @@ class AgentRunner: if progress_callback is not None else None ) - if file_edit_tracker is not None and progress_callback is not None: + if file_edit_trackers and progress_callback is not None: await invoke_file_edit_progress( progress_callback, [build_file_edit_start_event( file_edit_tracker, params if isinstance(params, dict) else None, - )], + ) for file_edit_tracker in file_edit_trackers], ) try: if tool is not None: @@ -884,10 +888,13 @@ class AgentRunner: except asyncio.CancelledError: raise except BaseException as exc: - if file_edit_tracker is not None and progress_callback is not None: + if file_edit_trackers and progress_callback is not None: await invoke_file_edit_progress( progress_callback, - [build_file_edit_error_event(file_edit_tracker, str(exc))], + [ + build_file_edit_error_event(file_edit_tracker, str(exc)) + for file_edit_tracker in file_edit_trackers + ], ) event = { "name": tool_call.name, @@ -910,10 +917,13 @@ class AgentRunner: return payload, event, None if isinstance(result, str) and result.startswith("Error"): - if file_edit_tracker is not None and progress_callback is not None: + if file_edit_trackers and progress_callback is not None: await invoke_file_edit_progress( progress_callback, - [build_file_edit_error_event(file_edit_tracker, result)], + [ + build_file_edit_error_event(file_edit_tracker, result) + for file_edit_tracker in file_edit_trackers + ], ) event = { "name": tool_call.name, @@ -933,13 +943,13 @@ class AgentRunner: return result + hint, event, RuntimeError(result) return result + hint, event, None - if file_edit_tracker is not None and progress_callback is not None: + if file_edit_trackers and progress_callback is not None: await invoke_file_edit_progress( progress_callback, [build_file_edit_end_event( file_edit_tracker, params if isinstance(params, dict) else None, - )], + ) for file_edit_tracker in file_edit_trackers], ) detail = "" if result is None else str(result) diff --git a/nanobot/agent/tools/exec_session.py b/nanobot/agent/tools/exec_session.py index 202fbc640..34667aeaa 100644 --- a/nanobot/agent/tools/exec_session.py +++ b/nanobot/agent/tools/exec_session.py @@ -25,22 +25,39 @@ class _SessionPoll: output: str done: bool exit_code: int | None + elapsed_s: float = 0.0 timed_out: bool = False terminated: bool = False stdin_closed: bool = False truncated_chars: int = 0 +@dataclass(slots=True) +class ExecSessionInfo: + session_id: str + command: str + cwd: str + elapsed_s: float + idle_s: float + remaining_s: float + returncode: int | None + + class _ExecSession: def __init__( self, *, session_id: str, process: asyncio.subprocess.Process, + command: str, + cwd: str, timeout: int, ) -> None: self.session_id = session_id self.process = process + self.command = command + self.cwd = cwd + self.started_at = time.monotonic() self.deadline = time.monotonic() + timeout self.last_access = time.monotonic() self._chunks: list[str] = [] @@ -122,6 +139,7 @@ class _ExecSession: output=output, done=self.process.returncode is not None, exit_code=self.process.returncode, + elapsed_s=max(0.0, time.monotonic() - self.started_at), timed_out=self._timed_out, terminated=terminated, stdin_closed=stdin_closed, @@ -161,7 +179,13 @@ class ExecSessionManager: raise RuntimeError(f"maximum exec sessions reached ({self.max_sessions})") process = await self._spawn(command, cwd, env, shell_program, login) session_id = uuid.uuid4().hex[:12] - session = _ExecSession(session_id=session_id, process=process, timeout=timeout) + session = _ExecSession( + session_id=session_id, + process=process, + command=command, + cwd=cwd, + timeout=timeout, + ) self._sessions[session_id] = session poll = await session.poll(yield_time_ms, max_output_chars) @@ -186,7 +210,7 @@ class ExecSessionManager: if session is None: raise KeyError(session_id) - if chars is not None: + if chars: error = await session.write(chars) if error: raise RuntimeError(error) @@ -209,13 +233,29 @@ class ExecSessionManager: self._sessions.pop(session_id, None) return poll + async def list(self) -> list[ExecSessionInfo]: + async with self._lock: + await self._cleanup_locked() + now = time.monotonic() + return [ + ExecSessionInfo( + session_id=session_id, + command=session.command, + cwd=session.cwd, + elapsed_s=max(0.0, now - session.started_at), + idle_s=max(0.0, now - session.last_access), + remaining_s=max(0.0, session.deadline - now), + returncode=session.process.returncode, + ) + for session_id, session in sorted(self._sessions.items()) + ] + async def _cleanup_locked(self) -> None: now = time.monotonic() stale = [ session_id for session_id, session in self._sessions.items() - if session.process.returncode is not None - or now - session.last_access > self.idle_timeout + if now - session.last_access > self.idle_timeout ] for session_id in stale: session = self._sessions.pop(session_id) @@ -291,6 +331,7 @@ def format_session_poll(session_id: str, poll: _SessionPoll) -> str: parts.append(f"Exit code: {poll.exit_code}") else: parts.append(f"Process running. session_id: {session_id}") + parts.append(f"Elapsed: {poll.elapsed_s:.1f}s") return "\n".join(parts) if parts else "(no output yet)" @@ -407,3 +448,68 @@ class WriteStdinTool(Tool): return f"Error: exec session not found: {session_id}" except Exception as exc: return f"Error writing to exec session: {exc}" + + +@tool_parameters(tool_parameters_schema()) +class ListExecSessionsTool(Tool): + """List active exec sessions.""" + + _scopes = {"core", "subagent"} + config_key = "exec" + + @classmethod + def config_cls(cls): + from nanobot.agent.tools.shell import ExecToolConfig + + return ExecToolConfig + + @classmethod + def enabled(cls, ctx: Any) -> bool: + return ctx.config.exec.enable + + def __init__( + self, + *, + manager: ExecSessionManager | None = None, + ) -> None: + self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER + + @classmethod + def create(cls, ctx: Any) -> Tool: + return cls() + + @property + def name(self) -> str: + return "list_exec_sessions" + + @property + def description(self) -> str: + return ( + "List active long-running exec sessions, including session_id, cwd, " + "elapsed time, idle time, remaining timeout, and command preview. " + "Use this to recover a session_id before polling with write_stdin." + ) + + @property + def read_only(self) -> bool: + return True + + async def execute(self, **kwargs: Any) -> str: + try: + sessions = await self._manager.list() + if not sessions: + return "No active exec sessions." + lines = [] + for info in sessions: + command = " ".join(info.command.split()) + if len(command) > 120: + command = command[:119] + "..." + status = "exited" if info.returncode is not None else "running" + lines.append( + f"{info.session_id} | {status} | elapsed={info.elapsed_s:.1f}s " + f"| idle={info.idle_s:.1f}s | remaining={info.remaining_s:.1f}s " + f"| cwd={info.cwd} | {command}" + ) + return "\n".join(lines) + except Exception as exc: + return f"Error listing exec sessions: {exc}" diff --git a/nanobot/agent/tools/search.py b/nanobot/agent/tools/search.py index 49448030b..52c18f16b 100644 --- a/nanobot/agent/tools/search.py +++ b/nanobot/agent/tools/search.py @@ -1,4 +1,4 @@ -"""Search tools: grep.""" +"""Search tools: file discovery and grep.""" from __future__ import annotations @@ -12,6 +12,7 @@ from typing import Any, Iterable, TypeVar from nanobot.agent.tools.filesystem import ListDirTool, _FsTool _DEFAULT_HEAD_LIMIT = 250 +_DEFAULT_FILE_HEAD_LIMIT = 200 T = TypeVar("T") _TYPE_GLOB_MAP = { "py": ("*.py", "*.pyi"), @@ -88,6 +89,14 @@ def _matches_type(name: str, file_type: str | None) -> bool: return any(fnmatch.fnmatch(name.lower(), pattern.lower()) for pattern in patterns) +def _matches_query(rel_path: str, query: str | None) -> bool: + if not query: + return True + haystack = rel_path.lower() + terms = [part for part in query.lower().split() if part] + return all(term in haystack for term in terms) + + class _SearchTool(_FsTool): _IGNORE_DIRS = set(ListDirTool._IGNORE_DIRS) @@ -109,6 +118,161 @@ class _SearchTool(_FsTool): yield current / filename +class FindFilesTool(_SearchTool): + """Find files by path fragment, glob, or type.""" + _scopes = {"core", "subagent"} + + @property + def name(self) -> str: + return "find_files" + + @property + def description(self) -> str: + return ( + "Find files by path fragment, glob, or file type. " + "Use this before read_file when you need to locate files. " + "Returns workspace-relative paths and skips common dependency/build directories." + ) + + @property + def read_only(self) -> bool: + return True + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory or file to search in (default '.')", + }, + "query": { + "type": "string", + "description": ( + "Optional case-insensitive path fragment search. " + "Whitespace-separated terms must all be present." + ), + }, + "glob": { + "type": "string", + "description": "Optional file filter, e.g. '*.py' or 'tests/**/test_*.py'", + }, + "type": { + "type": "string", + "description": "Optional file type shorthand, e.g. 'py', 'ts', 'md', 'json'", + }, + "include_dirs": { + "type": "boolean", + "description": "Include matching directories as well as files (default false)", + }, + "sort": { + "type": "string", + "enum": ["path", "modified"], + "description": "Sort by path or most recently modified first (default path)", + }, + "head_limit": { + "type": "integer", + "description": "Maximum number of paths to return (default 200, 0 for all, max 1000)", + "minimum": 0, + "maximum": 1000, + }, + "offset": { + "type": "integer", + "description": "Skip the first N results before applying head_limit", + "minimum": 0, + "maximum": 100000, + }, + }, + } + + def _iter_paths(self, root: Path, *, include_dirs: bool) -> Iterable[Path]: + if root.is_file(): + yield root + return + if include_dirs: + yield root + for dirpath, dirnames, filenames in os.walk(root): + dirnames[:] = sorted(d for d in dirnames if d not in self._IGNORE_DIRS) + current = Path(dirpath) + if include_dirs and current != root: + yield current + for filename in sorted(filenames): + yield current / filename + + async def execute( + self, + path: str = ".", + query: str | None = None, + glob: str | None = None, + type: str | None = None, + include_dirs: bool = False, + sort: str = "path", + head_limit: int | None = None, + offset: int = 0, + **kwargs: Any, + ) -> str: + try: + target = self._resolve(path or ".") + if not target.exists(): + return f"Error: Path not found: {path}" + if not (target.is_dir() or target.is_file()): + return f"Error: Unsupported path: {path}" + + if sort not in {"path", "modified"}: + return "Error: sort must be 'path' or 'modified'" + + limit = ( + _DEFAULT_FILE_HEAD_LIMIT + if head_limit is None + else None if head_limit == 0 else head_limit + ) + root = target if target.is_dir() else target.parent + matches: list[tuple[str, float]] = [] + + for candidate in self._iter_paths(target, include_dirs=include_dirs): + if candidate.is_dir() and not include_dirs: + continue + rel_path = candidate.relative_to(root).as_posix() + display_path = self._display_path(candidate, root) + name = candidate.name + + if glob and not _match_glob(rel_path, name, glob): + continue + if candidate.is_file() and not _matches_type(name, type): + continue + if candidate.is_dir() and type: + continue + if not _matches_query(display_path, query): + continue + try: + mtime = candidate.stat().st_mtime + except OSError: + mtime = 0.0 + suffix = "/" if candidate.is_dir() else "" + matches.append((display_path + suffix, mtime)) + + if sort == "modified": + matches.sort(key=lambda item: (-item[1], item[0])) + else: + matches.sort(key=lambda item: item[0]) + + paths = [item[0] for item in matches] + paged, truncated = _paginate(paths, limit, offset) + if not paged: + return "No files found" + + result = "\n".join(paged) + note = _pagination_note(limit, offset, truncated) + if note: + result += "\n\n" + note + return result + except PermissionError as e: + return f"Error: {e}" + except Exception as e: + return f"Error finding files: {e}" + + class GrepTool(_SearchTool): """Search file contents using a regex-like pattern.""" _scopes = {"core", "subagent"} diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index 056041f4b..c11e8ae60 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any, Awaitable, Callable -TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file"}) +TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "apply_patch"}) _MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024 _LIVE_EMIT_INTERVAL_S = 0.18 _LIVE_EMIT_LINE_STEP = 24 @@ -153,19 +153,108 @@ def prepare_file_edit_tracker( workspace: Path | None, params: dict[str, Any] | None, ) -> FileEditTracker | None: + trackers = prepare_file_edit_trackers( + call_id=call_id, + tool_name=tool_name, + tool=tool, + workspace=workspace, + params=params, + ) + return trackers[0] if trackers else None + + +def prepare_file_edit_trackers( + *, + call_id: str, + tool_name: str, + tool: Any, + workspace: Path | None, + params: dict[str, Any] | None, +) -> list[FileEditTracker]: if not is_file_edit_tool(tool_name): - return None + return [] + paths = resolve_file_edit_paths(tool_name, tool, workspace, params) + trackers: list[FileEditTracker] = [] + seen: set[Path] = set() + for path in paths: + try: + resolved = path.resolve() + except Exception: + resolved = path + if resolved in seen: + continue + seen.add(resolved) + before = read_file_snapshot(path) + trackers.append(FileEditTracker( + call_id=str(call_id or ""), + tool=tool_name, + path=path, + display_path=display_file_edit_path(path, workspace), + before=before, + )) + return trackers + + +def resolve_file_edit_paths( + tool_name: str, + tool: Any, + workspace: Path | None, + params: dict[str, Any] | None, +) -> list[Path]: + if tool_name == "apply_patch": + return _resolve_apply_patch_paths(tool, workspace, params) path = resolve_file_edit_path(tool, workspace, params) if path is None: - return None - before = read_file_snapshot(path) - return FileEditTracker( - call_id=str(call_id or ""), - tool=tool_name, - path=path, - display_path=display_file_edit_path(path, workspace), - before=before, - ) + return [] + return [path] + + +def _resolve_apply_patch_paths( + tool: Any, + workspace: Path | None, + params: dict[str, Any] | None, +) -> list[Path]: + if not isinstance(params, dict): + return [] + patch = params.get("patch") + if not isinstance(patch, str) or not patch.strip(): + return [] + try: + from nanobot.agent.tools.apply_patch import _parse_patch + + ops = _parse_patch(patch) + except Exception: + return [] + + resolved: list[Path] = [] + for op in ops: + for raw_path in (op.path, op.new_path): + if not raw_path: + continue + path = _resolve_raw_file_edit_path(tool, workspace, raw_path) + if path is not None: + resolved.append(path) + return resolved + + +def _resolve_raw_file_edit_path( + tool: Any, + workspace: Path | None, + raw_path: str, +) -> Path | None: + resolver = getattr(tool, "_resolve", None) + if callable(resolver): + try: + resolved = resolver(raw_path) + if isinstance(resolved, Path): + return resolved + if resolved: + return Path(resolved) + except Exception: + return None + if workspace is None: + return Path(raw_path).expanduser().resolve() + return (workspace / raw_path).expanduser().resolve() def build_file_edit_start_event( diff --git a/nanobot/utils/tool_hints.py b/nanobot/utils/tool_hints.py index 272a19c9a..3a6460701 100644 --- a/nanobot/utils/tool_hints.py +++ b/nanobot/utils/tool_hints.py @@ -11,8 +11,10 @@ _TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = { "read_file": (["path", "file_path"], "read {}", True, False), "write_file": (["path", "file_path"], "write {}", True, False), "edit": (["file_path", "path"], "edit {}", True, False), + "find_files": (["query", "glob", "path"], "find {}", False, False), "grep": (["pattern"], 'grep "{}"', False, False), "exec": (["command"], "$ {}", False, True), + "list_exec_sessions": ([], "exec sessions", False, False), "web_search": (["query"], 'search "{}"', False, False), "web_fetch": (["url"], "fetch {}", True, False), "list_dir": (["path"], "ls {}", True, False), @@ -81,6 +83,8 @@ def _extract_arg(tc, key_args: list[str]) -> str | None: def _fmt_known(tc, fmt: tuple, max_length: int = 40) -> str: """Format a registered tool using its template.""" + if not fmt[0] and "{}" not in fmt[1]: + return fmt[1] val = _extract_arg(tc, fmt[0]) if val is None: return tc.name diff --git a/tests/tools/test_exec_session_tools.py b/tests/tools/test_exec_session_tools.py index 945473926..52f72b556 100644 --- a/tests/tools/test_exec_session_tools.py +++ b/tests/tools/test_exec_session_tools.py @@ -7,7 +7,7 @@ import subprocess import sys from nanobot.agent.tools.shell import ExecTool -from nanobot.agent.tools.exec_session import ExecSessionManager, WriteStdinTool +from nanobot.agent.tools.exec_session import ExecSessionManager, ListExecSessionsTool, WriteStdinTool def _python_command(code: str) -> str: @@ -140,8 +140,10 @@ def test_exec_can_continue_with_stdin(tmp_path): initial, result = asyncio.run(run()) assert "ready" in initial assert "Process running" in initial + assert "Elapsed:" in initial assert "got:ping" in result assert "Exit code: 0" in result + assert "Elapsed:" in result def test_write_stdin_can_close_stdin(tmp_path): @@ -220,6 +222,29 @@ def test_write_stdin_accepts_max_output_tokens_alias(tmp_path): assert "Session terminated." in cleanup +def test_write_stdin_preserves_completed_session_output_until_polled(tmp_path): + async def run() -> tuple[str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('ready', flush=True); " + "time.sleep(1.0); print('done', flush=True)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=300) + sid = _session_id(initial) + await asyncio.sleep(1.2) + final = await stdin_tool.execute(session_id=sid, chars="", yield_time_ms=0) + return initial, final + + initial, final = asyncio.run(run()) + + assert "ready" in initial + assert "done" in final + assert "Exit code: 0" in final + + def test_exec_session_mode_reuses_exec_safety_guard(tmp_path): manager = ExecSessionManager() tool = ExecTool( @@ -240,3 +265,35 @@ def test_write_stdin_reports_missing_session(tmp_path): result = asyncio.run(tool.execute(session_id="missing", chars="")) assert "exec session not found" in result + + +def test_list_exec_sessions_reports_running_commands(tmp_path): + async def run() -> tuple[str, str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + list_tool = ListExecSessionsTool(manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('ready', flush=True); time.sleep(5)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=500) + sid = _session_id(initial) + listing = await list_tool.execute() + cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0) + return sid, listing, cleanup + + sid, listing, cleanup = asyncio.run(run()) + + assert sid in listing + assert "running" in listing + assert "elapsed=" in listing + assert "remaining=" in listing + assert str(tmp_path) in listing + assert "Session terminated." in cleanup + + +def test_list_exec_sessions_reports_empty_state(): + result = asyncio.run(ListExecSessionsTool(manager=ExecSessionManager()).execute()) + + assert result == "No active exec sessions." diff --git a/tests/tools/test_search_tools.py b/tests/tools/test_search_tools.py index 0d3697044..fc7c1944a 100644 --- a/tests/tools/test_search_tools.py +++ b/tests/tools/test_search_tools.py @@ -12,7 +12,7 @@ import pytest from nanobot.agent.loop import AgentLoop from nanobot.agent.subagent import SubagentManager, SubagentStatus -from nanobot.agent.tools.search import GrepTool +from nanobot.agent.tools.search import FindFilesTool, GrepTool from nanobot.agent.tools.web import WebSearchTool from nanobot.bus.queue import MessageBus from nanobot.config.schema import WebSearchConfig @@ -33,6 +33,68 @@ async def test_web_search_tool_refreshes_dynamic_config_loader(monkeypatch) -> N assert await tool.execute("nanobot") == "duckduckgo:nanobot:3" +@pytest.mark.asyncio +async def test_find_files_filters_by_query_glob_and_type(tmp_path: Path) -> None: + (tmp_path / "src").mkdir() + (tmp_path / "src" / "settings_view.tsx").write_text("export {}\n", encoding="utf-8") + (tmp_path / "src" / "settings_api.py").write_text("pass\n", encoding="utf-8") + (tmp_path / "README.md").write_text("settings\n", encoding="utf-8") + + tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path) + result = await tool.execute( + path=".", + query="settings", + glob="src/**", + type="ts", + ) + + assert result.splitlines() == ["src/settings_view.tsx"] + + +@pytest.mark.asyncio +async def test_find_files_can_include_directories(tmp_path: Path) -> None: + (tmp_path / "src" / "settings").mkdir(parents=True) + (tmp_path / "src" / "settings" / "index.ts").write_text("export {}\n", encoding="utf-8") + + tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path) + result = await tool.execute(path="src", query="settings", include_dirs=True) + + assert "src/settings/" in result.splitlines() + assert "src/settings/index.ts" in result.splitlines() + + +@pytest.mark.asyncio +async def test_find_files_supports_modified_sort_and_pagination(tmp_path: Path) -> None: + (tmp_path / "src").mkdir() + for idx, name in enumerate(("a.py", "b.py", "c.py"), start=1): + file_path = tmp_path / "src" / name + file_path.write_text("pass\n", encoding="utf-8") + os.utime(file_path, (idx, idx)) + + tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path) + result = await tool.execute( + path="src", + type="py", + sort="modified", + head_limit=1, + offset=1, + ) + + assert result.splitlines()[0] == "src/b.py" + assert "pagination: limit=1, offset=1" in result + + +@pytest.mark.asyncio +async def test_find_files_rejects_paths_outside_workspace(tmp_path: Path) -> None: + outside = tmp_path.parent / "outside-find-files.txt" + outside.write_text("secret\n", encoding="utf-8") + + tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path) + result = await tool.execute(path=str(outside)) + + assert result.startswith("Error:") + + @pytest.mark.asyncio async def test_grep_respects_glob_filter_and_context(tmp_path: Path) -> None: (tmp_path / "src").mkdir() @@ -249,6 +311,7 @@ def test_agent_loop_registers_grep(tmp_path: Path) -> None: loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + assert "find_files" in loop.tools.tool_names assert "grep" in loop.tools.tool_names @@ -280,6 +343,7 @@ async def test_subagent_registers_grep(tmp_path: Path) -> None: status = SubagentStatus(task_id="sub-1", label="label", task_description="search task", started_at=time.monotonic()) await mgr._run_subagent("sub-1", "search task", "label", {"channel": "cli", "chat_id": "direct"}, status) + assert "find_files" in captured["tool_names"] assert "grep" in captured["tool_names"] diff --git a/tests/tools/test_tool_loader.py b/tests/tools/test_tool_loader.py index 2dfb25cb7..62703883c 100644 --- a/tests/tools/test_tool_loader.py +++ b/tests/tools/test_tool_loader.py @@ -408,7 +408,8 @@ def test_loader_registers_same_tools_as_old_hardcoded(): expected = { "read_file", "write_file", "edit_file", "list_dir", - "grep", "exec", "web_search", "web_fetch", + "find_files", "grep", "exec", "write_stdin", "list_exec_sessions", + "web_search", "web_fetch", "message", "spawn", "cron", } actual = set(registered) diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py index cdaae5167..7cc8a59fa 100644 --- a/tests/utils/test_file_edit_events.py +++ b/tests/utils/test_file_edit_events.py @@ -9,6 +9,7 @@ from nanobot.utils.file_edit_events import ( build_file_edit_start_event, line_diff_stats, prepare_file_edit_tracker, + prepare_file_edit_trackers, read_file_snapshot, StreamingFileEditTracker, ) @@ -81,6 +82,49 @@ def test_binary_file_is_reported_but_not_counted(tmp_path: Path) -> None: assert (event["added"], event["deleted"]) == (0, 0) +def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) -> None: + (tmp_path / "src").mkdir() + existing = tmp_path / "src" / "existing.py" + existing.write_text("old\nkeep\n", encoding="utf-8") + delete_me = tmp_path / "src" / "delete_me.py" + delete_me.write_text("gone\n", encoding="utf-8") + + patch = """*** Begin Patch +*** Add File: src/new.py ++fresh +*** Update File: src/existing.py +@@ +-old ++new + keep +*** Delete File: src/delete_me.py +*** End Patch""" + + trackers = prepare_file_edit_trackers( + call_id="call-patch", + tool_name="apply_patch", + tool=None, + workspace=tmp_path, + params={"patch": patch}, + ) + + assert [tracker.display_path for tracker in trackers] == [ + "src/new.py", + "src/existing.py", + "src/delete_me.py", + ] + + (tmp_path / "src" / "new.py").write_text("fresh\n", encoding="utf-8") + existing.write_text("new\nkeep\n", encoding="utf-8") + delete_me.unlink() + + events = [build_file_edit_end_event(tracker, {"patch": patch}) for tracker in trackers] + by_path = {event["path"]: event for event in events} + assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0) + assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1) + assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1) + + def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None: target = tmp_path / "large.txt" params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)} From 8ec1025193918a86e0e3d2c40185df7f007be06f Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Fri, 15 May 2026 22:02:43 -0400 Subject: [PATCH 05/54] feat(signal): add Signal channel support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Integrates signal-cli daemon via HTTP JSON-RPC as a nanobot channel. Supports DMs and group chats with open/allowlist access policies, markdown→Signal text style conversion, typing indicators, attachment handling, group message context buffering, and automatic reconnect with exponential backoff. Includes unit tests for channel lifecycle, message routing, mention detection, markdown conversion, and message splitting. Originally based on https://github.com/HKUDS/nanobot/pull/601. --- nanobot/channels/signal.py | 1133 ++++++++++++++++++++++++ tests/channels/test_signal_channel.py | 1058 ++++++++++++++++++++++ tests/channels/test_signal_markdown.py | 244 +++++ 3 files changed, 2435 insertions(+) create mode 100644 nanobot/channels/signal.py create mode 100644 tests/channels/test_signal_channel.py create mode 100644 tests/channels/test_signal_markdown.py diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py new file mode 100644 index 000000000..3e35ae676 --- /dev/null +++ b/nanobot/channels/signal.py @@ -0,0 +1,1133 @@ +"""Signal channel implementation using signal-cli daemon JSON-RPC interface.""" + +from __future__ import annotations + +import asyncio +import json +import re +import shutil +import unicodedata +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import httpx +from pydantic import Field + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.base import BaseChannel +from nanobot.config.paths import get_media_dir +from nanobot.config.schema import Base +from nanobot.utils.helpers import safe_filename, split_message + + +@dataclass +class _Run: + text: str + styles: frozenset[str] = field(default_factory=frozenset) + opaque: bool = False # code / table content — skip further pattern processing + + +_SIG_CODE_BLOCK_RE = re.compile(r'```(?:\w+)?\n?([\s\S]*?)```') +_SIG_INLINE_CODE_RE = re.compile(r'`([^`\n]+)`') +_SIG_HEADER_RE = re.compile(r'^#{1,6}\s+(.+)$', re.MULTILINE) +_SIG_BLOCKQUOTE_RE = re.compile(r'^>\s*(.*)$', re.MULTILINE) +_SIG_BULLET_RE = re.compile(r'^[-*]\s+', re.MULTILINE) +_SIG_OLIST_RE = re.compile(r'^(\d+)\.\s+', re.MULTILINE) +_SIG_LINK_RE = re.compile(r'\[([^\]]+)\]\(([^)]+)\)') +_SIG_BOLD_RE = re.compile(r'\*\*(.+?)\*\*|__(.+?)__', re.DOTALL) +_SIG_ITALIC_RE = re.compile(r'(? str: + """Strip inline markdown from a table cell for plain-text rendering.""" + s = re.sub(r'\*\*(.+?)\*\*', r'\1', s) + s = re.sub(r'__(.+?)__', r'\1', s) + s = re.sub(r'~~(.+?)~~', r'\1', s) + s = re.sub(r'`([^`]+)`', r'\1', s) + return s.strip() + + +def _sig_render_table(table_lines: list[str]) -> str: + """Render a markdown pipe-table as fixed-width plain text.""" + + def dw(s: str) -> int: + return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s) + + rows: list[list[str]] = [] + has_sep = False + for line in table_lines: + cells = [_sig_strip_cell(c) for c in line.strip().strip('|').split('|')] + if all(re.match(r'^:?-+:?$', c) for c in cells if c): + has_sep = True + continue + rows.append(cells) + if not rows or not has_sep: + return '\n'.join(table_lines) + + ncols = max(len(r) for r in rows) + for r in rows: + r.extend([''] * (ncols - len(r))) + widths = [max(dw(r[c]) for r in rows) for c in range(ncols)] + + def dr(cells: list[str]) -> str: + return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths)) + + out = [dr(rows[0])] + out.append(' '.join('─' * w for w in widths)) + for row in rows[1:]: + out.append(dr(row)) + return '\n'.join(out) + + +def _markdown_to_signal(text: str) -> tuple[str, list[str]]: + """Convert markdown text to Signal plain text + textStyle ranges. + + Returns ``(plain_text, text_styles)`` where ``text_styles`` are + ``"start:length:STYLE"`` strings for the signal-cli ``textStyle`` parameter. + """ + if not text: + return text, [] + + # Phase 1 (text-level): extract code blocks and tables with placeholder tokens + # so they're protected from inline-style processing. + protected: list[str] = [] + + def save_code(m: re.Match) -> str: + protected.append(m.group(1)) + return f"\x00C{len(protected) - 1}\x00" + + text = _SIG_CODE_BLOCK_RE.sub(save_code, text) + + # Detect and render pipe-tables line by line. + lines = text.split('\n') + rebuilt: list[str] = [] + i = 0 + while i < len(lines): + if re.match(r'^\s*\|.+\|', lines[i]): + tbl: list[str] = [] + while i < len(lines) and re.match(r'^\s*\|.+\|', lines[i]): + tbl.append(lines[i]) + i += 1 + rendered = _sig_render_table(tbl) + if rendered != '\n'.join(tbl): + protected.append(rendered) + rebuilt.append(f"\x00C{len(protected) - 1}\x00") + else: + rebuilt.extend(tbl) + else: + rebuilt.append(lines[i]) + i += 1 + text = '\n'.join(rebuilt) + + # Phase 2 (run-based): process inline patterns. + runs: list[_Run] = [_Run(text)] + + def transform(pattern: re.Pattern, make_runs: Any) -> None: + new_runs: list[_Run] = [] + for run in runs: + if run.opaque: + new_runs.append(run) + continue + pos = 0 + for m in pattern.finditer(run.text): + if m.start() > pos: + new_runs.append(_Run(run.text[pos:m.start()], run.styles)) + new_runs.extend(make_runs(m, run.styles)) + pos = m.end() + if pos < len(run.text): + new_runs.append(_Run(run.text[pos:], run.styles)) + runs[:] = new_runs + + # Restore code/table placeholders as opaque MONOSPACE runs. + transform(_SIG_TOKEN_RE, lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)]) + + # Inline code (opaque). + transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)]) + + # Headers → bold plain text. + transform(_SIG_HEADER_RE, lambda m, s: [_Run(m.group(1), s | {"BOLD"})]) + + # Blockquotes → strip marker. + transform(_SIG_BLOCKQUOTE_RE, lambda m, s: [_Run(m.group(1), s)]) + + # Bullet lists → bullet character. + transform(_SIG_BULLET_RE, lambda m, s: [_Run("• ", s)]) + + # Numbered lists → normalize spacing. + transform(_SIG_OLIST_RE, lambda m, s: [_Run(m.group(1) + ". ", s)]) + + # Links → "text (url)" or bare url when text equals url. + def _link_runs(m: re.Match, s: frozenset) -> list[_Run]: + link_text, url = m.group(1), m.group(2) + + def _norm(u: str) -> str: + return re.sub(r'^https?://(www\.)?', '', u).rstrip('/').lower() + + if _norm(url) == _norm(link_text): + return [_Run(url, s)] + return [_Run(f"{link_text} ({url})", s)] + + transform(_SIG_LINK_RE, _link_runs) + + # Bold (before italic so ** doesn't interfere). + transform(_SIG_BOLD_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"BOLD"})]) + + # Italic (single * or _). + transform(_SIG_ITALIC_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"ITALIC"})]) + + # Strikethrough: ~~text~~ (standard) or ~text~ (single-tilde variant). + transform(_SIG_STRIKE_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"STRIKETHROUGH"})]) + + # Phase 3: assemble output. + plain_text = "" + text_styles: list[str] = [] + for run in runs: + if not run.text: + continue + start = len(plain_text) + plain_text += run.text + length = len(plain_text) - start + for style in sorted(run.styles): + text_styles.append(f"{start}:{length}:{style}") + + return plain_text, text_styles + + +class SignalDMConfig(Base): + """Signal DM policy configuration.""" + + enabled: bool = False + policy: str = "allowlist" # "open" or "allowlist" + allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers/UUIDs + + +class SignalGroupConfig(Base): + """Signal group policy configuration.""" + + enabled: bool = False + policy: str = "allowlist" # "open" or "allowlist" - which groups to operate in + allow_from: list[str] = Field(default_factory=list) # Allowed group IDs if allowlist policy + require_mention: bool = True # Whether bot must be mentioned to respond + + +class SignalConfig(Base): + """Signal channel configuration using signal-cli daemon (HTTP mode with -a flag only).""" + + enabled: bool = False + phone_number: str = "" # Your Signal phone number (e.g., "+1234567890") + daemon_host: str = "localhost" + daemon_port: int = 8080 + group_message_buffer_size: int = 20 # Number of recent group messages to keep for context + dm: SignalDMConfig = Field(default_factory=SignalDMConfig) + group: SignalGroupConfig = Field(default_factory=SignalGroupConfig) + + @property + def allow_from(self) -> list[str]: + """Aggregate allowlist for the base-class is_allowed() check. + + Returns the union of dm.allow_from and group.allow_from so the base + channel gate sees a populated list when either sub-policy is configured. + A ``"*"`` wildcard in either sub-list propagates to allow all. + """ + return list(dict.fromkeys(self.dm.allow_from + self.group.allow_from)) + + +class SignalChannel(BaseChannel): + """ + Signal channel using signal-cli daemon via HTTP JSON-RPC interface. + + Requires signal-cli daemon in HTTP mode: + - signal-cli -a +1234567890 daemon --http localhost:8080 + + See https://github.com/AsamK/signal-cli for setup instructions. + """ + + name = "signal" + display_name = "Signal" + _TYPING_REFRESH_SECONDS = 10.0 + _MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB) + + @classmethod + def default_config(cls) -> dict[str, Any]: + return SignalConfig().model_dump(by_alias=True) + + def __init__(self, config: SignalConfig, bus: MessageBus): + if isinstance(config, dict): + config = SignalConfig.model_validate(config) + super().__init__(config, bus) + self.config: SignalConfig = config + self._http: httpx.AsyncClient | None = None + self._request_id = 0 + self._sse_task: asyncio.Task | None = None + self._typing_tasks: dict[str, asyncio.Task] = {} + self._typing_uuid_warnings: set[str] = set() + self._account_id_aliases: set[str] = set() + self._remember_account_id_alias(self.config.phone_number) + + # Rolling message buffer for group context (group_id -> deque of messages) + # Each message is a dict with: sender_name, sender_number, content, timestamp + self._group_buffers: dict[str, deque] = {} + + async def start(self) -> None: + """Start the Signal channel and connect to signal-cli daemon.""" + if not self.config.phone_number: + self.self.logger.error("Signal account not configured") + return + + self._running = True + await self._start_http_mode() + + async def _start_http_mode(self) -> None: + """Start Signal channel using Server-Sent Events for receiving messages.""" + base_url = f"http://{self.config.daemon_host}:{self.config.daemon_port}" + reconnect_delay_s = 1.0 + max_reconnect_delay_s = 30.0 + + while self._running: + try: + self.logger.info(f"Connecting to signal-cli daemon at {base_url}...") + + # Create HTTP client + self._http = httpx.AsyncClient(timeout=60.0, base_url=base_url) + + # Test connection + try: + response = await self._http.get("/api/v1/check") + if response.status_code == 200: + self.logger.info("Connected to signal-cli daemon") + else: + raise ConnectionRefusedError( + f"signal-cli daemon check returned status {response.status_code}" + ) + except Exception as e: + raise ConnectionRefusedError(f"signal-cli daemon not responding: {e}") + + # Reset reconnect delay after successful connection check. + reconnect_delay_s = 1.0 + + # Ensure account-level typing indicators are enabled. + await self._ensure_typing_indicators_enabled() + + # Start SSE receiver and supervise it. If it exits while we're still + # running, treat it as a disconnect and reconnect. + self._sse_task = asyncio.create_task(self._sse_receive_loop()) + await self._sse_task + if self._running: + raise ConnectionError("Signal SSE stream ended unexpectedly") + + except asyncio.CancelledError: + break + except ConnectionRefusedError as e: + self.logger.error( + f"{e}. Make sure signal-cli daemon is running: " + f"signal-cli -a {self.config.phone_number} daemon --http {self.config.daemon_host}:{self.config.daemon_port}" + ) + except Exception as e: + self.logger.error(f"Signal channel error: {e}") + finally: + if self._sse_task: + if not self._sse_task.done(): + self._sse_task.cancel() + try: + await self._sse_task + except asyncio.CancelledError: + pass + except Exception: + pass + self._sse_task = None + if self._http: + await self._http.aclose() + self._http = None + + if self._running: + self.logger.info( + f"Reconnecting to signal-cli daemon in {reconnect_delay_s:.0f} seconds..." + ) + await asyncio.sleep(reconnect_delay_s) + reconnect_delay_s = min(reconnect_delay_s * 2, max_reconnect_delay_s) + + async def stop(self) -> None: + """Stop the Signal channel.""" + self._running = False + + # Stop SSE task + if self._sse_task: + self._sse_task.cancel() + try: + await self._sse_task + except asyncio.CancelledError: + pass + + # Cancel active typing indicators + for chat_id in list(self._typing_tasks): + await self._stop_typing(chat_id) + + # Close HTTP client + if self._http: + await self._http.aclose() + self._http = None + + async def send(self, msg: OutboundMessage) -> None: + """Send a message through Signal.""" + is_progress_message = bool(msg.metadata.get("_progress")) + try: + plain_text, text_styles = _markdown_to_signal(msg.content) + if not plain_text and not msg.media: + return + recipient_params = self._recipient_params(msg.chat_id) + + chunks = split_message(plain_text, self._MAX_MESSAGE_LEN) if plain_text else [""] + for i, chunk in enumerate(chunks): + params: dict[str, Any] = {"message": chunk} + if text_styles and i == 0: + params["textStyle"] = text_styles + params.update(recipient_params) + if msg.media and i == 0: + params["attachments"] = msg.media + + response = await self._send_request("send", params) + + if "error" in response: + self.logger.error(f"Error sending Signal message: {response['error']}") + else: + self.logger.debug( + f"Signal message sent, timestamp: {response.get('result', {}).get('timestamp')}" + ) + + except Exception: + self.logger.exception("Error sending Signal message") + raise + finally: + # Keep typing active across progress updates; stop on the final reply. + if not is_progress_message: + # Avoid immediate START->STOP for fast responses, which can be invisible + # in some Signal clients. Let indicator expire naturally (~15s). + await self._stop_typing(msg.chat_id, send_stop=False) + + async def _sse_receive_loop(self) -> None: + """Receive messages via Server-Sent Events (HTTP mode).""" + if not self._http: + raise RuntimeError("HTTP client not initialized for Signal SSE stream") + + self.logger.info("Started Signal message receive loop (SSE)") + + try: + async with self._http.stream("GET", "/api/v1/events") as response: + if response.status_code != 200: + raise ConnectionError( + f"SSE connection failed with status {response.status_code}" + ) + + self.logger.info("Subscribed to Signal messages via SSE") + + # Buffer for accumulating SSE data across multiple lines + event_buffer = [] + + async for line in response.aiter_lines(): + if not self._running: + break + + # Debug: log raw SSE lines (except keepalive pings) + if line and line != ":": + self.logger.debug(f"SSE line received: {line[:200]}") + + # SSE format handling + if isinstance(line, str): + # Empty line signals end of event + if not line or line == ":": + if event_buffer: + # Try to parse the accumulated data + data_str = "" + try: + data_str = "".join(event_buffer) + data = json.loads(data_str) + self.logger.debug(f"SSE event parsed: {data}") + await self._handle_receive_notification(data) + except json.JSONDecodeError as e: + self.logger.warning( + f"Invalid JSON in SSE buffer: {e}, data: {data_str[:200]}" + ) + finally: + event_buffer = [] + + # "data:" line - accumulate it + elif line.startswith("data:"): + event_buffer.append(line[5:]) # Skip "data:" prefix + + # "event:" line - just log it (we only care about data) + elif line.startswith("event:"): + pass # Ignore event type for now + + if self._running: + raise ConnectionError("Signal SSE stream closed by remote endpoint") + + except asyncio.CancelledError: + self.logger.info("SSE receive loop cancelled") + raise + except Exception as e: + self.logger.error(f"Error in SSE receive loop: {e}") + raise + + async def _handle_receive_notification(self, params: dict[str, Any]) -> None: + """Handle incoming message notification from signal-cli.""" + self.logger.debug(f"_handle_receive_notification called with: {params}") + try: + # Extract envelope from SSE notification: {"envelope": {...}} + envelope = params.get("envelope", {}) + + self.logger.debug(f"Extracted envelope: {envelope}") + + if not envelope: + self.logger.debug("No envelope found in params") + return + + # Extract sender information + sender_parts = self._collect_sender_id_parts(envelope) + source_name = envelope.get("sourceName") + + if not sender_parts: + self.logger.debug("Received message without source, skipping") + return + + sender_number = self._primary_sender_id(sender_parts) + sender_id = "|".join(sender_parts) + + # Keep aliases of the bot account for robust mention matching. + if any(self._id_matches_account(part) for part in sender_parts): + for part in sender_parts: + self._remember_account_id_alias(part) + + # Check different message types + data_message = envelope.get("dataMessage") + sync_message = envelope.get("syncMessage") + typing_message = envelope.get("typingMessage") + receipt_message = envelope.get("receiptMessage") + + # Ignore receipt messages (delivery/read receipts) + if receipt_message: + return + + # Handle data messages (incoming messages from others) + if data_message: + await self._handle_data_message(sender_id, sender_number, data_message, source_name) + + # Handle sync messages (messages sent from another device) + elif sync_message and sync_message.get("sentMessage"): + sent_msg = sync_message["sentMessage"] + destination = sent_msg.get("destination") or sent_msg.get("destinationNumber") + if destination: + self.logger.debug( + f"Sync message sent to {destination}: {sent_msg.get('message', '')[:50]}" + ) + + # Handle typing indicators (silently ignore) + elif typing_message: + pass # Ignore typing indicators + + except Exception as e: + self.logger.error(f"Error handling receive notification: {e}") + + async def _handle_data_message( + self, + sender_id: str, + sender_number: str, + data_message: dict[str, Any], + sender_name: str | None, + ) -> None: + """Handle a data message (text, attachments, etc.).""" + message_text = data_message.get("message") or "" + attachments = data_message.get("attachments", []) + group_info = data_message.get("groupInfo") + timestamp = data_message.get("timestamp") + mentions = data_message.get("mentions", []) + reaction = data_message.get("reaction") + + # Log full data_message for debugging group detection + self.logger.info( + f"Data message from {sender_number}: " + f"groupInfo={group_info}, " + f"groupV2={data_message.get('groupV2')}, " + f"keys={list(data_message.keys())}" + ) + + # Ignore reaction messages (emoji reactions to messages) + if reaction: + self.logger.debug(f"Ignoring reaction message from {sender_number}: {reaction}") + return + + # Ignore empty messages (e.g., when bot is added to a group) + if not message_text and not attachments: + self.logger.debug(f"Ignoring empty message from {sender_number}") + return + + # Determine chat_id (group ID or sender number) + # Check both groupInfo (v1) and groupV2 (v2) fields for group detection + group_v2 = data_message.get("groupV2") + is_group_message = group_info is not None or group_v2 is not None + group_id = self._extract_group_id(group_info, group_v2) + + is_command = bool(message_text and message_text.strip().startswith("/")) + + if is_group_message: + chat_id = group_id or sender_number + + # Check if this group is allowed before doing anything else + if not self.config.group.enabled: + self.logger.info(f"Ignoring group message from {chat_id} (groups disabled)") + return + if ( + self.config.group.policy == "allowlist" + and chat_id not in self.config.group.allow_from + ): + self.logger.info( + f"Ignoring group message from {chat_id} (policy: {self.config.group.policy})" + ) + return + + # Add to group message buffer (group is allowed) + self._add_to_group_buffer( + group_id=chat_id, + sender_name=sender_name or sender_number, + sender_number=sender_number, + message_text=message_text, + timestamp=timestamp, + ) + + # Commands bypass the mention requirement; non-commands require it. + if not is_command and not self._should_respond_in_group(message_text, mentions): + self.logger.info( + f"Ignoring group message (require_mention: {self.config.group.require_mention})" + ) + return + else: + # Direct message — check policy first, then forward everything to the bus. + chat_id = sender_number + if not self.config.dm.enabled: + self.logger.debug(f"Ignoring DM from {sender_id} (DMs disabled)") + return + if self.config.dm.policy == "allowlist": + allow_list = self.config.dm.allow_from + sender_str = str(sender_id) + parts = [sender_str] + (sender_str.split("|") if "|" in sender_str else []) + if not any(p for p in parts if p in allow_list): + self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})") + return + + # Build content from text and attachments + content_parts = [] + media_paths = [] + + # For group messages, include recent message context + if is_group_message: + buffer_context = self._get_group_buffer_context(chat_id) + if buffer_context: + content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---") + + # Prepend sender name for group messages so history shows who said what + if message_text: + # Strip bot mentions from text (for group messages) + if is_group_message: + message_text = self._strip_bot_mention(message_text, mentions) + # Prepend sender name to make it clear who is speaking + display_name = sender_name or sender_number + message_text = f"[{display_name}]: {message_text}" + content_parts.append(message_text) + + # Handle attachments + if attachments: + media_dir = get_media_dir("signal") + + for attachment in attachments: + attachment_id = attachment.get("id") + content_type = attachment.get("contentType", "") + filename = attachment.get("filename") or f"attachment_{attachment_id}" + + if not attachment_id: + continue + + try: + # signal-cli stores attachments in ~/.local/share/signal-cli/attachments/ + source_path = ( + Path.home() / ".local/share/signal-cli/attachments" / attachment_id + ) + + if source_path.exists(): + dest_path = media_dir / f"signal_{safe_filename(filename)}" + shutil.copy2(source_path, dest_path) + media_paths.append(str(dest_path)) + + # Determine media type from content type + media_type = content_type.split("/")[0] if "/" in content_type else "file" + if media_type not in ("image", "audio", "video"): + media_type = "file" + + content_parts.append(f"[{media_type}: {dest_path}]") + self.logger.debug(f"Downloaded attachment: {filename} -> {dest_path}") + else: + self.logger.warning(f"Attachment not found: {source_path}") + content_parts.append(f"[attachment: {filename} - not found]") + + except Exception as e: + self.logger.warning(f"Failed to process attachment {filename}: {e}") + content_parts.append(f"[attachment: {filename} - error]") + + content = "\n".join(content_parts) if content_parts else "[empty message]" + + self.logger.debug(f"Signal message from {sender_number}: {content[:50]}...") + + await self._start_typing(chat_id) + try: + # Forward to message bus + await self._handle_message( + sender_id=sender_id, + chat_id=chat_id, + content=content, + media=media_paths, + metadata={ + "timestamp": timestamp, + "sender_name": sender_name, + "sender_number": sender_number, + "is_group": is_group_message, + "group_id": group_id, + }, + ) + except Exception: + await self._stop_typing(chat_id) + raise + + def _add_to_group_buffer( + self, + group_id: str, + sender_name: str, + sender_number: str, + message_text: str, + timestamp: int | None, + ) -> None: + """ + Add a message to the group's rolling buffer. + + Args: + group_id: The group ID + sender_name: Display name of sender + sender_number: Phone number of sender + message_text: The message content + timestamp: Message timestamp + """ + if self.config.group_message_buffer_size <= 0: + return + + # Create buffer for this group if it doesn't exist + if group_id not in self._group_buffers: + self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size) + + # Add message to buffer (deque will automatically drop oldest when full) + self._group_buffers[group_id].append( + { + "sender_name": sender_name, + "sender_number": sender_number, + "content": message_text, + "timestamp": timestamp, + } + ) + + self.logger.debug( + f"Added message to group buffer {group_id}: " + f"{len(self._group_buffers[group_id])}/{self.config.group_message_buffer_size}" + ) + + def _get_group_buffer_context(self, group_id: str) -> str: + """ + Get formatted context from the group's message buffer. + + Args: + group_id: The group ID + + Returns: + Formatted string of recent messages (excluding the current one) + """ + if group_id not in self._group_buffers: + return "" + + buffer = self._group_buffers[group_id] + if len(buffer) <= 1: # Only current message, no context + return "" + + # Format all messages except the last one (which is the current message) + # We want to show context BEFORE the mention + context_messages = list(buffer)[:-1] # Exclude the last (current) message + + lines = [] + for msg in context_messages: + sender = msg["sender_name"] + content = msg["content"][:200] # Limit to 200 chars per message + lines.append(f"{sender}: {content}") + + return "\n".join(lines) + + @staticmethod + def _normalize_signal_id(value: str) -> list[str]: + """Normalize Signal identifiers (phone/uuid/service-id) for matching.""" + raw = value.strip() + if not raw: + return [] + + normalized = [raw, raw.lower()] + if raw.startswith("+") and len(raw) > 1: + normalized.append(raw[1:]) + elif raw.isdigit(): + normalized.append(f"+{raw}") + return list(dict.fromkeys(normalized)) + + def _remember_account_id_alias(self, value: str | None) -> None: + """Remember known bot identifiers for mention matching.""" + if not value: + return + if not isinstance(value, str): + return + for candidate in self._normalize_signal_id(value): + self._account_id_aliases.add(candidate) + + def _id_matches_account(self, value: str | None) -> bool: + """Return True when an identifier refers to the bot account.""" + if not value: + return False + if not isinstance(value, str): + return False + return any( + candidate in self._account_id_aliases for candidate in self._normalize_signal_id(value) + ) + + @staticmethod + def _collect_sender_id_parts(envelope: dict[str, Any]) -> list[str]: + """Collect all known sender identifier variants from an envelope.""" + parts: list[str] = [] + for key in ( + "sourceNumber", + "source", + "sourceUuid", + "sourceServiceId", + "sourceAci", + "sourceACI", + ): + value = envelope.get(key) + if not isinstance(value, str): + continue + candidate = value.strip() + if candidate and candidate not in parts: + parts.append(candidate) + return parts + + @staticmethod + def _primary_sender_id(sender_parts: list[str]) -> str: + """Pick the best sender identifier for routing (prefer phone-like IDs).""" + for part in sender_parts: + if part.startswith("+") or part.isdigit(): + return part + return sender_parts[0] if sender_parts else "" + + @staticmethod + def _extract_group_id(group_info: Any, group_v2: Any) -> str | None: + """Extract group ID from groupInfo/groupV2 payloads across signal-cli variants.""" + for group_obj in (group_info, group_v2): + if not isinstance(group_obj, dict): + continue + for key in ("groupId", "id", "groupID"): + value = group_obj.get(key) + if isinstance(value, str) and value: + return value + return None + + @staticmethod + def _mention_id_candidates(mention: dict[str, Any]) -> list[str]: + """Extract possible identifier fields from a mention payload.""" + ids: list[str] = [] + + def _walk(value: Any, depth: int = 0) -> None: + if depth > 2: + return + if not isinstance(value, dict): + return + for key, child in value.items(): + key_lower = str(key).lower() + if isinstance(child, str) and child: + if any(token in key_lower for token in ("number", "uuid", "serviceid", "aci")): + ids.append(child) + elif isinstance(child, dict): + _walk(child, depth + 1) + + _walk(mention) + return list(dict.fromkeys(ids)) + + @staticmethod + def _mention_span(mention: dict[str, Any]) -> tuple[int, int] | None: + """Extract a safe (start, length) span from a mention.""" + try: + start = int(mention.get("start", 0)) + length = int(mention.get("length", 0)) + except (TypeError, ValueError): + return None + + if start < 0 or length <= 0: + return None + return (start, length) + + @staticmethod + def _leading_placeholder_span(text: str | None) -> tuple[int, int] | None: + """ + Detect a leading Signal mention placeholder when mention metadata is missing. + + Some clients/integrations deliver mentions as a leading placeholder character + (typically U+FFFC) but omit `mentions` metadata in the payload. + """ + if not text: + return None + + start = 0 + while start < len(text) and text[start].isspace(): + start += 1 + + if start >= len(text): + return None + + marker = text[start] + if marker not in ("\ufffc", "\ufffd", "\x1b"): + return None + + next_index = start + 1 + if next_index < len(text) and not text[next_index].isspace(): + return None + + return (start, 1) + + def _should_respond_in_group(self, message_text: str, mentions: list[dict[str, Any]]) -> bool: + """ + Determine if the bot should respond to a group message. + + Args: + message_text: The message text content + mentions: List of mentions from Signal (format: [{"number": "+1234567890", "start": 0, "length": 10}]) + + Returns: + True if bot should respond, False otherwise + """ + # Group reply behavior is controlled only by group.require_mention. + if not self.config.group.require_mention: + return True + + # If mention is required, check if bot was mentioned. + for mention in mentions: + if not isinstance(mention, dict): + continue + for mention_id in self._mention_id_candidates(mention): + if self._id_matches_account(mention_id): + return True + + # Some Signal clients emit mention spans without recipient identifiers + # (for handle-style mentions). Accept a leading identifier-less mention + # as a mention of the bot to avoid false negatives. + for mention in mentions: + if not isinstance(mention, dict): + continue + if self._mention_id_candidates(mention): + continue + span = self._mention_span(mention) + if not span: + continue + start, _ = span + if message_text is not None and not message_text[:start].strip(): + self.logger.debug("Accepting identifier-less leading mention as bot mention") + return True + + # Some payloads omit `mentions` but still include the leading mention + # placeholder character in the message body. + if not mentions and self._leading_placeholder_span(message_text): + self.logger.debug("Accepting leading placeholder mention without mention metadata") + return True + + # Fallback: check for configured phone number in plain text. + if message_text and self.config.phone_number: + for account_id in self._normalize_signal_id(self.config.phone_number): + if account_id and account_id in message_text: + return True + + return False + + def _strip_bot_mention(self, text: str, mentions: list[dict[str, Any]]) -> str: + """ + Remove bot mentions from message text. + + Signal mentions are embedded in the text, so we need to remove them based on + the mentions array which provides start position and length. + + Args: + text: Original message text + mentions: List of mention objects with start/length positions + + Returns: + Text with bot mentions removed + """ + if not text: + return text + + # Build a list of (start, length) tuples for our bot's mentions + bot_mentions = [] + for mention in mentions: + if not isinstance(mention, dict): + continue + mention_ids = self._mention_id_candidates(mention) + span = self._mention_span(mention) + if not span: + continue + + # Strip matched bot mentions by ID. + if any(self._id_matches_account(mention_id) for mention_id in mention_ids): + bot_mentions.append(span) + continue + + # Also strip identifier-less leading mention spans (handle mentions). + if not mention_ids: + start, _ = span + if not text[:start].strip(): + bot_mentions.append(span) + + if not bot_mentions: + placeholder_span = self._leading_placeholder_span(text) + if placeholder_span: + bot_mentions.append(placeholder_span) + + # Sort mentions by start position (descending) to remove from end to start + # This prevents position shifts when removing earlier mentions + bot_mentions.sort(reverse=True) + + # Remove each mention + for start, length in bot_mentions: + if start >= len(text): + continue + end = min(len(text), start + length) + text = text[:start] + text[end:] + + return text.strip() + + @staticmethod + def _is_group_chat_id(chat_id: str) -> bool: + """Return True when chat_id appears to be a Signal group ID (base64).""" + return "=" in chat_id or (len(chat_id) > 40 and "-" not in chat_id) + + def _recipient_params(self, chat_id: str) -> dict[str, Any]: + """Build recipient params for signal-cli JSON-RPC methods.""" + if self._is_group_chat_id(chat_id): + return {"groupId": chat_id} + return {"recipient": [chat_id]} + + async def _start_typing(self, chat_id: str) -> None: + """Start periodic typing indicator updates for a chat.""" + await self._stop_typing(chat_id, send_stop=False) + await self._send_typing(chat_id) + self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id)) + + async def _stop_typing(self, chat_id: str, send_stop: bool = True) -> None: + """Stop typing indicator updates for a chat.""" + task = self._typing_tasks.pop(chat_id, None) + had_task = task is not None + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + if send_stop and had_task: + await self._send_typing(chat_id, stop=True) + + async def _typing_loop(self, chat_id: str) -> None: + """Send typing updates periodically until cancelled.""" + try: + while self._running: + await asyncio.sleep(self._TYPING_REFRESH_SECONDS) + await self._send_typing(chat_id, quiet_success=True) + except asyncio.CancelledError: + pass + except Exception as e: + self.logger.debug(f"Typing indicator loop stopped for {chat_id}: {e}") + + async def _send_typing( + self, chat_id: str, stop: bool = False, quiet_success: bool = False + ) -> None: + """Send a typing START/STOP message via signal-cli.""" + action = "stop" if stop else "start" + if ( + not self._is_group_chat_id(chat_id) + and chat_id.startswith("+") is False + and chat_id not in self._typing_uuid_warnings + ): + self._typing_uuid_warnings.add(chat_id) + self.logger.warning( + "Signal DM recipient is UUID-only (no phone number in envelope). " + "Some Signal clients may not render typing indicators for this recipient form." + ) + candidate_params: list[dict[str, Any]] + if self._is_group_chat_id(chat_id): + candidate_params = [{"groupId": chat_id}, {"groupId": [chat_id]}] + else: + candidate_params = [{"recipient": chat_id}, {"recipient": [chat_id]}] + + last_error: Any | None = None + for params in candidate_params: + if stop: + params["stop"] = True + try: + response = await self._send_request("sendTyping", params) + except Exception as e: + last_error = str(e) + continue + + if "error" not in response: + if not quiet_success: + self.logger.info(f"Signal typing {action} sent for {chat_id}") + return + + last_error = response["error"] + + self.logger.warning(f"Failed to send Signal typing {action} for {chat_id}: {last_error}") + + async def _ensure_typing_indicators_enabled(self) -> None: + """Enable typing indicators on the bot account.""" + response = await self._send_request("updateConfiguration", {"typingIndicators": True}) + if "error" in response: + self.logger.warning(f"Failed to enable Signal typing indicators: {response['error']}") + else: + self.logger.info("Signal typing indicators enabled on account configuration") + + async def _send_request( + self, method: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Send a JSON-RPC request via HTTP and wait for response.""" + # Generate request ID + self._request_id += 1 + request_id = self._request_id + + # Build JSON-RPC request + request = {"jsonrpc": "2.0", "method": method, "id": request_id} + + if params: + request["params"] = params + + return await self._send_http_request(request) + + async def _send_http_request(self, request: dict[str, Any]) -> dict[str, Any]: + """Send JSON-RPC request via HTTP.""" + if not self._http: + raise RuntimeError("Not connected to signal-cli daemon") + + try: + response = await self._http.post("/api/v1/rpc", json=request) + response.raise_for_status() + return response.json() + except Exception as e: + self.logger.error(f"HTTP request failed: {e}") + return {"error": {"message": str(e)}} diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py new file mode 100644 index 000000000..b5149459b --- /dev/null +++ b/tests/channels/test_signal_channel.py @@ -0,0 +1,1058 @@ +"""Tests for the Signal channel implementation.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.signal import ( + SignalChannel, + SignalConfig, + SignalDMConfig, + SignalGroupConfig, +) + +# --------------------------------------------------------------------------- +# Fake HTTP client +# --------------------------------------------------------------------------- + + +class _FakeResponse: + def __init__(self, status_code: int = 200, body: dict | None = None) -> None: + self.status_code = status_code + self._body = body or {} + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self) -> dict: + return self._body + + +class _FakeHTTPClient: + """Minimal httpx.AsyncClient stand-in that records requests.""" + + def __init__(self, *, default_response: dict | None = None) -> None: + self.posts: list[dict] = [] + self.gets: list[str] = [] + self._response = _FakeResponse(body=default_response or {"result": {"timestamp": 123}}) + self.closed = False + + async def get(self, path: str) -> _FakeResponse: + self.gets.append(path) + return self._response + + async def post(self, path: str, *, json: dict) -> _FakeResponse: + self.posts.append({"path": path, "json": json}) + return self._response + + async def aclose(self) -> None: + self.closed = True + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_channel( + *, + phone_number: str = "+10000000000", + dm_enabled: bool = True, + dm_policy: str = "open", + dm_allow_from: list[str] | None = None, + group_enabled: bool = False, + group_policy: str = "open", + group_allow_from: list[str] | None = None, + require_mention: bool = True, + group_buffer_size: int = 20, +) -> SignalChannel: + config = SignalConfig( + enabled=True, + phone_number=phone_number, + dm=SignalDMConfig( + enabled=dm_enabled, + policy=dm_policy, + allow_from=dm_allow_from or [], + ), + group=SignalGroupConfig( + enabled=group_enabled, + policy=group_policy, + allow_from=group_allow_from or [], + require_mention=require_mention, + ), + group_message_buffer_size=group_buffer_size, + ) + return SignalChannel(config, MessageBus()) + + +def _dm_envelope( + *, + source_number: str = "+19995550001", + source_uuid: str | None = None, + source_name: str | None = "Alice", + message: str = "hello", + attachments: list | None = None, + reaction: dict | None = None, + timestamp: int = 1000, +) -> dict: + data_message: dict = {"message": message, "timestamp": timestamp} + if attachments is not None: + data_message["attachments"] = attachments + if reaction is not None: + data_message["reaction"] = reaction + envelope: dict = { + "sourceNumber": source_number, + "sourceName": source_name, + "dataMessage": data_message, + } + if source_uuid: + envelope["sourceUuid"] = source_uuid + return {"envelope": envelope} + + +def _group_envelope( + *, + source_number: str = "+19995550001", + source_name: str = "Bob", + group_id: str = "group123==", + message: str = "hey group", + mentions: list | None = None, + timestamp: int = 2000, + use_v2: bool = False, +) -> dict: + group_obj = {"groupId": group_id} + key = "groupV2" if use_v2 else "groupInfo" + data_message: dict = { + "message": message, + "timestamp": timestamp, + key: group_obj, + "mentions": mentions or [], + } + return { + "envelope": { + "sourceNumber": source_number, + "sourceName": source_name, + "dataMessage": data_message, + } + } + + +# --------------------------------------------------------------------------- +# Static utility tests +# --------------------------------------------------------------------------- + + +class TestNormalizeSignalId: + def test_phone_number_kept_and_stripped(self): + result = SignalChannel._normalize_signal_id("+12345678901") + assert "+12345678901" in result + assert "12345678901" in result + + def test_digits_only_gets_plus_prefix(self): + result = SignalChannel._normalize_signal_id("12345678901") + assert "+12345678901" in result + + def test_lowercase_variant_added(self): + result = SignalChannel._normalize_signal_id("SOME-UUID") + assert "some-uuid" in result + + def test_empty_string_returns_empty(self): + assert SignalChannel._normalize_signal_id("") == [] + + def test_whitespace_stripped(self): + result = SignalChannel._normalize_signal_id(" +1234 ") + assert "+1234" in result + + +class TestCollectSenderIdParts: + def test_collects_source_number(self): + env = {"sourceNumber": "+15551234567"} + parts = SignalChannel._collect_sender_id_parts(env) + assert "+15551234567" in parts + + def test_collects_multiple_keys(self): + env = {"sourceNumber": "+15551234567", "sourceUuid": "uuid-abc"} + parts = SignalChannel._collect_sender_id_parts(env) + assert "+15551234567" in parts + assert "uuid-abc" in parts + + def test_deduplicates(self): + env = {"sourceNumber": "+15551234567", "source": "+15551234567"} + parts = SignalChannel._collect_sender_id_parts(env) + assert parts.count("+15551234567") == 1 + + def test_ignores_non_string_values(self): + env = {"sourceNumber": 12345, "sourceUuid": None} + parts = SignalChannel._collect_sender_id_parts(env) + assert parts == [] + + def test_empty_envelope_returns_empty(self): + assert SignalChannel._collect_sender_id_parts({}) == [] + + +class TestPrimarySenderId: + def test_prefers_phone_number(self): + assert SignalChannel._primary_sender_id(["+1234", "uuid-abc"]) == "+1234" + + def test_accepts_digit_only(self): + assert SignalChannel._primary_sender_id(["1234567890", "uuid-abc"]) == "1234567890" + + def test_falls_back_to_first_part(self): + assert SignalChannel._primary_sender_id(["uuid-abc", "other"]) == "uuid-abc" + + def test_empty_list_returns_empty(self): + assert SignalChannel._primary_sender_id([]) == "" + + +class TestExtractGroupId: + def test_extracts_from_group_info(self): + gid = SignalChannel._extract_group_id({"groupId": "abc=="}, None) + assert gid == "abc==" + + def test_extracts_from_group_v2(self): + gid = SignalChannel._extract_group_id(None, {"id": "xyz=="}) + assert gid == "xyz==" + + def test_prefers_group_info_over_v2(self): + gid = SignalChannel._extract_group_id({"groupId": "first"}, {"groupId": "second"}) + assert gid == "first" + + def test_returns_none_when_both_none(self): + assert SignalChannel._extract_group_id(None, None) is None + + def test_returns_none_when_not_dicts(self): + assert SignalChannel._extract_group_id("bad", 123) is None + + +class TestIsGroupChatId: + def test_base64_with_equals_is_group(self): + assert SignalChannel._is_group_chat_id("abc==") is True + + def test_long_id_without_dash_is_group(self): + long_id = "a" * 41 + assert SignalChannel._is_group_chat_id(long_id) is True + + def test_phone_number_is_not_group(self): + assert SignalChannel._is_group_chat_id("+12345678901") is False + + def test_uuid_with_dashes_is_not_group(self): + assert SignalChannel._is_group_chat_id("550e8400-e29b-41d4-a716-446655440000") is False + + +class TestRecipientParams: + def test_group_chat_uses_group_id(self): + ch = _make_channel() + params = ch._recipient_params("abc==") + assert params == {"groupId": "abc=="} + + def test_dm_uses_recipient_list(self): + ch = _make_channel() + params = ch._recipient_params("+12345678901") + assert params == {"recipient": ["+12345678901"]} + + +class TestMentionHelpers: + def test_mention_id_candidates_extracts_number(self): + mention = {"number": "+1234567890"} + ids = SignalChannel._mention_id_candidates(mention) + assert "+1234567890" in ids + + def test_mention_id_candidates_extracts_uuid(self): + mention = {"uuid": "some-uuid"} + ids = SignalChannel._mention_id_candidates(mention) + assert "some-uuid" in ids + + def test_mention_span_valid(self): + assert SignalChannel._mention_span({"start": 0, "length": 5}) == (0, 5) + + def test_mention_span_negative_start(self): + assert SignalChannel._mention_span({"start": -1, "length": 5}) is None + + def test_mention_span_zero_length(self): + assert SignalChannel._mention_span({"start": 0, "length": 0}) is None + + def test_mention_span_missing_keys(self): + assert SignalChannel._mention_span({}) is None + + def test_leading_placeholder_ufffc(self): + span = SignalChannel._leading_placeholder_span(" hello") + assert span == (0, 1) + + def test_leading_placeholder_not_at_start(self): + assert SignalChannel._leading_placeholder_span("hello ") is None + + def test_leading_placeholder_empty_string(self): + assert SignalChannel._leading_placeholder_span("") is None + + def test_leading_placeholder_plain_text(self): + assert SignalChannel._leading_placeholder_span("hello") is None + + +# --------------------------------------------------------------------------- +# Account ID alias / mention matching +# --------------------------------------------------------------------------- + + +class TestAccountIdAliases: + def test_phone_number_alias_registered_on_init(self): + ch = _make_channel(phone_number="+10000000000") + assert ch._id_matches_account("+10000000000") + + def test_digit_only_variant_matches(self): + ch = _make_channel(phone_number="+10000000000") + assert ch._id_matches_account("10000000000") + + def test_remember_alias_adds_uuid(self): + ch = _make_channel() + ch._remember_account_id_alias("some-uuid-abc") + assert ch._id_matches_account("some-uuid-abc") + + def test_non_matching_id_returns_false(self): + ch = _make_channel(phone_number="+10000000000") + assert not ch._id_matches_account("+19999999999") + + def test_none_and_non_string_return_false(self): + ch = _make_channel() + assert not ch._id_matches_account(None) + + +# --------------------------------------------------------------------------- +# _should_respond_in_group +# --------------------------------------------------------------------------- + + +class TestShouldRespondInGroup: + def _make_group_channel(self, require_mention: bool = True) -> SignalChannel: + return _make_channel( + phone_number="+10000000000", + group_enabled=True, + require_mention=require_mention, + ) + + def test_no_require_mention_always_responds(self): + ch = self._make_group_channel(require_mention=False) + assert ch._should_respond_in_group("anything", []) is True + + def test_require_mention_with_no_mentions_returns_false(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group("hello", []) is False + + def test_require_mention_with_bot_number_mention(self): + ch = self._make_group_channel(require_mention=True) + mentions = [{"number": "+10000000000", "start": 0, "length": 12}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_require_mention_with_uuid_mention(self): + ch = self._make_group_channel(require_mention=True) + ch._remember_account_id_alias("bot-uuid-123") + mentions = [{"uuid": "bot-uuid-123", "start": 0, "length": 8}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_identifier_less_leading_mention_accepted(self): + ch = self._make_group_channel(require_mention=True) + # Mention with no IDs but leading span — treated as bot mention + mentions = [{"start": 0, "length": 1}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_identifier_less_non_leading_mention_rejected(self): + ch = self._make_group_channel(require_mention=True) + mentions = [{"start": 5, "length": 1}] + assert ch._should_respond_in_group("hello ", mentions) is False + + def test_leading_placeholder_without_mentions_metadata(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group(" hello", []) is True + + def test_phone_number_in_text_triggers_response(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group("hey +10000000000 help", []) is True + + +# --------------------------------------------------------------------------- +# _strip_bot_mention +# --------------------------------------------------------------------------- + + +class TestStripBotMention: + def _make_channel_with_number(self) -> SignalChannel: + return _make_channel(phone_number="+10000000000") + + def test_strips_mention_by_phone(self): + ch = self._make_channel_with_number() + text = " hello" + mentions = [{"number": "+10000000000", "start": 0, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + assert result == "hello" + + def test_strips_identifier_less_leading_mention(self): + ch = self._make_channel_with_number() + text = " hello" + mentions = [{"start": 0, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + assert result == "hello" + + def test_strips_leading_placeholder_without_mention_metadata(self): + ch = self._make_channel_with_number() + text = " hello" + result = ch._strip_bot_mention(text, []) + assert result == "hello" + + def test_non_bot_mention_mid_text_not_stripped(self): + # A non-bot mention that is NOT a leading placeholder leaves the text alone. + ch = self._make_channel_with_number() + text = "hello  world" + mentions = [{"number": "+19999999999", "start": 6, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + # Mid-text placeholder from a non-bot mention should be untouched + assert "" in result + + def test_empty_text_returned_unchanged(self): + ch = self._make_channel_with_number() + assert ch._strip_bot_mention("", []) == "" + + +# --------------------------------------------------------------------------- +# Group message buffer +# --------------------------------------------------------------------------- + + +class TestGroupBuffer: + def test_add_and_get_context(self): + ch = _make_channel(group_buffer_size=5) + ch._add_to_group_buffer("g1", "Alice", "+1111", "first msg", 1000) + ch._add_to_group_buffer("g1", "Bob", "+2222", "second msg", 2000) + # Only messages before the latest are returned as context + ctx = ch._get_group_buffer_context("g1") + assert "first msg" in ctx + # The last message is not included (it's the "current" one) + assert "second msg" not in ctx + + def test_empty_context_when_only_one_message(self): + ch = _make_channel(group_buffer_size=5) + ch._add_to_group_buffer("g1", "Alice", "+1111", "only msg", 1000) + assert ch._get_group_buffer_context("g1") == "" + + def test_empty_context_when_group_unknown(self): + ch = _make_channel() + assert ch._get_group_buffer_context("unknown") == "" + + def test_buffer_respects_max_size(self): + ch = _make_channel(group_buffer_size=3) + for i in range(10): + ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i) + assert len(ch._group_buffers["g1"]) == 3 + + def test_zero_buffer_size_does_not_add(self): + ch = _make_channel(group_buffer_size=0) + ch._add_to_group_buffer("g1", "Alice", "+1111", "msg", 1000) + assert "g1" not in ch._group_buffers + + def test_context_limits_message_length(self): + ch = _make_channel(group_buffer_size=5) + long_msg = "x" * 500 + ch._add_to_group_buffer("g1", "Alice", "+1111", long_msg, 1000) + ch._add_to_group_buffer("g1", "Bob", "+2222", "short", 2000) + ctx = ch._get_group_buffer_context("g1") + # Context is capped at 200 chars per message + assert len(ctx.split("Alice: ", 1)[1]) <= 200 + + +# --------------------------------------------------------------------------- +# _handle_data_message — DM routing +# --------------------------------------------------------------------------- + + +class TestHandleDataMessageDM: + def _make_dm_channel(self, policy="open", allow_from=None) -> tuple[SignalChannel, list]: + ch = _make_channel(dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or []) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + ch._handle_message = capture # type: ignore[method-assign] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + return ch, handled + + @pytest.mark.asyncio + async def test_dm_open_policy_accepted(self): + ch, handled = self._make_dm_channel(policy="open") + params = _dm_envelope(source_number="+19995550001", message="hi") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "+19995550001" + assert handled[0]["content"] == "hi" + + @pytest.mark.asyncio + async def test_dm_allowlist_accepted(self): + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"]) + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_rejected(self): + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+10000000001"]) + params = _dm_envelope(source_number="+19995550002") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_dm_disabled_rejected(self): + ch = _make_channel(dm_enabled=False) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + ch._handle_message = capture # type: ignore[method-assign] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_reaction_message_ignored(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(reaction={"emoji": "👍", "targetTimestamp": 999}) + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_empty_message_ignored(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(message="") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_receipt_message_ignored(self): + ch, handled = self._make_dm_channel() + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "receiptMessage": {"when": 1234}, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + @pytest.mark.asyncio + async def test_typing_indicator_ignored(self): + ch, handled = self._make_dm_channel() + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "typingMessage": {"action": "STARTED"}, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + @pytest.mark.asyncio + async def test_missing_envelope_ignored(self): + ch, handled = self._make_dm_channel() + await ch._handle_receive_notification({}) + assert handled == [] + + @pytest.mark.asyncio + async def test_metadata_passed_to_handle(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(source_number="+19995550001", source_name="Alice", timestamp=9999) + await ch._handle_receive_notification(params) + meta = handled[0]["metadata"] + assert meta["sender_name"] == "Alice" + assert meta["timestamp"] == 9999 + assert meta["is_group"] is False + + @pytest.mark.asyncio + async def test_sender_id_with_uuid_variant(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(source_number="+19995550001", source_uuid="uuid-abc") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + # sender_id combines both parts + assert "+19995550001" in handled[0]["sender_id"] + assert "uuid-abc" in handled[0]["sender_id"] + + @pytest.mark.asyncio + async def test_stop_typing_called_on_handle_error(self): + ch = _make_channel(dm_enabled=True, dm_policy="open") + typing_stopped: list[str] = [] + + async def fail_handle(**kwargs): + raise RuntimeError("boom") + + async def noop_typing(chat_id): + pass + + async def record_stop(chat_id, **kwargs): + typing_stopped.append(chat_id) + + ch._handle_message = fail_handle # type: ignore[method-assign] + ch._start_typing = noop_typing # type: ignore[method-assign] + ch._stop_typing = record_stop # type: ignore[method-assign] + + # _handle_receive_notification swallows exceptions; the typing stop + # still fires from _handle_data_message's except clause. + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + + assert "+19995550001" in typing_stopped + + +# --------------------------------------------------------------------------- +# _handle_data_message — group routing +# --------------------------------------------------------------------------- + + +class TestHandleDataMessageGroup: + def _make_group_channel( + self, + policy="open", + allow_from=None, + require_mention=True, + ) -> tuple[SignalChannel, list]: + ch = _make_channel( + group_enabled=True, + group_policy=policy, + group_allow_from=allow_from or [], + require_mention=require_mention, + ) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + ch._handle_message = capture # type: ignore[method-assign] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + return ch, handled + + @pytest.mark.asyncio + async def test_group_disabled_rejected(self): + ch = _make_channel(group_enabled=False) + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_open_policy_no_mention_blocked_when_required(self): + ch, handled = self._make_group_channel(require_mention=True) + params = _group_envelope(group_id="grp==", message="hey everyone") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_open_policy_no_mention_required(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", message="hey everyone") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "grp==" + + @pytest.mark.asyncio + async def test_group_allowlist_accepted(self): + ch, handled = self._make_group_channel( + policy="allowlist", allow_from=["grp=="], require_mention=False + ) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_group_allowlist_rejected(self): + ch, handled = self._make_group_channel(policy="allowlist", allow_from=["other=="]) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_mention_triggers_response(self): + ch, handled = self._make_group_channel(require_mention=True) + ch._remember_account_id_alias("+10000000000") + mentions = [{"number": "+10000000000", "start": 0, "length": 1}] + params = _group_envelope(group_id="grp==", message=" hello", mentions=mentions) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_group_v2_id_extracted(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grpV2==", message="hi", use_v2=True) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "grpV2==" + + @pytest.mark.asyncio + async def test_group_message_includes_sender_prefix(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", source_name="Bob", message="hello") + await ch._handle_receive_notification(params) + assert "[Bob]:" in handled[0]["content"] + + @pytest.mark.asyncio + async def test_group_message_context_prepended(self): + ch, handled = self._make_group_channel(require_mention=False) + # First message — adds to buffer but no context yet + params1 = _group_envelope(group_id="grp==", source_name="Alice", message="msg1") + await ch._handle_receive_notification(params1) + # Second message — should include context from first + params2 = _group_envelope(group_id="grp==", source_name="Bob", message="msg2") + await ch._handle_receive_notification(params2) + assert "[Recent group messages for context:]" in handled[1]["content"] + assert "msg1" in handled[1]["content"] + + @pytest.mark.asyncio + async def test_group_metadata_marks_is_group(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled[0]["metadata"]["is_group"] is True + assert handled[0]["metadata"]["group_id"] == "grp==" + + @pytest.mark.asyncio + async def test_bot_account_alias_learned_from_incoming(self): + ch, handled = self._make_group_channel(require_mention=False) + # If the bot's own UUID appears in an envelope we learn it + params = _dm_envelope(source_number="+10000000000", source_uuid="new-bot-uuid") + # DMs from self are processed (learning alias), but DM policy is open + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + ch._start_typing = lambda chat_id: None # type: ignore[method-assign] + await ch._handle_receive_notification(params) + assert ch._id_matches_account("new-bot-uuid") + + +# --------------------------------------------------------------------------- +# Command handling +# --------------------------------------------------------------------------- + + +class TestCommandHandling: + @pytest.mark.asyncio + async def test_dm_command_forwarded_to_bus(self): + """Slash commands in DMs are forwarded to the bus for AgentLoop to handle.""" + ch = _make_channel(dm_enabled=True, dm_policy="open") + forwarded: list[dict] = [] + + async def capture(**kw): + forwarded.append(kw) + + ch._handle_message = capture # type: ignore[method-assign] + ch._start_typing = AsyncMock() + + params = _dm_envelope(source_number="+19995550001", message="/reset") + await ch._handle_receive_notification(params) + + assert len(forwarded) == 1 + assert forwarded[0]["content"].strip() == "/reset" + + @pytest.mark.asyncio + async def test_group_command_bypasses_mention_requirement(self): + """Slash commands in groups bypass the mention requirement and reach the bus.""" + ch = _make_channel( + group_enabled=True, group_policy="open", require_mention=True + ) + forwarded: list[dict] = [] + + async def capture(**kw): + forwarded.append(kw) + + ch._handle_message = capture # type: ignore[method-assign] + ch._start_typing = AsyncMock() + + params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset") + await ch._handle_receive_notification(params) + + assert len(forwarded) == 1 + assert "/reset" in forwarded[0]["content"] + + @pytest.mark.asyncio + async def test_command_denied_for_disallowed_dm_sender(self): + """Commands from senders not on the DM allowlist are dropped.""" + ch = _make_channel(dm_enabled=False) + forwarded: list[dict] = [] + + async def capture(**kw): + forwarded.append(kw) + + ch._handle_message = capture # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550001", message="/reset") + await ch._handle_receive_notification(params) + + assert forwarded == [] + + +# --------------------------------------------------------------------------- +# send() — outbound messages +# --------------------------------------------------------------------------- + + +class TestSend: + def _make_send_channel(self) -> tuple[SignalChannel, _FakeHTTPClient]: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + return ch, client + + @pytest.mark.asyncio + async def test_send_plain_text_posts_rpc(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") + await ch.send(msg) + assert len(client.posts) == 1 + payload = client.posts[0]["json"] + assert payload["method"] == "send" + assert payload["params"]["message"] == "hello" + + @pytest.mark.asyncio + async def test_send_with_markdown_includes_text_styles(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="**bold**") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "textStyle" in params + assert any("BOLD" in s for s in params["textStyle"]) + + @pytest.mark.asyncio + async def test_send_empty_content_skips_rpc(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="") + await ch.send(msg) + assert client.posts == [] + + @pytest.mark.asyncio + async def test_send_to_group_uses_group_id(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="grp==", content="hi group") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "groupId" in params + assert "recipient" not in params + + @pytest.mark.asyncio + async def test_send_to_dm_uses_recipient(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hi") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "recipient" in params + + @pytest.mark.asyncio + async def test_send_with_media_includes_attachments(self): + ch, client = self._make_send_channel() + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="see attachment", + media=["/tmp/file.jpg"], + ) + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert params.get("attachments") == ["/tmp/file.jpg"] + + @pytest.mark.asyncio + async def test_send_progress_message_does_not_stop_typing(self): + ch, client = self._make_send_channel() + stopped: list[str] = [] + + async def record_stop(chat_id, **kwargs): + stopped.append(chat_id) + + ch._stop_typing = record_stop # type: ignore[method-assign] + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="working...", + metadata={"_progress": True}, + ) + await ch.send(msg) + # Progress messages should NOT stop the typing indicator + assert stopped == [] + + @pytest.mark.asyncio + async def test_send_final_message_stops_typing(self): + ch, client = self._make_send_channel() + stopped: list[str] = [] + + async def record_stop(chat_id, send_stop=True): + stopped.append(chat_id) + + ch._stop_typing = record_stop # type: ignore[method-assign] + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="done") + await ch.send(msg) + assert "+19995550001" in stopped + + @pytest.mark.asyncio + async def test_send_logs_daemon_error_without_raising(self): + ch = _make_channel() + # The daemon returns {"error": {...}} in the JSON body — this is not a Python + # exception; send() logs it but does not raise (only HTTP-level exceptions raise). + ch._http = _FakeHTTPClient(default_response={"error": {"message": "fail"}}) + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") + await ch.send(msg) # must not raise + + +# --------------------------------------------------------------------------- +# stop() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stop_cancels_sse_task() -> None: + ch = _make_channel() + cancelled = False + + async def long_running(): + nonlocal cancelled + try: + await asyncio.sleep(9999) + except asyncio.CancelledError: + cancelled = True + raise + + ch._sse_task = asyncio.create_task(long_running()) + # Yield so the task can enter its body (reach the first await) before cancel. + await asyncio.sleep(0) + ch._running = True + + await ch.stop() + + assert cancelled + assert ch._running is False + + +@pytest.mark.asyncio +async def test_stop_closes_http_client() -> None: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + ch._running = True + + await ch.stop() + + assert client.closed + + +@pytest.mark.asyncio +async def test_stop_safe_when_no_sse_task() -> None: + ch = _make_channel() + ch._running = True + # Should not raise even with no _sse_task + await ch.stop() + assert ch._running is False + + +# --------------------------------------------------------------------------- +# _send_request / _send_http_request +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_send_request_increments_id() -> None: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + + await ch._send_request("testMethod", {"key": "val"}) + await ch._send_request("testMethod", {"key": "val"}) + + ids = [p["json"]["id"] for p in client.posts] + assert ids == [1, 2] + + +@pytest.mark.asyncio +async def test_send_request_raises_when_not_connected() -> None: + ch = _make_channel() + # _http is None by default + with pytest.raises(RuntimeError, match="Not connected"): + await ch._send_request("testMethod") + + +# --------------------------------------------------------------------------- +# _handle_receive_notification — envelope shapes +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_handle_notification_sync_message_does_not_forward() -> None: + ch = _make_channel(dm_enabled=True, dm_policy="open") + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "syncMessage": { + "sentMessage": { + "destination": "+19990000000", + "message": "sent from other device", + } + }, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + +@pytest.mark.asyncio +async def test_handle_notification_no_source_skipped() -> None: + ch = _make_channel(dm_enabled=True, dm_policy="open") + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + + notification = {"envelope": {"dataMessage": {"message": "ghost"}}} + await ch._handle_receive_notification(notification) + assert handled == [] + + +# --------------------------------------------------------------------------- +# Config: allow_from property aggregation +# --------------------------------------------------------------------------- + + +def test_config_allow_from_aggregates_dm_and_group() -> None: + config = SignalConfig( + enabled=True, + phone_number="+10000000000", + dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]), + group=SignalGroupConfig( + enabled=True, policy="allowlist", allow_from=["+3333", "+1111"] + ), + ) + combined = config.allow_from + assert "+1111" in combined + assert "+2222" in combined + assert "+3333" in combined + # Duplicates removed + assert combined.count("+1111") == 1 + + +def test_config_allow_from_wildcard_propagates() -> None: + config = SignalConfig( + enabled=True, + phone_number="+10000000000", + dm=SignalDMConfig(enabled=True, policy="open", allow_from=["*"]), + group=SignalGroupConfig(enabled=True, policy="open", allow_from=[]), + ) + assert "*" in config.allow_from diff --git a/tests/channels/test_signal_markdown.py b/tests/channels/test_signal_markdown.py new file mode 100644 index 000000000..15eca70ff --- /dev/null +++ b/tests/channels/test_signal_markdown.py @@ -0,0 +1,244 @@ +"""Unit tests for the Signal markdown → plain text + textStyle converter.""" + +from nanobot.channels.signal import _markdown_to_signal + + +def styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]: + """Return a dict mapping each styled substring to its style list.""" + result: dict[str, list[str]] = {} + for entry in text_styles: + start_s, length_s, style = entry.split(":", 2) + start, length = int(start_s), int(length_s) + span = plain[start : start + length] + result.setdefault(span, []).append(style) + return result + + +# --------------------------------------------------------------------------- +# Basic cases +# --------------------------------------------------------------------------- + + +def test_empty(): + plain, styles = _markdown_to_signal("") + assert plain == "" + assert styles == [] + + +def test_plain_text(): + plain, styles = _markdown_to_signal("hello world") + assert plain == "hello world" + assert styles == [] + + +def test_bold_stars(): + plain, styles = _markdown_to_signal("say **hello** now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["BOLD"]} + + +def test_bold_underscores(): + plain, styles = _markdown_to_signal("say __hello__ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["BOLD"]} + + +def test_italic_star(): + plain, styles = _markdown_to_signal("say *hello* now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["ITALIC"]} + + +def test_italic_underscore(): + plain, styles = _markdown_to_signal("say _hello_ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["ITALIC"]} + + +def test_strikethrough(): + plain, styles = _markdown_to_signal("say ~~hello~~ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["STRIKETHROUGH"]} + + +# --------------------------------------------------------------------------- +# Code +# --------------------------------------------------------------------------- + + +def test_inline_code(): + plain, styles = _markdown_to_signal("run `ls -la` here") + assert plain == "run ls -la here" + assert styles_for(plain, styles) == {"ls -la": ["MONOSPACE"]} + + +def test_code_block(): + plain, styles = _markdown_to_signal("```\nprint('hi')\n```") + assert "print('hi')" in plain + assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or \ + "MONOSPACE" in str(styles_for(plain, styles)) + + +def test_code_block_with_lang(): + plain, styles = _markdown_to_signal("```python\ncode\n```") + assert "code" in plain + assert any("MONOSPACE" in s for s in styles) + + +def test_code_block_not_processed_further(): + """Markdown inside a code block must not be styled.""" + plain, styles = _markdown_to_signal("```\n**not bold**\n```") + assert "**not bold**" in plain + # Only MONOSPACE should be applied, no BOLD + for entry in styles: + assert "BOLD" not in entry + + +def test_inline_code_not_processed_further(): + """Markdown inside inline code must not be styled.""" + plain, styles = _markdown_to_signal("use `**raw**` please") + assert "**raw**" in plain + for entry in styles: + assert "BOLD" not in entry + + +# --------------------------------------------------------------------------- +# Headers +# --------------------------------------------------------------------------- + + +def test_header_becomes_bold(): + plain, styles = _markdown_to_signal("# My Title") + assert plain == "My Title" + assert styles_for(plain, styles) == {"My Title": ["BOLD"]} + + +def test_h2_becomes_bold(): + plain, styles = _markdown_to_signal("## Sub-section") + assert plain == "Sub-section" + assert styles_for(plain, styles) == {"Sub-section": ["BOLD"]} + + +# --------------------------------------------------------------------------- +# Blockquotes +# --------------------------------------------------------------------------- + + +def test_blockquote_strips_marker(): + plain, styles = _markdown_to_signal("> some quote") + assert plain == "some quote" + assert styles == [] + + +# --------------------------------------------------------------------------- +# Lists +# --------------------------------------------------------------------------- + + +def test_bullet_dash(): + plain, styles = _markdown_to_signal("- item one") + assert plain == "• item one" + + +def test_bullet_star(): + plain, styles = _markdown_to_signal("* item two") + assert plain == "• item two" + + +def test_numbered_list(): + plain, styles = _markdown_to_signal("1. first\n2. second") + assert "1. first" in plain + assert "2. second" in plain + + +# --------------------------------------------------------------------------- +# Links +# --------------------------------------------------------------------------- + + +def test_link_text_differs_from_url(): + plain, styles = _markdown_to_signal("[Click here](https://example.com)") + assert plain == "Click here (https://example.com)" + assert styles == [] + + +def test_link_text_equals_url(): + plain, styles = _markdown_to_signal("[https://example.com](https://example.com)") + assert plain == "https://example.com" + assert styles == [] + + +def test_link_text_equals_url_without_scheme(): + plain, styles = _markdown_to_signal("[example.com](https://example.com)") + assert plain == "https://example.com" + + +# --------------------------------------------------------------------------- +# Mixed / nesting +# --------------------------------------------------------------------------- + + +def test_bold_and_italic_adjacent(): + plain, styles = _markdown_to_signal("**bold** and *italic*") + assert plain == "bold and italic" + sd = styles_for(plain, styles) + assert sd.get("bold") == ["BOLD"] + assert sd.get("italic") == ["ITALIC"] + + +def test_header_with_inline_code(): + """Header becomes BOLD; code inside becomes MONOSPACE (not double-BOLD).""" + plain, styles = _markdown_to_signal("# Use `grep`") + assert plain == "Use grep" + sd = styles_for(plain, styles) + assert "BOLD" in sd.get("Use ", []) or "BOLD" in str(styles) + assert "MONOSPACE" in sd.get("grep", []) + + +def test_multiline_mixed(): + md = "**Title**\n\nSome *italic* text.\n\n- bullet\n- another" + plain, styles = _markdown_to_signal(md) + assert "Title" in plain + assert "italic" in plain + assert "• bullet" in plain + sd = styles_for(plain, styles) + assert "BOLD" in sd.get("Title", []) + assert "ITALIC" in sd.get("italic", []) + + +# --------------------------------------------------------------------------- +# Table rendering +# --------------------------------------------------------------------------- + + +def test_table_rendered_as_monospace(): + md = "| A | B |\n| - | - |\n| 1 | 2 |" + plain, styles = _markdown_to_signal(md) + assert "A" in plain and "B" in plain + assert any("MONOSPACE" in s for s in styles) + + +# --------------------------------------------------------------------------- +# Style range format +# --------------------------------------------------------------------------- + + +def test_style_range_format(): + """Each style entry must be 'start:length:STYLE'.""" + _, styles = _markdown_to_signal("**bold** text") + for entry in styles: + parts = entry.split(":") + assert len(parts) == 3 + assert parts[0].isdigit() + assert parts[1].isdigit() + assert parts[2] in {"BOLD", "ITALIC", "STRIKETHROUGH", "MONOSPACE", "SPOILER"} + + +def test_style_ranges_are_within_bounds(): + text = "hello **world** end" + plain, styles = _markdown_to_signal(text) + for entry in styles: + start_s, length_s, _ = entry.split(":", 2) + start, length = int(start_s), int(length_s) + assert start >= 0 + assert start + length <= len(plain) From 1a6fe093e78c8406a7de1f5399640c9212cd9e93 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 11:01:10 -0400 Subject: [PATCH 06/54] fix(signal): drop duplicate self in unconfigured-account log call Addresses review feedback on HKUDS/nanobot#3852: self.self.logger.error would crash if the phone_number guard ever fired. --- nanobot/channels/signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 3e35ae676..45c520291 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -276,7 +276,7 @@ class SignalChannel(BaseChannel): async def start(self) -> None: """Start the Signal channel and connect to signal-cli daemon.""" if not self.config.phone_number: - self.self.logger.error("Signal account not configured") + self.logger.error("Signal account not configured") return self._running = True From 8f6b7611a2560da5e5c0aab4c053c4040e897f87 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 10:51:57 -0400 Subject: [PATCH 07/54] fix(signal): emit textStyle offsets in UTF-16 code units Signal's BodyRange (via signal-cli's textStyle) interprets start/length as UTF-16 code units, but the Phase-3 assembly used Python's len(), which counts code points. A single non-BMP character (e.g. an emoji) earlier in a message shifted every subsequent styled span left by one unit, dropping the last letter of bold/italic words. Track a running UTF-16 offset in the assembly loop and add regression tests covering emojis, supplementary CJK, ZWJ sequences, and a multi-section message that mirrors the reported failure. Co-Authored-By: Claude Opus 4.7 --- nanobot/channels/signal.py | 16 +++- tests/channels/test_signal_markdown.py | 109 +++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 3 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 45c520291..84bec425d 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -43,6 +43,11 @@ _SIG_STRIKE_RE = re.compile(r'~~(.+?)~~|(? int: + """UTF-16 code-unit length, matching Signal BodyRange semantics.""" + return len(s.encode("utf-16-le")) // 2 + + def _sig_strip_cell(s: str) -> str: """Strip inline markdown from a table cell for plain-text rendering.""" s = re.sub(r'\*\*(.+?)\*\*', r'\1', s) @@ -183,15 +188,20 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]: # Strikethrough: ~~text~~ (standard) or ~text~ (single-tilde variant). transform(_SIG_STRIKE_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"STRIKETHROUGH"})]) - # Phase 3: assemble output. + # Phase 3: assemble output. Offsets and lengths are emitted in UTF-16 code + # units because Signal's BodyRange (via signal-cli's textStyle) interprets + # them as such; Python's len() counts code points, which would shift ranges + # left by 1 unit per non-BMP character preceding them. plain_text = "" text_styles: list[str] = [] + utf16_offset = 0 for run in runs: if not run.text: continue - start = len(plain_text) plain_text += run.text - length = len(plain_text) - start + start = utf16_offset + length = _utf16_len(run.text) + utf16_offset += length for style in sorted(run.styles): text_styles.append(f"{start}:{length}:{style}") diff --git a/tests/channels/test_signal_markdown.py b/tests/channels/test_signal_markdown.py index 15eca70ff..36b75f163 100644 --- a/tests/channels/test_signal_markdown.py +++ b/tests/channels/test_signal_markdown.py @@ -3,6 +3,10 @@ from nanobot.channels.signal import _markdown_to_signal +def _utf16_len(s: str) -> int: + return len(s.encode("utf-16-le")) // 2 + + def styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]: """Return a dict mapping each styled substring to its style list.""" result: dict[str, list[str]] = {} @@ -14,6 +18,18 @@ def styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]: return result +def utf16_styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]: + """Like styles_for, but slices `plain` using UTF-16 offsets (Signal's units).""" + encoded = plain.encode("utf-16-le") + result: dict[str, list[str]] = {} + for entry in text_styles: + start_s, length_s, style = entry.split(":", 2) + start, length = int(start_s), int(length_s) + span = encoded[start * 2 : (start + length) * 2].decode("utf-16-le") + result.setdefault(span, []).append(style) + return result + + # --------------------------------------------------------------------------- # Basic cases # --------------------------------------------------------------------------- @@ -242,3 +258,96 @@ def test_style_ranges_are_within_bounds(): start, length = int(start_s), int(length_s) assert start >= 0 assert start + length <= len(plain) + + +# --------------------------------------------------------------------------- +# Non-BMP / UTF-16 offsets +# +# Signal's BodyRange (and signal-cli's textStyle) interprets start/length in +# UTF-16 code units. Python's len() counts code points, so characters outside +# the BMP (emojis, supplementary CJK) shift offsets by +1 per occurrence. +# --------------------------------------------------------------------------- + + +def assert_within_utf16_bounds(plain: str, styles: list[str]) -> None: + limit = _utf16_len(plain) + for entry in styles: + start_s, length_s, _ = entry.split(":", 2) + start, length = int(start_s), int(length_s) + assert start >= 0 + assert start + length <= limit, ( + f"range {entry} exceeds utf-16 length {limit} of {plain!r}" + ) + + +def test_bold_with_emoji_inside(): + plain, styles = _markdown_to_signal("**hi 🎉 bye**") + assert plain == "hi 🎉 bye" + assert utf16_styles_for(plain, styles) == {"hi 🎉 bye": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_italic_with_trailing_emoji(): + plain, styles = _markdown_to_signal("*bye 🎉*") + assert plain == "bye 🎉" + assert utf16_styles_for(plain, styles) == {"bye 🎉": ["ITALIC"]} + assert_within_utf16_bounds(plain, styles) + + +def test_bold_after_emoji_prefix(): + plain, styles = _markdown_to_signal("🎉 **bold**") + assert plain == "🎉 bold" + assert utf16_styles_for(plain, styles) == {"bold": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_bold_after_and_inside_emoji(): + plain, styles = _markdown_to_signal("🎉 **a 🎊 b**") + assert plain == "🎉 a 🎊 b" + assert utf16_styles_for(plain, styles) == {"a 🎊 b": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_supplementary_cjk_in_bold(): + """Non-BMP CJK (U+20BB7) proves the bug is UTF-16, not emoji-specific.""" + plain, styles = _markdown_to_signal("**𠮷野家**") + assert plain == "𠮷野家" + assert utf16_styles_for(plain, styles) == {"𠮷野家": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_zwj_emoji_in_bold(): + """ZWJ family sequence = multiple surrogate pairs + BMP ZWJs.""" + plain, styles = _markdown_to_signal("**hi 👨‍👩‍👧 bye**") + assert plain == "hi 👨‍👩‍👧 bye" + assert utf16_styles_for(plain, styles) == {"hi 👨‍👩‍👧 bye": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_ascii_offsets_unchanged(): + """ASCII-only path must produce the same offsets as before the UTF-16 fix.""" + plain, styles = _markdown_to_signal("**bold** plain *it*") + assert plain == "bold plain it" + assert sorted(styles) == sorted(["0:4:BOLD", "11:2:ITALIC"]) + + +def test_reported_daily_brief_pattern(): + """Regression for the reported bug: a single non-BMP emoji shifts every + subsequent styled span left by 1 UTF-16 unit, lopping off the last letter. + """ + md = ( + "**Weather**\n" + "- Conditions: 🌩️ Thunderstorms\n\n" + "**News**\n" + "*World*\n" + "*Local*\n\n" + "**Quote of the Day**" + ) + plain, styles = _markdown_to_signal(md) + sd = utf16_styles_for(plain, styles) + assert sd.get("Weather") == ["BOLD"] + assert sd.get("News") == ["BOLD"] + assert sd.get("World") == ["ITALIC"] + assert sd.get("Local") == ["ITALIC"] + assert sd.get("Quote of the Day") == ["BOLD"] + assert_within_utf16_bounds(plain, styles) From 96eb3b71947af28052926e1299934601ced71909 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 10:56:42 -0400 Subject: [PATCH 08/54] fix(signal): redistribute textStyle ranges across split message chunks split_message can break a long Signal payload into multiple JSON-RPC sends, but the previous code attached the full textStyle list only to chunk 0. Style ranges in later chunks were dropped, and ranges whose offsets pointed past chunk 0's end were sent as invalid metadata against chunk 0. Add _partition_styles, which rebases each range against the chunk it lives in (in UTF-16 code units, matching the markdown converter) and splits boundary-spanning ranges across the chunks they touch. Whitespace trimmed by split_message's lstrip is skipped so offsets stay aligned. Co-Authored-By: Claude Opus 4.7 --- nanobot/channels/signal.py | 53 +++++++++++- tests/channels/test_signal_channel.py | 32 +++++++ tests/channels/test_signal_markdown.py | 110 ++++++++++++++++++++++++- 3 files changed, 192 insertions(+), 3 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 84bec425d..83999877c 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -208,6 +208,54 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]: return plain_text, text_styles +def _partition_styles( + plain_text: str, chunks: list[str], text_styles: list[str] +) -> list[list[str]]: + """Partition Signal textStyle ranges across message chunks. + + ``split_message`` slices ``plain_text`` into pieces (optionally trimming + whitespace at the boundaries), but the style ranges produced by + ``_markdown_to_signal`` are expressed in UTF-16 offsets relative to the + full ``plain_text``. This redistributes them per chunk with offsets + rebased to each chunk's start. Ranges that span a boundary are split + across the chunks they touch; ranges that fall entirely in trimmed + whitespace are dropped. + """ + if not chunks: + return [] + if not text_styles: + return [[] for _ in chunks] + + # Locate each chunk's UTF-16 start in plain_text. split_message lstrips at + # boundaries (but not before the first chunk), so we skip whitespace + # between chunks to mirror that. + chunk_ranges: list[tuple[int, int]] = [] + cursor = 0 # Python codepoint cursor in plain_text + for i, chunk in enumerate(chunks): + if i > 0: + while cursor < len(plain_text) and plain_text[cursor].isspace(): + cursor += 1 + utf16_start = _utf16_len(plain_text[:cursor]) + utf16_end = utf16_start + _utf16_len(chunk) + chunk_ranges.append((utf16_start, utf16_end)) + cursor += len(chunk) + + result: list[list[str]] = [[] for _ in chunks] + for entry in text_styles: + s, ln, style = entry.split(":", 2) + r_start = int(s) + r_end = r_start + int(ln) + for i, (c_start, c_end) in enumerate(chunk_ranges): + if r_end <= c_start or r_start >= c_end: + continue + new_start = max(r_start, c_start) - c_start + new_end = min(r_end, c_end) - c_start + new_length = new_end - new_start + if new_length > 0: + result[i].append(f"{new_start}:{new_length}:{style}") + return result + + class SignalDMConfig(Base): """Signal DM policy configuration.""" @@ -392,10 +440,11 @@ class SignalChannel(BaseChannel): recipient_params = self._recipient_params(msg.chat_id) chunks = split_message(plain_text, self._MAX_MESSAGE_LEN) if plain_text else [""] + chunk_styles = _partition_styles(plain_text, chunks, text_styles) for i, chunk in enumerate(chunks): params: dict[str, Any] = {"message": chunk} - if text_styles and i == 0: - params["textStyle"] = text_styles + if chunk_styles[i]: + params["textStyle"] = chunk_styles[i] params.update(recipient_params) if msg.media and i == 0: params["attachments"] = msg.media diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index b5149459b..ecdeda334 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -834,6 +834,38 @@ class TestSend: assert "textStyle" in params assert any("BOLD" in s for s in params["textStyle"]) + @pytest.mark.asyncio + async def test_send_split_message_redistributes_text_styles(self): + """Long message split across chunks: each chunk gets its own textStyle + with offsets rebased to that chunk.""" + ch, client = self._make_send_channel() + ch._MAX_MESSAGE_LEN = 12 # type: ignore[attr-defined] + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="**head** middle and **tail**", + ) + await ch.send(msg) + assert len(client.posts) >= 2 + # Chunk 0 has BOLD for "head"; chunk 1+ must also carry BOLD for "tail". + bold_chunks = [ + p["json"]["params"] + for p in client.posts + if any("BOLD" in s for s in p["json"]["params"].get("textStyle", [])) + ] + assert len(bold_chunks) >= 2, ( + "expected BOLD ranges in more than one chunk; got " + f"{[p['json']['params'] for p in client.posts]}" + ) + # Each emitted range must point inside its own chunk's text. + for params in bold_chunks: + chunk_text = params["message"] + for entry in params["textStyle"]: + s, ln, _ = entry.split(":", 2) + start, length = int(s), int(ln) + end_units = start + length + assert end_units <= len(chunk_text.encode("utf-16-le")) // 2 + @pytest.mark.asyncio async def test_send_empty_content_skips_rpc(self): ch, client = self._make_send_channel() diff --git a/tests/channels/test_signal_markdown.py b/tests/channels/test_signal_markdown.py index 36b75f163..55b081095 100644 --- a/tests/channels/test_signal_markdown.py +++ b/tests/channels/test_signal_markdown.py @@ -1,6 +1,9 @@ """Unit tests for the Signal markdown → plain text + textStyle converter.""" -from nanobot.channels.signal import _markdown_to_signal +import pytest + +from nanobot.channels.signal import _markdown_to_signal, _partition_styles +from nanobot.utils.helpers import split_message def _utf16_len(s: str) -> int: @@ -351,3 +354,108 @@ def test_reported_daily_brief_pattern(): assert sd.get("Local") == ["ITALIC"] assert sd.get("Quote of the Day") == ["BOLD"] assert_within_utf16_bounds(plain, styles) + + +# --------------------------------------------------------------------------- +# Chunk redistribution +# +# split_message can break a long Signal payload into multiple chunks. The +# style ranges from _markdown_to_signal are anchored to the full text, so +# they must be redistributed per-chunk with rebased offsets — otherwise +# styles for chunks 1..N are silently lost. +# --------------------------------------------------------------------------- + + +def _resolve_chunk_styles(text: str, max_len: int) -> tuple[list[str], list[list[str]]]: + """Helper: full markdown → signal pipeline, including chunking.""" + plain, styles = _markdown_to_signal(text) + chunks = split_message(plain, max_len) if plain else [""] + return chunks, _partition_styles(plain, chunks, styles) + + +def test_partition_styles_single_chunk_passthrough(): + plain, styles = _markdown_to_signal("**bold** plain *it*") + parts = _partition_styles(plain, [plain], styles) + assert parts == [styles] + + +def test_partition_styles_no_styles(): + plain = "hello world" + assert _partition_styles(plain, [plain], []) == [[]] + assert _partition_styles(plain, ["hello", "world"], []) == [[], []] + + +def test_partition_styles_drops_styles_outside_chunks(): + """Whitespace trimmed by split_message must not carry a style range.""" + plain = "a b" + # Fake a style spanning the trimmed whitespace only. + chunks = ["a", "b"] + parts = _partition_styles(plain, chunks, ["1:3:BOLD"]) + assert parts == [[], []] + + +def test_partition_styles_long_message_preserves_chunk_one_styles(): + """A bold span deep in the message must follow the message into chunk 1.""" + # Two ~30-char paragraphs separated by a blank line, then **tail**. + line_a = "alpha " * 5 # 30 chars, ends with space + line_b = "beta " * 5 + md = f"{line_a.strip()}\n\n{line_b.strip()}\n\n**tail**" + plain, styles = _markdown_to_signal(md) + # Force a split between the paragraphs. + max_len = len(line_a.strip()) + 2 # fits paragraph A + the "\n\n" + chunks = split_message(plain, max_len) + assert len(chunks) >= 2, "test setup must produce a split" + parts = _partition_styles(plain, chunks, styles) + # The bold "tail" should land in the last chunk, with chunk-relative offset. + final_chunk = chunks[-1] + final_styles = parts[-1] + assert any("BOLD" in s for s in final_styles) + for entry in final_styles: + s, ln, _ = entry.split(":", 2) + start, length = int(s), int(ln) + slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le") + assert slice_ == "tail" + + +def test_partition_styles_chunk_zero_styles_unchanged(): + """Styles entirely in chunk 0 keep their original offsets.""" + md = "**head** middle and **tail**" + plain, styles = _markdown_to_signal(md) + # Split so chunk 0 contains "head" and part of the rest, chunk 1 contains "tail". + chunks = split_message(plain, 12) + assert len(chunks) >= 2 + parts = _partition_styles(plain, chunks, styles) + # "head" lives in chunk 0; assert its offset is unchanged (chunk 0 starts at 0). + head_entries = [s for s in parts[0] if "BOLD" in s] + assert any(s.startswith("0:4:") for s in head_entries) + + +def test_partition_styles_with_non_bmp_chunk_offset(): + """Chunk-start offsets must be expressed in UTF-16 code units.""" + # Emoji in chunk 0, bold in chunk 1. + md = "🎉 alpha beta gamma\n\n**tail**" + plain, styles = _markdown_to_signal(md) + chunks = split_message(plain, 18) + assert len(chunks) >= 2 + parts = _partition_styles(plain, chunks, styles) + final_styles = parts[-1] + assert any("BOLD" in s for s in final_styles) + final_chunk = chunks[-1] + for entry in final_styles: + s, ln, _ = entry.split(":", 2) + start, length = int(s), int(ln) + slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le") + assert slice_ == "tail" + + +def test_partition_styles_range_spanning_chunks_is_split(): + """A style range that straddles a chunk boundary gets sliced into both chunks.""" + # Construct manually: plain = "abc def", style covers "abc def" (whole thing). + plain = "abc def" + chunks = split_message(plain, 4) # "abc" / "def" + assert chunks == ["abc", "def"] + parts = _partition_styles(plain, chunks, ["0:7:BOLD"]) + # Chunk 0 holds 0:3:BOLD, chunk 1 holds 0:3:BOLD (length=3 each, "def" only + # since the space was trimmed by lstrip). + assert parts[0] == ["0:3:BOLD"] + assert parts[1] == ["0:3:BOLD"] From ca72f6b6c94e3b3ece751e2ea43f2b25d84ff7ce Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 11:23:12 -0400 Subject: [PATCH 09/54] refactor(signal): hygiene cleanups around constants, typing, and config - Hoist the cell-strip patterns to module level so they match the rest of the module's regex style and aren't reparsed on every call. - Type the markdown transform callback and the mention id walker so the inline Callable signature is no longer an untyped Any. - Add _HTTP_TIMEOUT_SECONDS alongside the other class-level tunables. - Reject group_message_buffer_size <= 0 in a Pydantic field_validator rather than silently disabling the buffer at write time. - Mark SignalConfig.allow_from as a computed_field so it shows up in model_dump() instead of being invisible to serialization. Co-Authored-By: Claude Opus 4.7 --- nanobot/channels/signal.py | 44 ++++++++++++++++++++------- tests/channels/test_signal_channel.py | 11 ++++--- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 83999877c..50d36c3af 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -8,12 +8,13 @@ import re import shutil import unicodedata from collections import deque +from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path from typing import Any import httpx -from pydantic import Field +from pydantic import Field, computed_field, field_validator from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus @@ -42,6 +43,18 @@ _SIG_ITALIC_RE = re.compile(r'(? int: """UTF-16 code-unit length, matching Signal BodyRange semantics.""" @@ -50,10 +63,8 @@ def _utf16_len(s: str) -> int: def _sig_strip_cell(s: str) -> str: """Strip inline markdown from a table cell for plain-text rendering.""" - s = re.sub(r'\*\*(.+?)\*\*', r'\1', s) - s = re.sub(r'__(.+?)__', r'\1', s) - s = re.sub(r'~~(.+?)~~', r'\1', s) - s = re.sub(r'`([^`]+)`', r'\1', s) + for pattern, repl in _SIG_CELL_STRIP_PATTERNS: + s = pattern.sub(repl, s) return s.strip() @@ -132,7 +143,10 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]: # Phase 2 (run-based): process inline patterns. runs: list[_Run] = [_Run(text)] - def transform(pattern: re.Pattern, make_runs: Any) -> None: + def transform( + pattern: re.Pattern, + make_runs: Callable[[re.Match, frozenset[str]], list[_Run]], + ) -> None: new_runs: list[_Run] = [] for run in runs: if run.opaque: @@ -284,6 +298,14 @@ class SignalConfig(Base): dm: SignalDMConfig = Field(default_factory=SignalDMConfig) group: SignalGroupConfig = Field(default_factory=SignalGroupConfig) + @field_validator("group_message_buffer_size") + @classmethod + def _validate_buffer_size(cls, v: int) -> int: + if v <= 0: + raise ValueError("group_message_buffer_size must be > 0") + return v + + @computed_field # type: ignore[prop-decorator] @property def allow_from(self) -> list[str]: """Aggregate allowlist for the base-class is_allowed() check. @@ -309,6 +331,7 @@ class SignalChannel(BaseChannel): display_name = "Signal" _TYPING_REFRESH_SECONDS = 10.0 _MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB) + _HTTP_TIMEOUT_SECONDS = 60.0 @classmethod def default_config(cls) -> dict[str, Any]: @@ -351,7 +374,9 @@ class SignalChannel(BaseChannel): self.logger.info(f"Connecting to signal-cli daemon at {base_url}...") # Create HTTP client - self._http = httpx.AsyncClient(timeout=60.0, base_url=base_url) + self._http = httpx.AsyncClient( + timeout=self._HTTP_TIMEOUT_SECONDS, base_url=base_url + ) # Test connection try: @@ -777,9 +802,6 @@ class SignalChannel(BaseChannel): message_text: The message content timestamp: Message timestamp """ - if self.config.group_message_buffer_size <= 0: - return - # Create buffer for this group if it doesn't exist if group_id not in self._group_buffers: self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size) @@ -906,7 +928,7 @@ class SignalChannel(BaseChannel): """Extract possible identifier fields from a mention payload.""" ids: list[str] = [] - def _walk(value: Any, depth: int = 0) -> None: + def _walk(value: dict[str, Any] | Any, depth: int = 0) -> None: if depth > 2: return if not isinstance(value, dict): diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index ecdeda334..27b8b1e91 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -448,10 +448,13 @@ class TestGroupBuffer: ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i) assert len(ch._group_buffers["g1"]) == 3 - def test_zero_buffer_size_does_not_add(self): - ch = _make_channel(group_buffer_size=0) - ch._add_to_group_buffer("g1", "Alice", "+1111", "msg", 1000) - assert "g1" not in ch._group_buffers + def test_zero_buffer_size_rejected_by_validator(self): + with pytest.raises(ValueError, match="group_message_buffer_size"): + _make_channel(group_buffer_size=0) + + def test_negative_buffer_size_rejected_by_validator(self): + with pytest.raises(ValueError, match="group_message_buffer_size"): + _make_channel(group_buffer_size=-1) def test_context_limits_message_length(self): ch = _make_channel(group_buffer_size=5) From 882d4139d70559ddcdb29129225e24f5151c529f Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 11:27:12 -0400 Subject: [PATCH 10/54] fix(signal): normalize identifiers when matching DM allowlist The DM allowlist check split sender_id on '|' and looked for raw membership in the allow_from list. Senders carry their phone number with a leading '+' but admins routinely write allowlist entries without it (or vice versa), and UUID/ACI matches were case-sensitive. Both forms now flow through _normalize_signal_id, so an entry like 19995550001 matches a sender +19995550001 and a UUID matches case-insensitively. Co-Authored-By: Claude Opus 4.7 --- nanobot/channels/signal.py | 27 ++++++++++++++++++++---- tests/channels/test_signal_channel.py | 30 +++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 4 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 50d36c3af..3fd23c780 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -695,10 +695,7 @@ class SignalChannel(BaseChannel): self.logger.debug(f"Ignoring DM from {sender_id} (DMs disabled)") return if self.config.dm.policy == "allowlist": - allow_list = self.config.dm.allow_from - sender_str = str(sender_id) - parts = [sender_str] + (sender_str.split("|") if "|" in sender_str else []) - if not any(p for p in parts if p in allow_list): + if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from): self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})") return @@ -864,6 +861,28 @@ class SignalChannel(BaseChannel): normalized.append(f"+{raw}") return list(dict.fromkeys(normalized)) + @classmethod + def _sender_matches_allowlist(cls, sender_id: str, allow_list: list[str]) -> bool: + """Return True if any normalized variant of sender_id is on allow_list. + + sender_id is the pipe-joined identifier string built by + _collect_sender_id_parts. Each part and each allow_list entry is run + through _normalize_signal_id so an allowlist entry like ``1234567890`` + matches a sender ``+1234567890`` (and vice versa), and case-only + differences in UUIDs/ACIs match too. + """ + if not allow_list: + return False + sender_variants: set[str] = set() + for part in str(sender_id).split("|"): + sender_variants.update(cls._normalize_signal_id(part)) + if not sender_variants: + return False + allow_variants: set[str] = set() + for entry in allow_list: + allow_variants.update(cls._normalize_signal_id(entry)) + return bool(sender_variants & allow_variants) + def _remember_account_id_alias(self, value: str | None) -> None: """Remember known bot identifiers for mention matching.""" if not value: diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index 27b8b1e91..d6308b803 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -510,6 +510,36 @@ class TestHandleDataMessageDM: await ch._handle_receive_notification(params) assert handled == [] + @pytest.mark.asyncio + async def test_dm_allowlist_matches_without_plus_prefix(self): + """An allowlist entry without '+' must match a sender that carries '+'.""" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["19995550001"]) + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_with_plus_prefix(self): + """An allowlist entry with '+' must match a sender without '+'.""" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"]) + params = _dm_envelope(source_number="+19995550001", source_uuid=None) + # Replace envelope's sourceNumber with the non-prefixed form by editing + # the constructed dict directly so _collect_sender_id_parts sees it. + params["envelope"]["sourceNumber"] = "19995550001" + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_uuid_case_insensitive(self): + """UUID matching must be case-insensitive.""" + uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890" + ch, handled = self._make_dm_channel( + policy="allowlist", allow_from=[uuid.lower()] + ) + params = _dm_envelope(source_number="+19995550001", source_uuid=uuid) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + @pytest.mark.asyncio async def test_dm_disabled_rejected(self): ch = _make_channel(dm_enabled=False) From ad7c1ac381dacf20386c83ee956ffbf3ba1ebf69 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 11:30:53 -0400 Subject: [PATCH 11/54] refactor(signal): wrap top-level receive handler with _safe_handle Replace the inline try/except at the end of _handle_receive_notification with a small async context manager that swallows the exception, logs self.logger.error with the offending payload's repr (bounded to 200 chars), and attaches the traceback via logger.opt(exception=True). The previous log line only carried `e`, so diagnosing a bad envelope from production logs required correlating timestamps. The wrapper is generic so future receive/dispatch sites can adopt it; for now only this site uses it. Co-Authored-By: Claude Opus 4.7 --- nanobot/channels/signal.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 3fd23c780..fce941c74 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -8,7 +8,8 @@ import re import shutil import unicodedata from collections import deque -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -557,10 +558,29 @@ class SignalChannel(BaseChannel): self.logger.error(f"Error in SSE receive loop: {e}") raise + @asynccontextmanager + async def _safe_handle( + self, action: str, payload: Any = None + ) -> AsyncIterator[None]: + """Swallow and log any exception from a top-level handler block. + + Logs `self.logger.error` with the action name, the exception, and a + bounded ``repr`` of the offending payload so the offending input is + recoverable from logs without having to correlate by timestamp. + """ + try: + yield + except Exception as e: + snippet = repr(payload)[:200] if payload is not None else "" + text = f"Error in {action}: {e}" + if snippet: + text += f" | payload={snippet}" + self.logger.opt(exception=True).error(text) + async def _handle_receive_notification(self, params: dict[str, Any]) -> None: """Handle incoming message notification from signal-cli.""" self.logger.debug(f"_handle_receive_notification called with: {params}") - try: + async with self._safe_handle("receive notification", params): # Extract envelope from SSE notification: {"envelope": {...}} envelope = params.get("envelope", {}) @@ -613,9 +633,6 @@ class SignalChannel(BaseChannel): elif typing_message: pass # Ignore typing indicators - except Exception as e: - self.logger.error(f"Error handling receive notification: {e}") - async def _handle_data_message( self, sender_id: str, From 83aed436823eef7e54aaa8333d8a5601bf240d00 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 11:33:30 -0400 Subject: [PATCH 12/54] feat(signal): make signal-cli attachments directory configurable The inbound attachment loop hardcoded ~/.local/share/signal-cli/attachments as the source path. That is the daemon's default on Linux but not on macOS or Windows, and breaks if the daemon was launched with XDG_DATA_HOME set. Add SignalConfig.attachments_dir as an optional override. When unset the behavior is unchanged; when set the value is run through Path.expanduser() so ~ is honored. Co-Authored-By: Claude Opus 4.7 --- nanobot/channels/signal.py | 21 +++++++++++++++++---- tests/channels/test_signal_channel.py | 18 ++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index fce941c74..96e515467 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -296,6 +296,11 @@ class SignalConfig(Base): daemon_host: str = "localhost" daemon_port: int = 8080 group_message_buffer_size: int = 20 # Number of recent group messages to keep for context + # Override the directory signal-cli writes inbound attachments to. When + # None, defaults to ~/.local/share/signal-cli/attachments (the daemon's + # platform default on Linux). Set this if the daemon is running with a + # custom XDG_DATA_HOME or on macOS/Windows where the default path differs. + attachments_dir: str | None = None dm: SignalDMConfig = Field(default_factory=SignalDMConfig) group: SignalGroupConfig = Field(default_factory=SignalGroupConfig) @@ -749,10 +754,7 @@ class SignalChannel(BaseChannel): continue try: - # signal-cli stores attachments in ~/.local/share/signal-cli/attachments/ - source_path = ( - Path.home() / ".local/share/signal-cli/attachments" / attachment_id - ) + source_path = self._signal_attachments_dir() / attachment_id if source_path.exists(): dest_path = media_dir / f"signal_{safe_filename(filename)}" @@ -864,6 +866,17 @@ class SignalChannel(BaseChannel): return "\n".join(lines) + def _signal_attachments_dir(self) -> Path: + """Return the directory signal-cli writes inbound attachments to. + + Defaults to ``~/.local/share/signal-cli/attachments`` (the daemon's + platform default on Linux) when ``config.attachments_dir`` is unset. + """ + configured = self.config.attachments_dir + if configured: + return Path(configured).expanduser() + return Path.home() / ".local/share/signal-cli/attachments" + @staticmethod def _normalize_signal_id(value: str) -> list[str]: """Normalize Signal identifiers (phone/uuid/service-id) for matching.""" diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index d6308b803..020118c4c 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +from pathlib import Path from unittest.mock import AsyncMock import pytest @@ -71,6 +72,7 @@ def _make_channel( group_allow_from: list[str] | None = None, require_mention: bool = True, group_buffer_size: int = 20, + attachments_dir: str | None = None, ) -> SignalChannel: config = SignalConfig( enabled=True, @@ -87,6 +89,7 @@ def _make_channel( require_mention=require_mention, ), group_message_buffer_size=group_buffer_size, + attachments_dir=attachments_dir, ) return SignalChannel(config, MessageBus()) @@ -471,6 +474,21 @@ class TestGroupBuffer: # --------------------------------------------------------------------------- +class TestAttachmentsDir: + def test_default_attachments_dir(self): + ch = _make_channel() + expected = Path.home() / ".local/share/signal-cli/attachments" + assert ch._signal_attachments_dir() == expected + + def test_configured_attachments_dir(self, tmp_path): + ch = _make_channel(attachments_dir=str(tmp_path / "custom")) + assert ch._signal_attachments_dir() == tmp_path / "custom" + + def test_attachments_dir_expands_user(self): + ch = _make_channel(attachments_dir="~/signal-attachments") + assert ch._signal_attachments_dir() == Path.home() / "signal-attachments" + + class TestHandleDataMessageDM: def _make_dm_channel(self, policy="open", allow_from=None) -> tuple[SignalChannel, list]: ch = _make_channel(dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or []) From 7733a7840e8eba9806f17bd63182cbbe6ddfa38a Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 11:39:18 -0400 Subject: [PATCH 13/54] refactor(signal): split _handle_data_message into policy and assembly helpers The receive-path handler was ~165 lines deep into nested DM/group policy checks, buffer mutations, mention stripping, attachment downloads, and final bus forwarding. Pull the policy gate out into _check_inbound_policy (returns (allow, chat_id), still appends to the group buffer once allowed) and the text+media construction into _assemble_inbound_content. The top-level method now reads as orchestration only. Add TestCheckInboundPolicy that exercises the helper directly across the DM/group policy permutations, including the buffer side effect, so the new seam is locked in. Co-Authored-By: Claude Opus 4.7 --- nanobot/channels/signal.py | 266 +++++++++++++++----------- tests/channels/test_signal_channel.py | 98 ++++++++++ 2 files changed, 252 insertions(+), 112 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 96e515467..c53c6a133 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -648,141 +648,57 @@ class SignalChannel(BaseChannel): """Handle a data message (text, attachments, etc.).""" message_text = data_message.get("message") or "" attachments = data_message.get("attachments", []) - group_info = data_message.get("groupInfo") - timestamp = data_message.get("timestamp") mentions = data_message.get("mentions", []) - reaction = data_message.get("reaction") + timestamp = data_message.get("timestamp") - # Log full data_message for debugging group detection self.logger.info( f"Data message from {sender_number}: " - f"groupInfo={group_info}, " + f"groupInfo={data_message.get('groupInfo')}, " f"groupV2={data_message.get('groupV2')}, " f"keys={list(data_message.keys())}" ) - # Ignore reaction messages (emoji reactions to messages) - if reaction: - self.logger.debug(f"Ignoring reaction message from {sender_number}: {reaction}") + if data_message.get("reaction"): + self.logger.debug( + f"Ignoring reaction message from {sender_number}: {data_message['reaction']}" + ) return - - # Ignore empty messages (e.g., when bot is added to a group) if not message_text and not attachments: self.logger.debug(f"Ignoring empty message from {sender_number}") return - # Determine chat_id (group ID or sender number) - # Check both groupInfo (v1) and groupV2 (v2) fields for group detection + group_info = data_message.get("groupInfo") group_v2 = data_message.get("groupV2") is_group_message = group_info is not None or group_v2 is not None group_id = self._extract_group_id(group_info, group_v2) - is_command = bool(message_text and message_text.strip().startswith("/")) + allowed, chat_id = self._check_inbound_policy( + sender_id=sender_id, + sender_number=sender_number, + group_id=group_id, + is_group_message=is_group_message, + message_text=message_text, + mentions=mentions, + sender_name=sender_name, + timestamp=timestamp, + ) + if not allowed: + return - if is_group_message: - chat_id = group_id or sender_number - - # Check if this group is allowed before doing anything else - if not self.config.group.enabled: - self.logger.info(f"Ignoring group message from {chat_id} (groups disabled)") - return - if ( - self.config.group.policy == "allowlist" - and chat_id not in self.config.group.allow_from - ): - self.logger.info( - f"Ignoring group message from {chat_id} (policy: {self.config.group.policy})" - ) - return - - # Add to group message buffer (group is allowed) - self._add_to_group_buffer( - group_id=chat_id, - sender_name=sender_name or sender_number, - sender_number=sender_number, - message_text=message_text, - timestamp=timestamp, - ) - - # Commands bypass the mention requirement; non-commands require it. - if not is_command and not self._should_respond_in_group(message_text, mentions): - self.logger.info( - f"Ignoring group message (require_mention: {self.config.group.require_mention})" - ) - return - else: - # Direct message — check policy first, then forward everything to the bus. - chat_id = sender_number - if not self.config.dm.enabled: - self.logger.debug(f"Ignoring DM from {sender_id} (DMs disabled)") - return - if self.config.dm.policy == "allowlist": - if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from): - self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})") - return - - # Build content from text and attachments - content_parts = [] - media_paths = [] - - # For group messages, include recent message context - if is_group_message: - buffer_context = self._get_group_buffer_context(chat_id) - if buffer_context: - content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---") - - # Prepend sender name for group messages so history shows who said what - if message_text: - # Strip bot mentions from text (for group messages) - if is_group_message: - message_text = self._strip_bot_mention(message_text, mentions) - # Prepend sender name to make it clear who is speaking - display_name = sender_name or sender_number - message_text = f"[{display_name}]: {message_text}" - content_parts.append(message_text) - - # Handle attachments - if attachments: - media_dir = get_media_dir("signal") - - for attachment in attachments: - attachment_id = attachment.get("id") - content_type = attachment.get("contentType", "") - filename = attachment.get("filename") or f"attachment_{attachment_id}" - - if not attachment_id: - continue - - try: - source_path = self._signal_attachments_dir() / attachment_id - - if source_path.exists(): - dest_path = media_dir / f"signal_{safe_filename(filename)}" - shutil.copy2(source_path, dest_path) - media_paths.append(str(dest_path)) - - # Determine media type from content type - media_type = content_type.split("/")[0] if "/" in content_type else "file" - if media_type not in ("image", "audio", "video"): - media_type = "file" - - content_parts.append(f"[{media_type}: {dest_path}]") - self.logger.debug(f"Downloaded attachment: {filename} -> {dest_path}") - else: - self.logger.warning(f"Attachment not found: {source_path}") - content_parts.append(f"[attachment: {filename} - not found]") - - except Exception as e: - self.logger.warning(f"Failed to process attachment {filename}: {e}") - content_parts.append(f"[attachment: {filename} - error]") - - content = "\n".join(content_parts) if content_parts else "[empty message]" + content, media_paths = self._assemble_inbound_content( + sender_name=sender_name, + sender_number=sender_number, + message_text=message_text, + attachments=attachments, + mentions=mentions, + is_group_message=is_group_message, + chat_id=chat_id, + ) self.logger.debug(f"Signal message from {sender_number}: {content[:50]}...") await self._start_typing(chat_id) try: - # Forward to message bus await self._handle_message( sender_id=sender_id, chat_id=chat_id, @@ -800,6 +716,132 @@ class SignalChannel(BaseChannel): await self._stop_typing(chat_id) raise + def _check_inbound_policy( + self, + *, + sender_id: str, + sender_number: str, + group_id: str | None, + is_group_message: bool, + message_text: str, + mentions: list, + sender_name: str | None, + timestamp: int | None, + ) -> tuple[bool, str]: + """Decide whether to route an inbound message past DM/group policy. + + Returns ``(allow, chat_id)``. Has one side effect: when a group + message passes the enabled+allowlist gates, it is appended to the + group's rolling context buffer before the mention check. + """ + if is_group_message: + chat_id = group_id or sender_number + if not self.config.group.enabled: + self.logger.info(f"Ignoring group message from {chat_id} (groups disabled)") + return False, chat_id + if ( + self.config.group.policy == "allowlist" + and chat_id not in self.config.group.allow_from + ): + self.logger.info( + f"Ignoring group message from {chat_id} (policy: {self.config.group.policy})" + ) + return False, chat_id + + self._add_to_group_buffer( + group_id=chat_id, + sender_name=sender_name or sender_number, + sender_number=sender_number, + message_text=message_text, + timestamp=timestamp, + ) + + is_command = bool(message_text and message_text.strip().startswith("/")) + if not is_command and not self._should_respond_in_group(message_text, mentions): + self.logger.info( + f"Ignoring group message (require_mention: {self.config.group.require_mention})" + ) + return False, chat_id + return True, chat_id + + # Direct message + chat_id = sender_number + if not self.config.dm.enabled: + self.logger.debug(f"Ignoring DM from {sender_id} (DMs disabled)") + return False, chat_id + if self.config.dm.policy == "allowlist": + if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from): + self.logger.debug( + f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})" + ) + return False, chat_id + return True, chat_id + + def _assemble_inbound_content( + self, + *, + sender_name: str | None, + sender_number: str, + message_text: str, + attachments: list, + mentions: list, + is_group_message: bool, + chat_id: str, + ) -> tuple[str, list[str]]: + """Build ``(content, media_paths)`` for an inbound message. + + Pulls in group context, strips bot mentions, prefixes the sender's + display name on group messages, and copies any attachments from + signal-cli's storage into the channel media dir. + """ + content_parts: list[str] = [] + media_paths: list[str] = [] + + if is_group_message: + buffer_context = self._get_group_buffer_context(chat_id) + if buffer_context: + content_parts.append( + f"[Recent group messages for context:]\n{buffer_context}\n---" + ) + + if message_text: + if is_group_message: + message_text = self._strip_bot_mention(message_text, mentions) + display_name = sender_name or sender_number + message_text = f"[{display_name}]: {message_text}" + content_parts.append(message_text) + + if attachments: + media_dir = get_media_dir("signal") + for attachment in attachments: + attachment_id = attachment.get("id") + content_type = attachment.get("contentType", "") + filename = attachment.get("filename") or f"attachment_{attachment_id}" + if not attachment_id: + continue + try: + source_path = self._signal_attachments_dir() / attachment_id + if source_path.exists(): + dest_path = media_dir / f"signal_{safe_filename(filename)}" + shutil.copy2(source_path, dest_path) + media_paths.append(str(dest_path)) + media_type = ( + content_type.split("/")[0] if "/" in content_type else "file" + ) + if media_type not in ("image", "audio", "video"): + media_type = "file" + content_parts.append(f"[{media_type}: {dest_path}]") + self.logger.debug(f"Downloaded attachment: {filename} -> {dest_path}") + else: + self.logger.warning(f"Attachment not found: {source_path}") + content_parts.append(f"[attachment: {filename} - not found]") + except Exception as e: + self.logger.warning(f"Failed to process attachment {filename}: {e}") + content_parts.append(f"[attachment: {filename} - error]") + + content = "\n".join(content_parts) if content_parts else "[empty message]" + return content, media_paths + def _add_to_group_buffer( self, group_id: str, diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index 020118c4c..f12b2f22e 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -474,6 +474,104 @@ class TestGroupBuffer: # --------------------------------------------------------------------------- +class TestCheckInboundPolicy: + """Direct tests for the policy gate that _handle_data_message now delegates to.""" + + def _call( + self, + ch: SignalChannel, + *, + sender_id: str = "+19995550001", + sender_number: str = "+19995550001", + group_id: str | None = None, + is_group_message: bool = False, + message_text: str = "hi", + mentions: list | None = None, + sender_name: str | None = "Alice", + timestamp: int | None = 1000, + ) -> tuple[bool, str]: + return ch._check_inbound_policy( + sender_id=sender_id, + sender_number=sender_number, + group_id=group_id, + is_group_message=is_group_message, + message_text=message_text, + mentions=mentions or [], + sender_name=sender_name, + timestamp=timestamp, + ) + + def test_dm_open_allows(self): + ch = _make_channel(dm_enabled=True, dm_policy="open") + allowed, chat_id = self._call(ch) + assert allowed is True + assert chat_id == "+19995550001" + + def test_dm_disabled_blocks(self): + ch = _make_channel(dm_enabled=False) + allowed, _ = self._call(ch) + assert allowed is False + + def test_dm_allowlist_blocks_unknown_sender(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+12223334444"]) + allowed, _ = self._call(ch, sender_id="+19995550001") + assert allowed is False + + def test_dm_allowlist_allows_known_sender(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+19995550001"]) + allowed, _ = self._call(ch, sender_id="+19995550001") + assert allowed is True + + def test_group_disabled_blocks(self): + ch = _make_channel(group_enabled=False) + allowed, _ = self._call(ch, is_group_message=True, group_id="g1") + assert allowed is False + + def test_group_open_with_mention_allows(self): + ch = _make_channel( + group_enabled=True, + group_policy="open", + phone_number="+10000000000", + require_mention=True, + ) + allowed, chat_id = self._call( + ch, + is_group_message=True, + group_id="g1", + message_text="hello @bot", + mentions=[{"number": "+10000000000", "start": 6, "length": 4}], + ) + assert allowed is True + assert chat_id == "g1" + + def test_group_open_without_mention_blocks(self): + ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) + allowed, _ = self._call( + ch, is_group_message=True, group_id="g1", message_text="plain talk" + ) + assert allowed is False + + def test_group_command_bypasses_mention_requirement(self): + ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) + allowed, _ = self._call( + ch, is_group_message=True, group_id="g1", message_text="/help" + ) + assert allowed is True + + def test_allowed_group_appends_to_buffer(self): + """Side effect: when a group message is allowed, it lands in the buffer.""" + ch = _make_channel(group_enabled=True, group_policy="open", require_mention=False) + self._call(ch, is_group_message=True, group_id="g1", message_text="first") + self._call(ch, is_group_message=True, group_id="g1", message_text="second") + assert len(ch._group_buffers["g1"]) == 2 + + def test_blocked_group_does_not_append_to_buffer(self): + """Side effect: when a group is disabled, the buffer must not change.""" + ch = _make_channel(group_enabled=False) + self._call(ch, is_group_message=True, group_id="g1", message_text="hi") + assert "g1" not in ch._group_buffers + + class TestAttachmentsDir: def test_default_attachments_dir(self): ch = _make_channel() From 590ac99c8ac7932dfcac74346962f3dd5c9576ab Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 11:41:04 -0400 Subject: [PATCH 14/54] test(signal): cover SSE receive loop and the empty-phone start guard Previously the SSE loop and the empty-phone-number short-circuit in start() had zero coverage. Both now have tests: a fake httpx stream feeds canned SSE lines, exercising the valid-frame, invalid-JSON, non-200, and no-http-client paths; start() with an empty phone number is asserted to return without entering the HTTP loop. Co-Authored-By: Claude Opus 4.7 --- tests/channels/test_signal_channel.py | 109 +++++++++++++++++++++++++- 1 file changed, 108 insertions(+), 1 deletion(-) diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index f12b2f22e..53c8a2aa6 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -3,8 +3,9 @@ from __future__ import annotations import asyncio +from contextlib import asynccontextmanager from pathlib import Path -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock import pytest @@ -891,6 +892,112 @@ class TestHandleDataMessageGroup: assert ch._id_matches_account("new-bot-uuid") +# --------------------------------------------------------------------------- +# Lifecycle / SSE +# --------------------------------------------------------------------------- + + +class _FakeSSEResponse: + """Minimal stand-in for httpx Response under stream().""" + + def __init__(self, lines: list[str], status_code: int = 200) -> None: + self.status_code = status_code + self._lines = lines + + async def aiter_lines(self): + for line in self._lines: + yield line + + +def _fake_streaming_client(lines: list[str], *, status_code: int = 200) -> MagicMock: + """Return an httpx.AsyncClient stand-in whose .stream() yields a FakeSSEResponse.""" + response = _FakeSSEResponse(lines, status_code=status_code) + + @asynccontextmanager + async def _ctx(*_args, **_kwargs): + yield response + + http = MagicMock() + http.stream = lambda *a, **kw: _ctx(*a, **kw) + return http + + +class TestLifecycle: + @pytest.mark.asyncio + async def test_start_returns_early_when_phone_missing(self): + """start() with an empty phone number must not enter the HTTP loop.""" + ch = _make_channel(phone_number="") + await ch.start() + assert ch._running is False + assert ch._http is None + assert ch._sse_task is None + + +class TestSSEReceiveLoop: + @pytest.mark.asyncio + async def test_dispatches_valid_envelope(self): + ch = _make_channel() + ch._running = True + + captured: list[dict] = [] + + async def capture(params): + captured.append(params) + + ch._handle_receive_notification = capture # type: ignore[method-assign] + ch._http = _fake_streaming_client( + ['data: {"envelope":{"sourceNumber":"+19995550001"}}', ""] + ) + + # Loop ends when lines exhaust; the surrounding _start_http_mode would + # treat that as a disconnect, but the loop itself raises ConnectionError + # when the stream closes while still running. + with pytest.raises(ConnectionError): + await ch._sse_receive_loop() + assert captured == [{"envelope": {"sourceNumber": "+19995550001"}}] + + @pytest.mark.asyncio + async def test_handles_invalid_json_frame(self): + """An unparseable SSE frame is logged and skipped without crashing.""" + ch = _make_channel() + ch._running = True + + captured: list[dict] = [] + + async def capture(params): + captured.append(params) + + ch._handle_receive_notification = capture # type: ignore[method-assign] + ch._http = _fake_streaming_client( + [ + "data: this-is-not-json", + "", # event boundary triggers parse attempt + 'data: {"envelope":{"sourceNumber":"+1"}}', + "", + ] + ) + + with pytest.raises(ConnectionError): + await ch._sse_receive_loop() + # Bad frame skipped; good frame still dispatched. + assert captured == [{"envelope": {"sourceNumber": "+1"}}] + + @pytest.mark.asyncio + async def test_non_200_status_raises(self): + ch = _make_channel() + ch._running = True + ch._http = _fake_streaming_client([], status_code=503) + with pytest.raises(ConnectionError, match="status 503"): + await ch._sse_receive_loop() + + @pytest.mark.asyncio + async def test_no_http_client_raises(self): + ch = _make_channel() + ch._http = None + with pytest.raises(RuntimeError, match="HTTP client not initialized"): + await ch._sse_receive_loop() + + # --------------------------------------------------------------------------- # Command handling # --------------------------------------------------------------------------- From 9c486b90d560f48ef6723b79fe1149676d8e06c5 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 11:43:09 -0400 Subject: [PATCH 15/54] test(signal): consolidate channel-capture setup into one factory Two test classes (TestHandleDataMessageDM, TestHandleDataMessageGroup) plus three TestCommandHandling tests each repeated the same handful of lines: build a channel, mock _handle_message to record kwargs, replace _start_typing with a no-op, paper over the assignment with type: ignore. Hoist the pattern into _make_channel_with_capture and call it from all five sites. Drops 30+ lines of duplication and 7 type: ignore comments. Co-Authored-By: Claude Opus 4.7 --- tests/channels/test_signal_channel.py | 86 +++++++++------------------ 1 file changed, 29 insertions(+), 57 deletions(-) diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index 53c8a2aa6..9444e2218 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -5,7 +5,7 @@ from __future__ import annotations import asyncio from contextlib import asynccontextmanager from pathlib import Path -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest @@ -62,6 +62,24 @@ class _FakeHTTPClient: # --------------------------------------------------------------------------- +def _make_channel_with_capture(**overrides) -> tuple[SignalChannel, list[dict]]: + """Build a SignalChannel with _handle_message captured into a list and a + no-op _start_typing, used by every receive-flow test class. + """ + ch = _make_channel(**overrides) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + async def noop_typing(chat_id): + pass + + ch._handle_message = capture # type: ignore[method-assign] + ch._start_typing = noop_typing # type: ignore[method-assign] + return ch, handled + + def _make_channel( *, phone_number: str = "+10000000000", @@ -590,19 +608,9 @@ class TestAttachmentsDir: class TestHandleDataMessageDM: def _make_dm_channel(self, policy="open", allow_from=None) -> tuple[SignalChannel, list]: - ch = _make_channel(dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or []) - handled: list[dict] = [] - - async def capture(**kwargs): - handled.append(kwargs) - - ch._handle_message = capture # type: ignore[method-assign] - - async def noop_typing(chat_id): - pass - - ch._start_typing = noop_typing # type: ignore[method-assign] - return ch, handled + return _make_channel_with_capture( + dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or [] + ) @pytest.mark.asyncio async def test_dm_open_policy_accepted(self): @@ -777,24 +785,12 @@ class TestHandleDataMessageGroup: allow_from=None, require_mention=True, ) -> tuple[SignalChannel, list]: - ch = _make_channel( + return _make_channel_with_capture( group_enabled=True, group_policy=policy, group_allow_from=allow_from or [], require_mention=require_mention, ) - handled: list[dict] = [] - - async def capture(**kwargs): - handled.append(kwargs) - - ch._handle_message = capture # type: ignore[method-assign] - - async def noop_typing(chat_id): - pass - - ch._start_typing = noop_typing # type: ignore[method-assign] - return ch, handled @pytest.mark.asyncio async def test_group_disabled_rejected(self): @@ -1007,55 +1003,31 @@ class TestCommandHandling: @pytest.mark.asyncio async def test_dm_command_forwarded_to_bus(self): """Slash commands in DMs are forwarded to the bus for AgentLoop to handle.""" - ch = _make_channel(dm_enabled=True, dm_policy="open") - forwarded: list[dict] = [] - - async def capture(**kw): - forwarded.append(kw) - - ch._handle_message = capture # type: ignore[method-assign] - ch._start_typing = AsyncMock() - + ch, forwarded = _make_channel_with_capture(dm_enabled=True, dm_policy="open") params = _dm_envelope(source_number="+19995550001", message="/reset") await ch._handle_receive_notification(params) - assert len(forwarded) == 1 assert forwarded[0]["content"].strip() == "/reset" @pytest.mark.asyncio async def test_group_command_bypasses_mention_requirement(self): """Slash commands in groups bypass the mention requirement and reach the bus.""" - ch = _make_channel( + ch, forwarded = _make_channel_with_capture( group_enabled=True, group_policy="open", require_mention=True ) - forwarded: list[dict] = [] - - async def capture(**kw): - forwarded.append(kw) - - ch._handle_message = capture # type: ignore[method-assign] - ch._start_typing = AsyncMock() - - params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset") + params = _group_envelope( + source_number="+19995550001", group_id="grp==", message="/reset" + ) await ch._handle_receive_notification(params) - assert len(forwarded) == 1 assert "/reset" in forwarded[0]["content"] @pytest.mark.asyncio async def test_command_denied_for_disallowed_dm_sender(self): """Commands from senders not on the DM allowlist are dropped.""" - ch = _make_channel(dm_enabled=False) - forwarded: list[dict] = [] - - async def capture(**kw): - forwarded.append(kw) - - ch._handle_message = capture # type: ignore[method-assign] - + ch, forwarded = _make_channel_with_capture(dm_enabled=False) params = _dm_envelope(source_number="+19995550001", message="/reset") await ch._handle_receive_notification(params) - assert forwarded == [] From 632f41e4184a186d80b3d0b08ed8c47a71647fe3 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 11:45:39 -0400 Subject: [PATCH 16/54] test(signal): cover markdown adjacency, nesting, and malformed input The existing markdown suite was strong on UTF-16 offsets and chunk redistribution but had no coverage for nested or adjacent styles, no test that an unmatched opener round-trips as plain text, and no test for the blockquote/inline-code interaction. Add six cases including the documented contiguous-BOLD output for `# **wrap** me`, which Signal renders as one visual span. Co-Authored-By: Claude Opus 4.7 --- tests/channels/test_signal_markdown.py | 63 ++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/channels/test_signal_markdown.py b/tests/channels/test_signal_markdown.py index 55b081095..bc65ae6cf 100644 --- a/tests/channels/test_signal_markdown.py +++ b/tests/channels/test_signal_markdown.py @@ -459,3 +459,66 @@ def test_partition_styles_range_spanning_chunks_is_split(): # since the space was trimmed by lstrip). assert parts[0] == ["0:3:BOLD"] assert parts[1] == ["0:3:BOLD"] + + +# --------------------------------------------------------------------------- +# Adjacency, nesting, and malformed input +# --------------------------------------------------------------------------- + + +def test_bold_italic_combo_outer_bold_inner_italic(): + """`**_combo_**` carries both BOLD and ITALIC over the same span.""" + plain, styles = _markdown_to_signal("**_combo_**") + assert plain == "combo" + sd = styles_for(plain, styles) + assert set(sd.get("combo", [])) == {"BOLD", "ITALIC"} + + +def test_bold_and_italic_adjacent_no_separator(): + """`**bold***italic*` produces BOLD on `bold` and ITALIC on `italic`.""" + plain, styles = _markdown_to_signal("**bold***italic*") + assert plain == "bolditalic" + sd = styles_for(plain, styles) + assert sd.get("bold") == ["BOLD"] + assert sd.get("italic") == ["ITALIC"] + + +def test_unclosed_bold_falls_through_as_plain(): + """An unmatched `**` opener round-trips as literal text with no style.""" + plain, styles = _markdown_to_signal("**bold") + assert plain == "**bold" + assert styles == [] + + +def test_unclosed_inline_code_falls_through_as_plain(): + """An unmatched backtick round-trips as literal text with no style.""" + plain, styles = _markdown_to_signal("use `grep") + assert plain == "use `grep" + assert styles == [] + + +def test_inline_code_inside_blockquote(): + """Blockquote prefix is stripped; inline code becomes MONOSPACE.""" + plain, styles = _markdown_to_signal("> use `grep`") + assert plain == "use grep" + sd = styles_for(plain, styles) + assert sd.get("grep") == ["MONOSPACE"] + + +def test_header_with_inner_bold_produces_contiguous_bold_ranges(): + """`# **wrap** me` — header forces BOLD over the whole line; the inner `**` + splits the run, yielding two contiguous BOLD ranges that together cover + "wrap me". This is intentional — Signal renders adjacent same-style ranges + as a single visual span. + """ + plain, styles = _markdown_to_signal("# **wrap** me") + assert plain == "wrap me" + # Both ranges are BOLD; collectively they cover the whole "wrap me". + bold_ranges = [s for s in styles if s.endswith(":BOLD")] + assert len(bold_ranges) == 2 + covered = set() + for entry in bold_ranges: + start, length, _ = entry.split(":", 2) + for i in range(int(start), int(start) + int(length)): + covered.add(i) + assert covered == set(range(len(plain))) From b300ea495fcf703e5e736c6641d1e2fd4bca96b7 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 12:27:33 -0400 Subject: [PATCH 17/54] fix(signal): normalize composite sender_ids in is_allowed too The base BaseChannel.is_allowed() does a literal ``sender_id in allow_from`` check, but Signal's sender_id is a pipe-joined composite of phone/UUID parts. After splitting an allowlist entry like ``+phone|uuid`` into two separate entries, the per-DM gate accepted it but the base gate still denied because the composite sender string wasn't literally in the list. Override is_allowed on SignalChannel to delegate to _sender_matches_allowlist, which already splits both sides on ``|`` and normalizes each part. _sender_matches_allowlist itself now also splits allowlist entries on ``|`` so legacy composite entries keep working too. Co-Authored-By: Claude Opus 4.7 --- nanobot/channels/signal.py | 31 +++++++++++--- tests/channels/test_signal_channel.py | 61 +++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 6 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index c53c6a133..328f82b22 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -360,6 +360,23 @@ class SignalChannel(BaseChannel): # Each message is a dict with: sender_name, sender_number, content, timestamp self._group_buffers: dict[str, deque] = {} + def is_allowed(self, sender_id: str) -> bool: + """Override base check to normalize and split pipe-joined identifiers. + + ``sender_id`` from Signal is the pipe-joined composite produced by + ``_collect_sender_id_parts``; allow_from entries may be single + identifiers or composites and may use the ``+`` prefix variant or + not. Delegates to ``_sender_matches_allowlist`` so the base gate + matches the per-policy DM gate. + """ + allow_list = self.config.allow_from + if not allow_list: + self.logger.warning("allow_from is empty — all access denied") + return False + if "*" in allow_list: + return True + return self._sender_matches_allowlist(sender_id, allow_list) + async def start(self) -> None: """Start the Signal channel and connect to signal-cli daemon.""" if not self.config.phone_number: @@ -937,11 +954,12 @@ class SignalChannel(BaseChannel): def _sender_matches_allowlist(cls, sender_id: str, allow_list: list[str]) -> bool: """Return True if any normalized variant of sender_id is on allow_list. - sender_id is the pipe-joined identifier string built by - _collect_sender_id_parts. Each part and each allow_list entry is run - through _normalize_signal_id so an allowlist entry like ``1234567890`` - matches a sender ``+1234567890`` (and vice versa), and case-only - differences in UUIDs/ACIs match too. + Both ``sender_id`` and each allow_list entry can be a single + identifier or a pipe-joined composite of several (e.g. + ``"+1234567890|uuid-abc"``); both sides are split on ``|`` and each + part is run through ``_normalize_signal_id`` so an allowlist entry + like ``1234567890`` matches a sender ``+1234567890`` (and vice + versa), and case-only differences in UUIDs/ACIs match too. """ if not allow_list: return False @@ -952,7 +970,8 @@ class SignalChannel(BaseChannel): return False allow_variants: set[str] = set() for entry in allow_list: - allow_variants.update(cls._normalize_signal_id(entry)) + for part in str(entry).split("|"): + allow_variants.update(cls._normalize_signal_id(part)) return bool(sender_variants & allow_variants) def _remember_account_id_alias(self, value: str | None) -> None: diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index 9444e2218..28433c07a 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -493,6 +493,51 @@ class TestGroupBuffer: # --------------------------------------------------------------------------- +class TestIsAllowed: + """The base-channel allowlist gate is overridden to understand Signal's + pipe-joined composite sender_ids and the +/no-+ phone variants. + """ + + def test_denies_when_allowlist_empty(self): + ch = _make_channel(dm_enabled=True, dm_policy="open") # open -> no entries + assert ch.is_allowed("+19995550001") is False + + def test_allows_wildcard(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["*"]) + assert ch.is_allowed("+19995550001|some-uuid") is True + + def test_allows_composite_sender_against_split_allowlist(self): + """Composite sender_id, single-id allow_from — must match either part.""" + ch = _make_channel( + dm_policy="allowlist", + dm_allow_from=["+19995550001"], + ) + assert ch.is_allowed("+19995550001|1872ba20-uuid") is True + + def test_allows_composite_sender_against_composite_allowlist_entry(self): + """Backward compat: pipe-joined composite allowlist entries still match.""" + composite = "+19995550001|1872ba20-uuid" + ch = _make_channel(dm_policy="allowlist", dm_allow_from=[composite]) + assert ch.is_allowed(composite) is True + + def test_allows_when_only_uuid_part_is_listed(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["1872ba20-uuid"]) + assert ch.is_allowed("+19995550001|1872ba20-uuid") is True + + def test_denies_when_no_part_matches(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+12223334444"]) + assert ch.is_allowed("+19995550001|1872ba20-uuid") is False + + def test_allowlist_union_includes_group_ids(self): + """allow_from is the union of dm.allow_from and group.allow_from.""" + ch = _make_channel( + group_enabled=True, + group_policy="allowlist", + group_allow_from=["group-id-base64=="], + ) + assert "group-id-base64==" in ch.config.allow_from + + class TestCheckInboundPolicy: """Direct tests for the policy gate that _handle_data_message now delegates to.""" @@ -665,6 +710,22 @@ class TestHandleDataMessageDM: await ch._handle_receive_notification(params) assert len(handled) == 1 + @pytest.mark.asyncio + async def test_dm_allowlist_matches_pipe_joined_composite_entry(self): + """Allowlist entries written as ``phone|uuid`` composites still work. + + Some configs pre-date the per-part splitting and store the full + sender_id composite as a single allow_from entry. Keep matching it. + """ + composite = "+19995550001|1872ba20-f52a-4bad-b434-bf7f808c8b22" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[composite]) + params = _dm_envelope( + source_number="+19995550001", + source_uuid="1872ba20-f52a-4bad-b434-bf7f808c8b22", + ) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + @pytest.mark.asyncio async def test_dm_disabled_rejected(self): ch = _make_channel(dm_enabled=False) From 96767ca1798de2801513a6c5ef39d531ef6900a6 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Sat, 16 May 2026 12:52:33 -0400 Subject: [PATCH 18/54] Cleanup --- nanobot/channels/signal.py | 83 +++++++++++++------------- tests/channels/test_signal_channel.py | 20 ++----- tests/channels/test_signal_markdown.py | 19 +++--- 3 files changed, 55 insertions(+), 67 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 328f82b22..e40f5197e 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -32,17 +32,19 @@ class _Run: opaque: bool = False # code / table content — skip further pattern processing -_SIG_CODE_BLOCK_RE = re.compile(r'```(?:\w+)?\n?([\s\S]*?)```') -_SIG_INLINE_CODE_RE = re.compile(r'`([^`\n]+)`') -_SIG_HEADER_RE = re.compile(r'^#{1,6}\s+(.+)$', re.MULTILINE) -_SIG_BLOCKQUOTE_RE = re.compile(r'^>\s*(.*)$', re.MULTILINE) -_SIG_BULLET_RE = re.compile(r'^[-*]\s+', re.MULTILINE) -_SIG_OLIST_RE = re.compile(r'^(\d+)\.\s+', re.MULTILINE) -_SIG_LINK_RE = re.compile(r'\[([^\]]+)\]\(([^)]+)\)') -_SIG_BOLD_RE = re.compile(r'\*\*(.+?)\*\*|__(.+?)__', re.DOTALL) -_SIG_ITALIC_RE = re.compile(r'(?\s*(.*)$", re.MULTILINE) +_SIG_BULLET_RE = re.compile(r"^[-*]\s+", re.MULTILINE) +_SIG_OLIST_RE = re.compile(r"^(\d+)\.\s+", re.MULTILINE) +_SIG_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") +_SIG_BOLD_RE = re.compile(r"\*\*(.+?)\*\*|__(.+?)__", re.DOTALL) +_SIG_ITALIC_RE = re.compile( + r"(? str: """Render a markdown pipe-table as fixed-width plain text.""" def dw(s: str) -> int: - return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s) + return sum(2 if unicodedata.east_asian_width(c) in ("W", "F") else 1 for c in s) rows: list[list[str]] = [] has_sep = False for line in table_lines: - cells = [_sig_strip_cell(c) for c in line.strip().strip('|').split('|')] - if all(re.match(r'^:?-+:?$', c) for c in cells if c): + cells = [_sig_strip_cell(c) for c in line.strip().strip("|").split("|")] + if all(re.match(r"^:?-+:?$", c) for c in cells if c): has_sep = True continue rows.append(cells) if not rows or not has_sep: - return '\n'.join(table_lines) + return "\n".join(table_lines) ncols = max(len(r) for r in rows) for r in rows: - r.extend([''] * (ncols - len(r))) + r.extend([""] * (ncols - len(r))) widths = [max(dw(r[c]) for r in rows) for c in range(ncols)] def dr(cells: list[str]) -> str: - return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths)) + return " ".join(f"{c}{' ' * (w - dw(c))}" for c, w in zip(cells, widths)) out = [dr(rows[0])] - out.append(' '.join('─' * w for w in widths)) + out.append(" ".join("─" * w for w in widths)) for row in rows[1:]: out.append(dr(row)) - return '\n'.join(out) + return "\n".join(out) def _markdown_to_signal(text: str) -> tuple[str, list[str]]: @@ -121,17 +123,17 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]: text = _SIG_CODE_BLOCK_RE.sub(save_code, text) # Detect and render pipe-tables line by line. - lines = text.split('\n') + lines = text.split("\n") rebuilt: list[str] = [] i = 0 while i < len(lines): - if re.match(r'^\s*\|.+\|', lines[i]): + if re.match(r"^\s*\|.+\|", lines[i]): tbl: list[str] = [] - while i < len(lines) and re.match(r'^\s*\|.+\|', lines[i]): + while i < len(lines) and re.match(r"^\s*\|.+\|", lines[i]): tbl.append(lines[i]) i += 1 rendered = _sig_render_table(tbl) - if rendered != '\n'.join(tbl): + if rendered != "\n".join(tbl): protected.append(rendered) rebuilt.append(f"\x00C{len(protected) - 1}\x00") else: @@ -139,7 +141,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]: else: rebuilt.append(lines[i]) i += 1 - text = '\n'.join(rebuilt) + text = "\n".join(rebuilt) # Phase 2 (run-based): process inline patterns. runs: list[_Run] = [_Run(text)] @@ -156,7 +158,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]: pos = 0 for m in pattern.finditer(run.text): if m.start() > pos: - new_runs.append(_Run(run.text[pos:m.start()], run.styles)) + new_runs.append(_Run(run.text[pos : m.start()], run.styles)) new_runs.extend(make_runs(m, run.styles)) pos = m.end() if pos < len(run.text): @@ -164,7 +166,10 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]: runs[:] = new_runs # Restore code/table placeholders as opaque MONOSPACE runs. - transform(_SIG_TOKEN_RE, lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)]) + transform( + _SIG_TOKEN_RE, + lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)], + ) # Inline code (opaque). transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)]) @@ -186,7 +191,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]: link_text, url = m.group(1), m.group(2) def _norm(u: str) -> str: - return re.sub(r'^https?://(www\.)?', '', u).rstrip('/').lower() + return re.sub(r"^https?://(www\.)?", "", u).rstrip("/").lower() if _norm(url) == _norm(link_text): return [_Run(url, s)] @@ -581,9 +586,7 @@ class SignalChannel(BaseChannel): raise @asynccontextmanager - async def _safe_handle( - self, action: str, payload: Any = None - ) -> AsyncIterator[None]: + async def _safe_handle(self, action: str, payload: Any = None) -> AsyncIterator[None]: """Swallow and log any exception from a top-level handler block. Logs `self.logger.error` with the action name, the exception, and a @@ -788,9 +791,7 @@ class SignalChannel(BaseChannel): return False, chat_id if self.config.dm.policy == "allowlist": if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from): - self.logger.debug( - f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})" - ) + self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})") return False, chat_id return True, chat_id @@ -817,9 +818,7 @@ class SignalChannel(BaseChannel): if is_group_message: buffer_context = self._get_group_buffer_context(chat_id) if buffer_context: - content_parts.append( - f"[Recent group messages for context:]\n{buffer_context}\n---" - ) + content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---") if message_text: if is_group_message: @@ -842,9 +841,7 @@ class SignalChannel(BaseChannel): dest_path = media_dir / f"signal_{safe_filename(filename)}" shutil.copy2(source_path, dest_path) media_paths.append(str(dest_path)) - media_type = ( - content_type.split("/")[0] if "/" in content_type else "file" - ) + media_type = content_type.split("/")[0] if "/" in content_type else "file" if media_type not in ("image", "audio", "video"): media_type = "file" content_parts.append(f"[{media_type}: {dest_path}]") diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index 28433c07a..56a35b94b 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -610,16 +610,12 @@ class TestCheckInboundPolicy: def test_group_open_without_mention_blocks(self): ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) - allowed, _ = self._call( - ch, is_group_message=True, group_id="g1", message_text="plain talk" - ) + allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="plain talk") assert allowed is False def test_group_command_bypasses_mention_requirement(self): ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) - allowed, _ = self._call( - ch, is_group_message=True, group_id="g1", message_text="/help" - ) + allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="/help") assert allowed is True def test_allowed_group_appends_to_buffer(self): @@ -703,9 +699,7 @@ class TestHandleDataMessageDM: async def test_dm_allowlist_matches_uuid_case_insensitive(self): """UUID matching must be case-insensitive.""" uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890" - ch, handled = self._make_dm_channel( - policy="allowlist", allow_from=[uuid.lower()] - ) + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[uuid.lower()]) params = _dm_envelope(source_number="+19995550001", source_uuid=uuid) await ch._handle_receive_notification(params) assert len(handled) == 1 @@ -1076,9 +1070,7 @@ class TestCommandHandling: ch, forwarded = _make_channel_with_capture( group_enabled=True, group_policy="open", require_mention=True ) - params = _group_envelope( - source_number="+19995550001", group_id="grp==", message="/reset" - ) + params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset") await ch._handle_receive_notification(params) assert len(forwarded) == 1 assert "/reset" in forwarded[0]["content"] @@ -1357,9 +1349,7 @@ def test_config_allow_from_aggregates_dm_and_group() -> None: enabled=True, phone_number="+10000000000", dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]), - group=SignalGroupConfig( - enabled=True, policy="allowlist", allow_from=["+3333", "+1111"] - ), + group=SignalGroupConfig(enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]), ) combined = config.allow_from assert "+1111" in combined diff --git a/tests/channels/test_signal_markdown.py b/tests/channels/test_signal_markdown.py index bc65ae6cf..37a21c6d8 100644 --- a/tests/channels/test_signal_markdown.py +++ b/tests/channels/test_signal_markdown.py @@ -1,7 +1,5 @@ """Unit tests for the Signal markdown → plain text + textStyle converter.""" -import pytest - from nanobot.channels.signal import _markdown_to_signal, _partition_styles from nanobot.utils.helpers import split_message @@ -94,8 +92,9 @@ def test_inline_code(): def test_code_block(): plain, styles = _markdown_to_signal("```\nprint('hi')\n```") assert "print('hi')" in plain - assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or \ - "MONOSPACE" in str(styles_for(plain, styles)) + assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or "MONOSPACE" in str( + styles_for(plain, styles) + ) def test_code_block_with_lang(): @@ -278,9 +277,7 @@ def assert_within_utf16_bounds(plain: str, styles: list[str]) -> None: start_s, length_s, _ = entry.split(":", 2) start, length = int(start_s), int(length_s) assert start >= 0 - assert start + length <= limit, ( - f"range {entry} exceeds utf-16 length {limit} of {plain!r}" - ) + assert start + length <= limit, f"range {entry} exceeds utf-16 length {limit} of {plain!r}" def test_bold_with_emoji_inside(): @@ -413,7 +410,9 @@ def test_partition_styles_long_message_preserves_chunk_one_styles(): for entry in final_styles: s, ln, _ = entry.split(":", 2) start, length = int(s), int(ln) - slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le") + slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode( + "utf-16-le" + ) assert slice_ == "tail" @@ -444,7 +443,9 @@ def test_partition_styles_with_non_bmp_chunk_offset(): for entry in final_styles: s, ln, _ = entry.split(":", 2) start, length = int(s), int(ln) - slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le") + slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode( + "utf-16-le" + ) assert slice_ == "tail" From d653f23aba01b5fe8e7b2f728a90a8a45efdc0cc Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Tue, 19 May 2026 08:56:06 -0400 Subject: [PATCH 19/54] fix(signal): raise on signal-cli error response so send is retriable _send_http_request collapses every exception path into a {"error": ...} dict, so the if "error" in response branch inside send() is the only place where send failures surface. Logging-only there meant the ChannelManager retry mechanism never fired. Raise RuntimeError so the base-class retry path is exercised; the outer try/except already re-raises into the caller. Addresses review comment on PR #3852. --- nanobot/channels/signal.py | 1 + tests/channels/test_signal_channel.py | 9 +++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index e40f5197e..64e099f4d 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -506,6 +506,7 @@ class SignalChannel(BaseChannel): if "error" in response: self.logger.error(f"Error sending Signal message: {response['error']}") + raise RuntimeError(f"signal-cli send failed: {response['error']}") else: self.logger.debug( f"Signal message sent, timestamp: {response.get('result', {}).get('timestamp')}" diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index 56a35b94b..214117176 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -1217,13 +1217,14 @@ class TestSend: assert "+19995550001" in stopped @pytest.mark.asyncio - async def test_send_logs_daemon_error_without_raising(self): + async def test_send_raises_on_daemon_error(self): + # _send_http_request turns every exception into {"error": ...}, so this branch + # is the only place ChannelManager retry can be triggered — must raise. ch = _make_channel() - # The daemon returns {"error": {...}} in the JSON body — this is not a Python - # exception; send() logs it but does not raise (only HTTP-level exceptions raise). ch._http = _FakeHTTPClient(default_response={"error": {"message": "fail"}}) msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") - await ch.send(msg) # must not raise + with pytest.raises(RuntimeError, match="signal-cli send failed"): + await ch.send(msg) # --------------------------------------------------------------------------- From d376ec129d47c1d3e1383f04ad8e80a7fafc0117 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Tue, 19 May 2026 08:56:28 -0400 Subject: [PATCH 20/54] fix(signal): pass is_dm to _handle_message so DM pairing flow runs BaseChannel._handle_message uses is_dm to decide whether to issue a pairing code when is_allowed rejects the sender. Without it the base class treats every denied message as a group message and silently drops it. Forward is_dm=not is_group_message so unapproved DM users get a pairing code through the standard flow. This change only takes effect once denied DMs actually reach _handle_message (next commit); on its own it is a no-op since the policy gate still short-circuits before this call. Addresses review comment on PR #3852. --- nanobot/channels/signal.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 64e099f4d..01e2cd981 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -732,6 +732,7 @@ class SignalChannel(BaseChannel): "is_group": is_group_message, "group_id": group_id, }, + is_dm=not is_group_message, ) except Exception: await self._stop_typing(chat_id) From dc332476710a44d53a74d3126da8fbc5ade17f08 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Tue, 19 May 2026 08:57:47 -0400 Subject: [PATCH 21/54] fix(signal): route denied DMs through _handle_message for pairing code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously _check_inbound_policy returned (False, chat_id) for DMs that failed the allowlist and the caller dropped them — so unapproved DM senders never saw a pairing code. Mirror Slack: when the policy gate denies a DM but dm.enabled is true, still call _handle_message(content="", is_dm=True) so BaseChannel can issue the pairing reply. Group denials stay a hard drop. Combined with the previous is_dm forwarding, unapproved DM senders now receive a pairing code through the standard flow. Addresses review comment on PR #3852. --- nanobot/channels/signal.py | 9 +++++++++ tests/channels/test_signal_channel.py | 9 +++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 01e2cd981..781a6bdd7 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -704,6 +704,15 @@ class SignalChannel(BaseChannel): timestamp=timestamp, ) if not allowed: + # Mirror Slack: let denied DMs reach _handle_message so the base + # class can reply with a pairing code. Group denials stay dropped. + if not is_group_message and self.config.dm.enabled: + await self._handle_message( + sender_id=sender_id, + chat_id=chat_id, + content="", + is_dm=True, + ) return content, media_paths = self._assemble_inbound_content( diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index 214117176..3a0565570 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -670,11 +670,16 @@ class TestHandleDataMessageDM: assert len(handled) == 1 @pytest.mark.asyncio - async def test_dm_allowlist_rejected(self): + async def test_dm_allowlist_rejected_triggers_pairing(self): + # Denied DM senders are routed to _handle_message with empty content + # and is_dm=True so BaseChannel issues a pairing code (mirrors Slack). ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+10000000001"]) params = _dm_envelope(source_number="+19995550002") await ch._handle_receive_notification(params) - assert handled == [] + assert len(handled) == 1 + assert handled[0]["content"] == "" + assert handled[0]["is_dm"] is True + assert handled[0]["chat_id"] == "+19995550002" @pytest.mark.asyncio async def test_dm_allowlist_matches_without_plus_prefix(self): From 82dfe8c1f7f64ec562cc937f045484f93409ad2c Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Tue, 19 May 2026 08:58:04 -0400 Subject: [PATCH 22/54] fix(signal): join multi-line SSE data with newline per spec Per the SSE spec, multiple data: lines within a single event must be joined with \n before parsing. signal-cli emits single-line JSON so this was latent, but the joining was wrong. Addresses review comment on PR #3852. --- nanobot/channels/signal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 781a6bdd7..ecee7116d 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -557,7 +557,7 @@ class SignalChannel(BaseChannel): # Try to parse the accumulated data data_str = "" try: - data_str = "".join(event_buffer) + data_str = "\n".join(event_buffer) data = json.loads(data_str) self.logger.debug(f"SSE event parsed: {data}") await self._handle_receive_notification(data) From b3d0d24a52cfa203f12f1c46cb44073c9b5373d2 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Tue, 19 May 2026 08:59:17 -0400 Subject: [PATCH 23/54] fix(signal): consult pairing store in is_allowed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit BaseChannel.is_allowed ORs is_approved (the pairing store) into the allow decision; the signal override dropped that step and only looked at config.allow_from. With the new DM-pairing flow in place, an approved-via-pairing sender's next message would have failed the allow check and triggered another pairing code in a loop. OR in a normalized check against the pairing store: walk each part of the pipe-joined sender_id through _normalize_signal_id and call is_approved for each variant, so an approval stored under one form (phone with/without "+", UUID/ACI) still matches when the next inbound uses a different form. Mirrors how slack.py:643 handles it. Also tightens the empty-allowlist warning to only fire when nothing else granted access, since pairing-store hits are now a valid path. Not part of the original review, but Comments 2 and 3 turn this latent gap into a broken round-trip — included so the pairing UX actually works. --- nanobot/channels/signal.py | 25 +++++++++++++++++++++---- tests/channels/test_signal_channel.py | 16 ++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index ecee7116d..66b7d2b40 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -22,6 +22,7 @@ from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base +from nanobot.pairing import is_approved from nanobot.utils.helpers import safe_filename, split_message @@ -375,12 +376,28 @@ class SignalChannel(BaseChannel): matches the per-policy DM gate. """ allow_list = self.config.allow_from - if not allow_list: - self.logger.warning("allow_from is empty — all access denied") - return False if "*" in allow_list: return True - return self._sender_matches_allowlist(sender_id, allow_list) + if self._sender_matches_allowlist(sender_id, allow_list): + return True + if self._sender_approved_via_pairing(sender_id): + return True + if not allow_list: + self.logger.warning("allow_from is empty — all access denied") + return False + + def _sender_approved_via_pairing(self, sender_id: str) -> bool: + """Return True if any normalized variant of sender_id is in the pairing store. + + Pairing approval may be recorded under any of the identifier forms + signal exposes (phone with/without ``+``, UUID, ACI), so we check + each part of the pipe-joined composite against ``is_approved``. + """ + for part in str(sender_id).split("|"): + for variant in self._normalize_signal_id(part): + if is_approved(self.name, variant): + return True + return False async def start(self) -> None: """Start the Signal channel and connect to signal-cli daemon.""" diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index 3a0565570..9822ff3b6 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -681,6 +681,22 @@ class TestHandleDataMessageDM: assert handled[0]["is_dm"] is True assert handled[0]["chat_id"] == "+19995550002" + @pytest.mark.asyncio + async def test_dm_paired_sender_allowed_without_allowlist_entry(self, monkeypatch): + # Once a sender completes pairing they should pass is_allowed on every + # subsequent message — otherwise the pairing reply loops forever. + approved = {"+19995550002"} + monkeypatch.setattr( + "nanobot.channels.signal.is_approved", + lambda channel, sender_id: sender_id in approved, + ) + ch = _make_channel(dm_enabled=True, dm_policy="allowlist", dm_allow_from=[]) + assert ch.is_allowed("+19995550002") is True + # Variant forms (with/without "+") must still match a stored approval. + assert ch.is_allowed("19995550002") is True + # Unpaired sender stays denied. + assert ch.is_allowed("+19995559999") is False + @pytest.mark.asyncio async def test_dm_allowlist_matches_without_plus_prefix(self): """An allowlist entry without '+' must match a sender that carries '+'.""" From 886e7e43d5facef585b1b736faa5e8172ae44bcd Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Wed, 20 May 2026 00:07:54 +0800 Subject: [PATCH 24/54] fix(signal): bypass base is_allowed for policy-approved messages Override _handle_message to publish directly to the bus for messages that have already passed _check_inbound_policy. The denied DM pairing path calls super()._handle_message() to issue pairing codes via the base class. This avoids cross-policy leakage where e.g. group open policy would cause is_allowed to incorrectly allow denied DM senders. Also includes: - SSE: strip one optional leading space after 'data:' per spec - Convert 20+ f-string log calls to loguru lazy formatting - Add end-to-end tests for DM/group routing through the full chain - Add cross-policy test (dm allowlist + group open) for pairing - Add Signal channel documentation to docs/chat-apps.md --- docs/chat-apps.md | 67 +++++++++++++ nanobot/channels/signal.py | 133 ++++++++++++++++-------- tests/channels/test_signal_channel.py | 139 ++++++++++++++++++++++++-- 3 files changed, 291 insertions(+), 48 deletions(-) diff --git a/docs/chat-apps.md b/docs/chat-apps.md index c0c1b4ba0..88242a5f7 100644 --- a/docs/chat-apps.md +++ b/docs/chat-apps.md @@ -17,6 +17,7 @@ Connect nanobot to your favorite chat platform. Want to build your own? See the | **Wecom** | Bot ID + Bot Secret | | **Microsoft Teams** | App ID + App Password + public HTTPS endpoint | | **Mochat** | Claw token (auto-setup available) | +| **Signal** | signal-cli daemon + phone number |

Telegram (Recommended) @@ -669,3 +670,69 @@ nanobot gateway ```
+ +
+Signal + +Uses **signal-cli** daemon in HTTP mode — receive messages via SSE, send via JSON-RPC. + +**1. Install signal-cli** + +Install [signal-cli](https://github.com/AsamK/signal-cli) and register a phone number: + +```bash +signal-cli -u +1234567890 register +signal-cli -u +1234567890 verify +``` + +Start the daemon: + +```bash +signal-cli -a +1234567890 daemon --http localhost:8080 +``` + +**2. Configure** + +```json +{ + "channels": { + "signal": { + "enabled": true, + "phoneNumber": "+1234567890", + "daemonHost": "localhost", + "daemonPort": 8080, + "dm": { + "enabled": true, + "policy": "open" + }, + "group": { + "enabled": true, + "policy": "open", + "requireMention": true + } + } + } +} +``` + +> - `phoneNumber`: Your registered Signal phone number. +> - `daemonHost` / `daemonPort`: Where signal-cli daemon is listening (default `localhost:8080`). +> - `dm.policy`: `"open"` (anyone can DM) or `"allowlist"` (only listed numbers/UUIDs). When `"allowlist"`, unlisted DM senders receive a pairing code. +> - `dm.allowFrom`: List of allowed phone numbers or UUIDs (used when policy is `"allowlist"`). +> - `group.policy`: `"open"` (all groups) or `"allowlist"` (only listed group IDs). +> - `group.requireMention`: When `true` (default), the bot only responds in groups when @mentioned. +> - `group.allowFrom`: List of allowed group IDs (used when group policy is `"allowlist"`). +> - `attachmentsDir`: Override the directory where signal-cli stores inbound attachments. Defaults to `~/.local/share/signal-cli/attachments` (the Linux default). Set this if signal-cli runs with a custom `XDG_DATA_HOME` or on macOS/Windows. +> - `groupMessageBufferSize`: Number of recent group messages kept for context (default `20`, must be > 0). + +**3. Run** + +```bash +nanobot gateway +``` + +> [!TIP] +> The channel automatically reconnects to the signal-cli daemon with exponential backoff if the connection drops. +> Markdown in bot replies is automatically converted to Signal text styles (bold, italic, code, etc.). + +
diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py index 66b7d2b40..2a38f60ac 100644 --- a/nanobot/channels/signal.py +++ b/nanobot/channels/signal.py @@ -17,7 +17,7 @@ from typing import Any import httpx from pydantic import Field, computed_field, field_validator -from nanobot.bus.events import OutboundMessage +from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.paths import get_media_dir @@ -399,6 +399,39 @@ class SignalChannel(BaseChannel): return True return False + async def _handle_message( + self, + sender_id: str, + chat_id: str, + content: str, + media: list[str] | None = None, + metadata: dict[str, Any] | None = None, + session_key: str | None = None, + is_dm: bool = False, + ) -> None: + """Handle an inbound message whose policy has already been checked. + + ``_check_inbound_policy`` is the authoritative gate for DM/group + access, so we skip the base-class ``is_allowed()`` check and publish + directly to the bus. The denied-DM pairing path calls + ``super()._handle_message`` instead, which goes through + ``is_allowed`` and issues a pairing code. + """ + meta = metadata or {} + if self.supports_streaming: + meta = {**meta, "_wants_stream": True} + await self.bus.publish_inbound( + InboundMessage( + channel=self.name, + sender_id=str(sender_id), + chat_id=str(chat_id), + content=content, + media=media or [], + metadata=meta, + session_key_override=session_key, + ) + ) + async def start(self) -> None: """Start the Signal channel and connect to signal-cli daemon.""" if not self.config.phone_number: @@ -416,7 +449,7 @@ class SignalChannel(BaseChannel): while self._running: try: - self.logger.info(f"Connecting to signal-cli daemon at {base_url}...") + self.logger.info("Connecting to signal-cli daemon at {}...", base_url) # Create HTTP client self._http = httpx.AsyncClient( @@ -452,11 +485,15 @@ class SignalChannel(BaseChannel): break except ConnectionRefusedError as e: self.logger.error( - f"{e}. Make sure signal-cli daemon is running: " - f"signal-cli -a {self.config.phone_number} daemon --http {self.config.daemon_host}:{self.config.daemon_port}" + "{}. Make sure signal-cli daemon is running: " + "signal-cli -a {} daemon --http {}:{}", + e, + self.config.phone_number, + self.config.daemon_host, + self.config.daemon_port, ) except Exception as e: - self.logger.error(f"Signal channel error: {e}") + self.logger.error("Signal channel error: {}", e) finally: if self._sse_task: if not self._sse_task.done(): @@ -474,7 +511,7 @@ class SignalChannel(BaseChannel): if self._running: self.logger.info( - f"Reconnecting to signal-cli daemon in {reconnect_delay_s:.0f} seconds..." + "Reconnecting to signal-cli daemon in {:.0f} seconds...", reconnect_delay_s ) await asyncio.sleep(reconnect_delay_s) reconnect_delay_s = min(reconnect_delay_s * 2, max_reconnect_delay_s) @@ -522,7 +559,7 @@ class SignalChannel(BaseChannel): response = await self._send_request("send", params) if "error" in response: - self.logger.error(f"Error sending Signal message: {response['error']}") + self.logger.error("Error sending Signal message: {}", response['error']) raise RuntimeError(f"signal-cli send failed: {response['error']}") else: self.logger.debug( @@ -564,7 +601,7 @@ class SignalChannel(BaseChannel): # Debug: log raw SSE lines (except keepalive pings) if line and line != ":": - self.logger.debug(f"SSE line received: {line[:200]}") + self.logger.debug("SSE line received: {}", line[:200]) # SSE format handling if isinstance(line, str): @@ -576,18 +613,21 @@ class SignalChannel(BaseChannel): try: data_str = "\n".join(event_buffer) data = json.loads(data_str) - self.logger.debug(f"SSE event parsed: {data}") + self.logger.debug("SSE event parsed: {}", data) await self._handle_receive_notification(data) except json.JSONDecodeError as e: self.logger.warning( - f"Invalid JSON in SSE buffer: {e}, data: {data_str[:200]}" + "Invalid JSON in SSE buffer: {}, data: {}", + e, + data_str[:200], ) finally: event_buffer = [] # "data:" line - accumulate it elif line.startswith("data:"): - event_buffer.append(line[5:]) # Skip "data:" prefix + # SSE spec: strip one optional leading space after "data:". + event_buffer.append(line[6:] if line[5:6] == " " else line[5:]) # "event:" line - just log it (we only care about data) elif line.startswith("event:"): @@ -600,7 +640,7 @@ class SignalChannel(BaseChannel): self.logger.info("SSE receive loop cancelled") raise except Exception as e: - self.logger.error(f"Error in SSE receive loop: {e}") + self.logger.error("Error in SSE receive loop: {}", e) raise @asynccontextmanager @@ -622,12 +662,12 @@ class SignalChannel(BaseChannel): async def _handle_receive_notification(self, params: dict[str, Any]) -> None: """Handle incoming message notification from signal-cli.""" - self.logger.debug(f"_handle_receive_notification called with: {params}") + self.logger.debug("_handle_receive_notification called with: {}", params) async with self._safe_handle("receive notification", params): # Extract envelope from SSE notification: {"envelope": {...}} envelope = params.get("envelope", {}) - self.logger.debug(f"Extracted envelope: {envelope}") + self.logger.debug("Extracted envelope: {}", envelope) if not envelope: self.logger.debug("No envelope found in params") @@ -669,7 +709,7 @@ class SignalChannel(BaseChannel): destination = sent_msg.get("destination") or sent_msg.get("destinationNumber") if destination: self.logger.debug( - f"Sync message sent to {destination}: {sent_msg.get('message', '')[:50]}" + "Sync message sent to {}: {}", destination, sent_msg.get("message", "")[:50] ) # Handle typing indicators (silently ignore) @@ -690,19 +730,20 @@ class SignalChannel(BaseChannel): timestamp = data_message.get("timestamp") self.logger.info( - f"Data message from {sender_number}: " - f"groupInfo={data_message.get('groupInfo')}, " - f"groupV2={data_message.get('groupV2')}, " - f"keys={list(data_message.keys())}" + "Data message from {}: groupInfo={}, groupV2={}, keys={}", + sender_number, + data_message.get("groupInfo"), + data_message.get("groupV2"), + list(data_message.keys()), ) if data_message.get("reaction"): self.logger.debug( - f"Ignoring reaction message from {sender_number}: {data_message['reaction']}" + "Ignoring reaction message from {}: {}", sender_number, data_message["reaction"] ) return if not message_text and not attachments: - self.logger.debug(f"Ignoring empty message from {sender_number}") + self.logger.debug("Ignoring empty message from {}", sender_number) return group_info = data_message.get("groupInfo") @@ -721,10 +762,11 @@ class SignalChannel(BaseChannel): timestamp=timestamp, ) if not allowed: - # Mirror Slack: let denied DMs reach _handle_message so the base - # class can reply with a pairing code. Group denials stay dropped. + # Mirror Slack: let denied DMs reach the base-class + # _handle_message so it can reply with a pairing code. + # Group denials stay dropped. if not is_group_message and self.config.dm.enabled: - await self._handle_message( + await super()._handle_message( sender_id=sender_id, chat_id=chat_id, content="", @@ -742,7 +784,7 @@ class SignalChannel(BaseChannel): chat_id=chat_id, ) - self.logger.debug(f"Signal message from {sender_number}: {content[:50]}...") + self.logger.debug("Signal message from {}: {}...", sender_number, content[:50]) await self._start_typing(chat_id) try: @@ -785,14 +827,16 @@ class SignalChannel(BaseChannel): if is_group_message: chat_id = group_id or sender_number if not self.config.group.enabled: - self.logger.info(f"Ignoring group message from {chat_id} (groups disabled)") + self.logger.info("Ignoring group message from {} (groups disabled)", chat_id) return False, chat_id if ( self.config.group.policy == "allowlist" and chat_id not in self.config.group.allow_from ): self.logger.info( - f"Ignoring group message from {chat_id} (policy: {self.config.group.policy})" + "Ignoring group message from {} (policy: {})", + chat_id, + self.config.group.policy, ) return False, chat_id @@ -807,7 +851,8 @@ class SignalChannel(BaseChannel): is_command = bool(message_text and message_text.strip().startswith("/")) if not is_command and not self._should_respond_in_group(message_text, mentions): self.logger.info( - f"Ignoring group message (require_mention: {self.config.group.require_mention})" + "Ignoring group message (require_mention: {})", + self.config.group.require_mention, ) return False, chat_id return True, chat_id @@ -815,11 +860,13 @@ class SignalChannel(BaseChannel): # Direct message chat_id = sender_number if not self.config.dm.enabled: - self.logger.debug(f"Ignoring DM from {sender_id} (DMs disabled)") + self.logger.debug("Ignoring DM from {} (DMs disabled)", sender_id) return False, chat_id if self.config.dm.policy == "allowlist": if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from): - self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})") + self.logger.debug( + "Ignoring DM from {} (policy: {})", sender_id, self.config.dm.policy + ) return False, chat_id return True, chat_id @@ -873,12 +920,12 @@ class SignalChannel(BaseChannel): if media_type not in ("image", "audio", "video"): media_type = "file" content_parts.append(f"[{media_type}: {dest_path}]") - self.logger.debug(f"Downloaded attachment: {filename} -> {dest_path}") + self.logger.debug("Downloaded attachment: {} -> {}", filename, dest_path) else: - self.logger.warning(f"Attachment not found: {source_path}") + self.logger.warning("Attachment not found: {}", source_path) content_parts.append(f"[attachment: {filename} - not found]") except Exception as e: - self.logger.warning(f"Failed to process attachment {filename}: {e}") + self.logger.warning("Failed to process attachment {}: {}", filename, e) content_parts.append(f"[attachment: {filename} - error]") content = "\n".join(content_parts) if content_parts else "[empty message]" @@ -917,8 +964,10 @@ class SignalChannel(BaseChannel): ) self.logger.debug( - f"Added message to group buffer {group_id}: " - f"{len(self._group_buffers[group_id])}/{self.config.group_message_buffer_size}" + "Added message to group buffer {}: {}/{}", + group_id, + len(self._group_buffers[group_id]), + self.config.group_message_buffer_size, ) def _get_group_buffer_context(self, group_id: str) -> str: @@ -1269,7 +1318,7 @@ class SignalChannel(BaseChannel): except asyncio.CancelledError: pass except Exception as e: - self.logger.debug(f"Typing indicator loop stopped for {chat_id}: {e}") + self.logger.debug("Typing indicator loop stopped for {}: {}", chat_id, e) async def _send_typing( self, chat_id: str, stop: bool = False, quiet_success: bool = False @@ -1304,18 +1353,22 @@ class SignalChannel(BaseChannel): if "error" not in response: if not quiet_success: - self.logger.info(f"Signal typing {action} sent for {chat_id}") + self.logger.info("Signal typing {} sent for {}", action, chat_id) return last_error = response["error"] - self.logger.warning(f"Failed to send Signal typing {action} for {chat_id}: {last_error}") + self.logger.warning( + "Failed to send Signal typing {} for {}: {}", action, chat_id, last_error + ) async def _ensure_typing_indicators_enabled(self) -> None: """Enable typing indicators on the bot account.""" response = await self._send_request("updateConfiguration", {"typingIndicators": True}) if "error" in response: - self.logger.warning(f"Failed to enable Signal typing indicators: {response['error']}") + self.logger.warning( + "Failed to enable Signal typing indicators: {}", response["error"] + ) else: self.logger.info("Signal typing indicators enabled on account configuration") @@ -1345,5 +1398,5 @@ class SignalChannel(BaseChannel): response.raise_for_status() return response.json() except Exception as e: - self.logger.error(f"HTTP request failed: {e}") + self.logger.error("HTTP request failed: {}", e) return {"error": {"message": str(e)}} diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py index 9822ff3b6..277c85b83 100644 --- a/tests/channels/test_signal_channel.py +++ b/tests/channels/test_signal_channel.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock import pytest -from nanobot.bus.events import OutboundMessage +from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.signal import ( SignalChannel, @@ -499,7 +499,12 @@ class TestIsAllowed: """ def test_denies_when_allowlist_empty(self): - ch = _make_channel(dm_enabled=True, dm_policy="open") # open -> no entries + ch = _make_channel(dm_enabled=True, dm_policy="allowlist") + assert ch.is_allowed("+19995550001") is False + + def test_denies_when_no_policy_allows(self): + """When both dm and group are disabled, is_allowed denies.""" + ch = _make_channel(dm_enabled=False, group_enabled=False) assert ch.is_allowed("+19995550001") is False def test_allows_wildcard(self): @@ -538,6 +543,121 @@ class TestIsAllowed: assert "group-id-base64==" in ch.config.allow_from +class TestEndToEndDMRouting: + """End-to-end tests that keep the real _handle_message chain (no mock), + verifying that _check_inbound_policy + _handle_message work together + correctly for DM routing. The override of _handle_message publishes + directly to bus (policy already checked); denied DMs call + super()._handle_message which issues a pairing code. + """ + + @pytest.mark.asyncio + async def test_open_dm_policy_publishes_to_bus(self): + """Open DM: _check_inbound_policy passes → _handle_message publishes.""" + ch = _make_channel(dm_enabled=True, dm_policy="open") + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550001", message="hello") + await ch._handle_receive_notification(params) + + assert len(published) == 1 + assert published[0].content == "hello" + assert published[0].sender_id == "+19995550001" + + @pytest.mark.asyncio + async def test_allowlist_dm_denied_triggers_pairing(self): + """Allowlist DM: denied sender triggers pairing code via send().""" + ch = _make_channel(dm_enabled=True, dm_policy="allowlist", dm_allow_from=[]) + ch._http = _FakeHTTPClient() # type: ignore[assignment] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550002", message="hello") + await ch._handle_receive_notification(params) + + # Should NOT publish to bus — sender is not on allowlist. + assert published == [] + # Should have sent a pairing code via send (captured in HTTP posts). + assert len(ch._http.posts) == 1 # type: ignore[attr-defined] + sent_text = ch._http.posts[0]["json"]["params"]["message"] # type: ignore[attr-defined] + assert "pairing" in sent_text.lower() or "pair" in sent_text.lower() + + @pytest.mark.asyncio + async def test_allowlist_dm_denied_with_group_open_still_pairs(self): + """dm.policy="allowlist" + group.policy="open": denied DM sender + must still get a pairing code, not be leaked by the group open check.""" + ch = _make_channel( + dm_enabled=True, + dm_policy="allowlist", + dm_allow_from=[], + group_enabled=True, + group_policy="open", + ) + ch._http = _FakeHTTPClient() # type: ignore[assignment] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550002", message="hello") + await ch._handle_receive_notification(params) + + assert published == [] + assert len(ch._http.posts) == 1 # type: ignore[attr-defined] + + @pytest.mark.asyncio + async def test_open_group_policy_publishes_to_bus(self): + """Open group: group message from unknown sender publishes to bus.""" + ch = _make_channel( + group_enabled=True, + group_policy="open", + require_mention=False, + ) + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _group_envelope(group_id="grp==", message="hello group") + await ch._handle_receive_notification(params) + + assert len(published) == 1 + assert "hello group" in published[0].content + + class TestCheckInboundPolicy: """Direct tests for the policy gate that _handle_data_message now delegates to.""" @@ -671,15 +791,18 @@ class TestHandleDataMessageDM: @pytest.mark.asyncio async def test_dm_allowlist_rejected_triggers_pairing(self): - # Denied DM senders are routed to _handle_message with empty content - # and is_dm=True so BaseChannel issues a pairing code (mirrors Slack). + # Denied DM senders go through super()._handle_message which checks + # is_allowed → sends pairing code via self.send(). ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+10000000001"]) + ch._http = _FakeHTTPClient() # type: ignore[attr-defined] params = _dm_envelope(source_number="+19995550002") await ch._handle_receive_notification(params) - assert len(handled) == 1 - assert handled[0]["content"] == "" - assert handled[0]["is_dm"] is True - assert handled[0]["chat_id"] == "+19995550002" + # The denied DM path calls super()._handle_message, not self._handle_message, + # so the capture list stays empty. Verify pairing code was sent via HTTP. + assert handled == [] + assert len(ch._http.posts) == 1 # type: ignore[attr-defined] + sent_text = ch._http.posts[0]["json"]["params"]["message"] # type: ignore[attr-defined] + assert "pairing" in sent_text.lower() or "pair" in sent_text.lower() @pytest.mark.asyncio async def test_dm_paired_sender_allowed_without_allowlist_entry(self, monkeypatch): From 5f0ba05de594250525c673da358c7e9933bc76da Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 01:25:20 +0800 Subject: [PATCH 25/54] feat(tools): tighten patch and session workflows --- nanobot/agent/tools/apply_patch.py | 129 +++++++++++++++++++++---- nanobot/agent/tools/exec_session.py | 90 +++++++++++++++-- nanobot/utils/file_edit_events.py | 2 + tests/tools/test_apply_patch_tool.py | 49 ++++++++++ tests/tools/test_exec_session_tools.py | 59 +++++++++++ tests/utils/test_file_edit_events.py | 22 +++++ 6 files changed, 322 insertions(+), 29 deletions(-) diff --git a/nanobot/agent/tools/apply_patch.py b/nanobot/agent/tools/apply_patch.py index e69a65f10..57d60f9b8 100644 --- a/nanobot/agent/tools/apply_patch.py +++ b/nanobot/agent/tools/apply_patch.py @@ -10,7 +10,7 @@ from typing import Any, Literal from nanobot.agent.tools.base import tool_parameters from nanobot.agent.tools.filesystem import _FsTool -from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema +from nanobot.agent.tools.schema import BooleanSchema, StringSchema, tool_parameters_schema PatchKind = Literal["add", "delete", "update"] @@ -31,6 +31,15 @@ class _PatchOp: hunks: list[_Hunk] | None = None +@dataclass(slots=True) +class _PatchSummary: + action: str + path: str + added: int = 0 + deleted: int = 0 + new_path: str | None = None + + class _PatchError(ValueError): pass @@ -65,6 +74,40 @@ def _lines_to_text(lines: list[str]) -> str: return "\n".join(lines) + "\n" +def _text_line_count(text: str) -> int: + if not text: + return 0 + return len(text.splitlines()) + + +def _line_diff_stats(before: str, after: str) -> tuple[int, int]: + before_lines = before.replace("\r\n", "\n").splitlines() + after_lines = after.replace("\r\n", "\n").splitlines() + added = 0 + deleted = 0 + matcher = difflib.SequenceMatcher(a=before_lines, b=after_lines, autojunk=False) + for tag, i1, i2, j1, j2 in matcher.get_opcodes(): + if tag == "equal": + continue + if tag in ("replace", "delete"): + deleted += i2 - i1 + if tag in ("replace", "insert"): + added += j2 - j1 + return added, deleted + + +def _format_summary(summary: _PatchSummary) -> str: + path = ( + f"{summary.path} -> {summary.new_path}" + if summary.new_path + else summary.path + ) + stats = "" + if summary.added or summary.deleted: + stats = f" (+{summary.added}/-{summary.deleted})" + return f"- {summary.action} {path}{stats}" + + def _parse_patch(patch: str) -> list[_PatchOp]: lines = patch.replace("\r\n", "\n").replace("\r", "\n").split("\n") if lines and lines[-1] == "": @@ -237,6 +280,10 @@ def _apply_hunks(path: str, content: str, hunks: list[_Hunk]) -> str: "for Add File, Update File, Delete File, and optional Move to.", min_length=1, ), + dry_run=BooleanSchema( + description="Validate and summarize the patch without writing files.", + default=False, + ), required=["patch"], ) ) @@ -254,60 +301,97 @@ class ApplyPatchTool(_FsTool): "Apply a structured patch for code edits. The patch must include " "*** Begin Patch and *** End Patch. Supports Add File, Update File, " "Delete File, and Move to. Paths must be relative. Prefer this for " - "multi-file coding changes; use edit_file for small exact replacements." + "multi-file coding changes; use edit_file for small exact replacements. " + "Set dry_run=true to validate and preview the change without writing files." ) - async def execute(self, patch: str, **kwargs: Any) -> str: + async def execute(self, patch: str, dry_run: bool = False, **kwargs: Any) -> str: try: ops = _parse_patch(patch) writes: dict[Path, str] = {} deletes: set[Path] = set() - touched: list[str] = [] + summaries: list[_PatchSummary] = [] for op in ops: source = self._resolve(op.path) if op.kind == "add": - if source.exists(): + if source.exists() or source in writes: raise _PatchError(f"file to add already exists: {op.path}") - writes[source] = _lines_to_text(op.add_lines or []) + new_content = _lines_to_text(op.add_lines or []) + writes[source] = new_content deletes.discard(source) - touched.append(f"add {op.path}") + summaries.append(_PatchSummary( + action="add", + path=op.path, + added=_text_line_count(new_content), + )) continue if op.kind == "delete": - if not source.exists(): + pending_content = writes.get(source) + if pending_content is None and not source.exists(): raise _PatchError(f"file to delete does not exist: {op.path}") - if not source.is_file(): + if pending_content is None and not source.is_file(): raise _PatchError(f"path to delete is not a file: {op.path}") + deleted_lines = 0 + if pending_content is not None: + deleted_lines = _text_line_count(pending_content) + else: + raw = source.read_bytes() + try: + deleted_lines = _text_line_count(raw.decode("utf-8")) + except UnicodeDecodeError: + deleted_lines = 0 deletes.add(source) writes.pop(source, None) - touched.append(f"delete {op.path}") + summaries.append(_PatchSummary( + action="delete", + path=op.path, + deleted=deleted_lines, + )) continue - if not source.exists(): + pending_content = writes.get(source) + if pending_content is None and not source.exists(): raise _PatchError(f"file to update does not exist: {op.path}") - if not source.is_file(): + if pending_content is None and not source.is_file(): raise _PatchError(f"path to update is not a file: {op.path}") - raw = source.read_bytes() - try: - content = raw.decode("utf-8") - except UnicodeDecodeError as exc: - raise _PatchError(f"file to update is not UTF-8 text: {op.path}") from exc + if pending_content is not None: + content = pending_content + else: + raw = source.read_bytes() + try: + content = raw.decode("utf-8") + except UnicodeDecodeError as exc: + raise _PatchError(f"file to update is not UTF-8 text: {op.path}") from exc uses_crlf = "\r\n" in content content = content.replace("\r\n", "\n") new_content = _apply_hunks(op.path, content, op.hunks or []) + added, deleted = _line_diff_stats(content, new_content) if uses_crlf: new_content = new_content.replace("\n", "\r\n") target = self._resolve(op.new_path) if op.new_path else source - if op.new_path and target.exists() and target != source: + if op.new_path and (target.exists() or target in writes) and target != source: raise _PatchError(f"move target already exists: {op.new_path}") writes[target] = new_content deletes.discard(target) if target != source: deletes.add(source) - action = f"move {op.path} -> {op.new_path}" if op.new_path else f"update {op.path}" - touched.append(action) + writes.pop(source, None) + summaries.append(_PatchSummary( + action="move" if op.new_path else "update", + path=op.path, + new_path=op.new_path, + added=added, + deleted=deleted, + )) + + if dry_run: + return ( + "Patch dry-run succeeded:\n" + + "\n".join(_format_summary(summary) for summary in summaries) + ) backups: dict[Path, bytes | None] = {} for path in set(writes) | deletes: @@ -332,7 +416,10 @@ class ApplyPatchTool(_FsTool): for path in set(writes) | deletes: self._file_states.record_write(path) - return "Patch applied:\n" + "\n".join(f"- {item}" for item in touched) + return ( + "Patch applied:\n" + + "\n".join(_format_summary(summary) for summary in summaries) + ) except PermissionError as exc: return f"Error: {exc}" except _PatchError as exc: diff --git a/nanobot/agent/tools/exec_session.py b/nanobot/agent/tools/exec_session.py index 34667aeaa..c9ca0a3d0 100644 --- a/nanobot/agent/tools/exec_session.py +++ b/nanobot/agent/tools/exec_session.py @@ -16,6 +16,8 @@ from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchem DEFAULT_YIELD_MS = 1000 MAX_YIELD_MS = 30_000 +DEFAULT_WAIT_FOR_MS = 10_000 +MAX_WAIT_FOR_MS = 120_000 DEFAULT_MAX_OUTPUT_CHARS = 10_000 MAX_OUTPUT_CHARS = 50_000 @@ -356,6 +358,18 @@ def format_session_poll(session_id: str, poll: _SessionPoll) -> str: minimum=0, maximum=MAX_YIELD_MS, ), + wait_for=StringSchema( + "Optional text to wait for in output before returning. " + "Useful for interactive commands and dev servers.", + nullable=True, + ), + wait_timeout_ms=IntegerSchema( + DEFAULT_WAIT_FOR_MS, + description="Maximum milliseconds to wait for wait_for text (default 10000, max 120000).", + minimum=0, + maximum=MAX_WAIT_FOR_MS, + nullable=True, + ), max_output_chars=IntegerSchema( DEFAULT_MAX_OUTPUT_CHARS, description="Maximum output characters to return from this poll (default 10000, max 50000).", @@ -412,8 +426,9 @@ class WriteStdinTool(Tool): return ( "Write text to a running exec session and return recent output. " "Use chars='' to poll without writing. Set close_stdin=true to send EOF, " - "or terminate=true to stop the session. Sessions finish automatically " - "when their process exits." + "or terminate=true to stop the session. Use wait_for to keep polling " + "until expected output appears. Sessions finish automatically when " + "their process exits." ) async def execute( @@ -423,6 +438,8 @@ class WriteStdinTool(Tool): close_stdin: bool = False, terminate: bool = False, yield_time_ms: int | None = None, + wait_for: str | None = None, + wait_timeout_ms: int | None = None, max_output_chars: int | None = None, max_output_tokens: int | None = None, **kwargs: Any, @@ -430,18 +447,34 @@ class WriteStdinTool(Tool): try: if max_output_chars is None: max_output_chars = max_output_tokens + output_limit = clamp_session_int( + max_output_chars, + DEFAULT_MAX_OUTPUT_CHARS, + 1000, + MAX_OUTPUT_CHARS, + ) + if wait_for: + return await self._wait_for_output( + session_id=session_id, + chars=chars, + close_stdin=close_stdin, + terminate=terminate, + wait_for=wait_for, + wait_timeout_ms=clamp_session_int( + wait_timeout_ms, + DEFAULT_WAIT_FOR_MS, + 0, + MAX_WAIT_FOR_MS, + ), + max_output_chars=output_limit, + ) poll = await self._manager.write( session_id=session_id, chars=chars, close_stdin=close_stdin, terminate=terminate, yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS), - max_output_chars=clamp_session_int( - max_output_chars, - DEFAULT_MAX_OUTPUT_CHARS, - 1000, - MAX_OUTPUT_CHARS, - ), + max_output_chars=output_limit, ) return format_session_poll(session_id, poll) except KeyError: @@ -449,6 +482,47 @@ class WriteStdinTool(Tool): except Exception as exc: return f"Error writing to exec session: {exc}" + async def _wait_for_output( + self, + *, + session_id: str, + chars: str | None, + close_stdin: bool, + terminate: bool, + wait_for: str, + wait_timeout_ms: int, + max_output_chars: int, + ) -> str: + deadline = time.monotonic() + (wait_timeout_ms / 1000) + aggregate: list[str] = [] + first = True + poll: _SessionPoll | None = None + + while True: + remaining_ms = max(0, int((deadline - time.monotonic()) * 1000)) + step_ms = min(500, remaining_ms) + poll = await self._manager.write( + session_id=session_id, + chars=chars if first else None, + close_stdin=close_stdin if first else False, + terminate=terminate if first else False, + yield_time_ms=step_ms, + max_output_chars=max_output_chars, + ) + first = False + if poll.output: + aggregate.append(poll.output) + joined = "".join(aggregate) + if wait_for in joined: + poll.output = joined + return format_session_poll(session_id, poll) + if poll.done or remaining_ms <= 0: + poll.output = "".join(aggregate) + result = format_session_poll(session_id, poll) + if wait_for not in poll.output: + result += f"\nWait target not observed: {wait_for!r}" + return result + @tool_parameters(tool_parameters_schema()) class ListExecSessionsTool(Tool): diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index c11e8ae60..acef725b0 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -219,6 +219,8 @@ def _resolve_apply_patch_paths( patch = params.get("patch") if not isinstance(patch, str) or not patch.strip(): return [] + if params.get("dry_run") is True: + return [] try: from nanobot.agent.tools.apply_patch import _parse_patch diff --git a/tests/tools/test_apply_patch_tool.py b/tests/tools/test_apply_patch_tool.py index ea98794b3..a356b83e8 100644 --- a/tests/tools/test_apply_patch_tool.py +++ b/tests/tools/test_apply_patch_tool.py @@ -40,9 +40,58 @@ def test_apply_patch_updates_multiple_hunks(tmp_path): )) assert "update multi.txt" in result + assert "(+2/-2)" in result assert target.read_text() == "line1\nchanged2\nline3\nchanged4\n" +def test_apply_patch_dry_run_validates_without_writing(tmp_path): + target = tmp_path / "dry.txt" + target.write_text("before\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: dry.txt +@@ +-before ++after +*** Add File: added.txt ++new +*** End Patch +""", + dry_run=True, + )) + + assert "Patch dry-run succeeded" in result + assert "- update dry.txt (+1/-1)" in result + assert "- add added.txt (+1/-0)" in result + assert target.read_text() == "before\n" + assert not (tmp_path / "added.txt").exists() + + +def test_apply_patch_applies_repeated_update_sections_sequentially(tmp_path): + target = tmp_path / "repeat.txt" + target.write_text("one\ntwo\nthree\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: repeat.txt +@@ +-one ++ONE +*** Update File: repeat.txt +@@ +-three ++THREE +*** End Patch +""" + )) + + assert result.count("update repeat.txt") == 2 + assert target.read_text() == "ONE\ntwo\nTHREE\n" + + def test_apply_patch_ignores_standard_no_newline_marker(tmp_path): target = tmp_path / "plain.txt" target.write_text("before") diff --git a/tests/tools/test_exec_session_tools.py b/tests/tools/test_exec_session_tools.py index 52f72b556..ad2506739 100644 --- a/tests/tools/test_exec_session_tools.py +++ b/tests/tools/test_exec_session_tools.py @@ -245,6 +245,65 @@ def test_write_stdin_preserves_completed_session_output_until_polled(tmp_path): assert "Exit code: 0" in final +def test_write_stdin_can_wait_for_expected_output(tmp_path): + async def run() -> tuple[str, str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('booting', flush=True); " + "time.sleep(0.4); print('ready', flush=True); time.sleep(5)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=100) + sid = _session_id(initial) + waited = await stdin_tool.execute( + session_id=sid, + wait_for="ready", + wait_timeout_ms=3000, + yield_time_ms=0, + ) + cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0) + return initial, waited, cleanup + + initial, waited, cleanup = asyncio.run(run()) + + assert "Process running" in initial + assert "booting" in waited + assert "ready" in waited + assert "Wait target not observed" not in waited + assert "Session terminated." in cleanup + + +def test_write_stdin_wait_for_reports_timeout_without_killing_session(tmp_path): + async def run() -> tuple[str, str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('booting', flush=True); time.sleep(5)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=100) + sid = _session_id(initial) + waited = await stdin_tool.execute( + session_id=sid, + wait_for="never-ready", + wait_timeout_ms=200, + yield_time_ms=0, + ) + cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0) + return initial, waited, cleanup + + initial, waited, cleanup = asyncio.run(run()) + + assert "Process running" in initial + assert "booting" in waited + assert "Process running" in waited + assert "Wait target not observed: 'never-ready'" in waited + assert "Session terminated." in cleanup + + def test_exec_session_mode_reuses_exec_safety_guard(tmp_path): manager = ExecSessionManager() tool = ExecTool( diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py index 7cc8a59fa..3ac4dc929 100644 --- a/tests/utils/test_file_edit_events.py +++ b/tests/utils/test_file_edit_events.py @@ -125,6 +125,28 @@ def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) -> assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1) +def test_apply_patch_dry_run_does_not_prepare_file_edit_trackers(tmp_path: Path) -> None: + (tmp_path / "file.txt").write_text("old\n", encoding="utf-8") + + trackers = prepare_file_edit_trackers( + call_id="call-patch", + tool_name="apply_patch", + tool=None, + workspace=tmp_path, + params={ + "dry_run": True, + "patch": """*** Begin Patch +*** Update File: file.txt +@@ +-old ++new +*** End Patch""", + }, + ) + + assert trackers == [] + + def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None: target = tmp_path / "large.txt" params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)} From 8141df0d3f93a77c3af432c244e9f4b8c51a2e54 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 01:32:27 +0800 Subject: [PATCH 26/54] fix(tools): stabilize session output test --- tests/tools/test_exec_session_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tools/test_exec_session_tools.py b/tests/tools/test_exec_session_tools.py index ad2506739..76b3c9781 100644 --- a/tests/tools/test_exec_session_tools.py +++ b/tests/tools/test_exec_session_tools.py @@ -269,7 +269,7 @@ def test_write_stdin_can_wait_for_expected_output(tmp_path): initial, waited, cleanup = asyncio.run(run()) assert "Process running" in initial - assert "booting" in waited + assert "booting" in initial + waited assert "ready" in waited assert "Wait target not observed" not in waited assert "Session terminated." in cleanup @@ -298,7 +298,7 @@ def test_write_stdin_wait_for_reports_timeout_without_killing_session(tmp_path): initial, waited, cleanup = asyncio.run(run()) assert "Process running" in initial - assert "booting" in waited + assert "booting" in initial + waited assert "Process running" in waited assert "Wait target not observed: 'never-ready'" in waited assert "Session terminated." in cleanup From 77ec55bf8ea5c408e0ee67b923e78b18ba4a8c84 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Thu, 21 May 2026 10:37:10 +0800 Subject: [PATCH 27/54] fix(provider): deduplicate streaming tool_call_ids for parallel calls --- nanobot/providers/openai_compat_provider.py | 9 +++++++++ nanobot/utils/file_edit_events.py | 6 ++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index b8112b529..03ab35a0e 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -1097,6 +1097,15 @@ class OpenAICompatProvider(LLMProvider): if delta: _accum_legacy_function_call(getattr(delta, "function_call", None)) + # Some providers (e.g. Zhipu/GLM) reuse the same tool_call id for + # parallel tool calls in streaming mode. Deduplicate before building + # the response so downstream tool messages don't collide. + _seen_tc_ids: set[str] = set() + for b in tc_bufs.values(): + if not b["id"] or b["id"] in _seen_tc_ids: + b["id"] = _short_tool_id() + _seen_tc_ids.add(b["id"]) + return LLMResponse( content="".join(content_parts) or None, tool_calls=[ diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index b5d2f6d73..ff2594435 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -367,12 +367,14 @@ class StreamingFileEditTracker: def apply_final_call_ids(self, final_tool_calls: list[Any]) -> None: """Keep final start/end events keyed to any earlier streamed placeholder.""" + used_canonicals: set[str] = set() for tool_call in final_tool_calls: canonical = self.canonical_call_id_for(tool_call) - if canonical: + if canonical and canonical not in used_canonicals: try: tool_call.id = canonical - except Exception: + used_canonicals.add(canonical) + except (AttributeError, TypeError): pass def canonical_call_id_for(self, tool_call: Any) -> str | None: From 3d3ebf11109341f9c26125ee8989b168b0b88cb2 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 12:20:19 +0800 Subject: [PATCH 28/54] test(provider): cover duplicate streaming tool call ids --- nanobot/utils/file_edit_events.py | 1 - tests/providers/test_custom_provider.py | 29 ++++++++++++++++++ tests/utils/test_file_edit_events.py | 39 ++++++++++++++++++++++++- 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index ff2594435..3b9ec8da8 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -10,7 +10,6 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, Awaitable, Callable - TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"}) _MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024 _LIVE_EMIT_INTERVAL_S = 0.18 diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py index 85314dc79..ee1f9a090 100644 --- a/tests/providers/test_custom_provider.py +++ b/tests/providers/test_custom_provider.py @@ -56,6 +56,35 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None: assert result.content == "hello world" +def test_custom_provider_parse_chunks_deduplicates_parallel_tool_call_ids() -> None: + chunks = [{ + "choices": [{ + "finish_reason": "tool_calls", + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "call_dup", + "function": {"name": "read_file", "arguments": '{"path":"a.txt"}'}, + }, + { + "index": 1, + "id": "call_dup", + "function": {"name": "read_file", "arguments": '{"path":"b.txt"}'}, + }, + ], + }, + }], + }] + + result = OpenAICompatProvider._parse_chunks(chunks) + ids = [tool_call.id for tool_call in result.tool_calls or []] + + assert ids[0] == "call_dup" + assert len(ids) == 2 + assert len(set(ids)) == 2 + + def test_local_provider_502_error_includes_reachability_hint() -> None: spec = find_by_name("ollama") with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py index cdaae5167..768b8d1f6 100644 --- a/tests/utils/test_file_edit_events.py +++ b/tests/utils/test_file_edit_events.py @@ -5,12 +5,12 @@ from pathlib import Path from types import SimpleNamespace from nanobot.utils.file_edit_events import ( + StreamingFileEditTracker, build_file_edit_end_event, build_file_edit_start_event, line_diff_stats, prepare_file_edit_tracker, read_file_snapshot, - StreamingFileEditTracker, ) @@ -308,6 +308,43 @@ def test_streaming_tracker_applies_canonical_call_id_to_final_tool(tmp_path: Pat asyncio.run(run()) +def test_streaming_tracker_does_not_restore_duplicate_canonical_ids(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call_dup", + "name": "write_file", + "arguments_delta": '{"path":"a.md","content":"one\\n"}', + }) + await tracker.update({ + "index": 1, + "call_id": "call_dup", + "name": "write_file", + "arguments_delta": '{"path":"b.md","content":"two\\n"}', + }) + final_a = SimpleNamespace( + id="call_dup", + name="write_file", + arguments={"path": "a.md", "content": "one\n"}, + ) + final_b = SimpleNamespace( + id="call_unique", + name="write_file", + arguments={"path": "b.md", "content": "two\n"}, + ) + tracker.apply_final_call_ids([final_a, final_b]) + assert final_a.id == "call_dup" + assert final_b.id == "call_unique" + + asyncio.run(run()) + + def test_streaming_edit_file_tracker_flushes_small_pending_count(tmp_path: Path) -> None: target = tmp_path / "small.py" target.write_text("old\n", encoding="utf-8") From de0a8f5e41e0b055f60d2c21208f462c645f6379 Mon Sep 17 00:00:00 2001 From: hanyuanling Date: Thu, 21 May 2026 10:59:15 +0800 Subject: [PATCH 29/54] fix(webui): keep new chat during session refresh --- webui/src/hooks/useSessions.ts | 16 +++++++++- webui/src/tests/useSessions.test.tsx | 47 ++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/webui/src/hooks/useSessions.ts b/webui/src/hooks/useSessions.ts index c22751c65..7b468fc89 100644 --- a/webui/src/hooks/useSessions.ts +++ b/webui/src/hooks/useSessions.ts @@ -27,13 +27,25 @@ export function useSessions(): { const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const tokenRef = useRef(token); + const optimisticKeysRef = useRef>(new Set()); tokenRef.current = token; const refresh = useCallback(async () => { try { setLoading(true); const rows = await listSessions(tokenRef.current); - setSessions(rows); + const serverKeys = new Set(rows.map((row) => row.key)); + setSessions((prev) => [ + ...rows, + ...prev.filter( + (session) => + optimisticKeysRef.current.has(session.key) && + !serverKeys.has(session.key), + ), + ]); + for (const key of Array.from(optimisticKeysRef.current)) { + if (serverKeys.has(key)) optimisticKeysRef.current.delete(key); + } setError(null); } catch (e) { const msg = @@ -57,6 +69,7 @@ export function useSessions(): { const createChat = useCallback(async (): Promise => { const chatId = await client.newChat(); const key = `websocket:${chatId}`; + optimisticKeysRef.current.add(key); // Optimistic insert; a subsequent refresh will replace it with the // authoritative row once the server persists the session. setSessions((prev) => [ @@ -77,6 +90,7 @@ export function useSessions(): { const deleteChat = useCallback( async (key: string) => { await apiDeleteSession(tokenRef.current, key); + optimisticKeysRef.current.delete(key); setSessions((prev) => prev.filter((s) => s.key !== key)); }, [], diff --git a/webui/src/tests/useSessions.test.tsx b/webui/src/tests/useSessions.test.tsx index 72df813e0..8e76e697e 100644 --- a/webui/src/tests/useSessions.test.tsx +++ b/webui/src/tests/useSessions.test.tsx @@ -157,6 +157,53 @@ describe("useSessions", () => { expect(api.listSessions).toHaveBeenCalledTimes(2); }); + it("keeps a newly created chat visible until the server session list catches up", async () => { + vi.mocked(api.listSessions) + .mockResolvedValueOnce([]) + .mockResolvedValueOnce([]) + .mockResolvedValueOnce([ + { + key: "websocket:chat-new", + channel: "websocket", + chatId: "chat-new", + createdAt: "2026-05-20T10:00:00Z", + updatedAt: "2026-05-20T10:01:00Z", + title: "Generated title", + preview: "First message", + }, + ]); + const client = fakeClient(); + client.newChat.mockResolvedValue("chat-new"); + + const { result } = renderHook(() => useSessions(), { + wrapper: wrap(client), + }); + + await waitFor(() => expect(result.current.loading).toBe(false)); + expect(result.current.sessions).toEqual([]); + + await act(async () => { + await result.current.createChat(); + }); + + expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]); + + await act(async () => { + await result.current.refresh(); + }); + + expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]); + expect(result.current.sessions[0]?.preview).toBe(""); + + await act(async () => { + await result.current.refresh(); + }); + + expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]); + expect(result.current.sessions[0]?.preview).toBe("First message"); + expect(result.current.sessions[0]?.title).toBe("Generated title"); + }); + it("passes through WebUI transcript user media as images and media", async () => { vi.mocked(api.fetchWebuiThread).mockResolvedValue({ schemaVersion: 3, From e2b51fa5dca86db7d0638848f4c72d390842c28a Mon Sep 17 00:00:00 2001 From: chengyongru Date: Tue, 19 May 2026 16:00:37 +0800 Subject: [PATCH 30/54] fix(weixin): prevent silent message drops from poll exceptions and expired tokens - Remove suppress(Exception) from poll loop and message processing; add logger.exception so inbound errors are visible. - Check both ret and errcode on send to avoid silent drops when iLink returns ret != 0 with errcode == 0. - Proactively refresh context_token via getconfig before sending if the cached token is older than 60s. This prevents message loss on long agent turns and cron pushes without relying on complex retry logic. Refs: openclaw/openclaw#61174, NousResearch/hermes-agent#21011 --- nanobot/channels/weixin.py | 169 ++++++++- tests/channels/test_weixin_channel.py | 526 ++++++++++++++++++++++++++ 2 files changed, 689 insertions(+), 6 deletions(-) diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 41390f8b3..a75c897f4 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -79,6 +79,12 @@ BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION} ERRCODE_SESSION_EXPIRED = -14 SESSION_PAUSE_DURATION_S = 60 * 60 +# iLink context_token is observed to expire server-side after ~90-160s of +# agent inactivity (openclaw/openclaw#61174). Proactively refresh before +# sending if the cached token is older than this threshold. +CONTEXT_TOKEN_MAX_AGE_S = 60 + + # Retry constants (matching the reference plugin's monitor.ts) MAX_CONSECUTIVE_FAILURES = 3 BACKOFF_DELAY_S = 30 @@ -159,6 +165,8 @@ class WeixinChannel(BaseChannel): self._session_pause_until: float = 0.0 self._typing_tasks: dict[str, asyncio.Task] = {} self._typing_tickets: dict[str, dict[str, Any]] = {} + self._context_token_at: dict[str, float] = {} + self._pending_tool_hints: dict[str, list[str]] = {} # ------------------------------------------------------------------ # State persistence @@ -486,6 +494,7 @@ class WeixinChannel(BaseChannel): except Exception: if not self._running: break + self.logger.exception("WeChat poll loop error") consecutive_failures += 1 if consecutive_failures >= MAX_CONSECUTIVE_FAILURES: consecutive_failures = 0 @@ -495,6 +504,7 @@ class WeixinChannel(BaseChannel): async def stop(self) -> None: self._running = False + self._pending_tool_hints.clear() if self._poll_task and not self._poll_task.done(): self._poll_task.cancel() for chat_id in list(self._typing_tasks): @@ -545,6 +555,7 @@ class WeixinChannel(BaseChannel): # Check for API-level errors (monitor.ts checks both ret and errcode) ret = data.get("ret", 0) errcode = data.get("errcode", 0) + is_error = (ret is not None and ret != 0) or (errcode is not None and errcode != 0) if is_error: @@ -575,8 +586,10 @@ class WeixinChannel(BaseChannel): # Process messages (WeixinMessage[] from types.ts) msgs: list[dict] = data.get("msgs", []) or [] for msg in msgs: - with suppress(Exception): + try: await self._process_message(msg) + except Exception: + self.logger.exception("Failed to process WeChat message") # ------------------------------------------------------------------ # Inbound message processing (matches inbound.ts + process-message.ts) @@ -610,6 +623,7 @@ class WeixinChannel(BaseChannel): ctx_token = msg.get("context_token", "") if ctx_token: self._context_tokens[from_user_id] = ctx_token + self._context_token_at[from_user_id] = time.time() self._save_state() # Parse item_list (WeixinMessage.item_list — types.ts:161) @@ -915,6 +929,99 @@ class WeixinChannel(BaseChannel): } return "" + async def _refresh_context_token_if_stale( + self, chat_id: str, context_token: str + ) -> str: + """Return a fresh context_token if the cached one is too old. + + iLink context_token expires server-side after a short idle period + (empirically ~90s). Proactively refreshing before sending prevents + silent message loss on long agent turns or cron pushes. + """ + if not context_token: + return context_token + + now = time.time() + cached_at = self._context_token_at.get(chat_id, 0) + age = now - cached_at + + if age < CONTEXT_TOKEN_MAX_AGE_S: + return context_token + + self.logger.debug( + "WeChat context_token for {} is {:.0f}s old; refreshing via getconfig", + chat_id, + age, + ) + + body: dict[str, Any] = { + "ilink_user_id": chat_id, + "context_token": context_token, + "base_info": BASE_INFO, + } + try: + data = await self._api_post("ilink/bot/getconfig", body) + except Exception as e: + self.logger.warning("WeChat getconfig failed for {}: {}", chat_id, e) + return context_token + + if data.get("ret", 0) != 0: + self.logger.warning( + "WeChat getconfig returned ret={} for {}: {}", + data.get("ret"), + chat_id, + data.get("errmsg", ""), + ) + return context_token + + new_token = str(data.get("context_token", "") or "") + if new_token and new_token != context_token: + self.logger.info( + "WeChat context_token refreshed for {} (age {:.0f}s -> fresh)", + chat_id, + age, + ) + self._context_tokens[chat_id] = new_token + self._context_token_at[chat_id] = now + self._save_state() + return new_token + + return context_token + + async def _flush_tool_hints(self, chat_id: str) -> None: + """Send any buffered tool hints for *chat_id* as a single message. + + Tool hints are coalesced to reduce message count and avoid hitting the + WeChat iLink rate limit (~7 msgs / 5 min). Failures are logged but + not raised so that the main message send is never blocked. + """ + hints = self._pending_tool_hints.pop(chat_id, None) + if not hints: + return + + self.logger.info( + "Flushing {} buffered tool hint(s) for {}", + len(hints), + chat_id, + ) + + ctx_token = self._context_tokens.get(chat_id, "") + ctx_token = await self._refresh_context_token_if_stale(chat_id, ctx_token) + if not ctx_token: + self.logger.warning( + "Dropped {} buffered tool hint(s) for {}: no context_token", + len(hints), + chat_id, + ) + return + + try: + await self._send_text(chat_id, "\n\n".join(hints), ctx_token) + except Exception: + self.logger.exception( + "Failed to flush buffered tool hints for {}", chat_id + ) + async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None: """Best-effort sendtyping wrapper.""" if not typing_ticket: @@ -944,11 +1051,47 @@ class WeixinChannel(BaseChannel): self._assert_session_active() is_progress = bool((msg.metadata or {}).get("_progress", False)) + + # Buffer tool hints to coalesce consecutive ones and avoid burning + # WeChat iLink rate-limit quota (~7 msgs / 5 min). + if is_progress and (msg.metadata or {}).get("_tool_hint"): + if not self.send_tool_hints: + return + self._pending_tool_hints.setdefault(msg.chat_id, []).append(msg.content) + self.logger.debug( + "Buffered tool hint for {} (count={})", + msg.chat_id, + len(self._pending_tool_hints[msg.chat_id]), + ) + return + + # Reasoning deltas are invisible in WeChat (there is no reasoning + # UI). Skip them entirely — do not send and do not flush buffer. + if is_progress and (msg.metadata or {}).get("_reasoning_delta"): + self.logger.debug( + "Dropped invisible reasoning delta for {}", msg.chat_id + ) + return + + content = msg.content.strip() + + # Empty progress messages (e.g. after_iteration tool_events) must + # NOT act as separators — they have no visible content. + if is_progress and not content and not (msg.media or []): + self.logger.debug( + "Skipped empty progress message for {} (no visible content)", + msg.chat_id, + ) + return + + # Flush buffered hints before sending any visible message. + await self._flush_tool_hints(msg.chat_id) + if not is_progress: await self._stop_typing(msg.chat_id, clear_remote=True) - content = msg.content.strip() ctx_token = self._context_tokens.get(msg.chat_id, "") + ctx_token = await self._refresh_context_token_if_stale(msg.chat_id, ctx_token) if not ctx_token: raise RuntimeError( f"WeChat context_token missing for chat_id={msg.chat_id}, cannot send" @@ -1037,6 +1180,18 @@ class WeixinChannel(BaseChannel): with suppress(Exception): await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL) + async def send_delta( + self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None + ) -> None: + """Weixin iLink does not support native streaming deltas. + + We only hook ``_stream_end`` so buffered tool hints are flushed even + when the final answer carries the ``_streamed`` flag and bypasses + :meth:`send`. + """ + if metadata and metadata.get("_stream_end"): + await self._flush_tool_hints(chat_id) + async def _start_typing(self, chat_id: str, context_token: str = "") -> None: """Start typing indicator immediately when a message is received.""" if not self._client or not self._token or not chat_id: @@ -1120,10 +1275,11 @@ class WeixinChannel(BaseChannel): } data = await self._api_post("ilink/bot/sendmessage", body) + ret = data.get("ret", 0) errcode = data.get("errcode", 0) - if errcode and errcode != 0: + if (ret is not None and ret != 0) or (errcode is not None and errcode != 0): raise RuntimeError( - f"WeChat send text error (code {errcode}): {data.get('errmsg', '')}" + f"WeChat send text error (ret={ret}, errcode={errcode}): {data.get('errmsg', '')}" ) async def _send_media_file( @@ -1270,10 +1426,11 @@ class WeixinChannel(BaseChannel): } data = await self._api_post("ilink/bot/sendmessage", body) + ret = data.get("ret", 0) errcode = data.get("errcode", 0) - if errcode and errcode != 0: + if (ret is not None and ret != 0) or (errcode is not None and errcode != 0): raise RuntimeError( - f"WeChat send media error (code {errcode}): {data.get('errmsg', '')}" + f"WeChat send media error (ret={ret}, errcode={errcode}): {data.get('errmsg', '')}" ) diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index a695ba936..3d3606e75 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -1,6 +1,7 @@ import asyncio import json import tempfile +import time from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock @@ -374,6 +375,7 @@ async def test_send_uses_typing_start_and_cancel_when_ticket_available() -> None channel._client = object() channel._token = "token" channel._context_tokens["wx-user"] = "ctx-typing" + channel._context_token_at["wx-user"] = time.time() channel._send_text = AsyncMock() channel._api_post = AsyncMock( side_effect=[ @@ -402,6 +404,7 @@ async def test_send_still_sends_text_when_typing_ticket_missing() -> None: channel._client = object() channel._token = "token" channel._context_tokens["wx-user"] = "ctx-no-ticket" + channel._context_token_at["wx-user"] = time.time() channel._send_text = AsyncMock() channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "no config"}) @@ -1254,3 +1257,526 @@ async def test_send_text_succeeds_on_zero_errcode() -> None: await channel._send_text("wx-user", "hello", "ctx-ok") channel._api_post.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_text_raises_on_nonzero_ret_even_when_errcode_zero() -> None: + """_send_text must raise when the API returns ret != 0, even if errcode is 0. + + The iLink API signals failure through either field. Checking only errcode + caused silent message drops (responses generated but never delivered). + """ + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._api_post = AsyncMock( + return_value={"ret": -100, "errcode": 0, "errmsg": "internal error"} + ) + + with pytest.raises(RuntimeError, match="WeChat send text error.*ret=-100.*errcode=0"): + await channel._send_text("wx-user", "hello", "ctx-ok") + + channel._api_post.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# Tests for _poll_once not silently dropping messages on processing errors +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_poll_once_logs_exception_on_process_message_failure(monkeypatch) -> None: + """When _process_message raises, _poll_once must log the error and continue + processing remaining messages instead of silently swallowing the exception.""" + channel, _bus = _make_channel() + channel._client = SimpleNamespace(timeout=None) + channel._token = "token" + channel._get_updates_buf = "old-buf" + + calls = [] + logged_messages: list[str] = [] + + async def _failing_process(msg: dict) -> None: + calls.append(msg.get("message_id")) + if msg.get("message_id") == "msg-1": + raise RuntimeError("processing failed") + + channel._process_message = _failing_process # type: ignore[method-assign] + + monkeypatch.setattr( + channel.logger, + "exception", + lambda message, *args, **kwargs: logged_messages.append(str(message)), + ) + + channel._api_post = AsyncMock( # type: ignore[method-assign] + return_value={ + "ret": 0, + "errcode": 0, + "get_updates_buf": "new-buf", + "msgs": [ + {"message_id": "msg-1", "message_type": 1}, + {"message_id": "msg-2", "message_type": 1}, + ], + } + ) + + await channel._poll_once() + + # Both messages should have been attempted + assert calls == ["msg-1", "msg-2"] + # Buffer should still advance (already updated before processing) + assert channel._get_updates_buf == "new-buf" + # Error should be logged + assert any("Failed to process WeChat message" in m for m in logged_messages) + + +@pytest.mark.asyncio +async def test_poll_loop_logs_exception_and_continues_on_poll_failure(monkeypatch) -> None: + """When _poll_once raises a non-timeout exception, the start() loop must log + the error and continue polling instead of exiting silently.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.config.token = "token" # skip QR login in start() + channel._running = True + + call_count = 0 + logged_messages: list[str] = [] + + async def _failing_poll() -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("poll exploded") + channel._running = False # Stop after second call + + channel._poll_once = _failing_poll # type: ignore[method-assign] + + monkeypatch.setattr( + channel.logger, + "exception", + lambda message, *args, **kwargs: logged_messages.append(str(message)), + ) + + # Use a tiny retry delay so the test finishes quickly + original_retry = weixin_mod.RETRY_DELAY_S + weixin_mod.RETRY_DELAY_S = 0.01 + try: + await channel.start() + finally: + weixin_mod.RETRY_DELAY_S = original_retry + + assert call_count == 2 + assert any("WeChat poll loop error" in m for m in logged_messages) + + +# --------------------------------------------------------------------------- +# Tool-hint buffering +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_buffer_single_tool_hint_not_sent_immediately() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Using tool", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + channel._send_text.assert_not_awaited() + assert channel._pending_tool_hints["wx-user"] == ["Using tool"] + + +@pytest.mark.asyncio +async def test_buffer_multiple_tool_hints_flushed_on_final_answer() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + for hint in ["tool1", "tool2"]: + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": hint, + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + assert channel._send_text.await_count == 2 + channel._send_text.assert_any_await("wx-user", "tool1\n\ntool2", "ctx-1") + channel._send_text.assert_any_await("wx-user", "Done", "ctx-1") + assert "wx-user" not in channel._pending_tool_hints + + +@pytest.mark.asyncio +async def test_thought_progress_flushes_tool_hints() -> None: + """Thoughts are visible progress messages and must act as separators, + flushing buffered tool hints before they are sent.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + # Buffer a tool hint + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "search 'foo'", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + # Send a thought — progress but not a tool_hint. + # It must act as a separator and flush the buffered hint. + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Let me think...", + "media": [], + "metadata": {"_progress": True}, + }, + )() + ) + + # The buffered hint was flushed before the thought was sent. + channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1") + channel._send_text.assert_any_await("wx-user", "Let me think...", "ctx-1") + assert "wx-user" not in channel._pending_tool_hints + + # Final answer arrives with nothing left to flush. + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + assert channel._send_text.await_count == 3 + channel._send_text.assert_any_await("wx-user", "Done", "ctx-1") + + +@pytest.mark.asyncio +async def test_reasoning_delta_does_not_flush_tool_hints() -> None: + """Reasoning deltas are invisible in WeChat and must NOT flush buffered + tool hints — otherwise hints separated only by hidden reasoning would + fail to coalesce.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + # Buffer a tool hint + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "search 'foo'", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + # Send a reasoning delta — invisible in WeChat, must NOT flush + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Thinking step 1...", + "media": [], + "metadata": {"_progress": True, "_reasoning_delta": True}, + }, + )() + ) + + # Reasoning is invisible; hint stays buffered, _send_text not called + channel._send_text.assert_not_awaited() + assert channel._pending_tool_hints["wx-user"] == ["search 'foo'"] + + # Final answer flushes the buffered hint + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1") + channel._send_text.assert_any_await("wx-user", "Done", "ctx-1") + assert "wx-user" not in channel._pending_tool_hints + + +@pytest.mark.asyncio +async def test_empty_progress_message_does_not_flush_tool_hints() -> None: + """Empty progress messages (e.g. after_iteration tool_events) have no + visible content and must NOT act as separators.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + # Buffer a tool hint + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "search 'foo'", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + # Send an empty progress message (no content, no media) + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "", + "media": [], + "metadata": {"_progress": True, "_tool_events": [{"phase": "end"}]}, + }, + )() + ) + + # Nothing should have been sent yet + channel._send_text.assert_not_awaited() + assert channel._pending_tool_hints["wx-user"] == ["search 'foo'"] + + # Final answer flushes the buffered hint + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + channel._send_text.assert_any_await("wx-user", "search 'foo'", "ctx-1") + channel._send_text.assert_any_await("wx-user", "Done", "ctx-1") + assert "wx-user" not in channel._pending_tool_hints + + +@pytest.mark.asyncio +async def test_buffer_flush_refreshes_context_token() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-old" + channel._context_token_at["wx-user"] = time.time() + channel._refresh_context_token_if_stale = AsyncMock(return_value="ctx-refreshed") + channel._send_text = AsyncMock() + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "hint", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + assert channel._refresh_context_token_if_stale.await_count == 2 + channel._refresh_context_token_if_stale.assert_any_await("wx-user", "ctx-old") + channel._send_text.assert_any_await("wx-user", "hint", "ctx-refreshed") + + +@pytest.mark.asyncio +async def test_buffer_flush_failure_does_not_block_final_answer() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock(side_effect=[RuntimeError("boom"), None]) + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "hint", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "Done", + "media": [], + "metadata": {}, + }, + )() + ) + + assert channel._send_text.await_count == 2 + channel._send_text.assert_any_await("wx-user", "hint", "ctx-1") + channel._send_text.assert_any_await("wx-user", "Done", "ctx-1") + + +@pytest.mark.asyncio +async def test_buffer_flushed_on_stream_end() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = True + channel._context_tokens["wx-user"] = "ctx-1" + channel._context_token_at["wx-user"] = time.time() + channel._send_text = AsyncMock() + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "hint", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + await channel.send_delta("wx-user", "", {"_stream_end": True}) + + channel._send_text.assert_awaited_once_with("wx-user", "hint", "ctx-1") + assert "wx-user" not in channel._pending_tool_hints + + +@pytest.mark.asyncio +async def test_stop_clears_buffer() -> None: + channel, _bus = _make_channel() + channel._pending_tool_hints["wx-user"] = ["hint1", "hint2"] + await channel.stop() + assert "wx-user" not in channel._pending_tool_hints + + +@pytest.mark.asyncio +async def test_send_tool_hints_false_drops_tool_hints() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel.send_tool_hints = False + channel._send_text = AsyncMock() + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "hint", + "media": [], + "metadata": {"_progress": True, "_tool_hint": True}, + }, + )() + ) + + channel._send_text.assert_not_awaited() + assert "wx-user" not in channel._pending_tool_hints From 44ef697aac9a85f6812604f63c210f06b1264bc0 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 14:28:39 +0800 Subject: [PATCH 31/54] docs(tools): clarify coding tool guidance --- nanobot/agent/tools/apply_patch.py | 11 ++++--- nanobot/agent/tools/exec_session.py | 14 ++++---- nanobot/agent/tools/filesystem.py | 21 ++++++++---- nanobot/agent/tools/search.py | 9 ++++-- nanobot/agent/tools/shell.py | 6 ++-- tests/tools/test_tool_descriptions.py | 46 +++++++++++++++++++++++++++ 6 files changed, 85 insertions(+), 22 deletions(-) create mode 100644 tests/tools/test_tool_descriptions.py diff --git a/nanobot/agent/tools/apply_patch.py b/nanobot/agent/tools/apply_patch.py index 57d60f9b8..c4dbf9f9f 100644 --- a/nanobot/agent/tools/apply_patch.py +++ b/nanobot/agent/tools/apply_patch.py @@ -298,11 +298,14 @@ class ApplyPatchTool(_FsTool): @property def description(self) -> str: return ( - "Apply a structured patch for code edits. The patch must include " + "Default tool for code edits. Apply a structured patch with " "*** Begin Patch and *** End Patch. Supports Add File, Update File, " - "Delete File, and Move to. Paths must be relative. Prefer this for " - "multi-file coding changes; use edit_file for small exact replacements. " - "Set dry_run=true to validate and preview the change without writing files." + "Delete File, and Move to across one or more files. Use this for " + "multi-file changes, structural edits, generated code, or any edit " + "where a reviewable patch is clearer than an exact replacement. " + "Paths must be relative. Set dry_run=true to validate and preview " + "the change summary without writing files. Use edit_file only for " + "small exact replacements copied from read_file." ) async def execute(self, patch: str, dry_run: bool = False, **kwargs: Any) -> str: diff --git a/nanobot/agent/tools/exec_session.py b/nanobot/agent/tools/exec_session.py index c9ca0a3d0..4dadb2d36 100644 --- a/nanobot/agent/tools/exec_session.py +++ b/nanobot/agent/tools/exec_session.py @@ -424,11 +424,12 @@ class WriteStdinTool(Tool): @property def description(self) -> str: return ( - "Write text to a running exec session and return recent output. " - "Use chars='' to poll without writing. Set close_stdin=true to send EOF, " - "or terminate=true to stop the session. Use wait_for to keep polling " - "until expected output appears. Sessions finish automatically when " - "their process exits." + "Interact with a running exec session created by exec with " + "yield_time_ms. Use chars='' to poll without writing, chars to send " + "stdin, close_stdin=true to send EOF, or terminate=true to stop the " + "process. Use wait_for with wait_timeout_ms for dev servers, test " + "watchers, and prompts where you need to wait for expected output. " + "Do not use this to start new commands; start them with exec." ) async def execute( @@ -561,7 +562,8 @@ class ListExecSessionsTool(Tool): return ( "List active long-running exec sessions, including session_id, cwd, " "elapsed time, idle time, remaining timeout, and command preview. " - "Use this to recover a session_id before polling with write_stdin." + "Use this to recover a session_id after context shifts before " + "polling, writing stdin, or terminating with write_stdin." ) @property diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 728ff9317..fa63e5f66 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -158,6 +158,9 @@ class ReadFileTool(_FsTool): "Text output format: LINE_NUM|CONTENT. " "Images return visual content for analysis. " "Supports PDF, DOCX, XLSX, PPTX documents. " + "Use find_files/list_dir first when the path is uncertain. " + "Read the relevant range before editing so replacements or patches " + "are based on current content. " "Use offset and limit for large text files. " "Use force=true to re-read content even if unchanged. " "Reads exceeding ~128K chars are truncated." @@ -384,9 +387,10 @@ class WriteFileTool(_FsTool): @property def description(self) -> str: return ( - "Write content to a file. Overwrites if the file already exists; " - "creates parent directories as needed. " - "For partial edits, prefer edit_file instead." + "Create a new file or intentionally replace an entire file with " + "the provided content. Overwrites existing files and creates parent " + "directories as needed. For code changes or partial edits, prefer " + "apply_patch; use edit_file only for small exact replacements." ) async def execute(self, path: str | None = None, content: str | None = None, **kwargs: Any) -> str: @@ -711,10 +715,13 @@ class EditFileTool(_FsTool): @property def description(self) -> str: return ( - "Edit a file by replacing old_text with new_text. " - "Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. " - "If old_text matches multiple times, you must provide more context " - "or set occurrence/line_hint/replace_all. Shows a diff of the closest match on failure." + "Perform a small, exact replacement in one file by replacing " + "old_text with new_text. Use this for narrow text substitutions " + "with old_text copied from read_file. For multi-file, structural, " + "or generated code edits, prefer apply_patch. If old_text matches " + "multiple times, provide more context or set occurrence, line_hint, " + "replace_all, and expected_replacements. Shows closest-match " + "diagnostics on failure." ) @staticmethod diff --git a/nanobot/agent/tools/search.py b/nanobot/agent/tools/search.py index 52c18f16b..0febb122c 100644 --- a/nanobot/agent/tools/search.py +++ b/nanobot/agent/tools/search.py @@ -130,8 +130,10 @@ class FindFilesTool(_SearchTool): def description(self) -> str: return ( "Find files by path fragment, glob, or file type. " - "Use this before read_file when you need to locate files. " - "Returns workspace-relative paths and skips common dependency/build directories." + "Use this before read_file when you need to locate files, and " + "prefer it over shell find/ls for ordinary workspace discovery. " + "Returns workspace-relative paths and skips common dependency/build " + "directories." ) @property @@ -289,7 +291,8 @@ class GrepTool(_SearchTool): return ( "Search file contents with a regex pattern. " "Default output_mode is files_with_matches (file paths only); " - "use content mode for matching lines with context. " + "use content mode for matching lines with context. Prefer this " + "over shell grep for ordinary workspace searches. " "Skips binary and files >2 MB. Supports glob/type filtering." ) diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 47d3e9065..090dcc716 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -211,8 +211,10 @@ class ExecTool(Tool): def description(self) -> str: return ( "Execute a shell command and return its output. " - "Prefer read_file/write_file/edit_file over cat/echo/sed, " - "and grep/glob over shell find/grep. " + "Use this for tests, builds, package commands, git commands, and " + "other process execution. Prefer read_file/find_files/grep for " + "inspection and apply_patch/write_file/edit_file for file changes " + "instead of cat, shell find/grep, echo, or sed. " "Use -y or --yes flags to avoid interactive prompts. " "For long-running or interactive commands, pass yield_time_ms; " "if the command keeps running, exec returns a session_id that can " diff --git a/tests/tools/test_tool_descriptions.py b/tests/tools/test_tool_descriptions.py new file mode 100644 index 000000000..bb7665e4e --- /dev/null +++ b/tests/tools/test_tool_descriptions.py @@ -0,0 +1,46 @@ +from nanobot.agent.tools.apply_patch import ApplyPatchTool +from nanobot.agent.tools.exec_session import ListExecSessionsTool, WriteStdinTool +from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool +from nanobot.agent.tools.search import FindFilesTool, GrepTool +from nanobot.agent.tools.shell import ExecTool + + +def test_coding_tool_descriptions_steer_editing_priority() -> None: + apply_patch = ApplyPatchTool().description.lower() + edit_file = EditFileTool().description.lower() + write_file = WriteFileTool().description.lower() + + assert "default tool for code edits" in apply_patch + assert "multi-file" in apply_patch + assert "dry_run=true" in apply_patch + assert "edit_file only for small exact replacements" in apply_patch + + assert "small, exact replacement" in edit_file + assert "copied from read_file" in edit_file + assert "prefer apply_patch" in edit_file + + assert "replace an entire file" in write_file + assert "prefer apply_patch" in write_file + + +def test_coding_tool_descriptions_steer_discovery_and_shell_usage() -> None: + read_file = ReadFileTool().description.lower() + find_files = FindFilesTool().description.lower() + grep = GrepTool().description.lower() + exec_tool = ExecTool().description.lower() + write_stdin = WriteStdinTool().description.lower() + list_sessions = ListExecSessionsTool().description.lower() + + assert "find_files/list_dir first" in read_file + assert "before editing" in read_file + assert "prefer it over shell find/ls" in find_files + assert "prefer this over shell grep" in grep + + assert "tests, builds" in exec_tool + assert "prefer read_file/find_files/grep" in exec_tool + assert "apply_patch/write_file/edit_file" in exec_tool + assert "yield_time_ms" in exec_tool + + assert "do not use this to start new commands" in write_stdin + assert "wait_for" in write_stdin + assert "recover a session_id" in list_sessions From 0cd2f626c054572f9eb0787daa9c817c470b3249 Mon Sep 17 00:00:00 2001 From: olgagaga Date: Sat, 16 May 2026 12:19:30 -0400 Subject: [PATCH 32/54] fix(providers): inject OpenRouter `reasoning.effort` for thinking models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to #3851: that PR added `extra_body.thinking={type: disabled}` for MiMo via OpenRouter, but OR doesn't forward provider-specific thinking shapes to upstream — it strips unknown extra_body fields and uses its own unified `reasoning` parameter. So MiMo via OR kept thinking despite the injection (reproduced by @ClearPlume on #3851 with identical kwargs but provider switched from openrouter → xiaomi_mimo). For known thinking-capable models (Kimi, MiMo) routed via the openrouter spec, also inject `extra_body.reasoning = {effort: }` in OR's documented enum ("none"|"minimal"|"low"|"medium"|"high"|"xhigh"). OR translates this to the upstream model's native shape. Existing tests updated to expect both fields on the OR path. The direct xiaomi_mimo and moonshot paths are unchanged (the new branch is gated on spec.name == "openrouter"). Flash and non-MiMo models on OR continue to receive no injection. --- nanobot/providers/openai_compat_provider.py | 21 ++++++++++ tests/providers/test_litellm_kwargs.py | 19 +++++++-- tests/providers/test_xiaomi_mimo_thinking.py | 44 ++++++++++++++++---- 3 files changed, 72 insertions(+), 12 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 03ab35a0e..222159dda 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -615,6 +615,27 @@ class OpenAICompatProvider(LLMProvider): {"thinking": {"type": "enabled" if thinking_enabled else "disabled"}} ) + # OpenRouter uses its own unified `reasoning` field and does not + # forward provider-specific thinking shapes (the Kimi/MiMo + # extra_body.thinking above) to upstream. Reported as the follow-up + # to #3845/#3851: MiMo via OR kept thinking despite our injection. + # For known thinking-capable models routed via OR, mirror the + # effort signal into reasoning.effort (OR's documented enum: + # "none"|"minimal"|"low"|"medium"|"high"|"xhigh"), which OR + # translates to the upstream model's native shape. + if ( + spec + and spec.name == "openrouter" + and reasoning_effort is not None + and ( + _is_kimi_thinking_model(model_name) + or _is_mimo_thinking_model(model_name) + ) + ): + kwargs.setdefault("extra_body", {}).update( + {"reasoning": {"effort": semantic_effort}} + ) + if tools: kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 5f2ffec59..461913c93 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -1391,9 +1391,16 @@ def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None: def test_kimi_k25_thinking_enabled_with_openrouter_prefix() -> None: - """OpenRouter-style model names like moonshotai/kimi-k2.5 must trigger thinking.""" + """OpenRouter-style model names like moonshotai/kimi-k2.5 must trigger thinking. + + OR drops upstream-provider `thinking` fields, so the same intent also has + to go through OR's `reasoning.effort` shape (#3851 follow-up). + """ kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.5", reasoning_effort="medium") - assert kw.get("extra_body") == {"thinking": {"type": "enabled"}} + assert kw.get("extra_body") == { + "thinking": {"type": "enabled"}, + "reasoning": {"effort": "medium"}, + } def test_kimi_k26_thinking_enabled() -> None: @@ -1403,9 +1410,13 @@ def test_kimi_k26_thinking_enabled() -> None: def test_kimi_k26_thinking_enabled_with_openrouter_prefix() -> None: - """OpenRouter-style names like moonshotai/kimi-k2.6 must trigger thinking.""" + """OpenRouter-style names like moonshotai/kimi-k2.6 must trigger thinking + via both upstream `thinking` and OR's `reasoning.effort`.""" kw = _build_kwargs_for("openrouter", "moonshotai/kimi-k2.6", reasoning_effort="medium") - assert kw.get("extra_body") == {"thinking": {"type": "enabled"}} + assert kw.get("extra_body") == { + "thinking": {"type": "enabled"}, + "reasoning": {"effort": "medium"}, + } def test_moonshot_kimi_k26_temperature_override() -> None: diff --git a/tests/providers/test_xiaomi_mimo_thinking.py b/tests/providers/test_xiaomi_mimo_thinking.py index 68ca6dd80..43dfec537 100644 --- a/tests/providers/test_xiaomi_mimo_thinking.py +++ b/tests/providers/test_xiaomi_mimo_thinking.py @@ -142,9 +142,11 @@ def test_mimo_reasoning_effort_unset_preserves_provider_default(): def test_mimo_via_openrouter_reasoning_effort_none_disables_thinking(): - """OpenRouter routes MiMo as "xiaomi/mimo-v2.5-pro"; the openrouter spec - has no thinking_style, so the disable signal must come from the - model-name path (#3845).""" + """OpenRouter routes MiMo as "xiaomi/mimo-v2.5-pro" and does NOT forward + extra_body.thinking to upstream, so a disable signal must also reach OR + in its own `reasoning.effort` shape. Verifies both the upstream-MiMo + payload (#3845) and the OR-native payload (#3851 follow-up) are sent. + """ provider = _openrouter_provider("xiaomi/mimo-v2.5-pro") kwargs = provider._build_kwargs( messages=_simple_messages(), @@ -152,11 +154,15 @@ def test_mimo_via_openrouter_reasoning_effort_none_disables_thinking(): temperature=0.7, reasoning_effort="none", tool_choice=None, ) assert "reasoning_effort" not in kwargs - assert kwargs["extra_body"] == {"thinking": {"type": "disabled"}} + assert kwargs["extra_body"] == { + "thinking": {"type": "disabled"}, + "reasoning": {"effort": "none"}, + } def test_mimo_via_openrouter_reasoning_effort_medium_enables_thinking(): - """Same as the direct path: any non-none/minimal effort enables thinking.""" + """Non-none/minimal effort enables thinking and the OR `reasoning.effort` + field mirrors the requested effort level.""" provider = _openrouter_provider("xiaomi/mimo-v2.5-pro") kwargs = provider._build_kwargs( messages=_simple_messages(), @@ -164,7 +170,10 @@ def test_mimo_via_openrouter_reasoning_effort_medium_enables_thinking(): temperature=0.7, reasoning_effort="medium", tool_choice=None, ) assert kwargs.get("reasoning_effort") == "medium" - assert kwargs["extra_body"] == {"thinking": {"type": "enabled"}} + assert kwargs["extra_body"] == { + "thinking": {"type": "enabled"}, + "reasoning": {"effort": "medium"}, + } def test_mimo_via_openrouter_bare_slug_also_matches(): @@ -176,12 +185,16 @@ def test_mimo_via_openrouter_bare_slug_also_matches(): tools=None, model=None, max_tokens=100, temperature=0.7, reasoning_effort="none", tool_choice=None, ) - assert kwargs["extra_body"] == {"thinking": {"type": "disabled"}} + assert kwargs["extra_body"] == { + "thinking": {"type": "disabled"}, + "reasoning": {"effort": "none"}, + } def test_mimo_flash_via_openrouter_does_not_inject_thinking(): """mimo-v2-flash has no thinking mode per Xiaomi docs; the allowlist - excludes it, so no thinking field should be injected on the gateway path.""" + excludes it, so neither the upstream `thinking` field nor OR's + `reasoning.effort` should be injected on the gateway path.""" provider = _openrouter_provider("xiaomi/mimo-v2-flash") kwargs = provider._build_kwargs( messages=_simple_messages(), @@ -200,3 +213,18 @@ def test_non_mimo_model_via_openrouter_unaffected(): temperature=0.7, reasoning_effort="none", tool_choice=None, ) assert "extra_body" not in kwargs + + +def test_kimi_via_openrouter_also_injects_reasoning_effort(): + """Kimi has the same gateway problem as MiMo: OR drops the upstream + `thinking` field. The same OR-reasoning injection should fire.""" + provider = _openrouter_provider("moonshotai/kimi-k2.5") + kwargs = provider._build_kwargs( + messages=_simple_messages(), + tools=None, model=None, max_tokens=100, + temperature=0.7, reasoning_effort="none", tool_choice=None, + ) + assert kwargs["extra_body"] == { + "thinking": {"type": "disabled"}, + "reasoning": {"effort": "none"}, + } From 4f895e6307cab731e86f4e67b6a044bc957dbf9e Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 14:34:45 +0800 Subject: [PATCH 33/54] refactor(providers): centralize gateway reasoning control --- nanobot/providers/openai_compat_provider.py | 130 +++++++------------ nanobot/providers/registry.py | 6 + tests/providers/test_xiaomi_mimo_thinking.py | 9 +- 3 files changed, 59 insertions(+), 86 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 222159dda..a61439025 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -74,41 +74,43 @@ _THINKING_STYLE_MAP: dict[str, Any] = { "enable_thinking": lambda on: {"enable_thinking": on}, "reasoning_split": lambda on: {"reasoning_split": on}, } +_GATEWAY_REASONING_STYLE_MAP: dict[str, Any] = { + "reasoning_effort": lambda effort: {"reasoning": {"effort": effort}}, +} +_MODEL_THINKING_STYLES: dict[str, str] = { + **dict.fromkeys(_KIMI_THINKING_MODELS, "thinking_type"), + **dict.fromkeys(_MIMO_THINKING_MODELS, "thinking_type"), +} -def _is_kimi_thinking_model(model_name: str) -> bool: - """Return True if model_name refers to a Kimi thinking-capable model. - - Supports two forms: - - Exact match: e.g. kimi-k2.5 / kimi-k2.6 in _KIMI_THINKING_MODELS - - Slug match: moonshotai/kimi-k2.5 -> the part after the last "/" - is checked against _KIMI_THINKING_MODELS - - This covers both the native Moonshot provider (bare slug) and - OpenRouter-style names (``"publisher/slug"``). - """ - name = model_name.lower() - if name in _KIMI_THINKING_MODELS: - return True - if "/" in name and name.rsplit("/", 1)[1] in _KIMI_THINKING_MODELS: - return True - return False +def _model_slug(model_name: str) -> str: + return model_name.lower().rsplit("/", 1)[-1] -def _is_mimo_thinking_model(model_name: str) -> bool: - """Return True if model_name refers to a MiMo thinking-capable model. +def _model_thinking_style(model_name: str) -> str: + return _MODEL_THINKING_STYLES.get(_model_slug(model_name), "") - Mirrors _is_kimi_thinking_model: gateway providers (e.g. OpenRouter - routing ``xiaomi/mimo-v2.5-pro``) have no ``thinking_style`` on their - spec, so the spec-driven branch in _build_kwargs misses them. The - model-name path catches those cases. - """ - name = model_name.lower() - if name in _MIMO_THINKING_MODELS: - return True - if "/" in name and name.rsplit("/", 1)[1] in _MIMO_THINKING_MODELS: - return True - return False + +def _thinking_styles_for(spec: ProviderSpec | None, model_name: str) -> list[str]: + styles: list[str] = [] + if spec and spec.thinking_style: + styles.append(spec.thinking_style) + model_style = _model_thinking_style(model_name) + if model_style and model_style not in styles: + styles.append(model_style) + return styles + + +def _thinking_extra_body(style: str, thinking_enabled: bool) -> dict[str, Any] | None: + builder = _THINKING_STYLE_MAP.get(style) + return builder(thinking_enabled) if builder else None + + +def _gateway_reasoning_extra_body(style: str, effort: str | None) -> dict[str, Any] | None: + if not effort: + return None + builder = _GATEWAY_REASONING_STYLE_MAP.get(style) + return builder(effort) if builder else None def _openai_compat_timeout_s() -> float: @@ -581,60 +583,19 @@ class OpenAICompatProvider(LLMProvider): if wire_effort and semantic_effort != "none": kwargs["reasoning_effort"] = wire_effort - # Provider-specific thinking parameters. - # Only sent when reasoning_effort is explicitly configured so that - # the provider default is preserved otherwise. - # The mapping is driven by ProviderSpec.thinking_style so that adding - # a new provider never requires touching this function. - if spec and spec.thinking_style and reasoning_effort is not None: + # Only send thinking controls when reasoning_effort is explicit so + # omitting the config preserves each provider's default. + if reasoning_effort is not None: thinking_enabled = semantic_effort not in ("none", "minimal") - extra = _THINKING_STYLE_MAP.get(spec.thinking_style, lambda _: None)(thinking_enabled) - if extra: - kwargs.setdefault("extra_body", {}).update(extra) - - # Model-level thinking injection for Kimi thinking-capable models. - # Strip any provider prefix (e.g. "moonshotai/") before the set lookup - # so that OpenRouter-style names like "moonshotai/kimi-k2.5" are handled - # identically to bare names like "kimi-k2.5". - if reasoning_effort is not None and _is_kimi_thinking_model(model_name): - thinking_enabled = semantic_effort not in ("none", "minimal") - kwargs.setdefault("extra_body", {}).update( - {"thinking": {"type": "enabled" if thinking_enabled else "disabled"}} - ) - - # Model-level thinking injection for MiMo thinking-capable models. - # Same shape as Kimi: gateway providers (OpenRouter, etc.) lack the - # xiaomi_mimo spec's thinking_style, so the spec-driven branch above - # misses them — match by model name to catch "xiaomi/mimo-v2.5-pro" - # and friends. (Direct xiaomi_mimo requests are also covered here; - # both branches write the same payload, so the dict update is a - # safe no-op for already-handled cases.) - if reasoning_effort is not None and _is_mimo_thinking_model(model_name): - thinking_enabled = semantic_effort not in ("none", "minimal") - kwargs.setdefault("extra_body", {}).update( - {"thinking": {"type": "enabled" if thinking_enabled else "disabled"}} - ) - - # OpenRouter uses its own unified `reasoning` field and does not - # forward provider-specific thinking shapes (the Kimi/MiMo - # extra_body.thinking above) to upstream. Reported as the follow-up - # to #3845/#3851: MiMo via OR kept thinking despite our injection. - # For known thinking-capable models routed via OR, mirror the - # effort signal into reasoning.effort (OR's documented enum: - # "none"|"minimal"|"low"|"medium"|"high"|"xhigh"), which OR - # translates to the upstream model's native shape. - if ( - spec - and spec.name == "openrouter" - and reasoning_effort is not None - and ( - _is_kimi_thinking_model(model_name) - or _is_mimo_thinking_model(model_name) - ) - ): - kwargs.setdefault("extra_body", {}).update( - {"reasoning": {"effort": semantic_effort}} - ) + for thinking_style in _thinking_styles_for(spec, model_name): + extra = _thinking_extra_body(thinking_style, thinking_enabled) + if extra: + kwargs.setdefault("extra_body", {}).update(extra) + gateway_style = getattr(spec, "gateway_reasoning_style", "") if spec else "" + if gateway_style and _model_thinking_style(model_name): + extra = _gateway_reasoning_extra_body(gateway_style, semantic_effort) + if extra: + kwargs.setdefault("extra_body", {}).update(extra) if tools: kwargs["tools"] = tools @@ -649,8 +610,7 @@ class OpenAICompatProvider(LLMProvider): and semantic_effort not in ("none", "minimal") and ( (spec and spec.thinking_style) - or _is_kimi_thinking_model(model_name) - or _is_mimo_thinking_model(model_name) + or _model_thinking_style(model_name) ) ) implicit_deepseek_thinking = ( diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 7c8edd271..d942c03bf 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -71,6 +71,11 @@ class ProviderSpec: # "reasoning_split" — {"reasoning_split": true/false} (MiniMax) thinking_style: str = "" + # Gateway-native reasoning control to pair with model-level thinking styles. + # "reasoning_effort" — {"reasoning": {"effort": }} + # (OpenRouter) + gateway_reasoning_style: str = "" + # When True, treat the "reasoning" response field as formal content # when "content" is empty. Only set this for providers (e.g. StepFun) # whose API returns the actual answer in "reasoning" instead of "content". @@ -142,6 +147,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( detect_by_base_keyword="openrouter", default_api_base="https://openrouter.ai/api/v1", supports_prompt_caching=True, + gateway_reasoning_style="reasoning_effort", ), # Hugging Face Inference Providers: OpenAI-compatible router for chat models. ProviderSpec( diff --git a/tests/providers/test_xiaomi_mimo_thinking.py b/tests/providers/test_xiaomi_mimo_thinking.py index 43dfec537..92161803f 100644 --- a/tests/providers/test_xiaomi_mimo_thinking.py +++ b/tests/providers/test_xiaomi_mimo_thinking.py @@ -32,7 +32,7 @@ def _mimo_spec(): def _openrouter_spec(): - """Return the registered OpenRouter ProviderSpec (no thinking_style).""" + """Return the registered OpenRouter ProviderSpec.""" specs = {s.name: s for s in PROVIDERS} return specs["openrouter"] @@ -77,6 +77,13 @@ def test_xiaomi_mimo_uses_thinking_type_style(): assert spec.default_api_base == "https://api.xiaomimimo.com/v1" +def test_openrouter_declares_gateway_reasoning_style(): + """OpenRouter uses its own reasoning.effort field for routed thinking models.""" + spec = _openrouter_spec() + assert spec.thinking_style == "" + assert spec.gateway_reasoning_style == "reasoning_effort" + + # --------------------------------------------------------------------------- # _build_kwargs wire-format # --------------------------------------------------------------------------- From e645fbcb34ee61884167c6b2b260545d6d3c8aba Mon Sep 17 00:00:00 2001 From: Haisam Abbas Date: Wed, 20 May 2026 17:16:53 +0500 Subject: [PATCH 34/54] fix shell guard url path detection --- nanobot/agent/tools/shell.py | 2 +- tests/tools/test_tool_validation.py | 36 +++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 0252b9746..7e0ef57a8 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -418,7 +418,7 @@ class ExecTool(Tool): # Windows: match drive-root paths like `C:\` as well as `C:\path\to\file`, and UNC paths like `\\server\share` # NOTE: `*` is required so `C:\` (nothing after the slash) is still extracted. win_paths = re.findall( - r"(?:[A-Za-z]:[^\s\"'|><;]*|\\\\[^\s\"'|><;]+(?:\\[^\s\"'|><;]+)*)", + r"(?<;]*|\\\\[^\s\"'|><;]+(?:\\[^\s\"'|><;]+)*)", command ) posix_paths = re.findall(r"(?:^|[\s|>'\"])(/[^\s\"'>;|<]+)", command) # POSIX: /absolute only diff --git a/tests/tools/test_tool_validation.py b/tests/tools/test_tool_validation.py index 42620dcc6..188a8952f 100644 --- a/tests/tools/test_tool_validation.py +++ b/tests/tools/test_tool_validation.py @@ -3,6 +3,8 @@ import subprocess import sys from typing import Any +import pytest + from nanobot.agent.tools import ( ArraySchema, IntegerSchema, @@ -15,6 +17,7 @@ from nanobot.agent.tools import ( from nanobot.agent.tools.base import Tool from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.shell import ExecTool +from nanobot.security.network import configure_ssrf_whitelist class SampleTool(Tool): @@ -218,6 +221,39 @@ def test_exec_extract_absolute_paths_ignores_relative_posix_segments() -> None: assert "/bin/python" not in paths +def test_exec_extract_absolute_paths_ignores_urls() -> None: + cmd = 'curl -s -o /dev/null -w "%{http_code}" https://www.google.com' + paths = ExecTool._extract_absolute_paths(cmd) + assert paths == ["/dev/null"] + + +@pytest.mark.parametrize( + "command", + [ + 'curl -s -o /dev/null -w "%{http_code}" https://www.google.com', + 'wget -q -O - http://example.com 2>&1 | head -c 100', + 'python3 -c "import urllib.request; print(urllib.request.urlopen(\'http://example.com\').read()[:100])"', + ], +) +def test_exec_guard_allows_public_urls(tmp_path, command: str) -> None: + tool = ExecTool(restrict_to_workspace=True) + error = tool._guard_command(command, str(tmp_path)) + assert error is None + + +def test_exec_guard_allows_whitelisted_internal_urls(tmp_path) -> None: + configure_ssrf_whitelist(["10.10.10.0/24"]) + try: + tool = ExecTool(restrict_to_workspace=True) + error = tool._guard_command( + 'curl -s -H "Authorization: Bearer ..." http://10.10.10.3:8123/api/', + str(tmp_path), + ) + assert error is None + finally: + configure_ssrf_whitelist([]) + + def test_exec_extract_absolute_paths_captures_posix_absolute_paths() -> None: cmd = "cat /tmp/data.txt > /tmp/out.txt" paths = ExecTool._extract_absolute_paths(cmd) From 7e3af8c38b48c4e4f99efc9c52eb135fb7a87e60 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 14:44:34 +0800 Subject: [PATCH 35/54] docs(tools): add general tool workflow contract --- nanobot/templates/TOOLS.md | 72 +++++++++++++++++++++-------- tests/agent/test_context_builder.py | 23 +++++++++ 2 files changed, 75 insertions(+), 20 deletions(-) diff --git a/nanobot/templates/TOOLS.md b/nanobot/templates/TOOLS.md index 374e49778..ee37090ea 100644 --- a/nanobot/templates/TOOLS.md +++ b/nanobot/templates/TOOLS.md @@ -1,28 +1,60 @@ # Tool Usage Notes -Tool signatures are provided automatically via function calling. -This file documents non-obvious constraints and usage patterns. +Tool signatures are provided automatically via function calling. This file +documents the general tool contract and non-obvious usage patterns. -## exec — Safety Limits +## General Tool Contract -- Commands have a configurable timeout (default 60s) -- Dangerous commands are blocked (rm -rf, format, dd, shutdown, etc.) -- Output is truncated at 10,000 characters -- `restrictToWorkspace` config can limit file access to the workspace +- Use the narrowest structured tool that directly matches the task. +- Use read-only discovery before writes when state is uncertain. +- Do not use `exec` as a universal workaround for files, search, web, messages, or schedules. +- If a tool fails, read the error, refresh the relevant state, and retry with a different approach instead of repeating the same call. +- After meaningful changes, verify with the smallest reliable check: re-read changed state, run targeted tests, or inspect command output. +- Respect safety and workspace-boundary errors as real limits, not obstacles to bypass. -## grep — Content Search +## Discovery and Reading -- Use `grep` to search file contents inside the workspace -- Default behavior returns only matching file paths (`output_mode="files_with_matches"`) -- Supports optional `glob` filtering (e.g. `glob="*.py"`) plus `context_before` / `context_after` -- Supports `type="py"`, `type="ts"`, `type="md"` and similar shorthand filters -- Use `fixed_strings=true` for literal keywords containing regex characters -- Use `output_mode="files_with_matches"` to get only matching file paths -- Use `output_mode="count"` to size a search before reading full matches -- Use `head_limit` and `offset` to page across results -- Prefer this over `exec` for code and history searches -- Binary or oversized files may be skipped to keep results readable +- Use `find_files` or `list_dir` to locate workspace paths before `read_file` when a path is uncertain. +- Use `grep` for content search inside the workspace; prefer it over shell grep for ordinary searches. +- `grep` defaults to `output_mode="files_with_matches"`; use `output_mode="content"` for matching lines with context. +- Use `fixed_strings=true` for literal keywords containing regex characters. +- Use `output_mode="count"` to size a broad search before reading full matches. +- Use `head_limit` and `offset` to page across large result sets. +- Binary or oversized files may be skipped to keep results readable. -## cron — Scheduled Reminders +## File and Coding Workflows -- Please refer to cron skill for usage. +- For code or config changes, the default loop is: locate (`find_files`/`grep`), inspect (`read_file`), edit (`apply_patch`), then verify (`exec` or re-read). +- Use `apply_patch` as the default code editing tool, especially for multi-file changes, structural edits, generated code, moves, adds, or deletes. +- Use `apply_patch dry_run=true` when the patch is uncertain and you want validation plus a change summary before writing. +- Use `edit_file` only for small exact replacements in one file, with `old_text` copied from `read_file`; add `occurrence`, `line_hint`, or `expected_replacements` when ambiguity matters. +- Use `write_file` for new files or intentional full-file rewrites, not routine partial edits. +- If `apply_patch` or `edit_file` fails, re-read with `force=true`, narrow the context, and try a smaller patch rather than switching to shell `sed` or `echo`. + +## Process Execution + +- Use `exec` for tests, builds, package commands, git commands, and other process execution. +- Prefer dedicated file/search tools over `cat`, shell `find`, shell `grep`, `sed`, or `echo` for ordinary workspace inspection and edits. +- Use non-interactive flags such as `-y` or `--yes` when available. +- Commands have a configurable timeout (default 60s), dangerous commands are blocked, and output is truncated. +- For long-running or interactive commands, pass `yield_time_ms`; if the process keeps running, continue with `write_stdin`. +- Use `write_stdin` to poll, provide stdin, close stdin, wait for expected output with `wait_for`, or terminate an existing exec session. +- Use `list_exec_sessions` to recover active session IDs after context shifts. + +## Web and External Information + +- Use web tools when the user asks for current information, a specific URL, or information likely to have changed. +- Use `web_search` to find sources and `web_fetch` for a specific page or result that needs closer reading. +- Do not invent freshness-sensitive facts when tools can verify them. + +## Messaging and Media + +- Use `message` to send content or local media to the user/channel. +- `read_file` only reads content for your analysis; it does not deliver a file to the user. +- When sending an existing local file, attach it through the message/media mechanism instead of pasting file contents unless the user asked for text. + +## Scheduling and Background Work + +- Use `cron` for scheduled reminders or recurring jobs; do not run `nanobot cron` through `exec`. +- For heartbeat tasks, update `HEARTBEAT.md` according to the agent instructions. +- Do not write reminders only to memory files when the user expects an actual notification. diff --git a/tests/agent/test_context_builder.py b/tests/agent/test_context_builder.py index 0206d0986..abd934b0a 100644 --- a/tests/agent/test_context_builder.py +++ b/tests/agent/test_context_builder.py @@ -171,6 +171,29 @@ class TestIsTemplateContent: assert ContextBuilder._is_template_content("totally different", "memory/MEMORY.md") is False +# --------------------------------------------------------------------------- +# Bundled bootstrap templates +# --------------------------------------------------------------------------- + + +class TestBundledToolsTemplate: + def test_tools_template_balances_general_and_coding_workflows(self): + from importlib.resources import files as pkg_files + + tpl = pkg_files("nanobot") / "templates" / "TOOLS.md" + content = tpl.read_text(encoding="utf-8") + + assert "## General Tool Contract" in content + assert "Use the narrowest structured tool" in content + assert "Do not use `exec` as a universal workaround" in content + assert "## File and Coding Workflows" in content + assert "apply_patch" in content + assert "## Web and External Information" in content + assert "## Messaging and Media" in content + assert "## Scheduling and Background Work" in content + assert "pure coding" not in content.lower() + + # --------------------------------------------------------------------------- # _build_user_content # --------------------------------------------------------------------------- From 84603f4cf2c1475f2770aea14c647f39161dd68a Mon Sep 17 00:00:00 2001 From: Haisam Abbas Date: Thu, 21 May 2026 12:06:08 +0500 Subject: [PATCH 36/54] Add Ollama image generation support --- docs/image-generation.md | 31 ++++- nanobot/agent/tools/image_generation.py | 7 +- nanobot/providers/image_generation.py | 143 ++++++++++++++++++++++ tests/providers/test_image_generation.py | 49 ++++++++ tests/tools/test_image_generation_tool.py | 33 +++++ 5 files changed, 259 insertions(+), 4 deletions(-) diff --git a/docs/image-generation.md b/docs/image-generation.md index 6ca7ed3fd..a9d6b620c 100644 --- a/docs/image-generation.md +++ b/docs/image-generation.md @@ -23,7 +23,7 @@ The feature is disabled by default. Enable it in `~/.nanobot/config.json`, confi } ``` -See [Provider Notes](#provider-notes) for AIHubMix, MiniMax, and Gemini configuration examples. +See [Provider Notes](#provider-notes) for AIHubMix, MiniMax, Gemini, Ollama, and StepFun configuration examples. > [!TIP] > Prefer environment variables for API keys. nanobot resolves `${VAR_NAME}` values from the environment at startup. @@ -46,7 +46,7 @@ The WebUI hides provider storage details from the user. The agent sees the saved | Option | Type | Default | Description | |--------|------|---------|-------------| | `tools.imageGeneration.enabled` | boolean | `false` | Register the `generate_image` tool | -| `tools.imageGeneration.provider` | string | `"openrouter"` | Image provider name. Supported values: `openrouter`, `aihubmix`, `minimax`, `gemini`, `stepfun` | +| `tools.imageGeneration.provider` | string | `"openrouter"` | Image provider name. Supported values: `openrouter`, `aihubmix`, `minimax`, `gemini`, `ollama`, `stepfun` | | `tools.imageGeneration.model` | string | `"openai/gpt-5.4-image-2"` | Provider model name | | `tools.imageGeneration.defaultAspectRatio` | string | `"1:1"` | Default ratio when the prompt/tool call does not specify one | | `tools.imageGeneration.defaultImageSize` | string | `"1K"` | Default size hint, for example `1K`, `2K`, `4K`, or `1024x1024` | @@ -168,6 +168,31 @@ For reference-image edits, use a Gemini Flash image model: Imagen 4 supports the aspect ratios `1:1`, `9:16`, `16:9`, `3:4`, and `4:3`. Unsupported ratios are ignored and the model uses its default. The `defaultImageSize` setting has no effect on Gemini models; sizing is controlled by `defaultAspectRatio` only. Reference images passed with an Imagen model are ignored (with a warning logged). +### Ollama + +Ollama's experimental native image generation API works with local servers and hosted ollama.com models. Local access at `http://localhost:11434/api` does not require an API key; set `providers.ollama.apiKey` only when targeting `https://ollama.com/api`. + +```json +{ + "providers": { + "ollama": { + "apiBase": "http://localhost:11434/api" + } + }, + "tools": { + "imageGeneration": { + "enabled": true, + "provider": "ollama", + "model": "x/z-image-turbo", + "defaultAspectRatio": "16:9", + "defaultImageSize": "2K" + } + } +} +``` + +Ollama maps `defaultAspectRatio` and `defaultImageSize` to native `width` and `height` values. Reference images are not supported by this integration. + ### StepFun StepFun (阶跃星辰) `step-image-edit-2` supports text-to-image generation. The `step-1x-medium` variant additionally supports **style-reference** image edits, where a reference image guides the visual style of the output. @@ -274,7 +299,7 @@ Use the reference image. Keep the same robot and composition, change the palette |---------|-------| | `generate_image` is not available | Set `tools.imageGeneration.enabled` to `true` and restart the gateway | | Missing API key error | Configure `providers..apiKey`; if using `${VAR_NAME}`, confirm the environment variable is visible to the gateway process | -| `unsupported image generation provider` | Use `openrouter`, `aihubmix`, `minimax`, `gemini`, or `stepfun` | +| `unsupported image generation provider` | Use `openrouter`, `aihubmix`, `minimax`, `gemini`, `ollama`, or `stepfun` | | AIHubMix says `Incorrect model ID` | Use `model: "gpt-image-2-free"`; nanobot expands it to the required `openai/gpt-image-2-free` model path internally | | Generation times out | Try a smaller/default image size, set AIHubMix `extraBody.quality` to `"low"`, or retry later | | Reference image rejected | Reference image paths must be inside the workspace or nanobot media directory and must be valid image files | diff --git a/nanobot/agent/tools/image_generation.py b/nanobot/agent/tools/image_generation.py index f2f599ded..58eaaf7d8 100644 --- a/nanobot/agent/tools/image_generation.py +++ b/nanobot/agent/tools/image_generation.py @@ -21,6 +21,7 @@ from nanobot.providers.image_generation import ( ImageGenerationProvider, get_image_gen_provider, ) +from nanobot.providers.registry import find_by_name from nanobot.utils.artifacts import ( ArtifactError, generated_image_tool_result, @@ -117,6 +118,10 @@ class ImageGenerationTool(Tool): def _provider_config(self) -> ProviderConfig | None: return self.provider_configs.get(self.config.provider) + def _provider_allows_missing_api_key(self) -> bool: + spec = find_by_name(self.config.provider) + return bool(spec and (spec.is_local or spec.is_direct or spec.is_oauth)) + def _provider_client(self) -> ImageGenerationProvider | None: provider = self._provider_config() cls = get_image_gen_provider(self.config.provider) @@ -174,7 +179,7 @@ class ImageGenerationTool(Tool): if client is None: return f"Error: unsupported image generation provider '{self.config.provider}'" provider = self._provider_config() - if not provider or not provider.api_key: + if not self._provider_allows_missing_api_key() and (not provider or not provider.api_key): return self._missing_api_key_error() requested = count or 1 diff --git a/nanobot/providers/image_generation.py b/nanobot/providers/image_generation.py index 501b98fd2..3ea8c374a 100644 --- a/nanobot/providers/image_generation.py +++ b/nanobot/providers/image_generation.py @@ -4,6 +4,7 @@ from __future__ import annotations import base64 import binascii +import re from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path @@ -31,6 +32,14 @@ _AIHUBMIX_ASPECT_RATIO_SIZES = { } _GEMINI_DEFAULT_TIMEOUT_S = 120.0 _GEMINI_IMAGEN_ASPECT_RATIOS = {"1:1", "9:16", "16:9", "3:4", "4:3"} +_OLLAMA_DEFAULT_SIDE = 1024 +_OLLAMA_SIZE_PRESETS = { + "1K": 1024, + "2K": 2048, + "4K": 4096, +} +_OLLAMA_EXPLICIT_SIZE_RE = re.compile(r"^\s*(\d+)\s*[xX]\s*(\d+)\s*$") +_OLLAMA_ASPECT_RATIO_RE = re.compile(r"^\s*(\d+)\s*:\s*(\d+)\s*$") class ImageGenerationError(RuntimeError): @@ -429,6 +438,139 @@ def _http_error_detail(response: httpx.Response) -> str: return response.text[:500] or "" +def _round_to_multiple(value: float, multiple: int = 8) -> int: + rounded = int(round(value / multiple) * multiple) + return max(multiple, rounded) + + +def _ollama_dimensions(aspect_ratio: str | None, image_size: str | None) -> tuple[int, int]: + if image_size: + size = image_size.strip() + explicit = _OLLAMA_EXPLICIT_SIZE_RE.fullmatch(size) + if explicit: + return int(explicit.group(1)), int(explicit.group(2)) + long_side = _OLLAMA_SIZE_PRESETS.get(size.upper(), _OLLAMA_DEFAULT_SIDE) + else: + long_side = _OLLAMA_DEFAULT_SIDE + + if not aspect_ratio: + return long_side, long_side + + ratio = _OLLAMA_ASPECT_RATIO_RE.fullmatch(aspect_ratio.strip()) + if ratio is None: + return long_side, long_side + + width_ratio = int(ratio.group(1)) + height_ratio = int(ratio.group(2)) + if width_ratio <= 0 or height_ratio <= 0: + return long_side, long_side + + if width_ratio >= height_ratio: + width = long_side + height = _round_to_multiple(long_side * height_ratio / width_ratio) + else: + height = long_side + width = _round_to_multiple(long_side * width_ratio / height_ratio) + return max(8, width), max(8, height) + + +def _ollama_image_data_url(value: str) -> str: + if value.startswith("data:image/"): + return value + return _b64_image_data_url(value) + + +def _ollama_images_from_payload(payload: dict[str, Any]) -> list[str]: + images: list[str] = [] + + def collect(value: Any) -> None: + if isinstance(value, str) and value: + images.append(_ollama_image_data_url(value)) + elif isinstance(value, list): + for item in value: + collect(item) + + collect(payload.get("image")) + collect(payload.get("images")) + return images + + +class OllamaImageGenerationClient(ImageGenerationProvider): + """Async client for Ollama native image generation models.""" + + provider_name = "ollama" + default_timeout = 300.0 + + def _default_base_url(self) -> str: + return "http://localhost:11434/api" + + def _resolve_base_url(self, api_base: str | None) -> str: + if api_base: + base = api_base.rstrip("/") + if base.endswith("/v1"): + return f"{base[:-3]}/api" + return base + return self._default_base_url() + + async def generate( + self, + *, + prompt: str, + model: str, + reference_images: list[str] | None = None, + aspect_ratio: str | None = None, + image_size: str | None = None, + ) -> GeneratedImageResponse: + if reference_images: + raise ImageGenerationError( + "Ollama image generation does not support reference images" + ) + + width, height = _ollama_dimensions(aspect_ratio, image_size) + body: dict[str, Any] = { + "model": model, + "prompt": prompt, + "width": width, + "height": height, + "steps": 0, + } + body.update(self.extra_body) + body["stream"] = False + + headers = { + "Content-Type": "application/json", + **self.extra_headers, + } + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + url = f"{self.api_base}/generate" + response = await self._http_post(url, headers=headers, body=body) + + try: + response.raise_for_status() + except httpx.HTTPStatusError as exc: + detail = _http_error_detail(response) + logger.error( + "Ollama image generation failed (HTTP {}): {}", + response.status_code, + detail, + ) + raise ImageGenerationError( + f"Ollama image generation failed (HTTP {response.status_code}): {detail}" + ) from exc + + data = response.json() + images = _ollama_images_from_payload(data) + + self._require_images(images, data) + + response_text = data.get("response") + content = response_text if isinstance(response_text, str) else "" + + return GeneratedImageResponse(images=images, content=content, raw=data) + + class GeminiImageGenerationClient(ImageGenerationProvider): """Async client for Gemini/Imagen image generation via the Generative Language API.""" @@ -886,5 +1028,6 @@ def _stepfun_images_from_payload(payload: dict[str, Any]) -> list[str]: register_image_gen_provider(OpenRouterImageGenerationClient) register_image_gen_provider(AIHubMixImageGenerationClient) register_image_gen_provider(GeminiImageGenerationClient) +register_image_gen_provider(OllamaImageGenerationClient) register_image_gen_provider(MiniMaxImageGenerationClient) register_image_gen_provider(StepFunImageGenerationClient) diff --git a/tests/providers/test_image_generation.py b/tests/providers/test_image_generation.py index 3bee376d8..701f09f0a 100644 --- a/tests/providers/test_image_generation.py +++ b/tests/providers/test_image_generation.py @@ -13,6 +13,7 @@ from nanobot.providers.image_generation import ( GeneratedImageResponse, ImageGenerationError, MiniMaxImageGenerationClient, + OllamaImageGenerationClient, OpenRouterImageGenerationClient, StepFunImageGenerationClient, ) @@ -133,6 +134,54 @@ async def test_openrouter_image_generation_requires_api_key() -> None: await client.generate(prompt="draw", model="model") +@pytest.mark.asyncio +async def test_ollama_image_generation_payload_and_response() -> None: + raw_b64 = PNG_DATA_URL.removeprefix("data:image/png;base64,") + fake = FakeClient(FakeResponse({"image": raw_b64})) + client = OllamaImageGenerationClient( + api_key="ollama-test", + api_base="http://localhost:11434/v1/", + extra_headers={"X-Test": "1"}, + extra_body={"seed": 123}, + client=fake, # type: ignore[arg-type] + ) + + response = await client.generate( + prompt="a sunset", + model="x/z-image-turbo", + aspect_ratio="16:9", + image_size="1K", + ) + + assert response.images == [PNG_DATA_URL] + assert response.content == "" + + call = fake.calls[0] + assert call["url"] == "http://localhost:11434/api/generate" + assert call["headers"]["Authorization"] == "Bearer ollama-test" + assert call["headers"]["X-Test"] == "1" + body = call["json"] + assert body["model"] == "x/z-image-turbo" + assert body["prompt"] == "a sunset" + assert body["width"] == 1024 + assert body["height"] == 576 + assert body["steps"] == 0 + assert body["stream"] is False + assert body["seed"] == 123 + + +@pytest.mark.asyncio +async def test_ollama_image_generation_rejects_reference_images() -> None: + client = OllamaImageGenerationClient(api_key=None) + + with pytest.raises(ImageGenerationError, match="reference images"): + await client.generate( + prompt="edit this", + model="x/z-image-turbo", + reference_images=["ref.png"], + ) + + @pytest.mark.asyncio async def test_aihubmix_image_generation_payload_and_response() -> None: raw_b64 = PNG_DATA_URL.removeprefix("data:image/png;base64,") diff --git a/tests/tools/test_image_generation_tool.py b/tests/tools/test_image_generation_tool.py index 92ed8a339..f5d2d9183 100644 --- a/tests/tools/test_image_generation_tool.py +++ b/tests/tools/test_image_generation_tool.py @@ -138,6 +138,39 @@ async def test_generate_image_tool_reports_missing_aihubmix_key(tmp_path: Path) assert result.startswith("Error: AIHubMix API key is not configured") +@pytest.mark.asyncio +async def test_generate_image_tool_allows_ollama_without_api_key( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + set_config_path(tmp_path / "config.json") + FakeImageClient.instances = [] + monkeypatch.setattr( + "nanobot.agent.tools.image_generation.get_image_gen_provider", + lambda name: FakeImageClient if name == "ollama" else None, + ) + tool = ImageGenerationTool( + workspace=tmp_path, + config=ImageGenerationToolConfig( + enabled=True, + provider="ollama", + model="x/z-image-turbo", + ), + provider_configs={"ollama": ProviderConfig(api_base="http://localhost:11434/v1")}, + ) + + result = await tool.execute(prompt="draw a cat") + + payload = json.loads(result) + assert len(payload["artifacts"]) == 1 + + fake = FakeImageClient.instances[0] + assert fake.kwargs["api_key"] is None + assert fake.kwargs["api_base"] == "http://localhost:11434/v1" + assert fake.calls[0]["aspect_ratio"] == "1:1" + assert fake.calls[0]["image_size"] == "1K" + + @pytest.mark.asyncio async def test_generate_image_tool_rejects_reference_outside_workspace(tmp_path: Path) -> None: set_config_path(tmp_path / "config.json") From d29fcaf5d14dd587c60ae3f3b5c3687adc6b453a Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 15:21:39 +0800 Subject: [PATCH 37/54] refactor(agent): internalize tool contract prompt --- nanobot/agent/context.py | 5 +++-- nanobot/templates/AGENTS.md | 12 +++++++---- .../{TOOLS.md => agent/tool_contract.md} | 2 +- nanobot/utils/helpers.py | 5 +++-- tests/agent/test_context_builder.py | 21 ++++++++++++++++--- tests/agent/test_onboard_logic.py | 20 ++++++++++++++++++ 6 files changed, 53 insertions(+), 12 deletions(-) rename nanobot/templates/{TOOLS.md => agent/tool_contract.md} (99%) diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index 19ee935c4..82ebfab65 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -22,7 +22,7 @@ from nanobot.utils.prompt_templates import render_template class ContextBuilder: """Builds the context (system prompt + messages) for the agent.""" - BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"] + BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md"] _RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]" _MAX_RECENT_HISTORY = 50 _MAX_HISTORY_CHARS = 32_000 # hard cap on recent history section size @@ -47,6 +47,8 @@ class ContextBuilder: if bootstrap: parts.append(bootstrap) + parts.append(render_template("agent/tool_contract.md")) + memory = self.memory.get_memory_context() if memory and not self._is_template_content(self.memory.read_memory(), "memory/MEMORY.md"): parts.append(f"# Memory\n\n{memory}") @@ -210,4 +212,3 @@ class ContextBuilder: if not images: return text return images + [{"type": "text", "text": text}] - diff --git a/nanobot/templates/AGENTS.md b/nanobot/templates/AGENTS.md index 0bf6de3d3..46cfc08c3 100644 --- a/nanobot/templates/AGENTS.md +++ b/nanobot/templates/AGENTS.md @@ -1,5 +1,9 @@ # Agent Instructions +## Workspace Guidance + +Use this file for project-specific preferences, recurring workflow conventions, and instructions you want the agent to remember for this workspace. Keep durable facts about the user in `USER.md`, personality/style guidance in `SOUL.md`, and long-term memory in `memory/MEMORY.md`. + ## Scheduled Reminders Before scheduling reminders, check available skills and follow skill guidance first. @@ -10,10 +14,10 @@ Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegr ## Heartbeat Tasks -`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks: +`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks. -- **Add**: `edit_file` to append new tasks -- **Remove**: `edit_file` to delete completed tasks -- **Rewrite**: `write_file` to replace all tasks +- Use `apply_patch` for normal task-list updates, especially when adding, removing, or changing multiple lines. +- Use `edit_file` only for small exact replacements copied from the current `HEARTBEAT.md`. +- Use `write_file` for first creation or intentional full-file rewrites. When the user asks for a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time cron reminder. diff --git a/nanobot/templates/TOOLS.md b/nanobot/templates/agent/tool_contract.md similarity index 99% rename from nanobot/templates/TOOLS.md rename to nanobot/templates/agent/tool_contract.md index ee37090ea..edbba21c9 100644 --- a/nanobot/templates/TOOLS.md +++ b/nanobot/templates/agent/tool_contract.md @@ -1,6 +1,6 @@ # Tool Usage Notes -Tool signatures are provided automatically via function calling. This file +Tool signatures are provided automatically via function calling. This section documents the general tool contract and non-obvious usage patterns. ## General Tool Contract diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 2a969298c..ae91bf394 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -576,7 +576,7 @@ def build_status_content( def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]: - """Sync bundled templates to workspace. Only creates missing files.""" + """Sync bundled templates to workspace. Creates missing files without overwriting user files.""" from importlib.resources import files as pkg_files try: @@ -589,10 +589,11 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str] added: list[str] = [] def _write(src, dest: Path): + content = src.read_text(encoding="utf-8") if src else "" if dest.exists(): return dest.parent.mkdir(parents=True, exist_ok=True) - dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8") + dest.write_text(content, encoding="utf-8") added.append(str(dest.relative_to(workspace))) for item in tpl.iterdir(): diff --git a/tests/agent/test_context_builder.py b/tests/agent/test_context_builder.py index abd934b0a..a36c0a30a 100644 --- a/tests/agent/test_context_builder.py +++ b/tests/agent/test_context_builder.py @@ -139,6 +139,13 @@ class TestLoadBootstrapFiles: for name in ContextBuilder.BOOTSTRAP_FILES: assert f"## {name}" in result + def test_legacy_tools_md_is_not_bootstrapped(self, tmp_path): + (tmp_path / "TOOLS.md").write_text("workspace tool notes", encoding="utf-8") + builder = _builder(tmp_path) + result = builder._load_bootstrap_files() + assert "TOOLS.md" not in result + assert "workspace tool notes" not in result + def test_utf8_content(self, tmp_path): (tmp_path / "AGENTS.md").write_text("用中文回复", encoding="utf-8") builder = _builder(tmp_path) @@ -176,11 +183,11 @@ class TestIsTemplateContent: # --------------------------------------------------------------------------- -class TestBundledToolsTemplate: - def test_tools_template_balances_general_and_coding_workflows(self): +class TestBundledToolContract: + def test_tool_contract_balances_general_and_coding_workflows(self): from importlib.resources import files as pkg_files - tpl = pkg_files("nanobot") / "templates" / "TOOLS.md" + tpl = pkg_files("nanobot") / "templates" / "agent" / "tool_contract.md" content = tpl.read_text(encoding="utf-8") assert "## General Tool Contract" in content @@ -193,6 +200,14 @@ class TestBundledToolsTemplate: assert "## Scheduling and Background Work" in content assert "pure coding" not in content.lower() + def test_tool_contract_is_injected_without_workspace_file(self, tmp_path): + builder = _builder(tmp_path) + prompt = builder.build_system_prompt() + + assert "# Tool Usage Notes" in prompt + assert "## General Tool Contract" in prompt + assert "Do not use `exec` as a universal workaround" in prompt + # --------------------------------------------------------------------------- # _build_user_content diff --git a/tests/agent/test_onboard_logic.py b/tests/agent/test_onboard_logic.py index 11a284bb5..762da4f31 100644 --- a/tests/agent/test_onboard_logic.py +++ b/tests/agent/test_onboard_logic.py @@ -346,6 +346,26 @@ class TestSyncWorkspaceTemplates: content = (workspace / "AGENTS.md").read_text() assert content == "existing content" + def test_does_not_create_tools_md(self, tmp_path): + """Tool contract is injected internally, not copied into user workspaces.""" + workspace = tmp_path / "workspace" + + added = sync_workspace_templates(workspace, silent=True) + + assert "TOOLS.md" not in added + assert not (workspace / "TOOLS.md").exists() + + def test_preserves_existing_tools_md_without_overwriting(self, tmp_path): + """Legacy user workspaces may have TOOLS.md; sync should leave it untouched.""" + workspace = tmp_path / "workspace" + workspace.mkdir(parents=True) + tools_path = workspace / "TOOLS.md" + tools_path.write_text("custom tool notes", encoding="utf-8") + + sync_workspace_templates(workspace, silent=True) + + assert tools_path.read_text(encoding="utf-8") == "custom tool notes" + def test_creates_memory_directory(self, tmp_path): """Should create memory directory structure.""" workspace = tmp_path / "workspace" From 23d5148a57541d9719ea7f0b9c4b94d1d00cd037 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 15:33:49 +0800 Subject: [PATCH 38/54] fix(provider): dedupe repeated tool ids in history --- nanobot/providers/openai_compat_provider.py | 42 +++++++++++++++++++-- tests/providers/test_litellm_kwargs.py | 35 +++++++++++++++++ 2 files changed, 74 insertions(+), 3 deletions(-) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index a61439025..3c1bf9b8f 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -11,6 +11,7 @@ import secrets import string import time import uuid +from collections import deque from collections.abc import Awaitable, Callable from ipaddress import ip_address from typing import TYPE_CHECKING, Any @@ -463,6 +464,7 @@ class OpenAICompatProvider(LLMProvider): """Strip non-standard keys, normalize tool_call IDs.""" sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS) id_map: dict[str, str] = {} + pending_tool_ids: dict[str, deque[str]] = {} force_string_content = bool(self._spec and self._spec.name == "deepseek") def map_id(value: Any) -> Any: @@ -470,15 +472,49 @@ class OpenAICompatProvider(LLMProvider): return value return id_map.setdefault(value, self._normalize_tool_call_id(value)) + def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str: + if isinstance(value, str) and value: + base = map_id(value) + else: + base = _short_tool_id() + if not isinstance(base, str) or not base: + base = _short_tool_id() + if base not in used_ids: + return base + seed = value if isinstance(value, str) and value else base + salt = 1 + while True: + candidate = self._normalize_tool_call_id(f"{seed}:{idx}:{salt}") + if isinstance(candidate, str) and candidate not in used_ids: + return candidate + salt += 1 + + def map_tool_result_id(value: Any) -> Any: + if not isinstance(value, str): + return value + queue = pending_tool_ids.get(value) + if queue: + mapped = queue.popleft() + if not queue: + pending_tool_ids.pop(value, None) + return mapped + return map_id(value) + for clean in sanitized: if isinstance(clean.get("tool_calls"), list): normalized = [] - for tc in clean["tool_calls"]: + used_ids: set[str] = set() + for idx, tc in enumerate(clean["tool_calls"]): if not isinstance(tc, dict): normalized.append(tc) continue tc_clean = dict(tc) - tc_clean["id"] = map_id(tc_clean.get("id")) + raw_id = tc_clean.get("id") + mapped_id = unique_tool_id(raw_id, used_ids, idx) + tc_clean["id"] = mapped_id + used_ids.add(mapped_id) + if isinstance(raw_id, str) and raw_id: + pending_tool_ids.setdefault(raw_id, deque()).append(mapped_id) function = tc_clean.get("function") if isinstance(function, dict): function_clean = dict(function) @@ -496,7 +532,7 @@ class OpenAICompatProvider(LLMProvider): # that mix non-empty content with tool_calls. clean["content"] = None if "tool_call_id" in clean and clean["tool_call_id"]: - clean["tool_call_id"] = map_id(clean["tool_call_id"]) + clean["tool_call_id"] = map_tool_result_id(clean["tool_call_id"]) if ( force_string_content and not (clean.get("role") == "assistant" and clean.get("tool_calls")) diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 461913c93..6a32981d9 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -1007,6 +1007,41 @@ def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() - assert sanitized[2]["tool_call_id"] == "3ec83c30d" +def test_openai_compat_deduplicates_duplicate_tool_call_ids_in_history() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + sanitized = provider._sanitize_messages([ + {"role": "user", "content": "check both files"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "ab1b45c2a", + "type": "function", + "function": {"name": "read_file", "arguments": '{"path":"a.txt"}'}, + }, + { + "id": "ab1b45c2a", + "type": "function", + "function": {"name": "read_file", "arguments": '{"path":"b.txt"}'}, + }, + ], + }, + {"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "a"}, + {"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "b"}, + {"role": "user", "content": "continue"}, + ]) + + tool_call_ids = [tc["id"] for tc in sanitized[1]["tool_calls"]] + tool_result_ids = [sanitized[2]["tool_call_id"], sanitized[3]["tool_call_id"]] + + assert tool_call_ids[0] == "ab1b45c2a" + assert len(tool_call_ids) == len(set(tool_call_ids)) == 2 + assert tool_result_ids == tool_call_ids + + def test_openai_compat_stringifies_dict_tool_arguments() -> None: with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): provider = OpenAICompatProvider() From 722b760eae1c52a4ba107fec25c104b0828a8726 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 15:44:01 +0800 Subject: [PATCH 39/54] feat(webui): stream apply patch edit progress --- nanobot/utils/file_edit_events.py | 205 +++++++++++++++++++++++++++ tests/utils/test_file_edit_events.py | 60 ++++++++ 2 files changed, 265 insertions(+) diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index f02022c13..fc561bcb0 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -393,6 +393,9 @@ class StreamingFileEditTracker: self._states[key] = state state.apply_delta(payload) + if state.name == "apply_patch": + await self._update_apply_patch(state) + return if state.name not in {"write_file", "edit_file"}: return if state.path is None: @@ -432,10 +435,62 @@ class StreamingFileEditTracker: deleted=deleted, )]) + async def _update_apply_patch(self, state: _StreamingFileEditState) -> None: + if _json_bool_true(state.arguments, "dry_run"): + return + patch = _extract_json_string_prefix(state.arguments, "patch") + if not patch: + return + tool = self._tools.get("apply_patch") if hasattr(self._tools, "get") else None + events: list[dict[str, Any]] = [] + now = time.monotonic() + for raw_path, added, deleted, delete_file in _streaming_apply_patch_stats(patch): + path = _resolve_raw_file_edit_path(tool, self._workspace, raw_path) + if path is None: + continue + file_state = state.patch_files.get(raw_path) + if file_state is None: + tracker = FileEditTracker( + call_id=state.call_id or state.key, + tool="apply_patch", + path=path, + display_path=display_file_edit_path(path, self._workspace), + before=read_file_snapshot(path), + ) + file_state = _StreamingPatchFileState(tracker=tracker) + state.patch_files[raw_path] = file_state + if delete_file and added == 0 and deleted == 0 and file_state.tracker.before.countable: + deleted = _text_line_count(file_state.tracker.before.text or "") + if not file_state.should_emit(added, deleted, now): + continue + file_state.mark_emitted(added, deleted, now) + events.append(build_file_edit_live_event( + file_state.tracker, + added=added, + deleted=deleted, + )) + if events: + await self._emit(events) + async def flush(self) -> None: events: list[dict[str, Any]] = [] now = time.monotonic() for state in self._states.values(): + for file_state in state.patch_files.values(): + added, deleted = file_state.last_added, file_state.last_deleted + if not file_state.emitted_once: + continue + if ( + file_state.last_emitted_added == added + and file_state.last_emitted_deleted == deleted + ): + continue + file_state.mark_emitted(added, deleted, now) + events.append(build_file_edit_live_event( + file_state.tracker, + added=added, + deleted=deleted, + )) if state.tracker is None: continue added, deleted = state.live_diff_counts() @@ -480,6 +535,10 @@ class StreamingFileEditTracker: """Mark streamed edits as failed when no final tool call will run.""" events: list[dict[str, Any]] = [] for state in self._states.values(): + for file_state in state.patch_files.values(): + if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls): + continue + events.append(build_file_edit_error_event(file_state.tracker, error)) if state.tracker is None: continue if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls): @@ -583,6 +642,39 @@ class _StreamingJsonStringField: self.last_char_cr = False +@dataclass(slots=True) +class _StreamingPatchFileState: + tracker: FileEditTracker + emitted_once: bool = False + last_emitted_added: int = -1 + last_emitted_deleted: int = -1 + last_emit_at: float = 0.0 + last_added: int = 0 + last_deleted: int = 0 + + def should_emit(self, added: int, deleted: int, now: float) -> bool: + self.last_added = added + self.last_deleted = deleted + if not self.emitted_once: + return True + if added == self.last_emitted_added and deleted == self.last_emitted_deleted: + return False + if max( + abs(added - self.last_emitted_added), + abs(deleted - self.last_emitted_deleted), + ) >= _LIVE_EMIT_LINE_STEP: + return True + return now - self.last_emit_at >= _LIVE_EMIT_INTERVAL_S + + def mark_emitted(self, added: int, deleted: int, now: float) -> None: + self.emitted_once = True + self.last_added = added + self.last_deleted = deleted + self.last_emitted_added = added + self.last_emitted_deleted = deleted + self.last_emit_at = now + + @dataclass(slots=True) class _StreamingFileEditState: key: str @@ -600,6 +692,7 @@ class _StreamingFileEditState: new_text: _StreamingJsonStringField = field( default_factory=lambda: _StreamingJsonStringField("new_text") ) + patch_files: dict[str, _StreamingPatchFileState] = field(default_factory=dict) emitted_once: bool = False last_emitted_added: int = -1 last_emitted_deleted: int = -1 @@ -622,6 +715,7 @@ class _StreamingFileEditState: self.content.reset() self.old_text.reset() self.new_text.reset() + self.patch_files.clear() return delta = payload.get("arguments_delta") if isinstance(delta, str) and delta: @@ -681,6 +775,13 @@ class _StreamingFileEditState: name = getattr(tool_call, "name", None) if name != self.name: return False + if self.name == "apply_patch": + arguments = getattr(tool_call, "arguments", None) + if not isinstance(arguments, dict): + return False + patch = arguments.get("patch") + streamed_patch = _extract_complete_json_string(self.arguments, "patch") + return isinstance(patch, str) and streamed_patch == patch arguments = getattr(tool_call, "arguments", None) if not isinstance(arguments, dict): return False @@ -703,6 +804,110 @@ def _stream_key(payload: dict[str, Any]) -> str: return "" +def _json_bool_true(source: str, key: str) -> bool: + return re.search(rf'"{re.escape(key)}"\s*:\s*true\b', source) is not None + + +def _extract_json_string_prefix(source: str, key: str) -> str | None: + match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source) + if match is None: + return None + out: list[str] = [] + i = match.end() + escape = False + while i < len(source): + ch = source[i] + if escape: + escape = False + if ch == "n": + out.append("\n") + elif ch == "r": + out.append("\r") + elif ch == "t": + out.append("\t") + elif ch == "u": + digits = source[i + 1:i + 5] + if len(digits) < 4: + break + try: + out.append(chr(int(digits, 16))) + except ValueError: + break + i += 4 + else: + out.append(ch) + i += 1 + continue + if ch == "\\": + escape = True + i += 1 + continue + if ch == '"': + return "".join(out) + out.append(ch) + i += 1 + return "".join(out) + + +def _streaming_apply_patch_stats(patch: str) -> list[tuple[str, int, int, bool]]: + stats: dict[str, list[Any]] = {} + order: list[str] = [] + current: str | None = None + + def ensure(path: str, *, delete_file: bool = False) -> list[Any]: + if path not in stats: + stats[path] = [0, 0, False] + order.append(path) + if delete_file: + stats[path][2] = True + return stats[path] + + lines = patch.splitlines() + tail = "" + if patch and not patch.endswith(("\n", "\r")) and lines: + tail = lines.pop() + + for line in lines: + if line.startswith("*** Add File: "): + current = line[len("*** Add File: "):].strip() + if current: + ensure(current) + continue + if line.startswith("*** Update File: "): + current = line[len("*** Update File: "):].strip() + if current: + ensure(current) + continue + if line.startswith("*** Delete File: "): + current = line[len("*** Delete File: "):].strip() + if current: + ensure(current, delete_file=True) + continue + if line.startswith("*** Move to: "): + moved = line[len("*** Move to: "):].strip() + if moved: + current = moved + ensure(current) + continue + if line.startswith("*** "): + current = None + continue + if not current: + continue + if line.startswith("+") and not line.startswith("+++"): + ensure(current)[0] += 1 + elif line.startswith("-") and not line.startswith("---"): + ensure(current)[1] += 1 + + if current and tail: + if tail.startswith("+") and not tail.startswith("+++"): + ensure(current)[0] += 1 + elif tail.startswith("-") and not tail.startswith("---"): + ensure(current)[1] += 1 + + return [(path, int(stats[path][0]), int(stats[path][1]), bool(stats[path][2])) for path in order] + + def _extract_complete_json_string(source: str, key: str) -> str | None: match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source) if match is None: diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py index ec1046061..57d1272c9 100644 --- a/tests/utils/test_file_edit_events.py +++ b/tests/utils/test_file_edit_events.py @@ -206,6 +206,66 @@ def test_streaming_write_file_tracker_emits_live_line_counts(tmp_path: Path) -> assert events[-1]["deleted"] == 0 +def test_streaming_apply_patch_tracker_emits_live_counts_per_file(tmp_path: Path) -> None: + (tmp_path / "src").mkdir() + (tmp_path / "src" / "existing.py").write_text("old\nkeep\n", encoding="utf-8") + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-patch", + "name": "apply_patch", + "arguments_delta": ( + '{"patch":"*** Begin Patch\\n' + '*** Update File: src/existing.py\\n' + '@@\\n' + '-old\\n' + '+new\\n' + ' keep\\n' + '*** Add File: src/new.py\\n' + '+fresh\\n' + ), + }) + + asyncio.run(run()) + + by_path = {event["path"]: event for event in events} + assert by_path["src/existing.py"]["tool"] == "apply_patch" + assert by_path["src/existing.py"]["status"] == "editing" + assert by_path["src/existing.py"]["approximate"] is True + assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1) + assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0) + + +def test_streaming_apply_patch_tracker_skips_dry_run(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-patch", + "name": "apply_patch", + "arguments_delta": ( + '{"dry_run":true,"patch":"*** Begin Patch\\n' + '*** Add File: dry.md\\n' + '+preview\\n' + ), + }) + + asyncio.run(run()) + + assert events == [] + + def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None: events: list[dict] = [] From 835bab5f5a65f2be0cf0239a26d4d1f40a0be5cd Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 16:06:52 +0800 Subject: [PATCH 40/54] fix(exec): stabilize Windows shell tests --- nanobot/agent/tools/shell.py | 3 ++- tests/tools/test_exec_session_tools.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 25fbff50c..537c89343 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -414,7 +414,8 @@ class ExecTool(Tool): ) shell_program = shell_program or shutil.which("bash") or "/bin/bash" args = [shell_program] - if login and Path(shell_program).name in {"bash", "zsh"}: + shell_name = Path(shell_program).name.lower() + if login and shell_name in {"bash", "bash.exe", "zsh", "zsh.exe"}: args.append("-l") args.extend(["-c", command]) return await asyncio.create_subprocess_exec( diff --git a/tests/tools/test_exec_session_tools.py b/tests/tools/test_exec_session_tools.py index 76b3c9781..f5fe45e96 100644 --- a/tests/tools/test_exec_session_tools.py +++ b/tests/tools/test_exec_session_tools.py @@ -37,7 +37,10 @@ def test_exec_keeps_one_shot_behavior_without_yield_time_ms(tmp_path): def test_exec_accepts_command_aliases(tmp_path): async def run() -> str: tool = ExecTool(working_dir="/") - return await tool.execute(cmd="pwd", workdir=str(tmp_path)) + return await tool.execute( + cmd=_python_command("import os; print(os.getcwd())"), + workdir=str(tmp_path), + ) result = asyncio.run(run()) From 0d1d23b5fb58fb825b120cb39889dc2f1234574d Mon Sep 17 00:00:00 2001 From: Alex-wuhu Date: Wed, 20 May 2026 16:19:37 +0800 Subject: [PATCH 41/54] feat: add Novita AI provider --- docs/configuration.md | 1 + nanobot/config/schema.py | 1 + nanobot/providers/registry.py | 12 ++++++++++++ tests/config/test_model_presets.py | 16 ++++++++++++++++ tests/providers/test_litellm_kwargs.py | 9 +++++++++ 5 files changed, 39 insertions(+) diff --git a/docs/configuration.md b/docs/configuration.md index dbd5e2626..17f73619e 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -148,6 +148,7 @@ ANTHROPIC_API_KEY="$(bw get password api/anthropic)" nanobot agent | `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | | `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) | | `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) | +| `novita` | LLM (Novita AI, OpenAI-compatible gateway) | [novita.ai](https://novita.ai) | | `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | | `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) | | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index c0ad7e758..2e094cc09 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -211,6 +211,7 @@ class ProvidersConfig(Base): ant_ling: ProviderConfig = Field(default_factory=ProviderConfig) # Ant Ling aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动) + novita: ProviderConfig = Field(default_factory=ProviderConfig) # Novita AI volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎) volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international) diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index d942c03bf..ab7e2cf1e 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -199,6 +199,18 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( default_api_base="https://api.siliconflow.cn/v1", ), + # Novita AI: OpenAI-compatible gateway for hosted model APIs. + ProviderSpec( + name="novita", + keywords=("novita",), + env_key="NOVITA_API_KEY", + display_name="Novita AI", + backend="openai_compat", + is_gateway=True, + detect_by_base_keyword="novita", + default_api_base="https://api.novita.ai/openai", + ), + # VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models ProviderSpec( name="volcengine", diff --git a/tests/config/test_model_presets.py b/tests/config/test_model_presets.py index 046c5b04d..4786326a9 100644 --- a/tests/config/test_model_presets.py +++ b/tests/config/test_model_presets.py @@ -192,3 +192,19 @@ def test_match_provider_uses_preset_provider_when_forced() -> None: }) name = config.get_provider_name() assert name == "anthropic" + + +def test_match_provider_routes_novita_prefixed_models() -> None: + config = Config.model_validate({ + "providers": { + "novita": {"apiKey": "sk-test"}, + }, + "agents": { + "defaults": { + "model": "novita/deepseek/deepseek-v4-pro", + } + }, + }) + + assert config.get_provider_name() == "novita" + assert config.get_api_base() == "https://api.novita.ai/openai" diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 6a32981d9..dddc70054 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -441,6 +441,15 @@ def test_openrouter_spec_is_gateway() -> None: assert spec.default_api_base == "https://openrouter.ai/api/v1" +def test_novita_spec_uses_openai_compatible_gateway() -> None: + spec = find_by_name("novita") + assert spec is not None + assert spec.is_gateway is True + assert spec.backend == "openai_compat" + assert spec.env_key == "NOVITA_API_KEY" + assert spec.default_api_base == "https://api.novita.ai/openai" + + def test_gemma_routes_to_gemini_provider() -> None: """gemma models (e.g. gemma-3-27b-it) must auto-route to Gemini when GEMINI_API_KEY is set. Users running gemma via the Gemini API endpoint expect automatic provider detection.""" From e5476573f4a5b6b3e4367bc613e6050a6b89ba90 Mon Sep 17 00:00:00 2001 From: Alex-wuhu Date: Wed, 20 May 2026 16:38:57 +0800 Subject: [PATCH 42/54] test(providers): align Novita provider coverage --- docs/configuration.md | 2 +- nanobot/providers/registry.py | 3 +- tests/config/test_model_presets.py | 5 +- tests/providers/test_novita_provider.py | 78 +++++++++++++++++++++++++ 4 files changed, 84 insertions(+), 4 deletions(-) create mode 100644 tests/providers/test_novita_provider.py diff --git a/docs/configuration.md b/docs/configuration.md index 17f73619e..f309f5376 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -148,7 +148,7 @@ ANTHROPIC_API_KEY="$(bw get password api/anthropic)" nanobot agent | `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | | `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) | | `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) | -| `novita` | LLM (Novita AI, OpenAI-compatible gateway) | [novita.ai](https://novita.ai) | +| `novita` | LLM (NovitaAI Model API: 200+ models on an AI-native cloud for builders and agents) | [novita.ai](https://novita.ai) | | `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | | `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) | | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index ab7e2cf1e..04a3d3757 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -199,7 +199,8 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( default_api_base="https://api.siliconflow.cn/v1", ), - # Novita AI: OpenAI-compatible gateway for hosted model APIs. + # NovitaAI: AI-native cloud for builders and agents. Model API exposes + # 200+ models through an OpenAI-compatible gateway. ProviderSpec( name="novita", keywords=("novita",), diff --git a/tests/config/test_model_presets.py b/tests/config/test_model_presets.py index 4786326a9..fe01c2547 100644 --- a/tests/config/test_model_presets.py +++ b/tests/config/test_model_presets.py @@ -194,14 +194,15 @@ def test_match_provider_uses_preset_provider_when_forced() -> None: assert name == "anthropic" -def test_match_provider_routes_novita_prefixed_models() -> None: +def test_match_provider_routes_forced_novita_model_api_models() -> None: config = Config.model_validate({ "providers": { "novita": {"apiKey": "sk-test"}, }, "agents": { "defaults": { - "model": "novita/deepseek/deepseek-v4-pro", + "model": "deepseek-v4-pro", + "provider": "novita", } }, }) diff --git a/tests/providers/test_novita_provider.py b/tests/providers/test_novita_provider.py new file mode 100644 index 000000000..820529e34 --- /dev/null +++ b/tests/providers/test_novita_provider.py @@ -0,0 +1,78 @@ +"""Tests for the Novita AI provider registration.""" + +from unittest.mock import patch + +from nanobot.config.schema import Config, ProvidersConfig +from nanobot.providers.openai_compat_provider import OpenAICompatProvider +from nanobot.providers.registry import PROVIDERS, find_by_name + + +def test_novita_config_field_exists() -> None: + config = ProvidersConfig() + + assert hasattr(config, "novita") + + +def test_novita_provider_in_registry() -> None: + specs = {spec.name: spec for spec in PROVIDERS} + + assert "novita" in specs + novita = specs["novita"] + assert novita.backend == "openai_compat" + assert novita.env_key == "NOVITA_API_KEY" + assert novita.display_name == "Novita AI" + assert novita.is_gateway is True + assert novita.detect_by_base_keyword == "novita" + assert novita.default_api_base == "https://api.novita.ai/openai" + assert novita.strip_model_prefix is False + + +def test_find_by_name_novita() -> None: + spec = find_by_name("novita") + + assert spec is not None + assert spec.name == "novita" + + +def test_novita_forced_provider_uses_default_api_base() -> None: + config = Config.model_validate({ + "providers": { + "novita": { + "apiKey": "novita-key", + }, + }, + "agents": { + "defaults": { + "model": "deepseek-v4-pro", + "provider": "novita", + }, + }, + }) + + assert config.get_provider_name("deepseek-v4-pro") == "novita" + assert config.get_api_key("deepseek-v4-pro") == "novita-key" + assert config.get_api_base("deepseek-v4-pro") == "https://api.novita.ai/openai" + + +def test_novita_preserves_model_api_id() -> None: + spec = find_by_name("novita") + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider( + api_key="novita-key", + default_model="deepseek-v4-pro", + spec=spec, + ) + + kwargs = provider._build_kwargs( + messages=[{"role": "user", "content": "hi"}], + tools=None, + model="deepseek-v4-pro", + max_tokens=1024, + temperature=0.7, + reasoning_effort=None, + tool_choice=None, + ) + + assert kwargs["model"] == "deepseek-v4-pro" + assert kwargs["max_tokens"] == 1024 + assert "max_completion_tokens" not in kwargs From 8281cd1946bc7113e6ed5ebf49d84e043a1a93e5 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Wed, 20 May 2026 19:33:25 +0800 Subject: [PATCH 43/54] test(providers): cover Novita gateway fallback --- docs/configuration.md | 2 +- nanobot/providers/registry.py | 3 +-- tests/providers/test_novita_provider.py | 19 +++++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index f309f5376..e4fbe83eb 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -148,7 +148,7 @@ ANTHROPIC_API_KEY="$(bw get password api/anthropic)" nanobot agent | `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | | `aihubmix` | LLM (API gateway, access to all models) | [aihubmix.com](https://aihubmix.com) | | `siliconflow` | LLM (SiliconFlow/硅基流动) | [siliconflow.cn](https://siliconflow.cn) | -| `novita` | LLM (NovitaAI Model API: 200+ models on an AI-native cloud for builders and agents) | [novita.ai](https://novita.ai) | +| `novita` | LLM (Novita AI OpenAI-compatible gateway) | [novita.ai](https://novita.ai) | | `dashscope` | LLM (Qwen) | [dashscope.console.aliyun.com](https://dashscope.console.aliyun.com) | | `moonshot` | LLM (Moonshot/Kimi) | [platform.moonshot.cn](https://platform.moonshot.cn) | | `zhipu` | LLM (Zhipu GLM) | [open.bigmodel.cn](https://open.bigmodel.cn) | diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 04a3d3757..ab7e2cf1e 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -199,8 +199,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( default_api_base="https://api.siliconflow.cn/v1", ), - # NovitaAI: AI-native cloud for builders and agents. Model API exposes - # 200+ models through an OpenAI-compatible gateway. + # Novita AI: OpenAI-compatible gateway for hosted model APIs. ProviderSpec( name="novita", keywords=("novita",), diff --git a/tests/providers/test_novita_provider.py b/tests/providers/test_novita_provider.py index 820529e34..0b1e8ec12 100644 --- a/tests/providers/test_novita_provider.py +++ b/tests/providers/test_novita_provider.py @@ -54,6 +54,25 @@ def test_novita_forced_provider_uses_default_api_base() -> None: assert config.get_api_base("deepseek-v4-pro") == "https://api.novita.ai/openai" +def test_novita_gateway_routes_unprefixed_models_when_configured() -> None: + config = Config.model_validate({ + "providers": { + "novita": { + "apiKey": "novita-key", + }, + }, + "agents": { + "defaults": { + "model": "deepseek-v4-pro", + }, + }, + }) + + assert config.get_provider_name("deepseek-v4-pro") == "novita" + assert config.get_api_key("deepseek-v4-pro") == "novita-key" + assert config.get_api_base("deepseek-v4-pro") == "https://api.novita.ai/openai" + + def test_novita_preserves_model_api_id() -> None: spec = find_by_name("novita") with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): From cb7daa77db1f40ad34849a895567136b1542c151 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Fri, 22 May 2026 00:31:55 +0800 Subject: [PATCH 44/54] feat(webui): refine collapsible sidebar --- webui/src/App.tsx | 25 +- webui/src/components/Sidebar.tsx | 259 +++++++++++++------ webui/src/components/thread/ThreadHeader.tsx | 25 +- webui/src/tests/app-layout.test.tsx | 10 +- 4 files changed, 225 insertions(+), 94 deletions(-) diff --git a/webui/src/App.tsx b/webui/src/App.tsx index c303446e2..aa3cf0cf8 100644 --- a/webui/src/App.tsx +++ b/webui/src/App.tsx @@ -43,6 +43,7 @@ const SIDEBAR_STORAGE_KEY = "nanobot-webui.sidebar"; const COMPLETED_RUNS_STORAGE_KEY = "nanobot-webui.sidebar.completed-runs.v1"; const RESTART_STARTED_KEY = "nanobot-webui.restartStartedAt"; const SIDEBAR_WIDTH = 272; +const SIDEBAR_RAIL_WIDTH = 56; const TOKEN_REFRESH_MARGIN_MS = 30_000; const TOKEN_REFRESH_MIN_DELAY_MS = 5_000; type ShellView = "chat" | "settings"; @@ -411,6 +412,10 @@ function Shell({ setDesktopSidebarOpen(false); }, []); + const openDesktopSidebar = useCallback(() => { + setDesktopSidebarOpen(true); + }, []); + const closeMobileSidebar = useCallback(() => { setMobileSidebarOpen(false); }, []); @@ -732,17 +737,19 @@ function Shell({ "relative z-20 hidden shrink-0 overflow-hidden lg:block", "transition-[width] duration-300 ease-out", )} - style={{ width: desktopSidebarOpen ? SIDEBAR_WIDTH : 0 }} + style={{ + width: desktopSidebarOpen ? SIDEBAR_WIDTH : SIDEBAR_RAIL_WIDTH, + }} >
- +
) : null} @@ -797,7 +804,7 @@ function Shell({ onTurnEnd={onTurnEnd} theme={theme} onToggleTheme={toggle} - hideSidebarToggleOnDesktop={desktopSidebarOpen} + hideSidebarToggleOnDesktop />
{view === "settings" && ( diff --git a/webui/src/components/Sidebar.tsx b/webui/src/components/Sidebar.tsx index 53acfc3a9..1040c9696 100644 --- a/webui/src/components/Sidebar.tsx +++ b/webui/src/components/Sidebar.tsx @@ -1,4 +1,4 @@ -import { useState } from "react"; +import { useState, type ReactNode } from "react"; import { Archive, ListFilter, @@ -28,6 +28,7 @@ import type { SidebarSortMode, SidebarViewState, } from "@/lib/types"; +import { cn } from "@/lib/utils"; interface SidebarProps { sessions: ChatSummary[]; @@ -44,7 +45,9 @@ interface SidebarProps { onToggleArchived: () => void; onUpdateView: (view: Partial) => void; onCollapse: () => void; + onExpand?: () => void; containActionMenus?: boolean; + collapsed?: boolean; pinnedKeys?: string[]; archivedKeys?: string[]; titleOverrides?: Record; @@ -59,6 +62,8 @@ export function Sidebar(props: SidebarProps) { const { t } = useTranslation(); const [menuPortalContainer, setMenuPortalContainer] = useState(null); + const collapsed = Boolean(props.collapsed); + const toggleLabel = t("thread.header.toggleSidebar"); return ( ); } +function SidebarActionButton({ + collapsed, + label, + icon, + onClick, + className, +}: { + collapsed: boolean; + label: string; + icon: ReactNode; + onClick: () => void; + className?: string; +}) { + return ( + + ); +} + function SidebarViewMenu({ + compact = false, view, onUpdateView, }: { + compact?: boolean; view?: SidebarViewState; onUpdateView: (view: Partial) => void; }) { @@ -182,11 +268,28 @@ function SidebarViewMenu({ diff --git a/webui/src/components/thread/ThreadHeader.tsx b/webui/src/components/thread/ThreadHeader.tsx index 72136b10f..f929e7e04 100644 --- a/webui/src/components/thread/ThreadHeader.tsx +++ b/webui/src/components/thread/ThreadHeader.tsx @@ -32,12 +32,17 @@ export function ThreadHeader({ onClick={onToggleSidebar} className={cn( "h-7 w-7 rounded-md text-muted-foreground hover:bg-accent/35 hover:text-foreground", - hideSidebarToggleOnDesktop && "lg:pointer-events-none lg:opacity-0", + hideSidebarToggleOnDesktop && "lg:hidden", )} > - + ); } @@ -52,7 +57,7 @@ export function ThreadHeader({ onClick={onToggleSidebar} className={cn( "h-7 w-7 rounded-md text-muted-foreground hover:bg-accent/35 hover:text-foreground", - hideSidebarToggleOnDesktop && "lg:pointer-events-none lg:opacity-0", + hideSidebarToggleOnDesktop && "lg:hidden", )} > @@ -62,7 +67,12 @@ export function ThreadHeader({ - +
@@ -73,10 +83,12 @@ function ThemeButton({ theme, onToggleTheme, label, + className, }: { theme: "light" | "dark"; onToggleTheme: () => void; label: string; + className?: string; }) { return ( + + ) : null} ); -} +}); function SessionActivityIndicator({ state, @@ -366,6 +421,45 @@ function groupSessions( return groups; } +function limitGroups( + groups: Array<{ label: string; sessions: ChatSummary[] }>, + limit: number, + activeKey: string | null, +): Array<{ label: string; sessions: ChatSummary[] }> { + let remaining = Math.max(0, limit); + let activeVisible = !activeKey; + const out: Array<{ label: string; sessions: ChatSummary[] }> = []; + + for (const group of groups) { + const visible = remaining > 0 + ? group.sessions.slice(0, remaining) + : []; + remaining -= visible.length; + if (activeKey && visible.some((session) => session.key === activeKey)) { + activeVisible = true; + } + if (visible.length > 0) { + out.push({ label: group.label, sessions: visible }); + } + } + + if (activeVisible || !activeKey) return out; + + for (const group of groups) { + const active = group.sessions.find((session) => session.key === activeKey); + if (!active) continue; + const existing = out.find((item) => item.label === group.label); + if (existing) { + existing.sessions = [...existing.sessions, active]; + } else { + out.push({ label: group.label, sessions: [active] }); + } + return out; + } + + return out; +} + function sortSessions( sessions: ChatSummary[], sort: SidebarSortMode, diff --git a/webui/src/components/SessionSearchDialog.tsx b/webui/src/components/SessionSearchDialog.tsx index 3ddea0672..1e0f12044 100644 --- a/webui/src/components/SessionSearchDialog.tsx +++ b/webui/src/components/SessionSearchDialog.tsx @@ -37,13 +37,16 @@ export function SessionSearchDialog({ const [highlightedIndex, setHighlightedIndex] = useState(0); const normalizedQuery = query.trim().toLowerCase(); - const results = useMemo(() => { + const sessionResults = useMemo(() => { + if (!open) return []; if (!normalizedQuery) return sessions; const terms = normalizedQuery.split(/\s+/).filter(Boolean); return sessions.filter((session) => sessionMatchesTerms(session, terms, titleOverrides[session.key]), ); - }, [normalizedQuery, sessions, titleOverrides]); + }, [normalizedQuery, open, sessions, titleOverrides]); + const itemCount = sessionResults.length; + const shortcutLabel = useMemo(getSearchShortcutLabel, []); useEffect(() => { if (!open) return; @@ -58,9 +61,9 @@ export function SessionSearchDialog({ useEffect(() => { setHighlightedIndex((index) => - results.length === 0 ? 0 : Math.min(index, results.length - 1), + itemCount === 0 ? 0 : Math.min(index, itemCount - 1), ); - }, [results.length]); + }, [itemCount]); const handleSelect = (key: string) => { onOpenChange(false); @@ -71,17 +74,19 @@ export function SessionSearchDialog({ if (event.key === "ArrowDown") { event.preventDefault(); setHighlightedIndex((index) => - results.length === 0 ? 0 : Math.min(index + 1, results.length - 1), + itemCount === 0 ? 0 : (index + 1) % itemCount, ); return; } if (event.key === "ArrowUp") { event.preventDefault(); - setHighlightedIndex((index) => Math.max(index - 1, 0)); + setHighlightedIndex((index) => + itemCount === 0 ? 0 : (index - 1 + itemCount) % itemCount, + ); return; } if (event.key === "Enter") { - const highlighted = results[highlightedIndex]; + const highlighted = sessionResults[highlightedIndex]; if (!highlighted) return; event.preventDefault(); handleSelect(highlighted.key); @@ -125,70 +130,75 @@ export function SessionSearchDialog({ aria-label={t("sidebar.searchAria")} className="h-full min-w-0 flex-1 bg-transparent text-[15px] font-medium text-foreground outline-none placeholder:text-muted-foreground/75" /> + + {shortcutLabel} +
-
- {sectionLabel} -
+
+
+ {sectionLabel} +
- {loading && sessions.length === 0 ? ( -
- {t("chat.loading")} -
- ) : results.length === 0 ? ( -
- {emptyLabel} -
- ) : ( -
    - {results.map((session, index) => { - const title = titleOverrides[session.key]?.trim() || - session.title?.trim() || - deriveTitle(session.preview, t("chat.newChat")); - const preview = session.preview.trim(); - const showPreview = - preview.length > 0 && - preview.toLowerCase() !== title.trim().toLowerCase(); - const highlighted = index === highlightedIndex; - const active = session.key === activeKey; - return ( -
  • - -
  • - ); - })} -
- )} + {showPreview ? ( + + {preview} + + ) : null} + + + + ); + })} + + )} +
@@ -211,3 +221,13 @@ function sessionMatchesTerms( return terms.every((term) => haystack.includes(term)); } + +function getSearchShortcutLabel() { + if (typeof navigator === "undefined") return "Ctrl K"; + const platform = navigator.platform.toLowerCase(); + const apple = + platform.includes("mac") || + platform.includes("iphone") || + platform.includes("ipad"); + return apple ? "⌘K" : "Ctrl K"; +} diff --git a/webui/src/i18n/locales/en/common.json b/webui/src/i18n/locales/en/common.json index f44332f95..68822ddd5 100644 --- a/webui/src/i18n/locales/en/common.json +++ b/webui/src/i18n/locales/en/common.json @@ -271,6 +271,7 @@ "fallbackTitle": "Chat {{id}}", "loading": "Loading…", "noSessions": "No sessions yet.", + "showMore": "Show {{count}} more", "actions": "Chat actions for {{title}}", "activity": { "running": "Agent running", diff --git a/webui/src/i18n/locales/es/common.json b/webui/src/i18n/locales/es/common.json index e658803e1..e031b3eef 100644 --- a/webui/src/i18n/locales/es/common.json +++ b/webui/src/i18n/locales/es/common.json @@ -224,6 +224,7 @@ "fallbackTitle": "Chat {{id}}", "loading": "Cargando…", "noSessions": "Todavía no hay sesiones.", + "showMore": "Mostrar {{count}} más", "actions": "Acciones del chat {{title}}", "activity": { "running": "Agent running", diff --git a/webui/src/i18n/locales/fr/common.json b/webui/src/i18n/locales/fr/common.json index a6f80e729..d49edf640 100644 --- a/webui/src/i18n/locales/fr/common.json +++ b/webui/src/i18n/locales/fr/common.json @@ -224,6 +224,7 @@ "fallbackTitle": "Discussion {{id}}", "loading": "Chargement…", "noSessions": "Aucune session pour le moment.", + "showMore": "Afficher {{count}} de plus", "actions": "Actions de la discussion {{title}}", "activity": { "running": "Agent running", diff --git a/webui/src/i18n/locales/id/common.json b/webui/src/i18n/locales/id/common.json index e64db6029..41901b475 100644 --- a/webui/src/i18n/locales/id/common.json +++ b/webui/src/i18n/locales/id/common.json @@ -224,6 +224,7 @@ "fallbackTitle": "Obrolan {{id}}", "loading": "Memuat…", "noSessions": "Belum ada sesi.", + "showMore": "Tampilkan {{count}} lagi", "actions": "Aksi obrolan untuk {{title}}", "activity": { "running": "Agent running", diff --git a/webui/src/i18n/locales/ja/common.json b/webui/src/i18n/locales/ja/common.json index c1bacd12b..b061c6c79 100644 --- a/webui/src/i18n/locales/ja/common.json +++ b/webui/src/i18n/locales/ja/common.json @@ -224,6 +224,7 @@ "fallbackTitle": "チャット {{id}}", "loading": "読み込み中…", "noSessions": "まだセッションがありません。", + "showMore": "さらに {{count}} 件表示", "actions": "「{{title}}」のチャット操作", "activity": { "running": "Agent running", diff --git a/webui/src/i18n/locales/ko/common.json b/webui/src/i18n/locales/ko/common.json index f94936648..86a6c908e 100644 --- a/webui/src/i18n/locales/ko/common.json +++ b/webui/src/i18n/locales/ko/common.json @@ -224,6 +224,7 @@ "fallbackTitle": "채팅 {{id}}", "loading": "불러오는 중…", "noSessions": "아직 세션이 없습니다.", + "showMore": "{{count}}개 더 보기", "actions": "{{title}} 채팅 작업", "activity": { "running": "Agent running", diff --git a/webui/src/i18n/locales/vi/common.json b/webui/src/i18n/locales/vi/common.json index 805e82ced..25d4719b6 100644 --- a/webui/src/i18n/locales/vi/common.json +++ b/webui/src/i18n/locales/vi/common.json @@ -224,6 +224,7 @@ "fallbackTitle": "Trò chuyện {{id}}", "loading": "Đang tải…", "noSessions": "Chưa có phiên nào.", + "showMore": "Hiển thị thêm {{count}}", "actions": "Tác vụ cho cuộc trò chuyện {{title}}", "activity": { "running": "Agent running", diff --git a/webui/src/i18n/locales/zh-CN/common.json b/webui/src/i18n/locales/zh-CN/common.json index 18089c7ad..1ac9fde0e 100644 --- a/webui/src/i18n/locales/zh-CN/common.json +++ b/webui/src/i18n/locales/zh-CN/common.json @@ -259,6 +259,7 @@ "fallbackTitle": "对话 {{id}}", "loading": "加载中…", "noSessions": "还没有会话。", + "showMore": "再显示 {{count}} 个", "actions": "“{{title}}” 的会话操作", "activity": { "running": "Agent 正在运行", diff --git a/webui/src/i18n/locales/zh-TW/common.json b/webui/src/i18n/locales/zh-TW/common.json index cf1cd4aa6..d4360af14 100644 --- a/webui/src/i18n/locales/zh-TW/common.json +++ b/webui/src/i18n/locales/zh-TW/common.json @@ -224,6 +224,7 @@ "fallbackTitle": "對話 {{id}}", "loading": "載入中…", "noSessions": "目前還沒有會話。", + "showMore": "再顯示 {{count}} 個", "actions": "「{{title}}」的會話操作", "activity": { "running": "Agent 正在執行", diff --git a/webui/src/tests/app-layout.test.tsx b/webui/src/tests/app-layout.test.tsx index c7ee44a06..1ae4d29c2 100644 --- a/webui/src/tests/app-layout.test.tsx +++ b/webui/src/tests/app-layout.test.tsx @@ -992,6 +992,73 @@ describe("App layout", () => { ); }); + it("opens search from the keyboard shortcut", async () => { + mockSessions = [ + { + key: "websocket:chat-a", + channel: "websocket", + chatId: "chat-a", + createdAt: "2026-04-16T10:00:00Z", + updatedAt: "2026-04-16T10:00:00Z", + preview: "Existing chat", + }, + ]; + + render(); + + await waitFor(() => expect(connectSpy).toHaveBeenCalled()); + fireEvent.keyDown(window, { key: "k", metaKey: true }); + + const dialog = await screen.findByRole("dialog", { name: "Search" }); + expect(within(dialog).queryByText("Global actions")).not.toBeInTheDocument(); + expect(within(dialog).getByText("Existing chat")).toBeInTheDocument(); + + const textbox = within(dialog).getByRole("textbox", { name: "Search" }); + fireEvent.change(textbox, { target: { value: "missing" } }); + expect(within(dialog).queryByText("Existing chat")).not.toBeInTheDocument(); + + fireEvent.change(textbox, { target: { value: "existing" } }); + expect(within(dialog).getByText("Existing chat")).toBeInTheDocument(); + + fireEvent.keyDown(textbox, { key: "Enter" }); + await waitFor(() => + expect(screen.queryByRole("dialog", { name: "Search" })).not.toBeInTheDocument(), + ); + expect(createChatSpy).not.toHaveBeenCalled(); + }); + + it("keeps large sidebars light while search still covers every chat", async () => { + mockSessions = Array.from({ length: 170 }, (_, index) => { + const chatId = `chat-${index}`; + return { + key: `websocket:${chatId}`, + channel: "websocket" as const, + chatId, + createdAt: new Date(Date.UTC(2026, 3, 16, 12, 0 - index)).toISOString(), + updatedAt: new Date(Date.UTC(2026, 3, 16, 12, 0 - index)).toISOString(), + title: index === 169 ? "Hidden target" : `Bulk chat ${index}`, + preview: "", + }; + }); + + render(); + + await waitFor(() => expect(connectSpy).toHaveBeenCalled()); + const sidebar = screen.getByRole("navigation", { name: "Sidebar navigation" }); + await waitFor(() => + expect(within(sidebar).getByRole("button", { name: "Bulk chat 0" })).toBeInTheDocument(), + ); + expect(within(sidebar).queryByText("Hidden target")).not.toBeInTheDocument(); + expect(within(sidebar).getByRole("button", { name: "Show 10 more" })).toBeInTheDocument(); + + fireEvent.click(within(sidebar).getByRole("button", { name: "Search" })); + const dialog = await screen.findByRole("dialog", { name: "Search" }); + fireEvent.change(within(dialog).getByRole("textbox", { name: "Search" }), { + target: { value: "hidden" }, + }); + expect(within(dialog).getByText("Hidden target")).toBeInTheDocument(); + }); + it("opens a blank start page without creating an empty chat", async () => { mockSessions = [ { diff --git a/webui/src/tests/i18n.test.tsx b/webui/src/tests/i18n.test.tsx index d1359121c..e92c42500 100644 --- a/webui/src/tests/i18n.test.tsx +++ b/webui/src/tests/i18n.test.tsx @@ -78,6 +78,7 @@ describe("webui i18n", () => { const common = resource.common; expect(common.app.system.restarting).toBeTruthy(); expect(common.sidebar.settings).toBeTruthy(); + expect(common.chat.showMore).toBeTruthy(); expect(common.settings.sidebar.title).toBeTruthy(); expect(common.settings.backToChat).toBeTruthy(); for (const key of SETTINGS_NAV_KEYS) { From 9b2f452b6e4b8e96a9e893bb2ed0ec676591f1f7 Mon Sep 17 00:00:00 2001 From: "A.G. Bocsardi" Date: Thu, 21 May 2026 15:44:47 +0200 Subject: [PATCH 46/54] fix: drop redundant reasoning_effort for Kimi thinking models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Moonshot's API rejects requests that carry both 'reasoning_effort' (top-level kwarg) and 'thinking' (extra_body) at the same time. After the unified thinking-style injection loop injects the native 'thinking' param for kimi models, pop 'reasoning_effort' from kwargs since it is redundant and causes a 400 error. Uses _model_slug() + _KIMI_THINKING_MODELS lookup to stay consistent with the refactored code (the old _is_kimi_thinking_model helper was removed in 4f895e63). Existing kimi tests updated to assert 'reasoning_effort' is absent. Xiaomi MiMo models are unaffected — their API accepts both params. Closes #3939 --- nanobot/providers/openai_compat_provider.py | 8 ++++++++ tests/providers/test_litellm_kwargs.py | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 3c1bf9b8f..8281d7d20 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -633,6 +633,14 @@ class OpenAICompatProvider(LLMProvider): if extra: kwargs.setdefault("extra_body", {}).update(extra) + # Moonshot rejects requests that carry both 'reasoning_effort' + # and the native 'thinking' param. We already expressed the + # user's intent via the provider-native shape, so drop the + # redundant wire-level kwarg. Only kimi models need this — + # Xiaomi's API accepts both params. + if _model_slug(model_name) in _KIMI_THINKING_MODELS: + kwargs.pop("reasoning_effort", None) + if tools: kwargs["tools"] = tools kwargs["tool_choice"] = tool_choice or "auto" diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index dddc70054..924ee0060 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -1420,12 +1420,15 @@ def test_kimi_k25_thinking_enabled() -> None: """kimi-k2.5 with reasoning_effort set should opt in to thinking.""" kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="medium") assert kw.get("extra_body") == {"thinking": {"type": "enabled"}} + # Moonshot rejects both 'reasoning_effort' and 'thinking' (#3939) + assert "reasoning_effort" not in kw def test_kimi_k25_thinking_disabled_for_minimal() -> None: """reasoning_effort='minimal' maps to thinking disabled for kimi-k2.5.""" kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="minimal") assert kw.get("extra_body") == {"thinking": {"type": "disabled"}} + assert "reasoning_effort" not in kw def test_kimi_k25_no_extra_body_when_reasoning_effort_none() -> None: @@ -1445,12 +1448,15 @@ def test_kimi_k25_thinking_enabled_with_openrouter_prefix() -> None: "thinking": {"type": "enabled"}, "reasoning": {"effort": "medium"}, } + # Even via OR, reasoning_effort wire kwarg is dropped for kimi models + assert "reasoning_effort" not in kw def test_kimi_k26_thinking_enabled() -> None: """kimi-k2.6 with reasoning_effort set should opt in to thinking.""" kw = _build_kwargs_for("moonshot", "kimi-k2.6", reasoning_effort="medium") assert kw.get("extra_body") == {"thinking": {"type": "enabled"}} + assert "reasoning_effort" not in kw def test_kimi_k26_thinking_enabled_with_openrouter_prefix() -> None: @@ -1461,6 +1467,7 @@ def test_kimi_k26_thinking_enabled_with_openrouter_prefix() -> None: "thinking": {"type": "enabled"}, "reasoning": {"effort": "medium"}, } + assert "reasoning_effort" not in kw def test_moonshot_kimi_k26_temperature_override() -> None: @@ -1479,6 +1486,7 @@ def test_kimi_k26_code_preview_thinking_enabled() -> None: """k2.6-code-preview also supports thinking; should behave like k2.5.""" kw = _build_kwargs_for("moonshot", "k2.6-code-preview", reasoning_effort="high") assert kw.get("extra_body") == {"thinking": {"type": "enabled"}} + assert "reasoning_effort" not in kw def test_kimi_k2_series_no_thinking_injection() -> None: @@ -1508,6 +1516,7 @@ def test_kimi_k25_thinking_disabled_for_none_string() -> None: """reasoning_effort='none' maps to thinking disabled for kimi-k2.5.""" kw = _build_kwargs_for("moonshot", "kimi-k2.5", reasoning_effort="none") assert kw.get("extra_body") == {"thinking": {"type": "disabled"}} + assert "reasoning_effort" not in kw def test_dashscope_thinking_disabled_for_none_string() -> None: From effc1efd92ac47e204a7efe6256024151fc0435c Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Fri, 22 May 2026 13:26:31 +0800 Subject: [PATCH 47/54] fix(webui): avoid misleading file edit counters --- .../thread/AgentActivityCluster.tsx | 44 +++++++++++-- .../src/tests/agent-activity-cluster.test.tsx | 63 +++++++++++++++++++ 2 files changed, 102 insertions(+), 5 deletions(-) diff --git a/webui/src/components/thread/AgentActivityCluster.tsx b/webui/src/components/thread/AgentActivityCluster.tsx index 46be8b43d..135f33ef6 100644 --- a/webui/src/components/thread/AgentActivityCluster.tsx +++ b/webui/src/components/thread/AgentActivityCluster.tsx @@ -27,6 +27,7 @@ interface ActivityCounts { fileCount: number; added: number; deleted: number; + hasDiffStats: boolean; hasEditingFiles: boolean; hasFailedFiles: boolean; primaryFilePath?: string; @@ -61,6 +62,7 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act } let added = 0; let deleted = 0; + let hasDiffStats = false; let hasEditingFiles = false; let failedFileCount = 0; let primaryFilePath: string | undefined; @@ -77,6 +79,10 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act if (edit.status === "error" || edit.binary) { continue; } + if (!hasVisibleDiffStats(edit)) { + continue; + } + hasDiffStats = true; added += edit.added; deleted += edit.deleted; } @@ -86,6 +92,7 @@ function countActivity(messages: UIMessage[], fileEdits: FileEditSummary[]): Act fileCount: fileEdits.length, added, deleted, + hasDiffStats, hasEditingFiles, hasFailedFiles: fileEdits.length > 0 && failedFileCount === fileEdits.length, primaryFilePath, @@ -120,6 +127,7 @@ export function AgentActivityCluster({ fileCount, added, deleted, + hasDiffStats, hasEditingFiles, hasFailedFiles, primaryFilePath, @@ -140,6 +148,7 @@ export function AgentActivityCluster({ const headerBusy = fileCount > 0 ? hasEditingFiles : isTurnStreaming; const singleFilePath = fileCount === 1 ? primaryFilePath : undefined; const singleFileTooltipPath = fileCount === 1 ? primaryFileTooltipPath : undefined; + const hasVisibleActivity = reasoningSteps > 0 || toolCalls > 0 || fileCount > 0; const fileActivitySummary = fileCount > 0 ? hasPendingFileEdit && !singleFilePath @@ -243,6 +252,8 @@ export function AgentActivityCluster({ autoFollowActivityRef.current = distance < ACTIVITY_SCROLL_NEAR_BOTTOM_PX; }, []); + if (!hasVisibleActivity) return null; + return (