From 5f0ba05de594250525c673da358c7e9933bc76da Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Thu, 21 May 2026 01:25:20 +0800 Subject: [PATCH] feat(tools): tighten patch and session workflows --- nanobot/agent/tools/apply_patch.py | 129 +++++++++++++++++++++---- nanobot/agent/tools/exec_session.py | 90 +++++++++++++++-- nanobot/utils/file_edit_events.py | 2 + tests/tools/test_apply_patch_tool.py | 49 ++++++++++ tests/tools/test_exec_session_tools.py | 59 +++++++++++ tests/utils/test_file_edit_events.py | 22 +++++ 6 files changed, 322 insertions(+), 29 deletions(-) diff --git a/nanobot/agent/tools/apply_patch.py b/nanobot/agent/tools/apply_patch.py index e69a65f10..57d60f9b8 100644 --- a/nanobot/agent/tools/apply_patch.py +++ b/nanobot/agent/tools/apply_patch.py @@ -10,7 +10,7 @@ from typing import Any, Literal from nanobot.agent.tools.base import tool_parameters from nanobot.agent.tools.filesystem import _FsTool -from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema +from nanobot.agent.tools.schema import BooleanSchema, StringSchema, tool_parameters_schema PatchKind = Literal["add", "delete", "update"] @@ -31,6 +31,15 @@ class _PatchOp: hunks: list[_Hunk] | None = None +@dataclass(slots=True) +class _PatchSummary: + action: str + path: str + added: int = 0 + deleted: int = 0 + new_path: str | None = None + + class _PatchError(ValueError): pass @@ -65,6 +74,40 @@ def _lines_to_text(lines: list[str]) -> str: return "\n".join(lines) + "\n" +def _text_line_count(text: str) -> int: + if not text: + return 0 + return len(text.splitlines()) + + +def _line_diff_stats(before: str, after: str) -> tuple[int, int]: + before_lines = before.replace("\r\n", "\n").splitlines() + after_lines = after.replace("\r\n", "\n").splitlines() + added = 0 + deleted = 0 + matcher = difflib.SequenceMatcher(a=before_lines, b=after_lines, autojunk=False) + for tag, i1, i2, j1, j2 in matcher.get_opcodes(): + if tag == "equal": + continue + if tag in ("replace", "delete"): + deleted += i2 - i1 + if tag in ("replace", "insert"): + added += j2 - j1 + return added, deleted + + +def _format_summary(summary: _PatchSummary) -> str: + path = ( + f"{summary.path} -> {summary.new_path}" + if summary.new_path + else summary.path + ) + stats = "" + if summary.added or summary.deleted: + stats = f" (+{summary.added}/-{summary.deleted})" + return f"- {summary.action} {path}{stats}" + + def _parse_patch(patch: str) -> list[_PatchOp]: lines = patch.replace("\r\n", "\n").replace("\r", "\n").split("\n") if lines and lines[-1] == "": @@ -237,6 +280,10 @@ def _apply_hunks(path: str, content: str, hunks: list[_Hunk]) -> str: "for Add File, Update File, Delete File, and optional Move to.", min_length=1, ), + dry_run=BooleanSchema( + description="Validate and summarize the patch without writing files.", + default=False, + ), required=["patch"], ) ) @@ -254,60 +301,97 @@ class ApplyPatchTool(_FsTool): "Apply a structured patch for code edits. The patch must include " "*** Begin Patch and *** End Patch. Supports Add File, Update File, " "Delete File, and Move to. Paths must be relative. Prefer this for " - "multi-file coding changes; use edit_file for small exact replacements." + "multi-file coding changes; use edit_file for small exact replacements. " + "Set dry_run=true to validate and preview the change without writing files." ) - async def execute(self, patch: str, **kwargs: Any) -> str: + async def execute(self, patch: str, dry_run: bool = False, **kwargs: Any) -> str: try: ops = _parse_patch(patch) writes: dict[Path, str] = {} deletes: set[Path] = set() - touched: list[str] = [] + summaries: list[_PatchSummary] = [] for op in ops: source = self._resolve(op.path) if op.kind == "add": - if source.exists(): + if source.exists() or source in writes: raise _PatchError(f"file to add already exists: {op.path}") - writes[source] = _lines_to_text(op.add_lines or []) + new_content = _lines_to_text(op.add_lines or []) + writes[source] = new_content deletes.discard(source) - touched.append(f"add {op.path}") + summaries.append(_PatchSummary( + action="add", + path=op.path, + added=_text_line_count(new_content), + )) continue if op.kind == "delete": - if not source.exists(): + pending_content = writes.get(source) + if pending_content is None and not source.exists(): raise _PatchError(f"file to delete does not exist: {op.path}") - if not source.is_file(): + if pending_content is None and not source.is_file(): raise _PatchError(f"path to delete is not a file: {op.path}") + deleted_lines = 0 + if pending_content is not None: + deleted_lines = _text_line_count(pending_content) + else: + raw = source.read_bytes() + try: + deleted_lines = _text_line_count(raw.decode("utf-8")) + except UnicodeDecodeError: + deleted_lines = 0 deletes.add(source) writes.pop(source, None) - touched.append(f"delete {op.path}") + summaries.append(_PatchSummary( + action="delete", + path=op.path, + deleted=deleted_lines, + )) continue - if not source.exists(): + pending_content = writes.get(source) + if pending_content is None and not source.exists(): raise _PatchError(f"file to update does not exist: {op.path}") - if not source.is_file(): + if pending_content is None and not source.is_file(): raise _PatchError(f"path to update is not a file: {op.path}") - raw = source.read_bytes() - try: - content = raw.decode("utf-8") - except UnicodeDecodeError as exc: - raise _PatchError(f"file to update is not UTF-8 text: {op.path}") from exc + if pending_content is not None: + content = pending_content + else: + raw = source.read_bytes() + try: + content = raw.decode("utf-8") + except UnicodeDecodeError as exc: + raise _PatchError(f"file to update is not UTF-8 text: {op.path}") from exc uses_crlf = "\r\n" in content content = content.replace("\r\n", "\n") new_content = _apply_hunks(op.path, content, op.hunks or []) + added, deleted = _line_diff_stats(content, new_content) if uses_crlf: new_content = new_content.replace("\n", "\r\n") target = self._resolve(op.new_path) if op.new_path else source - if op.new_path and target.exists() and target != source: + if op.new_path and (target.exists() or target in writes) and target != source: raise _PatchError(f"move target already exists: {op.new_path}") writes[target] = new_content deletes.discard(target) if target != source: deletes.add(source) - action = f"move {op.path} -> {op.new_path}" if op.new_path else f"update {op.path}" - touched.append(action) + writes.pop(source, None) + summaries.append(_PatchSummary( + action="move" if op.new_path else "update", + path=op.path, + new_path=op.new_path, + added=added, + deleted=deleted, + )) + + if dry_run: + return ( + "Patch dry-run succeeded:\n" + + "\n".join(_format_summary(summary) for summary in summaries) + ) backups: dict[Path, bytes | None] = {} for path in set(writes) | deletes: @@ -332,7 +416,10 @@ class ApplyPatchTool(_FsTool): for path in set(writes) | deletes: self._file_states.record_write(path) - return "Patch applied:\n" + "\n".join(f"- {item}" for item in touched) + return ( + "Patch applied:\n" + + "\n".join(_format_summary(summary) for summary in summaries) + ) except PermissionError as exc: return f"Error: {exc}" except _PatchError as exc: diff --git a/nanobot/agent/tools/exec_session.py b/nanobot/agent/tools/exec_session.py index 34667aeaa..c9ca0a3d0 100644 --- a/nanobot/agent/tools/exec_session.py +++ b/nanobot/agent/tools/exec_session.py @@ -16,6 +16,8 @@ from nanobot.agent.tools.schema import BooleanSchema, IntegerSchema, StringSchem DEFAULT_YIELD_MS = 1000 MAX_YIELD_MS = 30_000 +DEFAULT_WAIT_FOR_MS = 10_000 +MAX_WAIT_FOR_MS = 120_000 DEFAULT_MAX_OUTPUT_CHARS = 10_000 MAX_OUTPUT_CHARS = 50_000 @@ -356,6 +358,18 @@ def format_session_poll(session_id: str, poll: _SessionPoll) -> str: minimum=0, maximum=MAX_YIELD_MS, ), + wait_for=StringSchema( + "Optional text to wait for in output before returning. " + "Useful for interactive commands and dev servers.", + nullable=True, + ), + wait_timeout_ms=IntegerSchema( + DEFAULT_WAIT_FOR_MS, + description="Maximum milliseconds to wait for wait_for text (default 10000, max 120000).", + minimum=0, + maximum=MAX_WAIT_FOR_MS, + nullable=True, + ), max_output_chars=IntegerSchema( DEFAULT_MAX_OUTPUT_CHARS, description="Maximum output characters to return from this poll (default 10000, max 50000).", @@ -412,8 +426,9 @@ class WriteStdinTool(Tool): return ( "Write text to a running exec session and return recent output. " "Use chars='' to poll without writing. Set close_stdin=true to send EOF, " - "or terminate=true to stop the session. Sessions finish automatically " - "when their process exits." + "or terminate=true to stop the session. Use wait_for to keep polling " + "until expected output appears. Sessions finish automatically when " + "their process exits." ) async def execute( @@ -423,6 +438,8 @@ class WriteStdinTool(Tool): close_stdin: bool = False, terminate: bool = False, yield_time_ms: int | None = None, + wait_for: str | None = None, + wait_timeout_ms: int | None = None, max_output_chars: int | None = None, max_output_tokens: int | None = None, **kwargs: Any, @@ -430,18 +447,34 @@ class WriteStdinTool(Tool): try: if max_output_chars is None: max_output_chars = max_output_tokens + output_limit = clamp_session_int( + max_output_chars, + DEFAULT_MAX_OUTPUT_CHARS, + 1000, + MAX_OUTPUT_CHARS, + ) + if wait_for: + return await self._wait_for_output( + session_id=session_id, + chars=chars, + close_stdin=close_stdin, + terminate=terminate, + wait_for=wait_for, + wait_timeout_ms=clamp_session_int( + wait_timeout_ms, + DEFAULT_WAIT_FOR_MS, + 0, + MAX_WAIT_FOR_MS, + ), + max_output_chars=output_limit, + ) poll = await self._manager.write( session_id=session_id, chars=chars, close_stdin=close_stdin, terminate=terminate, yield_time_ms=clamp_session_int(yield_time_ms, DEFAULT_YIELD_MS, 0, MAX_YIELD_MS), - max_output_chars=clamp_session_int( - max_output_chars, - DEFAULT_MAX_OUTPUT_CHARS, - 1000, - MAX_OUTPUT_CHARS, - ), + max_output_chars=output_limit, ) return format_session_poll(session_id, poll) except KeyError: @@ -449,6 +482,47 @@ class WriteStdinTool(Tool): except Exception as exc: return f"Error writing to exec session: {exc}" + async def _wait_for_output( + self, + *, + session_id: str, + chars: str | None, + close_stdin: bool, + terminate: bool, + wait_for: str, + wait_timeout_ms: int, + max_output_chars: int, + ) -> str: + deadline = time.monotonic() + (wait_timeout_ms / 1000) + aggregate: list[str] = [] + first = True + poll: _SessionPoll | None = None + + while True: + remaining_ms = max(0, int((deadline - time.monotonic()) * 1000)) + step_ms = min(500, remaining_ms) + poll = await self._manager.write( + session_id=session_id, + chars=chars if first else None, + close_stdin=close_stdin if first else False, + terminate=terminate if first else False, + yield_time_ms=step_ms, + max_output_chars=max_output_chars, + ) + first = False + if poll.output: + aggregate.append(poll.output) + joined = "".join(aggregate) + if wait_for in joined: + poll.output = joined + return format_session_poll(session_id, poll) + if poll.done or remaining_ms <= 0: + poll.output = "".join(aggregate) + result = format_session_poll(session_id, poll) + if wait_for not in poll.output: + result += f"\nWait target not observed: {wait_for!r}" + return result + @tool_parameters(tool_parameters_schema()) class ListExecSessionsTool(Tool): diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index c11e8ae60..acef725b0 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -219,6 +219,8 @@ def _resolve_apply_patch_paths( patch = params.get("patch") if not isinstance(patch, str) or not patch.strip(): return [] + if params.get("dry_run") is True: + return [] try: from nanobot.agent.tools.apply_patch import _parse_patch diff --git a/tests/tools/test_apply_patch_tool.py b/tests/tools/test_apply_patch_tool.py index ea98794b3..a356b83e8 100644 --- a/tests/tools/test_apply_patch_tool.py +++ b/tests/tools/test_apply_patch_tool.py @@ -40,9 +40,58 @@ def test_apply_patch_updates_multiple_hunks(tmp_path): )) assert "update multi.txt" in result + assert "(+2/-2)" in result assert target.read_text() == "line1\nchanged2\nline3\nchanged4\n" +def test_apply_patch_dry_run_validates_without_writing(tmp_path): + target = tmp_path / "dry.txt" + target.write_text("before\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: dry.txt +@@ +-before ++after +*** Add File: added.txt ++new +*** End Patch +""", + dry_run=True, + )) + + assert "Patch dry-run succeeded" in result + assert "- update dry.txt (+1/-1)" in result + assert "- add added.txt (+1/-0)" in result + assert target.read_text() == "before\n" + assert not (tmp_path / "added.txt").exists() + + +def test_apply_patch_applies_repeated_update_sections_sequentially(tmp_path): + target = tmp_path / "repeat.txt" + target.write_text("one\ntwo\nthree\n") + tool = ApplyPatchTool(workspace=tmp_path) + + result = asyncio.run(tool.execute( + patch="""*** Begin Patch +*** Update File: repeat.txt +@@ +-one ++ONE +*** Update File: repeat.txt +@@ +-three ++THREE +*** End Patch +""" + )) + + assert result.count("update repeat.txt") == 2 + assert target.read_text() == "ONE\ntwo\nTHREE\n" + + def test_apply_patch_ignores_standard_no_newline_marker(tmp_path): target = tmp_path / "plain.txt" target.write_text("before") diff --git a/tests/tools/test_exec_session_tools.py b/tests/tools/test_exec_session_tools.py index 52f72b556..ad2506739 100644 --- a/tests/tools/test_exec_session_tools.py +++ b/tests/tools/test_exec_session_tools.py @@ -245,6 +245,65 @@ def test_write_stdin_preserves_completed_session_output_until_polled(tmp_path): assert "Exit code: 0" in final +def test_write_stdin_can_wait_for_expected_output(tmp_path): + async def run() -> tuple[str, str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('booting', flush=True); " + "time.sleep(0.4); print('ready', flush=True); time.sleep(5)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=100) + sid = _session_id(initial) + waited = await stdin_tool.execute( + session_id=sid, + wait_for="ready", + wait_timeout_ms=3000, + yield_time_ms=0, + ) + cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0) + return initial, waited, cleanup + + initial, waited, cleanup = asyncio.run(run()) + + assert "Process running" in initial + assert "booting" in waited + assert "ready" in waited + assert "Wait target not observed" not in waited + assert "Session terminated." in cleanup + + +def test_write_stdin_wait_for_reports_timeout_without_killing_session(tmp_path): + async def run() -> tuple[str, str, str]: + manager = ExecSessionManager() + exec_tool = ExecTool(working_dir=str(tmp_path), timeout=5, session_manager=manager) + stdin_tool = WriteStdinTool(manager=manager) + command = _python_command( + "import time; print('booting', flush=True); time.sleep(5)" + ) + + initial = await exec_tool.execute(command=command, yield_time_ms=100) + sid = _session_id(initial) + waited = await stdin_tool.execute( + session_id=sid, + wait_for="never-ready", + wait_timeout_ms=200, + yield_time_ms=0, + ) + cleanup = await stdin_tool.execute(session_id=sid, terminate=True, yield_time_ms=0) + return initial, waited, cleanup + + initial, waited, cleanup = asyncio.run(run()) + + assert "Process running" in initial + assert "booting" in waited + assert "Process running" in waited + assert "Wait target not observed: 'never-ready'" in waited + assert "Session terminated." in cleanup + + def test_exec_session_mode_reuses_exec_safety_guard(tmp_path): manager = ExecSessionManager() tool = ExecTool( diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py index 7cc8a59fa..3ac4dc929 100644 --- a/tests/utils/test_file_edit_events.py +++ b/tests/utils/test_file_edit_events.py @@ -125,6 +125,28 @@ def test_apply_patch_prepares_trackers_for_each_touched_file(tmp_path: Path) -> assert (by_path["src/delete_me.py"]["added"], by_path["src/delete_me.py"]["deleted"]) == (0, 1) +def test_apply_patch_dry_run_does_not_prepare_file_edit_trackers(tmp_path: Path) -> None: + (tmp_path / "file.txt").write_text("old\n", encoding="utf-8") + + trackers = prepare_file_edit_trackers( + call_id="call-patch", + tool_name="apply_patch", + tool=None, + workspace=tmp_path, + params={ + "dry_run": True, + "patch": """*** Begin Patch +*** Update File: file.txt +@@ +-old ++new +*** End Patch""", + }, + ) + + assert trackers == [] + + def test_oversized_write_file_end_uses_known_content_for_exact_count(tmp_path: Path) -> None: target = tmp_path / "large.txt" params = {"path": "large.txt", "content": "x" * (2 * 1024 * 1024 + 1)}