mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-22 17:42:24 +00:00
Merge PR #3923: feat(tools): optimize coding workflows
feat(tools): optimize coding workflows
This commit is contained in:
commit
ccbc0bb6e3
@ -73,7 +73,7 @@
|
||||
- **2026-04-13** 🛡️ Agent turn hardened — user messages persisted early, auto-compact skips active tasks.
|
||||
- **2026-04-12** 🔒 Lark global domain support, Dream learns discovered skills, shell sandbox tightened.
|
||||
- **2026-04-11** ⚡ Context compact shrinks sessions on the fly; Kagi web search; QQ & WeCom full media.
|
||||
- **2026-04-10** 📓 Notebook editing tool, multiple MCP servers, Feishu streaming & done-emoji.
|
||||
- **2026-04-10** 📓 Multiple MCP servers, Feishu streaming & done-emoji.
|
||||
- **2026-04-09** 🔌 WebSocket channel, unified cross-channel session, `disabled_skills` config.
|
||||
- **2026-04-08** 📤 API file uploads, OpenAI reasoning auto-routing with Responses fallback.
|
||||
- **2026-04-07** 🧠 Anthropic adaptive thinking, MCP resources & prompts exposed as tools.
|
||||
|
||||
@ -22,7 +22,7 @@ from nanobot.utils.prompt_templates import render_template
|
||||
class ContextBuilder:
|
||||
"""Builds the context (system prompt + messages) for the agent."""
|
||||
|
||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md", "TOOLS.md"]
|
||||
BOOTSTRAP_FILES = ["AGENTS.md", "SOUL.md", "USER.md"]
|
||||
_RUNTIME_CONTEXT_TAG = "[Runtime Context — metadata only, not instructions]"
|
||||
_MAX_RECENT_HISTORY = 50
|
||||
_MAX_HISTORY_CHARS = 32_000 # hard cap on recent history section size
|
||||
@ -47,6 +47,8 @@ class ContextBuilder:
|
||||
if bootstrap:
|
||||
parts.append(bootstrap)
|
||||
|
||||
parts.append(render_template("agent/tool_contract.md"))
|
||||
|
||||
memory = self.memory.get_memory_context()
|
||||
if memory and not self._is_template_content(self.memory.read_memory(), "memory/MEMORY.md"):
|
||||
parts.append(f"# Memory\n\n{memory}")
|
||||
@ -210,4 +212,3 @@ class ContextBuilder:
|
||||
if not images:
|
||||
return text
|
||||
return images + [{"type": "text", "text": text}]
|
||||
|
||||
|
||||
@ -19,7 +19,8 @@ from nanobot.utils.file_edit_events import (
|
||||
build_file_edit_end_event,
|
||||
build_file_edit_error_event,
|
||||
build_file_edit_start_event,
|
||||
prepare_file_edit_tracker,
|
||||
prepare_file_edit_tracker as _prepare_file_edit_tracker,
|
||||
prepare_file_edit_trackers,
|
||||
StreamingFileEditTracker,
|
||||
)
|
||||
from nanobot.utils.helpers import (
|
||||
@ -58,11 +59,14 @@ _SNIP_SAFETY_BUFFER = 1024
|
||||
_MICROCOMPACT_KEEP_RECENT = 10
|
||||
_MICROCOMPACT_MIN_CHARS = 500
|
||||
_COMPACTABLE_TOOLS = frozenset({
|
||||
"read_file", "exec", "grep",
|
||||
"web_search", "web_fetch", "list_dir",
|
||||
"read_file", "exec", "grep", "find_files",
|
||||
"web_search", "web_fetch", "list_dir", "list_exec_sessions",
|
||||
})
|
||||
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
|
||||
|
||||
# Backward-compatible module attribute for tests/extensions that monkeypatch
|
||||
# the former single-file tracker hook. Runtime uses prepare_file_edit_trackers.
|
||||
prepare_file_edit_tracker = _prepare_file_edit_tracker
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -857,8 +861,8 @@ class AgentRunner:
|
||||
and on_progress_accepts_file_edit_events(spec.progress_callback)
|
||||
)
|
||||
progress_callback = spec.progress_callback if emit_file_edit_events else None
|
||||
file_edit_tracker = (
|
||||
prepare_file_edit_tracker(
|
||||
file_edit_trackers = (
|
||||
prepare_file_edit_trackers(
|
||||
call_id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
tool=tool,
|
||||
@ -868,13 +872,13 @@ class AgentRunner:
|
||||
if progress_callback is not None
|
||||
else None
|
||||
)
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
if file_edit_trackers and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
progress_callback,
|
||||
[build_file_edit_start_event(
|
||||
file_edit_tracker,
|
||||
params if isinstance(params, dict) else None,
|
||||
)],
|
||||
) for file_edit_tracker in file_edit_trackers],
|
||||
)
|
||||
try:
|
||||
if tool is not None:
|
||||
@ -884,10 +888,13 @@ class AgentRunner:
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except BaseException as exc:
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
if file_edit_trackers and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
progress_callback,
|
||||
[build_file_edit_error_event(file_edit_tracker, str(exc))],
|
||||
[
|
||||
build_file_edit_error_event(file_edit_tracker, str(exc))
|
||||
for file_edit_tracker in file_edit_trackers
|
||||
],
|
||||
)
|
||||
event = {
|
||||
"name": tool_call.name,
|
||||
@ -910,10 +917,13 @@ class AgentRunner:
|
||||
return payload, event, None
|
||||
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
if file_edit_trackers and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
progress_callback,
|
||||
[build_file_edit_error_event(file_edit_tracker, result)],
|
||||
[
|
||||
build_file_edit_error_event(file_edit_tracker, result)
|
||||
for file_edit_tracker in file_edit_trackers
|
||||
],
|
||||
)
|
||||
event = {
|
||||
"name": tool_call.name,
|
||||
@ -933,13 +943,13 @@ class AgentRunner:
|
||||
return result + hint, event, RuntimeError(result)
|
||||
return result + hint, event, None
|
||||
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
if file_edit_trackers and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
progress_callback,
|
||||
[build_file_edit_end_event(
|
||||
file_edit_tracker,
|
||||
params if isinstance(params, dict) else None,
|
||||
)],
|
||||
) for file_edit_tracker in file_edit_trackers],
|
||||
)
|
||||
|
||||
detail = "" if result is None else str(result)
|
||||
|
||||
431
nanobot/agent/tools/apply_patch.py
Normal file
431
nanobot/agent/tools/apply_patch.py
Normal file
@ -0,0 +1,431 @@
|
||||
"""Structured patch editing tool for coding workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from nanobot.agent.tools.base import tool_parameters
|
||||
from nanobot.agent.tools.filesystem import _FsTool
|
||||
from nanobot.agent.tools.schema import BooleanSchema, StringSchema, tool_parameters_schema
|
||||
|
||||
|
||||
PatchKind = Literal["add", "delete", "update"]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _Hunk:
|
||||
header: str | None
|
||||
lines: list[tuple[str, str]]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _PatchOp:
|
||||
kind: PatchKind
|
||||
path: str
|
||||
new_path: str | None = None
|
||||
add_lines: list[str] | None = None
|
||||
hunks: list[_Hunk] | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _PatchSummary:
|
||||
action: str
|
||||
path: str
|
||||
added: int = 0
|
||||
deleted: int = 0
|
||||
new_path: str | None = None
|
||||
|
||||
|
||||
class _PatchError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
_ABSOLUTE_WINDOWS_RE = re.compile(r"^[A-Za-z]:[\\/]")
|
||||
|
||||
|
||||
def _is_file_header(line: str) -> bool:
|
||||
return (
|
||||
line.startswith("*** Add File: ")
|
||||
or line.startswith("*** Delete File: ")
|
||||
or line.startswith("*** Update File: ")
|
||||
)
|
||||
|
||||
|
||||
def _validate_relative_path(path: str) -> str:
|
||||
normalized = path.strip()
|
||||
if not normalized:
|
||||
raise _PatchError("patch path cannot be empty")
|
||||
if "\0" in normalized:
|
||||
raise _PatchError(f"patch path contains a null byte: {path!r}")
|
||||
if normalized.startswith(("~", "/", "\\")) or _ABSOLUTE_WINDOWS_RE.match(normalized):
|
||||
raise _PatchError(f"patch path must be relative: {path}")
|
||||
if any(part == ".." for part in re.split(r"[\\/]+", normalized)):
|
||||
raise _PatchError(f"patch path must not contain '..': {path}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _lines_to_text(lines: list[str]) -> str:
|
||||
if not lines:
|
||||
return ""
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _text_line_count(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
return len(text.splitlines())
|
||||
|
||||
|
||||
def _line_diff_stats(before: str, after: str) -> tuple[int, int]:
|
||||
before_lines = before.replace("\r\n", "\n").splitlines()
|
||||
after_lines = after.replace("\r\n", "\n").splitlines()
|
||||
added = 0
|
||||
deleted = 0
|
||||
matcher = difflib.SequenceMatcher(a=before_lines, b=after_lines, autojunk=False)
|
||||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||
if tag == "equal":
|
||||
continue
|
||||
if tag in ("replace", "delete"):
|
||||
deleted += i2 - i1
|
||||
if tag in ("replace", "insert"):
|
||||
added += j2 - j1
|
||||
return added, deleted
|
||||
|
||||
|
||||
def _format_summary(summary: _PatchSummary) -> str:
|
||||
path = (
|
||||
f"{summary.path} -> {summary.new_path}"
|
||||
if summary.new_path
|
||||
else summary.path
|
||||
)
|
||||
stats = ""
|
||||
if summary.added or summary.deleted:
|
||||
stats = f" (+{summary.added}/-{summary.deleted})"
|
||||
return f"- {summary.action} {path}{stats}"
|
||||
|
||||
|
||||
def _parse_patch(patch: str) -> list[_PatchOp]:
|
||||
lines = patch.replace("\r\n", "\n").replace("\r", "\n").split("\n")
|
||||
if lines and lines[-1] == "":
|
||||
lines.pop()
|
||||
if not lines or lines[0] != "*** Begin Patch":
|
||||
raise _PatchError("patch must start with '*** Begin Patch'")
|
||||
if len(lines) < 2 or lines[-1] != "*** End Patch":
|
||||
raise _PatchError("patch must end with '*** End Patch'")
|
||||
|
||||
ops: list[_PatchOp] = []
|
||||
i = 1
|
||||
end = len(lines) - 1
|
||||
while i < end:
|
||||
line = lines[i]
|
||||
if line.startswith("*** Add File: "):
|
||||
path = _validate_relative_path(line.removeprefix("*** Add File: "))
|
||||
i += 1
|
||||
add_lines: list[str] = []
|
||||
while i < end and not _is_file_header(lines[i]):
|
||||
if not lines[i].startswith("+"):
|
||||
raise _PatchError(f"Add File lines must start with '+': {lines[i]!r}")
|
||||
add_lines.append(lines[i][1:])
|
||||
i += 1
|
||||
ops.append(_PatchOp(kind="add", path=path, add_lines=add_lines))
|
||||
continue
|
||||
|
||||
if line.startswith("*** Delete File: "):
|
||||
path = _validate_relative_path(line.removeprefix("*** Delete File: "))
|
||||
ops.append(_PatchOp(kind="delete", path=path))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if line.startswith("*** Update File: "):
|
||||
path = _validate_relative_path(line.removeprefix("*** Update File: "))
|
||||
i += 1
|
||||
new_path: str | None = None
|
||||
if i < end and lines[i].startswith("*** Move to: "):
|
||||
new_path = _validate_relative_path(lines[i].removeprefix("*** Move to: "))
|
||||
i += 1
|
||||
|
||||
hunks: list[_Hunk] = []
|
||||
while i < end and not _is_file_header(lines[i]):
|
||||
if not lines[i].startswith("@@"):
|
||||
raise _PatchError(f"Update File sections require '@@' hunks: {lines[i]!r}")
|
||||
header = lines[i][2:].strip() or None
|
||||
i += 1
|
||||
hunk_lines: list[tuple[str, str]] = []
|
||||
while i < end and not lines[i].startswith("@@") and not _is_file_header(lines[i]):
|
||||
if lines[i] == "*** End of File":
|
||||
i += 1
|
||||
break
|
||||
if lines[i] == r"\ No newline at end of file":
|
||||
i += 1
|
||||
continue
|
||||
if not lines[i] or lines[i][0] not in {" ", "+", "-"}:
|
||||
raise _PatchError(f"Hunk lines must start with ' ', '+', or '-': {lines[i]!r}")
|
||||
hunk_lines.append((lines[i][0], lines[i][1:]))
|
||||
i += 1
|
||||
if not hunk_lines:
|
||||
raise _PatchError(f"Update File hunk is empty: {path}")
|
||||
hunks.append(_Hunk(header=header, lines=hunk_lines))
|
||||
if not hunks and new_path is None:
|
||||
raise _PatchError(f"Update File requires at least one hunk or Move to: {path}")
|
||||
ops.append(_PatchOp(kind="update", path=path, new_path=new_path, hunks=hunks))
|
||||
continue
|
||||
|
||||
raise _PatchError(f"unknown patch header: {line!r}")
|
||||
|
||||
if not ops:
|
||||
raise _PatchError("patch contains no file operations")
|
||||
return ops
|
||||
|
||||
|
||||
def _find_with_eof_fallback(content: str, needle: str, start: int) -> tuple[int, int]:
|
||||
pos = content.find(needle, start)
|
||||
if pos >= 0:
|
||||
return pos, len(needle)
|
||||
if needle.endswith("\n"):
|
||||
trimmed = needle[:-1]
|
||||
pos = content.find(trimmed, start)
|
||||
if pos >= 0 and pos + len(trimmed) == len(content):
|
||||
return pos, len(trimmed)
|
||||
return -1, 0
|
||||
|
||||
|
||||
def _line_offset(content: str, line_number: int) -> int:
|
||||
if line_number <= 1:
|
||||
return 0
|
||||
offset = 0
|
||||
for current, line in enumerate(content.splitlines(keepends=True), start=1):
|
||||
if current >= line_number:
|
||||
return offset
|
||||
offset += len(line)
|
||||
return len(content)
|
||||
|
||||
|
||||
def _line_hint(header: str | None) -> int | None:
|
||||
if not header:
|
||||
return None
|
||||
match = re.search(r"-(\d+)(?:,\d+)?", header)
|
||||
return int(match.group(1)) if match else None
|
||||
|
||||
|
||||
def _hunk_mismatch(path: str, old_text: str, content: str, header: str | None) -> str:
|
||||
lines = content.splitlines(keepends=True)
|
||||
old_lines = old_text.splitlines(keepends=True)
|
||||
window = max(1, len(old_lines))
|
||||
best_ratio, best_start = -1.0, 0
|
||||
best_lines: list[str] = []
|
||||
for i in range(max(1, len(lines) - window + 1)):
|
||||
current = lines[i : i + window]
|
||||
ratio = difflib.SequenceMatcher(None, "".join(old_lines), "".join(current)).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio, best_start, best_lines = ratio, i, current
|
||||
|
||||
label = f" after header {header!r}" if header else ""
|
||||
if best_ratio <= 0:
|
||||
return f"hunk does not match {path}{label}"
|
||||
diff = "\n".join(difflib.unified_diff(
|
||||
old_lines,
|
||||
best_lines,
|
||||
fromfile="patch hunk",
|
||||
tofile=f"{path} (actual, line {best_start + 1})",
|
||||
lineterm="",
|
||||
))
|
||||
return (
|
||||
f"hunk does not match {path}{label}. "
|
||||
f"Best match ({best_ratio:.0%} similar) at line {best_start + 1}:\n{diff}"
|
||||
)
|
||||
|
||||
|
||||
def _apply_hunks(path: str, content: str, hunks: list[_Hunk]) -> str:
|
||||
cursor = 0
|
||||
for hunk in hunks:
|
||||
old_lines = [text for marker, text in hunk.lines if marker in {" ", "-"}]
|
||||
new_lines = [text for marker, text in hunk.lines if marker in {" ", "+"}]
|
||||
old_text = _lines_to_text(old_lines)
|
||||
new_text = _lines_to_text(new_lines)
|
||||
|
||||
search_start = cursor
|
||||
line_hint = None
|
||||
if hunk.header:
|
||||
line_hint = _line_hint(hunk.header)
|
||||
if line_hint is not None:
|
||||
search_start = _line_offset(content, line_hint)
|
||||
else:
|
||||
header_pos = content.find(hunk.header, cursor)
|
||||
if header_pos >= 0:
|
||||
search_start = header_pos
|
||||
|
||||
if old_text:
|
||||
pos, match_len = _find_with_eof_fallback(content, old_text, search_start)
|
||||
if pos < 0 and search_start != 0 and line_hint is None:
|
||||
pos, match_len = _find_with_eof_fallback(content, old_text, 0)
|
||||
if pos < 0:
|
||||
raise _PatchError(_hunk_mismatch(path, old_text, content, hunk.header))
|
||||
else:
|
||||
pos = search_start
|
||||
match_len = 0
|
||||
|
||||
content = content[:pos] + new_text + content[pos + match_len:]
|
||||
cursor = pos + len(new_text)
|
||||
return content
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
patch=StringSchema(
|
||||
"Full patch text. Use *** Begin Patch / *** End Patch and file sections "
|
||||
"for Add File, Update File, Delete File, and optional Move to.",
|
||||
min_length=1,
|
||||
),
|
||||
dry_run=BooleanSchema(
|
||||
description="Validate and summarize the patch without writing files.",
|
||||
default=False,
|
||||
),
|
||||
required=["patch"],
|
||||
)
|
||||
)
|
||||
class ApplyPatchTool(_FsTool):
|
||||
"""Apply a structured multi-file patch."""
|
||||
_scopes = {"core", "subagent"}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "apply_patch"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Default tool for code edits. Apply a structured patch with "
|
||||
"*** Begin Patch and *** End Patch. Supports Add File, Update File, "
|
||||
"Delete File, and Move to across one or more files. Use this for "
|
||||
"multi-file changes, structural edits, generated code, or any edit "
|
||||
"where a reviewable patch is clearer than an exact replacement. "
|
||||
"Paths must be relative. Set dry_run=true to validate and preview "
|
||||
"the change summary without writing files. Use edit_file only for "
|
||||
"small exact replacements copied from read_file."
|
||||
)
|
||||
|
||||
async def execute(self, patch: str, dry_run: bool = False, **kwargs: Any) -> str:
|
||||
try:
|
||||
ops = _parse_patch(patch)
|
||||
writes: dict[Path, str] = {}
|
||||
deletes: set[Path] = set()
|
||||
summaries: list[_PatchSummary] = []
|
||||
|
||||
for op in ops:
|
||||
source = self._resolve(op.path)
|
||||
if op.kind == "add":
|
||||
if source.exists() or source in writes:
|
||||
raise _PatchError(f"file to add already exists: {op.path}")
|
||||
new_content = _lines_to_text(op.add_lines or [])
|
||||
writes[source] = new_content
|
||||
deletes.discard(source)
|
||||
summaries.append(_PatchSummary(
|
||||
action="add",
|
||||
path=op.path,
|
||||
added=_text_line_count(new_content),
|
||||
))
|
||||
continue
|
||||
|
||||
if op.kind == "delete":
|
||||
pending_content = writes.get(source)
|
||||
if pending_content is None and not source.exists():
|
||||
raise _PatchError(f"file to delete does not exist: {op.path}")
|
||||
if pending_content is None and not source.is_file():
|
||||
raise _PatchError(f"path to delete is not a file: {op.path}")
|
||||
deleted_lines = 0
|
||||
if pending_content is not None:
|
||||
deleted_lines = _text_line_count(pending_content)
|
||||
else:
|
||||
raw = source.read_bytes()
|
||||
try:
|
||||
deleted_lines = _text_line_count(raw.decode("utf-8"))
|
||||
except UnicodeDecodeError:
|
||||
deleted_lines = 0
|
||||
deletes.add(source)
|
||||
writes.pop(source, None)
|
||||
summaries.append(_PatchSummary(
|
||||
action="delete",
|
||||
path=op.path,
|
||||
deleted=deleted_lines,
|
||||
))
|
||||
continue
|
||||
|
||||
pending_content = writes.get(source)
|
||||
if pending_content is None and not source.exists():
|
||||
raise _PatchError(f"file to update does not exist: {op.path}")
|
||||
if pending_content is None and not source.is_file():
|
||||
raise _PatchError(f"path to update is not a file: {op.path}")
|
||||
if pending_content is not None:
|
||||
content = pending_content
|
||||
else:
|
||||
raw = source.read_bytes()
|
||||
try:
|
||||
content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise _PatchError(f"file to update is not UTF-8 text: {op.path}") from exc
|
||||
uses_crlf = "\r\n" in content
|
||||
content = content.replace("\r\n", "\n")
|
||||
new_content = _apply_hunks(op.path, content, op.hunks or [])
|
||||
added, deleted = _line_diff_stats(content, new_content)
|
||||
if uses_crlf:
|
||||
new_content = new_content.replace("\n", "\r\n")
|
||||
|
||||
target = self._resolve(op.new_path) if op.new_path else source
|
||||
if op.new_path and (target.exists() or target in writes) and target != source:
|
||||
raise _PatchError(f"move target already exists: {op.new_path}")
|
||||
writes[target] = new_content
|
||||
deletes.discard(target)
|
||||
if target != source:
|
||||
deletes.add(source)
|
||||
writes.pop(source, None)
|
||||
summaries.append(_PatchSummary(
|
||||
action="move" if op.new_path else "update",
|
||||
path=op.path,
|
||||
new_path=op.new_path,
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
))
|
||||
|
||||
if dry_run:
|
||||
return (
|
||||
"Patch dry-run succeeded:\n"
|
||||
+ "\n".join(_format_summary(summary) for summary in summaries)
|
||||
)
|
||||
|
||||
backups: dict[Path, bytes | None] = {}
|
||||
for path in set(writes) | deletes:
|
||||
backups[path] = path.read_bytes() if path.exists() else None
|
||||
|
||||
try:
|
||||
for path in deletes:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
for path, content in writes.items():
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(content, encoding="utf-8", newline="")
|
||||
except Exception:
|
||||
for path, data in backups.items():
|
||||
if data is None:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
else:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(data)
|
||||
raise
|
||||
|
||||
for path in set(writes) | deletes:
|
||||
self._file_states.record_write(path)
|
||||
return (
|
||||
"Patch applied:\n"
|
||||
+ "\n".join(_format_summary(summary) for summary in summaries)
|
||||
)
|
||||
except PermissionError as exc:
|
||||
return f"Error: {exc}"
|
||||
except _PatchError as exc:
|
||||
return f"Error applying patch: {exc}"
|
||||
except Exception as exc:
|
||||
return f"Error applying patch: {exc}"
|
||||
591
nanobot/agent/tools/exec_session.py
Normal file
591
nanobot/agent/tools/exec_session.py
Normal file
@ -0,0 +1,591 @@
|
||||
"""Session support for long-running exec workflows."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||
|
||||
|
||||
DEFAULT_YIELD_MS = 1000
|
||||
MAX_YIELD_MS = 30_000
|
||||
DEFAULT_WAIT_FOR_MS = 10_000
|
||||
MAX_WAIT_FOR_MS = 120_000
|
||||
DEFAULT_MAX_OUTPUT_CHARS = 10_000
|
||||
MAX_OUTPUT_CHARS = 50_000
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _SessionPoll:
|
||||
output: str
|
||||
done: bool
|
||||
exit_code: int | None
|
||||
elapsed_s: float = 0.0
|
||||
timed_out: bool = False
|
||||
terminated: bool = False
|
||||
stdin_closed: bool = False
|
||||
truncated_chars: int = 0
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ExecSessionInfo:
|
||||
session_id: str
|
||||
command: str
|
||||
cwd: str
|
||||
elapsed_s: float
|
||||
idle_s: float
|
||||
remaining_s: float
|
||||
returncode: int | None
|
||||
|
||||
|
||||
class _ExecSession:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
process: asyncio.subprocess.Process,
|
||||
command: str,
|
||||
cwd: str,
|
||||
timeout: int,
|
||||
) -> None:
|
||||
self.session_id = session_id
|
||||
self.process = process
|
||||
self.command = command
|
||||
self.cwd = cwd
|
||||
self.started_at = time.monotonic()
|
||||
self.deadline = time.monotonic() + timeout
|
||||
self.last_access = time.monotonic()
|
||||
self._chunks: list[str] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._timed_out = False
|
||||
self._stdout_task = asyncio.create_task(self._read_stream(process.stdout, ""))
|
||||
self._stderr_task = asyncio.create_task(self._read_stream(process.stderr, "STDERR:\n"))
|
||||
|
||||
async def _read_stream(
|
||||
self,
|
||||
stream: asyncio.StreamReader | None,
|
||||
prefix: str,
|
||||
) -> None:
|
||||
if stream is None:
|
||||
return
|
||||
first = True
|
||||
while True:
|
||||
chunk = await stream.read(4096)
|
||||
if not chunk:
|
||||
break
|
||||
text = chunk.decode("utf-8", errors="replace")
|
||||
if prefix and first:
|
||||
text = prefix + text
|
||||
first = False
|
||||
async with self._lock:
|
||||
self._chunks.append(text)
|
||||
|
||||
async def write(self, chars: str) -> str | None:
|
||||
if self.process.returncode is not None:
|
||||
return "session has already exited"
|
||||
if self.process.stdin is None:
|
||||
return "session stdin is not available"
|
||||
try:
|
||||
self.process.stdin.write(chars.encode("utf-8"))
|
||||
await self.process.stdin.drain()
|
||||
except (BrokenPipeError, ConnectionResetError):
|
||||
return "session stdin is closed"
|
||||
return None
|
||||
|
||||
async def close_stdin(self) -> str | None:
|
||||
if self.process.returncode is not None:
|
||||
return "session has already exited"
|
||||
if self.process.stdin is None:
|
||||
return "session stdin is not available"
|
||||
self.process.stdin.close()
|
||||
with suppress(BrokenPipeError, ConnectionResetError):
|
||||
await self.process.stdin.wait_closed()
|
||||
return None
|
||||
|
||||
async def poll(
|
||||
self,
|
||||
yield_time_ms: int,
|
||||
max_output_chars: int,
|
||||
*,
|
||||
terminated: bool = False,
|
||||
stdin_closed: bool = False,
|
||||
) -> _SessionPoll:
|
||||
self.last_access = time.monotonic()
|
||||
if yield_time_ms > 0 and self.process.returncode is None:
|
||||
await asyncio.sleep(min(yield_time_ms, MAX_YIELD_MS) / 1000)
|
||||
|
||||
if self.process.returncode is None and time.monotonic() >= self.deadline:
|
||||
self._timed_out = True
|
||||
await self.kill()
|
||||
|
||||
if self.process.returncode is not None:
|
||||
with suppress(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(self._stdout_task, self._stderr_task),
|
||||
timeout=2.0,
|
||||
)
|
||||
|
||||
async with self._lock:
|
||||
output = "".join(self._chunks)
|
||||
self._chunks.clear()
|
||||
|
||||
output, truncated = _truncate_output(output, max_output_chars)
|
||||
return _SessionPoll(
|
||||
output=output,
|
||||
done=self.process.returncode is not None,
|
||||
exit_code=self.process.returncode,
|
||||
elapsed_s=max(0.0, time.monotonic() - self.started_at),
|
||||
timed_out=self._timed_out,
|
||||
terminated=terminated,
|
||||
stdin_closed=stdin_closed,
|
||||
truncated_chars=truncated,
|
||||
)
|
||||
|
||||
async def kill(self) -> None:
|
||||
if self.process.returncode is not None:
|
||||
return
|
||||
self.process.kill()
|
||||
with suppress(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(self.process.wait(), timeout=5.0)
|
||||
|
||||
|
||||
class ExecSessionManager:
|
||||
def __init__(self, *, max_sessions: int = 8, idle_timeout: int = 1800) -> None:
|
||||
self.max_sessions = max_sessions
|
||||
self.idle_timeout = idle_timeout
|
||||
self._sessions: dict[str, _ExecSession] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def start(
|
||||
self,
|
||||
*,
|
||||
command: str,
|
||||
cwd: str,
|
||||
env: dict[str, str],
|
||||
timeout: int,
|
||||
shell_program: str | None,
|
||||
login: bool,
|
||||
yield_time_ms: int,
|
||||
max_output_chars: int,
|
||||
) -> tuple[str, _SessionPoll]:
|
||||
async with self._lock:
|
||||
await self._cleanup_locked()
|
||||
if len(self._sessions) >= self.max_sessions:
|
||||
raise RuntimeError(f"maximum exec sessions reached ({self.max_sessions})")
|
||||
process = await self._spawn(command, cwd, env, shell_program, login)
|
||||
session_id = uuid.uuid4().hex[:12]
|
||||
session = _ExecSession(
|
||||
session_id=session_id,
|
||||
process=process,
|
||||
command=command,
|
||||
cwd=cwd,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._sessions[session_id] = session
|
||||
|
||||
poll = await session.poll(yield_time_ms, max_output_chars)
|
||||
if poll.done:
|
||||
async with self._lock:
|
||||
self._sessions.pop(session_id, None)
|
||||
return session_id, poll
|
||||
|
||||
async def write(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
chars: str | None,
|
||||
close_stdin: bool,
|
||||
terminate: bool,
|
||||
yield_time_ms: int,
|
||||
max_output_chars: int,
|
||||
) -> _SessionPoll:
|
||||
async with self._lock:
|
||||
await self._cleanup_locked()
|
||||
session = self._sessions.get(session_id)
|
||||
if session is None:
|
||||
raise KeyError(session_id)
|
||||
|
||||
if chars:
|
||||
error = await session.write(chars)
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
stdin_closed = False
|
||||
if close_stdin:
|
||||
error = await session.close_stdin()
|
||||
if error:
|
||||
raise RuntimeError(error)
|
||||
stdin_closed = True
|
||||
if terminate:
|
||||
await session.kill()
|
||||
poll = await session.poll(
|
||||
yield_time_ms,
|
||||
max_output_chars,
|
||||
terminated=terminate,
|
||||
stdin_closed=stdin_closed,
|
||||
)
|
||||
if poll.done:
|
||||
async with self._lock:
|
||||
self._sessions.pop(session_id, None)
|
||||
return poll
|
||||
|
||||
async def list(self) -> list[ExecSessionInfo]:
|
||||
async with self._lock:
|
||||
await self._cleanup_locked()
|
||||
now = time.monotonic()
|
||||
return [
|
||||
ExecSessionInfo(
|
||||
session_id=session_id,
|
||||
command=session.command,
|
||||
cwd=session.cwd,
|
||||
elapsed_s=max(0.0, now - session.started_at),
|
||||
idle_s=max(0.0, now - session.last_access),
|
||||
remaining_s=max(0.0, session.deadline - now),
|
||||
returncode=session.process.returncode,
|
||||
)
|
||||
for session_id, session in sorted(self._sessions.items())
|
||||
]
|
||||
|
||||
async def _cleanup_locked(self) -> None:
|
||||
now = time.monotonic()
|
||||
stale = [
|
||||
session_id
|
||||
for session_id, session in self._sessions.items()
|
||||
if now - session.last_access > self.idle_timeout
|
||||
]
|
||||
for session_id in stale:
|
||||
session = self._sessions.pop(session_id)
|
||||
await session.kill()
|
||||
|
||||
async def _spawn(
|
||||
self,
|
||||
command: str,
|
||||
cwd: str,
|
||||
env: dict[str, str],
|
||||
shell_program: str | None,
|
||||
login: bool,
|
||||
) -> asyncio.subprocess.Process:
|
||||
from nanobot.agent.tools import shell
|
||||
|
||||
if shell._IS_WINDOWS:
|
||||
return await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
shell_program = shell_program or shutil.which("bash") or "/bin/bash"
|
||||
args = [shell_program]
|
||||
if login and shell_program.rsplit("/", 1)[-1] in {"bash", "zsh"}:
|
||||
args.append("-l")
|
||||
args.extend(["-c", command])
|
||||
return await asyncio.create_subprocess_exec(
|
||||
*args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_EXEC_SESSION_MANAGER = ExecSessionManager()
|
||||
|
||||
|
||||
def clamp_session_int(value: int | None, default: int, minimum: int, maximum: int) -> int:
|
||||
if value is None:
|
||||
return default
|
||||
return min(max(value, minimum), maximum)
|
||||
|
||||
|
||||
def _truncate_output(output: str, max_output_chars: int) -> tuple[str, int]:
|
||||
if len(output) <= max_output_chars:
|
||||
return output, 0
|
||||
half = max_output_chars // 2
|
||||
omitted = len(output) - max_output_chars
|
||||
return (
|
||||
output[:half]
|
||||
+ f"\n\n... ({omitted:,} chars truncated) ...\n\n"
|
||||
+ output[-half:],
|
||||
omitted,
|
||||
)
|
||||
|
||||
|
||||
def format_session_poll(session_id: str, poll: _SessionPoll) -> str:
|
||||
parts = [poll.output] if poll.output else []
|
||||
if poll.truncated_chars:
|
||||
parts.append(f"(output truncated by {poll.truncated_chars:,} chars)")
|
||||
if poll.timed_out:
|
||||
parts.append("Error: Command timed out; session was terminated.")
|
||||
if poll.terminated and not poll.timed_out:
|
||||
parts.append("Session terminated.")
|
||||
if poll.stdin_closed:
|
||||
parts.append("Stdin closed.")
|
||||
if poll.done:
|
||||
parts.append(f"Exit code: {poll.exit_code}")
|
||||
else:
|
||||
parts.append(f"Process running. session_id: {session_id}")
|
||||
parts.append(f"Elapsed: {poll.elapsed_s:.1f}s")
|
||||
return "\n".join(parts) if parts else "(no output yet)"
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
session_id=StringSchema("Session id returned by exec when yield_time_ms is used."),
|
||||
chars=StringSchema(
|
||||
"Bytes/text to write to stdin. Omit or pass an empty string to only poll recent output.",
|
||||
nullable=True,
|
||||
),
|
||||
close_stdin=BooleanSchema(
|
||||
description="Close stdin after writing chars. Useful for commands waiting for EOF.",
|
||||
default=False,
|
||||
),
|
||||
terminate=BooleanSchema(
|
||||
description="Terminate the running exec session.",
|
||||
default=False,
|
||||
),
|
||||
yield_time_ms=IntegerSchema(
|
||||
DEFAULT_YIELD_MS,
|
||||
description="Milliseconds to wait before returning recent output (default 1000, max 30000).",
|
||||
minimum=0,
|
||||
maximum=MAX_YIELD_MS,
|
||||
),
|
||||
wait_for=StringSchema(
|
||||
"Optional text to wait for in output before returning. "
|
||||
"Useful for interactive commands and dev servers.",
|
||||
nullable=True,
|
||||
),
|
||||
wait_timeout_ms=IntegerSchema(
|
||||
DEFAULT_WAIT_FOR_MS,
|
||||
description="Maximum milliseconds to wait for wait_for text (default 10000, max 120000).",
|
||||
minimum=0,
|
||||
maximum=MAX_WAIT_FOR_MS,
|
||||
nullable=True,
|
||||
),
|
||||
max_output_chars=IntegerSchema(
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
description="Maximum output characters to return from this poll (default 10000, max 50000).",
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
),
|
||||
max_output_tokens=IntegerSchema(
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
description="Compatibility alias for max_output_chars. The current runtime uses a character budget.",
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
nullable=True,
|
||||
),
|
||||
required=["session_id"],
|
||||
)
|
||||
)
|
||||
class WriteStdinTool(Tool):
|
||||
"""Write to or poll a running exec session."""
|
||||
|
||||
_scopes = {"core", "subagent"}
|
||||
config_key = "exec"
|
||||
|
||||
@classmethod
|
||||
def config_cls(cls):
|
||||
from nanobot.agent.tools.shell import ExecToolConfig
|
||||
|
||||
return ExecToolConfig
|
||||
|
||||
@classmethod
|
||||
def enabled(cls, ctx: Any) -> bool:
|
||||
return ctx.config.exec.enable
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
manager: ExecSessionManager | None = None,
|
||||
) -> None:
|
||||
self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER
|
||||
|
||||
@classmethod
|
||||
def create(cls, ctx: Any) -> Tool:
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "write_stdin"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Interact with a running exec session created by exec with "
|
||||
"yield_time_ms. Use chars='' to poll without writing, chars to send "
|
||||
"stdin, close_stdin=true to send EOF, or terminate=true to stop the "
|
||||
"process. Use wait_for with wait_timeout_ms for dev servers, test "
|
||||
"watchers, and prompts where you need to wait for expected output. "
|
||||
"Do not use this to start new commands; start them with exec."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
session_id: str,
|
||||
chars: str | None = None,
|
||||
close_stdin: bool = False,
|
||||
terminate: bool = False,
|
||||
yield_time_ms: int | None = None,
|
||||
wait_for: str | None = None,
|
||||
wait_timeout_ms: int | None = None,
|
||||
max_output_chars: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if max_output_chars is None:
|
||||
max_output_chars = max_output_tokens
|
||||
output_limit = clamp_session_int(
|
||||
max_output_chars,
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
1000,
|
||||
MAX_OUTPUT_CHARS,
|
||||
)
|
||||
if wait_for:
|
||||
return await self._wait_for_output(
|
||||
session_id=session_id,
|
||||
chars=chars,
|
||||
close_stdin=close_stdin,
|
||||
terminate=terminate,
|
||||
wait_for=wait_for,
|
||||
wait_timeout_ms=clamp_session_int(
|
||||
wait_timeout_ms,
|
||||
DEFAULT_WAIT_FOR_MS,
|
||||
0,
|
||||
MAX_WAIT_FOR_MS,
|
||||
),
|
||||
max_output_chars=output_limit,
|
||||
)
|
||||
poll = await self._manager.write(
|
||||
session_id=session_id,
|
||||
chars=chars,
|
||||
close_stdin=close_stdin,
|
||||
terminate=terminate,
|
||||
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
|
||||
max_output_chars=output_limit,
|
||||
)
|
||||
return format_session_poll(session_id, poll)
|
||||
except KeyError:
|
||||
return f"Error: exec session not found: {session_id}"
|
||||
except Exception as exc:
|
||||
return f"Error writing to exec session: {exc}"
|
||||
|
||||
async def _wait_for_output(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
chars: str | None,
|
||||
close_stdin: bool,
|
||||
terminate: bool,
|
||||
wait_for: str,
|
||||
wait_timeout_ms: int,
|
||||
max_output_chars: int,
|
||||
) -> str:
|
||||
deadline = time.monotonic() + (wait_timeout_ms / 1000)
|
||||
aggregate: list[str] = []
|
||||
first = True
|
||||
poll: _SessionPoll | None = None
|
||||
|
||||
while True:
|
||||
remaining_ms = max(0, int((deadline - time.monotonic()) * 1000))
|
||||
step_ms = min(500, remaining_ms)
|
||||
poll = await self._manager.write(
|
||||
session_id=session_id,
|
||||
chars=chars if first else None,
|
||||
close_stdin=close_stdin if first else False,
|
||||
terminate=terminate if first else False,
|
||||
yield_time_ms=step_ms,
|
||||
max_output_chars=max_output_chars,
|
||||
)
|
||||
first = False
|
||||
if poll.output:
|
||||
aggregate.append(poll.output)
|
||||
joined = "".join(aggregate)
|
||||
if wait_for in joined:
|
||||
poll.output = joined
|
||||
return format_session_poll(session_id, poll)
|
||||
if poll.done or remaining_ms <= 0:
|
||||
poll.output = "".join(aggregate)
|
||||
result = format_session_poll(session_id, poll)
|
||||
if wait_for not in poll.output:
|
||||
result += f"\nWait target not observed: {wait_for!r}"
|
||||
return result
|
||||
|
||||
|
||||
@tool_parameters(tool_parameters_schema())
|
||||
class ListExecSessionsTool(Tool):
|
||||
"""List active exec sessions."""
|
||||
|
||||
_scopes = {"core", "subagent"}
|
||||
config_key = "exec"
|
||||
|
||||
@classmethod
|
||||
def config_cls(cls):
|
||||
from nanobot.agent.tools.shell import ExecToolConfig
|
||||
|
||||
return ExecToolConfig
|
||||
|
||||
@classmethod
|
||||
def enabled(cls, ctx: Any) -> bool:
|
||||
return ctx.config.exec.enable
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
manager: ExecSessionManager | None = None,
|
||||
) -> None:
|
||||
self._manager = manager or DEFAULT_EXEC_SESSION_MANAGER
|
||||
|
||||
@classmethod
|
||||
def create(cls, ctx: Any) -> Tool:
|
||||
return cls()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_exec_sessions"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"List active long-running exec sessions, including session_id, cwd, "
|
||||
"elapsed time, idle time, remaining timeout, and command preview. "
|
||||
"Use this to recover a session_id after context shifts before "
|
||||
"polling, writing stdin, or terminating with write_stdin."
|
||||
)
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, **kwargs: Any) -> str:
|
||||
try:
|
||||
sessions = await self._manager.list()
|
||||
if not sessions:
|
||||
return "No active exec sessions."
|
||||
lines = []
|
||||
for info in sessions:
|
||||
command = " ".join(info.command.split())
|
||||
if len(command) > 120:
|
||||
command = command[:119] + "..."
|
||||
status = "exited" if info.returncode is not None else "running"
|
||||
lines.append(
|
||||
f"{info.session_id} | {status} | elapsed={info.elapsed_s:.1f}s "
|
||||
f"| idle={info.idle_s:.1f}s | remaining={info.remaining_s:.1f}s "
|
||||
f"| cwd={info.cwd} | {command}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
except Exception as exc:
|
||||
return f"Error listing exec sessions: {exc}"
|
||||
@ -132,6 +132,10 @@ def _parse_page_range(pages: str, total: int) -> tuple[int, int]:
|
||||
minimum=1,
|
||||
),
|
||||
pages=StringSchema("Page range for PDF files, e.g. '1-5' (default: all, max 20 pages)"),
|
||||
force=BooleanSchema(
|
||||
description="Bypass same-file read deduplication and return content again.",
|
||||
default=False,
|
||||
),
|
||||
required=["path"],
|
||||
)
|
||||
)
|
||||
@ -154,7 +158,11 @@ class ReadFileTool(_FsTool):
|
||||
"Text output format: LINE_NUM|CONTENT. "
|
||||
"Images return visual content for analysis. "
|
||||
"Supports PDF, DOCX, XLSX, PPTX documents. "
|
||||
"Use find_files/list_dir first when the path is uncertain. "
|
||||
"Read the relevant range before editing so replacements or patches "
|
||||
"are based on current content. "
|
||||
"Use offset and limit for large text files. "
|
||||
"Use force=true to re-read content even if unchanged. "
|
||||
"Reads exceeding ~128K chars are truncated."
|
||||
)
|
||||
|
||||
@ -162,7 +170,15 @@ class ReadFileTool(_FsTool):
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, path: str | None = None, offset: int = 1, limit: int | None = None, pages: str | None = None, **kwargs: Any) -> Any:
|
||||
async def execute(
|
||||
self,
|
||||
path: str | None = None,
|
||||
offset: int = 1,
|
||||
limit: int | None = None,
|
||||
pages: str | None = None,
|
||||
force: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
if not path:
|
||||
return "Error reading file: Unknown path"
|
||||
@ -202,7 +218,13 @@ class ReadFileTool(_FsTool):
|
||||
current_mtime = os.path.getmtime(fp)
|
||||
except OSError:
|
||||
current_mtime = 0.0
|
||||
if entry and entry.can_dedup and entry.offset == offset and entry.limit == limit:
|
||||
if (
|
||||
not force
|
||||
and entry
|
||||
and entry.can_dedup
|
||||
and entry.offset == offset
|
||||
and entry.limit == limit
|
||||
):
|
||||
if current_mtime != entry.mtime:
|
||||
# File was modified externally - force full read and mark as not dedupable
|
||||
entry.can_dedup = False
|
||||
@ -365,9 +387,10 @@ class WriteFileTool(_FsTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Write content to a file. Overwrites if the file already exists; "
|
||||
"creates parent directories as needed. "
|
||||
"For partial edits, prefer edit_file instead."
|
||||
"Create a new file or intentionally replace an entire file with "
|
||||
"the provided content. Overwrites existing files and creates parent "
|
||||
"directories as needed. For code changes or partial edits, prefer "
|
||||
"apply_patch; use edit_file only for small exact replacements."
|
||||
)
|
||||
|
||||
async def execute(self, path: str | None = None, content: str | None = None, **kwargs: Any) -> str:
|
||||
@ -657,6 +680,24 @@ def _find_match(content: str, old_text: str) -> tuple[str | None, int]:
|
||||
old_text=StringSchema("The text to find and replace"),
|
||||
new_text=StringSchema("The text to replace with"),
|
||||
replace_all=BooleanSchema(description="Replace all occurrences (default false)"),
|
||||
occurrence=IntegerSchema(
|
||||
1,
|
||||
description="Optional 1-based occurrence to replace when old_text appears multiple times.",
|
||||
minimum=1,
|
||||
nullable=True,
|
||||
),
|
||||
line_hint=IntegerSchema(
|
||||
1,
|
||||
description="Optional 1-based line hint used to choose the nearest match.",
|
||||
minimum=1,
|
||||
nullable=True,
|
||||
),
|
||||
expected_replacements=IntegerSchema(
|
||||
1,
|
||||
description="Optional guard for the number of replacements that must be made.",
|
||||
minimum=1,
|
||||
nullable=True,
|
||||
),
|
||||
required=["path", "old_text", "new_text"],
|
||||
)
|
||||
)
|
||||
@ -674,10 +715,13 @@ class EditFileTool(_FsTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a file by replacing old_text with new_text. "
|
||||
"Tolerates minor whitespace/indentation differences and curly/straight quote mismatches. "
|
||||
"If old_text matches multiple times, you must provide more context "
|
||||
"or set replace_all=true. Shows a diff of the closest match on failure."
|
||||
"Perform a small, exact replacement in one file by replacing "
|
||||
"old_text with new_text. Use this for narrow text substitutions "
|
||||
"with old_text copied from read_file. For multi-file, structural, "
|
||||
"or generated code edits, prefer apply_patch. If old_text matches "
|
||||
"multiple times, provide more context or set occurrence, line_hint, "
|
||||
"replace_all, and expected_replacements. Shows closest-match "
|
||||
"diagnostics on failure."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -688,7 +732,8 @@ class EditFileTool(_FsTool):
|
||||
async def execute(
|
||||
self, path: str | None = None, old_text: str | None = None,
|
||||
new_text: str | None = None,
|
||||
replace_all: bool = False, **kwargs: Any,
|
||||
replace_all: bool = False, occurrence: int | None = None,
|
||||
line_hint: int | None = None, expected_replacements: int | None = None, **kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if not path:
|
||||
@ -697,10 +742,12 @@ class EditFileTool(_FsTool):
|
||||
raise ValueError("Unknown old_text")
|
||||
if new_text is None:
|
||||
raise ValueError("Unknown new_text")
|
||||
|
||||
# .ipynb detection
|
||||
if path.endswith(".ipynb"):
|
||||
return "Error: This is a Jupyter notebook. Use the notebook_edit tool instead of edit_file."
|
||||
if occurrence is not None and occurrence < 1:
|
||||
return "Error: occurrence must be >= 1."
|
||||
if line_hint is not None and line_hint < 1:
|
||||
return "Error: line_hint must be >= 1."
|
||||
if expected_replacements is not None and expected_replacements < 1:
|
||||
return "Error: expected_replacements must be >= 1."
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
@ -743,7 +790,28 @@ class EditFileTool(_FsTool):
|
||||
if not matches:
|
||||
return self._not_found_msg(old_text, content, path)
|
||||
count = len(matches)
|
||||
if replace_all and occurrence is not None:
|
||||
return "Error: occurrence cannot be used with replace_all=true."
|
||||
if replace_all and line_hint is not None:
|
||||
return "Error: line_hint cannot be used with replace_all=true."
|
||||
if occurrence is not None and line_hint is not None:
|
||||
return "Error: line_hint cannot be used with occurrence."
|
||||
if count > 1 and not replace_all:
|
||||
if occurrence is not None:
|
||||
if occurrence > count:
|
||||
return (
|
||||
f"Error: occurrence {occurrence} is out of range; "
|
||||
f"old_text appears {count} times."
|
||||
)
|
||||
elif line_hint is not None:
|
||||
nearest = min(matches, key=lambda match: abs(match.line - line_hint))
|
||||
distance = abs(nearest.line - line_hint)
|
||||
if sum(1 for match in matches if abs(match.line - line_hint) == distance) > 1:
|
||||
return (
|
||||
f"Error: line_hint {line_hint} is ambiguous; "
|
||||
f"old_text appears {count} times."
|
||||
)
|
||||
else:
|
||||
line_numbers = [match.line for match in matches]
|
||||
preview = ", ".join(f"line {n}" for n in line_numbers[:3])
|
||||
if len(line_numbers) > 3:
|
||||
@ -751,7 +819,13 @@ class EditFileTool(_FsTool):
|
||||
location_hint = f" at {preview}" if preview else ""
|
||||
return (
|
||||
f"Warning: old_text appears {count} times{location_hint}. "
|
||||
"Provide more context to make it unique, or set replace_all=true."
|
||||
"Provide more context, set occurrence to choose one match, "
|
||||
"or set replace_all=true."
|
||||
)
|
||||
elif occurrence is not None and occurrence > count:
|
||||
return (
|
||||
f"Error: occurrence {occurrence} is out of range; "
|
||||
f"old_text appears {count} time."
|
||||
)
|
||||
|
||||
norm_new = new_text.replace("\r\n", "\n")
|
||||
@ -760,7 +834,17 @@ class EditFileTool(_FsTool):
|
||||
if fp.suffix.lower() not in self._MARKDOWN_EXTS:
|
||||
norm_new = self._strip_trailing_ws(norm_new)
|
||||
|
||||
selected = matches if replace_all else matches[:1]
|
||||
if replace_all:
|
||||
selected = matches
|
||||
elif line_hint is not None:
|
||||
selected = [min(matches, key=lambda match: abs(match.line - line_hint))]
|
||||
else:
|
||||
selected = [matches[occurrence - 1 if occurrence else 0]]
|
||||
if expected_replacements is not None and len(selected) != expected_replacements:
|
||||
return (
|
||||
f"Error: expected {expected_replacements} replacements but "
|
||||
f"would make {len(selected)}."
|
||||
)
|
||||
new_content = content
|
||||
for match in reversed(selected):
|
||||
replacement = _preserve_quote_style(norm_old, match.text, norm_new)
|
||||
|
||||
@ -1,162 +0,0 @@
|
||||
"""NotebookEditTool — edit Jupyter .ipynb notebooks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import tool_parameters
|
||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools.filesystem import _FsTool
|
||||
|
||||
|
||||
def _new_cell(source: str, cell_type: str = "code", generate_id: bool = False) -> dict:
|
||||
cell: dict[str, Any] = {
|
||||
"cell_type": cell_type,
|
||||
"source": source,
|
||||
"metadata": {},
|
||||
}
|
||||
if cell_type == "code":
|
||||
cell["outputs"] = []
|
||||
cell["execution_count"] = None
|
||||
if generate_id:
|
||||
cell["id"] = uuid.uuid4().hex[:8]
|
||||
return cell
|
||||
|
||||
|
||||
def _make_empty_notebook() -> dict:
|
||||
return {
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5,
|
||||
"metadata": {
|
||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
||||
"language_info": {"name": "python"},
|
||||
},
|
||||
"cells": [],
|
||||
}
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
path=StringSchema("Path to the .ipynb notebook file"),
|
||||
cell_index=IntegerSchema(0, description="0-based index of the cell to edit", minimum=0),
|
||||
new_source=StringSchema("New source content for the cell"),
|
||||
cell_type=StringSchema(
|
||||
"Cell type: 'code' or 'markdown' (default: code)",
|
||||
enum=["code", "markdown"],
|
||||
),
|
||||
edit_mode=StringSchema(
|
||||
"Mode: 'replace' (default), 'insert' (after target), or 'delete'",
|
||||
enum=["replace", "insert", "delete"],
|
||||
),
|
||||
required=["path", "cell_index"],
|
||||
)
|
||||
)
|
||||
class NotebookEditTool(_FsTool):
|
||||
"""Edit Jupyter notebook cells: replace, insert, or delete."""
|
||||
_scopes = {"core"}
|
||||
|
||||
_VALID_CELL_TYPES = frozenset({"code", "markdown"})
|
||||
_VALID_EDIT_MODES = frozenset({"replace", "insert", "delete"})
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "notebook_edit"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Edit a Jupyter notebook (.ipynb) cell. "
|
||||
"Modes: replace (default) replaces cell content, "
|
||||
"insert adds a new cell after the target index, "
|
||||
"delete removes the cell at the index. "
|
||||
"cell_index is 0-based."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
path: str | None = None,
|
||||
cell_index: int = 0,
|
||||
new_source: str = "",
|
||||
cell_type: str = "code",
|
||||
edit_mode: str = "replace",
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
if not path:
|
||||
return "Error: path is required"
|
||||
|
||||
if not path.endswith(".ipynb"):
|
||||
return "Error: notebook_edit only works on .ipynb files. Use edit_file for other files."
|
||||
|
||||
if edit_mode not in self._VALID_EDIT_MODES:
|
||||
return (
|
||||
f"Error: Invalid edit_mode '{edit_mode}'. "
|
||||
"Use one of: replace, insert, delete."
|
||||
)
|
||||
|
||||
if cell_type not in self._VALID_CELL_TYPES:
|
||||
return (
|
||||
f"Error: Invalid cell_type '{cell_type}'. "
|
||||
"Use one of: code, markdown."
|
||||
)
|
||||
|
||||
fp = self._resolve(path)
|
||||
|
||||
# Create new notebook if file doesn't exist and mode is insert
|
||||
if not fp.exists():
|
||||
if edit_mode != "insert":
|
||||
return f"Error: File not found: {path}"
|
||||
nb = _make_empty_notebook()
|
||||
cell = _new_cell(new_source, cell_type, generate_id=True)
|
||||
nb["cells"].append(cell)
|
||||
fp.parent.mkdir(parents=True, exist_ok=True)
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully created {fp} with 1 cell"
|
||||
|
||||
try:
|
||||
nb = json.loads(fp.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
return f"Error: Failed to parse notebook: {e}"
|
||||
|
||||
cells = nb.get("cells", [])
|
||||
nbformat_minor = nb.get("nbformat_minor", 0)
|
||||
generate_id = nb.get("nbformat", 0) >= 4 and nbformat_minor >= 5
|
||||
|
||||
if edit_mode == "delete":
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells.pop(cell_index)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully deleted cell {cell_index} from {fp}"
|
||||
|
||||
if edit_mode == "insert":
|
||||
insert_at = min(cell_index + 1, len(cells))
|
||||
cell = _new_cell(new_source, cell_type, generate_id=generate_id)
|
||||
cells.insert(insert_at, cell)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully inserted cell at index {insert_at} in {fp}"
|
||||
|
||||
# Default: replace
|
||||
if cell_index < 0 or cell_index >= len(cells):
|
||||
return f"Error: cell_index {cell_index} out of range (notebook has {len(cells)} cells)"
|
||||
cells[cell_index]["source"] = new_source
|
||||
if cell_type and cells[cell_index].get("cell_type") != cell_type:
|
||||
cells[cell_index]["cell_type"] = cell_type
|
||||
if cell_type == "code":
|
||||
cells[cell_index].setdefault("outputs", [])
|
||||
cells[cell_index].setdefault("execution_count", None)
|
||||
elif "outputs" in cells[cell_index]:
|
||||
del cells[cell_index]["outputs"]
|
||||
cells[cell_index].pop("execution_count", None)
|
||||
nb["cells"] = cells
|
||||
fp.write_text(json.dumps(nb, indent=1, ensure_ascii=False), encoding="utf-8")
|
||||
return f"Successfully edited cell {cell_index} in {fp}"
|
||||
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error editing notebook: {e}"
|
||||
@ -1,4 +1,4 @@
|
||||
"""Search tools: grep."""
|
||||
"""Search tools: file discovery and grep."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -12,6 +12,7 @@ from typing import Any, Iterable, TypeVar
|
||||
from nanobot.agent.tools.filesystem import ListDirTool, _FsTool
|
||||
|
||||
_DEFAULT_HEAD_LIMIT = 250
|
||||
_DEFAULT_FILE_HEAD_LIMIT = 200
|
||||
T = TypeVar("T")
|
||||
_TYPE_GLOB_MAP = {
|
||||
"py": ("*.py", "*.pyi"),
|
||||
@ -88,6 +89,14 @@ def _matches_type(name: str, file_type: str | None) -> bool:
|
||||
return any(fnmatch.fnmatch(name.lower(), pattern.lower()) for pattern in patterns)
|
||||
|
||||
|
||||
def _matches_query(rel_path: str, query: str | None) -> bool:
|
||||
if not query:
|
||||
return True
|
||||
haystack = rel_path.lower()
|
||||
terms = [part for part in query.lower().split() if part]
|
||||
return all(term in haystack for term in terms)
|
||||
|
||||
|
||||
class _SearchTool(_FsTool):
|
||||
_IGNORE_DIRS = set(ListDirTool._IGNORE_DIRS)
|
||||
|
||||
@ -109,6 +118,163 @@ class _SearchTool(_FsTool):
|
||||
yield current / filename
|
||||
|
||||
|
||||
class FindFilesTool(_SearchTool):
|
||||
"""Find files by path fragment, glob, or type."""
|
||||
_scopes = {"core", "subagent"}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "find_files"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Find files by path fragment, glob, or file type. "
|
||||
"Use this before read_file when you need to locate files, and "
|
||||
"prefer it over shell find/ls for ordinary workspace discovery. "
|
||||
"Returns workspace-relative paths and skips common dependency/build "
|
||||
"directories."
|
||||
)
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search in (default '.')",
|
||||
},
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Optional case-insensitive path fragment search. "
|
||||
"Whitespace-separated terms must all be present."
|
||||
),
|
||||
},
|
||||
"glob": {
|
||||
"type": "string",
|
||||
"description": "Optional file filter, e.g. '*.py' or 'tests/**/test_*.py'",
|
||||
},
|
||||
"type": {
|
||||
"type": "string",
|
||||
"description": "Optional file type shorthand, e.g. 'py', 'ts', 'md', 'json'",
|
||||
},
|
||||
"include_dirs": {
|
||||
"type": "boolean",
|
||||
"description": "Include matching directories as well as files (default false)",
|
||||
},
|
||||
"sort": {
|
||||
"type": "string",
|
||||
"enum": ["path", "modified"],
|
||||
"description": "Sort by path or most recently modified first (default path)",
|
||||
},
|
||||
"head_limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of paths to return (default 200, 0 for all, max 1000)",
|
||||
"minimum": 0,
|
||||
"maximum": 1000,
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Skip the first N results before applying head_limit",
|
||||
"minimum": 0,
|
||||
"maximum": 100000,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _iter_paths(self, root: Path, *, include_dirs: bool) -> Iterable[Path]:
|
||||
if root.is_file():
|
||||
yield root
|
||||
return
|
||||
if include_dirs:
|
||||
yield root
|
||||
for dirpath, dirnames, filenames in os.walk(root):
|
||||
dirnames[:] = sorted(d for d in dirnames if d not in self._IGNORE_DIRS)
|
||||
current = Path(dirpath)
|
||||
if include_dirs and current != root:
|
||||
yield current
|
||||
for filename in sorted(filenames):
|
||||
yield current / filename
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
path: str = ".",
|
||||
query: str | None = None,
|
||||
glob: str | None = None,
|
||||
type: str | None = None,
|
||||
include_dirs: bool = False,
|
||||
sort: str = "path",
|
||||
head_limit: int | None = None,
|
||||
offset: int = 0,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
target = self._resolve(path or ".")
|
||||
if not target.exists():
|
||||
return f"Error: Path not found: {path}"
|
||||
if not (target.is_dir() or target.is_file()):
|
||||
return f"Error: Unsupported path: {path}"
|
||||
|
||||
if sort not in {"path", "modified"}:
|
||||
return "Error: sort must be 'path' or 'modified'"
|
||||
|
||||
limit = (
|
||||
_DEFAULT_FILE_HEAD_LIMIT
|
||||
if head_limit is None
|
||||
else None if head_limit == 0 else head_limit
|
||||
)
|
||||
root = target if target.is_dir() else target.parent
|
||||
matches: list[tuple[str, float]] = []
|
||||
|
||||
for candidate in self._iter_paths(target, include_dirs=include_dirs):
|
||||
if candidate.is_dir() and not include_dirs:
|
||||
continue
|
||||
rel_path = candidate.relative_to(root).as_posix()
|
||||
display_path = self._display_path(candidate, root)
|
||||
name = candidate.name
|
||||
|
||||
if glob and not _match_glob(rel_path, name, glob):
|
||||
continue
|
||||
if candidate.is_file() and not _matches_type(name, type):
|
||||
continue
|
||||
if candidate.is_dir() and type:
|
||||
continue
|
||||
if not _matches_query(display_path, query):
|
||||
continue
|
||||
try:
|
||||
mtime = candidate.stat().st_mtime
|
||||
except OSError:
|
||||
mtime = 0.0
|
||||
suffix = "/" if candidate.is_dir() else ""
|
||||
matches.append((display_path + suffix, mtime))
|
||||
|
||||
if sort == "modified":
|
||||
matches.sort(key=lambda item: (-item[1], item[0]))
|
||||
else:
|
||||
matches.sort(key=lambda item: item[0])
|
||||
|
||||
paths = [item[0] for item in matches]
|
||||
paged, truncated = _paginate(paths, limit, offset)
|
||||
if not paged:
|
||||
return "No files found"
|
||||
|
||||
result = "\n".join(paged)
|
||||
note = _pagination_note(limit, offset, truncated)
|
||||
if note:
|
||||
result += "\n\n" + note
|
||||
return result
|
||||
except PermissionError as e:
|
||||
return f"Error: {e}"
|
||||
except Exception as e:
|
||||
return f"Error finding files: {e}"
|
||||
|
||||
|
||||
class GrepTool(_SearchTool):
|
||||
"""Search file contents using a regex-like pattern."""
|
||||
_scopes = {"core", "subagent"}
|
||||
@ -125,7 +291,8 @@ class GrepTool(_SearchTool):
|
||||
return (
|
||||
"Search file contents with a regex pattern. "
|
||||
"Default output_mode is files_with_matches (file paths only); "
|
||||
"use content mode for matching lines with context. "
|
||||
"use content mode for matching lines with context. Prefer this "
|
||||
"over shell grep for ordinary workspace searches. "
|
||||
"Skips binary and files >2 MB. Supports glob/type filtering."
|
||||
)
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import re
|
||||
import shutil
|
||||
import sys
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -15,8 +16,17 @@ from loguru import logger
|
||||
from pydantic import Field
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.exec_session import (
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
DEFAULT_YIELD_MS,
|
||||
DEFAULT_EXEC_SESSION_MANAGER,
|
||||
MAX_OUTPUT_CHARS,
|
||||
MAX_YIELD_MS,
|
||||
clamp_session_int,
|
||||
format_session_poll,
|
||||
)
|
||||
from nanobot.agent.tools.sandbox import wrap_command
|
||||
from nanobot.agent.tools.schema import IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
@ -44,10 +54,22 @@ class ExecToolConfig(Base):
|
||||
deny_patterns: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _PreparedCommand:
|
||||
command: str
|
||||
cwd: str
|
||||
env: dict[str, str]
|
||||
timeout: int
|
||||
shell_program: str | None
|
||||
login: bool
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
command=StringSchema("The shell command to execute"),
|
||||
cmd=StringSchema("Compatibility alias for command"),
|
||||
working_dir=StringSchema("Optional working directory for the command"),
|
||||
workdir=StringSchema("Compatibility alias for working_dir"),
|
||||
timeout=IntegerSchema(
|
||||
60,
|
||||
description=(
|
||||
@ -57,7 +79,44 @@ class ExecToolConfig(Base):
|
||||
minimum=1,
|
||||
maximum=600,
|
||||
),
|
||||
required=["command"],
|
||||
shell=StringSchema(
|
||||
"Optional shell binary to launch. On Unix, supports sh, bash, or zsh.",
|
||||
nullable=True,
|
||||
),
|
||||
login=BooleanSchema(
|
||||
description="Whether to run bash/zsh with login shell semantics (default true).",
|
||||
default=True,
|
||||
nullable=True,
|
||||
),
|
||||
yield_time_ms=IntegerSchema(
|
||||
description=(
|
||||
"Optional milliseconds to wait before returning output. "
|
||||
"When set, a still-running command returns a session_id that "
|
||||
"can be polled or written to with write_stdin. Omit this field "
|
||||
"to keep one-shot exec behavior."
|
||||
),
|
||||
minimum=0,
|
||||
maximum=MAX_YIELD_MS,
|
||||
nullable=True,
|
||||
),
|
||||
max_output_chars=IntegerSchema(
|
||||
description=(
|
||||
"Maximum output characters to return when yield_time_ms is used "
|
||||
"(default 10000, max 50000)."
|
||||
),
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
nullable=True,
|
||||
),
|
||||
max_output_tokens=IntegerSchema(
|
||||
description=(
|
||||
"Compatibility alias for max_output_chars. The current runtime "
|
||||
"uses a character budget."
|
||||
),
|
||||
minimum=1000,
|
||||
maximum=MAX_OUTPUT_CHARS,
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
class ExecTool(Tool):
|
||||
@ -98,6 +157,7 @@ class ExecTool(Tool):
|
||||
sandbox: str = "",
|
||||
path_append: str = "",
|
||||
allowed_env_keys: list[str] | None = None,
|
||||
session_manager: Any | None = None,
|
||||
):
|
||||
self.timeout = timeout
|
||||
self.working_dir = working_dir
|
||||
@ -125,6 +185,7 @@ class ExecTool(Tool):
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.path_append = path_append
|
||||
self.allowed_env_keys = allowed_env_keys or []
|
||||
self._session_manager = session_manager or DEFAULT_EXEC_SESSION_MANAGER
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@ -150,10 +211,15 @@ class ExecTool(Tool):
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Execute a shell command and return its output. "
|
||||
"Prefer read_file/write_file/edit_file over cat/echo/sed, "
|
||||
"and grep/glob over shell find/grep. "
|
||||
"Use this for tests, builds, package commands, git commands, and "
|
||||
"other process execution. Prefer read_file/find_files/grep for "
|
||||
"inspection and apply_patch/write_file/edit_file for file changes "
|
||||
"instead of cat, shell find/grep, echo, or sed. "
|
||||
"Use -y or --yes flags to avoid interactive prompts. "
|
||||
"Output is truncated at 10 000 chars; timeout defaults to 60s."
|
||||
"For long-running or interactive commands, pass yield_time_ms; "
|
||||
"if the command keeps running, exec returns a session_id that can "
|
||||
"be polled or written to with write_stdin. Output is truncated at "
|
||||
"10 000 chars; timeout defaults to 60s."
|
||||
)
|
||||
|
||||
@property
|
||||
@ -161,9 +227,111 @@ class ExecTool(Tool):
|
||||
return True
|
||||
|
||||
async def execute(
|
||||
self, command: str, working_dir: str | None = None,
|
||||
timeout: int | None = None, **kwargs: Any,
|
||||
self, command: str | None = None, cmd: str | None = None,
|
||||
working_dir: str | None = None, workdir: str | None = None,
|
||||
timeout: int | None = None, shell: str | None = None,
|
||||
login: bool | None = None, yield_time_ms: int | None = None,
|
||||
max_output_chars: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
command = command or cmd
|
||||
working_dir = working_dir or workdir
|
||||
if not command:
|
||||
return "Error: Missing command. Provide command or cmd."
|
||||
if max_output_chars is None:
|
||||
max_output_chars = max_output_tokens
|
||||
|
||||
prepared = self._prepare_command(command, working_dir, timeout, shell, login)
|
||||
if isinstance(prepared, str):
|
||||
return prepared
|
||||
|
||||
if yield_time_ms is not None:
|
||||
return await self._execute_session(prepared, yield_time_ms, max_output_chars)
|
||||
|
||||
try:
|
||||
process = await self._spawn(
|
||||
prepared.command,
|
||||
prepared.cwd,
|
||||
prepared.env,
|
||||
prepared.shell_program,
|
||||
prepared.login,
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
timeout=prepared.timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await self._kill_process(process)
|
||||
return f"Error: Command timed out after {prepared.timeout} seconds"
|
||||
except asyncio.CancelledError:
|
||||
await self._kill_process(process)
|
||||
raise
|
||||
|
||||
output_parts = []
|
||||
|
||||
if stdout:
|
||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||
|
||||
if stderr:
|
||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||
if stderr_text.strip():
|
||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||
|
||||
max_len = clamp_session_int(max_output_chars, self._MAX_OUTPUT, 1000, MAX_OUTPUT_CHARS)
|
||||
if len(result) > max_len:
|
||||
half = max_len // 2
|
||||
result = (
|
||||
result[:half]
|
||||
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
|
||||
+ result[-half:]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
async def _execute_session(
|
||||
self,
|
||||
prepared: _PreparedCommand,
|
||||
yield_time_ms: int | None,
|
||||
max_output_chars: int | None,
|
||||
) -> str:
|
||||
try:
|
||||
session_id, poll = await self._session_manager.start(
|
||||
command=prepared.command,
|
||||
cwd=prepared.cwd,
|
||||
env=prepared.env,
|
||||
timeout=prepared.timeout,
|
||||
shell_program=prepared.shell_program,
|
||||
login=prepared.login,
|
||||
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
|
||||
max_output_chars=clamp_session_int(
|
||||
max_output_chars,
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
1000,
|
||||
MAX_OUTPUT_CHARS,
|
||||
),
|
||||
)
|
||||
return format_session_poll(session_id, poll)
|
||||
except Exception as exc:
|
||||
return f"Error executing command: {exc}"
|
||||
|
||||
def _prepare_command(
|
||||
self,
|
||||
command: str,
|
||||
working_dir: str | None = None,
|
||||
timeout: int | None = None,
|
||||
shell: str | None = None,
|
||||
login: bool | None = None,
|
||||
) -> _PreparedCommand | str:
|
||||
cwd = working_dir or self.working_dir or os.getcwd()
|
||||
|
||||
# Prevent an LLM-supplied working_dir from escaping the configured
|
||||
@ -211,52 +379,24 @@ class ExecTool(Tool):
|
||||
env["NANOBOT_PATH_APPEND"] = self.path_append
|
||||
command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}'
|
||||
|
||||
try:
|
||||
process = await self._spawn(command, cwd, env)
|
||||
shell_program, shell_error = self._resolve_shell(shell)
|
||||
if shell_error:
|
||||
return shell_error
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(),
|
||||
return _PreparedCommand(
|
||||
command=command,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
timeout=effective_timeout,
|
||||
shell_program=shell_program,
|
||||
login=True if login is None else login,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await self._kill_process(process)
|
||||
return f"Error: Command timed out after {effective_timeout} seconds"
|
||||
except asyncio.CancelledError:
|
||||
await self._kill_process(process)
|
||||
raise
|
||||
|
||||
output_parts = []
|
||||
|
||||
if stdout:
|
||||
output_parts.append(stdout.decode("utf-8", errors="replace"))
|
||||
|
||||
if stderr:
|
||||
stderr_text = stderr.decode("utf-8", errors="replace")
|
||||
if stderr_text.strip():
|
||||
output_parts.append(f"STDERR:\n{stderr_text}")
|
||||
|
||||
output_parts.append(f"\nExit code: {process.returncode}")
|
||||
|
||||
result = "\n".join(output_parts) if output_parts else "(no output)"
|
||||
|
||||
max_len = self._MAX_OUTPUT
|
||||
if len(result) > max_len:
|
||||
half = max_len // 2
|
||||
result = (
|
||||
result[:half]
|
||||
+ f"\n\n... ({len(result) - max_len:,} chars truncated) ...\n\n"
|
||||
+ result[-half:]
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return f"Error executing command: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
async def _spawn(
|
||||
command: str, cwd: str, env: dict[str, str],
|
||||
shell_program: str | None = None,
|
||||
login: bool = True,
|
||||
) -> asyncio.subprocess.Process:
|
||||
"""Launch *command* in a platform-appropriate shell."""
|
||||
if _IS_WINDOWS:
|
||||
@ -272,9 +412,13 @@ class ExecTool(Tool):
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
bash = shutil.which("bash") or "/bin/bash"
|
||||
shell_program = shell_program or shutil.which("bash") or "/bin/bash"
|
||||
args = [shell_program]
|
||||
if login and Path(shell_program).name in {"bash", "zsh"}:
|
||||
args.append("-l")
|
||||
args.extend(["-c", command])
|
||||
return await asyncio.create_subprocess_exec(
|
||||
bash, "-l", "-c", command,
|
||||
*args,
|
||||
stdin=asyncio.subprocess.DEVNULL,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
@ -282,6 +426,31 @@ class ExecTool(Tool):
|
||||
env=env,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_shell(shell: str | None) -> tuple[str | None, str | None]:
|
||||
if not shell:
|
||||
return None, None
|
||||
if _IS_WINDOWS:
|
||||
return None, "Error: shell parameter is not supported on Windows"
|
||||
if "\0" in shell or "\n" in shell or "\r" in shell:
|
||||
return None, "Error: shell contains invalid characters"
|
||||
allowed = {"sh", "bash", "zsh"}
|
||||
path = Path(shell).expanduser()
|
||||
if path.is_absolute():
|
||||
if path.name not in allowed:
|
||||
return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
|
||||
if not path.is_file() or not os.access(path, os.X_OK):
|
||||
return None, f"Error: shell is not executable: {shell}"
|
||||
return str(path), None
|
||||
if "/" in shell or "\\" in shell:
|
||||
return None, "Error: shell must be a shell name or absolute path"
|
||||
if shell not in allowed:
|
||||
return None, f"Error: unsupported shell {shell!r}. Allowed: bash, sh, zsh"
|
||||
resolved = shutil.which(shell)
|
||||
if not resolved:
|
||||
return None, f"Error: shell not found: {shell}"
|
||||
return resolved, None
|
||||
|
||||
@staticmethod
|
||||
async def _kill_process(process: asyncio.subprocess.Process) -> None:
|
||||
"""Kill a subprocess and reap it to prevent zombies."""
|
||||
|
||||
@ -11,6 +11,7 @@ import secrets
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from collections.abc import Awaitable, Callable
|
||||
from ipaddress import ip_address
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -463,6 +464,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
"""Strip non-standard keys, normalize tool_call IDs."""
|
||||
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
||||
id_map: dict[str, str] = {}
|
||||
pending_tool_ids: dict[str, deque[str]] = {}
|
||||
force_string_content = bool(self._spec and self._spec.name == "deepseek")
|
||||
|
||||
def map_id(value: Any) -> Any:
|
||||
@ -470,15 +472,49 @@ class OpenAICompatProvider(LLMProvider):
|
||||
return value
|
||||
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
||||
|
||||
def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str:
|
||||
if isinstance(value, str) and value:
|
||||
base = map_id(value)
|
||||
else:
|
||||
base = _short_tool_id()
|
||||
if not isinstance(base, str) or not base:
|
||||
base = _short_tool_id()
|
||||
if base not in used_ids:
|
||||
return base
|
||||
seed = value if isinstance(value, str) and value else base
|
||||
salt = 1
|
||||
while True:
|
||||
candidate = self._normalize_tool_call_id(f"{seed}:{idx}:{salt}")
|
||||
if isinstance(candidate, str) and candidate not in used_ids:
|
||||
return candidate
|
||||
salt += 1
|
||||
|
||||
def map_tool_result_id(value: Any) -> Any:
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
queue = pending_tool_ids.get(value)
|
||||
if queue:
|
||||
mapped = queue.popleft()
|
||||
if not queue:
|
||||
pending_tool_ids.pop(value, None)
|
||||
return mapped
|
||||
return map_id(value)
|
||||
|
||||
for clean in sanitized:
|
||||
if isinstance(clean.get("tool_calls"), list):
|
||||
normalized = []
|
||||
for tc in clean["tool_calls"]:
|
||||
used_ids: set[str] = set()
|
||||
for idx, tc in enumerate(clean["tool_calls"]):
|
||||
if not isinstance(tc, dict):
|
||||
normalized.append(tc)
|
||||
continue
|
||||
tc_clean = dict(tc)
|
||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||
raw_id = tc_clean.get("id")
|
||||
mapped_id = unique_tool_id(raw_id, used_ids, idx)
|
||||
tc_clean["id"] = mapped_id
|
||||
used_ids.add(mapped_id)
|
||||
if isinstance(raw_id, str) and raw_id:
|
||||
pending_tool_ids.setdefault(raw_id, deque()).append(mapped_id)
|
||||
function = tc_clean.get("function")
|
||||
if isinstance(function, dict):
|
||||
function_clean = dict(function)
|
||||
@ -496,7 +532,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
# that mix non-empty content with tool_calls.
|
||||
clean["content"] = None
|
||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
clean["tool_call_id"] = map_tool_result_id(clean["tool_call_id"])
|
||||
if (
|
||||
force_string_content
|
||||
and not (clean.get("role") == "assistant" and clean.get("tool_calls"))
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
# Agent Instructions
|
||||
|
||||
## Workspace Guidance
|
||||
|
||||
Use this file for project-specific preferences, recurring workflow conventions, and instructions you want the agent to remember for this workspace. Keep durable facts about the user in `USER.md`, personality/style guidance in `SOUL.md`, and long-term memory in `memory/MEMORY.md`.
|
||||
|
||||
## Scheduled Reminders
|
||||
|
||||
Before scheduling reminders, check available skills and follow skill guidance first.
|
||||
@ -10,10 +14,10 @@ Get USER_ID and CHANNEL from the current session (e.g., `8281248569` and `telegr
|
||||
|
||||
## Heartbeat Tasks
|
||||
|
||||
`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks:
|
||||
`HEARTBEAT.md` is checked on the configured heartbeat interval. Use file tools to manage periodic tasks.
|
||||
|
||||
- **Add**: `edit_file` to append new tasks
|
||||
- **Remove**: `edit_file` to delete completed tasks
|
||||
- **Rewrite**: `write_file` to replace all tasks
|
||||
- Use `apply_patch` for normal task-list updates, especially when adding, removing, or changing multiple lines.
|
||||
- Use `edit_file` only for small exact replacements copied from the current `HEARTBEAT.md`.
|
||||
- Use `write_file` for first creation or intentional full-file rewrites.
|
||||
|
||||
When the user asks for a recurring/periodic task, update `HEARTBEAT.md` instead of creating a one-time cron reminder.
|
||||
|
||||
@ -1,28 +0,0 @@
|
||||
# Tool Usage Notes
|
||||
|
||||
Tool signatures are provided automatically via function calling.
|
||||
This file documents non-obvious constraints and usage patterns.
|
||||
|
||||
## exec — Safety Limits
|
||||
|
||||
- Commands have a configurable timeout (default 60s)
|
||||
- Dangerous commands are blocked (rm -rf, format, dd, shutdown, etc.)
|
||||
- Output is truncated at 10,000 characters
|
||||
- `restrictToWorkspace` config can limit file access to the workspace
|
||||
|
||||
## grep — Content Search
|
||||
|
||||
- Use `grep` to search file contents inside the workspace
|
||||
- Default behavior returns only matching file paths (`output_mode="files_with_matches"`)
|
||||
- Supports optional `glob` filtering (e.g. `glob="*.py"`) plus `context_before` / `context_after`
|
||||
- Supports `type="py"`, `type="ts"`, `type="md"` and similar shorthand filters
|
||||
- Use `fixed_strings=true` for literal keywords containing regex characters
|
||||
- Use `output_mode="files_with_matches"` to get only matching file paths
|
||||
- Use `output_mode="count"` to size a search before reading full matches
|
||||
- Use `head_limit` and `offset` to page across results
|
||||
- Prefer this over `exec` for code and history searches
|
||||
- Binary or oversized files may be skipped to keep results readable
|
||||
|
||||
## cron — Scheduled Reminders
|
||||
|
||||
- Please refer to cron skill for usage.
|
||||
60
nanobot/templates/agent/tool_contract.md
Normal file
60
nanobot/templates/agent/tool_contract.md
Normal file
@ -0,0 +1,60 @@
|
||||
# Tool Usage Notes
|
||||
|
||||
Tool signatures are provided automatically via function calling. This section
|
||||
documents the general tool contract and non-obvious usage patterns.
|
||||
|
||||
## General Tool Contract
|
||||
|
||||
- Use the narrowest structured tool that directly matches the task.
|
||||
- Use read-only discovery before writes when state is uncertain.
|
||||
- Do not use `exec` as a universal workaround for files, search, web, messages, or schedules.
|
||||
- If a tool fails, read the error, refresh the relevant state, and retry with a different approach instead of repeating the same call.
|
||||
- After meaningful changes, verify with the smallest reliable check: re-read changed state, run targeted tests, or inspect command output.
|
||||
- Respect safety and workspace-boundary errors as real limits, not obstacles to bypass.
|
||||
|
||||
## Discovery and Reading
|
||||
|
||||
- Use `find_files` or `list_dir` to locate workspace paths before `read_file` when a path is uncertain.
|
||||
- Use `grep` for content search inside the workspace; prefer it over shell grep for ordinary searches.
|
||||
- `grep` defaults to `output_mode="files_with_matches"`; use `output_mode="content"` for matching lines with context.
|
||||
- Use `fixed_strings=true` for literal keywords containing regex characters.
|
||||
- Use `output_mode="count"` to size a broad search before reading full matches.
|
||||
- Use `head_limit` and `offset` to page across large result sets.
|
||||
- Binary or oversized files may be skipped to keep results readable.
|
||||
|
||||
## File and Coding Workflows
|
||||
|
||||
- For code or config changes, the default loop is: locate (`find_files`/`grep`), inspect (`read_file`), edit (`apply_patch`), then verify (`exec` or re-read).
|
||||
- Use `apply_patch` as the default code editing tool, especially for multi-file changes, structural edits, generated code, moves, adds, or deletes.
|
||||
- Use `apply_patch dry_run=true` when the patch is uncertain and you want validation plus a change summary before writing.
|
||||
- Use `edit_file` only for small exact replacements in one file, with `old_text` copied from `read_file`; add `occurrence`, `line_hint`, or `expected_replacements` when ambiguity matters.
|
||||
- Use `write_file` for new files or intentional full-file rewrites, not routine partial edits.
|
||||
- If `apply_patch` or `edit_file` fails, re-read with `force=true`, narrow the context, and try a smaller patch rather than switching to shell `sed` or `echo`.
|
||||
|
||||
## Process Execution
|
||||
|
||||
- Use `exec` for tests, builds, package commands, git commands, and other process execution.
|
||||
- Prefer dedicated file/search tools over `cat`, shell `find`, shell `grep`, `sed`, or `echo` for ordinary workspace inspection and edits.
|
||||
- Use non-interactive flags such as `-y` or `--yes` when available.
|
||||
- Commands have a configurable timeout (default 60s), dangerous commands are blocked, and output is truncated.
|
||||
- For long-running or interactive commands, pass `yield_time_ms`; if the process keeps running, continue with `write_stdin`.
|
||||
- Use `write_stdin` to poll, provide stdin, close stdin, wait for expected output with `wait_for`, or terminate an existing exec session.
|
||||
- Use `list_exec_sessions` to recover active session IDs after context shifts.
|
||||
|
||||
## Web and External Information
|
||||
|
||||
- Use web tools when the user asks for current information, a specific URL, or information likely to have changed.
|
||||
- Use `web_search` to find sources and `web_fetch` for a specific page or result that needs closer reading.
|
||||
- Do not invent freshness-sensitive facts when tools can verify them.
|
||||
|
||||
## Messaging and Media
|
||||
|
||||
- Use `message` to send content or local media to the user/channel.
|
||||
- `read_file` only reads content for your analysis; it does not deliver a file to the user.
|
||||
- When sending an existing local file, attach it through the message/media mechanism instead of pasting file contents unless the user asked for text.
|
||||
|
||||
## Scheduling and Background Work
|
||||
|
||||
- Use `cron` for scheduled reminders or recurring jobs; do not run `nanobot cron` through `exec`.
|
||||
- For heartbeat tasks, update `HEARTBEAT.md` according to the agent instructions.
|
||||
- Do not write reminders only to memory files when the user expects an actual notification.
|
||||
@ -3,14 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "notebook_edit"})
|
||||
TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "apply_patch"})
|
||||
_MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024
|
||||
_LIVE_EMIT_INTERVAL_S = 0.18
|
||||
_LIVE_EMIT_LINE_STEP = 24
|
||||
@ -153,19 +152,110 @@ def prepare_file_edit_tracker(
|
||||
workspace: Path | None,
|
||||
params: dict[str, Any] | None,
|
||||
) -> FileEditTracker | None:
|
||||
trackers = prepare_file_edit_trackers(
|
||||
call_id=call_id,
|
||||
tool_name=tool_name,
|
||||
tool=tool,
|
||||
workspace=workspace,
|
||||
params=params,
|
||||
)
|
||||
return trackers[0] if trackers else None
|
||||
|
||||
|
||||
def prepare_file_edit_trackers(
|
||||
*,
|
||||
call_id: str,
|
||||
tool_name: str,
|
||||
tool: Any,
|
||||
workspace: Path | None,
|
||||
params: dict[str, Any] | None,
|
||||
) -> list[FileEditTracker]:
|
||||
if not is_file_edit_tool(tool_name):
|
||||
return None
|
||||
path = resolve_file_edit_path(tool, workspace, params)
|
||||
if path is None:
|
||||
return None
|
||||
return []
|
||||
paths = resolve_file_edit_paths(tool_name, tool, workspace, params)
|
||||
trackers: list[FileEditTracker] = []
|
||||
seen: set[Path] = set()
|
||||
for path in paths:
|
||||
try:
|
||||
resolved = path.resolve()
|
||||
except Exception:
|
||||
resolved = path
|
||||
if resolved in seen:
|
||||
continue
|
||||
seen.add(resolved)
|
||||
before = read_file_snapshot(path)
|
||||
return FileEditTracker(
|
||||
trackers.append(FileEditTracker(
|
||||
call_id=str(call_id or ""),
|
||||
tool=tool_name,
|
||||
path=path,
|
||||
display_path=display_file_edit_path(path, workspace),
|
||||
before=before,
|
||||
)
|
||||
))
|
||||
return trackers
|
||||
|
||||
|
||||
def resolve_file_edit_paths(
|
||||
tool_name: str,
|
||||
tool: Any,
|
||||
workspace: Path | None,
|
||||
params: dict[str, Any] | None,
|
||||
) -> list[Path]:
|
||||
if tool_name == "apply_patch":
|
||||
return _resolve_apply_patch_paths(tool, workspace, params)
|
||||
path = resolve_file_edit_path(tool, workspace, params)
|
||||
if path is None:
|
||||
return []
|
||||
return [path]
|
||||
|
||||
|
||||
def _resolve_apply_patch_paths(
|
||||
tool: Any,
|
||||
workspace: Path | None,
|
||||
params: dict[str, Any] | None,
|
||||
) -> list[Path]:
|
||||
if not isinstance(params, dict):
|
||||
return []
|
||||
patch = params.get("patch")
|
||||
if not isinstance(patch, str) or not patch.strip():
|
||||
return []
|
||||
if params.get("dry_run") is True:
|
||||
return []
|
||||
try:
|
||||
from nanobot.agent.tools.apply_patch import _parse_patch
|
||||
|
||||
ops = _parse_patch(patch)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
resolved: list[Path] = []
|
||||
for op in ops:
|
||||
for raw_path in (op.path, op.new_path):
|
||||
if not raw_path:
|
||||
continue
|
||||
path = _resolve_raw_file_edit_path(tool, workspace, raw_path)
|
||||
if path is not None:
|
||||
resolved.append(path)
|
||||
return resolved
|
||||
|
||||
|
||||
def _resolve_raw_file_edit_path(
|
||||
tool: Any,
|
||||
workspace: Path | None,
|
||||
raw_path: str,
|
||||
) -> Path | None:
|
||||
resolver = getattr(tool, "_resolve", None)
|
||||
if callable(resolver):
|
||||
try:
|
||||
resolved = resolver(raw_path)
|
||||
if isinstance(resolved, Path):
|
||||
return resolved
|
||||
if resolved:
|
||||
return Path(resolved)
|
||||
except Exception:
|
||||
return None
|
||||
if workspace is None:
|
||||
return Path(raw_path).expanduser().resolve()
|
||||
return (workspace / raw_path).expanduser().resolve()
|
||||
|
||||
|
||||
def build_file_edit_start_event(
|
||||
@ -303,6 +393,9 @@ class StreamingFileEditTracker:
|
||||
self._states[key] = state
|
||||
|
||||
state.apply_delta(payload)
|
||||
if state.name == "apply_patch":
|
||||
await self._update_apply_patch(state)
|
||||
return
|
||||
if state.name not in {"write_file", "edit_file"}:
|
||||
return
|
||||
if state.path is None:
|
||||
@ -342,10 +435,62 @@ class StreamingFileEditTracker:
|
||||
deleted=deleted,
|
||||
)])
|
||||
|
||||
async def _update_apply_patch(self, state: _StreamingFileEditState) -> None:
|
||||
if _json_bool_true(state.arguments, "dry_run"):
|
||||
return
|
||||
patch = _extract_json_string_prefix(state.arguments, "patch")
|
||||
if not patch:
|
||||
return
|
||||
tool = self._tools.get("apply_patch") if hasattr(self._tools, "get") else None
|
||||
events: list[dict[str, Any]] = []
|
||||
now = time.monotonic()
|
||||
for raw_path, added, deleted, delete_file in _streaming_apply_patch_stats(patch):
|
||||
path = _resolve_raw_file_edit_path(tool, self._workspace, raw_path)
|
||||
if path is None:
|
||||
continue
|
||||
file_state = state.patch_files.get(raw_path)
|
||||
if file_state is None:
|
||||
tracker = FileEditTracker(
|
||||
call_id=state.call_id or state.key,
|
||||
tool="apply_patch",
|
||||
path=path,
|
||||
display_path=display_file_edit_path(path, self._workspace),
|
||||
before=read_file_snapshot(path),
|
||||
)
|
||||
file_state = _StreamingPatchFileState(tracker=tracker)
|
||||
state.patch_files[raw_path] = file_state
|
||||
if delete_file and added == 0 and deleted == 0 and file_state.tracker.before.countable:
|
||||
deleted = _text_line_count(file_state.tracker.before.text or "")
|
||||
if not file_state.should_emit(added, deleted, now):
|
||||
continue
|
||||
file_state.mark_emitted(added, deleted, now)
|
||||
events.append(build_file_edit_live_event(
|
||||
file_state.tracker,
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
))
|
||||
if events:
|
||||
await self._emit(events)
|
||||
|
||||
async def flush(self) -> None:
|
||||
events: list[dict[str, Any]] = []
|
||||
now = time.monotonic()
|
||||
for state in self._states.values():
|
||||
for file_state in state.patch_files.values():
|
||||
added, deleted = file_state.last_added, file_state.last_deleted
|
||||
if not file_state.emitted_once:
|
||||
continue
|
||||
if (
|
||||
file_state.last_emitted_added == added
|
||||
and file_state.last_emitted_deleted == deleted
|
||||
):
|
||||
continue
|
||||
file_state.mark_emitted(added, deleted, now)
|
||||
events.append(build_file_edit_live_event(
|
||||
file_state.tracker,
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
))
|
||||
if state.tracker is None:
|
||||
continue
|
||||
added, deleted = state.live_diff_counts()
|
||||
@ -390,6 +535,10 @@ class StreamingFileEditTracker:
|
||||
"""Mark streamed edits as failed when no final tool call will run."""
|
||||
events: list[dict[str, Any]] = []
|
||||
for state in self._states.values():
|
||||
for file_state in state.patch_files.values():
|
||||
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
|
||||
continue
|
||||
events.append(build_file_edit_error_event(file_state.tracker, error))
|
||||
if state.tracker is None:
|
||||
continue
|
||||
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
|
||||
@ -493,6 +642,39 @@ class _StreamingJsonStringField:
|
||||
self.last_char_cr = False
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _StreamingPatchFileState:
|
||||
tracker: FileEditTracker
|
||||
emitted_once: bool = False
|
||||
last_emitted_added: int = -1
|
||||
last_emitted_deleted: int = -1
|
||||
last_emit_at: float = 0.0
|
||||
last_added: int = 0
|
||||
last_deleted: int = 0
|
||||
|
||||
def should_emit(self, added: int, deleted: int, now: float) -> bool:
|
||||
self.last_added = added
|
||||
self.last_deleted = deleted
|
||||
if not self.emitted_once:
|
||||
return True
|
||||
if added == self.last_emitted_added and deleted == self.last_emitted_deleted:
|
||||
return False
|
||||
if max(
|
||||
abs(added - self.last_emitted_added),
|
||||
abs(deleted - self.last_emitted_deleted),
|
||||
) >= _LIVE_EMIT_LINE_STEP:
|
||||
return True
|
||||
return now - self.last_emit_at >= _LIVE_EMIT_INTERVAL_S
|
||||
|
||||
def mark_emitted(self, added: int, deleted: int, now: float) -> None:
|
||||
self.emitted_once = True
|
||||
self.last_added = added
|
||||
self.last_deleted = deleted
|
||||
self.last_emitted_added = added
|
||||
self.last_emitted_deleted = deleted
|
||||
self.last_emit_at = now
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _StreamingFileEditState:
|
||||
key: str
|
||||
@ -510,6 +692,7 @@ class _StreamingFileEditState:
|
||||
new_text: _StreamingJsonStringField = field(
|
||||
default_factory=lambda: _StreamingJsonStringField("new_text")
|
||||
)
|
||||
patch_files: dict[str, _StreamingPatchFileState] = field(default_factory=dict)
|
||||
emitted_once: bool = False
|
||||
last_emitted_added: int = -1
|
||||
last_emitted_deleted: int = -1
|
||||
@ -532,6 +715,7 @@ class _StreamingFileEditState:
|
||||
self.content.reset()
|
||||
self.old_text.reset()
|
||||
self.new_text.reset()
|
||||
self.patch_files.clear()
|
||||
return
|
||||
delta = payload.get("arguments_delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
@ -591,6 +775,13 @@ class _StreamingFileEditState:
|
||||
name = getattr(tool_call, "name", None)
|
||||
if name != self.name:
|
||||
return False
|
||||
if self.name == "apply_patch":
|
||||
arguments = getattr(tool_call, "arguments", None)
|
||||
if not isinstance(arguments, dict):
|
||||
return False
|
||||
patch = arguments.get("patch")
|
||||
streamed_patch = _extract_complete_json_string(self.arguments, "patch")
|
||||
return isinstance(patch, str) and streamed_patch == patch
|
||||
arguments = getattr(tool_call, "arguments", None)
|
||||
if not isinstance(arguments, dict):
|
||||
return False
|
||||
@ -613,6 +804,110 @@ def _stream_key(payload: dict[str, Any]) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _json_bool_true(source: str, key: str) -> bool:
|
||||
return re.search(rf'"{re.escape(key)}"\s*:\s*true\b', source) is not None
|
||||
|
||||
|
||||
def _extract_json_string_prefix(source: str, key: str) -> str | None:
|
||||
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
|
||||
if match is None:
|
||||
return None
|
||||
out: list[str] = []
|
||||
i = match.end()
|
||||
escape = False
|
||||
while i < len(source):
|
||||
ch = source[i]
|
||||
if escape:
|
||||
escape = False
|
||||
if ch == "n":
|
||||
out.append("\n")
|
||||
elif ch == "r":
|
||||
out.append("\r")
|
||||
elif ch == "t":
|
||||
out.append("\t")
|
||||
elif ch == "u":
|
||||
digits = source[i + 1:i + 5]
|
||||
if len(digits) < 4:
|
||||
break
|
||||
try:
|
||||
out.append(chr(int(digits, 16)))
|
||||
except ValueError:
|
||||
break
|
||||
i += 4
|
||||
else:
|
||||
out.append(ch)
|
||||
i += 1
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
i += 1
|
||||
continue
|
||||
if ch == '"':
|
||||
return "".join(out)
|
||||
out.append(ch)
|
||||
i += 1
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _streaming_apply_patch_stats(patch: str) -> list[tuple[str, int, int, bool]]:
|
||||
stats: dict[str, list[Any]] = {}
|
||||
order: list[str] = []
|
||||
current: str | None = None
|
||||
|
||||
def ensure(path: str, *, delete_file: bool = False) -> list[Any]:
|
||||
if path not in stats:
|
||||
stats[path] = [0, 0, False]
|
||||
order.append(path)
|
||||
if delete_file:
|
||||
stats[path][2] = True
|
||||
return stats[path]
|
||||
|
||||
lines = patch.splitlines()
|
||||
tail = ""
|
||||
if patch and not patch.endswith(("\n", "\r")) and lines:
|
||||
tail = lines.pop()
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("*** Add File: "):
|
||||
current = line[len("*** Add File: "):].strip()
|
||||
if current:
|
||||
ensure(current)
|
||||
continue
|
||||
if line.startswith("*** Update File: "):
|
||||
current = line[len("*** Update File: "):].strip()
|
||||
if current:
|
||||
ensure(current)
|
||||
continue
|
||||
if line.startswith("*** Delete File: "):
|
||||
current = line[len("*** Delete File: "):].strip()
|
||||
if current:
|
||||
ensure(current, delete_file=True)
|
||||
continue
|
||||
if line.startswith("*** Move to: "):
|
||||
moved = line[len("*** Move to: "):].strip()
|
||||
if moved:
|
||||
current = moved
|
||||
ensure(current)
|
||||
continue
|
||||
if line.startswith("*** "):
|
||||
current = None
|
||||
continue
|
||||
if not current:
|
||||
continue
|
||||
if line.startswith("+") and not line.startswith("+++"):
|
||||
ensure(current)[0] += 1
|
||||
elif line.startswith("-") and not line.startswith("---"):
|
||||
ensure(current)[1] += 1
|
||||
|
||||
if current and tail:
|
||||
if tail.startswith("+") and not tail.startswith("+++"):
|
||||
ensure(current)[0] += 1
|
||||
elif tail.startswith("-") and not tail.startswith("---"):
|
||||
ensure(current)[1] += 1
|
||||
|
||||
return [(path, int(stats[path][0]), int(stats[path][1]), bool(stats[path][2])) for path in order]
|
||||
|
||||
|
||||
def _extract_complete_json_string(source: str, key: str) -> str | None:
|
||||
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
|
||||
if match is None:
|
||||
@ -705,77 +1000,4 @@ def _predict_after_text(
|
||||
return before_text.replace(old_text, new_text)
|
||||
return before_text.replace(old_text, new_text, 1)
|
||||
return None
|
||||
if tool_name == "notebook_edit":
|
||||
return _predict_notebook_after_text(params, before_text)
|
||||
return None
|
||||
|
||||
|
||||
def _predict_notebook_after_text(params: dict[str, Any], before_text: str) -> str | None:
|
||||
try:
|
||||
nb = json.loads(before_text) if before_text.strip() else _empty_notebook()
|
||||
except Exception:
|
||||
return None
|
||||
cells = nb.get("cells")
|
||||
if not isinstance(cells, list):
|
||||
return None
|
||||
try:
|
||||
cell_index = int(params.get("cell_index", 0))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
new_source = params.get("new_source")
|
||||
source = new_source if isinstance(new_source, str) else ""
|
||||
cell_type = (
|
||||
params.get("cell_type") if params.get("cell_type") in ("code", "markdown") else "code"
|
||||
)
|
||||
mode = (
|
||||
params.get("edit_mode")
|
||||
if params.get("edit_mode") in ("replace", "insert", "delete")
|
||||
else "replace"
|
||||
)
|
||||
if mode == "delete":
|
||||
if 0 <= cell_index < len(cells):
|
||||
cells.pop(cell_index)
|
||||
else:
|
||||
return None
|
||||
elif mode == "insert":
|
||||
insert_at = min(max(cell_index + 1, 0), len(cells))
|
||||
cells.insert(insert_at, _new_notebook_cell(source, str(cell_type)))
|
||||
else:
|
||||
if not (0 <= cell_index < len(cells)):
|
||||
return None
|
||||
cell = cells[cell_index]
|
||||
if not isinstance(cell, dict):
|
||||
return None
|
||||
cell["source"] = source
|
||||
cell["cell_type"] = cell_type
|
||||
if cell_type == "code":
|
||||
cell.setdefault("outputs", [])
|
||||
cell.setdefault("execution_count", None)
|
||||
else:
|
||||
cell.pop("outputs", None)
|
||||
cell.pop("execution_count", None)
|
||||
nb["cells"] = cells
|
||||
try:
|
||||
return json.dumps(nb, indent=1, ensure_ascii=False)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _empty_notebook() -> dict[str, Any]:
|
||||
return {
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5,
|
||||
"metadata": {
|
||||
"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
|
||||
"language_info": {"name": "python"},
|
||||
},
|
||||
"cells": [],
|
||||
}
|
||||
|
||||
|
||||
def _new_notebook_cell(source: str, cell_type: str) -> dict[str, Any]:
|
||||
cell: dict[str, Any] = {"cell_type": cell_type, "source": source, "metadata": {}}
|
||||
if cell_type == "code":
|
||||
cell["outputs"] = []
|
||||
cell["execution_count"] = None
|
||||
return cell
|
||||
|
||||
@ -576,7 +576,7 @@ def build_status_content(
|
||||
|
||||
|
||||
def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]:
|
||||
"""Sync bundled templates to workspace. Only creates missing files."""
|
||||
"""Sync bundled templates to workspace. Creates missing files without overwriting user files."""
|
||||
from importlib.resources import files as pkg_files
|
||||
|
||||
try:
|
||||
@ -589,10 +589,11 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str]
|
||||
added: list[str] = []
|
||||
|
||||
def _write(src, dest: Path):
|
||||
content = src.read_text(encoding="utf-8") if src else ""
|
||||
if dest.exists():
|
||||
return
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_text(src.read_text(encoding="utf-8") if src else "", encoding="utf-8")
|
||||
dest.write_text(content, encoding="utf-8")
|
||||
added.append(str(dest.relative_to(workspace)))
|
||||
|
||||
for item in tpl.iterdir():
|
||||
|
||||
@ -11,8 +11,10 @@ _TOOL_FORMATS: dict[str, tuple[list[str], str, bool, bool]] = {
|
||||
"read_file": (["path", "file_path"], "read {}", True, False),
|
||||
"write_file": (["path", "file_path"], "write {}", True, False),
|
||||
"edit": (["file_path", "path"], "edit {}", True, False),
|
||||
"find_files": (["query", "glob", "path"], "find {}", False, False),
|
||||
"grep": (["pattern"], 'grep "{}"', False, False),
|
||||
"exec": (["command"], "$ {}", False, True),
|
||||
"list_exec_sessions": ([], "exec sessions", False, False),
|
||||
"web_search": (["query"], 'search "{}"', False, False),
|
||||
"web_fetch": (["url"], "fetch {}", True, False),
|
||||
"list_dir": (["path"], "ls {}", True, False),
|
||||
@ -81,6 +83,8 @@ def _extract_arg(tc, key_args: list[str]) -> str | None:
|
||||
|
||||
def _fmt_known(tc, fmt: tuple, max_length: int = 40) -> str:
|
||||
"""Format a registered tool using its template."""
|
||||
if not fmt[0] and "{}" not in fmt[1]:
|
||||
return fmt[1]
|
||||
val = _extract_arg(tc, fmt[0])
|
||||
if val is None:
|
||||
return tc.name
|
||||
|
||||
@ -139,6 +139,13 @@ class TestLoadBootstrapFiles:
|
||||
for name in ContextBuilder.BOOTSTRAP_FILES:
|
||||
assert f"## {name}" in result
|
||||
|
||||
def test_legacy_tools_md_is_not_bootstrapped(self, tmp_path):
|
||||
(tmp_path / "TOOLS.md").write_text("workspace tool notes", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
result = builder._load_bootstrap_files()
|
||||
assert "TOOLS.md" not in result
|
||||
assert "workspace tool notes" not in result
|
||||
|
||||
def test_utf8_content(self, tmp_path):
|
||||
(tmp_path / "AGENTS.md").write_text("用中文回复", encoding="utf-8")
|
||||
builder = _builder(tmp_path)
|
||||
@ -171,6 +178,37 @@ class TestIsTemplateContent:
|
||||
assert ContextBuilder._is_template_content("totally different", "memory/MEMORY.md") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bundled bootstrap templates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBundledToolContract:
|
||||
def test_tool_contract_balances_general_and_coding_workflows(self):
|
||||
from importlib.resources import files as pkg_files
|
||||
|
||||
tpl = pkg_files("nanobot") / "templates" / "agent" / "tool_contract.md"
|
||||
content = tpl.read_text(encoding="utf-8")
|
||||
|
||||
assert "## General Tool Contract" in content
|
||||
assert "Use the narrowest structured tool" in content
|
||||
assert "Do not use `exec` as a universal workaround" in content
|
||||
assert "## File and Coding Workflows" in content
|
||||
assert "apply_patch" in content
|
||||
assert "## Web and External Information" in content
|
||||
assert "## Messaging and Media" in content
|
||||
assert "## Scheduling and Background Work" in content
|
||||
assert "pure coding" not in content.lower()
|
||||
|
||||
def test_tool_contract_is_injected_without_workspace_file(self, tmp_path):
|
||||
builder = _builder(tmp_path)
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
assert "# Tool Usage Notes" in prompt
|
||||
assert "## General Tool Contract" in prompt
|
||||
assert "Do not use `exec` as a universal workaround" in prompt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_user_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -346,6 +346,26 @@ class TestSyncWorkspaceTemplates:
|
||||
content = (workspace / "AGENTS.md").read_text()
|
||||
assert content == "existing content"
|
||||
|
||||
def test_does_not_create_tools_md(self, tmp_path):
|
||||
"""Tool contract is injected internally, not copied into user workspaces."""
|
||||
workspace = tmp_path / "workspace"
|
||||
|
||||
added = sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
assert "TOOLS.md" not in added
|
||||
assert not (workspace / "TOOLS.md").exists()
|
||||
|
||||
def test_preserves_existing_tools_md_without_overwriting(self, tmp_path):
|
||||
"""Legacy user workspaces may have TOOLS.md; sync should leave it untouched."""
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir(parents=True)
|
||||
tools_path = workspace / "TOOLS.md"
|
||||
tools_path.write_text("custom tool notes", encoding="utf-8")
|
||||
|
||||
sync_workspace_templates(workspace, silent=True)
|
||||
|
||||
assert tools_path.read_text(encoding="utf-8") == "custom tool notes"
|
||||
|
||||
def test_creates_memory_directory(self, tmp_path):
|
||||
"""Should create memory directory structure."""
|
||||
workspace = tmp_path / "workspace"
|
||||
|
||||
@ -1007,6 +1007,41 @@ def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -
|
||||
assert sanitized[2]["tool_call_id"] == "3ec83c30d"
|
||||
|
||||
|
||||
def test_openai_compat_deduplicates_duplicate_tool_call_ids_in_history() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{"role": "user", "content": "check both files"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "ab1b45c2a",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": '{"path":"a.txt"}'},
|
||||
},
|
||||
{
|
||||
"id": "ab1b45c2a",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": '{"path":"b.txt"}'},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "a"},
|
||||
{"role": "tool", "tool_call_id": "ab1b45c2a", "name": "read_file", "content": "b"},
|
||||
{"role": "user", "content": "continue"},
|
||||
])
|
||||
|
||||
tool_call_ids = [tc["id"] for tc in sanitized[1]["tool_calls"]]
|
||||
tool_result_ids = [sanitized[2]["tool_call_id"], sanitized[3]["tool_call_id"]]
|
||||
|
||||
assert tool_call_ids[0] == "ab1b45c2a"
|
||||
assert len(tool_call_ids) == len(set(tool_call_ids)) == 2
|
||||
assert tool_result_ids == tool_call_ids
|
||||
|
||||
|
||||
def test_openai_compat_stringifies_dict_tool_arguments() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
287
tests/tools/test_apply_patch_tool.py
Normal file
287
tests/tools/test_apply_patch_tool.py
Normal file
@ -0,0 +1,287 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from nanobot.agent.tools.apply_patch import ApplyPatchTool
|
||||
|
||||
|
||||
def test_apply_patch_adds_file(tmp_path):
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Add File: hello.txt
|
||||
+Hello
|
||||
+world
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "Patch applied" in result
|
||||
assert (tmp_path / "hello.txt").read_text() == "Hello\nworld\n"
|
||||
|
||||
|
||||
def test_apply_patch_updates_multiple_hunks(tmp_path):
|
||||
target = tmp_path / "multi.txt"
|
||||
target.write_text("line1\nline2\nline3\nline4\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: multi.txt
|
||||
@@
|
||||
-line2
|
||||
+changed2
|
||||
@@
|
||||
-line4
|
||||
+changed4
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "update multi.txt" in result
|
||||
assert "(+2/-2)" in result
|
||||
assert target.read_text() == "line1\nchanged2\nline3\nchanged4\n"
|
||||
|
||||
|
||||
def test_apply_patch_dry_run_validates_without_writing(tmp_path):
|
||||
target = tmp_path / "dry.txt"
|
||||
target.write_text("before\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: dry.txt
|
||||
@@
|
||||
-before
|
||||
+after
|
||||
*** Add File: added.txt
|
||||
+new
|
||||
*** End Patch
|
||||
""",
|
||||
dry_run=True,
|
||||
))
|
||||
|
||||
assert "Patch dry-run succeeded" in result
|
||||
assert "- update dry.txt (+1/-1)" in result
|
||||
assert "- add added.txt (+1/-0)" in result
|
||||
assert target.read_text() == "before\n"
|
||||
assert not (tmp_path / "added.txt").exists()
|
||||
|
||||
|
||||
def test_apply_patch_applies_repeated_update_sections_sequentially(tmp_path):
|
||||
target = tmp_path / "repeat.txt"
|
||||
target.write_text("one\ntwo\nthree\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: repeat.txt
|
||||
@@
|
||||
-one
|
||||
+ONE
|
||||
*** Update File: repeat.txt
|
||||
@@
|
||||
-three
|
||||
+THREE
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert result.count("update repeat.txt") == 2
|
||||
assert target.read_text() == "ONE\ntwo\nTHREE\n"
|
||||
|
||||
|
||||
def test_apply_patch_ignores_standard_no_newline_marker(tmp_path):
|
||||
target = tmp_path / "plain.txt"
|
||||
target.write_text("before")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: plain.txt
|
||||
@@ -1,1 +1,1 @@
|
||||
-before
|
||||
\\ No newline at end of file
|
||||
+after
|
||||
\\ No newline at end of file
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "update plain.txt" in result
|
||||
assert target.read_text() == "after\n"
|
||||
|
||||
|
||||
def test_apply_patch_rejects_empty_hunk(tmp_path):
|
||||
target = tmp_path / "plain.txt"
|
||||
target.write_text("before\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: plain.txt
|
||||
@@
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "hunk is empty" in result
|
||||
assert target.read_text() == "before\n"
|
||||
|
||||
|
||||
def test_apply_patch_uses_unified_diff_line_hint(tmp_path):
|
||||
target = tmp_path / "repeated.txt"
|
||||
target.write_text("target\nmiddle\ntarget\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: repeated.txt
|
||||
@@ -3,1 +3,1 @@
|
||||
-target
|
||||
+changed
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "update repeated.txt" in result
|
||||
assert target.read_text() == "target\nmiddle\nchanged\n"
|
||||
|
||||
|
||||
def test_apply_patch_line_hint_does_not_fallback_to_earlier_match(tmp_path):
|
||||
target = tmp_path / "repeated.txt"
|
||||
target.write_text("target\nmiddle\nother\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: repeated.txt
|
||||
@@ -3,1 +3,1 @@
|
||||
-target
|
||||
+changed
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "hunk does not match repeated.txt" in result
|
||||
assert target.read_text() == "target\nmiddle\nother\n"
|
||||
|
||||
|
||||
def test_apply_patch_mismatch_reports_best_match(tmp_path):
|
||||
target = tmp_path / "near.txt"
|
||||
target.write_text("alpha\nbeta\ngamma\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: near.txt
|
||||
@@ -2,1 +2,1 @@
|
||||
-betx
|
||||
+delta
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "hunk does not match near.txt" in result
|
||||
assert "Best match" in result
|
||||
assert "line 2" in result
|
||||
assert target.read_text() == "alpha\nbeta\ngamma\n"
|
||||
|
||||
|
||||
def test_apply_patch_moves_and_updates_file(tmp_path):
|
||||
source = tmp_path / "old" / "name.txt"
|
||||
source.parent.mkdir()
|
||||
source.write_text("old content\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: old/name.txt
|
||||
*** Move to: renamed/dir/name.txt
|
||||
@@
|
||||
-old content
|
||||
+new content
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "move old/name.txt -> renamed/dir/name.txt" in result
|
||||
assert not source.exists()
|
||||
assert (tmp_path / "renamed" / "dir" / "name.txt").read_text() == "new content\n"
|
||||
|
||||
|
||||
def test_apply_patch_deletes_file(tmp_path):
|
||||
target = tmp_path / "obsolete.txt"
|
||||
target.write_text("remove me\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Delete File: obsolete.txt
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "delete obsolete.txt" in result
|
||||
assert not target.exists()
|
||||
|
||||
|
||||
def test_apply_patch_rejects_absolute_and_parent_paths(tmp_path):
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
absolute = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Add File: /tmp/owned.txt
|
||||
+nope
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
parent = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Add File: ../owned.txt
|
||||
+nope
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "must be relative" in absolute
|
||||
assert "must not contain '..'" in parent
|
||||
assert not (tmp_path.parent / "owned.txt").exists()
|
||||
|
||||
|
||||
def test_apply_patch_does_not_overwrite_existing_file_with_add(tmp_path):
|
||||
target = tmp_path / "existing.txt"
|
||||
target.write_text("keep me\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Add File: existing.txt
|
||||
+replace me
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "file to add already exists" in result
|
||||
assert target.read_text() == "keep me\n"
|
||||
|
||||
|
||||
def test_apply_patch_rolls_back_when_late_operation_fails(tmp_path):
|
||||
first = tmp_path / "first.txt"
|
||||
first.write_text("before\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: first.txt
|
||||
@@
|
||||
-before
|
||||
+after
|
||||
*** Delete File: missing.txt
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert "file to delete does not exist" in result
|
||||
assert first.read_text() == "before\n"
|
||||
@ -1,5 +1,5 @@
|
||||
"""Tests for EditFileTool enhancements: read-before-edit tracking, path suggestions,
|
||||
.ipynb detection, and create-file semantics."""
|
||||
notebook JSON editing, and create-file semantics."""
|
||||
|
||||
import pytest
|
||||
|
||||
@ -108,22 +108,27 @@ class TestEditCreateFile:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# .ipynb detection
|
||||
# .ipynb editing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditIpynbDetection:
|
||||
"""edit_file should refuse .ipynb and suggest notebook_edit."""
|
||||
class TestEditIpynbFiles:
|
||||
"""edit_file edits notebooks as normal JSON files."""
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return EditFileTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ipynb_rejected_with_suggestion(self, tool, tmp_path):
|
||||
async def test_ipynb_can_be_edited_as_json(self, tool, tmp_path):
|
||||
f = tmp_path / "analysis.ipynb"
|
||||
f.write_text('{"cells": []}', encoding="utf-8")
|
||||
result = await tool.execute(path=str(f), old_text="x", new_text="y")
|
||||
assert "notebook" in result.lower()
|
||||
result = await tool.execute(
|
||||
path=str(f),
|
||||
old_text='"cells": []',
|
||||
new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
|
||||
)
|
||||
assert "Successfully edited" in result
|
||||
assert '"source": "hi"' in f.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -162,7 +162,7 @@ class TestPathAppendPlatform:
|
||||
captured_cmd = None
|
||||
captured_env = {}
|
||||
|
||||
async def capture_spawn(cmd, cwd, env):
|
||||
async def capture_spawn(cmd, cwd, env, shell_program=None, login=True):
|
||||
nonlocal captured_cmd
|
||||
captured_cmd = cmd
|
||||
captured_env.update(env)
|
||||
@ -190,7 +190,7 @@ class TestPathAppendPlatform:
|
||||
|
||||
captured_env = {}
|
||||
|
||||
async def capture_spawn(cmd, cwd, env):
|
||||
async def capture_spawn(cmd, cwd, env, shell_program=None, login=True):
|
||||
captured_env.update(env)
|
||||
return mock_proc
|
||||
|
||||
|
||||
358
tests/tools/test_exec_session_tools.py
Normal file
358
tests/tools/test_exec_session_tools.py
Normal file
@ -0,0 +1,358 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.exec_session import ExecSessionManager, ListExecSessionsTool, WriteStdinTool
|
||||
|
||||
|
||||
def _python_command(code: str) -> str:
|
||||
if sys.platform == "win32":
|
||||
return f"{subprocess.list2cmdline([sys.executable])} -u -c {subprocess.list2cmdline([code])}"
|
||||
return f"{shlex.quote(sys.executable)} -u -c {shlex.quote(code)}"
|
||||
|
||||
|
||||
def _session_id(output: str) -> str:
|
||||
match = re.search(r"session_id:\s*([0-9a-f]+)", output)
|
||||
assert match, output
|
||||
return match.group(1)
|
||||
|
||||
|
||||
def test_exec_keeps_one_shot_behavior_without_yield_time_ms(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
return await tool.execute(command="echo hello")
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "hello" in result
|
||||
assert "Exit code: 0" in result
|
||||
assert "session_id:" not in result
|
||||
|
||||
|
||||
def test_exec_accepts_command_aliases(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir="/")
|
||||
return await tool.execute(cmd="pwd", workdir=str(tmp_path))
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert str(tmp_path) in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_returns_completed_session_output_when_yield_time_ms_is_used(tmp_path):
|
||||
async def run() -> str:
|
||||
manager = ExecSessionManager()
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
|
||||
result = await tool.execute(command="echo hello", yield_time_ms=1000)
|
||||
if "session_id:" in result:
|
||||
sid = _session_id(result)
|
||||
result += "\n" + await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
chars="",
|
||||
yield_time_ms=1000,
|
||||
)
|
||||
return result
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "hello" in result
|
||||
assert "Exit code: 0" in result
|
||||
assert "session_id:" not in result
|
||||
|
||||
|
||||
def test_exec_session_accepts_max_output_tokens_alias(tmp_path):
|
||||
async def run() -> str:
|
||||
manager = ExecSessionManager()
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
command = _python_command("print('A' * 2000)")
|
||||
return await tool.execute(
|
||||
command=command,
|
||||
yield_time_ms=1000,
|
||||
max_output_tokens=1000,
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "chars truncated" in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_one_shot_accepts_max_output_tokens_alias(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
command = _python_command("print('A' * 2000)")
|
||||
return await tool.execute(command=command, max_output_tokens=1000)
|
||||
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "chars truncated" in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_accepts_supported_shell_parameter(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
return await tool.execute(command="echo shell-ok", shell="sh", login=False)
|
||||
|
||||
if sys.platform == "win32":
|
||||
return
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "shell-ok" in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_exec_rejects_unsupported_shell(tmp_path):
|
||||
async def run() -> str:
|
||||
tool = ExecTool(working_dir=str(tmp_path), timeout=5)
|
||||
return await tool.execute(command="echo no", shell="python")
|
||||
|
||||
if sys.platform == "win32":
|
||||
return
|
||||
result = asyncio.run(run())
|
||||
|
||||
assert "unsupported shell" in result
|
||||
|
||||
|
||||
def test_exec_can_continue_with_stdin(tmp_path):
|
||||
async def run() -> tuple[str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import sys; print('ready', flush=True); "
|
||||
"line=sys.stdin.readline(); print('got:' + line.strip(), flush=True)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||
sid = _session_id(initial)
|
||||
result = await stdin_tool.execute(session_id=sid, chars="ping\n", yield_time_ms=1000)
|
||||
return initial, result
|
||||
|
||||
initial, result = asyncio.run(run())
|
||||
assert "ready" in initial
|
||||
assert "Process running" in initial
|
||||
assert "Elapsed:" in initial
|
||||
assert "got:ping" in result
|
||||
assert "Exit code: 0" in result
|
||||
assert "Elapsed:" in result
|
||||
|
||||
|
||||
def test_write_stdin_can_close_stdin(tmp_path):
|
||||
async def run() -> tuple[str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import sys; print('ready', flush=True); "
|
||||
"data=sys.stdin.read(); print('got:' + data, flush=True)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||
sid = _session_id(initial)
|
||||
result = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
chars="payload",
|
||||
close_stdin=True,
|
||||
yield_time_ms=1000,
|
||||
)
|
||||
return initial, result
|
||||
|
||||
initial, result = asyncio.run(run())
|
||||
assert "ready" in initial
|
||||
assert "got:payload" in result
|
||||
assert "Stdin closed." in result
|
||||
assert "Exit code: 0" in result
|
||||
|
||||
|
||||
def test_write_stdin_can_terminate_session(tmp_path):
|
||||
async def run() -> tuple[str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=30, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('ready', flush=True); time.sleep(30)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||
sid = _session_id(initial)
|
||||
result = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
terminate=True,
|
||||
yield_time_ms=0,
|
||||
)
|
||||
return initial, result
|
||||
|
||||
initial, result = asyncio.run(run())
|
||||
assert "ready" in initial
|
||||
assert "Session terminated." in result
|
||||
assert "Exit code:" in result
|
||||
|
||||
|
||||
def test_write_stdin_accepts_max_output_tokens_alias(tmp_path):
|
||||
async def run() -> tuple[str, str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('A' * 2000, flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=0)
|
||||
sid = _session_id(initial)
|
||||
poll = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
yield_time_ms=500,
|
||||
max_output_tokens=1000,
|
||||
)
|
||||
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||
return initial, poll, cleanup
|
||||
|
||||
initial, poll, cleanup = asyncio.run(run())
|
||||
assert "Process running" in initial
|
||||
assert "chars truncated" in poll
|
||||
assert "Session terminated." in cleanup
|
||||
|
||||
|
||||
def test_write_stdin_preserves_completed_session_output_until_polled(tmp_path):
|
||||
async def run() -> tuple[str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('ready', flush=True); "
|
||||
"time.sleep(1.0); print('done', flush=True)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=300)
|
||||
sid = _session_id(initial)
|
||||
await asyncio.sleep(1.2)
|
||||
final = await stdin_tool.execute(session_id=sid, chars="", yield_time_ms=0)
|
||||
return initial, final
|
||||
|
||||
initial, final = asyncio.run(run())
|
||||
|
||||
assert "ready" in initial
|
||||
assert "done" in final
|
||||
assert "Exit code: 0" in final
|
||||
|
||||
|
||||
def test_write_stdin_can_wait_for_expected_output(tmp_path):
|
||||
async def run() -> tuple[str, str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('booting', flush=True); "
|
||||
"time.sleep(0.4); print('ready', flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=100)
|
||||
sid = _session_id(initial)
|
||||
waited = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
wait_for="ready",
|
||||
wait_timeout_ms=3000,
|
||||
yield_time_ms=0,
|
||||
)
|
||||
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||
return initial, waited, cleanup
|
||||
|
||||
initial, waited, cleanup = asyncio.run(run())
|
||||
|
||||
assert "Process running" in initial
|
||||
assert "booting" in initial + waited
|
||||
assert "ready" in waited
|
||||
assert "Wait target not observed" not in waited
|
||||
assert "Session terminated." in cleanup
|
||||
|
||||
|
||||
def test_write_stdin_wait_for_reports_timeout_without_killing_session(tmp_path):
|
||||
async def run() -> tuple[str, str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('booting', flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=100)
|
||||
sid = _session_id(initial)
|
||||
waited = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
wait_for="never-ready",
|
||||
wait_timeout_ms=200,
|
||||
yield_time_ms=0,
|
||||
)
|
||||
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||
return initial, waited, cleanup
|
||||
|
||||
initial, waited, cleanup = asyncio.run(run())
|
||||
|
||||
assert "Process running" in initial
|
||||
assert "booting" in initial + waited
|
||||
assert "Process running" in waited
|
||||
assert "Wait target not observed: 'never-ready'" in waited
|
||||
assert "Session terminated." in cleanup
|
||||
|
||||
|
||||
def test_exec_session_mode_reuses_exec_safety_guard(tmp_path):
|
||||
manager = ExecSessionManager()
|
||||
tool = ExecTool(
|
||||
working_dir=str(tmp_path),
|
||||
deny_patterns=[r"echo\s+blocked"],
|
||||
session_manager=manager,
|
||||
)
|
||||
|
||||
result = asyncio.run(tool.execute(command="echo blocked", yield_time_ms=0))
|
||||
|
||||
assert "blocked by deny pattern" in result
|
||||
|
||||
|
||||
def test_write_stdin_reports_missing_session(tmp_path):
|
||||
manager = ExecSessionManager()
|
||||
tool = WriteStdinTool(manager=manager)
|
||||
|
||||
result = asyncio.run(tool.execute(session_id="missing", chars=""))
|
||||
|
||||
assert "exec session not found" in result
|
||||
|
||||
|
||||
def test_list_exec_sessions_reports_running_commands(tmp_path):
|
||||
async def run() -> tuple[str, str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
list_tool = ListExecSessionsTool(manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('ready', flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=500)
|
||||
sid = _session_id(initial)
|
||||
listing = await list_tool.execute()
|
||||
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||
return sid, listing, cleanup
|
||||
|
||||
sid, listing, cleanup = asyncio.run(run())
|
||||
|
||||
assert sid in listing
|
||||
assert "running" in listing
|
||||
assert "elapsed=" in listing
|
||||
assert "remaining=" in listing
|
||||
assert str(tmp_path) in listing
|
||||
assert "Session terminated." in cleanup
|
||||
|
||||
|
||||
def test_list_exec_sessions_reports_empty_state():
|
||||
result = asyncio.run(ListExecSessionsTool(manager=ExecSessionManager()).execute())
|
||||
|
||||
assert result == "No active exec sessions."
|
||||
216
tests/tools/test_file_edit_coding_enhancements.py
Normal file
216
tests/tools/test_file_edit_coding_enhancements.py
Normal file
@ -0,0 +1,216 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool
|
||||
|
||||
|
||||
def test_read_file_force_bypasses_dedup(tmp_path):
|
||||
target = tmp_path / "data.txt"
|
||||
target.write_text("alpha\n")
|
||||
tool = ReadFileTool(workspace=tmp_path)
|
||||
|
||||
first = asyncio.run(tool.execute(path=str(target)))
|
||||
second = asyncio.run(tool.execute(path=str(target)))
|
||||
forced = asyncio.run(tool.execute(path=str(target), force=True))
|
||||
|
||||
assert "alpha" in first
|
||||
assert "unchanged" in second.lower()
|
||||
assert "alpha" in forced
|
||||
assert "unchanged" not in forced.lower()
|
||||
|
||||
|
||||
def test_edit_file_can_select_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("one\nsame\ntwo\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=2,
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert target.read_text() == "one\nsame\ntwo\nchanged\n"
|
||||
|
||||
|
||||
def test_edit_file_expected_replacements_guards_replace_all(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
replace_all=True,
|
||||
expected_replacements=1,
|
||||
))
|
||||
|
||||
assert "expected 1 replacements but would make 2" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_expected_replacements_allows_replace_all_when_count_matches(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
replace_all=True,
|
||||
expected_replacements=2,
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert target.read_text() == "changed\nchanged\n"
|
||||
|
||||
|
||||
def test_edit_file_can_select_nearest_line_hint(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("one\nsame\ntwo\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=4,
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert target.read_text() == "one\nsame\ntwo\nchanged\n"
|
||||
|
||||
|
||||
def test_edit_file_can_edit_ipynb_as_json(tmp_path):
|
||||
target = tmp_path / "analysis.ipynb"
|
||||
target.write_text('{"cells": []}')
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text='"cells": []',
|
||||
new_text='"cells": [{"cell_type": "markdown", "source": "hi"}]',
|
||||
))
|
||||
|
||||
assert "Successfully edited" in result
|
||||
assert '"source": "hi"' in target.read_text()
|
||||
|
||||
|
||||
def test_edit_file_multiple_match_hint_mentions_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
))
|
||||
|
||||
assert "old_text appears 2 times" in result
|
||||
assert "occurrence" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_ambiguous_line_hint(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nmiddle\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=2,
|
||||
))
|
||||
|
||||
assert "line_hint 2 is ambiguous" in result
|
||||
assert target.read_text() == "same\nmiddle\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_occurrence_with_replace_all(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=1,
|
||||
replace_all=True,
|
||||
))
|
||||
|
||||
assert "occurrence cannot be used with replace_all" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_line_hint_with_replace_all(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=1,
|
||||
replace_all=True,
|
||||
))
|
||||
|
||||
assert "line_hint cannot be used with replace_all" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_line_hint_with_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\nsame\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=1,
|
||||
line_hint=1,
|
||||
))
|
||||
|
||||
assert "line_hint cannot be used with occurrence" in result
|
||||
assert target.read_text() == "same\nsame\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_zero_occurrence(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
occurrence=0,
|
||||
))
|
||||
|
||||
assert "occurrence must be >= 1" in result
|
||||
assert target.read_text() == "same\n"
|
||||
|
||||
|
||||
def test_edit_file_rejects_zero_line_hint(tmp_path):
|
||||
target = tmp_path / "duplicate.txt"
|
||||
target.write_text("same\n")
|
||||
tool = EditFileTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
path=str(target),
|
||||
old_text="same",
|
||||
new_text="changed",
|
||||
line_hint=0,
|
||||
))
|
||||
|
||||
assert "line_hint must be >= 1" in result
|
||||
assert target.read_text() == "same\n"
|
||||
@ -1,147 +0,0 @@
|
||||
"""Tests for NotebookEditTool — Jupyter .ipynb editing."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
|
||||
|
||||
def _make_notebook(cells: list[dict] | None = None, nbformat: int = 4, nbformat_minor: int = 5) -> dict:
|
||||
"""Build a minimal valid .ipynb structure."""
|
||||
return {
|
||||
"nbformat": nbformat,
|
||||
"nbformat_minor": nbformat_minor,
|
||||
"metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}},
|
||||
"cells": cells or [],
|
||||
}
|
||||
|
||||
|
||||
def _code_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "code", "source": source, "metadata": {}, "outputs": [], "execution_count": None}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _md_cell(source: str, cell_id: str | None = None) -> dict:
|
||||
cell = {"cell_type": "markdown", "source": source, "metadata": {}}
|
||||
if cell_id:
|
||||
cell["id"] = cell_id
|
||||
return cell
|
||||
|
||||
|
||||
def _write_nb(tmp_path, name: str, nb: dict) -> str:
|
||||
p = tmp_path / name
|
||||
p.write_text(json.dumps(nb), encoding="utf-8")
|
||||
return str(p)
|
||||
|
||||
|
||||
class TestNotebookEdit:
|
||||
|
||||
@pytest.fixture()
|
||||
def tool(self, tmp_path):
|
||||
return NotebookEditTool(workspace=tmp_path)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replace_cell_content(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("print('hello')"), _code_cell("x = 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="print('world')")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][0]["source"] == "print('world')"
|
||||
assert saved["cells"][1]["source"] == "x = 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_cell_after_target(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("cell 0"), _code_cell("cell 1")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="inserted", edit_mode="insert")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 3
|
||||
assert saved["cells"][0]["source"] == "cell 0"
|
||||
assert saved["cells"][1]["source"] == "inserted"
|
||||
assert saved["cells"][2]["source"] == "cell 1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_cell(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("A"), _code_cell("B"), _code_cell("C")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=1, edit_mode="delete")
|
||||
assert "Successfully" in result
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert len(saved["cells"]) == 2
|
||||
assert saved["cells"][0]["source"] == "A"
|
||||
assert saved["cells"][1]["source"] == "C"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_new_notebook_from_scratch(self, tool, tmp_path):
|
||||
path = str(tmp_path / "new.ipynb")
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="# Hello", edit_mode="insert", cell_type="markdown")
|
||||
assert "Successfully" in result or "created" in result.lower()
|
||||
saved = json.loads((tmp_path / "new.ipynb").read_text())
|
||||
assert saved["nbformat"] == 4
|
||||
assert len(saved["cells"]) == 1
|
||||
assert saved["cells"][0]["cell_type"] == "markdown"
|
||||
assert saved["cells"][0]["source"] == "# Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_index_error(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("only cell")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=5, new_source="x")
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_ipynb_rejected(self, tool, tmp_path):
|
||||
f = tmp_path / "script.py"
|
||||
f.write_text("pass")
|
||||
result = await tool.execute(path=str(f), cell_index=0, new_source="x")
|
||||
assert "Error" in result
|
||||
assert ".ipynb" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_metadata_and_outputs(self, tool, tmp_path):
|
||||
cell = _code_cell("old")
|
||||
cell["outputs"] = [{"output_type": "stream", "text": "hello\n"}]
|
||||
cell["execution_count"] = 42
|
||||
nb = _make_notebook([cell])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="new")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["metadata"]["kernelspec"]["language"] == "python"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nbformat_45_generates_cell_id(self, tool, tmp_path):
|
||||
nb = _make_notebook([], nbformat_minor=5)
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="x = 1", edit_mode="insert")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert "id" in saved["cells"][0]
|
||||
assert len(saved["cells"][0]["id"]) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_insert_with_cell_type_markdown(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
await tool.execute(path=path, cell_index=0, new_source="# Title", edit_mode="insert", cell_type="markdown")
|
||||
saved = json.loads((tmp_path / "test.ipynb").read_text())
|
||||
assert saved["cells"][1]["cell_type"] == "markdown"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_edit_mode_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", edit_mode="replcae")
|
||||
assert "Error" in result
|
||||
assert "edit_mode" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_cell_type_rejected(self, tool, tmp_path):
|
||||
nb = _make_notebook([_code_cell("code")])
|
||||
path = _write_nb(tmp_path, "test.ipynb", nb)
|
||||
result = await tool.execute(path=path, cell_index=0, new_source="x", cell_type="raw")
|
||||
assert "Error" in result
|
||||
assert "cell_type" in result
|
||||
@ -12,7 +12,7 @@ import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.subagent import SubagentManager, SubagentStatus
|
||||
from nanobot.agent.tools.search import GrepTool
|
||||
from nanobot.agent.tools.search import FindFilesTool, GrepTool
|
||||
from nanobot.agent.tools.web import WebSearchTool
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import WebSearchConfig
|
||||
@ -33,6 +33,68 @@ async def test_web_search_tool_refreshes_dynamic_config_loader(monkeypatch) -> N
|
||||
assert await tool.execute("nanobot") == "duckduckgo:nanobot:3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_files_filters_by_query_glob_and_type(tmp_path: Path) -> None:
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "settings_view.tsx").write_text("export {}\n", encoding="utf-8")
|
||||
(tmp_path / "src" / "settings_api.py").write_text("pass\n", encoding="utf-8")
|
||||
(tmp_path / "README.md").write_text("settings\n", encoding="utf-8")
|
||||
|
||||
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||
result = await tool.execute(
|
||||
path=".",
|
||||
query="settings",
|
||||
glob="src/**",
|
||||
type="ts",
|
||||
)
|
||||
|
||||
assert result.splitlines() == ["src/settings_view.tsx"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_files_can_include_directories(tmp_path: Path) -> None:
|
||||
(tmp_path / "src" / "settings").mkdir(parents=True)
|
||||
(tmp_path / "src" / "settings" / "index.ts").write_text("export {}\n", encoding="utf-8")
|
||||
|
||||
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||
result = await tool.execute(path="src", query="settings", include_dirs=True)
|
||||
|
||||
assert "src/settings/" in result.splitlines()
|
||||
assert "src/settings/index.ts" in result.splitlines()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_files_supports_modified_sort_and_pagination(tmp_path: Path) -> None:
|
||||
(tmp_path / "src").mkdir()
|
||||
for idx, name in enumerate(("a.py", "b.py", "c.py"), start=1):
|
||||
file_path = tmp_path / "src" / name
|
||||
file_path.write_text("pass\n", encoding="utf-8")
|
||||
os.utime(file_path, (idx, idx))
|
||||
|
||||
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||
result = await tool.execute(
|
||||
path="src",
|
||||
type="py",
|
||||
sort="modified",
|
||||
head_limit=1,
|
||||
offset=1,
|
||||
)
|
||||
|
||||
assert result.splitlines()[0] == "src/b.py"
|
||||
assert "pagination: limit=1, offset=1" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_files_rejects_paths_outside_workspace(tmp_path: Path) -> None:
|
||||
outside = tmp_path.parent / "outside-find-files.txt"
|
||||
outside.write_text("secret\n", encoding="utf-8")
|
||||
|
||||
tool = FindFilesTool(workspace=tmp_path, allowed_dir=tmp_path)
|
||||
result = await tool.execute(path=str(outside))
|
||||
|
||||
assert result.startswith("Error:")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_grep_respects_glob_filter_and_context(tmp_path: Path) -> None:
|
||||
(tmp_path / "src").mkdir()
|
||||
@ -249,6 +311,7 @@ def test_agent_loop_registers_grep(tmp_path: Path) -> None:
|
||||
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
assert "find_files" in loop.tools.tool_names
|
||||
assert "grep" in loop.tools.tool_names
|
||||
|
||||
|
||||
@ -280,6 +343,7 @@ async def test_subagent_registers_grep(tmp_path: Path) -> None:
|
||||
status = SubagentStatus(task_id="sub-1", label="label", task_description="search task", started_at=time.monotonic())
|
||||
await mgr._run_subagent("sub-1", "search task", "label", {"channel": "cli", "chat_id": "direct"}, status)
|
||||
|
||||
assert "find_files" in captured["tool_names"]
|
||||
assert "grep" in captured["tool_names"]
|
||||
|
||||
|
||||
|
||||
46
tests/tools/test_tool_descriptions.py
Normal file
46
tests/tools/test_tool_descriptions.py
Normal file
@ -0,0 +1,46 @@
|
||||
from nanobot.agent.tools.apply_patch import ApplyPatchTool
|
||||
from nanobot.agent.tools.exec_session import ListExecSessionsTool, WriteStdinTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.search import FindFilesTool, GrepTool
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
|
||||
|
||||
def test_coding_tool_descriptions_steer_editing_priority() -> None:
|
||||
apply_patch = ApplyPatchTool().description.lower()
|
||||
edit_file = EditFileTool().description.lower()
|
||||
write_file = WriteFileTool().description.lower()
|
||||
|
||||
assert "default tool for code edits" in apply_patch
|
||||
assert "multi-file" in apply_patch
|
||||
assert "dry_run=true" in apply_patch
|
||||
assert "edit_file only for small exact replacements" in apply_patch
|
||||
|
||||
assert "small, exact replacement" in edit_file
|
||||
assert "copied from read_file" in edit_file
|
||||
assert "prefer apply_patch" in edit_file
|
||||
|
||||
assert "replace an entire file" in write_file
|
||||
assert "prefer apply_patch" in write_file
|
||||
|
||||
|
||||
def test_coding_tool_descriptions_steer_discovery_and_shell_usage() -> None:
|
||||
read_file = ReadFileTool().description.lower()
|
||||
find_files = FindFilesTool().description.lower()
|
||||
grep = GrepTool().description.lower()
|
||||
exec_tool = ExecTool().description.lower()
|
||||
write_stdin = WriteStdinTool().description.lower()
|
||||
list_sessions = ListExecSessionsTool().description.lower()
|
||||
|
||||
assert "find_files/list_dir first" in read_file
|
||||
assert "before editing" in read_file
|
||||
assert "prefer it over shell find/ls" in find_files
|
||||
assert "prefer this over shell grep" in grep
|
||||
|
||||
assert "tests, builds" in exec_tool
|
||||
assert "prefer read_file/find_files/grep" in exec_tool
|
||||
assert "apply_patch/write_file/edit_file" in exec_tool
|
||||
assert "yield_time_ms" in exec_tool
|
||||
|
||||
assert "do not use this to start new commands" in write_stdin
|
||||
assert "wait_for" in write_stdin
|
||||
assert "recover a session_id" in list_sessions
|
||||
@ -89,9 +89,11 @@ def test_discover_finds_concrete_tools():
|
||||
loader = ToolLoader()
|
||||
discovered = loader.discover()
|
||||
class_names = {cls.__name__ for cls in discovered}
|
||||
assert "ApplyPatchTool" in class_names
|
||||
assert "ExecTool" in class_names
|
||||
assert "MessageTool" in class_names
|
||||
assert "SpawnTool" in class_names
|
||||
assert "WriteStdinTool" in class_names
|
||||
|
||||
|
||||
def test_discover_excludes_abstract_and_mcp():
|
||||
@ -406,7 +408,8 @@ def test_loader_registers_same_tools_as_old_hardcoded():
|
||||
|
||||
expected = {
|
||||
"read_file", "write_file", "edit_file", "list_dir",
|
||||
"grep", "notebook_edit", "exec", "web_search", "web_fetch",
|
||||
"find_files", "grep", "exec", "write_stdin", "list_exec_sessions",
|
||||
"web_search", "web_fetch",
|
||||
"message", "spawn", "cron",
|
||||
}
|
||||
actual = set(registered)
|
||||
|
||||
@ -10,6 +10,7 @@ from nanobot.utils.file_edit_events import (
|
||||
build_file_edit_start_event,
|
||||
line_diff_stats,
|
||||
prepare_file_edit_tracker,
|
||||
prepare_file_edit_trackers,
|
||||
read_file_snapshot,
|
||||
)
|
||||
|
||||
@ -81,6 +82,71 @@ def test_binary_file_is_reported_but_not_counted(tmp_path: Path) -> None:
|
||||
assert (event["added"], event["deleted"]) == (0, 0)
|
||||
|
||||
|
||||
def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) -> None:
|
||||
(tmp_path / "src").mkdir()
|
||||
existing = tmp_path / "src" / "existing.py"
|
||||
existing.write_text("old\nkeep\n", encoding="utf-8")
|
||||
delete_me = tmp_path / "src" / "delete_me.py"
|
||||
delete_me.write_text("gone\n", encoding="utf-8")
|
||||
|
||||
patch = """*** Begin Patch
|
||||
*** Add File: src/new.py
|
||||
+fresh
|
||||
*** Update File: src/existing.py
|
||||
@@
|
||||
-old
|
||||
+new
|
||||
keep
|
||||
*** Delete File: src/delete_me.py
|
||||
*** End Patch"""
|
||||
|
||||
trackers = prepare_file_edit_trackers(
|
||||
call_id="call-patch",
|
||||
tool_name="apply_patch",
|
||||
tool=None,
|
||||
workspace=tmp_path,
|
||||
params={"patch": patch},
|
||||
)
|
||||
|
||||
assert [tracker.display_path for tracker in trackers] == [
|
||||
"src/new.py",
|
||||
"src/existing.py",
|
||||
"src/delete_me.py",
|
||||
]
|
||||
|
||||
(tmp_path / "src" / "new.py").write_text("fresh\n", encoding="utf-8")
|
||||
existing.write_text("new\nkeep\n", encoding="utf-8")
|
||||
delete_me.unlink()
|
||||
|
||||
events = [build_file_edit_end_event(tracker, {"patch": patch}) for tracker in trackers]
|
||||
by_path = {event["path"]: event for event in events}
|
||||
assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0)
|
||||
assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1)
|
||||
assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1)
|
||||
|
||||
|
||||
def test_apply_patch_dry_run_does_not_prepare_file_edit_trackers(tmp_path: Path) -> None:
|
||||
(tmp_path / "file.txt").write_text("old\n", encoding="utf-8")
|
||||
|
||||
trackers = prepare_file_edit_trackers(
|
||||
call_id="call-patch",
|
||||
tool_name="apply_patch",
|
||||
tool=None,
|
||||
workspace=tmp_path,
|
||||
params={
|
||||
"dry_run": True,
|
||||
"patch": """*** Begin Patch
|
||||
*** Update File: file.txt
|
||||
@@
|
||||
-old
|
||||
+new
|
||||
*** End Patch""",
|
||||
},
|
||||
)
|
||||
|
||||
assert trackers == []
|
||||
|
||||
|
||||
def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None:
|
||||
target = tmp_path / "large.txt"
|
||||
params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)}
|
||||
@ -140,6 +206,66 @@ def test_streaming_write_file_tracker_emits_live_line_counts(tmp_path: Path) ->
|
||||
assert events[-1]["deleted"] == 0
|
||||
|
||||
|
||||
def test_streaming_apply_patch_tracker_emits_live_counts_per_file(tmp_path: Path) -> None:
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "existing.py").write_text("old\nkeep\n", encoding="utf-8")
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-patch",
|
||||
"name": "apply_patch",
|
||||
"arguments_delta": (
|
||||
'{"patch":"*** Begin Patch\\n'
|
||||
'*** Update File: src/existing.py\\n'
|
||||
'@@\\n'
|
||||
'-old\\n'
|
||||
'+new\\n'
|
||||
' keep\\n'
|
||||
'*** Add File: src/new.py\\n'
|
||||
'+fresh\\n'
|
||||
),
|
||||
})
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
by_path = {event["path"]: event for event in events}
|
||||
assert by_path["src/existing.py"]["tool"] == "apply_patch"
|
||||
assert by_path["src/existing.py"]["status"] == "editing"
|
||||
assert by_path["src/existing.py"]["approximate"] is True
|
||||
assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1)
|
||||
assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0)
|
||||
|
||||
|
||||
def test_streaming_apply_patch_tracker_skips_dry_run(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-patch",
|
||||
"name": "apply_patch",
|
||||
"arguments_delta": (
|
||||
'{"dry_run":true,"patch":"*** Begin Patch\\n'
|
||||
'*** Add File: dry.md\\n'
|
||||
'+preview\\n'
|
||||
),
|
||||
})
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert events == []
|
||||
|
||||
|
||||
def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user