feat(tools): optimize coding workflows

This commit is contained in:
Xubin Ren 2026-05-20 12:51:26 +08:00
parent eae51333ad
commit 6851fa57a6
13 changed files with 1771 additions and 457 deletions

View File

@ -61,7 +61,7 @@
- **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks. - **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks.
- **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened. - **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened.
- **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media. - **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media.
- **2026-04-10** 📓 Notebook editing tool, multiple MCP servers, Feishu streaming & done-emoji. - **2026-04-10** 📓 Multiple MCP servers, Feishu streaming & done-emoji.
- **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config. - **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config.
- **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback. - **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback.
- **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools. - **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools.

View File

@ -0,0 +1,341 @@
"""Structured patch editing tool for coding workflows."""
from __future__ import annotations
import difflib
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
from nanobot.agent.tools.base import tool_parameters
from nanobot.agent.tools.filesystem import _FsTool
from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema
PatchKind = Literal["add", "delete", "update"]
@dataclass(slots=True)
class _Hunk:
header: str | None
lines: list[tuple[str, str]]
@dataclass(slots=True)
class _PatchOp:
kind: PatchKind
path: str
new_path: str | None = None
add_lines: list[str] | None = None
hunks: list[_Hunk] | None = None
class _PatchError(ValueError):
pass
_ABSOLUTE_WINDOWS_RE = re.compile(r"^[A-Za-z]:[\\/]")
def _is_file_header(line: str) -> bool:
return (
line.startswith("*** Add File: ")
or line.startswith("*** Delete File: ")
or line.startswith("*** Update File: ")
)
def _validate_relative_path(path: str) -> str:
normalized = path.strip()
if not normalized:
raise _PatchError("patch path cannot be empty")
if "\0" in normalized:
raise _PatchError(f"patch path contains a null byte: {path!r}")
if normalized.startswith(("~", "/", "\\")) or _ABSOLUTE_WINDOWS_RE.match(normalized):
raise _PatchError(f"patch path must be relative: {path}")
if any(part == ".." for part in re.split(r"[\\/]+", normalized)):
raise _PatchError(f"patch path must not contain '..': {path}")
return normalized
def _lines_to_text(lines: list[str]) -> str:
if not lines:
return ""
return "\n".join(lines) + "\n"
def _parse_patch(patch: str) -> list[_PatchOp]:
lines = patch.replace("\r\n", "\n").replace("\r", "\n").split("\n")
if lines and lines[-1] == "":
lines.pop()
if not lines or lines[0] != "*** Begin Patch":
raise _PatchError("patch must start with '*** Begin Patch'")
if len(lines) < 2 or lines[-1] != "*** End Patch":
raise _PatchError("patch must end with '*** End Patch'")
ops: list[_PatchOp] = []
i = 1
end = len(lines) - 1
while i < end:
line = lines[i]
if line.startswith("*** Add File: "):
path = _validate_relative_path(line.removeprefix("*** Add File: "))
i += 1
add_lines: list[str] = []
while i < end and not _is_file_header(lines[i]):
if not lines[i].startswith("+"):
raise _PatchError(f"Add File lines must start with '+': {lines[i]!r}")
add_lines.append(lines[i][1:])
i += 1
ops.append(_PatchOp(kind="add", path=path, add_lines=add_lines))
continue
if line.startswith("*** Delete File: "):
path = _validate_relative_path(line.removeprefix("*** Delete File: "))
ops.append(_PatchOp(kind="delete", path=path))
i += 1
continue
if line.startswith("*** Update File: "):
path = _validate_relative_path(line.removeprefix("*** Update File: "))
i += 1
new_path: str | None = None
if i < end and lines[i].startswith("*** Move to: "):
new_path = _validate_relative_path(lines[i].removeprefix("*** Move to: "))
i += 1
hunks: list[_Hunk] = []
while i < end and not _is_file_header(lines[i]):
if not lines[i].startswith("@@"):
raise _PatchError(f"Update File sections require '@@' hunks: {lines[i]!r}")
header = lines[i][2:].strip() or None
i += 1
hunk_lines: list[tuple[str, str]] = []
while i < end and not lines[i].startswith("@@") and not _is_file_header(lines[i]):
if lines[i] == "*** End of File":
i += 1
break
if lines[i] == r"\ No newline at end of file":
i += 1
continue
if not lines[i] or lines[i][0] not in {" ", "+", "-"}:
raise _PatchError(f"Hunk lines must start with ' ', '+', or '-': {lines[i]!r}")
hunk_lines.append((lines[i][0], lines[i][1:]))
i += 1
if not hunk_lines:
raise _PatchError(f"Update File hunk is empty: {path}")
hunks.append(_Hunk(header=header, lines=hunk_lines))
if not hunks and new_path is None:
raise _PatchError(f"Update File requires at least one hunk or Move to: {path}")
ops.append(_PatchOp(kind="update", path=path, new_path=new_path, hunks=hunks))
continue
raise _PatchError(f"unknown patch header: {line!r}")
if not ops:
raise _PatchError("patch contains no file operations")
return ops
def _find_with_eof_fallback(content: str, needle: str, start: int) -> tuple[int, int]:
pos = content.find(needle, start)
if pos >= 0:
return pos, len(needle)
if needle.endswith("\n"):
trimmed = needle[:-1]
pos = content.find(trimmed, start)
if pos >= 0 and pos + len(trimmed) == len(content):
return pos, len(trimmed)
return -1, 0
def _line_offset(content: str, line_number: int) -> int:
if line_number <= 1:
return 0
offset = 0
for current, line in enumerate(content.splitlines(keepends=True), start=1):
if current >= line_number:
return offset
offset += len(line)
return len(content)
def _line_hint(header: str | None) -> int | None:
if not header:
return None
match = re.search(r"-(\d+)(?:,\d+)?", header)
return int(match.group(1)) if match else None
def _hunk_mismatch(path: str, old_text: str, content: str, header: str | None) -> str:
lines = content.splitlines(keepends=True)
old_lines = old_text.splitlines(keepends=True)
window = max(1, len(old_lines))
best_ratio, best_start = -1.0, 0
best_lines: list[str] = []
for i in range(max(1, len(lines) - window + 1)):
current = lines[i : i + window]
ratio = difflib.SequenceMatcher(None, "".join(old_lines), "".join(current)).ratio()
if ratio > best_ratio:
best_ratio, best_start, best_lines = ratio, i, current
label = f" after header {header!r}" if header else ""
if best_ratio <= 0:
return f"hunk does not match {path}{label}"
diff = "\n".join(difflib.unified_diff(
old_lines,
best_lines,
fromfile="patch hunk",
tofile=f"{path} (actual, line {best_start + 1})",
lineterm="",
))
return (
f"hunk does not match {path}{label}. "
f"Best match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
)
def _apply_hunks(path: str, content: str, hunks: list[_Hunk]) -> str:
cursor = 0
for hunk in hunks:
old_lines = [text for marker, text in hunk.lines if marker in {" ", "-"}]
new_lines = [text for marker, text in hunk.lines if marker in {" ", "+"}]
old_text = _lines_to_text(old_lines)
new_text = _lines_to_text(new_lines)
search_start = cursor
line_hint = None
if hunk.header:
line_hint = _line_hint(hunk.header)
if line_hint is not None:
search_start = _line_offset(content, line_hint)
else:
header_pos = content.find(hunk.header, cursor)
if header_pos >= 0:
search_start = header_pos
if old_text:
pos, match_len = _find_with_eof_fallback(content, old_text, search_start)
if pos < 0 and search_start != 0 and line_hint is None:
pos, match_len = _find_with_eof_fallback(content, old_text, 0)
if pos < 0:
raise _PatchError(_hunk_mismatch(path, old_text, content, hunk.header))
else:
pos = search_start
match_len = 0
content = content[:pos] + new_text + content[pos + match_len:]
cursor = pos + len(new_text)
return content
@tool_parameters(
tool_parameters_schema(
patch=StringSchema(
"Full patch text. Use *** Begin Patch / *** End Patch and file sections "
"for Add File, Update File, Delete File, and optional Move to.",
min_length=1,
),
required=["patch"],
)
)
class ApplyPatchTool(_FsTool):
"""Apply a structured multi-file patch."""
_scopes = {"core", "subagent"}
@property
def name(self) -> str:
return "apply_patch"
@property
def description(self) -> str:
return (
"Apply a structured patch for code edits. The patch must include "
"*** Begin Patch and *** End Patch. Supports Add File, Update File, "
"Delete File, and Move to. Paths must be relative. Prefer this for "
"multi-file coding changes; use edit_file for small exact replacements."
)
async def execute(self, patch: str, **kwargs: Any) -> str:
try:
ops = _parse_patch(patch)
writes: dict[Path, str] = {}
deletes: set[Path] = set()
touched: list[str] = []
for op in ops:
source = self._resolve(op.path)
if op.kind == "add":
if source.exists():
raise _PatchError(f"file to add already exists: {op.path}")
writes[source] = _lines_to_text(op.add_lines or [])
deletes.discard(source)
touched.append(f"add {op.path}")
continue
if op.kind == "delete":
if not source.exists():
raise _PatchError(f"file to delete does not exist: {op.path}")
if not source.is_file():
raise _PatchError(f"path to delete is not a file: {op.path}")
deletes.add(source)
writes.pop(source, None)
touched.append(f"delete {op.path}")
continue
if not source.exists():
raise _PatchError(f"file to update does not exist: {op.path}")
if not source.is_file():
raise _PatchError(f"path to update is not a file: {op.path}")
raw = source.read_bytes()
try:
content = raw.decode("utf-8")
except UnicodeDecodeError as exc:
raise _PatchError(f"file to update is not UTF-8 text: {op.path}") from exc
uses_crlf = "\r\n" in content
content = content.replace("\r\n", "\n")
new_content = _apply_hunks(op.path, content, op.hunks or [])
if uses_crlf:
new_content = new_content.replace("\n", "\r\n")
target = self._resolve(op.new_path) if op.new_path else source
if op.new_path and target.exists() and target != source:
raise _PatchError(f"move target already exists: {op.new_path}")
writes[target] = new_content
deletes.discard(target)
if target != source:
deletes.add(source)
action = f"move {op.path} -> {op.new_path}" if op.new_path else f"update {op.path}"
touched.append(action)
backups: dict[Path, bytes | None] = {}
for path in set(writes) | deletes:
backups[path] = path.read_bytes() if path.exists() else None
try:
for path in deletes:
if path.exists():
path.unlink()
for path, content in writes.items():
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(content, encoding="utf-8", newline="")
except Exception:
for path, data in backups.items():
if data is None:
if path.exists():
path.unlink()
else:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(data)
raise
for path in set(writes) | deletes:
self._file_states.record_write(path)
return "Patch applied:\n" + "\n".join(f"- {item}" for item in touched)
except PermissionError as exc:
return f"Error: {exc}"
except _PatchError as exc:
return f"Error applying patch: {exc}"
except Exception as exc:
return f"Error applying patch: {exc}"

View File

@ -0,0 +1,409 @@
"""Session support for long-running exec workflows."""
from __future__ import annotations
import asyncio
import shutil
import time
import uuid
from contextlib import suppress
from dataclasses import dataclass
from typing import Any
from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
DEFAULT_YIELD_MS = 1000
MAX_YIELD_MS = 30_000
DEFAULT_MAX_OUTPUT_CHARS = 10_000
MAX_OUTPUT_CHARS = 50_000
@dataclass(slots=True)
class _SessionPoll:
output: str
done: bool
exit_code: int | None
timed_out: bool = False
terminated: bool = False
stdin_closed: bool = False
truncated_chars: int = 0
class _ExecSession:
def __init__(
self,
*,
session_id: str,
process: asyncio.subprocess.Process,
timeout: int,
) -> None:
self.session_id = session_id
self.process = process
self.deadline = time.monotonic() + timeout
self.last_access = time.monotonic()
self._chunks: list[str] = []
self._lock = asyncio.Lock()
self._timed_out = False
self._stdout_task = asyncio.create_task(self._read_stream(process.stdout, ""))
self._stderr_task = asyncio.create_task(self._read_stream(process.stderr, "STDERR:\n"))
async def _read_stream(
self,
stream: asyncio.StreamReader | None,
prefix: str,
) -> None:
if stream is None:
return
first = True
while True:
chunk = await stream.read(4096)
if not chunk:
break
text = chunk.decode("utf-8", errors="replace")
if prefix and first:
text = prefix + text
first = False
async with self._lock:
self._chunks.append(text)
async def write(self, chars: str) -> str | None:
if self.process.returncode is not None:
return "session has already exited"
if self.process.stdin is None:
return "session stdin is not available"
try:
self.process.stdin.write(chars.encode("utf-8"))
await self.process.stdin.drain()
except (BrokenPipeError, ConnectionResetError):
return "session stdin is closed"
return None
async def close_stdin(self) -> str | None:
if self.process.returncode is not None:
return "session has already exited"
if self.process.stdin is None:
return "session stdin is not available"
self.process.stdin.close()
with suppress(BrokenPipeError, ConnectionResetError):
await self.process.stdin.wait_closed()
return None
async def poll(
self,
yield_time_ms: int,
max_output_chars: int,
*,
terminated: bool = False,
stdin_closed: bool = False,
) -> _SessionPoll:
self.last_access = time.monotonic()
if yield_time_ms > 0 and self.process.returncode is None:
await asyncio.sleep(min(yield_time_ms, MAX_YIELD_MS) / 1000)
if self.process.returncode is None and time.monotonic() >= self.deadline:
self._timed_out = True
await self.kill()
if self.process.returncode is not None:
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(
asyncio.gather(self._stdout_task, self._stderr_task),
timeout=2.0,
)
async with self._lock:
output = "".join(self._chunks)
self._chunks.clear()
output, truncated = _truncate_output(output, max_output_chars)
return _SessionPoll(
output=output,
done=self.process.returncode is not None,
exit_code=self.process.returncode,
timed_out=self._timed_out,
terminated=terminated,
stdin_closed=stdin_closed,
truncated_chars=truncated,
)
async def kill(self) -> None:
if self.process.returncode is not None:
return
self.process.kill()
with suppress(asyncio.TimeoutError):
await asyncio.wait_for(self.process.wait(), timeout=5.0)
class ExecSessionManager:
def __init__(self, *, max_sessions: int = 8, idle_timeout: int = 1800) -> None:
self.max_sessions = max_sessions
self.idle_timeout = idle_timeout
self._sessions: dict[str, _ExecSession] = {}
self._lock = asyncio.Lock()
async def start(
self,
*,
command: str,
cwd: str,
env: dict[str, str],
timeout: int,
shell_program: str | None,
login: bool,
yield_time_ms: int,
max_output_chars: int,
) -> tuple[str, _SessionPoll]:
async with self._lock:
await self._cleanup_locked()
if len(self._sessions) >= self.max_sessions:
raise RuntimeError(f"maximum exec sessions reached ({self.max_sessions})")
process = await self._spawn(command, cwd, env, shell_program, login)
session_id = uuid.uuid4().hex[:12]
session = _ExecSession(session_id=session_id, process=process, timeout=timeout)
self._sessions[session_id] = session
poll = await session.poll(yield_time_ms, max_output_chars)
if poll.done:
async with self._lock:
self._sessions.pop(session_id, None)
return session_id, poll
async def write(
self,
*,
session_id: str,
chars: str | None,
close_stdin: bool,
terminate: bool,
yield_time_ms: int,
max_output_chars: int,
) -> _SessionPoll:
async with self._lock:
await self._cleanup_locked()
session = self._sessions.get(session_id)
if session is None:
raise KeyError(session_id)
if chars is not None:
error = await session.write(chars)
if error:
raise RuntimeError(error)
stdin_closed = False
if close_stdin:
error = await session.close_stdin()
if error:
raise RuntimeError(error)
stdin_closed = True
if terminate:
await session.kill()
poll = await session.poll(
yield_time_ms,
max_output_chars,
terminated=terminate,
stdin_closed=stdin_closed,
)
if poll.done:
async with self._lock:
self._sessions.pop(session_id, None)
return poll
async def _cleanup_locked(self) -> None:
now = time.monotonic()
stale = [
session_id
for session_id, session in self._sessions.items()
if session.process.returncode is not None
or now - session.last_access > self.idle_timeout
]
for session_id in stale:
session = self._sessions.pop(session_id)
await session.kill()
async def _spawn(
self,
command: str,
cwd: str,
env: dict[str, str],
shell_program: str | None,
login: bool,
) -> asyncio.subprocess.Process:
from nanobot.agent.tools import shell
if shell._IS_WINDOWS:
return await asyncio.create_subprocess_shell(
command,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
env=env,
)
shell_program = shell_program or shutil.which("bash") or "/bin/bash"
args = [shell_program]
if login and shell_program.rsplit("/", 1)[-1] in {"bash", "zsh"}:
args.append("-l")
args.extend(["-c", command])
return await asyncio.create_subprocess_exec(
*args,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=cwd,
env=env,
)
DEFAULT_EXEC_SESSION_MANAGER = ExecSessionManager()
def clamp_session_int(value: int | None, default: int, minimum: int, maximum: int) -> int:
if value is None:
return default
return min(max(value, minimum), maximum)
def _truncate_output(output: str, max_output_chars: int) -> tuple[str, int]:
if len(output) <= max_output_chars:
return output, 0
half = max_output_chars // 2
omitted = len(output) - max_output_chars
return (
output[:half]
+ f"\n\n... ({omitted:,} chars truncated) ...\n\n"
+ output[-half:],
omitted,
)
def format_session_poll(session_id: str, poll: _SessionPoll) -> str:
parts = [poll.output] if poll.output else []
if poll.truncated_chars:
parts.append(f"(output truncated by {poll.truncated_chars:,} chars)")
if poll.timed_out:
parts.append("Error: Command timed out; session was terminated.")
if poll.terminated and not poll.timed_out:
parts.append("Session terminated.")
if poll.stdin_closed:
parts.append("Stdin closed.")
if poll.done:
parts.append(f"Exit code: {poll.exit_code}")
else:
parts.append(f"Process running. session_id: {session_id}")
return "\n".join(parts) if parts else "(no output yet)"
@tool_parameters(
tool_parameters_schema(
session_id=StringSchema("Session id returned by exec when yield_time_ms is used."),
chars=StringSchema(
"Bytes/text to write to stdin. Omit or pass an empty string to only poll recent output.",
nullable=True,
),
close_stdin=BooleanSchema(
description="Close stdin after writing chars. Useful for commands waiting for EOF.",
default=False,
),
terminate=BooleanSchema(
description="Terminate the running exec session.",
default=False,
),
yield_time_ms=IntegerSchema(
DEFAULT_YIELD_MS,
description="Milliseconds to wait before returning recent output (default 1000, max 30000).",
minimum=0,
maximum=MAX_YIELD_MS,
),
max_output_chars=IntegerSchema(
DEFAULT_MAX_OUTPUT_CHARS,
description="Maximum output characters to return from this poll (default 10000, max 50000).",
minimum=1000,
maximum=MAX_OUTPUT_CHARS,
),
max_output_tokens=IntegerSchema(
DEFAULT_MAX_OUTPUT_CHARS,
description="Compatibility alias for max_output_chars. The current runtime uses a character budget.",
minimum=1000,
maximum=MAX_OUTPUT_CHARS,
nullable=True,
),
required=["session_id"],
)
)
class WriteStdinTool(Tool):
"""Write to or poll a running exec session."""
_scopes = {"core", "subagent"}
config_key = "exec"
@classmethod
def config_cls(cls):
from nanobot.agent.tools.shell import ExecToolConfig
return ExecToolConfig
@classmethod
def enabled(cls, ctx: Any) -> bool:
return ctx.config.exec.enable
def __init__(
self,
*,
manager: ExecSessionManager | None = None,
) -> None:
self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER
@classmethod
def create(cls, ctx: Any) -> Tool:
return cls()
@property
def exclusive(self) -> bool:
return True
@property
def name(self) -> str:
return "write_stdin"
@property
def description(self) -> str:
return (
"Write text to a running exec session and return recent output. "
"Use chars='' to poll without writing. Set close_stdin=true to send EOF, "
"or terminate=true to stop the session. Sessions finish automatically "
"when their process exits."
)
async def execute(
self,
session_id: str,
chars: str | None = None,
close_stdin: bool = False,
terminate: bool = False,
yield_time_ms: int | None = None,
max_output_chars: int | None = None,
max_output_tokens: int | None = None,
**kwargs: Any,
) -> str:
try:
if max_output_chars is None:
max_output_chars = max_output_tokens
poll = await self._manager.write(
session_id=session_id,
chars=chars,
close_stdin=close_stdin,
terminate=terminate,
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
max_output_chars=clamp_session_int(
max_output_chars,
DEFAULT_MAX_OUTPUT_CHARS,
1000,
MAX_OUTPUT_CHARS,
),
)
return format_session_poll(session_id, poll)
except KeyError:
return f"Error: exec session not found: {session_id}"
except Exception as exc:
return f"Error writing to exec session: {exc}"

View File

@ -132,6 +132,10 @@ def _parse_page_range(pages: str, total: int) -> tuple[int, int]:
minimum=1, minimum=1,
), ),
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"), pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
force=BooleanSchema(
description="Bypass same-file read deduplication and return content again.",
default=False,
),
required=["path"], required=["path"],
) )
) )
@ -155,6 +159,7 @@ class ReadFileTool(_FsTool):
"Images return visual content for analysis. " "Images return visual content for analysis. "
"Supports PDF, DOCX, XLSX, PPTX documents. " "Supports PDF, DOCX, XLSX, PPTX documents. "
"Use offset and limit for large text files. " "Use offset and limit for large text files. "
"Use force=true to re-read content even if unchanged. "
"Reads exceeding ~128K chars are truncated." "Reads exceeding ~128K chars are truncated."
) )
@ -162,7 +167,15 @@ class ReadFileTool(_FsTool):
def read_only(self) -> bool: def read_only(self) -> bool:
return True return True
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any: async def execute(
self,
path: str | None = None,
offset: int = 1,
limit: int | None = None,
pages: str | None = None,
force: bool = False,
**kwargs: Any,
) -> Any:
try: try:
if not path: if not path:
return "Error reading file: Unknown path" return "Error reading file: Unknown path"
@ -202,7 +215,13 @@ class ReadFileTool(_FsTool):
current_mtime = os.path.getmtime(fp) current_mtime = os.path.getmtime(fp)
except OSError: except OSError:
current_mtime = 0.0 current_mtime = 0.0
if entry and entry.can_dedup and entry.offset == offset and entry.limit == limit: if (
not force
and entry
and entry.can_dedup
and entry.offset == offset
and entry.limit == limit
):
if current_mtime != entry.mtime: if current_mtime != entry.mtime:
# File was modified externally - force full read and mark as not dedupable # File was modified externally - force full read and mark as not dedupable
entry.can_dedup = False entry.can_dedup = False
@ -657,6 +676,24 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
old_text=StringSchema("The text to find and replace"), old_text=StringSchema("The text to find and replace"),
new_text=StringSchema("The text to replace with"), new_text=StringSchema("The text to replace with"),
replace_all=BooleanSchema(description="Replace all occurrences (default false)"), replace_all=BooleanSchema(description="Replace all occurrences (default false)"),
occurrence=IntegerSchema(
1,
description="Optional 1-based occurrence to replace when old_text appears multiple times.",
minimum=1,
nullable=True,
),
line_hint=IntegerSchema(
1,
description="Optional 1-based line hint used to choose the nearest match.",
minimum=1,
nullable=True,
),
expected_replacements=IntegerSchema(
1,
description="Optional guard for the number of replacements that must be made.",
minimum=1,
nullable=True,
),
required=["path", "old_text", "new_text"], required=["path", "old_text", "new_text"],
) )
) )
@ -677,7 +714,7 @@ class EditFileTool(_FsTool):
"Edit a file by replacing old_text with new_text. " "Edit a file by replacing old_text with new_text. "
"Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. " "Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. "
"If old_text matches multiple times, you must provide more context " "If old_text matches multiple times, you must provide more context "
"or set replace_all=true. Shows a diff of the closest match on failure." "or set occurrence/line_hint/replace_all. Shows a diff of the closest match on failure."
) )
@staticmethod @staticmethod
@ -688,7 +725,8 @@ class EditFileTool(_FsTool):
async def execute( async def execute(
self, path: str | None = None, old_text: str | None = None, self, path: str | None = None, old_text: str | None = None,
new_text: str | None = None, new_text: str | None = None,
replace_all: bool = False, **kwargs: Any, replace_all: bool = False, occurrence: int | None = None,
line_hint: int | None = None, expected_replacements: int | None = None, **kwargs: Any,
) -> str: ) -> str:
try: try:
if not path: if not path:
@ -697,10 +735,12 @@ class EditFileTool(_FsTool):
raise ValueError("Unknown old_text") raise ValueError("Unknown old_text")
if new_text is None: if new_text is None:
raise ValueError("Unknown new_text") raise ValueError("Unknown new_text")
if occurrence is not None and occurrence < 1:
# .ipynb detection return "Error: occurrence must be >= 1."
if path.endswith(".ipynb"): if line_hint is not None and line_hint < 1:
return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file." return "Error: line_hint must be >= 1."
if expected_replacements is not None and expected_replacements < 1:
return "Error: expected_replacements must be >= 1."
fp = self._resolve(path) fp = self._resolve(path)
@ -743,7 +783,28 @@ class EditFileTool(_FsTool):
if not matches: if not matches:
return self._not_found_msg(old_text, content, path) return self._not_found_msg(old_text, content, path)
count = len(matches) count = len(matches)
if replace_all and occurrence is not None:
return "Error: occurrence cannot be used with replace_all=true."
if replace_all and line_hint is not None:
return "Error: line_hint cannot be used with replace_all=true."
if occurrence is not None and line_hint is not None:
return "Error: line_hint cannot be used with occurrence."
if count > 1 and not replace_all: if count > 1 and not replace_all:
if occurrence is not None:
if occurrence > count:
return (
f"Error: occurrence {occurrence} is out of range; "
f"old_text appears {count} times."
)
elif line_hint is not None:
nearest = min(matches, key=lambda match: abs(match.line - line_hint))
distance = abs(nearest.line - line_hint)
if sum(1 for match in matches if abs(match.line - line_hint) == distance) > 1:
return (
f"Error: line_hint {line_hint} is ambiguous; "
f"old_text appears {count} times."
)
else:
line_numbers = [match.line for match in matches] line_numbers = [match.line for match in matches]
preview = ", ".join(f"line {n}" for n in line_numbers[:3]) preview = ", ".join(f"line {n}" for n in line_numbers[:3])
if len(line_numbers) > 3: if len(line_numbers) > 3:
@ -751,7 +812,13 @@ class EditFileTool(_FsTool):
location_hint = f" at {preview}" if preview else "" location_hint = f" at {preview}" if preview else ""
return ( return (
f"Warning: old_text appears {count} times{location_hint}. " f"Warning: old_text appears {count} times{location_hint}. "
"Provide more context to make it unique, or set replace_all=true." "Provide more context, set occurrence to choose one match, "
"or set replace_all=true."
)
elif occurrence is not None and occurrence > count:
return (
f"Error: occurrence {occurrence} is out of range; "
f"old_text appears {count} time."
) )
norm_new = new_text.replace("\r\n", "\n") norm_new = new_text.replace("\r\n", "\n")
@ -760,7 +827,17 @@ class EditFileTool(_FsTool):
if fp.suffix.lower() not in self._MARKDOWN_EXTS: if fp.suffix.lower() not in self._MARKDOWN_EXTS:
norm_new = self._strip_trailing_ws(norm_new) norm_new = self._strip_trailing_ws(norm_new)
selected = matches if replace_all else matches[:1] if replace_all:
selected = matches
elif line_hint is not None:
selected = [min(matches, key=lambda match: abs(match.line - line_hint))]
else:
selected = [matches[occurrence - 1 if occurrence else 0]]
if expected_replacements is not None and len(selected) != expected_replacements:
return (
f"Error: expected {expected_replacements} replacements but "
f"would make {len(selected)}."
)
new_content = content new_content = content
for match in reversed(selected): for match in reversed(selected):
replacement = _preserve_quote_style(norm_old, match.text, norm_new) replacement = _preserve_quote_style(norm_old, match.text, norm_new)

View File

@ -1,162 +0,0 @@
"""NotebookEditTool — edit Jupyter .ipynb notebooks."""
from __future__ import annotations
import json
import uuid
from typing import Any
from nanobot.agent.tools.base import tool_parameters
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
from nanobot.agent.tools.filesystem import _FsTool
def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict:
cell: dict[str, Any] = {
"cell_type": cell_type,
"source": source,
"metadata": {},
}
if cell_type == "code":
cell["outputs"] = []
cell["execution_count"] = None
if generate_id:
cell["id"] = uuid.uuid4().hex[:8]
return cell
def _make_empty_notebook() -> dict:
return {
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
"language_info": {"name": "python"},
},
"cells": [],
}
@tool_parameters(
tool_parameters_schema(
path=StringSchema("Path to the .ipynb notebook file"),
cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0),
new_source=StringSchema("New source content for the cell"),
cell_type=StringSchema(
"Cell type: 'code' or 'markdown' (default: code)",
enum=["code", "markdown"],
),
edit_mode=StringSchema(
"Mode: 'replace' (default), 'insert' (after target), or 'delete'",
enum=["replace", "insert", "delete"],
),
required=["path", "cell_index"],
)
)
class NotebookEditTool(_FsTool):
"""Edit Jupyter notebook cells: replace, insert, or delete."""
_scopes = {"core"}
_VALID_CELL_TYPES = frozenset({"code", "markdown"})
_VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"})
@property
def name(self) -> str:
return "notebook_edit"
@property
def description(self) -> str:
return (
"Edit a Jupyter notebook (.ipynb) cell. "
"Modes: replace (default) replaces cell content, "
"insert adds a new cell after the target index, "
"delete removes the cell at the index. "
"cell_index is 0-based."
)
async def execute(
self,
path: str | None = None,
cell_index: int = 0,
new_source: str = "",
cell_type: str = "code",
edit_mode: str = "replace",
**kwargs: Any,
) -> str:
try:
if not path:
return "Error: path is required"
if not path.endswith(".ipynb"):
return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files."
if edit_mode not in self._VALID_EDIT_MODES:
return (
f"Error: Invalid edit_mode '{edit_mode}'. "
"Use one of: replace, insert, delete."
)
if cell_type not in self._VALID_CELL_TYPES:
return (
f"Error: Invalid cell_type '{cell_type}'. "
"Use one of: code, markdown."
)
fp = self._resolve(path)
# Create new notebook if file doesn't exist and mode is insert
if not fp.exists():
if edit_mode != "insert":
return f"Error: File not found: {path}"
nb = _make_empty_notebook()
cell = _new_cell(new_source, cell_type, generate_id=True)
nb["cells"].append(cell)
fp.parent.mkdir(parents=True, exist_ok=True)
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
return f"Successfully created {fp} with 1 cell"
try:
nb = json.loads(fp.read_text(encoding="utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
return f"Error: Failed to parse notebook: {e}"
cells = nb.get("cells", [])
nbformat_minor = nb.get("nbformat_minor", 0)
generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5
if edit_mode == "delete":
if cell_index < 0 or cell_index >= len(cells):
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
cells.pop(cell_index)
nb["cells"] = cells
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
return f"Successfully deleted cell {cell_index} from {fp}"
if edit_mode == "insert":
insert_at = min(cell_index + 1, len(cells))
cell = _new_cell(new_source, cell_type, generate_id=generate_id)
cells.insert(insert_at, cell)
nb["cells"] = cells
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
return f"Successfully inserted cell at index {insert_at} in {fp}"
# Default: replace
if cell_index < 0 or cell_index >= len(cells):
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
cells[cell_index]["source"] = new_source
if cell_type and cells[cell_index].get("cell_type") != cell_type:
cells[cell_index]["cell_type"] = cell_type
if cell_type == "code":
cells[cell_index].setdefault("outputs", [])
cells[cell_index].setdefault("execution_count", None)
elif "outputs" in cells[cell_index]:
del cells[cell_index]["outputs"]
cells[cell_index].pop("execution_count", None)
nb["cells"] = cells
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
return f"Successfully edited cell {cell_index} in {fp}"
except PermissionError as e:
return f"Error: {e}"
except Exception as e:
return f"Error editing notebook: {e}"

View File

@ -8,6 +8,7 @@ import re
import shutil import shutil
import sys import sys
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -15,8 +16,17 @@ from loguru import logger
from pydantic import Field from pydantic import Field
from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.exec_session import (
DEFAULT_MAX_OUTPUT_CHARS,
DEFAULT_YIELD_MS,
DEFAULT_EXEC_SESSION_MANAGER,
MAX_OUTPUT_CHARS,
MAX_YIELD_MS,
clamp_session_int,
format_session_poll,
)
from nanobot.agent.tools.sandbox import wrap_command from nanobot.agent.tools.sandbox import wrap_command
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
from nanobot.config.paths import get_media_dir from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base from nanobot.config.schema import Base
@ -44,10 +54,22 @@ class ExecToolConfig(Base):
deny_patterns: list[str] = Field(default_factory=list) deny_patterns: list[str] = Field(default_factory=list)
@dataclass(slots=True)
class _PreparedCommand:
command: str
cwd: str
env: dict[str, str]
timeout: int
shell_program: str | None
login: bool
@tool_parameters( @tool_parameters(
tool_parameters_schema( tool_parameters_schema(
command=StringSchema("The shell command to execute"), command=StringSchema("The shell command to execute"),
cmd=StringSchema("Compatibility alias for command"),
working_dir=StringSchema("Optional working directory for the command"), working_dir=StringSchema("Optional working directory for the command"),
workdir=StringSchema("Compatibility alias for working_dir"),
timeout=IntegerSchema( timeout=IntegerSchema(
60, 60,
description=( description=(
@ -57,7 +79,44 @@ class ExecToolConfig(Base):
minimum=1, minimum=1,
maximum=600, maximum=600,
), ),
required=["command"], shell=StringSchema(
"Optional shell binary to launch. On Unix, supports sh, bash, or zsh.",
nullable=True,
),
login=BooleanSchema(
description="Whether to run bash/zsh with login shell semantics (default true).",
default=True,
nullable=True,
),
yield_time_ms=IntegerSchema(
description=(
"Optional milliseconds to wait before returning output. "
"When set, a still-running command returns a session_id that "
"can be polled or written to with write_stdin. Omit this field "
"to keep one-shot exec behavior."
),
minimum=0,
maximum=MAX_YIELD_MS,
nullable=True,
),
max_output_chars=IntegerSchema(
description=(
"Maximum output characters to return when yield_time_ms is used "
"(default 10000, max 50000)."
),
minimum=1000,
maximum=MAX_OUTPUT_CHARS,
nullable=True,
),
max_output_tokens=IntegerSchema(
description=(
"Compatibility alias for max_output_chars. The current runtime "
"uses a character budget."
),
minimum=1000,
maximum=MAX_OUTPUT_CHARS,
nullable=True,
),
) )
) )
class ExecTool(Tool): class ExecTool(Tool):
@ -98,6 +157,7 @@ class ExecTool(Tool):
sandbox: str = "", sandbox: str = "",
path_append: str = "", path_append: str = "",
allowed_env_keys: list[str] | None = None, allowed_env_keys: list[str] | None = None,
session_manager: Any | None = None,
): ):
self.timeout = timeout self.timeout = timeout
self.working_dir = working_dir self.working_dir = working_dir
@ -125,6 +185,7 @@ class ExecTool(Tool):
self.restrict_to_workspace = restrict_to_workspace self.restrict_to_workspace = restrict_to_workspace
self.path_append = path_append self.path_append = path_append
self.allowed_env_keys = allowed_env_keys or [] self.allowed_env_keys = allowed_env_keys or []
self._session_manager = session_manager or DEFAULT_EXEC_SESSION_MANAGER
@property @property
def name(self) -> str: def name(self) -> str:
@ -153,7 +214,10 @@ class ExecTool(Tool):
"Prefer read_file/write_file/edit_file over cat/echo/sed, " "Prefer read_file/write_file/edit_file over cat/echo/sed, "
"and grep/glob over shell find/grep. " "and grep/glob over shell find/grep. "
"Use -y or --yes flags to avoid interactive prompts. " "Use -y or --yes flags to avoid interactive prompts. "
"Output is truncated at 10 000 chars; timeout defaults to 60s." "For long-running or interactive commands, pass yield_time_ms; "
"if the command keeps running, exec returns a session_id that can "
"be polled or written to with write_stdin. Output is truncated at "
"10 000 chars; timeout defaults to 60s."
) )
@property @property
@ -161,9 +225,111 @@ class ExecTool(Tool):
return True return True
async def execute( async def execute(
self, command: str, working_dir: str | None = None, self, command: str | None = None, cmd: str | None = None,
timeout: int | None = None, **kwargs: Any, working_dir: str | None = None, workdir: str | None = None,
timeout: int | None = None, shell: str | None = None,
login: bool | None = None, yield_time_ms: int | None = None,
max_output_chars: int | None = None,
max_output_tokens: int | None = None,
**kwargs: Any,
) -> str: ) -> str:
command = command or cmd
working_dir = working_dir or workdir
if not command:
return "Error: Missing command. Provide command or cmd."
if max_output_chars is None:
max_output_chars = max_output_tokens
prepared = self._prepare_command(command, working_dir, timeout, shell, login)
if isinstance(prepared, str):
return prepared
if yield_time_ms is not None:
return await self._execute_session(prepared, yield_time_ms, max_output_chars)
try:
process = await self._spawn(
prepared.command,
prepared.cwd,
prepared.env,
prepared.shell_program,
prepared.login,
)
try:
stdout, stderr = await asyncio.wait_for(
process.communicate(),
timeout=prepared.timeout,
)
except asyncio.TimeoutError:
await self._kill_process(process)
return f"Error: Command timed out after {prepared.timeout} seconds"
except asyncio.CancelledError:
await self._kill_process(process)
raise
output_parts = []
if stdout:
output_parts.append(stdout.decode("utf-8", errors="replace"))
if stderr:
stderr_text = stderr.decode("utf-8", errors="replace")
if stderr_text.strip():
output_parts.append(f"STDERR:\n{stderr_text}")
output_parts.append(f"\nExit code: {process.returncode}")
result = "\n".join(output_parts) if output_parts else "(no output)"
max_len = clamp_session_int(max_output_chars, self._MAX_OUTPUT, 1000, MAX_OUTPUT_CHARS)
if len(result) > max_len:
half = max_len // 2
result = (
result[:half]
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
+ result[-half:]
)
return result
except Exception as e:
return f"Error executing command: {str(e)}"
async def _execute_session(
self,
prepared: _PreparedCommand,
yield_time_ms: int | None,
max_output_chars: int | None,
) -> str:
try:
session_id, poll = await self._session_manager.start(
command=prepared.command,
cwd=prepared.cwd,
env=prepared.env,
timeout=prepared.timeout,
shell_program=prepared.shell_program,
login=prepared.login,
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
max_output_chars=clamp_session_int(
max_output_chars,
DEFAULT_MAX_OUTPUT_CHARS,
1000,
MAX_OUTPUT_CHARS,
),
)
return format_session_poll(session_id, poll)
except Exception as exc:
return f"Error executing command: {exc}"
def _prepare_command(
self,
command: str,
working_dir: str | None = None,
timeout: int | None = None,
shell: str | None = None,
login: bool | None = None,
) -> _PreparedCommand | str:
cwd = working_dir or self.working_dir or os.getcwd() cwd = working_dir or self.working_dir or os.getcwd()
# Prevent an LLM-supplied working_dir from escaping the configured # Prevent an LLM-supplied working_dir from escaping the configured
@ -211,52 +377,24 @@ class ExecTool(Tool):
env["NANOBOT_PATH_APPEND"] = self.path_append env["NANOBOT_PATH_APPEND"] = self.path_append
command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}' command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}'
try: shell_program, shell_error = self._resolve_shell(shell)
process = await self._spawn(command, cwd, env) if shell_error:
return shell_error
try: return _PreparedCommand(
stdout, stderr = await asyncio.wait_for( command=command,
process.communicate(), cwd=cwd,
env=env,
timeout=effective_timeout, timeout=effective_timeout,
shell_program=shell_program,
login=True if login is None else login,
) )
except asyncio.TimeoutError:
await self._kill_process(process)
return f"Error: Command timed out after {effective_timeout} seconds"
except asyncio.CancelledError:
await self._kill_process(process)
raise
output_parts = []
if stdout:
output_parts.append(stdout.decode("utf-8", errors="replace"))
if stderr:
stderr_text = stderr.decode("utf-8", errors="replace")
if stderr_text.strip():
output_parts.append(f"STDERR:\n{stderr_text}")
output_parts.append(f"\nExit code: {process.returncode}")
result = "\n".join(output_parts) if output_parts else "(no output)"
max_len = self._MAX_OUTPUT
if len(result) > max_len:
half = max_len // 2
result = (
result[:half]
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
+ result[-half:]
)
return result
except Exception as e:
return f"Error executing command: {str(e)}"
@staticmethod @staticmethod
async def _spawn( async def _spawn(
command: str, cwd: str, env: dict[str, str], command: str, cwd: str, env: dict[str, str],
shell_program: str | None = None,
login: bool = True,
) -> asyncio.subprocess.Process: ) -> asyncio.subprocess.Process:
"""Launch *command* in a platform-appropriate shell.""" """Launch *command* in a platform-appropriate shell."""
if _IS_WINDOWS: if _IS_WINDOWS:
@ -272,9 +410,13 @@ class ExecTool(Tool):
cwd=cwd, cwd=cwd,
env=env, env=env,
) )
bash = shutil.which("bash") or "/bin/bash" shell_program = shell_program or shutil.which("bash") or "/bin/bash"
args = [shell_program]
if login and Path(shell_program).name in {"bash", "zsh"}:
args.append("-l")
args.extend(["-c", command])
return await asyncio.create_subprocess_exec( return await asyncio.create_subprocess_exec(
bash, "-l", "-c", command, *args,
stdin=asyncio.subprocess.DEVNULL, stdin=asyncio.subprocess.DEVNULL,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
@ -282,6 +424,31 @@ class ExecTool(Tool):
env=env, env=env,
) )
@staticmethod
def _resolve_shell(shell: str | None) -> tuple[str | None, str | None]:
if not shell:
return None, None
if _IS_WINDOWS:
return None, "Error: shell parameter is not supported on Windows"
if "\0" in shell or "\n" in shell or "\r" in shell:
return None, "Error: shell contains invalid characters"
allowed = {"sh", "bash", "zsh"}
path = Path(shell).expanduser()
if path.is_absolute():
if path.name not in allowed:
return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
if not path.is_file() or not os.access(path, os.X_OK):
return None, f"Error: shell is not executable: {shell}"
return str(path), None
if "/" in shell or "\\" in shell:
return None, "Error: shell must be a shell name or absolute path"
if shell not in allowed:
return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
resolved = shutil.which(shell)
if not resolved:
return None, f"Error: shell not found: {shell}"
return resolved, None
@staticmethod @staticmethod
async def _kill_process(process: asyncio.subprocess.Process) -> None: async def _kill_process(process: asyncio.subprocess.Process) -> None:
"""Kill a subprocess and reap it to prevent zombies.""" """Kill a subprocess and reap it to prevent zombies."""

View File

@ -3,7 +3,6 @@
from __future__ import annotations from __future__ import annotations
import difflib import difflib
import json
import re import re
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -11,7 +10,7 @@ from pathlib import Path
from typing import Any, Awaitable, Callable from typing import Any, Awaitable, Callable
TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"}) TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file"})
_MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024 _MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024
_LIVE_EMIT_INTERVAL_S = 0.18 _LIVE_EMIT_INTERVAL_S = 0.18
_LIVE_EMIT_LINE_STEP = 24 _LIVE_EMIT_LINE_STEP = 24
@ -704,77 +703,4 @@ def _predict_after_text(
return before_text.replace(old_text, new_text) return before_text.replace(old_text, new_text)
return before_text.replace(old_text, new_text, 1) return before_text.replace(old_text, new_text, 1)
return None return None
if tool_name == "notebook_edit":
return _predict_notebook_after_text(params, before_text)
return None return None
def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> str | None:
try:
nb = json.loads(before_text) if before_text.strip() else _empty_notebook()
except Exception:
return None
cells = nb.get("cells")
if not isinstance(cells, list):
return None
try:
cell_index = int(params.get("cell_index", 0))
except (TypeError, ValueError):
return None
new_source = params.get("new_source")
source = new_source if isinstance(new_source, str) else ""
cell_type = (
params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code"
)
mode = (
params.get("edit_mode")
if params.get("edit_mode") in ("replace", "insert", "delete")
else "replace"
)
if mode == "delete":
if 0 <= cell_index < len(cells):
cells.pop(cell_index)
else:
return None
elif mode == "insert":
insert_at = min(max(cell_index + 1, 0), len(cells))
cells.insert(insert_at, _new_notebook_cell(source, str(cell_type)))
else:
if not (0 <= cell_index < len(cells)):
return None
cell = cells[cell_index]
if not isinstance(cell, dict):
return None
cell["source"] = source
cell["cell_type"] = cell_type
if cell_type == "code":
cell.setdefault("outputs", [])
cell.setdefault("execution_count", None)
else:
cell.pop("outputs", None)
cell.pop("execution_count", None)
nb["cells"] = cells
try:
return json.dumps(nb, indent=1, ensure_ascii=False)
except Exception:
return None
def _empty_notebook() -> dict[str, Any]:
return {
"nbformat": 4,
"nbformat_minor": 5,
"metadata": {
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
"language_info": {"name": "python"},
},
"cells": [],
}
def _new_notebook_cell(source: str, cell_type: str) -> dict[str, Any]:
cell: dict[str, Any] = {"cell_type": cell_type, "source": source, "metadata": {}}
if cell_type == "code":
cell["outputs"] = []
cell["execution_count"] = None
return cell

View File

@ -0,0 +1,238 @@
from __future__ import annotations
import asyncio
from nanobot.agent.tools.apply_patch import ApplyPatchTool
def test_apply_patch_adds_file(tmp_path):
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Add File: hello.txt
+Hello
+world
*** End Patch
"""
))
assert "Patch applied" in result
assert (tmp_path / "hello.txt").read_text() == "Hello\nworld\n"
def test_apply_patch_updates_multiple_hunks(tmp_path):
target = tmp_path / "multi.txt"
target.write_text("line1\nline2\nline3\nline4\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: multi.txt
@@
-line2
+changed2
@@
-line4
+changed4
*** End Patch
"""
))
assert "update multi.txt" in result
assert target.read_text() == "line1\nchanged2\nline3\nchanged4\n"
def test_apply_patch_ignores_standard_no_newline_marker(tmp_path):
target = tmp_path / "plain.txt"
target.write_text("before")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: plain.txt
@@ -1,1 +1,1 @@
-before
\\ No newline at end of file
+after
\\ No newline at end of file
*** End Patch
"""
))
assert "update plain.txt" in result
assert target.read_text() == "after\n"
def test_apply_patch_rejects_empty_hunk(tmp_path):
target = tmp_path / "plain.txt"
target.write_text("before\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: plain.txt
@@
*** End Patch
"""
))
assert "hunk is empty" in result
assert target.read_text() == "before\n"
def test_apply_patch_uses_unified_diff_line_hint(tmp_path):
target = tmp_path / "repeated.txt"
target.write_text("target\nmiddle\ntarget\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: repeated.txt
@@ -3,1 +3,1 @@
-target
+changed
*** End Patch
"""
))
assert "update repeated.txt" in result
assert target.read_text() == "target\nmiddle\nchanged\n"
def test_apply_patch_line_hint_does_not_fallback_to_earlier_match(tmp_path):
target = tmp_path / "repeated.txt"
target.write_text("target\nmiddle\nother\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: repeated.txt
@@ -3,1 +3,1 @@
-target
+changed
*** End Patch
"""
))
assert "hunk does not match repeated.txt" in result
assert target.read_text() == "target\nmiddle\nother\n"
def test_apply_patch_mismatch_reports_best_match(tmp_path):
target = tmp_path / "near.txt"
target.write_text("alpha\nbeta\ngamma\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: near.txt
@@ -2,1 +2,1 @@
-betx
+delta
*** End Patch
"""
))
assert "hunk does not match near.txt" in result
assert "Best match" in result
assert "line 2" in result
assert target.read_text() == "alpha\nbeta\ngamma\n"
def test_apply_patch_moves_and_updates_file(tmp_path):
source = tmp_path / "old" / "name.txt"
source.parent.mkdir()
source.write_text("old content\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: old/name.txt
*** Move to: renamed/dir/name.txt
@@
-old content
+new content
*** End Patch
"""
))
assert "move old/name.txt -> renamed/dir/name.txt" in result
assert not source.exists()
assert (tmp_path / "renamed" / "dir" / "name.txt").read_text() == "new content\n"
def test_apply_patch_deletes_file(tmp_path):
target = tmp_path / "obsolete.txt"
target.write_text("remove me\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Delete File: obsolete.txt
*** End Patch
"""
))
assert "delete obsolete.txt" in result
assert not target.exists()
def test_apply_patch_rejects_absolute_and_parent_paths(tmp_path):
tool = ApplyPatchTool(workspace=tmp_path)
absolute = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Add File: /tmp/owned.txt
+nope
*** End Patch
"""
))
parent = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Add File: ../owned.txt
+nope
*** End Patch
"""
))
assert "must be relative" in absolute
assert "must not contain '..'" in parent
assert not (tmp_path.parent / "owned.txt").exists()
def test_apply_patch_does_not_overwrite_existing_file_with_add(tmp_path):
target = tmp_path / "existing.txt"
target.write_text("keep me\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Add File: existing.txt
+replace me
*** End Patch
"""
))
assert "file to add already exists" in result
assert target.read_text() == "keep me\n"
def test_apply_patch_rolls_back_when_late_operation_fails(tmp_path):
first = tmp_path / "first.txt"
first.write_text("before\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: first.txt
@@
-before
+after
*** Delete File: missing.txt
*** End Patch
"""
))
assert "file to delete does not exist" in result
assert first.read_text() == "before\n"

View File

@ -1,5 +1,5 @@
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions, """Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
.ipynb detection, and create-file semantics.""" notebook JSON editing, and create-file semantics."""
import pytest import pytest
@ -108,22 +108,27 @@ class TestEditCreateFile:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# .ipynb detection # .ipynb editing
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestEditIpynbDetection: class TestEditIpynbFiles:
"""edit_file should refuse .ipynb and suggest notebook_edit.""" """edit_file edits notebooks as normal JSON files."""
@pytest.fixture() @pytest.fixture()
def tool(self, tmp_path): def tool(self, tmp_path):
return EditFileTool(workspace=tmp_path) return EditFileTool(workspace=tmp_path)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path): async def test_ipynb_can_be_edited_as_json(self, tool, tmp_path):
f = tmp_path / "analysis.ipynb" f = tmp_path / "analysis.ipynb"
f.write_text('{"cells": []}', encoding="utf-8") f.write_text('{"cells": []}', encoding="utf-8")
result = await tool.execute(path=str(f), old_text="x", new_text="y") result = await tool.execute(
assert "notebook" in result.lower() path=str(f),
old_text='"cells": []',
new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
)
assert "Successfully edited" in result
assert '"source": "hi"' in f.read_text(encoding="utf-8")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@ -0,0 +1,242 @@
from __future__ import annotations
import asyncio
import re
import shlex
import subprocess
import sys
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.exec_session import ExecSessionManager, WriteStdinTool
def _python_command(code: str) -> str:
if sys.platform == "win32":
return f"{subprocess.list2cmdline([sys.executable])} -u -c {subprocess.list2cmdline([code])}"
return f"{shlex.quote(sys.executable)} -u -c {shlex.quote(code)}"
def _session_id(output: str) -> str:
match = re.search(r"session_id:\s*([0-9a-f]+)", output)
assert match, output
return match.group(1)
def test_exec_keeps_one_shot_behavior_without_yield_time_ms(tmp_path):
async def run() -> str:
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
return await tool.execute(command="echo hello")
result = asyncio.run(run())
assert "hello" in result
assert "Exit code: 0" in result
assert "session_id:" not in result
def test_exec_accepts_command_aliases(tmp_path):
async def run() -> str:
tool = ExecTool(working_dir="/")
return await tool.execute(cmd="pwd", workdir=str(tmp_path))
result = asyncio.run(run())
assert str(tmp_path) in result
assert "Exit code: 0" in result
def test_exec_returns_completed_session_output_when_yield_time_ms_is_used(tmp_path):
async def run() -> str:
manager = ExecSessionManager()
tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
stdin_tool = WriteStdinTool(manager=manager)
result = await tool.execute(command="echo hello", yield_time_ms=1000)
if "session_id:" in result:
sid = _session_id(result)
result += "\n" + await stdin_tool.execute(
session_id=sid,
chars="",
yield_time_ms=1000,
)
return result
result = asyncio.run(run())
assert "hello" in result
assert "Exit code: 0" in result
assert "session_id:" not in result
def test_exec_session_accepts_max_output_tokens_alias(tmp_path):
async def run() -> str:
manager = ExecSessionManager()
tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
command = _python_command("print('A' * 2000)")
return await tool.execute(
command=command,
yield_time_ms=1000,
max_output_tokens=1000,
)
result = asyncio.run(run())
assert "chars truncated" in result
assert "Exit code: 0" in result
def test_exec_one_shot_accepts_max_output_tokens_alias(tmp_path):
async def run() -> str:
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
command = _python_command("print('A' * 2000)")
return await tool.execute(command=command, max_output_tokens=1000)
result = asyncio.run(run())
assert "chars truncated" in result
assert "Exit code: 0" in result
def test_exec_accepts_supported_shell_parameter(tmp_path):
async def run() -> str:
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
return await tool.execute(command="echo shell-ok", shell="sh", login=False)
if sys.platform == "win32":
return
result = asyncio.run(run())
assert "shell-ok" in result
assert "Exit code: 0" in result
def test_exec_rejects_unsupported_shell(tmp_path):
async def run() -> str:
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
return await tool.execute(command="echo no", shell="python")
if sys.platform == "win32":
return
result = asyncio.run(run())
assert "unsupported shell" in result
def test_exec_can_continue_with_stdin(tmp_path):
async def run() -> tuple[str, str]:
manager = ExecSessionManager()
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
stdin_tool = WriteStdinTool(manager=manager)
command = _python_command(
"import sys; print('ready', flush=True); "
"line=sys.stdin.readline(); print('got:' + line.strip(), flush=True)"
)
initial = await exec_tool.execute(command=command, yield_time_ms=500)
sid = _session_id(initial)
result = await stdin_tool.execute(session_id=sid, chars="ping\n", yield_time_ms=1000)
return initial, result
initial, result = asyncio.run(run())
assert "ready" in initial
assert "Process running" in initial
assert "got:ping" in result
assert "Exit code: 0" in result
def test_write_stdin_can_close_stdin(tmp_path):
async def run() -> tuple[str, str]:
manager = ExecSessionManager()
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
stdin_tool = WriteStdinTool(manager=manager)
command = _python_command(
"import sys; print('ready', flush=True); "
"data=sys.stdin.read(); print('got:' + data, flush=True)"
)
initial = await exec_tool.execute(command=command, yield_time_ms=500)
sid = _session_id(initial)
result = await stdin_tool.execute(
session_id=sid,
chars="payload",
close_stdin=True,
yield_time_ms=1000,
)
return initial, result
initial, result = asyncio.run(run())
assert "ready" in initial
assert "got:payload" in result
assert "Stdin closed." in result
assert "Exit code: 0" in result
def test_write_stdin_can_terminate_session(tmp_path):
async def run() -> tuple[str, str]:
manager = ExecSessionManager()
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=30, session_manager=manager)
stdin_tool = WriteStdinTool(manager=manager)
command = _python_command(
"import time; print('ready', flush=True); time.sleep(30)"
)
initial = await exec_tool.execute(command=command, yield_time_ms=500)
sid = _session_id(initial)
result = await stdin_tool.execute(
session_id=sid,
terminate=True,
yield_time_ms=0,
)
return initial, result
initial, result = asyncio.run(run())
assert "ready" in initial
assert "Session terminated." in result
assert "Exit code:" in result
def test_write_stdin_accepts_max_output_tokens_alias(tmp_path):
async def run() -> tuple[str, str, str]:
manager = ExecSessionManager()
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
stdin_tool = WriteStdinTool(manager=manager)
command = _python_command(
"import time; print('A' * 2000, flush=True); time.sleep(5)"
)
initial = await exec_tool.execute(command=command, yield_time_ms=0)
sid = _session_id(initial)
poll = await stdin_tool.execute(
session_id=sid,
yield_time_ms=500,
max_output_tokens=1000,
)
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
return initial, poll, cleanup
initial, poll, cleanup = asyncio.run(run())
assert "Process running" in initial
assert "chars truncated" in poll
assert "Session terminated." in cleanup
def test_exec_session_mode_reuses_exec_safety_guard(tmp_path):
manager = ExecSessionManager()
tool = ExecTool(
working_dir=str(tmp_path),
deny_patterns=[r"echo\s+blocked"],
session_manager=manager,
)
result = asyncio.run(tool.execute(command="echo blocked", yield_time_ms=0))
assert "blocked by deny pattern" in result
def test_write_stdin_reports_missing_session(tmp_path):
manager = ExecSessionManager()
tool = WriteStdinTool(manager=manager)
result = asyncio.run(tool.execute(session_id="missing", chars=""))
assert "exec session not found" in result

View File

@ -0,0 +1,216 @@
from __future__ import annotations
import asyncio
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool
def test_read_file_force_bypasses_dedup(tmp_path):
target = tmp_path / "data.txt"
target.write_text("alpha\n")
tool = ReadFileTool(workspace=tmp_path)
first = asyncio.run(tool.execute(path=str(target)))
second = asyncio.run(tool.execute(path=str(target)))
forced = asyncio.run(tool.execute(path=str(target), force=True))
assert "alpha" in first
assert "unchanged" in second.lower()
assert "alpha" in forced
assert "unchanged" not in forced.lower()
def test_edit_file_can_select_occurrence(tmp_path):
target = tmp_path / "duplicate.txt"
target.write_text("one\nsame\ntwo\nsame\n")
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text="same",
new_text="changed",
occurrence=2,
))
assert "Successfully edited" in result
assert target.read_text() == "one\nsame\ntwo\nchanged\n"
def test_edit_file_expected_replacements_guards_replace_all(tmp_path):
target = tmp_path / "duplicate.txt"
target.write_text("same\nsame\n")
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text="same",
new_text="changed",
replace_all=True,
expected_replacements=1,
))
assert "expected 1 replacements but would make 2" in result
assert target.read_text() == "same\nsame\n"
def test_edit_file_expected_replacements_allows_replace_all_when_count_matches(tmp_path):
target = tmp_path / "duplicate.txt"
target.write_text("same\nsame\n")
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text="same",
new_text="changed",
replace_all=True,
expected_replacements=2,
))
assert "Successfully edited" in result
assert target.read_text() == "changed\nchanged\n"
def test_edit_file_can_select_nearest_line_hint(tmp_path):
target = tmp_path / "duplicate.txt"
target.write_text("one\nsame\ntwo\nsame\n")
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text="same",
new_text="changed",
line_hint=4,
))
assert "Successfully edited" in result
assert target.read_text() == "one\nsame\ntwo\nchanged\n"
def test_edit_file_can_edit_ipynb_as_json(tmp_path):
target = tmp_path / "analysis.ipynb"
target.write_text('{"cells": []}')
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text='"cells": []',
new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
))
assert "Successfully edited" in result
assert '"source": "hi"' in target.read_text()
def test_edit_file_multiple_match_hint_mentions_occurrence(tmp_path):
target = tmp_path / "duplicate.txt"
target.write_text("same\nsame\n")
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text="same",
new_text="changed",
))
assert "old_text appears 2 times" in result
assert "occurrence" in result
assert target.read_text() == "same\nsame\n"
def test_edit_file_rejects_ambiguous_line_hint(tmp_path):
target = tmp_path / "duplicate.txt"
target.write_text("same\nmiddle\nsame\n")
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text="same",
new_text="changed",
line_hint=2,
))
assert "line_hint 2 is ambiguous" in result
assert target.read_text() == "same\nmiddle\nsame\n"
def test_edit_file_rejects_occurrence_with_replace_all(tmp_path):
target = tmp_path / "duplicate.txt"
target.write_text("same\nsame\n")
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text="same",
new_text="changed",
occurrence=1,
replace_all=True,
))
assert "occurrence cannot be used with replace_all" in result
assert target.read_text() == "same\nsame\n"
def test_edit_file_rejects_line_hint_with_replace_all(tmp_path):
target = tmp_path / "duplicate.txt"
target.write_text("same\nsame\n")
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text="same",
new_text="changed",
line_hint=1,
replace_all=True,
))
assert "line_hint cannot be used with replace_all" in result
assert target.read_text() == "same\nsame\n"
def test_edit_file_rejects_line_hint_with_occurrence(tmp_path):
target = tmp_path / "duplicate.txt"
target.write_text("same\nsame\n")
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text="same",
new_text="changed",
occurrence=1,
line_hint=1,
))
assert "line_hint cannot be used with occurrence" in result
assert target.read_text() == "same\nsame\n"
def test_edit_file_rejects_zero_occurrence(tmp_path):
target = tmp_path / "duplicate.txt"
target.write_text("same\n")
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text="same",
new_text="changed",
occurrence=0,
))
assert "occurrence must be >= 1" in result
assert target.read_text() == "same\n"
def test_edit_file_rejects_zero_line_hint(tmp_path):
target = tmp_path / "duplicate.txt"
target.write_text("same\n")
tool = EditFileTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
path=str(target),
old_text="same",
new_text="changed",
line_hint=0,
))
assert "line_hint must be >= 1" in result
assert target.read_text() == "same\n"

View File

@ -1,147 +0,0 @@
"""Tests for NotebookEditTool — Jupyter .ipynb editing."""
import json
import pytest
from nanobot.agent.tools.notebook import NotebookEditTool
def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict:
"""Build a minimal valid .ipynb structure."""
return {
"nbformat": nbformat,
"nbformat_minor": nbformat_minor,
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
"cells": cells or [],
}
def _code_cell(source: str, cell_id: str | None = None) -> dict:
cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None}
if cell_id:
cell["id"] = cell_id
return cell
def _md_cell(source: str, cell_id: str | None = None) -> dict:
cell = {"cell_type": "markdown", "source": source, "metadata": {}}
if cell_id:
cell["id"] = cell_id
return cell
def _write_nb(tmp_path, name: str, nb: dict) -> str:
p = tmp_path / name
p.write_text(json.dumps(nb), encoding="utf-8")
return str(p)
class TestNotebookEdit:
@pytest.fixture()
def tool(self, tmp_path):
return NotebookEditTool(workspace=tmp_path)
@pytest.mark.asyncio
async def test_replace_cell_content(self, tool, tmp_path):
nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=0, new_source="print('world')")
assert "Successfully" in result
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert saved["cells"][0]["source"] == "print('world')"
assert saved["cells"][1]["source"] == "x = 1"
@pytest.mark.asyncio
async def test_insert_cell_after_target(self, tool, tmp_path):
nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert")
assert "Successfully" in result
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert len(saved["cells"]) == 3
assert saved["cells"][0]["source"] == "cell 0"
assert saved["cells"][1]["source"] == "inserted"
assert saved["cells"][2]["source"] == "cell 1"
@pytest.mark.asyncio
async def test_delete_cell(self, tool, tmp_path):
nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=1, edit_mode="delete")
assert "Successfully" in result
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert len(saved["cells"]) == 2
assert saved["cells"][0]["source"] == "A"
assert saved["cells"][1]["source"] == "C"
@pytest.mark.asyncio
async def test_create_new_notebook_from_scratch(self, tool, tmp_path):
path = str(tmp_path / "new.ipynb")
result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown")
assert "Successfully" in result or "created" in result.lower()
saved = json.loads((tmp_path / "new.ipynb").read_text())
assert saved["nbformat"] == 4
assert len(saved["cells"]) == 1
assert saved["cells"][0]["cell_type"] == "markdown"
assert saved["cells"][0]["source"] == "# Hello"
@pytest.mark.asyncio
async def test_invalid_cell_index_error(self, tool, tmp_path):
nb = _make_notebook([_code_cell("only cell")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=5, new_source="x")
assert "Error" in result
@pytest.mark.asyncio
async def test_non_ipynb_rejected(self, tool, tmp_path):
f = tmp_path / "script.py"
f.write_text("pass")
result = await tool.execute(path=str(f), cell_index=0, new_source="x")
assert "Error" in result
assert ".ipynb" in result
@pytest.mark.asyncio
async def test_preserves_metadata_and_outputs(self, tool, tmp_path):
cell = _code_cell("old")
cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}]
cell["execution_count"] = 42
nb = _make_notebook([cell])
path = _write_nb(tmp_path, "test.ipynb", nb)
await tool.execute(path=path, cell_index=0, new_source="new")
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert saved["metadata"]["kernelspec"]["language"] == "python"
@pytest.mark.asyncio
async def test_nbformat_45_generates_cell_id(self, tool, tmp_path):
nb = _make_notebook([], nbformat_minor=5)
path = _write_nb(tmp_path, "test.ipynb", nb)
await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert")
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert "id" in saved["cells"][0]
assert len(saved["cells"][0]["id"]) > 0
@pytest.mark.asyncio
async def test_insert_with_cell_type_markdown(self, tool, tmp_path):
nb = _make_notebook([_code_cell("code")])
path = _write_nb(tmp_path, "test.ipynb", nb)
await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown")
saved = json.loads((tmp_path / "test.ipynb").read_text())
assert saved["cells"][1]["cell_type"] == "markdown"
@pytest.mark.asyncio
async def test_invalid_edit_mode_rejected(self, tool, tmp_path):
nb = _make_notebook([_code_cell("code")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae")
assert "Error" in result
assert "edit_mode" in result
@pytest.mark.asyncio
async def test_invalid_cell_type_rejected(self, tool, tmp_path):
nb = _make_notebook([_code_cell("code")])
path = _write_nb(tmp_path, "test.ipynb", nb)
result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw")
assert "Error" in result
assert "cell_type" in result

View File

@ -89,9 +89,11 @@ def test_discover_finds_concrete_tools():
loader = ToolLoader() loader = ToolLoader()
discovered = loader.discover() discovered = loader.discover()
class_names = {cls.__name__ for cls in discovered} class_names = {cls.__name__ for cls in discovered}
assert "ApplyPatchTool" in class_names
assert "ExecTool" in class_names assert "ExecTool" in class_names
assert "MessageTool" in class_names assert "MessageTool" in class_names
assert "SpawnTool" in class_names assert "SpawnTool" in class_names
assert "WriteStdinTool" in class_names
def test_discover_excludes_abstract_and_mcp(): def test_discover_excludes_abstract_and_mcp():
@ -406,7 +408,7 @@ def test_loader_registers_same_tools_as_old_hardcoded():
expected = { expected = {
"read_file", "write_file", "edit_file", "list_dir", "read_file", "write_file", "edit_file", "list_dir",
"grep", "notebook_edit", "exec", "web_search", "web_fetch", "grep", "exec", "web_search", "web_fetch",
"message", "spawn", "cron", "message", "spawn", "cron",
} }
actual = set(registered) actual = set(registered)