From 651aeae656e33b029adba1eeab5af1aee05d4df4 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Fri, 10 Apr 2026 15:44:50 +0000 Subject: [PATCH] improve file editing and add notebook tool Enhance file tools with read tracking, PDF support, safer path handling, smarter edit matching/diagnostics, and introduce notebook_edit with tests. --- nanobot/agent/loop.py | 2 + nanobot/agent/tools/file_state.py | 105 ++++++ nanobot/agent/tools/filesystem.py | 503 ++++++++++++++++++++++++-- nanobot/agent/tools/notebook.py | 162 +++++++++ pyproject.toml | 4 + tests/tools/test_edit_advanced.py | 423 ++++++++++++++++++++++ tests/tools/test_edit_enhancements.py | 152 ++++++++ tests/tools/test_notebook_tool.py | 147 ++++++++ tests/tools/test_read_enhancements.py | 180 +++++++++ 9 files changed, 1638 insertions(+), 40 deletions(-) create mode 100644 nanobot/agent/tools/file_state.py create mode 100644 nanobot/agent/tools/notebook.py create mode 100644 tests/tools/test_edit_advanced.py create mode 100644 tests/tools/test_edit_enhancements.py create mode 100644 tests/tools/test_notebook_tool.py create mode 100644 tests/tools/test_read_enhancements.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index bc83cc77c..56d79f3f9 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -22,6 +22,7 @@ from nanobot.agent.tools.cron import CronTool from nanobot.agent.skills import BUILTIN_SKILLS_DIR from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.message import MessageTool +from nanobot.agent.tools.notebook import NotebookEditTool from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.search import GlobTool, GrepTool from nanobot.agent.tools.shell import ExecTool @@ -235,6 +236,7 @@ class AgentLoop: self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) for cls in (GlobTool, GrepTool): self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir)) + self.tools.register(NotebookEditTool(workspace=self.workspace, allowed_dir=allowed_dir)) if self.exec_config.enable: self.tools.register(ExecTool( working_dir=str(self.workspace), diff --git a/nanobot/agent/tools/file_state.py b/nanobot/agent/tools/file_state.py new file mode 100644 index 000000000..81b1d4485 --- /dev/null +++ b/nanobot/agent/tools/file_state.py @@ -0,0 +1,105 @@ +"""Track file-read state for read-before-edit warnings and read deduplication.""" + +from __future__ import annotations + +import hashlib +import os +from dataclasses import dataclass +from pathlib import Path + + +@dataclass(slots=True) +class ReadState: + mtime: float + offset: int + limit: int | None + content_hash: str | None + can_dedup: bool + + +_state: dict[str, ReadState] = {} + + +def _hash_file(p: str) -> str | None: + try: + return hashlib.sha256(Path(p).read_bytes()).hexdigest() + except OSError: + return None + + +def record_read(path: str | Path, offset: int = 1, limit: int | None = None) -> None: + """Record that a file was read (called after successful read).""" + p = str(Path(path).resolve()) + try: + mtime = os.path.getmtime(p) + except OSError: + return + _state[p] = ReadState( + mtime=mtime, + offset=offset, + limit=limit, + content_hash=_hash_file(p), + can_dedup=True, + ) + + +def record_write(path: str | Path) -> None: + """Record that a file was written (updates mtime in state).""" + p = str(Path(path).resolve()) + try: + mtime = os.path.getmtime(p) + except OSError: + _state.pop(p, None) + return + _state[p] = ReadState( + mtime=mtime, + offset=1, + limit=None, + content_hash=_hash_file(p), + can_dedup=False, + ) + + +def check_read(path: str | Path) -> str | None: + """Check if a file has been read and is fresh. + + Returns None if OK, or a warning string. + When mtime changed but file content is identical (e.g. touch, editor save), + the check passes to avoid false-positive staleness warnings. + """ + p = str(Path(path).resolve()) + entry = _state.get(p) + if entry is None: + return "Warning: file has not been read yet. Read it first to verify content before editing." + try: + current_mtime = os.path.getmtime(p) + except OSError: + return None + if current_mtime != entry.mtime: + if entry.content_hash and _hash_file(p) == entry.content_hash: + entry.mtime = current_mtime + return None + return "Warning: file has been modified since last read. Re-read to verify content before editing." + return None + + +def is_unchanged(path: str | Path, offset: int = 1, limit: int | None = None) -> bool: + """Return True if file was previously read with same params and mtime is unchanged.""" + p = str(Path(path).resolve()) + entry = _state.get(p) + if entry is None: + return False + if not entry.can_dedup: + return False + if entry.offset != offset or entry.limit != limit: + return False + try: + current_mtime = os.path.getmtime(p) + except OSError: + return False + return current_mtime == entry.mtime + + +def clear() -> None: + """Clear all tracked state (useful for testing).""" + _state.clear() diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index fdce38b69..e131a2e69 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -2,11 +2,13 @@ import difflib import mimetypes +from dataclasses import dataclass from pathlib import Path from typing import Any from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema +from nanobot.agent.tools import file_state from nanobot.utils.helpers import build_image_content_blocks, detect_image_mime from nanobot.config.paths import get_media_dir @@ -60,6 +62,36 @@ class _FsTool(Tool): # --------------------------------------------------------------------------- +_BLOCKED_DEVICE_PATHS = frozenset({ + "/dev/zero", "/dev/random", "/dev/urandom", "/dev/full", + "/dev/stdin", "/dev/stdout", "/dev/stderr", + "/dev/tty", "/dev/console", + "/dev/fd/0", "/dev/fd/1", "/dev/fd/2", +}) + + +def _is_blocked_device(path: str | Path) -> bool: + """Check if path is a blocked device that could hang or produce infinite output.""" + import re + raw = str(path) + if raw in _BLOCKED_DEVICE_PATHS: + return True + if re.match(r"/proc/\d+/fd/[012]$", raw) or re.match(r"/proc/self/fd/[012]$", raw): + return True + return False + + +def _parse_page_range(pages: str, total: int) -> tuple[int, int]: + """Parse a page range like '2-5' into 0-based (start, end) inclusive.""" + parts = pages.strip().split("-") + if len(parts) == 1: + p = int(parts[0]) + return max(0, p - 1), min(p - 1, total - 1) + start = int(parts[0]) + end = int(parts[1]) + return max(0, start - 1), min(end - 1, total - 1) + + @tool_parameters( tool_parameters_schema( path=StringSchema("The file path to read"), @@ -73,6 +105,7 @@ class _FsTool(Tool): description="Maximum number of lines to read (default 2000)", minimum=1, ), + pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"), required=["path"], ) ) @@ -81,6 +114,7 @@ class ReadFileTool(_FsTool): _MAX_CHARS = 128_000 _DEFAULT_LIMIT = 2000 + _MAX_PDF_PAGES = 20 @property def name(self) -> str: @@ -89,9 +123,10 @@ class ReadFileTool(_FsTool): @property def description(self) -> str: return ( - "Read a text file. Output format: LINE_NUM|CONTENT. " + "Read a file (text or image). Text output format: LINE_NUM|CONTENT. " + "Images return visual content for analysis. " "Use offset and limit for large files. " - "Cannot read binary files or images. " + "Cannot read non-image binary files. " "Reads exceeding ~128K chars are truncated." ) @@ -99,16 +134,27 @@ class ReadFileTool(_FsTool): def read_only(self) -> bool: return True - async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any: + async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any: try: if not path: return "Error reading file: Unknown path" + + # Device path blacklist + if _is_blocked_device(path): + return f"Error: Reading {path} is blocked (device path that could hang or produce infinite output)." + fp = self._resolve(path) + if _is_blocked_device(fp): + return f"Error: Reading {fp} is blocked (device path that could hang or produce infinite output)." if not fp.exists(): return f"Error: File not found: {path}" if not fp.is_file(): return f"Error: Not a file: {path}" + # PDF support + if fp.suffix.lower() == ".pdf": + return self._read_pdf(fp, pages) + raw = fp.read_bytes() if not raw: return f"(Empty file: {path})" @@ -117,6 +163,10 @@ class ReadFileTool(_FsTool): if mime and mime.startswith("image/"): return build_image_content_blocks(raw, mime, str(fp), f"(Image file: {path})") + # Read dedup: same path + offset + limit + unchanged mtime → stub + if file_state.is_unchanged(fp, offset=offset, limit=limit): + return f"[File unchanged since last read: {path}]" + try: text_content = raw.decode("utf-8") except UnicodeDecodeError: @@ -149,12 +199,59 @@ class ReadFileTool(_FsTool): result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)" else: result += f"\n\n(End of file — {total} lines total)" + file_state.record_read(fp, offset=offset, limit=limit) return result except PermissionError as e: return f"Error: {e}" except Exception as e: return f"Error reading file: {e}" + def _read_pdf(self, fp: Path, pages: str | None) -> str: + try: + import fitz # pymupdf + except ImportError: + return "Error: PDF reading requires pymupdf. Install with: pip install pymupdf" + + try: + doc = fitz.open(str(fp)) + except Exception as e: + return f"Error reading PDF: {e}" + + total_pages = len(doc) + if pages: + try: + start, end = _parse_page_range(pages, total_pages) + except (ValueError, IndexError): + doc.close() + return f"Error: Invalid page range '{pages}'. Use format like '1-5'." + if start > end or start >= total_pages: + doc.close() + return f"Error: Page range '{pages}' is out of bounds (document has {total_pages} pages)." + else: + start = 0 + end = min(total_pages - 1, self._MAX_PDF_PAGES - 1) + + if end - start + 1 > self._MAX_PDF_PAGES: + end = start + self._MAX_PDF_PAGES - 1 + + parts: list[str] = [] + for i in range(start, end + 1): + page = doc[i] + text = page.get_text().strip() + if text: + parts.append(f"--- Page {i + 1} ---\n{text}") + doc.close() + + if not parts: + return f"(PDF has no extractable text: {fp})" + + result = "\n\n".join(parts) + if end < total_pages - 1: + result += f"\n\n(Showing pages {start + 1}-{end + 1} of {total_pages}. Use pages='{end + 2}-{min(end + 1 + self._MAX_PDF_PAGES, total_pages)}' to continue.)" + if len(result) > self._MAX_CHARS: + result = result[:self._MAX_CHARS] + "\n\n(PDF text truncated at ~128K chars)" + return result + # --------------------------------------------------------------------------- # write_file @@ -192,6 +289,7 @@ class WriteFileTool(_FsTool): fp = self._resolve(path) fp.parent.mkdir(parents=True, exist_ok=True) fp.write_text(content, encoding="utf-8") + file_state.record_write(fp) return f"Successfully wrote {len(content)} characters to {fp}" except PermissionError as e: return f"Error: {e}" @@ -203,30 +301,269 @@ class WriteFileTool(_FsTool): # edit_file # --------------------------------------------------------------------------- +_QUOTE_TABLE = str.maketrans({ + "\u2018": "'", "\u2019": "'", # curly single → straight + "\u201c": '"', "\u201d": '"', # curly double → straight + "'": "'", '"': '"', # identity (kept for completeness) +}) + + +def _normalize_quotes(s: str) -> str: + return s.translate(_QUOTE_TABLE) + + +def _curly_double_quotes(text: str) -> str: + parts: list[str] = [] + opening = True + for ch in text: + if ch == '"': + parts.append("\u201c" if opening else "\u201d") + opening = not opening + else: + parts.append(ch) + return "".join(parts) + + +def _curly_single_quotes(text: str) -> str: + parts: list[str] = [] + opening = True + for i, ch in enumerate(text): + if ch != "'": + parts.append(ch) + continue + prev_ch = text[i - 1] if i > 0 else "" + next_ch = text[i + 1] if i + 1 < len(text) else "" + if prev_ch.isalnum() and next_ch.isalnum(): + parts.append("\u2019") + continue + parts.append("\u2018" if opening else "\u2019") + opening = not opening + return "".join(parts) + + +def _preserve_quote_style(old_text: str, actual_text: str, new_text: str) -> str: + """Preserve curly quote style when a quote-normalized fallback matched.""" + if _normalize_quotes(old_text.strip()) != _normalize_quotes(actual_text.strip()) or old_text == actual_text: + return new_text + + styled = new_text + if any(ch in actual_text for ch in ("\u201c", "\u201d")) and '"' in styled: + styled = _curly_double_quotes(styled) + if any(ch in actual_text for ch in ("\u2018", "\u2019")) and "'" in styled: + styled = _curly_single_quotes(styled) + return styled + + +def _leading_ws(line: str) -> str: + return line[: len(line) - len(line.lstrip(" \t"))] + + +def _reindent_like_match(old_text: str, actual_text: str, new_text: str) -> str: + """Preserve the outer indentation from the actual matched block.""" + old_lines = old_text.split("\n") + actual_lines = actual_text.split("\n") + if len(old_lines) != len(actual_lines): + return new_text + + comparable = [ + (old_line, actual_line) + for old_line, actual_line in zip(old_lines, actual_lines) + if old_line.strip() and actual_line.strip() + ] + if not comparable or any( + _normalize_quotes(old_line.strip()) != _normalize_quotes(actual_line.strip()) + for old_line, actual_line in comparable + ): + return new_text + + old_ws = _leading_ws(comparable[0][0]) + actual_ws = _leading_ws(comparable[0][1]) + if actual_ws == old_ws: + return new_text + + if old_ws: + if not actual_ws.startswith(old_ws): + return new_text + delta = actual_ws[len(old_ws):] + else: + delta = actual_ws + + if not delta: + return new_text + + return "\n".join((delta + line) if line else line for line in new_text.split("\n")) + + +@dataclass(slots=True) +class _MatchSpan: + start: int + end: int + text: str + line: int + + +def _find_exact_matches(content: str, old_text: str) -> list[_MatchSpan]: + matches: list[_MatchSpan] = [] + start = 0 + while True: + idx = content.find(old_text, start) + if idx == -1: + break + matches.append( + _MatchSpan( + start=idx, + end=idx + len(old_text), + text=content[idx : idx + len(old_text)], + line=content.count("\n", 0, idx) + 1, + ) + ) + start = idx + max(1, len(old_text)) + return matches + + +def _find_trim_matches(content: str, old_text: str, *, normalize_quotes: bool = False) -> list[_MatchSpan]: + old_lines = old_text.splitlines() + if not old_lines: + return [] + + content_lines = content.splitlines() + content_lines_keepends = content.splitlines(keepends=True) + if len(content_lines) < len(old_lines): + return [] + + offsets: list[int] = [] + pos = 0 + for line in content_lines_keepends: + offsets.append(pos) + pos += len(line) + offsets.append(pos) + + if normalize_quotes: + stripped_old = [_normalize_quotes(line.strip()) for line in old_lines] + else: + stripped_old = [line.strip() for line in old_lines] + + matches: list[_MatchSpan] = [] + window_size = len(stripped_old) + for i in range(len(content_lines) - window_size + 1): + window = content_lines[i : i + window_size] + if normalize_quotes: + comparable = [_normalize_quotes(line.strip()) for line in window] + else: + comparable = [line.strip() for line in window] + if comparable != stripped_old: + continue + + start = offsets[i] + end = offsets[i + window_size] + if content_lines_keepends[i + window_size - 1].endswith("\n"): + end -= 1 + matches.append( + _MatchSpan( + start=start, + end=end, + text=content[start:end], + line=i + 1, + ) + ) + return matches + + +def _find_quote_matches(content: str, old_text: str) -> list[_MatchSpan]: + norm_content = _normalize_quotes(content) + norm_old = _normalize_quotes(old_text) + matches: list[_MatchSpan] = [] + start = 0 + while True: + idx = norm_content.find(norm_old, start) + if idx == -1: + break + matches.append( + _MatchSpan( + start=idx, + end=idx + len(old_text), + text=content[idx : idx + len(old_text)], + line=content.count("\n", 0, idx) + 1, + ) + ) + start = idx + max(1, len(norm_old)) + return matches + + +def _find_matches(content: str, old_text: str) -> list[_MatchSpan]: + """Locate all matches using progressively looser strategies.""" + for matcher in ( + lambda: _find_exact_matches(content, old_text), + lambda: _find_trim_matches(content, old_text), + lambda: _find_trim_matches(content, old_text, normalize_quotes=True), + lambda: _find_quote_matches(content, old_text), + ): + matches = matcher() + if matches: + return matches + return [] + + +def _find_match_line_numbers(content: str, old_text: str) -> list[int]: + """Return 1-based starting line numbers for the current matching strategies.""" + return [match.line for match in _find_matches(content, old_text)] + + +def _collapse_internal_whitespace(text: str) -> str: + return "\n".join(" ".join(line.split()) for line in text.splitlines()) + + +def _diagnose_near_match(old_text: str, actual_text: str) -> list[str]: + """Return actionable hints describing why text was close but not exact.""" + hints: list[str] = [] + + if old_text.lower() == actual_text.lower() and old_text != actual_text: + hints.append("letter case differs") + if _collapse_internal_whitespace(old_text) == _collapse_internal_whitespace(actual_text) and old_text != actual_text: + hints.append("whitespace differs") + if old_text.rstrip("\n") == actual_text.rstrip("\n") and old_text != actual_text: + hints.append("trailing newline differs") + if _normalize_quotes(old_text) == _normalize_quotes(actual_text) and old_text != actual_text: + hints.append("quote style differs") + + return hints + + +def _best_window(old_text: str, content: str) -> tuple[float, int, list[str], list[str]]: + """Find the closest line-window match and return ratio/start/snippet/hints.""" + lines = content.splitlines(keepends=True) + old_lines = old_text.splitlines(keepends=True) + window = max(1, len(old_lines)) + + best_ratio, best_start = -1.0, 0 + best_window_lines: list[str] = [] + + for i in range(max(1, len(lines) - window + 1)): + current = lines[i : i + window] + ratio = difflib.SequenceMatcher(None, old_lines, current).ratio() + if ratio > best_ratio: + best_ratio, best_start = ratio, i + best_window_lines = current + + actual_text = "".join(best_window_lines).replace("\r\n", "\n").rstrip("\n") + hints = _diagnose_near_match(old_text.replace("\r\n", "\n").rstrip("\n"), actual_text) + return best_ratio, best_start, best_window_lines, hints + + def _find_match(content: str, old_text: str) -> tuple[str | None, int]: - """Locate old_text in content: exact first, then line-trimmed sliding window. + """Locate old_text in content with a multi-level fallback chain: + + 1. Exact substring match + 2. Line-trimmed sliding window (handles indentation differences) + 3. Smart quote normalization (curly ↔ straight quotes) Both inputs should use LF line endings (caller normalises CRLF). Returns (matched_fragment, count) or (None, 0). """ - if old_text in content: - return old_text, content.count(old_text) - - old_lines = old_text.splitlines() - if not old_lines: + matches = _find_matches(content, old_text) + if not matches: return None, 0 - stripped_old = [l.strip() for l in old_lines] - content_lines = content.splitlines() - - candidates = [] - for i in range(len(content_lines) - len(stripped_old) + 1): - window = content_lines[i : i + len(stripped_old)] - if [l.strip() for l in window] == stripped_old: - candidates.append("\n".join(window)) - - if candidates: - return candidates[0], len(candidates) - return None, 0 + return matches[0].text, len(matches) @tool_parameters( @@ -241,6 +578,9 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]: class EditFileTool(_FsTool): """Edit a file by replacing text with fallback matching.""" + _MAX_EDIT_FILE_SIZE = 1024 * 1024 * 1024 # 1 GiB + _MARKDOWN_EXTS = frozenset({".md", ".mdx", ".markdown"}) + @property def name(self) -> str: return "edit_file" @@ -249,11 +589,16 @@ class EditFileTool(_FsTool): def description(self) -> str: return ( "Edit a file by replacing old_text with new_text. " - "Tolerates minor whitespace/indentation differences. " + "Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. " "If old_text matches multiple times, you must provide more context " "or set replace_all=true. Shows a diff of the closest match on failure." ) + @staticmethod + def _strip_trailing_ws(text: str) -> str: + """Strip trailing whitespace from each line.""" + return "\n".join(line.rstrip() for line in text.split("\n")) + async def execute( self, path: str | None = None, old_text: str | None = None, new_text: str | None = None, @@ -267,55 +612,133 @@ class EditFileTool(_FsTool): if new_text is None: raise ValueError("Unknown new_text") + # .ipynb detection + if path.endswith(".ipynb"): + return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file." + fp = self._resolve(path) + + # Create-file semantics: old_text='' + file doesn't exist → create if not fp.exists(): - return f"Error: File not found: {path}" + if old_text == "": + fp.parent.mkdir(parents=True, exist_ok=True) + fp.write_text(new_text, encoding="utf-8") + file_state.record_write(fp) + return f"Successfully created {fp}" + return self._file_not_found_msg(path, fp) + + # File size protection + try: + fsize = fp.stat().st_size + except OSError: + fsize = 0 + if fsize > self._MAX_EDIT_FILE_SIZE: + return f"Error: File too large to edit ({fsize / (1024**3):.1f} GiB). Maximum is 1 GiB." + + # Create-file: old_text='' but file exists and not empty → reject + if old_text == "": + raw = fp.read_bytes() + content = raw.decode("utf-8") + if content.strip(): + return f"Error: Cannot create file — {path} already exists and is not empty." + fp.write_text(new_text, encoding="utf-8") + file_state.record_write(fp) + return f"Successfully edited {fp}" + + # Read-before-edit check + warning = file_state.check_read(fp) raw = fp.read_bytes() uses_crlf = b"\r\n" in raw content = raw.decode("utf-8").replace("\r\n", "\n") - match, count = _find_match(content, old_text.replace("\r\n", "\n")) + norm_old = old_text.replace("\r\n", "\n") + matches = _find_matches(content, norm_old) - if match is None: + if not matches: return self._not_found_msg(old_text, content, path) + count = len(matches) if count > 1 and not replace_all: + line_numbers = [match.line for match in matches] + preview = ", ".join(f"line {n}" for n in line_numbers[:3]) + if len(line_numbers) > 3: + preview += ", ..." + location_hint = f" at {preview}" if preview else "" return ( - f"Warning: old_text appears {count} times. " + f"Warning: old_text appears {count} times{location_hint}. " "Provide more context to make it unique, or set replace_all=true." ) norm_new = new_text.replace("\r\n", "\n") - new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1) + + # Trailing whitespace stripping (skip markdown to preserve double-space line breaks) + if fp.suffix.lower() not in self._MARKDOWN_EXTS: + norm_new = self._strip_trailing_ws(norm_new) + + selected = matches if replace_all else matches[:1] + new_content = content + for match in reversed(selected): + replacement = _preserve_quote_style(norm_old, match.text, norm_new) + replacement = _reindent_like_match(norm_old, match.text, replacement) + + # Delete-line cleanup: when deleting text (new_text=''), consume trailing + # newline to avoid leaving a blank line + end = match.end + if replacement == "" and not match.text.endswith("\n") and content[end:end + 1] == "\n": + end += 1 + + new_content = new_content[: match.start] + replacement + new_content[end:] if uses_crlf: new_content = new_content.replace("\n", "\r\n") fp.write_bytes(new_content.encode("utf-8")) - return f"Successfully edited {fp}" + file_state.record_write(fp) + msg = f"Successfully edited {fp}" + if warning: + msg = f"{warning}\n{msg}" + return msg except PermissionError as e: return f"Error: {e}" except Exception as e: return f"Error editing file: {e}" + def _file_not_found_msg(self, path: str, fp: Path) -> str: + """Build an error message with 'Did you mean ...?' suggestions.""" + parent = fp.parent + suggestions: list[str] = [] + if parent.is_dir(): + siblings = [f.name for f in parent.iterdir() if f.is_file()] + close = difflib.get_close_matches(fp.name, siblings, n=3, cutoff=0.6) + suggestions = [str(parent / c) for c in close] + parts = [f"Error: File not found: {path}"] + if suggestions: + parts.append("Did you mean: " + ", ".join(suggestions) + "?") + return "\n".join(parts) + @staticmethod def _not_found_msg(old_text: str, content: str, path: str) -> str: - lines = content.splitlines(keepends=True) - old_lines = old_text.splitlines(keepends=True) - window = len(old_lines) - - best_ratio, best_start = 0.0, 0 - for i in range(max(1, len(lines) - window + 1)): - ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio() - if ratio > best_ratio: - best_ratio, best_start = ratio, i - + best_ratio, best_start, best_window_lines, hints = _best_window(old_text, content) if best_ratio > 0.5: diff = "\n".join(difflib.unified_diff( - old_lines, lines[best_start : best_start + window], + old_text.splitlines(keepends=True), + best_window_lines, fromfile="old_text (provided)", tofile=f"{path} (actual, line {best_start + 1})", lineterm="", )) - return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" + hint_text = "" + if hints: + hint_text = "\nPossible cause: " + ", ".join(hints) + "." + return ( + f"Error: old_text not found in {path}." + f"{hint_text}\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}" + ) + + if hints: + return ( + f"Error: old_text not found in {path}. " + f"Possible cause: {', '.join(hints)}. " + "Copy the exact text from read_file and try again." + ) return f"Error: old_text not found in {path}. No similar text found. Verify the file content." diff --git a/nanobot/agent/tools/notebook.py b/nanobot/agent/tools/notebook.py new file mode 100644 index 000000000..8c4be110d --- /dev/null +++ b/nanobot/agent/tools/notebook.py @@ -0,0 +1,162 @@ +"""NotebookEditTool — edit Jupyter .ipynb notebooks.""" + +from __future__ import annotations + +import json +import uuid +from pathlib import Path +from typing import Any + +from nanobot.agent.tools.base import Tool, tool_parameters +from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema +from nanobot.agent.tools.filesystem import _FsTool + + +def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict: + cell: dict[str, Any] = { + "cell_type": cell_type, + "source": source, + "metadata": {}, + } + if cell_type == "code": + cell["outputs"] = [] + cell["execution_count"] = None + if generate_id: + cell["id"] = uuid.uuid4().hex[:8] + return cell + + +def _make_empty_notebook() -> dict: + return { + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, + "language_info": {"name": "python"}, + }, + "cells": [], + } + + +@tool_parameters( + tool_parameters_schema( + path=StringSchema("Path to the .ipynb notebook file"), + cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0), + new_source=StringSchema("New source content for the cell"), + cell_type=StringSchema( + "Cell type: 'code' or 'markdown' (default: code)", + enum=["code", "markdown"], + ), + edit_mode=StringSchema( + "Mode: 'replace' (default), 'insert' (after target), or 'delete'", + enum=["replace", "insert", "delete"], + ), + required=["path", "cell_index"], + ) +) +class NotebookEditTool(_FsTool): + """Edit Jupyter notebook cells: replace, insert, or delete.""" + + _VALID_CELL_TYPES = frozenset({"code", "markdown"}) + _VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"}) + + @property + def name(self) -> str: + return "notebook_edit" + + @property + def description(self) -> str: + return ( + "Edit a Jupyter notebook (.ipynb) cell. " + "Modes: replace (default) replaces cell content, " + "insert adds a new cell after the target index, " + "delete removes the cell at the index. " + "cell_index is 0-based." + ) + + async def execute( + self, + path: str | None = None, + cell_index: int = 0, + new_source: str = "", + cell_type: str = "code", + edit_mode: str = "replace", + **kwargs: Any, + ) -> str: + try: + if not path: + return "Error: path is required" + + if not path.endswith(".ipynb"): + return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files." + + if edit_mode not in self._VALID_EDIT_MODES: + return ( + f"Error: Invalid edit_mode '{edit_mode}'. " + "Use one of: replace, insert, delete." + ) + + if cell_type not in self._VALID_CELL_TYPES: + return ( + f"Error: Invalid cell_type '{cell_type}'. " + "Use one of: code, markdown." + ) + + fp = self._resolve(path) + + # Create new notebook if file doesn't exist and mode is insert + if not fp.exists(): + if edit_mode != "insert": + return f"Error: File not found: {path}" + nb = _make_empty_notebook() + cell = _new_cell(new_source, cell_type, generate_id=True) + nb["cells"].append(cell) + fp.parent.mkdir(parents=True, exist_ok=True) + fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") + return f"Successfully created {fp} with 1 cell" + + try: + nb = json.loads(fp.read_text(encoding="utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + return f"Error: Failed to parse notebook: {e}" + + cells = nb.get("cells", []) + nbformat_minor = nb.get("nbformat_minor", 0) + generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5 + + if edit_mode == "delete": + if cell_index < 0 or cell_index >= len(cells): + return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)" + cells.pop(cell_index) + nb["cells"] = cells + fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") + return f"Successfully deleted cell {cell_index} from {fp}" + + if edit_mode == "insert": + insert_at = min(cell_index + 1, len(cells)) + cell = _new_cell(new_source, cell_type, generate_id=generate_id) + cells.insert(insert_at, cell) + nb["cells"] = cells + fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") + return f"Successfully inserted cell at index {insert_at} in {fp}" + + # Default: replace + if cell_index < 0 or cell_index >= len(cells): + return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)" + cells[cell_index]["source"] = new_source + if cell_type and cells[cell_index].get("cell_type") != cell_type: + cells[cell_index]["cell_type"] = cell_type + if cell_type == "code": + cells[cell_index].setdefault("outputs", []) + cells[cell_index].setdefault("execution_count", None) + elif "outputs" in cells[cell_index]: + del cells[cell_index]["outputs"] + cells[cell_index].pop("execution_count", None) + nb["cells"] = cells + fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8") + return f"Successfully edited cell {cell_index} in {fp}" + + except PermissionError as e: + return f"Error: {e}" + except Exception as e: + return f"Error editing notebook: {e}" diff --git a/pyproject.toml b/pyproject.toml index 751716135..b2a25bfad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,12 +76,16 @@ discord = [ langsmith = [ "langsmith>=0.1.0", ] +pdf = [ + "pymupdf>=1.25.0", +] dev = [ "pytest>=9.0.0,<10.0.0", "pytest-asyncio>=1.3.0,<2.0.0", "aiohttp>=3.9.0,<4.0.0", "pytest-cov>=6.0.0,<7.0.0", "ruff>=0.1.0", + "pymupdf>=1.25.0", ] [project.scripts] diff --git a/tests/tools/test_edit_advanced.py b/tests/tools/test_edit_advanced.py new file mode 100644 index 000000000..baf0eb02f --- /dev/null +++ b/tests/tools/test_edit_advanced.py @@ -0,0 +1,423 @@ +"""Tests for advanced EditFileTool enhancements inspired by claude-code: +- Delete-line newline cleanup +- Smart quote normalization (curly ↔ straight) +- Quote style preservation in replacements +- Indentation preservation when fallback match is trimmed +- Trailing whitespace stripping for new_text +- File size protection +- Stale detection with content-equality fallback +""" + +import os +import time + +import pytest + +from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, _find_match +from nanobot.agent.tools import file_state + + +@pytest.fixture(autouse=True) +def _clear_file_state(): + file_state.clear() + yield + file_state.clear() + + +# --------------------------------------------------------------------------- +# Delete-line newline cleanup +# --------------------------------------------------------------------------- + + +class TestDeleteLineCleanup: + """When new_text='' and deleting a line, trailing newline should be consumed.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_delete_line_consumes_trailing_newline(self, tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("line1\nline2\nline3\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="line2", new_text="") + assert "Successfully" in result + content = f.read_text() + # Should not leave a blank line where line2 was + assert content == "line1\nline3\n" + + @pytest.mark.asyncio + async def test_delete_line_with_explicit_newline_in_old_text(self, tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("line1\nline2\nline3\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="line2\n", new_text="") + assert "Successfully" in result + assert f.read_text() == "line1\nline3\n" + + @pytest.mark.asyncio + async def test_delete_preserves_content_when_not_trailing_newline(self, tool, tmp_path): + """Deleting a word mid-line should not consume extra characters.""" + f = tmp_path / "a.py" + f.write_text("hello world here\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="world ", new_text="") + assert "Successfully" in result + assert f.read_text() == "hello here\n" + + +# --------------------------------------------------------------------------- +# Smart quote normalization +# --------------------------------------------------------------------------- + + +class TestSmartQuoteNormalization: + """_find_match should handle curly ↔ straight quote fallback.""" + + def test_curly_double_quotes_match_straight(self): + content = 'She said \u201chello\u201d to him' + old_text = 'She said "hello" to him' + match, count = _find_match(content, old_text) + assert match is not None + assert count == 1 + # Returned match should be the ORIGINAL content with curly quotes + assert "\u201c" in match + + def test_curly_single_quotes_match_straight(self): + content = "it\u2019s a test" + old_text = "it's a test" + match, count = _find_match(content, old_text) + assert match is not None + assert count == 1 + assert "\u2019" in match + + def test_straight_matches_curly_in_old_text(self): + content = 'x = "hello"' + old_text = 'x = \u201chello\u201d' + match, count = _find_match(content, old_text) + assert match is not None + assert count == 1 + + def test_exact_match_still_preferred_over_quote_normalization(self): + content = 'x = "hello"' + old_text = 'x = "hello"' + match, count = _find_match(content, old_text) + assert match == old_text + assert count == 1 + + +class TestQuoteStylePreservation: + """When quote-normalized matching occurs, replacement should preserve actual quote style.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_replacement_preserves_curly_double_quotes(self, tool, tmp_path): + f = tmp_path / "quotes.txt" + f.write_text('message = “hello”\n', encoding="utf-8") + result = await tool.execute( + path=str(f), + old_text='message = "hello"', + new_text='message = "goodbye"', + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == 'message = “goodbye”\n' + + @pytest.mark.asyncio + async def test_replacement_preserves_curly_apostrophe(self, tool, tmp_path): + f = tmp_path / "apostrophe.txt" + f.write_text("it’s fine\n", encoding="utf-8") + result = await tool.execute( + path=str(f), + old_text="it's fine", + new_text="it's better", + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == "it’s better\n" + + +# --------------------------------------------------------------------------- +# Indentation preservation +# --------------------------------------------------------------------------- + + +class TestIndentationPreservation: + """Replacement should keep outer indentation when trim fallback matched.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_trim_fallback_preserves_outer_indentation(self, tool, tmp_path): + f = tmp_path / "indent.py" + f.write_text( + "if True:\n" + " def foo():\n" + " pass\n", + encoding="utf-8", + ) + result = await tool.execute( + path=str(f), + old_text="def foo():\n pass", + new_text="def bar():\n return 1", + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == ( + "if True:\n" + " def bar():\n" + " return 1\n" + ) + + +# --------------------------------------------------------------------------- +# Failure diagnostics +# --------------------------------------------------------------------------- + + +class TestEditDiagnostics: + """Failure paths should offer actionable hints.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_ambiguous_match_reports_candidate_lines(self, tool, tmp_path): + f = tmp_path / "dup.py" + f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx") + assert "appears 2 times" in result.lower() + assert "line 1" in result.lower() + assert "line 3" in result.lower() + assert "replace_all=true" in result + + @pytest.mark.asyncio + async def test_not_found_reports_whitespace_hint(self, tool, tmp_path): + f = tmp_path / "space.py" + f.write_text("value = 1\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="value = 1", new_text="value = 2") + assert "Error" in result + assert "whitespace" in result.lower() + + @pytest.mark.asyncio + async def test_not_found_reports_case_hint(self, tool, tmp_path): + f = tmp_path / "case.py" + f.write_text("HelloWorld\n", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="helloworld", new_text="goodbye") + assert "Error" in result + assert "letter case differs" in result.lower() + + +# --------------------------------------------------------------------------- +# Advanced fallback replacement behavior +# --------------------------------------------------------------------------- + + +class TestAdvancedReplaceAll: + """replace_all should work correctly for fallback-based matches too.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path): + f = tmp_path / "indent_multi.py" + f.write_text( + "if a:\n" + " def foo():\n" + " pass\n" + "if b:\n" + " def foo():\n" + " pass\n", + encoding="utf-8", + ) + result = await tool.execute( + path=str(f), + old_text="def foo():\n pass", + new_text="def bar():\n return 1", + replace_all=True, + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == ( + "if a:\n" + " def bar():\n" + " return 1\n" + "if b:\n" + " def bar():\n" + " return 1\n" + ) + + @pytest.mark.asyncio + async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path): + f = tmp_path / "quote_indent.py" + f.write_text(" message = “hello”\n", encoding="utf-8") + result = await tool.execute( + path=str(f), + old_text='message = "hello"', + new_text='message = "goodbye"', + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == " message = “goodbye”\n" + + +# --------------------------------------------------------------------------- +# Advanced fallback replacement behavior +# --------------------------------------------------------------------------- + + +class TestAdvancedReplaceAll: + """replace_all should work correctly for fallback-based matches too.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path): + f = tmp_path / "indent_multi.py" + f.write_text( + "if a:\n" + " def foo():\n" + " pass\n" + "if b:\n" + " def foo():\n" + " pass\n", + encoding="utf-8", + ) + result = await tool.execute( + path=str(f), + old_text="def foo():\n pass", + new_text="def bar():\n return 1", + replace_all=True, + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == ( + "if a:\n" + " def bar():\n" + " return 1\n" + "if b:\n" + " def bar():\n" + " return 1\n" + ) + + @pytest.mark.asyncio + async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path): + f = tmp_path / "quote_indent.py" + f.write_text(" message = “hello”\n", encoding="utf-8") + result = await tool.execute( + path=str(f), + old_text='message = "hello"', + new_text='message = "goodbye"', + ) + assert "Successfully" in result + assert f.read_text(encoding="utf-8") == " message = “goodbye”\n" + + +# --------------------------------------------------------------------------- +# Trailing whitespace stripping on new_text +# --------------------------------------------------------------------------- + + +class TestTrailingWhitespaceStrip: + """new_text trailing whitespace should be stripped (except .md files).""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_strips_trailing_whitespace_from_new_text(self, tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("x = 1\n", encoding="utf-8") + result = await tool.execute( + path=str(f), old_text="x = 1", new_text="x = 2 \ny = 3 ", + ) + assert "Successfully" in result + content = f.read_text() + assert "x = 2\ny = 3\n" == content + + @pytest.mark.asyncio + async def test_preserves_trailing_whitespace_in_markdown(self, tool, tmp_path): + f = tmp_path / "doc.md" + f.write_text("# Title\n", encoding="utf-8") + # Markdown uses trailing double-space for line breaks + result = await tool.execute( + path=str(f), old_text="# Title", new_text="# Title \nSubtitle ", + ) + assert "Successfully" in result + content = f.read_text() + # Trailing spaces should be preserved for markdown + assert "Title " in content + assert "Subtitle " in content + + +# --------------------------------------------------------------------------- +# File size protection +# --------------------------------------------------------------------------- + + +class TestFileSizeProtection: + """Editing extremely large files should be rejected.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_rejects_file_over_size_limit(self, tool, tmp_path): + f = tmp_path / "huge.txt" + f.write_text("x", encoding="utf-8") + # Monkey-patch the file size check by creating a stat mock + original_stat = f.stat + + class FakeStat: + def __init__(self, real_stat): + self._real = real_stat + + def __getattr__(self, name): + return getattr(self._real, name) + + @property + def st_size(self): + return 2 * 1024 * 1024 * 1024 # 2 GiB + + import unittest.mock + with unittest.mock.patch.object(type(f), 'stat', return_value=FakeStat(f.stat())): + result = await tool.execute(path=str(f), old_text="x", new_text="y") + assert "Error" in result + assert "too large" in result.lower() or "size" in result.lower() + + +# --------------------------------------------------------------------------- +# Stale detection with content-equality fallback +# --------------------------------------------------------------------------- + + +class TestStaleDetectionContentFallback: + """When mtime changed but file content is unchanged, edit should proceed without warning.""" + + @pytest.fixture() + def read_tool(self, tmp_path): + return ReadFileTool(workspace=tmp_path) + + @pytest.fixture() + def edit_tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_mtime_bump_same_content_no_warning(self, read_tool, edit_tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("hello world", encoding="utf-8") + await read_tool.execute(path=str(f)) + + # Touch the file to bump mtime without changing content + time.sleep(0.05) + original_content = f.read_text() + f.write_text(original_content, encoding="utf-8") + + result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth") + assert "Successfully" in result + # Should NOT warn about modification since content is the same + assert "modified" not in result.lower() diff --git a/tests/tools/test_edit_enhancements.py b/tests/tools/test_edit_enhancements.py new file mode 100644 index 000000000..7ad098960 --- /dev/null +++ b/tests/tools/test_edit_enhancements.py @@ -0,0 +1,152 @@ +"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions, +.ipynb detection, and create-file semantics.""" + +import pytest + +from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool +from nanobot.agent.tools import file_state + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _clear_file_state(): + """Reset global read-state between tests.""" + file_state.clear() + yield + file_state.clear() + + +# --------------------------------------------------------------------------- +# Read-before-edit tracking +# --------------------------------------------------------------------------- + +class TestEditReadTracking: + """edit_file should warn when file hasn't been read first.""" + + @pytest.fixture() + def read_tool(self, tmp_path): + return ReadFileTool(workspace=tmp_path) + + @pytest.fixture() + def edit_tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_edit_warns_if_file_not_read_first(self, edit_tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("hello world", encoding="utf-8") + result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth") + # Should still succeed but include a warning + assert "Successfully" in result + assert "not been read" in result.lower() or "warning" in result.lower() + + @pytest.mark.asyncio + async def test_edit_succeeds_cleanly_after_read(self, read_tool, edit_tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("hello world", encoding="utf-8") + await read_tool.execute(path=str(f)) + result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth") + assert "Successfully" in result + # No warning when file was read first + assert "not been read" not in result.lower() + assert f.read_text() == "hello earth" + + @pytest.mark.asyncio + async def test_edit_warns_if_file_modified_since_read(self, read_tool, edit_tool, tmp_path): + f = tmp_path / "a.py" + f.write_text("hello world", encoding="utf-8") + await read_tool.execute(path=str(f)) + # External modification + f.write_text("hello universe", encoding="utf-8") + result = await edit_tool.execute(path=str(f), old_text="universe", new_text="earth") + assert "Successfully" in result + assert "modified" in result.lower() or "warning" in result.lower() + + +# --------------------------------------------------------------------------- +# Create-file semantics +# --------------------------------------------------------------------------- + +class TestEditCreateFile: + """edit_file with old_text='' creates new file if not exists.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_create_new_file_with_empty_old_text(self, tool, tmp_path): + f = tmp_path / "subdir" / "new.py" + result = await tool.execute(path=str(f), old_text="", new_text="print('hi')") + assert "created" in result.lower() or "Successfully" in result + assert f.exists() + assert f.read_text() == "print('hi')" + + @pytest.mark.asyncio + async def test_create_fails_if_file_already_exists_and_not_empty(self, tool, tmp_path): + f = tmp_path / "existing.py" + f.write_text("existing content", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="", new_text="new content") + assert "Error" in result or "already exists" in result.lower() + # File should be unchanged + assert f.read_text() == "existing content" + + @pytest.mark.asyncio + async def test_create_succeeds_if_file_exists_but_empty(self, tool, tmp_path): + f = tmp_path / "empty.py" + f.write_text("", encoding="utf-8") + result = await tool.execute(path=str(f), old_text="", new_text="print('hi')") + assert "Successfully" in result + assert f.read_text() == "print('hi')" + + +# --------------------------------------------------------------------------- +# .ipynb detection +# --------------------------------------------------------------------------- + +class TestEditIpynbDetection: + """edit_file should refuse .ipynb and suggest notebook_edit.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path): + f = tmp_path / "analysis.ipynb" + f.write_text('{"cells": []}', encoding="utf-8") + result = await tool.execute(path=str(f), old_text="x", new_text="y") + assert "notebook" in result.lower() + + +# --------------------------------------------------------------------------- +# Path suggestion on not-found +# --------------------------------------------------------------------------- + +class TestEditPathSuggestion: + """edit_file should suggest similar paths on not-found.""" + + @pytest.fixture() + def tool(self, tmp_path): + return EditFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_suggests_similar_filename(self, tool, tmp_path): + f = tmp_path / "config.py" + f.write_text("x = 1", encoding="utf-8") + # Typo: conifg.py + result = await tool.execute( + path=str(tmp_path / "conifg.py"), old_text="x = 1", new_text="x = 2", + ) + assert "Error" in result + assert "config.py" in result + + @pytest.mark.asyncio + async def test_shows_cwd_in_error(self, tool, tmp_path): + result = await tool.execute( + path=str(tmp_path / "nonexistent.py"), old_text="a", new_text="b", + ) + assert "Error" in result diff --git a/tests/tools/test_notebook_tool.py b/tests/tools/test_notebook_tool.py new file mode 100644 index 000000000..232f13c4b --- /dev/null +++ b/tests/tools/test_notebook_tool.py @@ -0,0 +1,147 @@ +"""Tests for NotebookEditTool — Jupyter .ipynb editing.""" + +import json + +import pytest + +from nanobot.agent.tools.notebook import NotebookEditTool + + +def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict: + """Build a minimal valid .ipynb structure.""" + return { + "nbformat": nbformat, + "nbformat_minor": nbformat_minor, + "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}}, + "cells": cells or [], + } + + +def _code_cell(source: str, cell_id: str | None = None) -> dict: + cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None} + if cell_id: + cell["id"] = cell_id + return cell + + +def _md_cell(source: str, cell_id: str | None = None) -> dict: + cell = {"cell_type": "markdown", "source": source, "metadata": {}} + if cell_id: + cell["id"] = cell_id + return cell + + +def _write_nb(tmp_path, name: str, nb: dict) -> str: + p = tmp_path / name + p.write_text(json.dumps(nb), encoding="utf-8") + return str(p) + + +class TestNotebookEdit: + + @pytest.fixture() + def tool(self, tmp_path): + return NotebookEditTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_replace_cell_content(self, tool, tmp_path): + nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=0, new_source="print('world')") + assert "Successfully" in result + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert saved["cells"][0]["source"] == "print('world')" + assert saved["cells"][1]["source"] == "x = 1" + + @pytest.mark.asyncio + async def test_insert_cell_after_target(self, tool, tmp_path): + nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert") + assert "Successfully" in result + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert len(saved["cells"]) == 3 + assert saved["cells"][0]["source"] == "cell 0" + assert saved["cells"][1]["source"] == "inserted" + assert saved["cells"][2]["source"] == "cell 1" + + @pytest.mark.asyncio + async def test_delete_cell(self, tool, tmp_path): + nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=1, edit_mode="delete") + assert "Successfully" in result + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert len(saved["cells"]) == 2 + assert saved["cells"][0]["source"] == "A" + assert saved["cells"][1]["source"] == "C" + + @pytest.mark.asyncio + async def test_create_new_notebook_from_scratch(self, tool, tmp_path): + path = str(tmp_path / "new.ipynb") + result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown") + assert "Successfully" in result or "created" in result.lower() + saved = json.loads((tmp_path / "new.ipynb").read_text()) + assert saved["nbformat"] == 4 + assert len(saved["cells"]) == 1 + assert saved["cells"][0]["cell_type"] == "markdown" + assert saved["cells"][0]["source"] == "# Hello" + + @pytest.mark.asyncio + async def test_invalid_cell_index_error(self, tool, tmp_path): + nb = _make_notebook([_code_cell("only cell")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=5, new_source="x") + assert "Error" in result + + @pytest.mark.asyncio + async def test_non_ipynb_rejected(self, tool, tmp_path): + f = tmp_path / "script.py" + f.write_text("pass") + result = await tool.execute(path=str(f), cell_index=0, new_source="x") + assert "Error" in result + assert ".ipynb" in result + + @pytest.mark.asyncio + async def test_preserves_metadata_and_outputs(self, tool, tmp_path): + cell = _code_cell("old") + cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}] + cell["execution_count"] = 42 + nb = _make_notebook([cell]) + path = _write_nb(tmp_path, "test.ipynb", nb) + await tool.execute(path=path, cell_index=0, new_source="new") + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert saved["metadata"]["kernelspec"]["language"] == "python" + + @pytest.mark.asyncio + async def test_nbformat_45_generates_cell_id(self, tool, tmp_path): + nb = _make_notebook([], nbformat_minor=5) + path = _write_nb(tmp_path, "test.ipynb", nb) + await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert") + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert "id" in saved["cells"][0] + assert len(saved["cells"][0]["id"]) > 0 + + @pytest.mark.asyncio + async def test_insert_with_cell_type_markdown(self, tool, tmp_path): + nb = _make_notebook([_code_cell("code")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown") + saved = json.loads((tmp_path / "test.ipynb").read_text()) + assert saved["cells"][1]["cell_type"] == "markdown" + + @pytest.mark.asyncio + async def test_invalid_edit_mode_rejected(self, tool, tmp_path): + nb = _make_notebook([_code_cell("code")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae") + assert "Error" in result + assert "edit_mode" in result + + @pytest.mark.asyncio + async def test_invalid_cell_type_rejected(self, tool, tmp_path): + nb = _make_notebook([_code_cell("code")]) + path = _write_nb(tmp_path, "test.ipynb", nb) + result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw") + assert "Error" in result + assert "cell_type" in result diff --git a/tests/tools/test_read_enhancements.py b/tests/tools/test_read_enhancements.py new file mode 100644 index 000000000..a703ba6e4 --- /dev/null +++ b/tests/tools/test_read_enhancements.py @@ -0,0 +1,180 @@ +"""Tests for ReadFileTool enhancements: description fix, read dedup, PDF support, device blacklist.""" + +import pytest + +from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool +from nanobot.agent.tools import file_state + + +@pytest.fixture(autouse=True) +def _clear_file_state(): + file_state.clear() + yield + file_state.clear() + + +# --------------------------------------------------------------------------- +# Description fix +# --------------------------------------------------------------------------- + +class TestReadDescriptionFix: + + def test_description_mentions_image_support(self): + tool = ReadFileTool() + assert "image" in tool.description.lower() + + def test_description_no_longer_says_cannot_read_images(self): + tool = ReadFileTool() + assert "cannot read binary files or images" not in tool.description.lower() + + +# --------------------------------------------------------------------------- +# Read deduplication +# --------------------------------------------------------------------------- + +class TestReadDedup: + """Same file + same offset/limit + unchanged mtime -> short stub.""" + + @pytest.fixture() + def tool(self, tmp_path): + return ReadFileTool(workspace=tmp_path) + + @pytest.fixture() + def write_tool(self, tmp_path): + return WriteFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_second_read_returns_unchanged_stub(self, tool, tmp_path): + f = tmp_path / "data.txt" + f.write_text("\n".join(f"line {i}" for i in range(100)), encoding="utf-8") + first = await tool.execute(path=str(f)) + assert "line 0" in first + second = await tool.execute(path=str(f)) + assert "unchanged" in second.lower() + # Stub should not contain file content + assert "line 0" not in second + + @pytest.mark.asyncio + async def test_read_after_external_modification_returns_full(self, tool, tmp_path): + f = tmp_path / "data.txt" + f.write_text("original", encoding="utf-8") + await tool.execute(path=str(f)) + # Modify the file externally + f.write_text("modified content", encoding="utf-8") + second = await tool.execute(path=str(f)) + assert "modified content" in second + + @pytest.mark.asyncio + async def test_different_offset_returns_full(self, tool, tmp_path): + f = tmp_path / "data.txt" + f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8") + await tool.execute(path=str(f), offset=1, limit=5) + second = await tool.execute(path=str(f), offset=6, limit=5) + # Different offset → full read, not stub + assert "line 6" in second + + @pytest.mark.asyncio + async def test_first_read_after_write_returns_full_content(self, tool, write_tool, tmp_path): + f = tmp_path / "fresh.txt" + result = await write_tool.execute(path=str(f), content="hello") + assert "Successfully" in result + read_result = await tool.execute(path=str(f)) + assert "hello" in read_result + assert "unchanged" not in read_result.lower() + + @pytest.mark.asyncio + async def test_dedup_does_not_apply_to_images(self, tool, tmp_path): + f = tmp_path / "img.png" + f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data") + first = await tool.execute(path=str(f)) + assert isinstance(first, list) + second = await tool.execute(path=str(f)) + # Images should always return full content blocks, not a stub + assert isinstance(second, list) + + +# --------------------------------------------------------------------------- +# PDF support +# --------------------------------------------------------------------------- + +class TestReadPdf: + + @pytest.fixture() + def tool(self, tmp_path): + return ReadFileTool(workspace=tmp_path) + + @pytest.mark.asyncio + async def test_pdf_returns_text_content(self, tool, tmp_path): + fitz = pytest.importorskip("fitz") + pdf_path = tmp_path / "test.pdf" + doc = fitz.open() + page = doc.new_page() + page.insert_text((72, 72), "Hello PDF World") + doc.save(str(pdf_path)) + doc.close() + + result = await tool.execute(path=str(pdf_path)) + assert "Hello PDF World" in result + + @pytest.mark.asyncio + async def test_pdf_pages_parameter(self, tool, tmp_path): + fitz = pytest.importorskip("fitz") + pdf_path = tmp_path / "multi.pdf" + doc = fitz.open() + for i in range(5): + page = doc.new_page() + page.insert_text((72, 72), f"Page {i + 1} content") + doc.save(str(pdf_path)) + doc.close() + + result = await tool.execute(path=str(pdf_path), pages="2-3") + assert "Page 2 content" in result + assert "Page 3 content" in result + assert "Page 1 content" not in result + + @pytest.mark.asyncio + async def test_pdf_file_not_found_error(self, tool, tmp_path): + result = await tool.execute(path=str(tmp_path / "nope.pdf")) + assert "Error" in result + assert "not found" in result + + +# --------------------------------------------------------------------------- +# Device path blacklist +# --------------------------------------------------------------------------- + +class TestReadDeviceBlacklist: + + @pytest.fixture() + def tool(self): + return ReadFileTool() + + @pytest.mark.asyncio + async def test_dev_random_blocked(self, tool): + result = await tool.execute(path="/dev/random") + assert "Error" in result + assert "blocked" in result.lower() or "device" in result.lower() + + @pytest.mark.asyncio + async def test_dev_urandom_blocked(self, tool): + result = await tool.execute(path="/dev/urandom") + assert "Error" in result + + @pytest.mark.asyncio + async def test_dev_zero_blocked(self, tool): + result = await tool.execute(path="/dev/zero") + assert "Error" in result + + @pytest.mark.asyncio + async def test_proc_fd_blocked(self, tool): + result = await tool.execute(path="/proc/self/fd/0") + assert "Error" in result + + @pytest.mark.asyncio + async def test_symlink_to_dev_zero_blocked(self, tmp_path): + tool = ReadFileTool(workspace=tmp_path) + link = tmp_path / "zero-link" + link.symlink_to("/dev/zero") + result = await tool.execute(path=str(link)) + assert "Error" in result + assert "blocked" in result.lower() or "device" in result.lower()