feat(webui): stream apply patch edit progress

This commit is contained in:
Xubin Ren 2026-05-21 15:44:01 +08:00
parent 23d5148a57
commit 722b760eae
2 changed files with 265 additions and 0 deletions

View File

@ -393,6 +393,9 @@ class StreamingFileEditTracker:
self._states[key] = state
state.apply_delta(payload)
if state.name == "apply_patch":
await self._update_apply_patch(state)
return
if state.name not in {"write_file", "edit_file"}:
return
if state.path is None:
@ -432,10 +435,62 @@ class StreamingFileEditTracker:
deleted=deleted,
)])
async def _update_apply_patch(self, state: _StreamingFileEditState) -> None:
if _json_bool_true(state.arguments, "dry_run"):
return
patch = _extract_json_string_prefix(state.arguments, "patch")
if not patch:
return
tool = self._tools.get("apply_patch") if hasattr(self._tools, "get") else None
events: list[dict[str, Any]] = []
now = time.monotonic()
for raw_path, added, deleted, delete_file in _streaming_apply_patch_stats(patch):
path = _resolve_raw_file_edit_path(tool, self._workspace, raw_path)
if path is None:
continue
file_state = state.patch_files.get(raw_path)
if file_state is None:
tracker = FileEditTracker(
call_id=state.call_id or state.key,
tool="apply_patch",
path=path,
display_path=display_file_edit_path(path, self._workspace),
before=read_file_snapshot(path),
)
file_state = _StreamingPatchFileState(tracker=tracker)
state.patch_files[raw_path] = file_state
if delete_file and added == 0 and deleted == 0 and file_state.tracker.before.countable:
deleted = _text_line_count(file_state.tracker.before.text or "")
if not file_state.should_emit(added, deleted, now):
continue
file_state.mark_emitted(added, deleted, now)
events.append(build_file_edit_live_event(
file_state.tracker,
added=added,
deleted=deleted,
))
if events:
await self._emit(events)
async def flush(self) -> None:
events: list[dict[str, Any]] = []
now = time.monotonic()
for state in self._states.values():
for file_state in state.patch_files.values():
added, deleted = file_state.last_added, file_state.last_deleted
if not file_state.emitted_once:
continue
if (
file_state.last_emitted_added == added
and file_state.last_emitted_deleted == deleted
):
continue
file_state.mark_emitted(added, deleted, now)
events.append(build_file_edit_live_event(
file_state.tracker,
added=added,
deleted=deleted,
))
if state.tracker is None:
continue
added, deleted = state.live_diff_counts()
@ -480,6 +535,10 @@ class StreamingFileEditTracker:
"""Mark streamed edits as failed when no final tool call will run."""
events: list[dict[str, Any]] = []
for state in self._states.values():
for file_state in state.patch_files.values():
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
continue
events.append(build_file_edit_error_event(file_state.tracker, error))
if state.tracker is None:
continue
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
@ -583,6 +642,39 @@ class _StreamingJsonStringField:
self.last_char_cr = False
@dataclass(slots=True)
class _StreamingPatchFileState:
tracker: FileEditTracker
emitted_once: bool = False
last_emitted_added: int = -1
last_emitted_deleted: int = -1
last_emit_at: float = 0.0
last_added: int = 0
last_deleted: int = 0
def should_emit(self, added: int, deleted: int, now: float) -> bool:
self.last_added = added
self.last_deleted = deleted
if not self.emitted_once:
return True
if added == self.last_emitted_added and deleted == self.last_emitted_deleted:
return False
if max(
abs(added - self.last_emitted_added),
abs(deleted - self.last_emitted_deleted),
) >= _LIVE_EMIT_LINE_STEP:
return True
return now - self.last_emit_at >= _LIVE_EMIT_INTERVAL_S
def mark_emitted(self, added: int, deleted: int, now: float) -> None:
self.emitted_once = True
self.last_added = added
self.last_deleted = deleted
self.last_emitted_added = added
self.last_emitted_deleted = deleted
self.last_emit_at = now
@dataclass(slots=True)
class _StreamingFileEditState:
key: str
@ -600,6 +692,7 @@ class _StreamingFileEditState:
new_text: _StreamingJsonStringField = field(
default_factory=lambda: _StreamingJsonStringField("new_text")
)
patch_files: dict[str, _StreamingPatchFileState] = field(default_factory=dict)
emitted_once: bool = False
last_emitted_added: int = -1
last_emitted_deleted: int = -1
@ -622,6 +715,7 @@ class _StreamingFileEditState:
self.content.reset()
self.old_text.reset()
self.new_text.reset()
self.patch_files.clear()
return
delta = payload.get("arguments_delta")
if isinstance(delta, str) and delta:
@ -681,6 +775,13 @@ class _StreamingFileEditState:
name = getattr(tool_call, "name", None)
if name != self.name:
return False
if self.name == "apply_patch":
arguments = getattr(tool_call, "arguments", None)
if not isinstance(arguments, dict):
return False
patch = arguments.get("patch")
streamed_patch = _extract_complete_json_string(self.arguments, "patch")
return isinstance(patch, str) and streamed_patch == patch
arguments = getattr(tool_call, "arguments", None)
if not isinstance(arguments, dict):
return False
@ -703,6 +804,110 @@ def _stream_key(payload: dict[str, Any]) -> str:
return ""
def _json_bool_true(source: str, key: str) -> bool:
return re.search(rf'"{re.escape(key)}"\s*:\s*true\b', source) is not None
def _extract_json_string_prefix(source: str, key: str) -> str | None:
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
if match is None:
return None
out: list[str] = []
i = match.end()
escape = False
while i < len(source):
ch = source[i]
if escape:
escape = False
if ch == "n":
out.append("\n")
elif ch == "r":
out.append("\r")
elif ch == "t":
out.append("\t")
elif ch == "u":
digits = source[i + 1:i + 5]
if len(digits) < 4:
break
try:
out.append(chr(int(digits, 16)))
except ValueError:
break
i += 4
else:
out.append(ch)
i += 1
continue
if ch == "\\":
escape = True
i += 1
continue
if ch == '"':
return "".join(out)
out.append(ch)
i += 1
return "".join(out)
def _streaming_apply_patch_stats(patch: str) -> list[tuple[str, int, int, bool]]:
stats: dict[str, list[Any]] = {}
order: list[str] = []
current: str | None = None
def ensure(path: str, *, delete_file: bool = False) -> list[Any]:
if path not in stats:
stats[path] = [0, 0, False]
order.append(path)
if delete_file:
stats[path][2] = True
return stats[path]
lines = patch.splitlines()
tail = ""
if patch and not patch.endswith(("\n", "\r")) and lines:
tail = lines.pop()
for line in lines:
if line.startswith("*** Add File: "):
current = line[len("*** Add File: "):].strip()
if current:
ensure(current)
continue
if line.startswith("*** Update File: "):
current = line[len("*** Update File: "):].strip()
if current:
ensure(current)
continue
if line.startswith("*** Delete File: "):
current = line[len("*** Delete File: "):].strip()
if current:
ensure(current, delete_file=True)
continue
if line.startswith("*** Move to: "):
moved = line[len("*** Move to: "):].strip()
if moved:
current = moved
ensure(current)
continue
if line.startswith("*** "):
current = None
continue
if not current:
continue
if line.startswith("+") and not line.startswith("+++"):
ensure(current)[0] += 1
elif line.startswith("-") and not line.startswith("---"):
ensure(current)[1] += 1
if current and tail:
if tail.startswith("+") and not tail.startswith("+++"):
ensure(current)[0] += 1
elif tail.startswith("-") and not tail.startswith("---"):
ensure(current)[1] += 1
return [(path, int(stats[path][0]), int(stats[path][1]), bool(stats[path][2])) for path in order]
def _extract_complete_json_string(source: str, key: str) -> str | None:
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
if match is None:

View File

@ -206,6 +206,66 @@ def test_streaming_write_file_tracker_emits_live_line_counts(tmp_path: Path) ->
assert events[-1]["deleted"] == 0
def test_streaming_apply_patch_tracker_emits_live_counts_per_file(tmp_path: Path) -> None:
(tmp_path / "src").mkdir()
(tmp_path / "src" / "existing.py").write_text("old\nkeep\n", encoding="utf-8")
events: list[dict] = []
async def emit(batch: list[dict]) -> None:
events.extend(batch)
async def run() -> None:
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
await tracker.update({
"index": 0,
"call_id": "call-patch",
"name": "apply_patch",
"arguments_delta": (
'{"patch":"*** Begin Patch\\n'
'*** Update File: src/existing.py\\n'
'@@\\n'
'-old\\n'
'+new\\n'
' keep\\n'
'*** Add File: src/new.py\\n'
'+fresh\\n'
),
})
asyncio.run(run())
by_path = {event["path"]: event for event in events}
assert by_path["src/existing.py"]["tool"] == "apply_patch"
assert by_path["src/existing.py"]["status"] == "editing"
assert by_path["src/existing.py"]["approximate"] is True
assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1)
assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0)
def test_streaming_apply_patch_tracker_skips_dry_run(tmp_path: Path) -> None:
events: list[dict] = []
async def emit(batch: list[dict]) -> None:
events.extend(batch)
async def run() -> None:
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
await tracker.update({
"index": 0,
"call_id": "call-patch",
"name": "apply_patch",
"arguments_delta": (
'{"dry_run":true,"patch":"*** Begin Patch\\n'
'*** Add File: dry.md\\n'
'+preview\\n'
),
})
asyncio.run(run())
assert events == []
def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None:
events: list[dict] = []