From 6851fa57a6d0ca75e754065cc56c5c88a4332cc5 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Wed, 20 May 2026 12:51:26 +0800 Subject: [PATCH] feat(tools): optimize coding workflows --- README.md | 2 +- nanobot/agent/tools/apply_patch.py | 341 +++++++++++++++ nanobot/agent/tools/exec_session.py | 409 ++++++++++++++++++ nanobot/agent/tools/filesystem.py | 109 ++++- nanobot/agent/tools/notebook.py | 162 ------- nanobot/agent/tools/shell.py | 263 +++++++++-- nanobot/utils/file_edit_events.py | 76 +--- tests/tools/test_apply_patch_tool.py | 238 ++++++++++ tests/tools/test_edit_enhancements.py | 19 +- tests/tools/test_exec_session_tools.py | 242 +++++++++++ .../test_file_edit_coding_enhancements.py | 216 +++++++++ tests/tools/test_notebook_tool.py | 147 ------- tests/tools/test_tool_loader.py | 4 +- 13 files changed, 1771 insertions(+), 457 deletions(-) create mode 100644 nanobot/agent/tools/apply_patch.py create mode 100644 nanobot/agent/tools/exec_session.py delete mode 100644 nanobot/agent/tools/notebook.py create mode 100644 tests/tools/test_apply_patch_tool.py create mode 100644 tests/tools/test_exec_session_tools.py create mode 100644 tests/tools/test_file_edit_coding_enhancements.py delete mode 100644 tests/tools/test_notebook_tool.py diff --git a/README.md b/README.md index b5e4b02c0..b97545731 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ - **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks. - **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened. - **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media. -- **2026-04-10** 📓 Notebook editing tool, multiple MCP servers, Feishu streaming & done-emoji. +- **2026-04-10** 📓 Multiple MCP servers, Feishu streaming & done-emoji. - **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config. - **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback. - **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools. diff --git a/nanobot/agent/tools/apply_patch.py b/nanobot/agent/tools/apply_patch.py new file mode 100644 index 000000000..e69a65f10 --- /dev/null +++ b/nanobot/agent/tools/apply_patch.py @@ -0,0 +1,341 @@ +"""Structured patch editing tool for coding workflows.""" + +from __future__ import annotations + +import difflib +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +from nanobot.agent.tools.base import tool_parameters +from nanobot.agent.tools.filesystem import _FsTool +from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema + + +PatchKind = Literal["add", "delete", "update"] + + +@dataclass(slots=True) +class _Hunk: + header: str | None + lines: list[tuple[str, str]] + + +@dataclass(slots=True) +class _PatchOp: + kind: PatchKind + path: str + new_path: str | None = None + add_lines: list[str] | None = None + hunks: list[_Hunk] | None = None + + +class _PatchError(ValueError): + pass + + +_ABSOLUTE_WINDOWS_RE = re.compile(r"^[A-Za-z]:[\\/]") + + +def _is_file_header(line: str) -> bool: + return ( + line.startswith("*** Add File: ") + or line.startswith("*** Delete File: ") + or line.startswith("*** Update File: ") + ) + + +def _validate_relative_path(path: str) -> str: + normalized = path.strip() + if not normalized: + raise _PatchError("patch path cannot be empty") + if "\0" in normalized: + raise _PatchError(f"patch path contains a null byte: {path!r}") + if normalized.startswith(("~", "/", "\\")) or _ABSOLUTE_WINDOWS_RE.match(normalized): + raise _PatchError(f"patch path must be relative: {path}") + if any(part == ".." for part in re.split(r"[\\/]+", normalized)): + raise _PatchError(f"patch path must not contain '..': {path}") + return normalized + + +def _lines_to_text(lines: list[str]) -> str: + if not lines: + return "" + return "\n".join(lines) + "\n" + + +def _parse_patch(patch: str) -> list[_PatchOp]: + lines = patch.replace("\r\n", "\n").replace("\r", "\n").split("\n") + if lines and lines[-1] == "": + lines.pop() + if not lines or lines[0] != "*** Begin Patch": + raise _PatchError("patch must start with '*** Begin Patch'") + if len(lines) < 2 or lines[-1] != "*** End Patch": + raise _PatchError("patch must end with '*** End Patch'") + + ops: list[_PatchOp] = [] + i = 1 + end = len(lines) - 1 + while i < end: + line = lines[i] + if line.startswith("*** Add File: "): + path = _validate_relative_path(line.removeprefix("*** Add File: ")) + i += 1 + add_lines: list[str] = [] + while i < end and not _is_file_header(lines[i]): + if not lines[i].startswith("+"): + raise _PatchError(f"Add File lines must start with '+': {lines[i]!r}") + add_lines.append(lines[i][1:]) + i += 1 + ops.append(_PatchOp(kind="add", path=path, add_lines=add_lines)) + continue + + if line.startswith("*** Delete File: "): + path = _validate_relative_path(line.removeprefix("*** Delete File: ")) + ops.append(_PatchOp(kind="delete", path=path)) + i += 1 + continue + + if line.startswith("*** Update File: "): + path = _validate_relative_path(line.removeprefix("*** Update File: ")) + i += 1 + new_path: str | None = None + if i < end and lines[i].startswith("*** Move to: "): + new_path = _validate_relative_path(lines[i].removeprefix("*** Move to: ")) + i += 1 + + hunks: list[_Hunk] = [] + while i < end and not _is_file_header(lines[i]): + if not lines[i].startswith("@@"): + raise _PatchError(f"Update File sections require '@@' hunks: {lines[i]!r}") + header = lines[i][2:].strip() or None + i += 1 + hunk_lines: list[tuple[str, str]] = [] + while i < end and not lines[i].startswith("@@") and not _is_file_header(lines[i]): + if lines[i] == "*** End of File": + i += 1 + break + if lines[i] == r"\ No newline at end of file": + i += 1 + continue + if not lines[i] or lines[i][0] not in {" ", "+", "-"}: + raise _PatchError(f"Hunk lines must start with ' ', '+', or '-': {lines[i]!r}") + hunk_lines.append((lines[i][0], lines[i][1:])) + i += 1 + if not hunk_lines: + raise _PatchError(f"Update File hunk is empty: {path}") + hunks.append(_Hunk(header=header, lines=hunk_lines)) + if not hunks and new_path is None: + raise _PatchError(f"Update File requires at least one hunk or Move to: {path}") + ops.append(_PatchOp(kind="update", path=path, new_path=new_path, hunks=hunks)) + continue + + raise _PatchError(f"unknown patch header: {line!r}") + + if not ops: + raise _PatchError("patch contains no file operations") + return ops + + +def _find_with_eof_fallback(content: str, needle: str, start: int) -> tuple[int, int]: + pos = content.find(needle, start) + if pos >= 0: + return pos, len(needle) + if needle.endswith("\n"): + trimmed = needle[:-1] + pos = content.find(trimmed, start) + if pos >= 0 and pos + len(trimmed) == len(content): + return pos, len(trimmed) + return -1, 0 + + +def _line_offset(content: str, line_number: int) -> int: + if line_number <= 1: + return 0 + offset = 0 + for current, line in enumerate(content.splitlines(keepends=True), start=1): + if current >= line_number: + return offset + offset += len(line) + return len(content) + + +def _line_hint(header: str | None) -> int | None: + if not header: + return None + match = re.search(r"-(\d+)(?:,\d+)?", header) + return int(match.group(1)) if match else None + + +def _hunk_mismatch(path: str, old_text: str, content: str, header: str | None) -> str: + lines = content.splitlines(keepends=True) + old_lines = old_text.splitlines(keepends=True) + window = max(1, len(old_lines)) + best_ratio, best_start = -1.0, 0 + best_lines: list[str] = [] + for i in range(max(1, len(lines) - window + 1)): + current = lines[i : i + window] + ratio = difflib.SequenceMatcher(None, "".join(old_lines), "".join(current)).ratio() + if ratio > best_ratio: + best_ratio, best_start, best_lines = ratio, i, current + + label = f" after header {header!r}" if header else "" + if best_ratio <= 0: + return f"hunk does not match {path}{label}" + diff = "\n".join(difflib.unified_diff( + old_lines, + best_lines, + fromfile="patch hunk", + tofile=f"{path} (actual, line {best_start + 1})", + lineterm="", + )) + return ( + f"hunk does not match {path}{label}. " + f"Best match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" + ) + + +def _apply_hunks(path: str, content: str, hunks: list[_Hunk]) -> str: + cursor = 0 + for hunk in hunks: + old_lines = [text for marker, text in hunk.lines if marker in {" ", "-"}] + new_lines = [text for marker, text in hunk.lines if marker in {" ", "+"}] + old_text = _lines_to_text(old_lines) + new_text = _lines_to_text(new_lines) + + search_start = cursor + line_hint = None + if hunk.header: + line_hint = _line_hint(hunk.header) + if line_hint is not None: + search_start = _line_offset(content, line_hint) + else: + header_pos = content.find(hunk.header, cursor) + if header_pos >= 0: + search_start = header_pos + + if old_text: + pos, match_len = _find_with_eof_fallback(content, old_text, search_start) + if pos < 0 and search_start != 0 and line_hint is None: + pos, match_len = _find_with_eof_fallback(content, old_text, 0) + if pos < 0: + raise _PatchError(_hunk_mismatch(path, old_text, content, hunk.header)) + else: + pos = search_start + match_len = 0 + + content = content[:pos] + new_text + content[pos + match_len:] + cursor = pos + len(new_text) + return content + + +@tool_parameters( + tool_parameters_schema( + patch=StringSchema( + "Full patch text. Use *** Begin Patch / *** End Patch and file sections " + "for Add File, Update File, Delete File, and optional Move to.", + min_length=1, + ), + required=["patch"], + ) +) +class ApplyPatchTool(_FsTool): + """Apply a structured multi-file patch.""" + _scopes = {"core", "subagent"} + + @property + def name(self) -> str: + return "apply_patch" + + @property + def description(self) -> str: + return ( + "Apply a structured patch for code edits. The patch must include " + "*** Begin Patch and *** End Patch. Supports Add File, Update File, " + "Delete File, and Move to. Paths must be relative. Prefer this for " + "multi-file coding changes; use edit_file for small exact replacements." + ) + + async def execute(self, patch: str, **kwargs: Any) -> str: + try: + ops = _parse_patch(patch) + writes: dict[Path, str] = {} + deletes: set[Path] = set() + touched: list[str] = [] + + for op in ops: + source = self._resolve(op.path) + if op.kind == "add": + if source.exists(): + raise _PatchError(f"file to add already exists: {op.path}") + writes[source] = _lines_to_text(op.add_lines or []) + deletes.discard(source) + touched.append(f"add {op.path}") + continue + + if op.kind == "delete": + if not source.exists(): + raise _PatchError(f"file to delete does not exist: {op.path}") + if not source.is_file(): + raise _PatchError(f"path to delete is not a file: {op.path}") + deletes.add(source) + writes.pop(source, None) + touched.append(f"delete {op.path}") + continue + + if not source.exists(): + raise _PatchError(f"file to update does not exist: {op.path}") + if not source.is_file(): + raise _PatchError(f"path to update is not a file: {op.path}") + raw = source.read_bytes() + try: + content = raw.decode("utf-8") + except UnicodeDecodeError as exc: + raise _PatchError(f"file to update is not UTF-8 text: {op.path}") from exc + uses_crlf = "\r\n" in content + content = content.replace("\r\n", "\n") + new_content = _apply_hunks(op.path, content, op.hunks or []) + if uses_crlf: + new_content = new_content.replace("\n", "\r\n") + + target = self._resolve(op.new_path) if op.new_path else source + if op.new_path and target.exists() and target != source: + raise _PatchError(f"move target already exists: {op.new_path}") + writes[target] = new_content + deletes.discard(target) + if target != source: + deletes.add(source) + action = f"move {op.path} -> {op.new_path}" if op.new_path else f"update {op.path}" + touched.append(action) + + backups: dict[Path, bytes | None] = {} + for path in set(writes) | deletes: + backups[path] = path.read_bytes() if path.exists() else None + + try: + for path in deletes: + if path.exists(): + path.unlink() + for path, content in writes.items(): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(content, encoding="utf-8", newline="") + except Exception: + for path, data in backups.items(): + if data is None: + if path.exists(): + path.unlink() + else: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(data) + raise + + for path in set(writes) | deletes: + self._file_states.record_write(path) + return "Patch applied:\n" + "\n".join(f"- {item}" for item in touched) + except PermissionError as exc: + return f"Error: {exc}" + except _PatchError as exc: + return f"Error applying patch: {exc}" + except Exception as exc: + return f"Error applying patch: {exc}" diff --git a/nanobot/agent/tools/exec_session.py b/nanobot/agent/tools/exec_session.py new file mode 100644 index 000000000..202fbc640 --- /dev/null +++ b/nanobot/agent/tools/exec_session.py @@ -0,0 +1,409 @@ +"""Session support for long-running exec workflows.""" + +from __future__ import annotations + +import asyncio +import shutil +import time +import uuid +from contextlib import suppress +from dataclasses import dataclass +from typing import Any + +from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema + + +DEFAULT_YIELD_MS = 1000 +MAX_YIELD_MS = 30_000 +DEFAULT_MAX_OUTPUT_CHARS = 10_000 +MAX_OUTPUT_CHARS = 50_000 + + +@dataclass(slots=True) +class _SessionPoll: + output: str + done: bool + exit_code: int | None + timed_out: bool = False + terminated: bool = False + stdin_closed: bool = False + truncated_chars: int = 0 + + +class _ExecSession: + def __init__( + self, + *, + session_id: str, + process: asyncio.subprocess.Process, + timeout: int, + ) -> None: + self.session_id = session_id + self.process = process + self.deadline = time.monotonic() + timeout + self.last_access = time.monotonic() + self._chunks: list[str] = [] + self._lock = asyncio.Lock() + self._timed_out = False + self._stdout_task = asyncio.create_task(self._read_stream(process.stdout, "")) + self._stderr_task = asyncio.create_task(self._read_stream(process.stderr, "STDERR:\n")) + + async def _read_stream( + self, + stream: asyncio.StreamReader | None, + prefix: str, + ) -> None: + if stream is None: + return + first = True + while True: + chunk = await stream.read(4096) + if not chunk: + break + text = chunk.decode("utf-8", errors="replace") + if prefix and first: + text = prefix + text + first = False + async with self._lock: + self._chunks.append(text) + + async def write(self, chars: str) -> str | None: + if self.process.returncode is not None: + return "session has already exited" + if self.process.stdin is None: + return "session stdin is not available" + try: + self.process.stdin.write(chars.encode("utf-8")) + await self.process.stdin.drain() + except (BrokenPipeError, ConnectionResetError): + return "session stdin is closed" + return None + + async def close_stdin(self) -> str | None: + if self.process.returncode is not None: + return "session has already exited" + if self.process.stdin is None: + return "session stdin is not available" + self.process.stdin.close() + with suppress(BrokenPipeError, ConnectionResetError): + await self.process.stdin.wait_closed() + return None + + async def poll( + self, + yield_time_ms: int, + max_output_chars: int, + *, + terminated: bool = False, + stdin_closed: bool = False, + ) -> _SessionPoll: + self.last_access = time.monotonic() + if yield_time_ms > 0 and self.process.returncode is None: + await asyncio.sleep(min(yield_time_ms, MAX_YIELD_MS) / 1000) + + if self.process.returncode is None and time.monotonic() >= self.deadline: + self._timed_out = True + await self.kill() + + if self.process.returncode is not None: + with suppress(asyncio.TimeoutError): + await asyncio.wait_for( + asyncio.gather(self._stdout_task, self._stderr_task), + timeout=2.0, + ) + + async with self._lock: + output = "".join(self._chunks) + self._chunks.clear() + + output, truncated = _truncate_output(output, max_output_chars) + return _SessionPoll( + output=output, + done=self.process.returncode is not None, + exit_code=self.process.returncode, + timed_out=self._timed_out, + terminated=terminated, + stdin_closed=stdin_closed, + truncated_chars=truncated, + ) + + async def kill(self) -> None: + if self.process.returncode is not None: + return + self.process.kill() + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(self.process.wait(), timeout=5.0) + + +class ExecSessionManager: + def __init__(self, *, max_sessions: int = 8, idle_timeout: int = 1800) -> None: + self.max_sessions = max_sessions + self.idle_timeout = idle_timeout + self._sessions: dict[str, _ExecSession] = {} + self._lock = asyncio.Lock() + + async def start( + self, + *, + command: str, + cwd: str, + env: dict[str, str], + timeout: int, + shell_program: str | None, + login: bool, + yield_time_ms: int, + max_output_chars: int, + ) -> tuple[str, _SessionPoll]: + async with self._lock: + await self._cleanup_locked() + if len(self._sessions) >= self.max_sessions: + raise RuntimeError(f"maximum exec sessions reached ({self.max_sessions})") + process = await self._spawn(command, cwd, env, shell_program, login) + session_id = uuid.uuid4().hex[:12] + session = _ExecSession(session_id=session_id, process=process, timeout=timeout) + self._sessions[session_id] = session + + poll = await session.poll(yield_time_ms, max_output_chars) + if poll.done: + async with self._lock: + self._sessions.pop(session_id, None) + return session_id, poll + + async def write( + self, + *, + session_id: str, + chars: str | None, + close_stdin: bool, + terminate: bool, + yield_time_ms: int, + max_output_chars: int, + ) -> _SessionPoll: + async with self._lock: + await self._cleanup_locked() + session = self._sessions.get(session_id) + if session is None: + raise KeyError(session_id) + + if chars is not None: + error = await session.write(chars) + if error: + raise RuntimeError(error) + stdin_closed = False + if close_stdin: + error = await session.close_stdin() + if error: + raise RuntimeError(error) + stdin_closed = True + if terminate: + await session.kill() + poll = await session.poll( + yield_time_ms, + max_output_chars, + terminated=terminate, + stdin_closed=stdin_closed, + ) + if poll.done: + async with self._lock: + self._sessions.pop(session_id, None) + return poll + + async def _cleanup_locked(self) -> None: + now = time.monotonic() + stale = [ + session_id + for session_id, session in self._sessions.items() + if session.process.returncode is not None + or now - session.last_access > self.idle_timeout + ] + for session_id in stale: + session = self._sessions.pop(session_id) + await session.kill() + + async def _spawn( + self, + command: str, + cwd: str, + env: dict[str, str], + shell_program: str | None, + login: bool, + ) -> asyncio.subprocess.Process: + from nanobot.agent.tools import shell + + if shell._IS_WINDOWS: + return await asyncio.create_subprocess_shell( + command, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=env, + ) + shell_program = shell_program or shutil.which("bash") or "/bin/bash" + args = [shell_program] + if login and shell_program.rsplit("/", 1)[-1] in {"bash", "zsh"}: + args.append("-l") + args.extend(["-c", command]) + return await asyncio.create_subprocess_exec( + *args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=env, + ) + + +DEFAULT_EXEC_SESSION_MANAGER = ExecSessionManager() + + +def clamp_session_int(value: int | None, default: int, minimum: int, maximum: int) -> int: + if value is None: + return default + return min(max(value, minimum), maximum) + + +def _truncate_output(output: str, max_output_chars: int) -> tuple[str, int]: + if len(output) <= max_output_chars: + return output, 0 + half = max_output_chars // 2 + omitted = len(output) - max_output_chars + return ( + output[:half] + + f"\n\n... ({omitted:,} chars truncated) ...\n\n" + + output[-half:], + omitted, + ) + + +def format_session_poll(session_id: str, poll: _SessionPoll) -> str: + parts = [poll.output] if poll.output else [] + if poll.truncated_chars: + parts.append(f"(output truncated by {poll.truncated_chars:,} chars)") + if poll.timed_out: + parts.append("Error: Command timed out; session was terminated.") + if poll.terminated and not poll.timed_out: + parts.append("Session terminated.") + if poll.stdin_closed: + parts.append("Stdin closed.") + if poll.done: + parts.append(f"Exit code: {poll.exit_code}") + else: + parts.append(f"Process running. session_id: {session_id}") + return "\n".join(parts) if parts else "(no output yet)" + + +@tool_parameters( + tool_parameters_schema( + session_id=StringSchema("Session id returned by exec when yield_time_ms is used."), + chars=StringSchema( + "Bytes/text to write to stdin. Omit or pass an empty string to only poll recent output.", + nullable=True, + ), + close_stdin=BooleanSchema( + description="Close stdin after writing chars. Useful for commands waiting for EOF.", + default=False, + ), + terminate=BooleanSchema( + description="Terminate the running exec session.", + default=False, + ), + yield_time_ms=IntegerSchema( + DEFAULT_YIELD_MS, + description="Milliseconds to wait before returning recent output (default 1000, max 30000).", + minimum=0, + maximum=MAX_YIELD_MS, + ), + max_output_chars=IntegerSchema( + DEFAULT_MAX_OUTPUT_CHARS, + description="Maximum output characters to return from this poll (default 10000, max 50000).", + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + ), + max_output_tokens=IntegerSchema( + DEFAULT_MAX_OUTPUT_CHARS, + description="Compatibility alias for max_output_chars. The current runtime uses a character budget.", + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + nullable=True, + ), + required=["session_id"], + ) +) +class WriteStdinTool(Tool): + """Write to or poll a running exec session.""" + + _scopes = {"core", "subagent"} + config_key = "exec" + + @classmethod + def config_cls(cls): + from nanobot.agent.tools.shell import ExecToolConfig + + return ExecToolConfig + + @classmethod + def enabled(cls, ctx: Any) -> bool: + return ctx.config.exec.enable + + def __init__( + self, + *, + manager: ExecSessionManager | None = None, + ) -> None: + self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER + + @classmethod + def create(cls, ctx: Any) -> Tool: + return cls() + + @property + def exclusive(self) -> bool: + return True + + @property + def name(self) -> str: + return "write_stdin" + + @property + def description(self) -> str: + return ( + "Write text to a running exec session and return recent output. " + "Use chars='' to poll without writing. Set close_stdin=true to send EOF, " + "or terminate=true to stop the session. Sessions finish automatically " + "when their process exits." + ) + + async def execute( + self, + session_id: str, + chars: str | None = None, + close_stdin: bool = False, + terminate: bool = False, + yield_time_ms: int | None = None, + max_output_chars: int | None = None, + max_output_tokens: int | None = None, + **kwargs: Any, + ) -> str: + try: + if max_output_chars is None: + max_output_chars = max_output_tokens + poll = await self._manager.write( + session_id=session_id, + chars=chars, + close_stdin=close_stdin, + terminate=terminate, + yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS), + max_output_chars=clamp_session_int( + max_output_chars, + DEFAULT_MAX_OUTPUT_CHARS, + 1000, + MAX_OUTPUT_CHARS, + ), + ) + return format_session_poll(session_id, poll) + except KeyError: + return f"Error: exec session not found: {session_id}" + except Exception as exc: + return f"Error writing to exec session: {exc}" diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index 8f4f660da..728ff9317 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -132,6 +132,10 @@ def _parse_page_range(pages: str, total: int) -> tuple[int, int]: minimum=1, ), pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"), + force=BooleanSchema( + description="Bypass same-file read deduplication and return content again.", + default=False, + ), required=["path"], ) ) @@ -155,6 +159,7 @@ class ReadFileTool(_FsTool): "Images return visual content for analysis. " "Supports PDF, DOCX, XLSX, PPTX documents. " "Use offset and limit for large text files. " + "Use force=true to re-read content even if unchanged. " "Reads exceeding ~128K chars are truncated." ) @@ -162,7 +167,15 @@ class ReadFileTool(_FsTool): def read_only(self) -> bool: return True - async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any: + async def execute( + self, + path: str | None = None, + offset: int = 1, + limit: int | None = None, + pages: str | None = None, + force: bool = False, + **kwargs: Any, + ) -> Any: try: if not path: return "Error reading file: Unknown path" @@ -202,7 +215,13 @@ class ReadFileTool(_FsTool): current_mtime = os.path.getmtime(fp) except OSError: current_mtime = 0.0 - if entry and entry.can_dedup and entry.offset == offset and entry.limit == limit: + if ( + not force + and entry + and entry.can_dedup + and entry.offset == offset + and entry.limit == limit + ): if current_mtime != entry.mtime: # File was modified externally - force full read and mark as not dedupable entry.can_dedup = False @@ -657,6 +676,24 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]: old_text=StringSchema("The text to find and replace"), new_text=StringSchema("The text to replace with"), replace_all=BooleanSchema(description="Replace all occurrences (default false)"), + occurrence=IntegerSchema( + 1, + description="Optional 1-based occurrence to replace when old_text appears multiple times.", + minimum=1, + nullable=True, + ), + line_hint=IntegerSchema( + 1, + description="Optional 1-based line hint used to choose the nearest match.", + minimum=1, + nullable=True, + ), + expected_replacements=IntegerSchema( + 1, + description="Optional guard for the number of replacements that must be made.", + minimum=1, + nullable=True, + ), required=["path", "old_text", "new_text"], ) ) @@ -677,7 +714,7 @@ class EditFileTool(_FsTool): "Edit a file by replacing old_text with new_text. " "Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. " "If old_text matches multiple times, you must provide more context " - "or set replace_all=true. Shows a diff of the closest match on failure." + "or set occurrence/line_hint/replace_all. Shows a diff of the closest match on failure." ) @staticmethod @@ -688,7 +725,8 @@ class EditFileTool(_FsTool): async def execute( self, path: str | None = None, old_text: str | None = None, new_text: str | None = None, - replace_all: bool = False, **kwargs: Any, + replace_all: bool = False, occurrence: int | None = None, + line_hint: int | None = None, expected_replacements: int | None = None, **kwargs: Any, ) -> str: try: if not path: @@ -697,10 +735,12 @@ class EditFileTool(_FsTool): raise ValueError("Unknown old_text") if new_text is None: raise ValueError("Unknown new_text") - - # .ipynb detection - if path.endswith(".ipynb"): - return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file." + if occurrence is not None and occurrence < 1: + return "Error: occurrence must be >= 1." + if line_hint is not None and line_hint < 1: + return "Error: line_hint must be >= 1." + if expected_replacements is not None and expected_replacements < 1: + return "Error: expected_replacements must be >= 1." fp = self._resolve(path) @@ -743,15 +783,42 @@ class EditFileTool(_FsTool): if not matches: return self._not_found_msg(old_text, content, path) count = len(matches) + if replace_all and occurrence is not None: + return "Error: occurrence cannot be used with replace_all=true." + if replace_all and line_hint is not None: + return "Error: line_hint cannot be used with replace_all=true." + if occurrence is not None and line_hint is not None: + return "Error: line_hint cannot be used with occurrence." if count > 1 and not replace_all: - line_numbers = [match.line for match in matches] - preview = ", ".join(f"line {n}" for n in line_numbers[:3]) - if len(line_numbers) > 3: - preview += ", ..." - location_hint = f" at {preview}" if preview else "" + if occurrence is not None: + if occurrence > count: + return ( + f"Error: occurrence {occurrence} is out of range; " + f"old_text appears {count} times." + ) + elif line_hint is not None: + nearest = min(matches, key=lambda match: abs(match.line - line_hint)) + distance = abs(nearest.line - line_hint) + if sum(1 for match in matches if abs(match.line - line_hint) == distance) > 1: + return ( + f"Error: line_hint {line_hint} is ambiguous; " + f"old_text appears {count} times." + ) + else: + line_numbers = [match.line for match in matches] + preview = ", ".join(f"line {n}" for n in line_numbers[:3]) + if len(line_numbers) > 3: + preview += ", ..." + location_hint = f" at {preview}" if preview else "" + return ( + f"Warning: old_text appears {count} times{location_hint}. " + "Provide more context, set occurrence to choose one match, " + "or set replace_all=true." + ) + elif occurrence is not None and occurrence > count: return ( - f"Warning: old_text appears {count} times{location_hint}. " - "Provide more context to make it unique, or set replace_all=true." + f"Error: occurrence {occurrence} is out of range; " + f"old_text appears {count} time." ) norm_new = new_text.replace("\r\n", "\n") @@ -760,7 +827,17 @@ class EditFileTool(_FsTool): if fp.suffix.lower() not in self._MARKDOWN_EXTS: norm_new = self._strip_trailing_ws(norm_new) - selected = matches if replace_all else matches[:1] + if replace_all: + selected = matches + elif line_hint is not None: + selected = [min(matches, key=lambda match: abs(match.line - line_hint))] + else: + selected = [matches[occurrence - 1 if occurrence else 0]] + if expected_replacements is not None and len(selected) != expected_replacements: + return ( + f"Error: expected {expected_replacements} replacements but " + f"would make {len(selected)}." + ) new_content = content for match in reversed(selected): replacement = _preserve_quote_style(norm_old, match.text, norm_new) diff --git a/nanobot/agent/tools/notebook.py b/nanobot/agent/tools/notebook.py deleted file mode 100644 index 0980b7c93..000000000 --- a/nanobot/agent/tools/notebook.py +++ /dev/null @@ -1,162 +0,0 @@ -"""NotebookEditTool — edit Jupyter .ipynb notebooks.""" - -from __future__ import annotations - -import json -import uuid -from typing import Any - -from nanobot.agent.tools.base import tool_parameters -from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema -from nanobot.agent.tools.filesystem import _FsTool - - -def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict: - cell: dict[str, Any] = { - "cell_type": cell_type, - "source": source, - "metadata": {}, - } - if cell_type == "code": - cell["outputs"] = [] - cell["execution_count"] = None - if generate_id: - cell["id"] = uuid.uuid4().hex[:8] - return cell - - -def _make_empty_notebook() -> dict: - return { - "nbformat": 4, - "nbformat_minor": 5, - "metadata": { - "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, - "language_info": {"name": "python"}, - }, - "cells": [], - } - - -@tool_parameters( - tool_parameters_schema( - path=StringSchema("Path to the .ipynb notebook file"), - cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0), - new_source=StringSchema("New source content for the cell"), - cell_type=StringSchema( - "Cell type: 'code' or 'markdown' (default: code)", - enum=["code", "markdown"], - ), - edit_mode=StringSchema( - "Mode: 'replace' (default), 'insert' (after target), or 'delete'", - enum=["replace", "insert", "delete"], - ), - required=["path", "cell_index"], - ) -) -class NotebookEditTool(_FsTool): - """Edit Jupyter notebook cells: replace, insert, or delete.""" - _scopes = {"core"} - - _VALID_CELL_TYPES = frozenset({"code", "markdown"}) - _VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"}) - - @property - def name(self) -> str: - return "notebook_edit" - - @property - def description(self) -> str: - return ( - "Edit a Jupyter notebook (.ipynb) cell. " - "Modes: replace (default) replaces cell content, " - "insert adds a new cell after the target index, " - "delete removes the cell at the index. " - "cell_index is 0-based." - ) - - async def execute( - self, - path: str | None = None, - cell_index: int = 0, - new_source: str = "", - cell_type: str = "code", - edit_mode: str = "replace", - **kwargs: Any, - ) -> str: - try: - if not path: - return "Error: path is required" - - if not path.endswith(".ipynb"): - return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files." - - if edit_mode not in self._VALID_EDIT_MODES: - return ( - f"Error: Invalid edit_mode '{edit_mode}'. " - "Use one of: replace, insert, delete." - ) - - if cell_type not in self._VALID_CELL_TYPES: - return ( - f"Error: Invalid cell_type '{cell_type}'. " - "Use one of: code, markdown." - ) - - fp = self._resolve(path) - - # Create new notebook if file doesn't exist and mode is insert - if not fp.exists(): - if edit_mode != "insert": - return f"Error: File not found: {path}" - nb = _make_empty_notebook() - cell = _new_cell(new_source, cell_type, generate_id=True) - nb["cells"].append(cell) - fp.parent.mkdir(parents=True, exist_ok=True) - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully created {fp} with 1 cell" - - try: - nb = json.loads(fp.read_text(encoding="utf-8")) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - return f"Error: Failed to parse notebook: {e}" - - cells = nb.get("cells", []) - nbformat_minor = nb.get("nbformat_minor", 0) - generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5 - - if edit_mode == "delete": - if cell_index < 0 or cell_index >= len(cells): - return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)" - cells.pop(cell_index) - nb["cells"] = cells - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully deleted cell {cell_index} from {fp}" - - if edit_mode == "insert": - insert_at = min(cell_index + 1, len(cells)) - cell = _new_cell(new_source, cell_type, generate_id=generate_id) - cells.insert(insert_at, cell) - nb["cells"] = cells - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully inserted cell at index {insert_at} in {fp}" - - # Default: replace - if cell_index < 0 or cell_index >= len(cells): - return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)" - cells[cell_index]["source"] = new_source - if cell_type and cells[cell_index].get("cell_type") != cell_type: - cells[cell_index]["cell_type"] = cell_type - if cell_type == "code": - cells[cell_index].setdefault("outputs", []) - cells[cell_index].setdefault("execution_count", None) - elif "outputs" in cells[cell_index]: - del cells[cell_index]["outputs"] - cells[cell_index].pop("execution_count", None) - nb["cells"] = cells - fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") - return f"Successfully edited cell {cell_index} in {fp}" - - except PermissionError as e: - return f"Error: {e}" - except Exception as e: - return f"Error editing notebook: {e}" diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index 0252b9746..47d3e9065 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -8,6 +8,7 @@ import re import shutil import sys from contextlib import suppress +from dataclasses import dataclass from pathlib import Path from typing import Any @@ -15,8 +16,17 @@ from loguru import logger from pydantic import Field from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.exec_session import ( + DEFAULT_MAX_OUTPUT_CHARS, + DEFAULT_YIELD_MS, + DEFAULT_EXEC_SESSION_MANAGER, + MAX_OUTPUT_CHARS, + MAX_YIELD_MS, + clamp_session_int, + format_session_poll, +) from nanobot.agent.tools.sandbox import wrap_command -from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema +from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base @@ -44,10 +54,22 @@ class ExecToolConfig(Base): deny_patterns: list[str] = Field(default_factory=list) +@dataclass(slots=True) +class _PreparedCommand: + command: str + cwd: str + env: dict[str, str] + timeout: int + shell_program: str | None + login: bool + + @tool_parameters( tool_parameters_schema( command=StringSchema("The shell command to execute"), + cmd=StringSchema("Compatibility alias for command"), working_dir=StringSchema("Optional working directory for the command"), + workdir=StringSchema("Compatibility alias for working_dir"), timeout=IntegerSchema( 60, description=( @@ -57,7 +79,44 @@ class ExecToolConfig(Base): minimum=1, maximum=600, ), - required=["command"], + shell=StringSchema( + "Optional shell binary to launch. On Unix, supports sh, bash, or zsh.", + nullable=True, + ), + login=BooleanSchema( + description="Whether to run bash/zsh with login shell semantics (default true).", + default=True, + nullable=True, + ), + yield_time_ms=IntegerSchema( + description=( + "Optional milliseconds to wait before returning output. " + "When set, a still-running command returns a session_id that " + "can be polled or written to with write_stdin. Omit this field " + "to keep one-shot exec behavior." + ), + minimum=0, + maximum=MAX_YIELD_MS, + nullable=True, + ), + max_output_chars=IntegerSchema( + description=( + "Maximum output characters to return when yield_time_ms is used " + "(default 10000, max 50000)." + ), + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + nullable=True, + ), + max_output_tokens=IntegerSchema( + description=( + "Compatibility alias for max_output_chars. The current runtime " + "uses a character budget." + ), + minimum=1000, + maximum=MAX_OUTPUT_CHARS, + nullable=True, + ), ) ) class ExecTool(Tool): @@ -98,6 +157,7 @@ class ExecTool(Tool): sandbox: str = "", path_append: str = "", allowed_env_keys: list[str] | None = None, + session_manager: Any | None = None, ): self.timeout = timeout self.working_dir = working_dir @@ -125,6 +185,7 @@ class ExecTool(Tool): self.restrict_to_workspace = restrict_to_workspace self.path_append = path_append self.allowed_env_keys = allowed_env_keys or [] + self._session_manager = session_manager or DEFAULT_EXEC_SESSION_MANAGER @property def name(self) -> str: @@ -153,7 +214,10 @@ class ExecTool(Tool): "Prefer read_file/write_file/edit_file over cat/echo/sed, " "and grep/glob over shell find/grep. " "Use -y or --yes flags to avoid interactive prompts. " - "Output is truncated at 10 000 chars; timeout defaults to 60s." + "For long-running or interactive commands, pass yield_time_ms; " + "if the command keeps running, exec returns a session_id that can " + "be polled or written to with write_stdin. Output is truncated at " + "10 000 chars; timeout defaults to 60s." ) @property @@ -161,9 +225,111 @@ class ExecTool(Tool): return True async def execute( - self, command: str, working_dir: str | None = None, - timeout: int | None = None, **kwargs: Any, + self, command: str | None = None, cmd: str | None = None, + working_dir: str | None = None, workdir: str | None = None, + timeout: int | None = None, shell: str | None = None, + login: bool | None = None, yield_time_ms: int | None = None, + max_output_chars: int | None = None, + max_output_tokens: int | None = None, + **kwargs: Any, ) -> str: + command = command or cmd + working_dir = working_dir or workdir + if not command: + return "Error: Missing command. Provide command or cmd." + if max_output_chars is None: + max_output_chars = max_output_tokens + + prepared = self._prepare_command(command, working_dir, timeout, shell, login) + if isinstance(prepared, str): + return prepared + + if yield_time_ms is not None: + return await self._execute_session(prepared, yield_time_ms, max_output_chars) + + try: + process = await self._spawn( + prepared.command, + prepared.cwd, + prepared.env, + prepared.shell_program, + prepared.login, + ) + + try: + stdout, stderr = await asyncio.wait_for( + process.communicate(), + timeout=prepared.timeout, + ) + except asyncio.TimeoutError: + await self._kill_process(process) + return f"Error: Command timed out after {prepared.timeout} seconds" + except asyncio.CancelledError: + await self._kill_process(process) + raise + + output_parts = [] + + if stdout: + output_parts.append(stdout.decode("utf-8", errors="replace")) + + if stderr: + stderr_text = stderr.decode("utf-8", errors="replace") + if stderr_text.strip(): + output_parts.append(f"STDERR:\n{stderr_text}") + + output_parts.append(f"\nExit code: {process.returncode}") + + result = "\n".join(output_parts) if output_parts else "(no output)" + + max_len = clamp_session_int(max_output_chars, self._MAX_OUTPUT, 1000, MAX_OUTPUT_CHARS) + if len(result) > max_len: + half = max_len // 2 + result = ( + result[:half] + + f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n" + + result[-half:] + ) + + return result + + except Exception as e: + return f"Error executing command: {str(e)}" + + async def _execute_session( + self, + prepared: _PreparedCommand, + yield_time_ms: int | None, + max_output_chars: int | None, + ) -> str: + try: + session_id, poll = await self._session_manager.start( + command=prepared.command, + cwd=prepared.cwd, + env=prepared.env, + timeout=prepared.timeout, + shell_program=prepared.shell_program, + login=prepared.login, + yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS), + max_output_chars=clamp_session_int( + max_output_chars, + DEFAULT_MAX_OUTPUT_CHARS, + 1000, + MAX_OUTPUT_CHARS, + ), + ) + return format_session_poll(session_id, poll) + except Exception as exc: + return f"Error executing command: {exc}" + + def _prepare_command( + self, + command: str, + working_dir: str | None = None, + timeout: int | None = None, + shell: str | None = None, + login: bool | None = None, + ) -> _PreparedCommand | str: cwd = working_dir or self.working_dir or os.getcwd() # Prevent an LLM-supplied working_dir from escaping the configured @@ -211,52 +377,24 @@ class ExecTool(Tool): env["NANOBOT_PATH_APPEND"] = self.path_append command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}' - try: - process = await self._spawn(command, cwd, env) + shell_program, shell_error = self._resolve_shell(shell) + if shell_error: + return shell_error - try: - stdout, stderr = await asyncio.wait_for( - process.communicate(), - timeout=effective_timeout, - ) - except asyncio.TimeoutError: - await self._kill_process(process) - return f"Error: Command timed out after {effective_timeout} seconds" - except asyncio.CancelledError: - await self._kill_process(process) - raise - - output_parts = [] - - if stdout: - output_parts.append(stdout.decode("utf-8", errors="replace")) - - if stderr: - stderr_text = stderr.decode("utf-8", errors="replace") - if stderr_text.strip(): - output_parts.append(f"STDERR:\n{stderr_text}") - - output_parts.append(f"\nExit code: {process.returncode}") - - result = "\n".join(output_parts) if output_parts else "(no output)" - - max_len = self._MAX_OUTPUT - if len(result) > max_len: - half = max_len // 2 - result = ( - result[:half] - + f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n" - + result[-half:] - ) - - return result - - except Exception as e: - return f"Error executing command: {str(e)}" + return _PreparedCommand( + command=command, + cwd=cwd, + env=env, + timeout=effective_timeout, + shell_program=shell_program, + login=True if login is None else login, + ) @staticmethod async def _spawn( command: str, cwd: str, env: dict[str, str], + shell_program: str | None = None, + login: bool = True, ) -> asyncio.subprocess.Process: """Launch *command* in a platform-appropriate shell.""" if _IS_WINDOWS: @@ -272,9 +410,13 @@ class ExecTool(Tool): cwd=cwd, env=env, ) - bash = shutil.which("bash") or "/bin/bash" + shell_program = shell_program or shutil.which("bash") or "/bin/bash" + args = [shell_program] + if login and Path(shell_program).name in {"bash", "zsh"}: + args.append("-l") + args.extend(["-c", command]) return await asyncio.create_subprocess_exec( - bash, "-l", "-c", command, + *args, stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, @@ -282,6 +424,31 @@ class ExecTool(Tool): env=env, ) + @staticmethod + def _resolve_shell(shell: str | None) -> tuple[str | None, str | None]: + if not shell: + return None, None + if _IS_WINDOWS: + return None, "Error: shell parameter is not supported on Windows" + if "\0" in shell or "\n" in shell or "\r" in shell: + return None, "Error: shell contains invalid characters" + allowed = {"sh", "bash", "zsh"} + path = Path(shell).expanduser() + if path.is_absolute(): + if path.name not in allowed: + return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh" + if not path.is_file() or not os.access(path, os.X_OK): + return None, f"Error: shell is not executable: {shell}" + return str(path), None + if "/" in shell or "\\" in shell: + return None, "Error: shell must be a shell name or absolute path" + if shell not in allowed: + return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh" + resolved = shutil.which(shell) + if not resolved: + return None, f"Error: shell not found: {shell}" + return resolved, None + @staticmethod async def _kill_process(process: asyncio.subprocess.Process) -> None: """Kill a subprocess and reap it to prevent zombies.""" diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index b5d2f6d73..056041f4b 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -3,7 +3,6 @@ from __future__ import annotations import difflib -import json import re import time from dataclasses import dataclass, field @@ -11,7 +10,7 @@ from pathlib import Path from typing import Any, Awaitable, Callable -TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"}) +TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file"}) _MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024 _LIVE_EMIT_INTERVAL_S = 0.18 _LIVE_EMIT_LINE_STEP = 24 @@ -704,77 +703,4 @@ def _predict_after_text( return before_text.replace(old_text, new_text) return before_text.replace(old_text, new_text, 1) return None - if tool_name == "notebook_edit": - return _predict_notebook_after_text(params, before_text) return None - - -def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> str | None: - try: - nb = json.loads(before_text) if before_text.strip() else _empty_notebook() - except Exception: - return None - cells = nb.get("cells") - if not isinstance(cells, list): - return None - try: - cell_index = int(params.get("cell_index", 0)) - except (TypeError, ValueError): - return None - new_source = params.get("new_source") - source = new_source if isinstance(new_source, str) else "" - cell_type = ( - params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code" - ) - mode = ( - params.get("edit_mode") - if params.get("edit_mode") in ("replace", "insert", "delete") - else "replace" - ) - if mode == "delete": - if 0 <= cell_index < len(cells): - cells.pop(cell_index) - else: - return None - elif mode == "insert": - insert_at = min(max(cell_index + 1, 0), len(cells)) - cells.insert(insert_at, _new_notebook_cell(source, str(cell_type))) - else: - if not (0 <= cell_index < len(cells)): - return None - cell = cells[cell_index] - if not isinstance(cell, dict): - return None - cell["source"] = source - cell["cell_type"] = cell_type - if cell_type == "code": - cell.setdefault("outputs", []) - cell.setdefault("execution_count", None) - else: - cell.pop("outputs", None) - cell.pop("execution_count", None) - nb["cells"] = cells - try: - return json.dumps(nb, indent=1, ensure_ascii=False) - except Exception: - return None - - -def _empty_notebook() -> dict[str, Any]: - return { - "nbformat": 4, - "nbformat_minor": 5, - "metadata": { - "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, - "language_info": {"name": "python"}, - }, - "cells": [], - } - - -def _new_notebook_cell(source: str, cell_type: str) -> dict[str, Any]: - cell: dict[str, Any] = {"cell_type": cell_type, "source": source, "metadata": {}} - if cell_type == "code": - cell["outputs"] = [] - cell["execution_count"] = None - return cell diff --git a/tests/tools/test_apply_patch_tool.py b/tests/tools/test_apply_patch_tool.py new file mode 100644 index 000000000..ea98794b3 --- /dev/null +++ b/tests/tools/test_apply_patch_tool.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +import asyncio + +from nanobot.agent.tools.apply_patch import ApplyPatchTool + + +def test_apply_patch_adds_file(tmp_path): + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Add File: hello.txt ++Hello ++world +*** End Patch +""" + )) + + assert "Patch applied" in result + assert (tmp_path / "hello.txt").read_text() == "Hello\nworld\n" + + +def test_apply_patch_updates_multiple_hunks(tmp_path): + target = tmp_path / "multi.txt" + target.write_text("line1\nline2\nline3\nline4\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: multi.txt +@@ +-line2 ++changed2 +@@ +-line4 ++changed4 +*** End Patch +""" + )) + + assert "update multi.txt" in result + assert target.read_text() == "line1\nchanged2\nline3\nchanged4\n" + + +def test_apply_patch_ignores_standard_no_newline_marker(tmp_path): + target = tmp_path / "plain.txt" + target.write_text("before") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: plain.txt +@@ -1,1 +1,1 @@ +-before +\\ No newline at end of file ++after +\\ No newline at end of file +*** End Patch +""" + )) + + assert "update plain.txt" in result + assert target.read_text() == "after\n" + + +def test_apply_patch_rejects_empty_hunk(tmp_path): + target = tmp_path / "plain.txt" + target.write_text("before\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: plain.txt +@@ +*** End Patch +""" + )) + + assert "hunk is empty" in result + assert target.read_text() == "before\n" + + +def test_apply_patch_uses_unified_diff_line_hint(tmp_path): + target = tmp_path / "repeated.txt" + target.write_text("target\nmiddle\ntarget\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: repeated.txt +@@ -3,1 +3,1 @@ +-target ++changed +*** End Patch +""" + )) + + assert "update repeated.txt" in result + assert target.read_text() == "target\nmiddle\nchanged\n" + + +def test_apply_patch_line_hint_does_not_fallback_to_earlier_match(tmp_path): + target = tmp_path / "repeated.txt" + target.write_text("target\nmiddle\nother\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: repeated.txt +@@ -3,1 +3,1 @@ +-target ++changed +*** End Patch +""" + )) + + assert "hunk does not match repeated.txt" in result + assert target.read_text() == "target\nmiddle\nother\n" + + +def test_apply_patch_mismatch_reports_best_match(tmp_path): + target = tmp_path / "near.txt" + target.write_text("alpha\nbeta\ngamma\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: near.txt +@@ -2,1 +2,1 @@ +-betx ++delta +*** End Patch +""" + )) + + assert "hunk does not match near.txt" in result + assert "Best match" in result + assert "line 2" in result + assert target.read_text() == "alpha\nbeta\ngamma\n" + + +def test_apply_patch_moves_and_updates_file(tmp_path): + source = tmp_path / "old" / "name.txt" + source.parent.mkdir() + source.write_text("old content\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: old/name.txt +*** Move to: renamed/dir/name.txt +@@ +-old content ++new content +*** End Patch +""" + )) + + assert "move old/name.txt -> renamed/dir/name.txt" in result + assert not source.exists() + assert (tmp_path / "renamed" / "dir" / "name.txt").read_text() == "new content\n" + + +def test_apply_patch_deletes_file(tmp_path): + target = tmp_path / "obsolete.txt" + target.write_text("remove me\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Delete File: obsolete.txt +*** End Patch +""" + )) + + assert "delete obsolete.txt" in result + assert not target.exists() + + +def test_apply_patch_rejects_absolute_and_parent_paths(tmp_path): + tool = ApplyPatchTool(workspace=tmp_path) + + absolute = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Add File: /tmp/owned.txt ++nope +*** End Patch +""" + )) + parent = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Add File: ../owned.txt ++nope +*** End Patch +""" + )) + + assert "must be relative" in absolute + assert "must not contain '..'" in parent + assert not (tmp_path.parent / "owned.txt").exists() + + +def test_apply_patch_does_not_overwrite_existing_file_with_add(tmp_path): + target = tmp_path / "existing.txt" + target.write_text("keep me\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Add File: existing.txt ++replace me +*** End Patch +""" + )) + + assert "file to add already exists" in result + assert target.read_text() == "keep me\n" + + +def test_apply_patch_rolls_back_when_late_operation_fails(tmp_path): + first = tmp_path / "first.txt" + first.write_text("before\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: first.txt +@@ +-before ++after +*** Delete File: missing.txt +*** End Patch +""" + )) + + assert "file to delete does not exist" in result + assert first.read_text() == "before\n" diff --git a/tests/tools/test_edit_enhancements.py b/tests/tools/test_edit_enhancements.py index 1f22c963b..7202fc37b 100644 --- a/tests/tools/test_edit_enhancements.py +++ b/tests/tools/test_edit_enhancements.py @@ -1,5 +1,5 @@ """Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions, -.ipynb detection, and create-file semantics.""" +notebook JSON editing, and create-file semantics.""" import pytest @@ -108,22 +108,27 @@ class TestEditCreateFile: # --------------------------------------------------------------------------- -# .ipynb detection +# .ipynb editing # --------------------------------------------------------------------------- -class TestEditIpynbDetection: - """edit_file should refuse .ipynb and suggest notebook_edit.""" +class TestEditIpynbFiles: + """edit_file edits notebooks as normal JSON files.""" @pytest.fixture() def tool(self, tmp_path): return EditFileTool(workspace=tmp_path) @pytest.mark.asyncio - async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path): + async def test_ipynb_can_be_edited_as_json(self, tool, tmp_path): f = tmp_path / "analysis.ipynb" f.write_text('{"cells": []}', encoding="utf-8") - result = await tool.execute(path=str(f), old_text="x", new_text="y") - assert "notebook" in result.lower() + result = await tool.execute( + path=str(f), + old_text='"cells": []', + new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]', + ) + assert "Successfully edited" in result + assert '"source": "hi"' in f.read_text(encoding="utf-8") # --------------------------------------------------------------------------- diff --git a/tests/tools/test_exec_session_tools.py b/tests/tools/test_exec_session_tools.py new file mode 100644 index 000000000..945473926 --- /dev/null +++ b/tests/tools/test_exec_session_tools.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import asyncio +import re +import shlex +import subprocess +import sys + +from nanobot.agent.tools.shell import ExecTool +from nanobot.agent.tools.exec_session import ExecSessionManager, WriteStdinTool + + +def _python_command(code: str) -> str: + if sys.platform == "win32": + return f"{subprocess.list2cmdline([sys.executable])} -u -c {subprocess.list2cmdline([code])}" + return f"{shlex.quote(sys.executable)} -u -c {shlex.quote(code)}" + + +def _session_id(output: str) -> str: + match = re.search(r"session_id:\s*([0-9a-f]+)", output) + assert match, output + return match.group(1) + + +def test_exec_keeps_one_shot_behavior_without_yield_time_ms(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + return await tool.execute(command="echo hello") + + result = asyncio.run(run()) + + assert "hello" in result + assert "Exit code: 0" in result + assert "session_id:" not in result + + +def test_exec_accepts_command_aliases(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir="/") + return await tool.execute(cmd="pwd", workdir=str(tmp_path)) + + result = asyncio.run(run()) + + assert str(tmp_path) in result + assert "Exit code: 0" in result + + +def test_exec_returns_completed_session_output_when_yield_time_ms_is_used(tmp_path): + async def run() -> str: + manager = ExecSessionManager() + tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + + result = await tool.execute(command="echo hello", yield_time_ms=1000) + if "session_id:" in result: + sid = _session_id(result) + result += "\n" + await stdin_tool.execute( + session_id=sid, + chars="", + yield_time_ms=1000, + ) + return result + + result = asyncio.run(run()) + + assert "hello" in result + assert "Exit code: 0" in result + assert "session_id:" not in result + + +def test_exec_session_accepts_max_output_tokens_alias(tmp_path): + async def run() -> str: + manager = ExecSessionManager() + tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + command = _python_command("print('A' * 2000)") + return await tool.execute( + command=command, + yield_time_ms=1000, + max_output_tokens=1000, + ) + + result = asyncio.run(run()) + + assert "chars truncated" in result + assert "Exit code: 0" in result + + +def test_exec_one_shot_accepts_max_output_tokens_alias(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + command = _python_command("print('A' * 2000)") + return await tool.execute(command=command, max_output_tokens=1000) + + result = asyncio.run(run()) + + assert "chars truncated" in result + assert "Exit code: 0" in result + + +def test_exec_accepts_supported_shell_parameter(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + return await tool.execute(command="echo shell-ok", shell="sh", login=False) + + if sys.platform == "win32": + return + result = asyncio.run(run()) + + assert "shell-ok" in result + assert "Exit code: 0" in result + + +def test_exec_rejects_unsupported_shell(tmp_path): + async def run() -> str: + tool = ExecTool(working_dir=str(tmp_path), timeout=5) + return await tool.execute(command="echo no", shell="python") + + if sys.platform == "win32": + return + result = asyncio.run(run()) + + assert "unsupported shell" in result + + +def test_exec_can_continue_with_stdin(tmp_path): + async def run() -> tuple[str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import sys; print('ready', flush=True); " + "line=sys.stdin.readline(); print('got:' + line.strip(), flush=True)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=500) + sid = _session_id(initial) + result = await stdin_tool.execute(session_id=sid, chars="ping\n", yield_time_ms=1000) + return initial, result + + initial, result = asyncio.run(run()) + assert "ready" in initial + assert "Process running" in initial + assert "got:ping" in result + assert "Exit code: 0" in result + + +def test_write_stdin_can_close_stdin(tmp_path): + async def run() -> tuple[str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import sys; print('ready', flush=True); " + "data=sys.stdin.read(); print('got:' + data, flush=True)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=500) + sid = _session_id(initial) + result = await stdin_tool.execute( + session_id=sid, + chars="payload", + close_stdin=True, + yield_time_ms=1000, + ) + return initial, result + + initial, result = asyncio.run(run()) + assert "ready" in initial + assert "got:payload" in result + assert "Stdin closed." in result + assert "Exit code: 0" in result + + +def test_write_stdin_can_terminate_session(tmp_path): + async def run() -> tuple[str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=30, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('ready', flush=True); time.sleep(30)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=500) + sid = _session_id(initial) + result = await stdin_tool.execute( + session_id=sid, + terminate=True, + yield_time_ms=0, + ) + return initial, result + + initial, result = asyncio.run(run()) + assert "ready" in initial + assert "Session terminated." in result + assert "Exit code:" in result + + +def test_write_stdin_accepts_max_output_tokens_alias(tmp_path): + async def run() -> tuple[str, str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('A' * 2000, flush=True); time.sleep(5)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=0) + sid = _session_id(initial) + poll = await stdin_tool.execute( + session_id=sid, + yield_time_ms=500, + max_output_tokens=1000, + ) + cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0) + return initial, poll, cleanup + + initial, poll, cleanup = asyncio.run(run()) + assert "Process running" in initial + assert "chars truncated" in poll + assert "Session terminated." in cleanup + + +def test_exec_session_mode_reuses_exec_safety_guard(tmp_path): + manager = ExecSessionManager() + tool = ExecTool( + working_dir=str(tmp_path), + deny_patterns=[r"echo\s+blocked"], + session_manager=manager, + ) + + result = asyncio.run(tool.execute(command="echo blocked", yield_time_ms=0)) + + assert "blocked by deny pattern" in result + + +def test_write_stdin_reports_missing_session(tmp_path): + manager = ExecSessionManager() + tool = WriteStdinTool(manager=manager) + + result = asyncio.run(tool.execute(session_id="missing", chars="")) + + assert "exec session not found" in result diff --git a/tests/tools/test_file_edit_coding_enhancements.py b/tests/tools/test_file_edit_coding_enhancements.py new file mode 100644 index 000000000..d361d88ae --- /dev/null +++ b/tests/tools/test_file_edit_coding_enhancements.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import asyncio + +from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool + + +def test_read_file_force_bypasses_dedup(tmp_path): + target = tmp_path / "data.txt" + target.write_text("alpha\n") + tool = ReadFileTool(workspace=tmp_path) + + first = asyncio.run(tool.execute(path=str(target))) + second = asyncio.run(tool.execute(path=str(target))) + forced = asyncio.run(tool.execute(path=str(target), force=True)) + + assert "alpha" in first + assert "unchanged" in second.lower() + assert "alpha" in forced + assert "unchanged" not in forced.lower() + + +def test_edit_file_can_select_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("one\nsame\ntwo\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=2, + )) + + assert "Successfully edited" in result + assert target.read_text() == "one\nsame\ntwo\nchanged\n" + + +def test_edit_file_expected_replacements_guards_replace_all(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + replace_all=True, + expected_replacements=1, + )) + + assert "expected 1 replacements but would make 2" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_expected_replacements_allows_replace_all_when_count_matches(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + replace_all=True, + expected_replacements=2, + )) + + assert "Successfully edited" in result + assert target.read_text() == "changed\nchanged\n" + + +def test_edit_file_can_select_nearest_line_hint(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("one\nsame\ntwo\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=4, + )) + + assert "Successfully edited" in result + assert target.read_text() == "one\nsame\ntwo\nchanged\n" + + +def test_edit_file_can_edit_ipynb_as_json(tmp_path): + target = tmp_path / "analysis.ipynb" + target.write_text('{"cells": []}') + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text='"cells": []', + new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]', + )) + + assert "Successfully edited" in result + assert '"source": "hi"' in target.read_text() + + +def test_edit_file_multiple_match_hint_mentions_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + )) + + assert "old_text appears 2 times" in result + assert "occurrence" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_ambiguous_line_hint(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nmiddle\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=2, + )) + + assert "line_hint 2 is ambiguous" in result + assert target.read_text() == "same\nmiddle\nsame\n" + + +def test_edit_file_rejects_occurrence_with_replace_all(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=1, + replace_all=True, + )) + + assert "occurrence cannot be used with replace_all" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_line_hint_with_replace_all(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=1, + replace_all=True, + )) + + assert "line_hint cannot be used with replace_all" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_line_hint_with_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\nsame\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=1, + line_hint=1, + )) + + assert "line_hint cannot be used with occurrence" in result + assert target.read_text() == "same\nsame\n" + + +def test_edit_file_rejects_zero_occurrence(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + occurrence=0, + )) + + assert "occurrence must be >= 1" in result + assert target.read_text() == "same\n" + + +def test_edit_file_rejects_zero_line_hint(tmp_path): + target = tmp_path / "duplicate.txt" + target.write_text("same\n") + tool = EditFileTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + path=str(target), + old_text="same", + new_text="changed", + line_hint=0, + )) + + assert "line_hint must be >= 1" in result + assert target.read_text() == "same\n" diff --git a/tests/tools/test_notebook_tool.py b/tests/tools/test_notebook_tool.py deleted file mode 100644 index 232f13c4b..000000000 --- a/tests/tools/test_notebook_tool.py +++ /dev/null @@ -1,147 +0,0 @@ -"""Tests for NotebookEditTool — Jupyter .ipynb editing.""" - -import json - -import pytest - -from nanobot.agent.tools.notebook import NotebookEditTool - - -def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict: - """Build a minimal valid .ipynb structure.""" - return { - "nbformat": nbformat, - "nbformat_minor": nbformat_minor, - "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}}, - "cells": cells or [], - } - - -def _code_cell(source: str, cell_id: str | None = None) -> dict: - cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None} - if cell_id: - cell["id"] = cell_id - return cell - - -def _md_cell(source: str, cell_id: str | None = None) -> dict: - cell = {"cell_type": "markdown", "source": source, "metadata": {}} - if cell_id: - cell["id"] = cell_id - return cell - - -def _write_nb(tmp_path, name: str, nb: dict) -> str: - p = tmp_path / name - p.write_text(json.dumps(nb), encoding="utf-8") - return str(p) - - -class TestNotebookEdit: - - @pytest.fixture() - def tool(self, tmp_path): - return NotebookEditTool(workspace=tmp_path) - - @pytest.mark.asyncio - async def test_replace_cell_content(self, tool, tmp_path): - nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="print('world')") - assert "Successfully" in result - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert saved["cells"][0]["source"] == "print('world')" - assert saved["cells"][1]["source"] == "x = 1" - - @pytest.mark.asyncio - async def test_insert_cell_after_target(self, tool, tmp_path): - nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert") - assert "Successfully" in result - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert len(saved["cells"]) == 3 - assert saved["cells"][0]["source"] == "cell 0" - assert saved["cells"][1]["source"] == "inserted" - assert saved["cells"][2]["source"] == "cell 1" - - @pytest.mark.asyncio - async def test_delete_cell(self, tool, tmp_path): - nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=1, edit_mode="delete") - assert "Successfully" in result - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert len(saved["cells"]) == 2 - assert saved["cells"][0]["source"] == "A" - assert saved["cells"][1]["source"] == "C" - - @pytest.mark.asyncio - async def test_create_new_notebook_from_scratch(self, tool, tmp_path): - path = str(tmp_path / "new.ipynb") - result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown") - assert "Successfully" in result or "created" in result.lower() - saved = json.loads((tmp_path / "new.ipynb").read_text()) - assert saved["nbformat"] == 4 - assert len(saved["cells"]) == 1 - assert saved["cells"][0]["cell_type"] == "markdown" - assert saved["cells"][0]["source"] == "# Hello" - - @pytest.mark.asyncio - async def test_invalid_cell_index_error(self, tool, tmp_path): - nb = _make_notebook([_code_cell("only cell")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=5, new_source="x") - assert "Error" in result - - @pytest.mark.asyncio - async def test_non_ipynb_rejected(self, tool, tmp_path): - f = tmp_path / "script.py" - f.write_text("pass") - result = await tool.execute(path=str(f), cell_index=0, new_source="x") - assert "Error" in result - assert ".ipynb" in result - - @pytest.mark.asyncio - async def test_preserves_metadata_and_outputs(self, tool, tmp_path): - cell = _code_cell("old") - cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}] - cell["execution_count"] = 42 - nb = _make_notebook([cell]) - path = _write_nb(tmp_path, "test.ipynb", nb) - await tool.execute(path=path, cell_index=0, new_source="new") - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert saved["metadata"]["kernelspec"]["language"] == "python" - - @pytest.mark.asyncio - async def test_nbformat_45_generates_cell_id(self, tool, tmp_path): - nb = _make_notebook([], nbformat_minor=5) - path = _write_nb(tmp_path, "test.ipynb", nb) - await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert") - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert "id" in saved["cells"][0] - assert len(saved["cells"][0]["id"]) > 0 - - @pytest.mark.asyncio - async def test_insert_with_cell_type_markdown(self, tool, tmp_path): - nb = _make_notebook([_code_cell("code")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown") - saved = json.loads((tmp_path / "test.ipynb").read_text()) - assert saved["cells"][1]["cell_type"] == "markdown" - - @pytest.mark.asyncio - async def test_invalid_edit_mode_rejected(self, tool, tmp_path): - nb = _make_notebook([_code_cell("code")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae") - assert "Error" in result - assert "edit_mode" in result - - @pytest.mark.asyncio - async def test_invalid_cell_type_rejected(self, tool, tmp_path): - nb = _make_notebook([_code_cell("code")]) - path = _write_nb(tmp_path, "test.ipynb", nb) - result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw") - assert "Error" in result - assert "cell_type" in result diff --git a/tests/tools/test_tool_loader.py b/tests/tools/test_tool_loader.py index 54b4d92d5..2dfb25cb7 100644 --- a/tests/tools/test_tool_loader.py +++ b/tests/tools/test_tool_loader.py @@ -89,9 +89,11 @@ def test_discover_finds_concrete_tools(): loader = ToolLoader() discovered = loader.discover() class_names = {cls.__name__ for cls in discovered} + assert "ApplyPatchTool" in class_names assert "ExecTool" in class_names assert "MessageTool" in class_names assert "SpawnTool" in class_names + assert "WriteStdinTool" in class_names def test_discover_excludes_abstract_and_mcp(): @@ -406,7 +408,7 @@ def test_loader_registers_same_tools_as_old_hardcoded(): expected = { "read_file", "write_file", "edit_file", "list_dir", - "grep", "notebook_edit", "exec", "web_search", "web_fetch", + "grep", "exec", "web_search", "web_fetch", "message", "spawn", "cron", } actual = set(registered)