feat(tools): tighten patch and session workflows

This commit is contained in:
Xubin Ren 2026-05-21 01:25:20 +08:00
parent 480ca28a2d
commit 5f0ba05de5
6 changed files with 322 additions and 29 deletions

View File

@ -10,7 +10,7 @@ from typing import Any, Literal
from nanobot.agent.tools.base import tool_parameters
from nanobot.agent.tools.filesystem import _FsTool
from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema
from nanobot.agent.tools.schema import BooleanSchema, StringSchema, tool_parameters_schema
PatchKind = Literal["add", "delete", "update"]
@ -31,6 +31,15 @@ class _PatchOp:
hunks: list[_Hunk] | None = None
@dataclass(slots=True)
class _PatchSummary:
action: str
path: str
added: int = 0
deleted: int = 0
new_path: str | None = None
class _PatchError(ValueError):
pass
@ -65,6 +74,40 @@ def _lines_to_text(lines: list[str]) -> str:
return "\n".join(lines) + "\n"
def _text_line_count(text: str) -> int:
if not text:
return 0
return len(text.splitlines())
def _line_diff_stats(before: str, after: str) -> tuple[int, int]:
before_lines = before.replace("\r\n", "\n").splitlines()
after_lines = after.replace("\r\n", "\n").splitlines()
added = 0
deleted = 0
matcher = difflib.SequenceMatcher(a=before_lines, b=after_lines, autojunk=False)
for tag, i1, i2, j1, j2 in matcher.get_opcodes():
if tag == "equal":
continue
if tag in ("replace", "delete"):
deleted += i2 - i1
if tag in ("replace", "insert"):
added += j2 - j1
return added, deleted
def _format_summary(summary: _PatchSummary) -> str:
path = (
f"{summary.path} -> {summary.new_path}"
if summary.new_path
else summary.path
)
stats = ""
if summary.added or summary.deleted:
stats = f" (+{summary.added}/-{summary.deleted})"
return f"- {summary.action} {path}{stats}"
def _parse_patch(patch: str) -> list[_PatchOp]:
lines = patch.replace("\r\n", "\n").replace("\r", "\n").split("\n")
if lines and lines[-1] == "":
@ -237,6 +280,10 @@ def _apply_hunks(path: str, content: str, hunks: list[_Hunk]) -> str:
"for Add File, Update File, Delete File, and optional Move to.",
min_length=1,
),
dry_run=BooleanSchema(
description="Validate and summarize the patch without writing files.",
default=False,
),
required=["patch"],
)
)
@ -255,39 +302,63 @@ class ApplyPatchTool(_FsTool):
"*** Begin Patch and *** End Patch. Supports Add File, Update File, "
"Delete File, and Move to. Paths must be relative. Prefer this for "
"multi-file coding changes; use edit_file for small exact replacements. "
"Set dry_run=true to validate and preview the change without writing files."
)
async def execute(self, patch: str, **kwargs: Any) -> str:
async def execute(self, patch: str, dry_run: bool = False, **kwargs: Any) -> str:
try:
ops = _parse_patch(patch)
writes: dict[Path, str] = {}
deletes: set[Path] = set()
touched: list[str] = []
summaries: list[_PatchSummary] = []
for op in ops:
source = self._resolve(op.path)
if op.kind == "add":
if source.exists():
if source.exists() or source in writes:
raise _PatchError(f"file to add already exists: {op.path}")
writes[source] = _lines_to_text(op.add_lines or [])
new_content = _lines_to_text(op.add_lines or [])
writes[source] = new_content
deletes.discard(source)
touched.append(f"add {op.path}")
summaries.append(_PatchSummary(
action="add",
path=op.path,
added=_text_line_count(new_content),
))
continue
if op.kind == "delete":
if not source.exists():
pending_content = writes.get(source)
if pending_content is None and not source.exists():
raise _PatchError(f"file to delete does not exist: {op.path}")
if not source.is_file():
if pending_content is None and not source.is_file():
raise _PatchError(f"path to delete is not a file: {op.path}")
deleted_lines = 0
if pending_content is not None:
deleted_lines = _text_line_count(pending_content)
else:
raw = source.read_bytes()
try:
deleted_lines = _text_line_count(raw.decode("utf-8"))
except UnicodeDecodeError:
deleted_lines = 0
deletes.add(source)
writes.pop(source, None)
touched.append(f"delete {op.path}")
summaries.append(_PatchSummary(
action="delete",
path=op.path,
deleted=deleted_lines,
))
continue
if not source.exists():
pending_content = writes.get(source)
if pending_content is None and not source.exists():
raise _PatchError(f"file to update does not exist: {op.path}")
if not source.is_file():
if pending_content is None and not source.is_file():
raise _PatchError(f"path to update is not a file: {op.path}")
if pending_content is not None:
content = pending_content
else:
raw = source.read_bytes()
try:
content = raw.decode("utf-8")
@ -296,18 +367,31 @@ class ApplyPatchTool(_FsTool):
uses_crlf = "\r\n" in content
content = content.replace("\r\n", "\n")
new_content = _apply_hunks(op.path, content, op.hunks or [])
added, deleted = _line_diff_stats(content, new_content)
if uses_crlf:
new_content = new_content.replace("\n", "\r\n")
target = self._resolve(op.new_path) if op.new_path else source
if op.new_path and target.exists() and target != source:
if op.new_path and (target.exists() or target in writes) and target != source:
raise _PatchError(f"move target already exists: {op.new_path}")
writes[target] = new_content
deletes.discard(target)
if target != source:
deletes.add(source)
action = f"move {op.path} -> {op.new_path}" if op.new_path else f"update {op.path}"
touched.append(action)
writes.pop(source, None)
summaries.append(_PatchSummary(
action="move" if op.new_path else "update",
path=op.path,
new_path=op.new_path,
added=added,
deleted=deleted,
))
if dry_run:
return (
"Patch dry-run succeeded:\n"
+ "\n".join(_format_summary(summary) for summary in summaries)
)
backups: dict[Path, bytes | None] = {}
for path in set(writes) | deletes:
@ -332,7 +416,10 @@ class ApplyPatchTool(_FsTool):
for path in set(writes) | deletes:
self._file_states.record_write(path)
return "Patch applied:\n" + "\n".join(f"- {item}" for item in touched)
return (
"Patch applied:\n"
+ "\n".join(_format_summary(summary) for summary in summaries)
)
except PermissionError as exc:
return f"Error: {exc}"
except _PatchError as exc:

View File

@ -16,6 +16,8 @@ from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchem
DEFAULT_YIELD_MS = 1000
MAX_YIELD_MS = 30_000
DEFAULT_WAIT_FOR_MS = 10_000
MAX_WAIT_FOR_MS = 120_000
DEFAULT_MAX_OUTPUT_CHARS = 10_000
MAX_OUTPUT_CHARS = 50_000
@ -356,6 +358,18 @@ def format_session_poll(session_id: str, poll: _SessionPoll) -> str:
minimum=0,
maximum=MAX_YIELD_MS,
),
wait_for=StringSchema(
"Optional text to wait for in output before returning. "
"Useful for interactive commands and dev servers.",
nullable=True,
),
wait_timeout_ms=IntegerSchema(
DEFAULT_WAIT_FOR_MS,
description="Maximum milliseconds to wait for wait_for text (default 10000, max 120000).",
minimum=0,
maximum=MAX_WAIT_FOR_MS,
nullable=True,
),
max_output_chars=IntegerSchema(
DEFAULT_MAX_OUTPUT_CHARS,
description="Maximum output characters to return from this poll (default 10000, max 50000).",
@ -412,8 +426,9 @@ class WriteStdinTool(Tool):
return (
"Write text to a running exec session and return recent output. "
"Use chars='' to poll without writing. Set close_stdin=true to send EOF, "
"or terminate=true to stop the session. Sessions finish automatically "
"when their process exits."
"or terminate=true to stop the session. Use wait_for to keep polling "
"until expected output appears. Sessions finish automatically when "
"their process exits."
)
async def execute(
@ -423,6 +438,8 @@ class WriteStdinTool(Tool):
close_stdin: bool = False,
terminate: bool = False,
yield_time_ms: int | None = None,
wait_for: str | None = None,
wait_timeout_ms: int | None = None,
max_output_chars: int | None = None,
max_output_tokens: int | None = None,
**kwargs: Any,
@ -430,18 +447,34 @@ class WriteStdinTool(Tool):
try:
if max_output_chars is None:
max_output_chars = max_output_tokens
output_limit = clamp_session_int(
max_output_chars,
DEFAULT_MAX_OUTPUT_CHARS,
1000,
MAX_OUTPUT_CHARS,
)
if wait_for:
return await self._wait_for_output(
session_id=session_id,
chars=chars,
close_stdin=close_stdin,
terminate=terminate,
wait_for=wait_for,
wait_timeout_ms=clamp_session_int(
wait_timeout_ms,
DEFAULT_WAIT_FOR_MS,
0,
MAX_WAIT_FOR_MS,
),
max_output_chars=output_limit,
)
poll = await self._manager.write(
session_id=session_id,
chars=chars,
close_stdin=close_stdin,
terminate=terminate,
yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS),
max_output_chars=clamp_session_int(
max_output_chars,
DEFAULT_MAX_OUTPUT_CHARS,
1000,
MAX_OUTPUT_CHARS,
),
max_output_chars=output_limit,
)
return format_session_poll(session_id, poll)
except KeyError:
@ -449,6 +482,47 @@ class WriteStdinTool(Tool):
except Exception as exc:
return f"Error writing to exec session: {exc}"
async def _wait_for_output(
self,
*,
session_id: str,
chars: str | None,
close_stdin: bool,
terminate: bool,
wait_for: str,
wait_timeout_ms: int,
max_output_chars: int,
) -> str:
deadline = time.monotonic() + (wait_timeout_ms / 1000)
aggregate: list[str] = []
first = True
poll: _SessionPoll | None = None
while True:
remaining_ms = max(0, int((deadline - time.monotonic()) * 1000))
step_ms = min(500, remaining_ms)
poll = await self._manager.write(
session_id=session_id,
chars=chars if first else None,
close_stdin=close_stdin if first else False,
terminate=terminate if first else False,
yield_time_ms=step_ms,
max_output_chars=max_output_chars,
)
first = False
if poll.output:
aggregate.append(poll.output)
joined = "".join(aggregate)
if wait_for in joined:
poll.output = joined
return format_session_poll(session_id, poll)
if poll.done or remaining_ms <= 0:
poll.output = "".join(aggregate)
result = format_session_poll(session_id, poll)
if wait_for not in poll.output:
result += f"\nWait target not observed: {wait_for!r}"
return result
@tool_parameters(tool_parameters_schema())
class ListExecSessionsTool(Tool):

View File

@ -219,6 +219,8 @@ def _resolve_apply_patch_paths(
patch = params.get("patch")
if not isinstance(patch, str) or not patch.strip():
return []
if params.get("dry_run") is True:
return []
try:
from nanobot.agent.tools.apply_patch import _parse_patch

View File

@ -40,9 +40,58 @@ def test_apply_patch_updates_multiple_hunks(tmp_path):
))
assert "update multi.txt" in result
assert "(+2/-2)" in result
assert target.read_text() == "line1\nchanged2\nline3\nchanged4\n"
def test_apply_patch_dry_run_validates_without_writing(tmp_path):
target = tmp_path / "dry.txt"
target.write_text("before\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: dry.txt
@@
-before
+after
*** Add File: added.txt
+new
*** End Patch
""",
dry_run=True,
))
assert "Patch dry-run succeeded" in result
assert "- update dry.txt (+1/-1)" in result
assert "- add added.txt (+1/-0)" in result
assert target.read_text() == "before\n"
assert not (tmp_path / "added.txt").exists()
def test_apply_patch_applies_repeated_update_sections_sequentially(tmp_path):
target = tmp_path / "repeat.txt"
target.write_text("one\ntwo\nthree\n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(tool.execute(
patch="""*** Begin Patch
*** Update File: repeat.txt
@@
-one
+ONE
*** Update File: repeat.txt
@@
-three
+THREE
*** End Patch
"""
))
assert result.count("update repeat.txt") == 2
assert target.read_text() == "ONE\ntwo\nTHREE\n"
def test_apply_patch_ignores_standard_no_newline_marker(tmp_path):
target = tmp_path / "plain.txt"
target.write_text("before")

View File

@ -245,6 +245,65 @@ def test_write_stdin_preserves_completed_session_output_until_polled(tmp_path):
assert "Exit code: 0" in final
def test_write_stdin_can_wait_for_expected_output(tmp_path):
async def run() -> tuple[str, str, str]:
manager = ExecSessionManager()
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
stdin_tool = WriteStdinTool(manager=manager)
command = _python_command(
"import time; print('booting', flush=True); "
"time.sleep(0.4); print('ready', flush=True); time.sleep(5)"
)
initial = await exec_tool.execute(command=command, yield_time_ms=100)
sid = _session_id(initial)
waited = await stdin_tool.execute(
session_id=sid,
wait_for="ready",
wait_timeout_ms=3000,
yield_time_ms=0,
)
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
return initial, waited, cleanup
initial, waited, cleanup = asyncio.run(run())
assert "Process running" in initial
assert "booting" in waited
assert "ready" in waited
assert "Wait target not observed" not in waited
assert "Session terminated." in cleanup
def test_write_stdin_wait_for_reports_timeout_without_killing_session(tmp_path):
async def run() -> tuple[str, str, str]:
manager = ExecSessionManager()
exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager)
stdin_tool = WriteStdinTool(manager=manager)
command = _python_command(
"import time; print('booting', flush=True); time.sleep(5)"
)
initial = await exec_tool.execute(command=command, yield_time_ms=100)
sid = _session_id(initial)
waited = await stdin_tool.execute(
session_id=sid,
wait_for="never-ready",
wait_timeout_ms=200,
yield_time_ms=0,
)
cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0)
return initial, waited, cleanup
initial, waited, cleanup = asyncio.run(run())
assert "Process running" in initial
assert "booting" in waited
assert "Process running" in waited
assert "Wait target not observed: 'never-ready'" in waited
assert "Session terminated." in cleanup
def test_exec_session_mode_reuses_exec_safety_guard(tmp_path):
manager = ExecSessionManager()
tool = ExecTool(

View File

@ -125,6 +125,28 @@ def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) ->
assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1)
def test_apply_patch_dry_run_does_not_prepare_file_edit_trackers(tmp_path: Path) -> None:
(tmp_path / "file.txt").write_text("old\n", encoding="utf-8")
trackers = prepare_file_edit_trackers(
call_id="call-patch",
tool_name="apply_patch",
tool=None,
workspace=tmp_path,
params={
"dry_run": True,
"patch": """*** Begin Patch
*** Update File: file.txt
@@
-old
+new
*** End Patch""",
},
)
assert trackers == []
def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None:
target = tmp_path / "large.txt"
params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)}