mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-14 23:19:55 +00:00
improve file editing and add notebook tool
Enhance file tools with read tracking, PDF support, safer path handling, smarter edit matching/diagnostics, and introduce notebook_edit with tests.
This commit is contained in:
parent
9bccfa63d2
commit
651aeae656
@ -22,6 +22,7 @@ from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.search import GlobTool, GrepTool
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
@ -235,6 +236,7 @@ class AgentLoop:
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
for cls in (GlobTool, GrepTool):
|
||||
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
self.tools.register(NotebookEditTool(workspace=self.workspace, allowed_dir=allowed_dir))
|
||||
if self.exec_config.enable:
|
||||
self.tools.register(ExecTool(
|
||||
working_dir=str(self.workspace),
|
||||
|
||||
105
nanobot/agent/tools/file_state.py
Normal file
105
nanobot/agent/tools/file_state.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""Track file-read state for read-before-edit warnings and read deduplication."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ReadState:
|
||||
mtime: float
|
||||
offset: int
|
||||
limit: int | None
|
||||
content_hash: str | None
|
||||
can_dedup: bool
|
||||
|
||||
|
||||
_state: dict[str, ReadState] = {}
|
||||
|
||||
|
||||
def _hash_file(p: str) -> str | None:
|
||||
try:
|
||||
return hashlib.sha256(Path(p).read_bytes()).hexdigest()
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
def record_read(path: str | Path, offset: int = 1, limit: int | None = None) -> None:
|
||||
"""Record that a file was read (called after successful read)."""
|
||||
p = str(Path(path).resolve())
|
||||
try:
|
||||
mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
return
|
||||
_state[p] = ReadState(
|
||||
mtime=mtime,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
content_hash=_hash_file(p),
|
||||
can_dedup=True,
|
||||
)
|
||||
|
||||
|
||||
def record_write(path: str | Path) -> None:
|
||||
"""Record that a file was written (updates mtime in state)."""
|
||||
p = str(Path(path).resolve())
|
||||
try:
|
||||
mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
_state.pop(p, None)
|
||||
return
|
||||
_state[p] = ReadState(
|
||||
mtime=mtime,
|
||||
offset=1,
|
||||
limit=None,
|
||||
content_hash=_hash_file(p),
|
||||
can_dedup=False,
|
||||
)
|
||||
|
||||
|
||||
def check_read(path: str | Path) -> str | None:
|
||||
"""Check if a file has been read and is fresh.
|
||||
|
||||
Returns None if OK, or a warning string.
|
||||
When mtime changed but file content is identical (e.g. touch, editor save),
|
||||
the check passes to avoid false-positive staleness warnings.
|
||||
"""
|
||||
p = str(Path(path).resolve())
|
||||
entry = _state.get(p)
|
||||
if entry is None:
|
||||
return "Warning: file has not been read yet. Read it first to verify content before editing."
|
||||
try:
|
||||
current_mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
return None
|
||||
if current_mtime != entry.mtime:
|
||||
if entry.content_hash and _hash_file(p) == entry.content_hash:
|
||||
entry.mtime = current_mtime
|
||||
return None
|
||||
return "Warning: file has been modified since last read. Re-read to verify content before editing."
|
||||
return None
|
||||
|
||||
|
||||
def is_unchanged(path: str | Path, offset: int = 1, limit: int | None = None) -> bool:
|
||||
"""Return True if file was previously read with same params and mtime is unchanged."""
|
||||
p = str(Path(path).resolve())
|
||||
entry = _state.get(p)
|
||||
if entry is None:
|
||||
return False
|
||||
if not entry.can_dedup:
|
||||
return False
|
||||
if entry.offset != offset or entry.limit != limit:
|
||||
return False
|
||||
try:
|
||||
current_mtime = os.path.getmtime(p)
|
||||
except OSError:
|
||||
return False
|
||||
return current_mtime == entry.mtime
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
"""Clear all tracked state (useful for testing)."""
|
||||
_state.clear()
|
||||
@ -2,11 +2,13 @@
|
||||
|
||||
import difflib
|
||||
import mimetypes
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools import file_state
|
||||
from nanobot.utils.helpers import build_image_content_blocks, detect_image_mime
|
||||
from nanobot.config.paths import get_media_dir
|
||||
|
||||
@ -60,6 +62,36 @@ class _FsTool(Tool):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_BLOCKED_DEVICE_PATHS = frozenset({
|
||||
"/dev/zero", "/dev/random", "/dev/urandom", "/dev/full",
|
||||
"/dev/stdin", "/dev/stdout", "/dev/stderr",
|
||||
"/dev/tty", "/dev/console",
|
||||
"/dev/fd/0", "/dev/fd/1", "/dev/fd/2",
|
||||
})
|
||||
|
||||
|
||||
def _is_blocked_device(path: str | Path) -> bool:
|
||||
"""Check if path is a blocked device that could hang or produce infinite output."""
|
||||
import re
|
||||
raw = str(path)
|
||||
if raw in _BLOCKED_DEVICE_PATHS:
|
||||
return True
|
||||
if re.match(r"/proc/\d+/fd/[012]$", raw) or re.match(r"/proc/self/fd/[012]$", raw):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _parse_page_range(pages: str, total: int) -> tuple[int, int]:
|
||||
"""Parse a page range like '2-5' into 0-based (start, end) inclusive."""
|
||||
parts = pages.strip().split("-")
|
||||
if len(parts) == 1:
|
||||
p = int(parts[0])
|
||||
return max(0, p - 1), min(p - 1, total - 1)
|
||||
start = int(parts[0])
|
||||
end = int(parts[1])
|
||||
return max(0, start - 1), min(end - 1, total - 1)
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
path=StringSchema("The file path to read"),
|
||||
@ -73,6 +105,7 @@ class _FsTool(Tool):
|
||||
description="Maximum number of lines to read (default 2000)",
|
||||
minimum=1,
|
||||
),
|
||||
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
|
||||
required=["path"],
|
||||
)
|
||||
)
|
||||
@ -81,6 +114,7 @@ class ReadFileTool(_FsTool):
|
||||
|
||||
_MAX_CHARS = 128_000
|
||||
_DEFAULT_LIMIT = 2000
|
||||
_MAX_PDF_PAGES = 20
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -89,9 +123,10 @@ class ReadFileTool(_FsTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Read a text file. Output format: LINE_NUM|CONTENT. "
|
||||
"Read a file (text or image). Text output format: LINE_NUM|CONTENT. "
|
||||
"Images return visual content for analysis. "
|
||||
"Use offset and limit for large files. "
|
||||
"Cannot read binary files or images. "
|
||||
"Cannot read non-image binary files. "
|
||||
"Reads exceeding ~128K chars are truncated."
|
||||
)
|
||||
|
||||
@ -99,16 +134,27 @@ class ReadFileTool(_FsTool):
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, **kwargs: Any) -> Any:
|
||||
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any:
|
||||
try:
|
||||
if not path:
|
||||
return "Error reading file: Unknown path"
|
||||
|
||||
# Device path blacklist
|
||||
if _is_blocked_device(path):
|
||||
return f"Error: Reading {path} is blocked (device path that could hang or produce infinite output)."
|
||||
|
||||
fp = self._resolve(path)
|
||||
if _is_blocked_device(fp):
|
||||
return f"Error: Reading {fp} is blocked (device path that could hang or produce infinite output)."
|
||||
if not fp.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
if not fp.is_file():
|
||||
return f"Error: Not a file: {path}"
|
||||
|
||||
# PDF support
|
||||
if fp.suffix.lower() == ".pdf":
|
||||
return self._read_pdf(fp, pages)
|
||||
|
||||
raw = fp.read_bytes()
|
||||
if not raw:
|
||||
return f"(Empty file: {path})"
|
||||
@ -117,6 +163,10 @@ class ReadFileTool(_FsTool):
|
||||
if mime and mime.startswith("image/"):
|
||||
return build_image_content_blocks(raw, mime, str(fp), f"(Image file: {path})")
|
||||
|
||||
# Read dedup: same path + offset + limit + unchanged mtime → stub
|
||||
if file_state.is_unchanged(fp, offset=offset, limit=limit):
|
||||
return f"[File unchanged since last read: {path}]"
|
||||
|
||||
try:
|
||||
text_content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
@ -149,12 +199,59 @@ class ReadFileTool(_FsTool):
|
||||
result += f"\n\n(Showing lines {offset}-{end} of {total}. Use offset={end + 1} to continue.)"
|
||||
else:
|
||||
result += f"\n\n(End of file — {total} lines total)"
|
||||
file_state.record_read(fp, offset=offset, limit=limit)
|
||||
return result
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error reading file: {e}"
|
||||
|
||||
def _read_pdf(self, fp: Path, pages: str | None) -> str:
|
||||
try:
|
||||
import fitz # pymupdf
|
||||
except ImportError:
|
||||
return "Error: PDF reading requires pymupdf. Install with: pip install pymupdf"
|
||||
|
||||
try:
|
||||
doc = fitz.open(str(fp))
|
||||
except Exception as e:
|
||||
return f"Error reading PDF: {e}"
|
||||
|
||||
total_pages = len(doc)
|
||||
if pages:
|
||||
try:
|
||||
start, end = _parse_page_range(pages, total_pages)
|
||||
except (ValueError, IndexError):
|
||||
doc.close()
|
||||
return f"Error: Invalid page range '{pages}'. Use format like '1-5'."
|
||||
if start > end or start >= total_pages:
|
||||
doc.close()
|
||||
return f"Error: Page range '{pages}' is out of bounds (document has {total_pages} pages)."
|
||||
else:
|
||||
start = 0
|
||||
end = min(total_pages - 1, self._MAX_PDF_PAGES - 1)
|
||||
|
||||
if end - start + 1 > self._MAX_PDF_PAGES:
|
||||
end = start + self._MAX_PDF_PAGES - 1
|
||||
|
||||
parts: list[str] = []
|
||||
for i in range(start, end + 1):
|
||||
page = doc[i]
|
||||
text = page.get_text().strip()
|
||||
if text:
|
||||
parts.append(f"--- Page {i + 1} ---\n{text}")
|
||||
doc.close()
|
||||
|
||||
if not parts:
|
||||
return f"(PDF has no extractable text: {fp})"
|
||||
|
||||
result = "\n\n".join(parts)
|
||||
if end < total_pages - 1:
|
||||
result += f"\n\n(Showing pages {start + 1}-{end + 1} of {total_pages}. Use pages='{end + 2}-{min(end + 1 + self._MAX_PDF_PAGES, total_pages)}' to continue.)"
|
||||
if len(result) > self._MAX_CHARS:
|
||||
result = result[:self._MAX_CHARS] + "\n\n(PDF text truncated at ~128K chars)"
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# write_file
|
||||
@ -192,6 +289,7 @@ class WriteFileTool(_FsTool):
|
||||
fp = self._resolve(path)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(content, encoding="utf-8")
|
||||
file_state.record_write(fp)
|
||||
return f"Successfully wrote {len(content)} characters to {fp}"
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
@ -203,30 +301,269 @@ class WriteFileTool(_FsTool):
|
||||
# edit_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_QUOTE_TABLE = str.maketrans({
|
||||
"\u2018": "'", "\u2019": "'", # curly single → straight
|
||||
"\u201c": '"', "\u201d": '"', # curly double → straight
|
||||
"'": "'", '"': '"', # identity (kept for completeness)
|
||||
})
|
||||
|
||||
|
||||
def _normalize_quotes(s: str) -> str:
|
||||
return s.translate(_QUOTE_TABLE)
|
||||
|
||||
|
||||
def _curly_double_quotes(text: str) -> str:
|
||||
parts: list[str] = []
|
||||
opening = True
|
||||
for ch in text:
|
||||
if ch == '"':
|
||||
parts.append("\u201c" if opening else "\u201d")
|
||||
opening = not opening
|
||||
else:
|
||||
parts.append(ch)
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _curly_single_quotes(text: str) -> str:
|
||||
parts: list[str] = []
|
||||
opening = True
|
||||
for i, ch in enumerate(text):
|
||||
if ch != "'":
|
||||
parts.append(ch)
|
||||
continue
|
||||
prev_ch = text[i - 1] if i > 0 else ""
|
||||
next_ch = text[i + 1] if i + 1 < len(text) else ""
|
||||
if prev_ch.isalnum() and next_ch.isalnum():
|
||||
parts.append("\u2019")
|
||||
continue
|
||||
parts.append("\u2018" if opening else "\u2019")
|
||||
opening = not opening
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _preserve_quote_style(old_text: str, actual_text: str, new_text: str) -> str:
|
||||
"""Preserve curly quote style when a quote-normalized fallback matched."""
|
||||
if _normalize_quotes(old_text.strip()) != _normalize_quotes(actual_text.strip()) or old_text == actual_text:
|
||||
return new_text
|
||||
|
||||
styled = new_text
|
||||
if any(ch in actual_text for ch in ("\u201c", "\u201d")) and '"' in styled:
|
||||
styled = _curly_double_quotes(styled)
|
||||
if any(ch in actual_text for ch in ("\u2018", "\u2019")) and "'" in styled:
|
||||
styled = _curly_single_quotes(styled)
|
||||
return styled
|
||||
|
||||
|
||||
def _leading_ws(line: str) -> str:
|
||||
return line[: len(line) - len(line.lstrip(" \t"))]
|
||||
|
||||
|
||||
def _reindent_like_match(old_text: str, actual_text: str, new_text: str) -> str:
|
||||
"""Preserve the outer indentation from the actual matched block."""
|
||||
old_lines = old_text.split("\n")
|
||||
actual_lines = actual_text.split("\n")
|
||||
if len(old_lines) != len(actual_lines):
|
||||
return new_text
|
||||
|
||||
comparable = [
|
||||
(old_line, actual_line)
|
||||
for old_line, actual_line in zip(old_lines, actual_lines)
|
||||
if old_line.strip() and actual_line.strip()
|
||||
]
|
||||
if not comparable or any(
|
||||
_normalize_quotes(old_line.strip()) != _normalize_quotes(actual_line.strip())
|
||||
for old_line, actual_line in comparable
|
||||
):
|
||||
return new_text
|
||||
|
||||
old_ws = _leading_ws(comparable[0][0])
|
||||
actual_ws = _leading_ws(comparable[0][1])
|
||||
if actual_ws == old_ws:
|
||||
return new_text
|
||||
|
||||
if old_ws:
|
||||
if not actual_ws.startswith(old_ws):
|
||||
return new_text
|
||||
delta = actual_ws[len(old_ws):]
|
||||
else:
|
||||
delta = actual_ws
|
||||
|
||||
if not delta:
|
||||
return new_text
|
||||
|
||||
return "\n".join((delta + line) if line else line for line in new_text.split("\n"))
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _MatchSpan:
|
||||
start: int
|
||||
end: int
|
||||
text: str
|
||||
line: int
|
||||
|
||||
|
||||
def _find_exact_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||
matches: list[_MatchSpan] = []
|
||||
start = 0
|
||||
while True:
|
||||
idx = content.find(old_text, start)
|
||||
if idx == -1:
|
||||
break
|
||||
matches.append(
|
||||
_MatchSpan(
|
||||
start=idx,
|
||||
end=idx + len(old_text),
|
||||
text=content[idx : idx + len(old_text)],
|
||||
line=content.count("\n", 0, idx) + 1,
|
||||
)
|
||||
)
|
||||
start = idx + max(1, len(old_text))
|
||||
return matches
|
||||
|
||||
|
||||
def _find_trim_matches(content: str, old_text: str, *, normalize_quotes: bool = False) -> list[_MatchSpan]:
|
||||
old_lines = old_text.splitlines()
|
||||
if not old_lines:
|
||||
return []
|
||||
|
||||
content_lines = content.splitlines()
|
||||
content_lines_keepends = content.splitlines(keepends=True)
|
||||
if len(content_lines) < len(old_lines):
|
||||
return []
|
||||
|
||||
offsets: list[int] = []
|
||||
pos = 0
|
||||
for line in content_lines_keepends:
|
||||
offsets.append(pos)
|
||||
pos += len(line)
|
||||
offsets.append(pos)
|
||||
|
||||
if normalize_quotes:
|
||||
stripped_old = [_normalize_quotes(line.strip()) for line in old_lines]
|
||||
else:
|
||||
stripped_old = [line.strip() for line in old_lines]
|
||||
|
||||
matches: list[_MatchSpan] = []
|
||||
window_size = len(stripped_old)
|
||||
for i in range(len(content_lines) - window_size + 1):
|
||||
window = content_lines[i : i + window_size]
|
||||
if normalize_quotes:
|
||||
comparable = [_normalize_quotes(line.strip()) for line in window]
|
||||
else:
|
||||
comparable = [line.strip() for line in window]
|
||||
if comparable != stripped_old:
|
||||
continue
|
||||
|
||||
start = offsets[i]
|
||||
end = offsets[i + window_size]
|
||||
if content_lines_keepends[i + window_size - 1].endswith("\n"):
|
||||
end -= 1
|
||||
matches.append(
|
||||
_MatchSpan(
|
||||
start=start,
|
||||
end=end,
|
||||
text=content[start:end],
|
||||
line=i + 1,
|
||||
)
|
||||
)
|
||||
return matches
|
||||
|
||||
|
||||
def _find_quote_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||
norm_content = _normalize_quotes(content)
|
||||
norm_old = _normalize_quotes(old_text)
|
||||
matches: list[_MatchSpan] = []
|
||||
start = 0
|
||||
while True:
|
||||
idx = norm_content.find(norm_old, start)
|
||||
if idx == -1:
|
||||
break
|
||||
matches.append(
|
||||
_MatchSpan(
|
||||
start=idx,
|
||||
end=idx + len(old_text),
|
||||
text=content[idx : idx + len(old_text)],
|
||||
line=content.count("\n", 0, idx) + 1,
|
||||
)
|
||||
)
|
||||
start = idx + max(1, len(norm_old))
|
||||
return matches
|
||||
|
||||
|
||||
def _find_matches(content: str, old_text: str) -> list[_MatchSpan]:
|
||||
"""Locate all matches using progressively looser strategies."""
|
||||
for matcher in (
|
||||
lambda: _find_exact_matches(content, old_text),
|
||||
lambda: _find_trim_matches(content, old_text),
|
||||
lambda: _find_trim_matches(content, old_text, normalize_quotes=True),
|
||||
lambda: _find_quote_matches(content, old_text),
|
||||
):
|
||||
matches = matcher()
|
||||
if matches:
|
||||
return matches
|
||||
return []
|
||||
|
||||
|
||||
def _find_match_line_numbers(content: str, old_text: str) -> list[int]:
|
||||
"""Return 1-based starting line numbers for the current matching strategies."""
|
||||
return [match.line for match in _find_matches(content, old_text)]
|
||||
|
||||
|
||||
def _collapse_internal_whitespace(text: str) -> str:
|
||||
return "\n".join(" ".join(line.split()) for line in text.splitlines())
|
||||
|
||||
|
||||
def _diagnose_near_match(old_text: str, actual_text: str) -> list[str]:
|
||||
"""Return actionable hints describing why text was close but not exact."""
|
||||
hints: list[str] = []
|
||||
|
||||
if old_text.lower() == actual_text.lower() and old_text != actual_text:
|
||||
hints.append("letter case differs")
|
||||
if _collapse_internal_whitespace(old_text) == _collapse_internal_whitespace(actual_text) and old_text != actual_text:
|
||||
hints.append("whitespace differs")
|
||||
if old_text.rstrip("\n") == actual_text.rstrip("\n") and old_text != actual_text:
|
||||
hints.append("trailing newline differs")
|
||||
if _normalize_quotes(old_text) == _normalize_quotes(actual_text) and old_text != actual_text:
|
||||
hints.append("quote style differs")
|
||||
|
||||
return hints
|
||||
|
||||
|
||||
def _best_window(old_text: str, content: str) -> tuple[float, int, list[str], list[str]]:
|
||||
"""Find the closest line-window match and return ratio/start/snippet/hints."""
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = max(1, len(old_lines))
|
||||
|
||||
best_ratio, best_start = -1.0, 0
|
||||
best_window_lines: list[str] = []
|
||||
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
current = lines[i : i + window]
|
||||
ratio = difflib.SequenceMatcher(None, old_lines, current).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start = ratio, i
|
||||
best_window_lines = current
|
||||
|
||||
actual_text = "".join(best_window_lines).replace("\r\n", "\n").rstrip("\n")
|
||||
hints = _diagnose_near_match(old_text.replace("\r\n", "\n").rstrip("\n"), actual_text)
|
||||
return best_ratio, best_start, best_window_lines, hints
|
||||
|
||||
|
||||
def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||
"""Locate old_text in content: exact first, then line-trimmed sliding window.
|
||||
"""Locate old_text in content with a multi-level fallback chain:
|
||||
|
||||
1. Exact substring match
|
||||
2. Line-trimmed sliding window (handles indentation differences)
|
||||
3. Smart quote normalization (curly ↔ straight quotes)
|
||||
|
||||
Both inputs should use LF line endings (caller normalises CRLF).
|
||||
Returns (matched_fragment, count) or (None, 0).
|
||||
"""
|
||||
if old_text in content:
|
||||
return old_text, content.count(old_text)
|
||||
|
||||
old_lines = old_text.splitlines()
|
||||
if not old_lines:
|
||||
matches = _find_matches(content, old_text)
|
||||
if not matches:
|
||||
return None, 0
|
||||
stripped_old = [l.strip() for l in old_lines]
|
||||
content_lines = content.splitlines()
|
||||
|
||||
candidates = []
|
||||
for i in range(len(content_lines) - len(stripped_old) + 1):
|
||||
window = content_lines[i : i + len(stripped_old)]
|
||||
if [l.strip() for l in window] == stripped_old:
|
||||
candidates.append("\n".join(window))
|
||||
|
||||
if candidates:
|
||||
return candidates[0], len(candidates)
|
||||
return None, 0
|
||||
return matches[0].text, len(matches)
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
@ -241,6 +578,9 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||
class EditFileTool(_FsTool):
|
||||
"""Edit a file by replacing text with fallback matching."""
|
||||
|
||||
_MAX_EDIT_FILE_SIZE = 1024 * 1024 * 1024 # 1 GiB
|
||||
_MARKDOWN_EXTS = frozenset({".md", ".mdx", ".markdown"})
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "edit_file"
|
||||
@ -249,11 +589,16 @@ class EditFileTool(_FsTool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a file by replacing old_text with new_text. "
|
||||
"Tolerates minor whitespace/indentation differences. "
|
||||
"Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. "
|
||||
"If old_text matches multiple times, you must provide more context "
|
||||
"or set replace_all=true. Shows a diff of the closest match on failure."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _strip_trailing_ws(text: str) -> str:
|
||||
"""Strip trailing whitespace from each line."""
|
||||
return "\n".join(line.rstrip() for line in text.split("\n"))
|
||||
|
||||
async def execute(
|
||||
self, path: str | None = None, old_text: str | None = None,
|
||||
new_text: str | None = None,
|
||||
@ -267,55 +612,133 @@ class EditFileTool(_FsTool):
|
||||
if new_text is None:
|
||||
raise ValueError("Unknown new_text")
|
||||
|
||||
# .ipynb detection
|
||||
if path.endswith(".ipynb"):
|
||||
return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file."
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
# Create-file semantics: old_text='' + file doesn't exist → create
|
||||
if not fp.exists():
|
||||
return f"Error: File not found: {path}"
|
||||
if old_text == "":
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(new_text, encoding="utf-8")
|
||||
file_state.record_write(fp)
|
||||
return f"Successfully created {fp}"
|
||||
return self._file_not_found_msg(path, fp)
|
||||
|
||||
# File size protection
|
||||
try:
|
||||
fsize = fp.stat().st_size
|
||||
except OSError:
|
||||
fsize = 0
|
||||
if fsize > self._MAX_EDIT_FILE_SIZE:
|
||||
return f"Error: File too large to edit ({fsize / (1024**3):.1f} GiB). Maximum is 1 GiB."
|
||||
|
||||
# Create-file: old_text='' but file exists and not empty → reject
|
||||
if old_text == "":
|
||||
raw = fp.read_bytes()
|
||||
content = raw.decode("utf-8")
|
||||
if content.strip():
|
||||
return f"Error: Cannot create file — {path} already exists and is not empty."
|
||||
fp.write_text(new_text, encoding="utf-8")
|
||||
file_state.record_write(fp)
|
||||
return f"Successfully edited {fp}"
|
||||
|
||||
# Read-before-edit check
|
||||
warning = file_state.check_read(fp)
|
||||
|
||||
raw = fp.read_bytes()
|
||||
uses_crlf = b"\r\n" in raw
|
||||
content = raw.decode("utf-8").replace("\r\n", "\n")
|
||||
match, count = _find_match(content, old_text.replace("\r\n", "\n"))
|
||||
norm_old = old_text.replace("\r\n", "\n")
|
||||
matches = _find_matches(content, norm_old)
|
||||
|
||||
if match is None:
|
||||
if not matches:
|
||||
return self._not_found_msg(old_text, content, path)
|
||||
count = len(matches)
|
||||
if count > 1 and not replace_all:
|
||||
line_numbers = [match.line for match in matches]
|
||||
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
|
||||
if len(line_numbers) > 3:
|
||||
preview += ", ..."
|
||||
location_hint = f" at {preview}" if preview else ""
|
||||
return (
|
||||
f"Warning: old_text appears {count} times. "
|
||||
f"Warning: old_text appears {count} times{location_hint}. "
|
||||
"Provide more context to make it unique, or set replace_all=true."
|
||||
)
|
||||
|
||||
norm_new = new_text.replace("\r\n", "\n")
|
||||
new_content = content.replace(match, norm_new) if replace_all else content.replace(match, norm_new, 1)
|
||||
|
||||
# Trailing whitespace stripping (skip markdown to preserve double-space line breaks)
|
||||
if fp.suffix.lower() not in self._MARKDOWN_EXTS:
|
||||
norm_new = self._strip_trailing_ws(norm_new)
|
||||
|
||||
selected = matches if replace_all else matches[:1]
|
||||
new_content = content
|
||||
for match in reversed(selected):
|
||||
replacement = _preserve_quote_style(norm_old, match.text, norm_new)
|
||||
replacement = _reindent_like_match(norm_old, match.text, replacement)
|
||||
|
||||
# Delete-line cleanup: when deleting text (new_text=''), consume trailing
|
||||
# newline to avoid leaving a blank line
|
||||
end = match.end
|
||||
if replacement == "" and not match.text.endswith("\n") and content[end:end + 1] == "\n":
|
||||
end += 1
|
||||
|
||||
new_content = new_content[: match.start] + replacement + new_content[end:]
|
||||
if uses_crlf:
|
||||
new_content = new_content.replace("\n", "\r\n")
|
||||
|
||||
fp.write_bytes(new_content.encode("utf-8"))
|
||||
return f"Successfully edited {fp}"
|
||||
file_state.record_write(fp)
|
||||
msg = f"Successfully edited {fp}"
|
||||
if warning:
|
||||
msg = f"{warning}\n{msg}"
|
||||
return msg
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing file: {e}"
|
||||
|
||||
def _file_not_found_msg(self, path: str, fp: Path) -> str:
|
||||
"""Build an error message with 'Did you mean ...?' suggestions."""
|
||||
parent = fp.parent
|
||||
suggestions: list[str] = []
|
||||
if parent.is_dir():
|
||||
siblings = [f.name for f in parent.iterdir() if f.is_file()]
|
||||
close = difflib.get_close_matches(fp.name, siblings, n=3, cutoff=0.6)
|
||||
suggestions = [str(parent / c) for c in close]
|
||||
parts = [f"Error: File not found: {path}"]
|
||||
if suggestions:
|
||||
parts.append("Did you mean: " + ", ".join(suggestions) + "?")
|
||||
return "\n".join(parts)
|
||||
|
||||
@staticmethod
|
||||
def _not_found_msg(old_text: str, content: str, path: str) -> str:
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = len(old_lines)
|
||||
|
||||
best_ratio, best_start = 0.0, 0
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
ratio = difflib.SequenceMatcher(None, old_lines, lines[i : i + window]).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start = ratio, i
|
||||
|
||||
best_ratio, best_start, best_window_lines, hints = _best_window(old_text, content)
|
||||
if best_ratio > 0.5:
|
||||
diff = "\n".join(difflib.unified_diff(
|
||||
old_lines, lines[best_start : best_start + window],
|
||||
old_text.splitlines(keepends=True),
|
||||
best_window_lines,
|
||||
fromfile="old_text (provided)",
|
||||
tofile=f"{path} (actual, line {best_start + 1})",
|
||||
lineterm="",
|
||||
))
|
||||
return f"Error: old_text not found in {path}.\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
hint_text = ""
|
||||
if hints:
|
||||
hint_text = "\nPossible cause: " + ", ".join(hints) + "."
|
||||
return (
|
||||
f"Error: old_text not found in {path}."
|
||||
f"{hint_text}\nBest match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
)
|
||||
|
||||
if hints:
|
||||
return (
|
||||
f"Error: old_text not found in {path}. "
|
||||
f"Possible cause: {', '.join(hints)}. "
|
||||
"Copy the exact text from read_file and try again."
|
||||
)
|
||||
return f"Error: old_text not found in {path}. No similar text found. Verify the file content."
|
||||
|
||||
|
||||
|
||||
162
nanobot/agent/tools/notebook.py
Normal file
162
nanobot/agent/tools/notebook.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""NotebookEditTool — edit Jupyter .ipynb notebooks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools.filesystem import _FsTool
|
||||
|
||||
|
||||
def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict:
|
||||
cell: dict[str, Any] = {
|
||||
"cell_type": cell_type,
|
||||
"source": source,
|
||||
"metadata": {},
|
||||
}
|
||||
if cell_type == "code":
|
||||
cell["outputs"] = []
|
||||
cell["execution_count"] = None
|
||||
if generate_id:
|
||||
cell["id"] = uuid.uuid4().hex[:8]
|
||||
return cell
|
||||
|
||||
|
||||
def _make_empty_notebook() -> dict:
|
||||
return {
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5,
|
||||
"metadata": {
|
||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
||||
"language_info": {"name": "python"},
|
||||
},
|
||||
"cells": [],
|
||||
}
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
path=StringSchema("Path to the .ipynb notebook file"),
|
||||
cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0),
|
||||
new_source=StringSchema("New source content for the cell"),
|
||||
cell_type=StringSchema(
|
||||
"Cell type: 'code' or 'markdown' (default: code)",
|
||||
enum=["code", "markdown"],
|
||||
),
|
||||
edit_mode=StringSchema(
|
||||
"Mode: 'replace' (default), 'insert' (after target), or 'delete'",
|
||||
enum=["replace", "insert", "delete"],
|
||||
),
|
||||
required=["path", "cell_index"],
|
||||
)
|
||||
)
|
||||
class NotebookEditTool(_FsTool):
|
||||
"""Edit Jupyter notebook cells: replace, insert, or delete."""
|
||||
|
||||
_VALID_CELL_TYPES = frozenset({"code", "markdown"})
|
||||
_VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"})
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "notebook_edit"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a Jupyter notebook (.ipynb) cell. "
|
||||
"Modes: replace (default) replaces cell content, "
|
||||
"insert adds a new cell after the target index, "
|
||||
"delete removes the cell at the index. "
|
||||
"cell_index is 0-based."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
path: str | None = None,
|
||||
cell_index: int = 0,
|
||||
new_source: str = "",
|
||||
cell_type: str = "code",
|
||||
edit_mode: str = "replace",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if not path:
|
||||
return "Error: path is required"
|
||||
|
||||
if not path.endswith(".ipynb"):
|
||||
return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files."
|
||||
|
||||
if edit_mode not in self._VALID_EDIT_MODES:
|
||||
return (
|
||||
f"Error: Invalid edit_mode '{edit_mode}'. "
|
||||
"Use one of: replace, insert, delete."
|
||||
)
|
||||
|
||||
if cell_type not in self._VALID_CELL_TYPES:
|
||||
return (
|
||||
f"Error: Invalid cell_type '{cell_type}'. "
|
||||
"Use one of: code, markdown."
|
||||
)
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
# Create new notebook if file doesn't exist and mode is insert
|
||||
if not fp.exists():
|
||||
if edit_mode != "insert":
|
||||
return f"Error: File not found: {path}"
|
||||
nb = _make_empty_notebook()
|
||||
cell = _new_cell(new_source, cell_type, generate_id=True)
|
||||
nb["cells"].append(cell)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully created {fp} with 1 cell"
|
||||
|
||||
try:
|
||||
nb = json.loads(fp.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
return f"Error: Failed to parse notebook: {e}"
|
||||
|
||||
cells = nb.get("cells", [])
|
||||
nbformat_minor = nb.get("nbformat_minor", 0)
|
||||
generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5
|
||||
|
||||
if edit_mode == "delete":
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells.pop(cell_index)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully deleted cell {cell_index} from {fp}"
|
||||
|
||||
if edit_mode == "insert":
|
||||
insert_at = min(cell_index + 1, len(cells))
|
||||
cell = _new_cell(new_source, cell_type, generate_id=generate_id)
|
||||
cells.insert(insert_at, cell)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully inserted cell at index {insert_at} in {fp}"
|
||||
|
||||
# Default: replace
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells[cell_index]["source"] = new_source
|
||||
if cell_type and cells[cell_index].get("cell_type") != cell_type:
|
||||
cells[cell_index]["cell_type"] = cell_type
|
||||
if cell_type == "code":
|
||||
cells[cell_index].setdefault("outputs", [])
|
||||
cells[cell_index].setdefault("execution_count", None)
|
||||
elif "outputs" in cells[cell_index]:
|
||||
del cells[cell_index]["outputs"]
|
||||
cells[cell_index].pop("execution_count", None)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully edited cell {cell_index} in {fp}"
|
||||
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing notebook: {e}"
|
||||
@ -76,12 +76,16 @@ discord = [
|
||||
langsmith = [
|
||||
"langsmith>=0.1.0",
|
||||
]
|
||||
pdf = [
|
||||
"pymupdf>=1.25.0",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=9.0.0,<10.0.0",
|
||||
"pytest-asyncio>=1.3.0,<2.0.0",
|
||||
"aiohttp>=3.9.0,<4.0.0",
|
||||
"pytest-cov>=6.0.0,<7.0.0",
|
||||
"ruff>=0.1.0",
|
||||
"pymupdf>=1.25.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
423
tests/tools/test_edit_advanced.py
Normal file
423
tests/tools/test_edit_advanced.py
Normal file
@ -0,0 +1,423 @@
|
||||
"""Tests for advanced EditFileTool enhancements inspired by claude-code:
|
||||
- Delete-line newline cleanup
|
||||
- Smart quote normalization (curly ↔ straight)
|
||||
- Quote style preservation in replacements
|
||||
- Indentation preservation when fallback match is trimmed
|
||||
- Trailing whitespace stripping for new_text
|
||||
- File size protection
|
||||
- Stale detection with content-equality fallback
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, _find_match
|
||||
from nanobot.agent.tools import file_state
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_file_state():
|
||||
file_state.clear()
|
||||
yield
|
||||
file_state.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Delete-line newline cleanup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteLineCleanup:
|
||||
"""When new_text='' and deleting a line, trailing newline should be consumed."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_line_consumes_trailing_newline(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("line1\nline2\nline3\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="line2", new_text="")
|
||||
assert "Successfully" in result
|
||||
content = f.read_text()
|
||||
# Should not leave a blank line where line2 was
|
||||
assert content == "line1\nline3\n"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_line_with_explicit_newline_in_old_text(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("line1\nline2\nline3\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="line2\n", new_text="")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "line1\nline3\n"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_preserves_content_when_not_trailing_newline(self, tool, tmp_path):
|
||||
"""Deleting a word mid-line should not consume extra characters."""
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world here\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="world ", new_text="")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "hello here\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Smart quote normalization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSmartQuoteNormalization:
|
||||
"""_find_match should handle curly ↔ straight quote fallback."""
|
||||
|
||||
def test_curly_double_quotes_match_straight(self):
|
||||
content = 'She said \u201chello\u201d to him'
|
||||
old_text = 'She said "hello" to him'
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
# Returned match should be the ORIGINAL content with curly quotes
|
||||
assert "\u201c" in match
|
||||
|
||||
def test_curly_single_quotes_match_straight(self):
|
||||
content = "it\u2019s a test"
|
||||
old_text = "it's a test"
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
assert "\u2019" in match
|
||||
|
||||
def test_straight_matches_curly_in_old_text(self):
|
||||
content = 'x = "hello"'
|
||||
old_text = 'x = \u201chello\u201d'
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match is not None
|
||||
assert count == 1
|
||||
|
||||
def test_exact_match_still_preferred_over_quote_normalization(self):
|
||||
content = 'x = "hello"'
|
||||
old_text = 'x = "hello"'
|
||||
match, count = _find_match(content, old_text)
|
||||
assert match == old_text
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestQuoteStylePreservation:
|
||||
"""When quote-normalized matching occurs, replacement should preserve actual quote style."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replacement_preserves_curly_double_quotes(self, tool, tmp_path):
|
||||
f = tmp_path / "quotes.txt"
|
||||
f.write_text('message = “hello”\n', encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='message = "hello"',
|
||||
new_text='message = "goodbye"',
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == 'message = “goodbye”\n'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replacement_preserves_curly_apostrophe(self, tool, tmp_path):
|
||||
f = tmp_path / "apostrophe.txt"
|
||||
f.write_text("it’s fine\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="it's fine",
|
||||
new_text="it's better",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == "it’s better\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Indentation preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIndentationPreservation:
|
||||
"""Replacement should keep outer indentation when trim fallback matched."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_fallback_preserves_outer_indentation(self, tool, tmp_path):
|
||||
f = tmp_path / "indent.py"
|
||||
f.write_text(
|
||||
"if True:\n"
|
||||
" def foo():\n"
|
||||
" pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="def foo():\n pass",
|
||||
new_text="def bar():\n return 1",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == (
|
||||
"if True:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Failure diagnostics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEditDiagnostics:
|
||||
"""Failure paths should offer actionable hints."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ambiguous_match_reports_candidate_lines(self, tool, tmp_path):
|
||||
f = tmp_path / "dup.py"
|
||||
f.write_text("aaa\nbbb\naaa\nbbb\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="aaa\nbbb", new_text="xxx")
|
||||
assert "appears 2 times" in result.lower()
|
||||
assert "line 1" in result.lower()
|
||||
assert "line 3" in result.lower()
|
||||
assert "replace_all=true" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_reports_whitespace_hint(self, tool, tmp_path):
|
||||
f = tmp_path / "space.py"
|
||||
f.write_text("value = 1\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="value = 1", new_text="value = 2")
|
||||
assert "Error" in result
|
||||
assert "whitespace" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_reports_case_hint(self, tool, tmp_path):
|
||||
f = tmp_path / "case.py"
|
||||
f.write_text("HelloWorld\n", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="helloworld", new_text="goodbye")
|
||||
assert "Error" in result
|
||||
assert "letter case differs" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Advanced fallback replacement behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdvancedReplaceAll:
|
||||
"""replace_all should work correctly for fallback-based matches too."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path):
|
||||
f = tmp_path / "indent_multi.py"
|
||||
f.write_text(
|
||||
"if a:\n"
|
||||
" def foo():\n"
|
||||
" pass\n"
|
||||
"if b:\n"
|
||||
" def foo():\n"
|
||||
" pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="def foo():\n pass",
|
||||
new_text="def bar():\n return 1",
|
||||
replace_all=True,
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == (
|
||||
"if a:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
"if b:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path):
|
||||
f = tmp_path / "quote_indent.py"
|
||||
f.write_text(" message = “hello”\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='message = "hello"',
|
||||
new_text='message = "goodbye"',
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == " message = “goodbye”\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Advanced fallback replacement behavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAdvancedReplaceAll:
|
||||
"""replace_all should work correctly for fallback-based matches too."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_all_preserves_each_match_indentation(self, tool, tmp_path):
|
||||
f = tmp_path / "indent_multi.py"
|
||||
f.write_text(
|
||||
"if a:\n"
|
||||
" def foo():\n"
|
||||
" pass\n"
|
||||
"if b:\n"
|
||||
" def foo():\n"
|
||||
" pass\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text="def foo():\n pass",
|
||||
new_text="def bar():\n return 1",
|
||||
replace_all=True,
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == (
|
||||
"if a:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
"if b:\n"
|
||||
" def bar():\n"
|
||||
" return 1\n"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trim_and_quote_fallback_match_succeeds(self, tool, tmp_path):
|
||||
f = tmp_path / "quote_indent.py"
|
||||
f.write_text(" message = “hello”\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='message = "hello"',
|
||||
new_text='message = "goodbye"',
|
||||
)
|
||||
assert "Successfully" in result
|
||||
assert f.read_text(encoding="utf-8") == " message = “goodbye”\n"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Trailing whitespace stripping on new_text
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrailingWhitespaceStrip:
|
||||
"""new_text trailing whitespace should be stripped (except .md files)."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strips_trailing_whitespace_from_new_text(self, tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("x = 1\n", encoding="utf-8")
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="x = 1", new_text="x = 2 \ny = 3 ",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
content = f.read_text()
|
||||
assert "x = 2\ny = 3\n" == content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_trailing_whitespace_in_markdown(self, tool, tmp_path):
|
||||
f = tmp_path / "doc.md"
|
||||
f.write_text("# Title\n", encoding="utf-8")
|
||||
# Markdown uses trailing double-space for line breaks
|
||||
result = await tool.execute(
|
||||
path=str(f), old_text="# Title", new_text="# Title \nSubtitle ",
|
||||
)
|
||||
assert "Successfully" in result
|
||||
content = f.read_text()
|
||||
# Trailing spaces should be preserved for markdown
|
||||
assert "Title " in content
|
||||
assert "Subtitle " in content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File size protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileSizeProtection:
|
||||
"""Editing extremely large files should be rejected."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_file_over_size_limit(self, tool, tmp_path):
|
||||
f = tmp_path / "huge.txt"
|
||||
f.write_text("x", encoding="utf-8")
|
||||
# Monkey-patch the file size check by creating a stat mock
|
||||
original_stat = f.stat
|
||||
|
||||
class FakeStat:
|
||||
def __init__(self, real_stat):
|
||||
self._real = real_stat
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._real, name)
|
||||
|
||||
@property
|
||||
def st_size(self):
|
||||
return 2 * 1024 * 1024 * 1024 # 2 GiB
|
||||
|
||||
import unittest.mock
|
||||
with unittest.mock.patch.object(type(f), 'stat', return_value=FakeStat(f.stat())):
|
||||
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
||||
assert "Error" in result
|
||||
assert "too large" in result.lower() or "size" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stale detection with content-equality fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStaleDetectionContentFallback:
|
||||
"""When mtime changed but file content is unchanged, edit should proceed without warning."""
|
||||
|
||||
@pytest.fixture()
|
||||
def read_tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def edit_tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mtime_bump_same_content_no_warning(self, read_tool, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
await read_tool.execute(path=str(f))
|
||||
|
||||
# Touch the file to bump mtime without changing content
|
||||
time.sleep(0.05)
|
||||
original_content = f.read_text()
|
||||
f.write_text(original_content, encoding="utf-8")
|
||||
|
||||
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
# Should NOT warn about modification since content is the same
|
||||
assert "modified" not in result.lower()
|
||||
152
tests/tools/test_edit_enhancements.py
Normal file
152
tests/tools/test_edit_enhancements.py
Normal file
@ -0,0 +1,152 @@
|
||||
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
|
||||
.ipynb detection, and create-file semantics."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools import file_state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_file_state():
|
||||
"""Reset global read-state between tests."""
|
||||
file_state.clear()
|
||||
yield
|
||||
file_state.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read-before-edit tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditReadTracking:
|
||||
"""edit_file should warn when file hasn't been read first."""
|
||||
|
||||
@pytest.fixture()
|
||||
def read_tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def edit_tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_warns_if_file_not_read_first(self, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
# Should still succeed but include a warning
|
||||
assert "Successfully" in result
|
||||
assert "not been read" in result.lower() or "warning" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_succeeds_cleanly_after_read(self, read_tool, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
await read_tool.execute(path=str(f))
|
||||
result = await edit_tool.execute(path=str(f), old_text="world", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
# No warning when file was read first
|
||||
assert "not been read" not in result.lower()
|
||||
assert f.read_text() == "hello earth"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_warns_if_file_modified_since_read(self, read_tool, edit_tool, tmp_path):
|
||||
f = tmp_path / "a.py"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
await read_tool.execute(path=str(f))
|
||||
# External modification
|
||||
f.write_text("hello universe", encoding="utf-8")
|
||||
result = await edit_tool.execute(path=str(f), old_text="universe", new_text="earth")
|
||||
assert "Successfully" in result
|
||||
assert "modified" in result.lower() or "warning" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Create-file semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditCreateFile:
|
||||
"""edit_file with old_text='' creates new file if not exists."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_file_with_empty_old_text(self, tool, tmp_path):
|
||||
f = tmp_path / "subdir" / "new.py"
|
||||
result = await tool.execute(path=str(f), old_text="", new_text="print('hi')")
|
||||
assert "created" in result.lower() or "Successfully" in result
|
||||
assert f.exists()
|
||||
assert f.read_text() == "print('hi')"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_fails_if_file_already_exists_and_not_empty(self, tool, tmp_path):
|
||||
f = tmp_path / "existing.py"
|
||||
f.write_text("existing content", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="", new_text="new content")
|
||||
assert "Error" in result or "already exists" in result.lower()
|
||||
# File should be unchanged
|
||||
assert f.read_text() == "existing content"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_succeeds_if_file_exists_but_empty(self, tool, tmp_path):
|
||||
f = tmp_path / "empty.py"
|
||||
f.write_text("", encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="", new_text="print('hi')")
|
||||
assert "Successfully" in result
|
||||
assert f.read_text() == "print('hi')"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# .ipynb detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditIpynbDetection:
|
||||
"""edit_file should refuse .ipynb and suggest notebook_edit."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path):
|
||||
f = tmp_path / "analysis.ipynb"
|
||||
f.write_text('{"cells": []}', encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
||||
assert "notebook" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Path suggestion on not-found
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditPathSuggestion:
|
||||
"""edit_file should suggest similar paths on not-found."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suggests_similar_filename(self, tool, tmp_path):
|
||||
f = tmp_path / "config.py"
|
||||
f.write_text("x = 1", encoding="utf-8")
|
||||
# Typo: conifg.py
|
||||
result = await tool.execute(
|
||||
path=str(tmp_path / "conifg.py"), old_text="x = 1", new_text="x = 2",
|
||||
)
|
||||
assert "Error" in result
|
||||
assert "config.py" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_shows_cwd_in_error(self, tool, tmp_path):
|
||||
result = await tool.execute(
|
||||
path=str(tmp_path / "nonexistent.py"), old_text="a", new_text="b",
|
||||
)
|
||||
assert "Error" in result
|
||||
147
tests/tools/test_notebook_tool.py
Normal file
147
tests/tools/test_notebook_tool.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""Tests for NotebookEditTool — Jupyter .ipynb editing."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
|
||||
|
||||
def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict:
|
||||
"""Build a minimal valid .ipynb structure."""
|
||||
return {
|
||||
"nbformat": nbformat,
|
||||
"nbformat_minor": nbformat_minor,
|
||||
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
|
||||
"cells": cells or [],
|
||||
}
|
||||
|
||||
|
||||
def _code_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _md_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "markdown", "source": source, "metadata": {}}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _write_nb(tmp_path, name: str, nb: dict) -> str:
|
||||
p = tmp_path / name
|
||||
p.write_text(json.dumps(nb), encoding="utf-8")
|
||||
return str(p)
|
||||
|
||||
|
||||
class TestNotebookEdit:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return NotebookEditTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_cell_content(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="print('world')")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][0]["source"] == "print('world')"
|
||||
assert saved["cells"][1]["source"] == "x = 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_cell_after_target(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 3
|
||||
assert saved["cells"][0]["source"] == "cell 0"
|
||||
assert saved["cells"][1]["source"] == "inserted"
|
||||
assert saved["cells"][2]["source"] == "cell 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_cell(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=1, edit_mode="delete")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 2
|
||||
assert saved["cells"][0]["source"] == "A"
|
||||
assert saved["cells"][1]["source"] == "C"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_notebook_from_scratch(self, tool, tmp_path):
|
||||
path = str(tmp_path / "new.ipynb")
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown")
|
||||
assert "Successfully" in result or "created" in result.lower()
|
||||
saved = json.loads((tmp_path / "new.ipynb").read_text())
|
||||
assert saved["nbformat"] == 4
|
||||
assert len(saved["cells"]) == 1
|
||||
assert saved["cells"][0]["cell_type"] == "markdown"
|
||||
assert saved["cells"][0]["source"] == "# Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_index_error(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("only cell")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=5, new_source="x")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_ipynb_rejected(self, tool, tmp_path):
|
||||
f = tmp_path / "script.py"
|
||||
f.write_text("pass")
|
||||
result = await tool.execute(path=str(f), cell_index=0, new_source="x")
|
||||
assert "Error" in result
|
||||
assert ".ipynb" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_metadata_and_outputs(self, tool, tmp_path):
|
||||
cell = _code_cell("old")
|
||||
cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}]
|
||||
cell["execution_count"] = 42
|
||||
nb = _make_notebook([cell])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="new")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["metadata"]["kernelspec"]["language"] == "python"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nbformat_45_generates_cell_id(self, tool, tmp_path):
|
||||
nb = _make_notebook([], nbformat_minor=5)
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert "id" in saved["cells"][0]
|
||||
assert len(saved["cells"][0]["id"]) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_with_cell_type_markdown(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][1]["cell_type"] == "markdown"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_edit_mode_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae")
|
||||
assert "Error" in result
|
||||
assert "edit_mode" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_type_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw")
|
||||
assert "Error" in result
|
||||
assert "cell_type" in result
|
||||
180
tests/tools/test_read_enhancements.py
Normal file
180
tests/tools/test_read_enhancements.py
Normal file
@ -0,0 +1,180 @@
|
||||
"""Tests for ReadFileTool enhancements: description fix, read dedup, PDF support, device blacklist."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.filesystem import ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools import file_state
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_file_state():
|
||||
file_state.clear()
|
||||
yield
|
||||
file_state.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Description fix
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadDescriptionFix:
|
||||
|
||||
def test_description_mentions_image_support(self):
|
||||
tool = ReadFileTool()
|
||||
assert "image" in tool.description.lower()
|
||||
|
||||
def test_description_no_longer_says_cannot_read_images(self):
|
||||
tool = ReadFileTool()
|
||||
assert "cannot read binary files or images" not in tool.description.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Read deduplication
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadDedup:
|
||||
"""Same file + same offset/limit + unchanged mtime -> short stub."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.fixture()
|
||||
def write_tool(self, tmp_path):
|
||||
return WriteFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_read_returns_unchanged_stub(self, tool, tmp_path):
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("\n".join(f"line {i}" for i in range(100)), encoding="utf-8")
|
||||
first = await tool.execute(path=str(f))
|
||||
assert "line 0" in first
|
||||
second = await tool.execute(path=str(f))
|
||||
assert "unchanged" in second.lower()
|
||||
# Stub should not contain file content
|
||||
assert "line 0" not in second
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_after_external_modification_returns_full(self, tool, tmp_path):
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("original", encoding="utf-8")
|
||||
await tool.execute(path=str(f))
|
||||
# Modify the file externally
|
||||
f.write_text("modified content", encoding="utf-8")
|
||||
second = await tool.execute(path=str(f))
|
||||
assert "modified content" in second
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_offset_returns_full(self, tool, tmp_path):
|
||||
f = tmp_path / "data.txt"
|
||||
f.write_text("\n".join(f"line {i}" for i in range(1, 21)), encoding="utf-8")
|
||||
await tool.execute(path=str(f), offset=1, limit=5)
|
||||
second = await tool.execute(path=str(f), offset=6, limit=5)
|
||||
# Different offset → full read, not stub
|
||||
assert "line 6" in second
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_read_after_write_returns_full_content(self, tool, write_tool, tmp_path):
|
||||
f = tmp_path / "fresh.txt"
|
||||
result = await write_tool.execute(path=str(f), content="hello")
|
||||
assert "Successfully" in result
|
||||
read_result = await tool.execute(path=str(f))
|
||||
assert "hello" in read_result
|
||||
assert "unchanged" not in read_result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedup_does_not_apply_to_images(self, tool, tmp_path):
|
||||
f = tmp_path / "img.png"
|
||||
f.write_bytes(b"\x89PNG\r\n\x1a\nfake-png-data")
|
||||
first = await tool.execute(path=str(f))
|
||||
assert isinstance(first, list)
|
||||
second = await tool.execute(path=str(f))
|
||||
# Images should always return full content blocks, not a stub
|
||||
assert isinstance(second, list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PDF support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadPdf:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return ReadFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_returns_text_content(self, tool, tmp_path):
|
||||
fitz = pytest.importorskip("fitz")
|
||||
pdf_path = tmp_path / "test.pdf"
|
||||
doc = fitz.open()
|
||||
page = doc.new_page()
|
||||
page.insert_text((72, 72), "Hello PDF World")
|
||||
doc.save(str(pdf_path))
|
||||
doc.close()
|
||||
|
||||
result = await tool.execute(path=str(pdf_path))
|
||||
assert "Hello PDF World" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_pages_parameter(self, tool, tmp_path):
|
||||
fitz = pytest.importorskip("fitz")
|
||||
pdf_path = tmp_path / "multi.pdf"
|
||||
doc = fitz.open()
|
||||
for i in range(5):
|
||||
page = doc.new_page()
|
||||
page.insert_text((72, 72), f"Page {i + 1} content")
|
||||
doc.save(str(pdf_path))
|
||||
doc.close()
|
||||
|
||||
result = await tool.execute(path=str(pdf_path), pages="2-3")
|
||||
assert "Page 2 content" in result
|
||||
assert "Page 3 content" in result
|
||||
assert "Page 1 content" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_file_not_found_error(self, tool, tmp_path):
|
||||
result = await tool.execute(path=str(tmp_path / "nope.pdf"))
|
||||
assert "Error" in result
|
||||
assert "not found" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Device path blacklist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadDeviceBlacklist:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self):
|
||||
return ReadFileTool()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_random_blocked(self, tool):
|
||||
result = await tool.execute(path="/dev/random")
|
||||
assert "Error" in result
|
||||
assert "blocked" in result.lower() or "device" in result.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_urandom_blocked(self, tool):
|
||||
result = await tool.execute(path="/dev/urandom")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dev_zero_blocked(self, tool):
|
||||
result = await tool.execute(path="/dev/zero")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proc_fd_blocked(self, tool):
|
||||
result = await tool.execute(path="/proc/self/fd/0")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_symlink_to_dev_zero_blocked(self, tmp_path):
|
||||
tool = ReadFileTool(workspace=tmp_path)
|
||||
link = tmp_path / "zero-link"
|
||||
link.symlink_to("/dev/zero")
|
||||
result = await tool.execute(path=str(link))
|
||||
assert "Error" in result
|
||||
assert "blocked" in result.lower() or "device" in result.lower()
|
||||
Loading…
x
Reference in New Issue
Block a user