mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-21 17:12:32 +00:00
feat(tools): optimize coding workflows
This commit is contained in:
parent
eae51333ad
commit
6851fa57a6
@ -61,7 +61,7 @@
|
||||
- **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks.
|
||||
- **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened.
|
||||
- **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media.
|
||||
- **2026-04-10** 📓 Notebook editing tool, multiple MCP servers, Feishu streaming & done-emoji.
|
||||
- **2026-04-10** 📓 Multiple MCP servers, Feishu streaming & done-emoji.
|
||||
- **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config.
|
||||
- **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback.
|
||||
- **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools.
|
||||
|
||||
341
nanobot/agent/tools/apply_patch.py
Normal file
341
nanobot/agent/tools/apply_patch.py
Normal file
@ -0,0 +1,341 @@
|
||||
"""Structured patch editing tool for coding workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from nanobot.agent.tools.base import tool_parameters
|
||||
from nanobot.agent.tools.filesystem import _FsTool
|
||||
from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema
|
||||
|
||||
|
||||
PatchKind = Literal["add", "delete", "update"]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _Hunk:
|
||||
header: str | None
|
||||
lines: list[tuple[str, str]]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _PatchOp:
|
||||
kind: PatchKind
|
||||
path: str
|
||||
new_path: str | None = None
|
||||
add_lines: list[str] | None = None
|
||||
hunks: list[_Hunk] | None = None
|
||||
|
||||
|
||||
class _PatchError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
_ABSOLUTE_WINDOWS_RE = re.compile(r"^[A-Za-z]:[\\/]")
|
||||
|
||||
|
||||
def _is_file_header(line: str) -> bool:
|
||||
return (
|
||||
line.startswith("*** Add File: ")
|
||||
or line.startswith("*** Delete File: ")
|
||||
or line.startswith("*** Update File: ")
|
||||
)
|
||||
|
||||
|
||||
def _validate_relative_path(path: str) -> str:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
raise _PatchError("patch path cannot be empty")
|
||||
if "\0" in normalized:
|
||||
raise _PatchError(f"patch path contains a null byte: {path!r}")
|
||||
if normalized.startswith(("~", "/", "\\")) or _ABSOLUTE_WINDOWS_RE.match(normalized):
|
||||
raise _PatchError(f"patch path must be relative: {path}")
|
||||
if any(part == ".." for part in re.split(r"[\\/]+", normalized)):
|
||||
raise _PatchError(f"patch path must not contain '..': {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _lines_to_text(lines: list[str]) -> str:
|
||||
if not lines:
|
||||
return ""
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _parse_patch(patch: str) -> list[_PatchOp]:
|
||||
lines = patch.replace("\r\n", "\n").replace("\r", "\n").split("\n")
|
||||
if lines and lines[-1] == "":
|
||||
lines.pop()
|
||||
if not lines or lines[0] != "*** Begin Patch":
|
||||
raise _PatchError("patch must start with '*** Begin Patch'")
|
||||
if len(lines) < 2 or lines[-1] != "*** End Patch":
|
||||
raise _PatchError("patch must end with '*** End Patch'")
|
||||
|
||||
ops: list[_PatchOp] = []
|
||||
i = 1
|
||||
end = len(lines) - 1
|
||||
while i < end:
|
||||
line = lines[i]
|
||||
if line.startswith("*** Add File: "):
|
||||
path = _validate_relative_path(line.removeprefix("*** Add File: "))
|
||||
i += 1
|
||||
add_lines: list[str] = []
|
||||
while i < end and not _is_file_header(lines[i]):
|
||||
if not lines[i].startswith("+"):
|
||||
raise _PatchError(f"Add File lines must start with '+': {lines[i]!r}")
|
||||
add_lines.append(lines[i][1:])
|
||||
i += 1
|
||||
ops.append(_PatchOp(kind="add", path=path, add_lines=add_lines))
|
||||
continue
|
||||
|
||||
if line.startswith("*** Delete File: "):
|
||||
path = _validate_relative_path(line.removeprefix("*** Delete File: "))
|
||||
ops.append(_PatchOp(kind="delete", path=path))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if line.startswith("*** Update File: "):
|
||||
path = _validate_relative_path(line.removeprefix("*** Update File: "))
|
||||
i += 1
|
||||
new_path: str | None = None
|
||||
if i < end and lines[i].startswith("*** Move to: "):
|
||||
new_path = _validate_relative_path(lines[i].removeprefix("*** Move to: "))
|
||||
i += 1
|
||||
|
||||
hunks: list[_Hunk] = []
|
||||
while i < end and not _is_file_header(lines[i]):
|
||||
if not lines[i].startswith("@@"):
|
||||
raise _PatchError(f"Update File sections require '@@' hunks: {lines[i]!r}")
|
||||
header = lines[i][2:].strip() or None
|
||||
i += 1
|
||||
hunk_lines: list[tuple[str, str]] = []
|
||||
while i < end and not lines[i].startswith("@@") and not _is_file_header(lines[i]):
|
||||
if lines[i] == "*** End of File":
|
||||
i += 1
|
||||
break
|
||||
if lines[i] == r"\ No newline at end of file":
|
||||
i += 1
|
||||
continue
|
||||
if not lines[i] or lines[i][0] not in {" ", "+", "-"}:
|
||||
raise _PatchError(f"Hunk lines must start with ' ', '+', or '-': {lines[i]!r}")
|
||||
hunk_lines.append((lines[i][0], lines[i][1:]))
|
||||
i += 1
|
||||
if not hunk_lines:
|
||||
raise _PatchError(f"Update File hunk is empty: {path}")
|
||||
hunks.append(_Hunk(header=header, lines=hunk_lines))
|
||||
if not hunks and new_path is None:
|
||||
raise _PatchError(f"Update File requires at least one hunk or Move to: {path}")
|
||||
ops.append(_PatchOp(kind="update", path=path, new_path=new_path, hunks=hunks))
|
||||
continue
|
||||
|
||||
raise _PatchError(f"unknown patch header: {line!r}")
|
||||
|
||||
if not ops:
|
||||
raise _PatchError("patch contains no file operations")
|
||||
return ops
|
||||
|
||||
|
||||
def _find_with_eof_fallback(content: str, needle: str, start: int) -> tuple[int, int]:
|
||||
pos = content.find(needle, start)
|
||||
if pos >= 0:
|
||||
return pos, len(needle)
|
||||
if needle.endswith("\n"):
|
||||
trimmed = needle[:-1]
|
||||
pos = content.find(trimmed, start)
|
||||
if pos >= 0 and pos + len(trimmed) == len(content):
|
||||
return pos, len(trimmed)
|
||||
return -1, 0
|
||||
|
||||
|
||||
def _line_offset(content: str, line_number: int) -> int:
|
||||
if line_number <= 1:
|
||||
return 0
|
||||
offset = 0
|
||||
for current, line in enumerate(content.splitlines(keepends=True), start=1):
|
||||
if current >= line_number:
|
||||
return offset
|
||||
offset += len(line)
|
||||
return len(content)
|
||||
|
||||
|
||||
def _line_hint(header: str | None) -> int | None:
|
||||
if not header:
|
||||
return None
|
||||
match = re.search(r"-(\d+)(?:,\d+)?", header)
|
||||
return int(match.group(1)) if match else None
|
||||
|
||||
|
||||
def _hunk_mismatch(path: str, old_text: str, content: str, header: str | None) -> str:
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = max(1, len(old_lines))
|
||||
best_ratio, best_start = -1.0, 0
|
||||
best_lines: list[str] = []
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
current = lines[i : i + window]
|
||||
ratio = difflib.SequenceMatcher(None, "".join(old_lines), "".join(current)).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start, best_lines = ratio, i, current
|
||||
|
||||
label = f" after header {header!r}" if header else ""
|
||||
if best_ratio <= 0:
|
||||
return f"hunk does not match {path}{label}"
|
||||
diff = "\n".join(difflib.unified_diff(
|
||||
old_lines,
|
||||
best_lines,
|
||||
fromfile="patch hunk",
|
||||
tofile=f"{path} (actual, line {best_start + 1})",
|
||||
lineterm="",
|
||||
))
|
||||
return (
|
||||
f"hunk does not match {path}{label}. "
|
||||
f"Best match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
)
|
||||
|
||||
|
||||
def _apply_hunks(path: str, content: str, hunks: list[_Hunk]) -> str:
|
||||
cursor = 0
|
||||
for hunk in hunks:
|
||||
old_lines = [text for marker, text in hunk.lines if marker in {" ", "-"}]
|
||||
new_lines = [text for marker, text in hunk.lines if marker in {" ", "+"}]
|
||||
old_text = _lines_to_text(old_lines)
|
||||
new_text = _lines_to_text(new_lines)
|
||||
|
||||
search_start = cursor
|
||||
line_hint = None
|
||||
if hunk.header:
|
||||
line_hint = _line_hint(hunk.header)
|
||||
if line_hint is not None:
|
||||
search_start = _line_offset(content, line_hint)
|
||||
else:
|
||||
header_pos = content.find(hunk.header, cursor)
|
||||
if header_pos >= 0:
|
||||
search_start = header_pos
|
||||
|
||||
if old_text:
|
||||
pos, match_len = _find_with_eof_fallback(content, old_text, search_start)
|
||||
if pos < 0 and search_start != 0 and line_hint is None:
|
||||
pos, match_len = _find_with_eof_fallback(content, old_text, 0)
|
||||
if pos < 0:
|
||||
raise _PatchError(_hunk_mismatch(path, old_text, content, hunk.header))
|
||||
else:
|
||||
pos = search_start
|
||||
match_len = 0
|
||||
|
||||
content = content[:pos] + new_text + content[pos + match_len:]
|
||||
cursor = pos + len(new_text)
|
||||
return content
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
patch=StringSchema(
|
||||
"Full patch text. Use *** Begin Patch / *** End Patch and file sections "
|
||||
"for Add File, Update File, Delete File, and optional Move to.",
|
||||
min_length=1,
|
||||
),
|
||||
required=["patch"],
|
||||
)
|
||||
)
|
||||
class ApplyPatchTool(_FsTool):
|
||||
"""Apply a structured multi-file patch."""
|
||||
_scopes = {"core", "subagent"}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "apply_patch"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Apply a structured patch for code edits. The patch must include "
|
||||
"*** Begin Patch and *** End Patch. Supports Add File, Update File, "
|
||||
"Delete File, and Move to. Paths must be relative. Prefer this for "
|
||||
"multi-file coding changes; use edit_file for small exact replacements."
|
||||
)
|
||||
|
||||
async def execute(self, patch: str, **kwargs: Any) -> str:
|
||||
try:
|
||||
ops = _parse_patch(patch)
|
||||
writes: dict[Path, str] = {}
|
||||
deletes: set[Path] = set()
|
||||
touched: list[str] = []
|
||||
|
||||
for op in ops:
|
||||
source = self._resolve(op.path)
|
||||
if op.kind == "add":
|
||||
if source.exists():
|
||||
raise _PatchError(f"file to add already exists: {op.path}")
|
||||
writes[source] = _lines_to_text(op.add_lines or [])
|
||||
deletes.discard(source)
|
||||
touched.append(f"add {op.path}")
|
||||
continue
|
||||
|
||||
if op.kind == "delete":
|
||||
if not source.exists():
|
||||
raise _PatchError(f"file to delete does not exist: {op.path}")
|
||||
if not source.is_file():
|
||||
raise _PatchError(f"path to delete is not a file: {op.path}")
|
||||
deletes.add(source)
|
||||
writes.pop(source, None)
|
||||
touched.append(f"delete {op.path}")
|
||||
continue
|
||||
|
||||
if not source.exists():
|
||||
raise _PatchError(f"file to update does not exist: {op.path}")
|
||||
if not source.is_file():
|
||||
raise _PatchError(f"path to update is not a file: {op.path}")
|
||||
raw = source.read_bytes()
|
||||
try:
|
||||
content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise _PatchError(f"file to update is not UTF-8 text: {op.path}") from exc
|
||||
uses_crlf = "\r\n" in content
|
||||
content = content.replace("\r\n", "\n")
|
||||
new_content = _apply_hunks(op.path, content, op.hunks or [])
|
||||
if uses_crlf:
|
||||
new_content = new_content.replace("\n", "\r\n")
|
||||
|
||||
target = self._resolve(op.new_path) if op.new_path else source
|
||||
if op.new_path and target.exists() and target != source:
|
||||
raise _PatchError(f"move target already exists: {op.new_path}")
|
||||
writes[target] = new_content
|
||||
deletes.discard(target)
|
||||
if target != source:
|
||||
deletes.add(source)
|
||||
action = f"move {op.path} -> {op.new_path}" if op.new_path else f"update {op.path}"
|
||||
touched.append(action)
|
||||
|
||||
backups: dict[Path, bytes | None] = {}
|
||||
for path in set(writes) | deletes:
|
||||
backups[path] = path.read_bytes() if path.exists() else None
|
||||
|
||||
try:
|
||||
for path in deletes:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
for path, content in writes.items():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content, encoding="utf-8", newline="")
|
||||
except Exception:
|
||||
for path, data in backups.items():
|
||||
if data is None:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
else:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(data)
|
||||
raise
|
||||
|
||||
for path in set(writes) | deletes:
|
||||
self._file_states.record_write(path)
|
||||
return "Patch applied:\n" + "\n".join(f"- {item}" for item in touched)
|
||||
except PermissionError as exc:
|
||||
return f"Error: {exc}"
|
||||
except _PatchError as exc:
|
||||
return f"Error applying patch: {exc}"
|
||||
except Exception as exc:
|
||||
return f"Error applying patch: {exc}"
|
||||
409
nanobot/agent/tools/exec_session.py
Normal file
409
nanobot/agent/tools/exec_session.py
Normal file
@ -0,0 +1,409 @@
|
||||
"""Session support for long-running exec workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||
|
||||
|
||||
DEFAULT_YIELD_MS = 1000
|
||||
MAX_YIELD_MS = 30_000
|
||||
DEFAULT_MAX_OUTPUT_CHARS = 10_000
|
||||
MAX_OUTPUT_CHARS = 50_000
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _SessionPoll:
|
||||
output: str
|
||||
done: bool
|
||||
exit_code: int | None
|
||||
timed_out: bool = False
|
||||
terminated: bool = False
|
||||
stdin_closed: bool = False
|
||||
truncated_chars: int = 0
|
||||
|
||||
|
||||
class _ExecSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
process: asyncio.subprocess.Process,
|
||||
timeout: int,
|
||||
) -> None:
|
||||
self.session_id = session_id
|
||||
self.process = process
|
||||
self.deadline = time.monotonic() + timeout
|
||||
self.last_access = time.monotonic()
|
||||
self._chunks: list[str] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._timed_out = False
|
||||
self._stdout_task = asyncio.create_task(self._read_stream(process.stdout, ""))
|
||||
self._stderr_task = asyncio.create_task(self._read_stream(process.stderr, "STDERR:\n"))
|
||||
|
||||
async def _read_stream(
|
||||
self,
|
||||
stream: asyncio.StreamReader | None,
|
||||
prefix: str,
|
||||
) -> None:
|
||||
if stream is None:
|
||||
return
|
||||
first = True
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
text = chunk.decode("utf-8", errors="replace")
|
||||
if prefix and first:
|
||||
text = prefix + text
|
||||
first = False
|
||||
async with self._lock:
|
||||
self._chunks.append(text)
|
||||
|
||||
async def write(self, chars: str) -> str | None:
|
||||
if self.process.returncode is not None:
|
||||
return "session has already exited"
|
||||
if self.process.stdin is None:
|
||||
return "session stdin is not available"
|
||||
try:
|
||||
self.process.stdin.write(chars.encode("utf-8"))
|
||||
await self.process.stdin.drain()
|
||||
except (BrokenPipeError, ConnectionResetError):
|
||||
return "session stdin is closed"
|
||||
return None
|
||||
|
||||
async def close_stdin(self) -> str | None:
|
||||
if self.process.returncode is not None:
|
||||
return "session has already exited"
|
||||
if self.process.stdin is None:
|
||||
return "session stdin is not available"
|
||||
self.process.stdin.close()
|
||||
with suppress(BrokenPipeError, ConnectionResetError):
|
||||
await self.process.stdin.wait_closed()
|
||||
return None
|
||||
|
||||
async def poll(
|
||||
self,
|
||||
yield_time_ms: int,
|
||||
max_output_chars: int,
|
||||
*,
|
||||
terminated: bool = False,
|
||||
stdin_closed: bool = False,
|
||||
) -> _SessionPoll:
|
||||
self.last_access = time.monotonic()
|
||||
if yield_time_ms > 0 and self.process.returncode is None:
|
||||
await asyncio.sleep(min(yield_time_ms, MAX_YIELD_MS) / 1000)
|
||||
|
||||
if self.process.returncode is None and time.monotonic() >= self.deadline:
|
||||
self._timed_out = True
|
||||
await self.kill()
|
||||
|
||||
if self.process.returncode is not None:
|
||||
with suppress(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(self._stdout_task, self._stderr_task),
|
||||
timeout=2.0,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
output = "".join(self._chunks)
|
||||
self._chunks.clear()
|
||||
|
||||
output, truncated = _truncate_output(output, max_output_chars)
|
||||
return _SessionPoll(
|
||||
output=output,
|
||||
done=self.process.returncode is not None,
|
||||
exit_code=self.process.returncode,
|
||||
timed_out=self._timed_out,
|
||||
terminated=terminated,
|
||||
stdin_closed=stdin_closed,
|
||||
truncated_chars=truncated,
|
||||
)
|
||||
|
||||
async def kill(self) -> None:
|
||||
if self.process.returncode is not None:
|
||||
return
|
||||
self.process.kill()
|
||||
with suppress(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(self.process.wait(), timeout=5.0)
|
||||
|
||||
|
||||
class ExecSessionManager:
|
||||
def __init__(self, *, max_sessions: int = 8, idle_timeout: int = 1800) -> None:
|
||||
self.max_sessions = max_sessions
|
||||
self.idle_timeout = idle_timeout
|
||||
self._sessions: dict[str, _ExecSession] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def start(
|
||||
self,
|
||||
*,
|
||||
command: str,
|
||||
cwd: str,
|
||||
env: dict[str, str],
|
||||
timeout: int,
|
||||
shell_program: str | None,
|
||||
login: bool,
|
||||
yield_time_ms: int,
|
||||
max_output_chars: int,
|
||||
) -> tuple[str, _SessionPoll]:
|
||||
async with self._lock:
|
||||
await self._cleanup_locked()
|
||||
if len(self._sessions) >= self.max_sessions:
|
||||
raise RuntimeError(f"maximum exec sessions reached ({self.max_sessions})")
|
||||
process = await self._spawn(command, cwd, env, shell_program, login)
|
||||
session_id = uuid.uuid4().hex[:12]
|
||||
session = _ExecSession(session_id=session_id, process=process, timeout=timeout)
|
||||
self._sessions[session_id] = session
|
||||
|
||||
poll = await session.poll(yield_time_ms, max_output_chars)
|
||||
if poll.done:
|
||||
async with self._lock:
|
||||
self._sessions.pop(session_id, None)
|
||||
return session_id, poll
|
||||
|
||||
async def write(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
chars: str | None,
|
||||
close_stdin: bool,
|
||||
terminate: bool,
|
||||
yield_time_ms: int,
|
||||
max_output_chars: int,
|
||||
) -> _SessionPoll:
|
||||
async with self._lock:
|
||||
await self._cleanup_locked()
|
||||
session = self._sessions.get(session_id)
|
||||
if session is None:
|
||||
raise KeyError(session_id)
|
||||
|
||||
if chars is not None:
|
||||
error = await session.write(chars)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
stdin_closed = False
|
||||
if close_stdin:
|
||||
error = await session.close_stdin()
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
stdin_closed = True
|
||||
if terminate:
|
||||
await session.kill()
|
||||
poll = await session.poll(
|
||||
yield_time_ms,
|
||||
max_output_chars,
|
||||
terminated=terminate,
|
||||
stdin_closed=stdin_closed,
|
||||
)
|
||||
if poll.done:
|
||||
async with self._lock:
|
||||
self._sessions.pop(session_id, None)
|
||||
return poll
|
||||
|
||||
async def _cleanup_locked(self) -> None:
|
||||
now = time.monotonic()
|
||||
stale = [
|
||||
session_id
|
||||
for session_id, session in self._sessions.items()
|
||||
if session.process.returncode is not None
|
||||
or now - session.last_access > self.idle_timeout
|
||||
]
|
||||
for session_id in stale:
|
||||
session = self._sessions.pop(session_id)
|
||||
await session.kill()
|
||||
|
||||
async def _spawn(
|
||||
self,
|
||||
command: str,
|
||||
cwd: str,
|
||||
env: dict[str, str],
|
||||
shell_program: str | None,
|
||||
login: bool,
|
||||
) -> asyncio.subprocess.Process:
|
||||
from nanobot.agent.tools import shell
|
||||
|
||||
if shell._IS_WINDOWS:
|
||||
return await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
shell_program = shell_program or shutil.which("bash") or "/bin/bash"
|
||||
args = [shell_program]
|
||||
if login and shell_program.rsplit("/", 1)[-1] in {"bash", "zsh"}:
|
||||
args.append("-l")
|
||||
args.extend(["-c", command])
|
||||
return await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_EXEC_SESSION_MANAGER = ExecSessionManager()
|
||||
|
||||
|
||||
def clamp_session_int(value: int | None, default: int, minimum: int, maximum: int) -> int:
|
||||
if value is None:
|
||||
return default
|
||||
return min(max(value, minimum), maximum)
|
||||
|
||||
|
||||
def _truncate_output(output: str, max_output_chars: int) -> tuple[str, int]:
|
||||
if len(output) <= max_output_chars:
|
||||
return output, 0
|
||||
half = max_output_chars // 2
|
||||
omitted = len(output) - max_output_chars
|
||||
return (
|
||||
output[:half]
|
||||
+ f"\n\n... ({omitted:,} chars truncated) ...\n\n"
|
||||
+ output[-half:],
|
||||
omitted,
|
||||
)
|
||||
|
||||
|
||||
def format_session_poll(session_id: str, poll: _SessionPoll) -> str:
|
||||
parts = [poll.output] if poll.output else []
|
||||
if poll.truncated_chars:
|
||||
parts.append(f"(output truncated by {poll.truncated_chars:,} chars)")
|
||||
if poll.timed_out:
|
||||
parts.append("Error: Command timed out; session was terminated.")
|
||||
if poll.terminated and not poll.timed_out:
|
||||
parts.append("Session terminated.")
|
||||
if poll.stdin_closed:
|
||||
parts.append("Stdin closed.")
|
||||
if poll.done:
|
||||
parts.append(f"Exit code: {poll.exit_code}")
|
||||
else:
|
||||
parts.append(f"Process running. session_id: {session_id}")
|
||||
return "\n".join(parts) if parts else "(no output yet)"
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
session_id=StringSchema("Session id returned by exec when yield_time_ms is used."),
|
||||
chars=StringSchema(
|
||||
"Bytes/text to write to stdin. Omit or pass an empty string to only poll recent output.",
|
||||
nullable=True,
|
||||
),
|
||||
close_stdin=BooleanSchema(
|
||||
description="Close stdin after writing chars. Useful for commands waiting for EOF.",
|
||||
default=False,
|
||||
),
|
||||
terminate=BooleanSchema(
|
||||
description="Terminate the running exec session.",
|
||||
default=False,
|
||||
),
|
||||
yield_time_ms=IntegerSchema(
|
||||
DEFAULT_YIELD_MS,
|
||||
description="Milliseconds to wait before returning recent output (default 1000, max 30000).",
|
||||
minimum=0,
|
||||
maximum=MAX_YIELD_MS,
|
||||
),
|
||||
max_output_chars=IntegerSchema(
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
description="Maximum output characters to return from this poll (default 10000, max 50000).",
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
),
|
||||
max_output_tokens=IntegerSchema(
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
description="Compatibility alias for max_output_chars. The current runtime uses a character budget.",
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
nullable=True,
|
||||
),
|
||||
required=["session_id"],
|
||||
)
|
||||
)
|
||||
class WriteStdinTool(Tool):
|
||||
"""Write to or poll a running exec session."""
|
||||
|
||||
_scopes = {"core", "subagent"}
|
||||
config_key = "exec"
|
||||
|
||||
@classmethod
|
||||
def config_cls(cls):
|
||||
from nanobot.agent.tools.shell import ExecToolConfig
|
||||
|
||||
return ExecToolConfig
|
||||
|
||||
@classmethod
|
||||
def enabled(cls, ctx: Any) -> bool:
|
||||
return ctx.config.exec.enable
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
manager: ExecSessionManager | None = None,
|
||||
) -> None:
|
||||
self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER
|
||||
|
||||
@classmethod
|
||||
def create(cls, ctx: Any) -> Tool:
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_stdin"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write text to a running exec session and return recent output. "
|
||||
"Use chars='' to poll without writing. Set close_stdin=true to send EOF, "
|
||||
"or terminate=true to stop the session. Sessions finish automatically "
|
||||
"when their process exits."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
session_id: str,
|
||||
chars: str | None = None,
|
||||
close_stdin: bool = False,
|
||||
terminate: bool = False,
|
||||
yield_time_ms: int | None = None,
|
||||
max_output_chars: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if max_output_chars is None:
|
||||
max_output_chars = max_output_tokens
|
||||
poll = await self._manager.write(
|
||||
session_id=session_id,
|
||||
chars=chars,
|
||||
close_stdin=close_stdin,
|
||||
terminate=terminate,
|
||||
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
|
||||
max_output_chars=clamp_session_int(
|
||||
max_output_chars,
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
1000,
|
||||
MAX_OUTPUT_CHARS,
|
||||
),
|
||||
)
|
||||
return format_session_poll(session_id, poll)
|
||||
except KeyError:
|
||||
return f"Error: exec session not found: {session_id}"
|
||||
except Exception as exc:
|
||||
return f"Error writing to exec session: {exc}"
|
||||
@ -132,6 +132,10 @@ def _parse_page_range(pages: str, total: int) -> tuple[int, int]:
|
||||
minimum=1,
|
||||
),
|
||||
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
|
||||
force=BooleanSchema(
|
||||
description="Bypass same-file read deduplication and return content again.",
|
||||
default=False,
|
||||
),
|
||||
required=["path"],
|
||||
)
|
||||
)
|
||||
@ -155,6 +159,7 @@ class ReadFileTool(_FsTool):
|
||||
"Images return visual content for analysis. "
|
||||
"Supports PDF, DOCX, XLSX, PPTX documents. "
|
||||
"Use offset and limit for large text files. "
|
||||
"Use force=true to re-read content even if unchanged. "
|
||||
"Reads exceeding ~128K chars are truncated."
|
||||
)
|
||||
|
||||
@ -162,7 +167,15 @@ class ReadFileTool(_FsTool):
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any:
|
||||
async def execute(
|
||||
self,
|
||||
path: str | None = None,
|
||||
offset: int = 1,
|
||||
limit: int | None = None,
|
||||
pages: str | None = None,
|
||||
force: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
if not path:
|
||||
return "Error reading file: Unknown path"
|
||||
@ -202,7 +215,13 @@ class ReadFileTool(_FsTool):
|
||||
current_mtime = os.path.getmtime(fp)
|
||||
except OSError:
|
||||
current_mtime = 0.0
|
||||
if entry and entry.can_dedup and entry.offset == offset and entry.limit == limit:
|
||||
if (
|
||||
not force
|
||||
and entry
|
||||
and entry.can_dedup
|
||||
and entry.offset == offset
|
||||
and entry.limit == limit
|
||||
):
|
||||
if current_mtime != entry.mtime:
|
||||
# File was modified externally - force full read and mark as not dedupable
|
||||
entry.can_dedup = False
|
||||
@ -657,6 +676,24 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||
old_text=StringSchema("The text to find and replace"),
|
||||
new_text=StringSchema("The text to replace with"),
|
||||
replace_all=BooleanSchema(description="Replace all occurrences (default false)"),
|
||||
occurrence=IntegerSchema(
|
||||
1,
|
||||
description="Optional 1-based occurrence to replace when old_text appears multiple times.",
|
||||
minimum=1,
|
||||
nullable=True,
|
||||
),
|
||||
line_hint=IntegerSchema(
|
||||
1,
|
||||
description="Optional 1-based line hint used to choose the nearest match.",
|
||||
minimum=1,
|
||||
nullable=True,
|
||||
),
|
||||
expected_replacements=IntegerSchema(
|
||||
1,
|
||||
description="Optional guard for the number of replacements that must be made.",
|
||||
minimum=1,
|
||||
nullable=True,
|
||||
),
|
||||
required=["path", "old_text", "new_text"],
|
||||
)
|
||||
)
|
||||
@ -677,7 +714,7 @@ class EditFileTool(_FsTool):
|
||||
"Edit a file by replacing old_text with new_text. "
|
||||
"Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. "
|
||||
"If old_text matches multiple times, you must provide more context "
|
||||
"or set replace_all=true. Shows a diff of the closest match on failure."
|
||||
"or set occurrence/line_hint/replace_all. Shows a diff of the closest match on failure."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -688,7 +725,8 @@ class EditFileTool(_FsTool):
|
||||
async def execute(
|
||||
self, path: str | None = None, old_text: str | None = None,
|
||||
new_text: str | None = None,
|
||||
replace_all: bool = False, **kwargs: Any,
|
||||
replace_all: bool = False, occurrence: int | None = None,
|
||||
line_hint: int | None = None, expected_replacements: int | None = None, **kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if not path:
|
||||
@ -697,10 +735,12 @@ class EditFileTool(_FsTool):
|
||||
raise ValueError("Unknown old_text")
|
||||
if new_text is None:
|
||||
raise ValueError("Unknown new_text")
|
||||
|
||||
# .ipynb detection
|
||||
if path.endswith(".ipynb"):
|
||||
return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file."
|
||||
if occurrence is not None and occurrence < 1:
|
||||
return "Error: occurrence must be >= 1."
|
||||
if line_hint is not None and line_hint < 1:
|
||||
return "Error: line_hint must be >= 1."
|
||||
if expected_replacements is not None and expected_replacements < 1:
|
||||
return "Error: expected_replacements must be >= 1."
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
@ -743,15 +783,42 @@ class EditFileTool(_FsTool):
|
||||
if not matches:
|
||||
return self._not_found_msg(old_text, content, path)
|
||||
count = len(matches)
|
||||
if replace_all and occurrence is not None:
|
||||
return "Error: occurrence cannot be used with replace_all=true."
|
||||
if replace_all and line_hint is not None:
|
||||
return "Error: line_hint cannot be used with replace_all=true."
|
||||
if occurrence is not None and line_hint is not None:
|
||||
return "Error: line_hint cannot be used with occurrence."
|
||||
if count > 1 and not replace_all:
|
||||
line_numbers = [match.line for match in matches]
|
||||
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
|
||||
if len(line_numbers) > 3:
|
||||
preview += ", ..."
|
||||
location_hint = f" at {preview}" if preview else ""
|
||||
if occurrence is not None:
|
||||
if occurrence > count:
|
||||
return (
|
||||
f"Error: occurrence {occurrence} is out of range; "
|
||||
f"old_text appears {count} times."
|
||||
)
|
||||
elif line_hint is not None:
|
||||
nearest = min(matches, key=lambda match: abs(match.line - line_hint))
|
||||
distance = abs(nearest.line - line_hint)
|
||||
if sum(1 for match in matches if abs(match.line - line_hint) == distance) > 1:
|
||||
return (
|
||||
f"Error: line_hint {line_hint} is ambiguous; "
|
||||
f"old_text appears {count} times."
|
||||
)
|
||||
else:
|
||||
line_numbers = [match.line for match in matches]
|
||||
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
|
||||
if len(line_numbers) > 3:
|
||||
preview += ", ..."
|
||||
location_hint = f" at {preview}" if preview else ""
|
||||
return (
|
||||
f"Warning: old_text appears {count} times{location_hint}. "
|
||||
"Provide more context, set occurrence to choose one match, "
|
||||
"or set replace_all=true."
|
||||
)
|
||||
elif occurrence is not None and occurrence > count:
|
||||
return (
|
||||
f"Warning: old_text appears {count} times{location_hint}. "
|
||||
"Provide more context to make it unique, or set replace_all=true."
|
||||
f"Error: occurrence {occurrence} is out of range; "
|
||||
f"old_text appears {count} time."
|
||||
)
|
||||
|
||||
norm_new = new_text.replace("\r\n", "\n")
|
||||
@ -760,7 +827,17 @@ class EditFileTool(_FsTool):
|
||||
if fp.suffix.lower() not in self._MARKDOWN_EXTS:
|
||||
norm_new = self._strip_trailing_ws(norm_new)
|
||||
|
||||
selected = matches if replace_all else matches[:1]
|
||||
if replace_all:
|
||||
selected = matches
|
||||
elif line_hint is not None:
|
||||
selected = [min(matches, key=lambda match: abs(match.line - line_hint))]
|
||||
else:
|
||||
selected = [matches[occurrence - 1 if occurrence else 0]]
|
||||
if expected_replacements is not None and len(selected) != expected_replacements:
|
||||
return (
|
||||
f"Error: expected {expected_replacements} replacements but "
|
||||
f"would make {len(selected)}."
|
||||
)
|
||||
new_content = content
|
||||
for match in reversed(selected):
|
||||
replacement = _preserve_quote_style(norm_old, match.text, norm_new)
|
||||
|
||||
@ -1,162 +0,0 @@
|
||||
"""NotebookEditTool — edit Jupyter .ipynb notebooks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import tool_parameters
|
||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools.filesystem import _FsTool
|
||||
|
||||
|
||||
def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict:
|
||||
cell: dict[str, Any] = {
|
||||
"cell_type": cell_type,
|
||||
"source": source,
|
||||
"metadata": {},
|
||||
}
|
||||
if cell_type == "code":
|
||||
cell["outputs"] = []
|
||||
cell["execution_count"] = None
|
||||
if generate_id:
|
||||
cell["id"] = uuid.uuid4().hex[:8]
|
||||
return cell
|
||||
|
||||
|
||||
def _make_empty_notebook() -> dict:
|
||||
return {
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5,
|
||||
"metadata": {
|
||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
||||
"language_info": {"name": "python"},
|
||||
},
|
||||
"cells": [],
|
||||
}
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
path=StringSchema("Path to the .ipynb notebook file"),
|
||||
cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0),
|
||||
new_source=StringSchema("New source content for the cell"),
|
||||
cell_type=StringSchema(
|
||||
"Cell type: 'code' or 'markdown' (default: code)",
|
||||
enum=["code", "markdown"],
|
||||
),
|
||||
edit_mode=StringSchema(
|
||||
"Mode: 'replace' (default), 'insert' (after target), or 'delete'",
|
||||
enum=["replace", "insert", "delete"],
|
||||
),
|
||||
required=["path", "cell_index"],
|
||||
)
|
||||
)
|
||||
class NotebookEditTool(_FsTool):
|
||||
"""Edit Jupyter notebook cells: replace, insert, or delete."""
|
||||
_scopes = {"core"}
|
||||
|
||||
_VALID_CELL_TYPES = frozenset({"code", "markdown"})
|
||||
_VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"})
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "notebook_edit"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a Jupyter notebook (.ipynb) cell. "
|
||||
"Modes: replace (default) replaces cell content, "
|
||||
"insert adds a new cell after the target index, "
|
||||
"delete removes the cell at the index. "
|
||||
"cell_index is 0-based."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
path: str | None = None,
|
||||
cell_index: int = 0,
|
||||
new_source: str = "",
|
||||
cell_type: str = "code",
|
||||
edit_mode: str = "replace",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if not path:
|
||||
return "Error: path is required"
|
||||
|
||||
if not path.endswith(".ipynb"):
|
||||
return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files."
|
||||
|
||||
if edit_mode not in self._VALID_EDIT_MODES:
|
||||
return (
|
||||
f"Error: Invalid edit_mode '{edit_mode}'. "
|
||||
"Use one of: replace, insert, delete."
|
||||
)
|
||||
|
||||
if cell_type not in self._VALID_CELL_TYPES:
|
||||
return (
|
||||
f"Error: Invalid cell_type '{cell_type}'. "
|
||||
"Use one of: code, markdown."
|
||||
)
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
# Create new notebook if file doesn't exist and mode is insert
|
||||
if not fp.exists():
|
||||
if edit_mode != "insert":
|
||||
return f"Error: File not found: {path}"
|
||||
nb = _make_empty_notebook()
|
||||
cell = _new_cell(new_source, cell_type, generate_id=True)
|
||||
nb["cells"].append(cell)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully created {fp} with 1 cell"
|
||||
|
||||
try:
|
||||
nb = json.loads(fp.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
return f"Error: Failed to parse notebook: {e}"
|
||||
|
||||
cells = nb.get("cells", [])
|
||||
nbformat_minor = nb.get("nbformat_minor", 0)
|
||||
generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5
|
||||
|
||||
if edit_mode == "delete":
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells.pop(cell_index)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully deleted cell {cell_index} from {fp}"
|
||||
|
||||
if edit_mode == "insert":
|
||||
insert_at = min(cell_index + 1, len(cells))
|
||||
cell = _new_cell(new_source, cell_type, generate_id=generate_id)
|
||||
cells.insert(insert_at, cell)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully inserted cell at index {insert_at} in {fp}"
|
||||
|
||||
# Default: replace
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells[cell_index]["source"] = new_source
|
||||
if cell_type and cells[cell_index].get("cell_type") != cell_type:
|
||||
cells[cell_index]["cell_type"] = cell_type
|
||||
if cell_type == "code":
|
||||
cells[cell_index].setdefault("outputs", [])
|
||||
cells[cell_index].setdefault("execution_count", None)
|
||||
elif "outputs" in cells[cell_index]:
|
||||
del cells[cell_index]["outputs"]
|
||||
cells[cell_index].pop("execution_count", None)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully edited cell {cell_index} in {fp}"
|
||||
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing notebook: {e}"
|
||||
@ -8,6 +8,7 @@ import re
|
||||
import shutil
|
||||
import sys
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -15,8 +16,17 @@ from loguru import logger
|
||||
from pydantic import Field
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.exec_session import (
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
DEFAULT_YIELD_MS,
|
||||
DEFAULT_EXEC_SESSION_MANAGER,
|
||||
MAX_OUTPUT_CHARS,
|
||||
MAX_YIELD_MS,
|
||||
clamp_session_int,
|
||||
format_session_poll,
|
||||
)
|
||||
from nanobot.agent.tools.sandbox import wrap_command
|
||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
@ -44,10 +54,22 @@ class ExecToolConfig(Base):
|
||||
deny_patterns: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _PreparedCommand:
|
||||
command: str
|
||||
cwd: str
|
||||
env: dict[str, str]
|
||||
timeout: int
|
||||
shell_program: str | None
|
||||
login: bool
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
command=StringSchema("The shell command to execute"),
|
||||
cmd=StringSchema("Compatibility alias for command"),
|
||||
working_dir=StringSchema("Optional working directory for the command"),
|
||||
workdir=StringSchema("Compatibility alias for working_dir"),
|
||||
timeout=IntegerSchema(
|
||||
60,
|
||||
description=(
|
||||
@ -57,7 +79,44 @@ class ExecToolConfig(Base):
|
||||
minimum=1,
|
||||
maximum=600,
|
||||
),
|
||||
required=["command"],
|
||||
shell=StringSchema(
|
||||
"Optional shell binary to launch. On Unix, supports sh, bash, or zsh.",
|
||||
nullable=True,
|
||||
),
|
||||
login=BooleanSchema(
|
||||
description="Whether to run bash/zsh with login shell semantics (default true).",
|
||||
default=True,
|
||||
nullable=True,
|
||||
),
|
||||
yield_time_ms=IntegerSchema(
|
||||
description=(
|
||||
"Optional milliseconds to wait before returning output. "
|
||||
"When set, a still-running command returns a session_id that "
|
||||
"can be polled or written to with write_stdin. Omit this field "
|
||||
"to keep one-shot exec behavior."
|
||||
),
|
||||
minimum=0,
|
||||
maximum=MAX_YIELD_MS,
|
||||
nullable=True,
|
||||
),
|
||||
max_output_chars=IntegerSchema(
|
||||
description=(
|
||||
"Maximum output characters to return when yield_time_ms is used "
|
||||
"(default 10000, max 50000)."
|
||||
),
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
nullable=True,
|
||||
),
|
||||
max_output_tokens=IntegerSchema(
|
||||
description=(
|
||||
"Compatibility alias for max_output_chars. The current runtime "
|
||||
"uses a character budget."
|
||||
),
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
class ExecTool(Tool):
|
||||
@ -98,6 +157,7 @@ class ExecTool(Tool):
|
||||
sandbox: str = "",
|
||||
path_append: str = "",
|
||||
allowed_env_keys: list[str] | None = None,
|
||||
session_manager: Any | None = None,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.working_dir = working_dir
|
||||
@ -125,6 +185,7 @@ class ExecTool(Tool):
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.path_append = path_append
|
||||
self.allowed_env_keys = allowed_env_keys or []
|
||||
self._session_manager = session_manager or DEFAULT_EXEC_SESSION_MANAGER
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -153,7 +214,10 @@ class ExecTool(Tool):
|
||||
"Prefer read_file/write_file/edit_file over cat/echo/sed, "
|
||||
"and grep/glob over shell find/grep. "
|
||||
"Use -y or --yes flags to avoid interactive prompts. "
|
||||
"Output is truncated at 10 000 chars; timeout defaults to 60s."
|
||||
"For long-running or interactive commands, pass yield_time_ms; "
|
||||
"if the command keeps running, exec returns a session_id that can "
|
||||
"be polled or written to with write_stdin. Output is truncated at "
|
||||
"10 000 chars; timeout defaults to 60s."
|
||||
)
|
||||
|
||||
@property
|
||||
@ -161,9 +225,111 @@ class ExecTool(Tool):
|
||||
return True
|
||||
|
||||
async def execute(
|
||||
self, command: str, working_dir: str | None = None,
|
||||
timeout: int | None = None, **kwargs: Any,
|
||||
self, command: str | None = None, cmd: str | None = None,
|
||||
working_dir: str | None = None, workdir: str | None = None,
|
||||
timeout: int | None = None, shell: str | None = None,
|
||||
login: bool | None = None, yield_time_ms: int | None = None,
|
||||
max_output_chars: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
command = command or cmd
|
||||
working_dir = working_dir or workdir
|
||||
if not command:
|
||||
return "Error: Missing command. Provide command or cmd."
|
||||
if max_output_chars is None:
|
||||
max_output_chars = max_output_tokens
|
||||
|
||||
prepared = self._prepare_command(command, working_dir, timeout, shell, login)
|
||||
if isinstance(prepared, str):
|
||||
return prepared
|
||||
|
||||
if yield_time_ms is not None:
|
||||
return await self._execute_session(prepared, yield_time_ms, max_output_chars)
|
||||
|
||||
try:
|
||||
process = await self._spawn(
|
||||
prepared.command,
|
||||
prepared.cwd,
|
||||
prepared.env,
|
||||
prepared.shell_program,
|
||||
prepared.login,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=prepared.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await self._kill_process(process)
|
||||
return f"Error: Command timed out after {prepared.timeout} seconds"
|
||||
except asyncio.CancelledError:
|
||||
await self._kill_process(process)
|
||||
raise
|
||||
|
||||
output_parts = []
|
||||
|
||||
if stdout:
|
||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||
|
||||
if stderr:
|
||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||
if stderr_text.strip():
|
||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||
|
||||
max_len = clamp_session_int(max_output_chars, self._MAX_OUTPUT, 1000, MAX_OUTPUT_CHARS)
|
||||
if len(result) > max_len:
|
||||
half = max_len // 2
|
||||
result = (
|
||||
result[:half]
|
||||
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
|
||||
+ result[-half:]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
async def _execute_session(
|
||||
self,
|
||||
prepared: _PreparedCommand,
|
||||
yield_time_ms: int | None,
|
||||
max_output_chars: int | None,
|
||||
) -> str:
|
||||
try:
|
||||
session_id, poll = await self._session_manager.start(
|
||||
command=prepared.command,
|
||||
cwd=prepared.cwd,
|
||||
env=prepared.env,
|
||||
timeout=prepared.timeout,
|
||||
shell_program=prepared.shell_program,
|
||||
login=prepared.login,
|
||||
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
|
||||
max_output_chars=clamp_session_int(
|
||||
max_output_chars,
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
1000,
|
||||
MAX_OUTPUT_CHARS,
|
||||
),
|
||||
)
|
||||
return format_session_poll(session_id, poll)
|
||||
except Exception as exc:
|
||||
return f"Error executing command: {exc}"
|
||||
|
||||
def _prepare_command(
|
||||
self,
|
||||
command: str,
|
||||
working_dir: str | None = None,
|
||||
timeout: int | None = None,
|
||||
shell: str | None = None,
|
||||
login: bool | None = None,
|
||||
) -> _PreparedCommand | str:
|
||||
cwd = working_dir or self.working_dir or os.getcwd()
|
||||
|
||||
# Prevent an LLM-supplied working_dir from escaping the configured
|
||||
@ -211,52 +377,24 @@ class ExecTool(Tool):
|
||||
env["NANOBOT_PATH_APPEND"] = self.path_append
|
||||
command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}'
|
||||
|
||||
try:
|
||||
process = await self._spawn(command, cwd, env)
|
||||
shell_program, shell_error = self._resolve_shell(shell)
|
||||
if shell_error:
|
||||
return shell_error
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=effective_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await self._kill_process(process)
|
||||
return f"Error: Command timed out after {effective_timeout} seconds"
|
||||
except asyncio.CancelledError:
|
||||
await self._kill_process(process)
|
||||
raise
|
||||
|
||||
output_parts = []
|
||||
|
||||
if stdout:
|
||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||
|
||||
if stderr:
|
||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||
if stderr_text.strip():
|
||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||
|
||||
max_len = self._MAX_OUTPUT
|
||||
if len(result) > max_len:
|
||||
half = max_len // 2
|
||||
result = (
|
||||
result[:half]
|
||||
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
|
||||
+ result[-half:]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
return _PreparedCommand(
|
||||
command=command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
timeout=effective_timeout,
|
||||
shell_program=shell_program,
|
||||
login=True if login is None else login,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _spawn(
|
||||
command: str, cwd: str, env: dict[str, str],
|
||||
shell_program: str | None = None,
|
||||
login: bool = True,
|
||||
) -> asyncio.subprocess.Process:
|
||||
"""Launch *command* in a platform-appropriate shell."""
|
||||
if _IS_WINDOWS:
|
||||
@ -272,9 +410,13 @@ class ExecTool(Tool):
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
bash = shutil.which("bash") or "/bin/bash"
|
||||
shell_program = shell_program or shutil.which("bash") or "/bin/bash"
|
||||
args = [shell_program]
|
||||
if login and Path(shell_program).name in {"bash", "zsh"}:
|
||||
args.append("-l")
|
||||
args.extend(["-c", command])
|
||||
return await asyncio.create_subprocess_exec(
|
||||
bash, "-l", "-c", command,
|
||||
*args,
|
||||
stdin=asyncio.subprocess.DEVNULL,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
@ -282,6 +424,31 @@ class ExecTool(Tool):
|
||||
env=env,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_shell(shell: str | None) -> tuple[str | None, str | None]:
|
||||
if not shell:
|
||||
return None, None
|
||||
if _IS_WINDOWS:
|
||||
return None, "Error: shell parameter is not supported on Windows"
|
||||
if "\0" in shell or "\n" in shell or "\r" in shell:
|
||||
return None, "Error: shell contains invalid characters"
|
||||
allowed = {"sh", "bash", "zsh"}
|
||||
path = Path(shell).expanduser()
|
||||
if path.is_absolute():
|
||||
if path.name not in allowed:
|
||||
return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
|
||||
if not path.is_file() or not os.access(path, os.X_OK):
|
||||
return None, f"Error: shell is not executable: {shell}"
|
||||
return str(path), None
|
||||
if "/" in shell or "\\" in shell:
|
||||
return None, "Error: shell must be a shell name or absolute path"
|
||||
if shell not in allowed:
|
||||
return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
|
||||
resolved = shutil.which(shell)
|
||||
if not resolved:
|
||||
return None, f"Error: shell not found: {shell}"
|
||||
return resolved, None
|
||||
|
||||
@staticmethod
|
||||
async def _kill_process(process: asyncio.subprocess.Process) -> None:
|
||||
"""Kill a subprocess and reap it to prevent zombies."""
|
||||
|
||||
@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
@ -11,7 +10,7 @@ from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
|
||||
TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"})
|
||||
TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file"})
|
||||
_MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024
|
||||
_LIVE_EMIT_INTERVAL_S = 0.18
|
||||
_LIVE_EMIT_LINE_STEP = 24
|
||||
@ -704,77 +703,4 @@ def _predict_after_text(
|
||||
return before_text.replace(old_text, new_text)
|
||||
return before_text.replace(old_text, new_text, 1)
|
||||
return None
|
||||
if tool_name == "notebook_edit":
|
||||
return _predict_notebook_after_text(params, before_text)
|
||||
return None
|
||||
|
||||
|
||||
def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> str | None:
|
||||
try:
|
||||
nb = json.loads(before_text) if before_text.strip() else _empty_notebook()
|
||||
except Exception:
|
||||
return None
|
||||
cells = nb.get("cells")
|
||||
if not isinstance(cells, list):
|
||||
return None
|
||||
try:
|
||||
cell_index = int(params.get("cell_index", 0))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
new_source = params.get("new_source")
|
||||
source = new_source if isinstance(new_source, str) else ""
|
||||
cell_type = (
|
||||
params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code"
|
||||
)
|
||||
mode = (
|
||||
params.get("edit_mode")
|
||||
if params.get("edit_mode") in ("replace", "insert", "delete")
|
||||
else "replace"
|
||||
)
|
||||
if mode == "delete":
|
||||
if 0 <= cell_index < len(cells):
|
||||
cells.pop(cell_index)
|
||||
else:
|
||||
return None
|
||||
elif mode == "insert":
|
||||
insert_at = min(max(cell_index + 1, 0), len(cells))
|
||||
cells.insert(insert_at, _new_notebook_cell(source, str(cell_type)))
|
||||
else:
|
||||
if not (0 <= cell_index < len(cells)):
|
||||
return None
|
||||
cell = cells[cell_index]
|
||||
if not isinstance(cell, dict):
|
||||
return None
|
||||
cell["source"] = source
|
||||
cell["cell_type"] = cell_type
|
||||
if cell_type == "code":
|
||||
cell.setdefault("outputs", [])
|
||||
cell.setdefault("execution_count", None)
|
||||
else:
|
||||
cell.pop("outputs", None)
|
||||
cell.pop("execution_count", None)
|
||||
nb["cells"] = cells
|
||||
try:
|
||||
return json.dumps(nb, indent=1, ensure_ascii=False)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _empty_notebook() -> dict[str, Any]:
|
||||
return {
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5,
|
||||
"metadata": {
|
||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
||||
"language_info": {"name": "python"},
|
||||
},
|
||||
"cells": [],
|
||||
}
|
||||
|
||||
|
||||
def _new_notebook_cell(source: str, cell_type: str) -> dict[str, Any]:
|
||||
cell: dict[str, Any] = {"cell_type": cell_type, "source": source, "metadata": {}}
|
||||
if cell_type == "code":
|
||||
cell["outputs"] = []
|
||||
cell["execution_count"] = None
|
||||
return cell
|
||||
|
||||
238
tests/tools/test_apply_patch_tool.py
Normal file
238
tests/tools/test_apply_patch_tool.py
Normal file
@ -0,0 +1,238 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from nanobot.agent.tools.apply_patch import ApplyPatchTool
|
||||
|
||||
|
||||
def test_apply_patch_adds_file(tmp_path):
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Add File: hello.txt
|
||||
+Hello
|
||||
+world
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "Patch applied" in result
|
||||
assert (tmp_path / "hello.txt").read_text() == "Hello\nworld\n"
|
||||
|
||||
|
||||
def test_apply_patch_updates_multiple_hunks(tmp_path):
|
||||
target = tmp_path / "multi.txt"
|
||||
target.write_text("line1\nline2\nline3\nline4\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: multi.txt
|
||||
@@
|
||||
-line2
|
||||
+changed2
|
||||
@@
|
||||
-line4
|
||||
+changed4
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "update multi.txt" in result
|
||||
assert target.read_text() == "line1\nchanged2\nline3\nchanged4\n"
|
||||
|
||||
|
||||
def test_apply_patch_ignores_standard_no_newline_marker(tmp_path):
|
||||
target = tmp_path / "plain.txt"
|
||||
target.write_text("before")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: plain.txt
|
||||
@@ -1,1 +1,1 @@
|
||||
-before
|
||||
\\ No newline at end of file
|
||||
+after
|
||||
\\ No newline at end of file
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "update plain.txt" in result
|
||||
assert target.read_text() == "after\n"
|
||||
|
||||
|
||||
def test_apply_patch_rejects_empty_hunk(tmp_path):
|
||||
target = tmp_path / "plain.txt"
|
||||
target.write_text("before\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: plain.txt
|
||||
@@
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "hunk is empty" in result
|
||||
assert target.read_text() == "before\n"
|
||||
|
||||
|
||||
def test_apply_patch_uses_unified_diff_line_hint(tmp_path):
|
||||
target = tmp_path / "repeated.txt"
|
||||
target.write_text("target\nmiddle\ntarget\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: repeated.txt
|
||||
@@ -3,1 +3,1 @@
|
||||
-target
|
||||
+changed
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "update repeated.txt" in result
|
||||
assert target.read_text() == "target\nmiddle\nchanged\n"
|
||||
|
||||
|
||||
def test_apply_patch_line_hint_does_not_fallback_to_earlier_match(tmp_path):
|
||||
target = tmp_path / "repeated.txt"
|
||||
target.write_text("target\nmiddle\nother\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: repeated.txt
|
||||
@@ -3,1 +3,1 @@
|
||||
-target
|
||||
+changed
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "hunk does not match repeated.txt" in result
|
||||
assert target.read_text() == "target\nmiddle\nother\n"
|
||||
|
||||
|
||||
def test_apply_patch_mismatch_reports_best_match(tmp_path):
|
||||
target = tmp_path / "near.txt"
|
||||
target.write_text("alpha\nbeta\ngamma\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: near.txt
|
||||
@@ -2,1 +2,1 @@
|
||||
-betx
|
||||
+delta
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "hunk does not match near.txt" in result
|
||||
assert "Best match" in result
|
||||
assert "line 2" in result
|
||||
assert target.read_text() == "alpha\nbeta\ngamma\n"
|
||||
|
||||
|
||||
def test_apply_patch_moves_and_updates_file(tmp_path):
|
||||
source = tmp_path / "old" / "name.txt"
|
||||
source.parent.mkdir()
|
||||
source.write_text("old content\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: old/name.txt
|
||||
*** Move to: renamed/dir/name.txt
|
||||
@@
|
||||
-old content
|
||||
+new content
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "move old/name.txt -> renamed/dir/name.txt" in result
|
||||
assert not source.exists()
|
||||
assert (tmp_path / "renamed" / "dir" / "name.txt").read_text() == "new content\n"
|
||||
|
||||
|
||||
def test_apply_patch_deletes_file(tmp_path):
|
||||
target = tmp_path / "obsolete.txt"
|
||||
target.write_text("remove me\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Delete File: obsolete.txt
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "delete obsolete.txt" in result
|
||||
assert not target.exists()
|
||||
|
||||
|
||||
def test_apply_patch_rejects_absolute_and_parent_paths(tmp_path):
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
absolute = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Add File: /tmp/owned.txt
|
||||
+nope
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
parent = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Add File: ../owned.txt
|
||||
+nope
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "must be relative" in absolute
|
||||
assert "must not contain '..'" in parent
|
||||
assert not (tmp_path.parent / "owned.txt").exists()
|
||||
|
||||
|
||||
def test_apply_patch_does_not_overwrite_existing_file_with_add(tmp_path):
|
||||
target = tmp_path / "existing.txt"
|
||||
target.write_text("keep me\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Add File: existing.txt
|
||||
+replace me
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "file to add already exists" in result
|
||||
assert target.read_text() == "keep me\n"
|
||||
|
||||
|
||||
def test_apply_patch_rolls_back_when_late_operation_fails(tmp_path):
|
||||
first = tmp_path / "first.txt"
|
||||
first.write_text("before\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: first.txt
|
||||
@@
|
||||
-before
|
||||
+after
|
||||
*** Delete File: missing.txt
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "file to delete does not exist" in result
|
||||
assert first.read_text() == "before\n"
|
||||
@ -1,5 +1,5 @@
|
||||
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
|
||||
.ipynb detection, and create-file semantics."""
|
||||
notebook JSON editing, and create-file semantics."""
|
||||
|
||||
import pytest
|
||||
|
||||
@ -108,22 +108,27 @@ class TestEditCreateFile:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# .ipynb detection
|
||||
# .ipynb editing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditIpynbDetection:
|
||||
"""edit_file should refuse .ipynb and suggest notebook_edit."""
|
||||
class TestEditIpynbFiles:
|
||||
"""edit_file edits notebooks as normal JSON files."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path):
|
||||
async def test_ipynb_can_be_edited_as_json(self, tool, tmp_path):
|
||||
f = tmp_path / "analysis.ipynb"
|
||||
f.write_text('{"cells": []}', encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
||||
assert "notebook" in result.lower()
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='"cells": []',
|
||||
new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
|
||||
)
|
||||
assert "Successfully edited" in result
|
||||
assert '"source": "hi"' in f.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
242
tests/tools/test_exec_session_tools.py
Normal file
242
tests/tools/test_exec_session_tools.py
Normal file
@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.exec_session import ExecSessionManager, WriteStdinTool
|
||||
|
||||
|
||||
def _python_command(code: str) -> str:
|
||||
if sys.platform == "win32":
|
||||
return f"{subprocess.list2cmdline([sys.executable])} -u -c {subprocess.list2cmdline([code])}"
|
||||
return f"{shlex.quote(sys.executable)} -u -c {shlex.quote(code)}"
|
||||
|
||||
|
||||
def _session_id(output: str) -> str:
|
||||
match = re.search(r"session_id:\s*([0-9a-f]+)", output)
|
||||
assert match, output
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def test_exec_keeps_one_shot_behavior_without_yield_time_ms(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
return await tool.execute(command="echo hello")
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "hello" in result
|
||||
assert "Exit code: 0" in result
|
||||
assert "session_id:" not in result
|
||||
|
||||
|
||||
def test_exec_accepts_command_aliases(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir="/")
|
||||
return await tool.execute(cmd="pwd", workdir=str(tmp_path))
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert str(tmp_path) in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_returns_completed_session_output_when_yield_time_ms_is_used(tmp_path):
|
||||
async def run() -> str:
|
||||
manager = ExecSessionManager()
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
|
||||
result = await tool.execute(command="echo hello", yield_time_ms=1000)
|
||||
if "session_id:" in result:
|
||||
sid = _session_id(result)
|
||||
result += "\n" + await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
chars="",
|
||||
yield_time_ms=1000,
|
||||
)
|
||||
return result
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "hello" in result
|
||||
assert "Exit code: 0" in result
|
||||
assert "session_id:" not in result
|
||||
|
||||
|
||||
def test_exec_session_accepts_max_output_tokens_alias(tmp_path):
|
||||
async def run() -> str:
|
||||
manager = ExecSessionManager()
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
command = _python_command("print('A' * 2000)")
|
||||
return await tool.execute(
|
||||
command=command,
|
||||
yield_time_ms=1000,
|
||||
max_output_tokens=1000,
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "chars truncated" in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_one_shot_accepts_max_output_tokens_alias(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
command = _python_command("print('A' * 2000)")
|
||||
return await tool.execute(command=command, max_output_tokens=1000)
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "chars truncated" in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_accepts_supported_shell_parameter(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
return await tool.execute(command="echo shell-ok", shell="sh", login=False)
|
||||
|
||||
if sys.platform == "win32":
|
||||
return
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "shell-ok" in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_rejects_unsupported_shell(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
return await tool.execute(command="echo no", shell="python")
|
||||
|
||||
if sys.platform == "win32":
|
||||
return
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "unsupported shell" in result
|
||||
|
||||
|
||||
def test_exec_can_continue_with_stdin(tmp_path):
|
||||
async def run() -> tuple[str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import sys; print('ready', flush=True); "
|
||||
"line=sys.stdin.readline(); print('got:' + line.strip(), flush=True)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||
sid = _session_id(initial)
|
||||
result = await stdin_tool.execute(session_id=sid, chars="ping\n", yield_time_ms=1000)
|
||||
return initial, result
|
||||
|
||||
initial, result = asyncio.run(run())
|
||||
assert "ready" in initial
|
||||
assert "Process running" in initial
|
||||
assert "got:ping" in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_write_stdin_can_close_stdin(tmp_path):
|
||||
async def run() -> tuple[str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import sys; print('ready', flush=True); "
|
||||
"data=sys.stdin.read(); print('got:' + data, flush=True)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||
sid = _session_id(initial)
|
||||
result = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
chars="payload",
|
||||
close_stdin=True,
|
||||
yield_time_ms=1000,
|
||||
)
|
||||
return initial, result
|
||||
|
||||
initial, result = asyncio.run(run())
|
||||
assert "ready" in initial
|
||||
assert "got:payload" in result
|
||||
assert "Stdin closed." in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_write_stdin_can_terminate_session(tmp_path):
|
||||
async def run() -> tuple[str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=30, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('ready', flush=True); time.sleep(30)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||
sid = _session_id(initial)
|
||||
result = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
terminate=True,
|
||||
yield_time_ms=0,
|
||||
)
|
||||
return initial, result
|
||||
|
||||
initial, result = asyncio.run(run())
|
||||
assert "ready" in initial
|
||||
assert "Session terminated." in result
|
||||
assert "Exit code:" in result
|
||||
|
||||
|
||||
def test_write_stdin_accepts_max_output_tokens_alias(tmp_path):
|
||||
async def run() -> tuple[str, str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('A' * 2000, flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=0)
|
||||
sid = _session_id(initial)
|
||||
poll = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
yield_time_ms=500,
|
||||
max_output_tokens=1000,
|
||||
)
|
||||
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||
return initial, poll, cleanup
|
||||
|
||||
initial, poll, cleanup = asyncio.run(run())
|
||||
assert "Process running" in initial
|
||||
assert "chars truncated" in poll
|
||||
assert "Session terminated." in cleanup
|
||||
|
||||
|
||||
def test_exec_session_mode_reuses_exec_safety_guard(tmp_path):
|
||||
manager = ExecSessionManager()
|
||||
tool = ExecTool(
|
||||
working_dir=str(tmp_path),
|
||||
deny_patterns=[r"echo\s+blocked"],
|
||||
session_manager=manager,
|
||||
)
|
||||
|
||||
result = asyncio.run(tool.execute(command="echo blocked", yield_time_ms=0))
|
||||
|
||||
assert "blocked by deny pattern" in result
|
||||
|
||||
|
||||
def test_write_stdin_reports_missing_session(tmp_path):
|
||||
manager = ExecSessionManager()
|
||||
tool = WriteStdinTool(manager=manager)
|
||||
|
||||
result = asyncio.run(tool.execute(session_id="missing", chars=""))
|
||||
|
||||
assert "exec session not found" in result
|
||||
216
tests/tools/test_file_edit_coding_enhancements.py
Normal file
216
tests/tools/test_file_edit_coding_enhancements.py
Normal file
@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool
|
||||
|
||||
|
||||
def test_read_file_force_bypasses_dedup(tmp_path):
|
||||
target = tmp_path / "data.txt"
|
||||
target.write_text("alpha\n")
|
||||
tool = ReadFileTool(workspace=tmp_path)
|
||||
|
||||
first = asyncio.run(tool.execute(path=str(target)))
|
||||
second = asyncio.run(tool.execute(path=str(target)))
|
||||
forced = asyncio.run(tool.execute(path=str(target), force=True))
|
||||
|
||||
assert "alpha" in first
|
||||
assert "unchanged" in second.lower()
|
||||
assert "alpha" in forced
|
||||
assert "unchanged" not in forced.lower()
|
||||
|
||||
|
||||
def test_edit_file_can_select_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("one\nsame\ntwo\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=2,
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert target.read_text() == "one\nsame\ntwo\nchanged\n"
|
||||
|
||||
|
||||
def test_edit_file_expected_replacements_guards_replace_all(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
replace_all=True,
|
||||
expected_replacements=1,
|
||||
))
|
||||
|
||||
assert "expected 1 replacements but would make 2" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_expected_replacements_allows_replace_all_when_count_matches(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
replace_all=True,
|
||||
expected_replacements=2,
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert target.read_text() == "changed\nchanged\n"
|
||||
|
||||
|
||||
def test_edit_file_can_select_nearest_line_hint(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("one\nsame\ntwo\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=4,
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert target.read_text() == "one\nsame\ntwo\nchanged\n"
|
||||
|
||||
|
||||
def test_edit_file_can_edit_ipynb_as_json(tmp_path):
|
||||
target = tmp_path / "analysis.ipynb"
|
||||
target.write_text('{"cells": []}')
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text='"cells": []',
|
||||
new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert '"source": "hi"' in target.read_text()
|
||||
|
||||
|
||||
def test_edit_file_multiple_match_hint_mentions_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
))
|
||||
|
||||
assert "old_text appears 2 times" in result
|
||||
assert "occurrence" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_ambiguous_line_hint(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nmiddle\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=2,
|
||||
))
|
||||
|
||||
assert "line_hint 2 is ambiguous" in result
|
||||
assert target.read_text() == "same\nmiddle\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_occurrence_with_replace_all(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=1,
|
||||
replace_all=True,
|
||||
))
|
||||
|
||||
assert "occurrence cannot be used with replace_all" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_line_hint_with_replace_all(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=1,
|
||||
replace_all=True,
|
||||
))
|
||||
|
||||
assert "line_hint cannot be used with replace_all" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_line_hint_with_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=1,
|
||||
line_hint=1,
|
||||
))
|
||||
|
||||
assert "line_hint cannot be used with occurrence" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_zero_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=0,
|
||||
))
|
||||
|
||||
assert "occurrence must be >= 1" in result
|
||||
assert target.read_text() == "same\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_zero_line_hint(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=0,
|
||||
))
|
||||
|
||||
assert "line_hint must be >= 1" in result
|
||||
assert target.read_text() == "same\n"
|
||||
@ -1,147 +0,0 @@
|
||||
"""Tests for NotebookEditTool — Jupyter .ipynb editing."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
|
||||
|
||||
def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict:
|
||||
"""Build a minimal valid .ipynb structure."""
|
||||
return {
|
||||
"nbformat": nbformat,
|
||||
"nbformat_minor": nbformat_minor,
|
||||
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
|
||||
"cells": cells or [],
|
||||
}
|
||||
|
||||
|
||||
def _code_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _md_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "markdown", "source": source, "metadata": {}}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _write_nb(tmp_path, name: str, nb: dict) -> str:
|
||||
p = tmp_path / name
|
||||
p.write_text(json.dumps(nb), encoding="utf-8")
|
||||
return str(p)
|
||||
|
||||
|
||||
class TestNotebookEdit:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return NotebookEditTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_cell_content(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="print('world')")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][0]["source"] == "print('world')"
|
||||
assert saved["cells"][1]["source"] == "x = 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_cell_after_target(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 3
|
||||
assert saved["cells"][0]["source"] == "cell 0"
|
||||
assert saved["cells"][1]["source"] == "inserted"
|
||||
assert saved["cells"][2]["source"] == "cell 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_cell(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=1, edit_mode="delete")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 2
|
||||
assert saved["cells"][0]["source"] == "A"
|
||||
assert saved["cells"][1]["source"] == "C"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_notebook_from_scratch(self, tool, tmp_path):
|
||||
path = str(tmp_path / "new.ipynb")
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown")
|
||||
assert "Successfully" in result or "created" in result.lower()
|
||||
saved = json.loads((tmp_path / "new.ipynb").read_text())
|
||||
assert saved["nbformat"] == 4
|
||||
assert len(saved["cells"]) == 1
|
||||
assert saved["cells"][0]["cell_type"] == "markdown"
|
||||
assert saved["cells"][0]["source"] == "# Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_index_error(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("only cell")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=5, new_source="x")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_ipynb_rejected(self, tool, tmp_path):
|
||||
f = tmp_path / "script.py"
|
||||
f.write_text("pass")
|
||||
result = await tool.execute(path=str(f), cell_index=0, new_source="x")
|
||||
assert "Error" in result
|
||||
assert ".ipynb" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_metadata_and_outputs(self, tool, tmp_path):
|
||||
cell = _code_cell("old")
|
||||
cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}]
|
||||
cell["execution_count"] = 42
|
||||
nb = _make_notebook([cell])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="new")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["metadata"]["kernelspec"]["language"] == "python"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nbformat_45_generates_cell_id(self, tool, tmp_path):
|
||||
nb = _make_notebook([], nbformat_minor=5)
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert "id" in saved["cells"][0]
|
||||
assert len(saved["cells"][0]["id"]) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_with_cell_type_markdown(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][1]["cell_type"] == "markdown"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_edit_mode_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae")
|
||||
assert "Error" in result
|
||||
assert "edit_mode" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_type_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw")
|
||||
assert "Error" in result
|
||||
assert "cell_type" in result
|
||||
@ -89,9 +89,11 @@ def test_discover_finds_concrete_tools():
|
||||
loader = ToolLoader()
|
||||
discovered = loader.discover()
|
||||
class_names = {cls.__name__ for cls in discovered}
|
||||
assert "ApplyPatchTool" in class_names
|
||||
assert "ExecTool" in class_names
|
||||
assert "MessageTool" in class_names
|
||||
assert "SpawnTool" in class_names
|
||||
assert "WriteStdinTool" in class_names
|
||||
|
||||
|
||||
def test_discover_excludes_abstract_and_mcp():
|
||||
@ -406,7 +408,7 @@ def test_loader_registers_same_tools_as_old_hardcoded():
|
||||
|
||||
expected = {
|
||||
"read_file", "write_file", "edit_file", "list_dir",
|
||||
"grep", "notebook_edit", "exec", "web_search", "web_fetch",
|
||||
"grep", "exec", "web_search", "web_fetch",
|
||||
"message", "spawn", "cron",
|
||||
}
|
||||
actual = set(registered)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user