mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-22 09:32:33 +00:00
feat(tools): tighten patch and session workflows
This commit is contained in:
parent
480ca28a2d
commit
5f0ba05de5
@ -10,7 +10,7 @@ from typing import Any, Literal
|
||||
|
||||
from nanobot.agent.tools.base import tool_parameters
|
||||
from nanobot.agent.tools.filesystem import _FsTool
|
||||
from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema
|
||||
from nanobot.agent.tools.schema import BooleanSchema, StringSchema, tool_parameters_schema
|
||||
|
||||
|
||||
PatchKind = Literal["add", "delete", "update"]
|
||||
@ -31,6 +31,15 @@ class _PatchOp:
|
||||
hunks: list[_Hunk] | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _PatchSummary:
|
||||
action: str
|
||||
path: str
|
||||
added: int = 0
|
||||
deleted: int = 0
|
||||
new_path: str | None = None
|
||||
|
||||
|
||||
class _PatchError(ValueError):
|
||||
pass
|
||||
|
||||
@ -65,6 +74,40 @@ def _lines_to_text(lines: list[str]) -> str:
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def _text_line_count(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
return len(text.splitlines())
|
||||
|
||||
|
||||
def _line_diff_stats(before: str, after: str) -> tuple[int, int]:
|
||||
before_lines = before.replace("\r\n", "\n").splitlines()
|
||||
after_lines = after.replace("\r\n", "\n").splitlines()
|
||||
added = 0
|
||||
deleted = 0
|
||||
matcher = difflib.SequenceMatcher(a=before_lines, b=after_lines, autojunk=False)
|
||||
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
|
||||
if tag == "equal":
|
||||
continue
|
||||
if tag in ("replace", "delete"):
|
||||
deleted += i2 - i1
|
||||
if tag in ("replace", "insert"):
|
||||
added += j2 - j1
|
||||
return added, deleted
|
||||
|
||||
|
||||
def _format_summary(summary: _PatchSummary) -> str:
|
||||
path = (
|
||||
f"{summary.path} -> {summary.new_path}"
|
||||
if summary.new_path
|
||||
else summary.path
|
||||
)
|
||||
stats = ""
|
||||
if summary.added or summary.deleted:
|
||||
stats = f" (+{summary.added}/-{summary.deleted})"
|
||||
return f"- {summary.action} {path}{stats}"
|
||||
|
||||
|
||||
def _parse_patch(patch: str) -> list[_PatchOp]:
|
||||
lines = patch.replace("\r\n", "\n").replace("\r", "\n").split("\n")
|
||||
if lines and lines[-1] == "":
|
||||
@ -237,6 +280,10 @@ def _apply_hunks(path: str, content: str, hunks: list[_Hunk]) -> str:
|
||||
"for Add File, Update File, Delete File, and optional Move to.",
|
||||
min_length=1,
|
||||
),
|
||||
dry_run=BooleanSchema(
|
||||
description="Validate and summarize the patch without writing files.",
|
||||
default=False,
|
||||
),
|
||||
required=["patch"],
|
||||
)
|
||||
)
|
||||
@ -254,60 +301,97 @@ class ApplyPatchTool(_FsTool):
|
||||
"Apply a structured patch for code edits. The patch must include "
|
||||
"*** Begin Patch and *** End Patch. Supports Add File, Update File, "
|
||||
"Delete File, and Move to. Paths must be relative. Prefer this for "
|
||||
"multi-file coding changes; use edit_file for small exact replacements."
|
||||
"multi-file coding changes; use edit_file for small exact replacements. "
|
||||
"Set dry_run=true to validate and preview the change without writing files."
|
||||
)
|
||||
|
||||
async def execute(self, patch: str, **kwargs: Any) -> str:
|
||||
async def execute(self, patch: str, dry_run: bool = False, **kwargs: Any) -> str:
|
||||
try:
|
||||
ops = _parse_patch(patch)
|
||||
writes: dict[Path, str] = {}
|
||||
deletes: set[Path] = set()
|
||||
touched: list[str] = []
|
||||
summaries: list[_PatchSummary] = []
|
||||
|
||||
for op in ops:
|
||||
source = self._resolve(op.path)
|
||||
if op.kind == "add":
|
||||
if source.exists():
|
||||
if source.exists() or source in writes:
|
||||
raise _PatchError(f"file to add already exists: {op.path}")
|
||||
writes[source] = _lines_to_text(op.add_lines or [])
|
||||
new_content = _lines_to_text(op.add_lines or [])
|
||||
writes[source] = new_content
|
||||
deletes.discard(source)
|
||||
touched.append(f"add {op.path}")
|
||||
summaries.append(_PatchSummary(
|
||||
action="add",
|
||||
path=op.path,
|
||||
added=_text_line_count(new_content),
|
||||
))
|
||||
continue
|
||||
|
||||
if op.kind == "delete":
|
||||
if not source.exists():
|
||||
pending_content = writes.get(source)
|
||||
if pending_content is None and not source.exists():
|
||||
raise _PatchError(f"file to delete does not exist: {op.path}")
|
||||
if not source.is_file():
|
||||
if pending_content is None and not source.is_file():
|
||||
raise _PatchError(f"path to delete is not a file: {op.path}")
|
||||
deleted_lines = 0
|
||||
if pending_content is not None:
|
||||
deleted_lines = _text_line_count(pending_content)
|
||||
else:
|
||||
raw = source.read_bytes()
|
||||
try:
|
||||
deleted_lines = _text_line_count(raw.decode("utf-8"))
|
||||
except UnicodeDecodeError:
|
||||
deleted_lines = 0
|
||||
deletes.add(source)
|
||||
writes.pop(source, None)
|
||||
touched.append(f"delete {op.path}")
|
||||
summaries.append(_PatchSummary(
|
||||
action="delete",
|
||||
path=op.path,
|
||||
deleted=deleted_lines,
|
||||
))
|
||||
continue
|
||||
|
||||
if not source.exists():
|
||||
pending_content = writes.get(source)
|
||||
if pending_content is None and not source.exists():
|
||||
raise _PatchError(f"file to update does not exist: {op.path}")
|
||||
if not source.is_file():
|
||||
if pending_content is None and not source.is_file():
|
||||
raise _PatchError(f"path to update is not a file: {op.path}")
|
||||
raw = source.read_bytes()
|
||||
try:
|
||||
content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise _PatchError(f"file to update is not UTF-8 text: {op.path}") from exc
|
||||
if pending_content is not None:
|
||||
content = pending_content
|
||||
else:
|
||||
raw = source.read_bytes()
|
||||
try:
|
||||
content = raw.decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise _PatchError(f"file to update is not UTF-8 text: {op.path}") from exc
|
||||
uses_crlf = "\r\n" in content
|
||||
content = content.replace("\r\n", "\n")
|
||||
new_content = _apply_hunks(op.path, content, op.hunks or [])
|
||||
added, deleted = _line_diff_stats(content, new_content)
|
||||
if uses_crlf:
|
||||
new_content = new_content.replace("\n", "\r\n")
|
||||
|
||||
target = self._resolve(op.new_path) if op.new_path else source
|
||||
if op.new_path and target.exists() and target != source:
|
||||
if op.new_path and (target.exists() or target in writes) and target != source:
|
||||
raise _PatchError(f"move target already exists: {op.new_path}")
|
||||
writes[target] = new_content
|
||||
deletes.discard(target)
|
||||
if target != source:
|
||||
deletes.add(source)
|
||||
action = f"move {op.path} -> {op.new_path}" if op.new_path else f"update {op.path}"
|
||||
touched.append(action)
|
||||
writes.pop(source, None)
|
||||
summaries.append(_PatchSummary(
|
||||
action="move" if op.new_path else "update",
|
||||
path=op.path,
|
||||
new_path=op.new_path,
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
))
|
||||
|
||||
if dry_run:
|
||||
return (
|
||||
"Patch dry-run succeeded:\n"
|
||||
+ "\n".join(_format_summary(summary) for summary in summaries)
|
||||
)
|
||||
|
||||
backups: dict[Path, bytes | None] = {}
|
||||
for path in set(writes) | deletes:
|
||||
@ -332,7 +416,10 @@ class ApplyPatchTool(_FsTool):
|
||||
|
||||
for path in set(writes) | deletes:
|
||||
self._file_states.record_write(path)
|
||||
return "Patch applied:\n" + "\n".join(f"- {item}" for item in touched)
|
||||
return (
|
||||
"Patch applied:\n"
|
||||
+ "\n".join(_format_summary(summary) for summary in summaries)
|
||||
)
|
||||
except PermissionError as exc:
|
||||
return f"Error: {exc}"
|
||||
except _PatchError as exc:
|
||||
|
||||
@ -16,6 +16,8 @@ from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchem
|
||||
|
||||
DEFAULT_YIELD_MS = 1000
|
||||
MAX_YIELD_MS = 30_000
|
||||
DEFAULT_WAIT_FOR_MS = 10_000
|
||||
MAX_WAIT_FOR_MS = 120_000
|
||||
DEFAULT_MAX_OUTPUT_CHARS = 10_000
|
||||
MAX_OUTPUT_CHARS = 50_000
|
||||
|
||||
@ -356,6 +358,18 @@ def format_session_poll(session_id: str, poll: _SessionPoll) -> str:
|
||||
minimum=0,
|
||||
maximum=MAX_YIELD_MS,
|
||||
),
|
||||
wait_for=StringSchema(
|
||||
"Optional text to wait for in output before returning. "
|
||||
"Useful for interactive commands and dev servers.",
|
||||
nullable=True,
|
||||
),
|
||||
wait_timeout_ms=IntegerSchema(
|
||||
DEFAULT_WAIT_FOR_MS,
|
||||
description="Maximum milliseconds to wait for wait_for text (default 10000, max 120000).",
|
||||
minimum=0,
|
||||
maximum=MAX_WAIT_FOR_MS,
|
||||
nullable=True,
|
||||
),
|
||||
max_output_chars=IntegerSchema(
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
description="Maximum output characters to return from this poll (default 10000, max 50000).",
|
||||
@ -412,8 +426,9 @@ class WriteStdinTool(Tool):
|
||||
return (
|
||||
"Write text to a running exec session and return recent output. "
|
||||
"Use chars='' to poll without writing. Set close_stdin=true to send EOF, "
|
||||
"or terminate=true to stop the session. Sessions finish automatically "
|
||||
"when their process exits."
|
||||
"or terminate=true to stop the session. Use wait_for to keep polling "
|
||||
"until expected output appears. Sessions finish automatically when "
|
||||
"their process exits."
|
||||
)
|
||||
|
||||
async def execute(
|
||||
@ -423,6 +438,8 @@ class WriteStdinTool(Tool):
|
||||
close_stdin: bool = False,
|
||||
terminate: bool = False,
|
||||
yield_time_ms: int | None = None,
|
||||
wait_for: str | None = None,
|
||||
wait_timeout_ms: int | None = None,
|
||||
max_output_chars: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
**kwargs: Any,
|
||||
@ -430,18 +447,34 @@ class WriteStdinTool(Tool):
|
||||
try:
|
||||
if max_output_chars is None:
|
||||
max_output_chars = max_output_tokens
|
||||
output_limit = clamp_session_int(
|
||||
max_output_chars,
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
1000,
|
||||
MAX_OUTPUT_CHARS,
|
||||
)
|
||||
if wait_for:
|
||||
return await self._wait_for_output(
|
||||
session_id=session_id,
|
||||
chars=chars,
|
||||
close_stdin=close_stdin,
|
||||
terminate=terminate,
|
||||
wait_for=wait_for,
|
||||
wait_timeout_ms=clamp_session_int(
|
||||
wait_timeout_ms,
|
||||
DEFAULT_WAIT_FOR_MS,
|
||||
0,
|
||||
MAX_WAIT_FOR_MS,
|
||||
),
|
||||
max_output_chars=output_limit,
|
||||
)
|
||||
poll = await self._manager.write(
|
||||
session_id=session_id,
|
||||
chars=chars,
|
||||
close_stdin=close_stdin,
|
||||
terminate=terminate,
|
||||
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
|
||||
max_output_chars=clamp_session_int(
|
||||
max_output_chars,
|
||||
DEFAULT_MAX_OUTPUT_CHARS,
|
||||
1000,
|
||||
MAX_OUTPUT_CHARS,
|
||||
),
|
||||
max_output_chars=output_limit,
|
||||
)
|
||||
return format_session_poll(session_id, poll)
|
||||
except KeyError:
|
||||
@ -449,6 +482,47 @@ class WriteStdinTool(Tool):
|
||||
except Exception as exc:
|
||||
return f"Error writing to exec session: {exc}"
|
||||
|
||||
async def _wait_for_output(
|
||||
self,
|
||||
*,
|
||||
session_id: str,
|
||||
chars: str | None,
|
||||
close_stdin: bool,
|
||||
terminate: bool,
|
||||
wait_for: str,
|
||||
wait_timeout_ms: int,
|
||||
max_output_chars: int,
|
||||
) -> str:
|
||||
deadline = time.monotonic() + (wait_timeout_ms / 1000)
|
||||
aggregate: list[str] = []
|
||||
first = True
|
||||
poll: _SessionPoll | None = None
|
||||
|
||||
while True:
|
||||
remaining_ms = max(0, int((deadline - time.monotonic()) * 1000))
|
||||
step_ms = min(500, remaining_ms)
|
||||
poll = await self._manager.write(
|
||||
session_id=session_id,
|
||||
chars=chars if first else None,
|
||||
close_stdin=close_stdin if first else False,
|
||||
terminate=terminate if first else False,
|
||||
yield_time_ms=step_ms,
|
||||
max_output_chars=max_output_chars,
|
||||
)
|
||||
first = False
|
||||
if poll.output:
|
||||
aggregate.append(poll.output)
|
||||
joined = "".join(aggregate)
|
||||
if wait_for in joined:
|
||||
poll.output = joined
|
||||
return format_session_poll(session_id, poll)
|
||||
if poll.done or remaining_ms <= 0:
|
||||
poll.output = "".join(aggregate)
|
||||
result = format_session_poll(session_id, poll)
|
||||
if wait_for not in poll.output:
|
||||
result += f"\nWait target not observed: {wait_for!r}"
|
||||
return result
|
||||
|
||||
|
||||
@tool_parameters(tool_parameters_schema())
|
||||
class ListExecSessionsTool(Tool):
|
||||
|
||||
@ -219,6 +219,8 @@ def _resolve_apply_patch_paths(
|
||||
patch = params.get("patch")
|
||||
if not isinstance(patch, str) or not patch.strip():
|
||||
return []
|
||||
if params.get("dry_run") is True:
|
||||
return []
|
||||
try:
|
||||
from nanobot.agent.tools.apply_patch import _parse_patch
|
||||
|
||||
|
||||
@ -40,9 +40,58 @@ def test_apply_patch_updates_multiple_hunks(tmp_path):
|
||||
))
|
||||
|
||||
assert "update multi.txt" in result
|
||||
assert "(+2/-2)" in result
|
||||
assert target.read_text() == "line1\nchanged2\nline3\nchanged4\n"
|
||||
|
||||
|
||||
def test_apply_patch_dry_run_validates_without_writing(tmp_path):
|
||||
target = tmp_path / "dry.txt"
|
||||
target.write_text("before\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: dry.txt
|
||||
@@
|
||||
-before
|
||||
+after
|
||||
*** Add File: added.txt
|
||||
+new
|
||||
*** End Patch
|
||||
""",
|
||||
dry_run=True,
|
||||
))
|
||||
|
||||
assert "Patch dry-run succeeded" in result
|
||||
assert "- update dry.txt (+1/-1)" in result
|
||||
assert "- add added.txt (+1/-0)" in result
|
||||
assert target.read_text() == "before\n"
|
||||
assert not (tmp_path / "added.txt").exists()
|
||||
|
||||
|
||||
def test_apply_patch_applies_repeated_update_sections_sequentially(tmp_path):
|
||||
target = tmp_path / "repeat.txt"
|
||||
target.write_text("one\ntwo\nthree\n")
|
||||
tool = ApplyPatchTool(workspace=tmp_path)
|
||||
|
||||
result = asyncio.run(tool.execute(
|
||||
patch="""*** Begin Patch
|
||||
*** Update File: repeat.txt
|
||||
@@
|
||||
-one
|
||||
+ONE
|
||||
*** Update File: repeat.txt
|
||||
@@
|
||||
-three
|
||||
+THREE
|
||||
*** End Patch
|
||||
"""
|
||||
))
|
||||
|
||||
assert result.count("update repeat.txt") == 2
|
||||
assert target.read_text() == "ONE\ntwo\nTHREE\n"
|
||||
|
||||
|
||||
def test_apply_patch_ignores_standard_no_newline_marker(tmp_path):
|
||||
target = tmp_path / "plain.txt"
|
||||
target.write_text("before")
|
||||
|
||||
@ -245,6 +245,65 @@ def test_write_stdin_preserves_completed_session_output_until_polled(tmp_path):
|
||||
assert "Exit code: 0" in final
|
||||
|
||||
|
||||
def test_write_stdin_can_wait_for_expected_output(tmp_path):
|
||||
async def run() -> tuple[str, str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('booting', flush=True); "
|
||||
"time.sleep(0.4); print('ready', flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=100)
|
||||
sid = _session_id(initial)
|
||||
waited = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
wait_for="ready",
|
||||
wait_timeout_ms=3000,
|
||||
yield_time_ms=0,
|
||||
)
|
||||
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||
return initial, waited, cleanup
|
||||
|
||||
initial, waited, cleanup = asyncio.run(run())
|
||||
|
||||
assert "Process running" in initial
|
||||
assert "booting" in waited
|
||||
assert "ready" in waited
|
||||
assert "Wait target not observed" not in waited
|
||||
assert "Session terminated." in cleanup
|
||||
|
||||
|
||||
def test_write_stdin_wait_for_reports_timeout_without_killing_session(tmp_path):
|
||||
async def run() -> tuple[str, str, str]:
|
||||
manager = ExecSessionManager()
|
||||
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
|
||||
stdin_tool = WriteStdinTool(manager=manager)
|
||||
command = _python_command(
|
||||
"import time; print('booting', flush=True); time.sleep(5)"
|
||||
)
|
||||
|
||||
initial = await exec_tool.execute(command=command, yield_time_ms=100)
|
||||
sid = _session_id(initial)
|
||||
waited = await stdin_tool.execute(
|
||||
session_id=sid,
|
||||
wait_for="never-ready",
|
||||
wait_timeout_ms=200,
|
||||
yield_time_ms=0,
|
||||
)
|
||||
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
|
||||
return initial, waited, cleanup
|
||||
|
||||
initial, waited, cleanup = asyncio.run(run())
|
||||
|
||||
assert "Process running" in initial
|
||||
assert "booting" in waited
|
||||
assert "Process running" in waited
|
||||
assert "Wait target not observed: 'never-ready'" in waited
|
||||
assert "Session terminated." in cleanup
|
||||
|
||||
|
||||
def test_exec_session_mode_reuses_exec_safety_guard(tmp_path):
|
||||
manager = ExecSessionManager()
|
||||
tool = ExecTool(
|
||||
|
||||
@ -125,6 +125,28 @@ def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) ->
|
||||
assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1)
|
||||
|
||||
|
||||
def test_apply_patch_dry_run_does_not_prepare_file_edit_trackers(tmp_path: Path) -> None:
|
||||
(tmp_path / "file.txt").write_text("old\n", encoding="utf-8")
|
||||
|
||||
trackers = prepare_file_edit_trackers(
|
||||
call_id="call-patch",
|
||||
tool_name="apply_patch",
|
||||
tool=None,
|
||||
workspace=tmp_path,
|
||||
params={
|
||||
"dry_run": True,
|
||||
"patch": """*** Begin Patch
|
||||
*** Update File: file.txt
|
||||
@@
|
||||
-old
|
||||
+new
|
||||
*** End Patch""",
|
||||
},
|
||||
)
|
||||
|
||||
assert trackers == []
|
||||
|
||||
|
||||
def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None:
|
||||
target = tmp_path / "large.txt"
|
||||
params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user