diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index f02022c13..fc561bcb0 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -393,6 +393,9 @@ class StreamingFileEditTracker: self._states[key] = state state.apply_delta(payload) + if state.name == "apply_patch": + await self._update_apply_patch(state) + return if state.name not in {"write_file", "edit_file"}: return if state.path is None: @@ -432,10 +435,62 @@ class StreamingFileEditTracker: deleted=deleted, )]) + async def _update_apply_patch(self, state: _StreamingFileEditState) -> None: + if _json_bool_true(state.arguments, "dry_run"): + return + patch = _extract_json_string_prefix(state.arguments, "patch") + if not patch: + return + tool = self._tools.get("apply_patch") if hasattr(self._tools, "get") else None + events: list[dict[str, Any]] = [] + now = time.monotonic() + for raw_path, added, deleted, delete_file in _streaming_apply_patch_stats(patch): + path = _resolve_raw_file_edit_path(tool, self._workspace, raw_path) + if path is None: + continue + file_state = state.patch_files.get(raw_path) + if file_state is None: + tracker = FileEditTracker( + call_id=state.call_id or state.key, + tool="apply_patch", + path=path, + display_path=display_file_edit_path(path, self._workspace), + before=read_file_snapshot(path), + ) + file_state = _StreamingPatchFileState(tracker=tracker) + state.patch_files[raw_path] = file_state + if delete_file and added == 0 and deleted == 0 and file_state.tracker.before.countable: + deleted = _text_line_count(file_state.tracker.before.text or "") + if not file_state.should_emit(added, deleted, now): + continue + file_state.mark_emitted(added, deleted, now) + events.append(build_file_edit_live_event( + file_state.tracker, + added=added, + deleted=deleted, + )) + if events: + await self._emit(events) + async def flush(self) -> None: events: list[dict[str, Any]] = [] now = time.monotonic() for state in self._states.values(): + for file_state in state.patch_files.values(): + added, deleted = file_state.last_added, file_state.last_deleted + if not file_state.emitted_once: + continue + if ( + file_state.last_emitted_added == added + and file_state.last_emitted_deleted == deleted + ): + continue + file_state.mark_emitted(added, deleted, now) + events.append(build_file_edit_live_event( + file_state.tracker, + added=added, + deleted=deleted, + )) if state.tracker is None: continue added, deleted = state.live_diff_counts() @@ -480,6 +535,10 @@ class StreamingFileEditTracker: """Mark streamed edits as failed when no final tool call will run.""" events: list[dict[str, Any]] = [] for state in self._states.values(): + for file_state in state.patch_files.values(): + if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls): + continue + events.append(build_file_edit_error_event(file_state.tracker, error)) if state.tracker is None: continue if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls): @@ -583,6 +642,39 @@ class _StreamingJsonStringField: self.last_char_cr = False +@dataclass(slots=True) +class _StreamingPatchFileState: + tracker: FileEditTracker + emitted_once: bool = False + last_emitted_added: int = -1 + last_emitted_deleted: int = -1 + last_emit_at: float = 0.0 + last_added: int = 0 + last_deleted: int = 0 + + def should_emit(self, added: int, deleted: int, now: float) -> bool: + self.last_added = added + self.last_deleted = deleted + if not self.emitted_once: + return True + if added == self.last_emitted_added and deleted == self.last_emitted_deleted: + return False + if max( + abs(added - self.last_emitted_added), + abs(deleted - self.last_emitted_deleted), + ) >= _LIVE_EMIT_LINE_STEP: + return True + return now - self.last_emit_at >= _LIVE_EMIT_INTERVAL_S + + def mark_emitted(self, added: int, deleted: int, now: float) -> None: + self.emitted_once = True + self.last_added = added + self.last_deleted = deleted + self.last_emitted_added = added + self.last_emitted_deleted = deleted + self.last_emit_at = now + + @dataclass(slots=True) class _StreamingFileEditState: key: str @@ -600,6 +692,7 @@ class _StreamingFileEditState: new_text: _StreamingJsonStringField = field( default_factory=lambda: _StreamingJsonStringField("new_text") ) + patch_files: dict[str, _StreamingPatchFileState] = field(default_factory=dict) emitted_once: bool = False last_emitted_added: int = -1 last_emitted_deleted: int = -1 @@ -622,6 +715,7 @@ class _StreamingFileEditState: self.content.reset() self.old_text.reset() self.new_text.reset() + self.patch_files.clear() return delta = payload.get("arguments_delta") if isinstance(delta, str) and delta: @@ -681,6 +775,13 @@ class _StreamingFileEditState: name = getattr(tool_call, "name", None) if name != self.name: return False + if self.name == "apply_patch": + arguments = getattr(tool_call, "arguments", None) + if not isinstance(arguments, dict): + return False + patch = arguments.get("patch") + streamed_patch = _extract_complete_json_string(self.arguments, "patch") + return isinstance(patch, str) and streamed_patch == patch arguments = getattr(tool_call, "arguments", None) if not isinstance(arguments, dict): return False @@ -703,6 +804,110 @@ def _stream_key(payload: dict[str, Any]) -> str: return "" +def _json_bool_true(source: str, key: str) -> bool: + return re.search(rf'"{re.escape(key)}"\s*:\s*true\b', source) is not None + + +def _extract_json_string_prefix(source: str, key: str) -> str | None: + match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source) + if match is None: + return None + out: list[str] = [] + i = match.end() + escape = False + while i < len(source): + ch = source[i] + if escape: + escape = False + if ch == "n": + out.append("\n") + elif ch == "r": + out.append("\r") + elif ch == "t": + out.append("\t") + elif ch == "u": + digits = source[i + 1:i + 5] + if len(digits) < 4: + break + try: + out.append(chr(int(digits, 16))) + except ValueError: + break + i += 4 + else: + out.append(ch) + i += 1 + continue + if ch == "\\": + escape = True + i += 1 + continue + if ch == '"': + return "".join(out) + out.append(ch) + i += 1 + return "".join(out) + + +def _streaming_apply_patch_stats(patch: str) -> list[tuple[str, int, int, bool]]: + stats: dict[str, list[Any]] = {} + order: list[str] = [] + current: str | None = None + + def ensure(path: str, *, delete_file: bool = False) -> list[Any]: + if path not in stats: + stats[path] = [0, 0, False] + order.append(path) + if delete_file: + stats[path][2] = True + return stats[path] + + lines = patch.splitlines() + tail = "" + if patch and not patch.endswith(("\n", "\r")) and lines: + tail = lines.pop() + + for line in lines: + if line.startswith("*** Add File: "): + current = line[len("*** Add File: "):].strip() + if current: + ensure(current) + continue + if line.startswith("*** Update File: "): + current = line[len("*** Update File: "):].strip() + if current: + ensure(current) + continue + if line.startswith("*** Delete File: "): + current = line[len("*** Delete File: "):].strip() + if current: + ensure(current, delete_file=True) + continue + if line.startswith("*** Move to: "): + moved = line[len("*** Move to: "):].strip() + if moved: + current = moved + ensure(current) + continue + if line.startswith("*** "): + current = None + continue + if not current: + continue + if line.startswith("+") and not line.startswith("+++"): + ensure(current)[0] += 1 + elif line.startswith("-") and not line.startswith("---"): + ensure(current)[1] += 1 + + if current and tail: + if tail.startswith("+") and not tail.startswith("+++"): + ensure(current)[0] += 1 + elif tail.startswith("-") and not tail.startswith("---"): + ensure(current)[1] += 1 + + return [(path, int(stats[path][0]), int(stats[path][1]), bool(stats[path][2])) for path in order] + + def _extract_complete_json_string(source: str, key: str) -> str | None: match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source) if match is None: diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py index ec1046061..57d1272c9 100644 --- a/tests/utils/test_file_edit_events.py +++ b/tests/utils/test_file_edit_events.py @@ -206,6 +206,66 @@ def test_streaming_write_file_tracker_emits_live_line_counts(tmp_path: Path) -> assert events[-1]["deleted"] == 0 +def test_streaming_apply_patch_tracker_emits_live_counts_per_file(tmp_path: Path) -> None: + (tmp_path / "src").mkdir() + (tmp_path / "src" / "existing.py").write_text("old\nkeep\n", encoding="utf-8") + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-patch", + "name": "apply_patch", + "arguments_delta": ( + '{"patch":"*** Begin Patch\\n' + '*** Update File: src/existing.py\\n' + '@@\\n' + '-old\\n' + '+new\\n' + ' keep\\n' + '*** Add File: src/new.py\\n' + '+fresh\\n' + ), + }) + + asyncio.run(run()) + + by_path = {event["path"]: event for event in events} + assert by_path["src/existing.py"]["tool"] == "apply_patch" + assert by_path["src/existing.py"]["status"] == "editing" + assert by_path["src/existing.py"]["approximate"] is True + assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1) + assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0) + + +def test_streaming_apply_patch_tracker_skips_dry_run(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call-patch", + "name": "apply_patch", + "arguments_delta": ( + '{"dry_run":true,"patch":"*** Begin Patch\\n' + '*** Add File: dry.md\\n' + '+preview\\n' + ), + }) + + asyncio.run(run()) + + assert events == [] + + def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None: events: list[dict] = []