feat(tools): tighten patch and session workflows

This commit is contained in:
Xubin Ren 2026-05-21 01:25:20 +08:00
parent 480ca28a2d
commit 5f0ba05de5
6 changed files with 322 additions and 29 deletions

View File

@ -10,7 +10,7 @@ from typing import Any, Literal
from nanobot.agent.tools.base import tool_parameters from nanobot.agent.tools.base import tool_parameters
from nanobot.agent.tools.filesystem import _FsTool from nanobot.agent.tools.filesystem import _FsTool
from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema from nanobot.agent.tools.schema import BooleanSchema, StringSchema, tool_parameters_schema
PatchKind = Literal["add", "delete", "update"] PatchKind = Literal["add", "delete", "update"]
@ -31,6 +31,15 @@ class _PatchOp:
hunks: list[_Hunk] | None = None hunks: list[_Hunk] | None = None
@dataclass(slots=True)
class _PatchSummary:
action: str
path: str
added: int = 0
deleted: int = 0
new_path: str | None = None
class _PatchError(ValueError): class _PatchError(ValueError):
pass pass
@ -65,6 +74,40 @@ def _lines_to_text(lines: list[str]) -> str:
return "\n".join(lines) + "\n" return "\n".join(lines) + "\n"
def _text_line_count(text: str) -> int:
if not text:
return 0
return len(text.splitlines())
def _line_diff_stats(before: str, after: str) -> tuple[int, int]:
before_lines = before.replace("\r\n", "\n").splitlines()
after_lines = after.replace("\r\n", "\n").splitlines()
added = 0
deleted = 0
matcher = difflib.SequenceMatcher(a=before_lines, b=after_lines, autojunk=False)
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "equal":
continue
if tag in ("replace", "delete"):
deleted += i2 - i1
if tag in ("replace", "insert"):
added += j2 - j1
return added, deleted
def _format_summary(summary: _PatchSummary) -> str:
path = (
f"{summary.path} -> {summary.new_path}"
if summary.new_path
else summary.path
)
stats = ""
if summary.added or summary.deleted:
stats = f" (+{summary.added}/-{summary.deleted})"
return f"- {summary.action} {path}{stats}"
def _parse_patch(patch: str) -> list[_PatchOp]: def _parse_patch(patch: str) -> list[_PatchOp]:
lines = patch.replace("\r\n", "\n").replace("\r", "\n").split("\n") lines = patch.replace("\r\n", "\n").replace("\r", "\n").split("\n")
if lines and lines[-1] == "": if lines and lines[-1] == "":
@ -237,6 +280,10 @@ def _apply_hunks(path: str, content: str, hunks: list[_Hunk]) -> str:
"for Add File, Update File, Delete File, and optional Move to.", "for Add File, Update File, Delete File, and optional Move to.",
min_length=1, min_length=1,
), ),
dry_run=BooleanSchema(
description="Validate and summarize the patch without writing files.",
default=False,
),
required=["patch"], required=["patch"],
) )
) )
@ -254,40 +301,64 @@ class ApplyPatchTool(_FsTool):
"Apply a structured patch for code edits. The patch must include " "Apply a structured patch for code edits. The patch must include "
"*** Begin Patch and *** End Patch. Supports Add File, Update File, " "*** Begin Patch and *** End Patch. Supports Add File, Update File, "
"Delete File, and Move to. Paths must be relative. Prefer this for " "Delete File, and Move to. Paths must be relative. Prefer this for "
"multi-file coding changes; use edit_file for small exact replacements." "multi-file coding changes; use edit_file for small exact replacements. "
"Set dry_run=true to validate and preview the change without writing files."
) )
async def execute(self, patch: str, **kwargs: Any) -> str: async def execute(self, patch: str, dry_run: bool = False, **kwargs: Any) -> str:
try: try:
ops = _parse_patch(patch) ops = _parse_patch(patch)
writes: dict[Path, str] = {} writes: dict[Path, str] = {}
deletes: set[Path] = set() deletes: set[Path] = set()
touched: list[str] = [] summaries: list[_PatchSummary] = []
for op in ops: for op in ops:
source = self._resolve(op.path) source = self._resolve(op.path)
if op.kind == "add": if op.kind == "add":
if source.exists(): if source.exists() or source in writes:
raise _PatchError(f"file to add already exists: {op.path}") raise _PatchError(f"file to add already exists: {op.path}")
writes[source] = _lines_to_text(op.add_lines or []) new_content = _lines_to_text(op.add_lines or [])
writes[source] = new_content
deletes.discard(source) deletes.discard(source)
touched.append(f"add {op.path}") summaries.append(_PatchSummary(
action="add",
path=op.path,
added=_text_line_count(new_content),
))
continue continue
if op.kind == "delete": if op.kind == "delete":
if not source.exists(): pending_content = writes.get(source)
if pending_content is None and not source.exists():
raise _PatchError(f"file to delete does not exist: {op.path}") raise _PatchError(f"file to delete does not exist: {op.path}")
if not source.is_file(): if pending_content is None and not source.is_file():
raise _PatchError(f"path to delete is not a file: {op.path}") raise _PatchError(f"path to delete is not a file: {op.path}")
deleted_lines = 0
if pending_content is not None:
deleted_lines = _text_line_count(pending_content)
else:
raw = source.read_bytes()
try:
deleted_lines = _text_line_count(raw.decode("utf-8"))
except UnicodeDecodeError:
deleted_lines = 0
deletes.add(source) deletes.add(source)
writes.pop(source, None) writes.pop(source, None)
touched.append(f"delete {op.path}") summaries.append(_PatchSummary(
action="delete",
path=op.path,
deleted=deleted_lines,
))
continue continue
if not source.exists(): pending_content = writes.get(source)
if pending_content is None and not source.exists():
raise _PatchError(f"file to update does not exist: {op.path}") raise _PatchError(f"file to update does not exist: {op.path}")
if not source.is_file(): if pending_content is None and not source.is_file():
raise _PatchError(f"path to update is not a file: {op.path}") raise _PatchError(f"path to update is not a file: {op.path}")
if pending_content is not None:
content = pending_content
else:
raw = source.read_bytes() raw = source.read_bytes()
try: try:
content = raw.decode("utf-8") content = raw.decode("utf-8")
@ -296,18 +367,31 @@ class ApplyPatchTool(_FsTool):
uses_crlf = "\r\n" in content uses_crlf = "\r\n" in content
content = content.replace("\r\n", "\n") content = content.replace("\r\n", "\n")
new_content = _apply_hunks(op.path, content, op.hunks or []) new_content = _apply_hunks(op.path, content, op.hunks or [])
added, deleted = _line_diff_stats(content, new_content)
if uses_crlf: if uses_crlf:
new_content = new_content.replace("\n", "\r\n") new_content = new_content.replace("\n", "\r\n")
target = self._resolve(op.new_path) if op.new_path else source target = self._resolve(op.new_path) if op.new_path else source
if op.new_path and target.exists() and target != source: if op.new_path and (target.exists() or target in writes) and target != source:
raise _PatchError(f"move target already exists: {op.new_path}") raise _PatchError(f"move target already exists: {op.new_path}")
writes[target] = new_content writes[target] = new_content
deletes.discard(target) deletes.discard(target)
if target != source: if target != source:
deletes.add(source) deletes.add(source)
action = f"move {op.path} -> {op.new_path}" if op.new_path else f"update {op.path}" writes.pop(source, None)
touched.append(action) summaries.append(_PatchSummary(
action="move" if op.new_path else "update",
path=op.path,
new_path=op.new_path,
added=added,
deleted=deleted,
))
if dry_run:
return (
"Patch dry-run succeeded:\n"
+ "\n".join(_format_summary(summary) for summary in summaries)
)
backups: dict[Path, bytes | None] = {} backups: dict[Path, bytes | None] = {}
for path in set(writes) | deletes: for path in set(writes) | deletes:
@ -332,7 +416,10 @@ class ApplyPatchTool(_FsTool):
for path in set(writes) | deletes: for path in set(writes) | deletes:
self._file_states.record_write(path) self._file_states.record_write(path)
return "Patch applied:\n" + "\n".join(f"- {item}" for item in touched) return (
"Patch applied:\n"
+ "\n".join(_format_summary(summary) for summary in summaries)
)
except PermissionError as exc: except PermissionError as exc:
return f"Error: {exc}" return f"Error: {exc}"
except _PatchError as exc: except _PatchError as exc:

View File

@ -16,6 +16,8 @@ from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchem
DEFAULT_YIELD_MS = 1000 DEFAULT_YIELD_MS = 1000
MAX_YIELD_MS = 30_000 MAX_YIELD_MS = 30_000
DEFAULT_WAIT_FOR_MS = 10_000
MAX_WAIT_FOR_MS = 120_000
DEFAULT_MAX_OUTPUT_CHARS = 10_000 DEFAULT_MAX_OUTPUT_CHARS = 10_000
MAX_OUTPUT_CHARS = 50_000 MAX_OUTPUT_CHARS = 50_000
@ -356,6 +358,18 @@ def format_session_poll(session_id: str, poll: _SessionPoll) -> str:
minimum=0, minimum=0,
maximum=MAX_YIELD_MS, maximum=MAX_YIELD_MS,
), ),
wait_for=StringSchema(
"Optional text to wait for in output before returning. "
"Useful for interactive commands and dev servers.",
nullable=True,
),
wait_timeout_ms=IntegerSchema(
DEFAULT_WAIT_FOR_MS,
description="Maximum milliseconds to wait for wait_for text (default 10000, max 120000).",
minimum=0,
maximum=MAX_WAIT_FOR_MS,
nullable=True,
),
max_output_chars=IntegerSchema( max_output_chars=IntegerSchema(
DEFAULT_MAX_OUTPUT_CHARS, DEFAULT_MAX_OUTPUT_CHARS,
description="Maximum output characters to return from this poll (default 10000, max 50000).", description="Maximum output characters to return from this poll (default 10000, max 50000).",
@ -412,8 +426,9 @@ class WriteStdinTool(Tool):
return ( return (
"Write text to a running exec session and return recent output. " "Write text to a running exec session and return recent output. "
"Use chars='' to poll without writing. Set close_stdin=true to send EOF, " "Use chars='' to poll without writing. Set close_stdin=true to send EOF, "
"or terminate=true to stop the session. Sessions finish automatically " "or terminate=true to stop the session. Use wait_for to keep polling "
"when their process exits." "until expected output appears. Sessions finish automatically when "
"their process exits."
) )
async def execute( async def execute(
@ -423,6 +438,8 @@ class WriteStdinTool(Tool):
close_stdin: bool = False, close_stdin: bool = False,
terminate: bool = False, terminate: bool = False,
yield_time_ms: int | None = None, yield_time_ms: int | None = None,
wait_for: str | None = None,
wait_timeout_ms: int | None = None,
max_output_chars: int | None = None, max_output_chars: int | None = None,
max_output_tokens: int | None = None, max_output_tokens: int | None = None,
**kwargs: Any, **kwargs: Any,
@ -430,18 +447,34 @@ class WriteStdinTool(Tool):
try: try:
if max_output_chars is None: if max_output_chars is None:
max_output_chars = max_output_tokens max_output_chars = max_output_tokens
output_limit = clamp_session_int(
max_output_chars,
DEFAULT_MAX_OUTPUT_CHARS,
1000,
MAX_OUTPUT_CHARS,
)
if wait_for:
return await self._wait_for_output(
session_id=session_id,
chars=chars,
close_stdin=close_stdin,
terminate=terminate,
wait_for=wait_for,
wait_timeout_ms=clamp_session_int(
wait_timeout_ms,
DEFAULT_WAIT_FOR_MS,
0,
MAX_WAIT_FOR_MS,
),
max_output_chars=output_limit,
)
poll = await self._manager.write( poll = await self._manager.write(
session_id=session_id, session_id=session_id,
chars=chars, chars=chars,
close_stdin=close_stdin, close_stdin=close_stdin,
terminate=terminate, terminate=terminate,
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS), yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
max_output_chars=clamp_session_int( max_output_chars=output_limit,
max_output_chars,
DEFAULT_MAX_OUTPUT_CHARS,
1000,
MAX_OUTPUT_CHARS,
),
) )
return format_session_poll(session_id, poll) return format_session_poll(session_id, poll)
except KeyError: except KeyError:
@ -449,6 +482,47 @@ class WriteStdinTool(Tool):
except Exception as exc: except Exception as exc:
return f"Error writing to exec session: {exc}" return f"Error writing to exec session: {exc}"
async def _wait_for_output(
self,
*,
session_id: str,
chars: str | None,
close_stdin: bool,
terminate: bool,
wait_for: str,
wait_timeout_ms: int,
max_output_chars: int,
) -> str:
deadline = time.monotonic() + (wait_timeout_ms / 1000)
aggregate: list[str] = []
first = True
poll: _SessionPoll | None = None
while True:
remaining_ms = max(0, int((deadline - time.monotonic()) * 1000))
step_ms = min(500, remaining_ms)
poll = await self._manager.write(
session_id=session_id,
chars=chars if first else None,
close_stdin=close_stdin if first else False,
terminate=terminate if first else False,
yield_time_ms=step_ms,
max_output_chars=max_output_chars,
)
first = False
if poll.output:
aggregate.append(poll.output)
joined = "".join(aggregate)
if wait_for in joined:
poll.output = joined
return format_session_poll(session_id, poll)
if poll.done or remaining_ms <= 0:
poll.output = "".join(aggregate)
result = format_session_poll(session_id, poll)
if wait_for not in poll.output:
result += f"\nWait target not observed: {wait_for!r}"
return result
@tool_parameters(tool_parameters_schema()) @tool_parameters(tool_parameters_schema())
class ListExecSessionsTool(Tool): class ListExecSessionsTool(Tool):

View File

@ -219,6 +219,8 @@ def _resolve_apply_patch_paths(
patch = params.get("patch") patch = params.get("patch")
if not isinstance(patch, str) or not patch.strip(): if not isinstance(patch, str) or not patch.strip():
return [] return []
if params.get("dry_run") is True:
return []
try: try:
from nanobot.agent.tools.apply_patch import _parse_patch from nanobot.agent.tools.apply_patch import _parse_patch

View File

@ -40,9 +40,58 @@ def test_apply_patch_updates_multiple_hunks(tmp_path):
)) ))
assert "update multi.txt" in result assert "update multi.txt" in result
assert "(+2/-2)" in result
assert target.read_text() == "line1\nchanged2\nline3\nchanged4\n" assert target.read_text() == "line1\nchanged2\nline3\nchanged4\n"
def test_apply_patch_dry_run_validates_without_writing(tmp_path):
target = tmp_path / "dry.txt"
target.write_text("before\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: dry.txt
@@
-before
+after
*** Add File: added.txt
+new
*** End Patch
""",
dry_run=True,
))
assert "Patch dry-run succeeded" in result
assert "- update dry.txt (+1/-1)" in result
assert "- add added.txt (+1/-0)" in result
assert target.read_text() == "before\n"
assert not (tmp_path / "added.txt").exists()
def test_apply_patch_applies_repeated_update_sections_sequentially(tmp_path):
target = tmp_path / "repeat.txt"
target.write_text("one\ntwo\nthree\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: repeat.txt
@@
-one
+ONE
*** Update File: repeat.txt
@@
-three
+THREE
*** End Patch
"""
))
assert result.count("update repeat.txt") == 2
assert target.read_text() == "ONE\ntwo\nTHREE\n"
def test_apply_patch_ignores_standard_no_newline_marker(tmp_path): def test_apply_patch_ignores_standard_no_newline_marker(tmp_path):
target = tmp_path / "plain.txt" target = tmp_path / "plain.txt"
target.write_text("before") target.write_text("before")

View File

@ -245,6 +245,65 @@ def test_write_stdin_preserves_completed_session_output_until_polled(tmp_path):
assert "Exit code: 0" in final assert "Exit code: 0" in final
def test_write_stdin_can_wait_for_expected_output(tmp_path):
async def run() -> tuple[str, str, str]:
manager = ExecSessionManager()
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
stdin_tool = WriteStdinTool(manager=manager)
command = _python_command(
"import time; print('booting', flush=True); "
"time.sleep(0.4); print('ready', flush=True); time.sleep(5)"
)
initial = await exec_tool.execute(command=command, yield_time_ms=100)
sid = _session_id(initial)
waited = await stdin_tool.execute(
session_id=sid,
wait_for="ready",
wait_timeout_ms=3000,
yield_time_ms=0,
)
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
return initial, waited, cleanup
initial, waited, cleanup = asyncio.run(run())
assert "Process running" in initial
assert "booting" in waited
assert "ready" in waited
assert "Wait target not observed" not in waited
assert "Session terminated." in cleanup
def test_write_stdin_wait_for_reports_timeout_without_killing_session(tmp_path):
async def run() -> tuple[str, str, str]:
manager = ExecSessionManager()
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
stdin_tool = WriteStdinTool(manager=manager)
command = _python_command(
"import time; print('booting', flush=True); time.sleep(5)"
)
initial = await exec_tool.execute(command=command, yield_time_ms=100)
sid = _session_id(initial)
waited = await stdin_tool.execute(
session_id=sid,
wait_for="never-ready",
wait_timeout_ms=200,
yield_time_ms=0,
)
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
return initial, waited, cleanup
initial, waited, cleanup = asyncio.run(run())
assert "Process running" in initial
assert "booting" in waited
assert "Process running" in waited
assert "Wait target not observed: 'never-ready'" in waited
assert "Session terminated." in cleanup
def test_exec_session_mode_reuses_exec_safety_guard(tmp_path): def test_exec_session_mode_reuses_exec_safety_guard(tmp_path):
manager = ExecSessionManager() manager = ExecSessionManager()
tool = ExecTool( tool = ExecTool(

View File

@ -125,6 +125,28 @@ def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) ->
assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1) assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1)
def test_apply_patch_dry_run_does_not_prepare_file_edit_trackers(tmp_path: Path) -> None:
(tmp_path / "file.txt").write_text("old\n", encoding="utf-8")
trackers = prepare_file_edit_trackers(
call_id="call-patch",
tool_name="apply_patch",
tool=None,
workspace=tmp_path,
params={
"dry_run": True,
"patch": """*** Begin Patch
*** Update File: file.txt
@@
-old
+new
*** End Patch""",
},
)
assert trackers == []
def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None: def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None:
target = tmp_path / "large.txt" target = tmp_path / "large.txt"
params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)} params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)}