mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-22 09:32:33 +00:00
feat(webui): stream apply patch edit progress
This commit is contained in:
parent
23d5148a57
commit
722b760eae
@ -393,6 +393,9 @@ class StreamingFileEditTracker:
|
||||
self._states[key] = state
|
||||
|
||||
state.apply_delta(payload)
|
||||
if state.name == "apply_patch":
|
||||
await self._update_apply_patch(state)
|
||||
return
|
||||
if state.name not in {"write_file", "edit_file"}:
|
||||
return
|
||||
if state.path is None:
|
||||
@ -432,10 +435,62 @@ class StreamingFileEditTracker:
|
||||
deleted=deleted,
|
||||
)])
|
||||
|
||||
async def _update_apply_patch(self, state: _StreamingFileEditState) -> None:
|
||||
if _json_bool_true(state.arguments, "dry_run"):
|
||||
return
|
||||
patch = _extract_json_string_prefix(state.arguments, "patch")
|
||||
if not patch:
|
||||
return
|
||||
tool = self._tools.get("apply_patch") if hasattr(self._tools, "get") else None
|
||||
events: list[dict[str, Any]] = []
|
||||
now = time.monotonic()
|
||||
for raw_path, added, deleted, delete_file in _streaming_apply_patch_stats(patch):
|
||||
path = _resolve_raw_file_edit_path(tool, self._workspace, raw_path)
|
||||
if path is None:
|
||||
continue
|
||||
file_state = state.patch_files.get(raw_path)
|
||||
if file_state is None:
|
||||
tracker = FileEditTracker(
|
||||
call_id=state.call_id or state.key,
|
||||
tool="apply_patch",
|
||||
path=path,
|
||||
display_path=display_file_edit_path(path, self._workspace),
|
||||
before=read_file_snapshot(path),
|
||||
)
|
||||
file_state = _StreamingPatchFileState(tracker=tracker)
|
||||
state.patch_files[raw_path] = file_state
|
||||
if delete_file and added == 0 and deleted == 0 and file_state.tracker.before.countable:
|
||||
deleted = _text_line_count(file_state.tracker.before.text or "")
|
||||
if not file_state.should_emit(added, deleted, now):
|
||||
continue
|
||||
file_state.mark_emitted(added, deleted, now)
|
||||
events.append(build_file_edit_live_event(
|
||||
file_state.tracker,
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
))
|
||||
if events:
|
||||
await self._emit(events)
|
||||
|
||||
async def flush(self) -> None:
|
||||
events: list[dict[str, Any]] = []
|
||||
now = time.monotonic()
|
||||
for state in self._states.values():
|
||||
for file_state in state.patch_files.values():
|
||||
added, deleted = file_state.last_added, file_state.last_deleted
|
||||
if not file_state.emitted_once:
|
||||
continue
|
||||
if (
|
||||
file_state.last_emitted_added == added
|
||||
and file_state.last_emitted_deleted == deleted
|
||||
):
|
||||
continue
|
||||
file_state.mark_emitted(added, deleted, now)
|
||||
events.append(build_file_edit_live_event(
|
||||
file_state.tracker,
|
||||
added=added,
|
||||
deleted=deleted,
|
||||
))
|
||||
if state.tracker is None:
|
||||
continue
|
||||
added, deleted = state.live_diff_counts()
|
||||
@ -480,6 +535,10 @@ class StreamingFileEditTracker:
|
||||
"""Mark streamed edits as failed when no final tool call will run."""
|
||||
events: list[dict[str, Any]] = []
|
||||
for state in self._states.values():
|
||||
for file_state in state.patch_files.values():
|
||||
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
|
||||
continue
|
||||
events.append(build_file_edit_error_event(file_state.tracker, error))
|
||||
if state.tracker is None:
|
||||
continue
|
||||
if any(state.matches_final_tool_call(tool_call) for tool_call in final_tool_calls):
|
||||
@ -583,6 +642,39 @@ class _StreamingJsonStringField:
|
||||
self.last_char_cr = False
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _StreamingPatchFileState:
|
||||
tracker: FileEditTracker
|
||||
emitted_once: bool = False
|
||||
last_emitted_added: int = -1
|
||||
last_emitted_deleted: int = -1
|
||||
last_emit_at: float = 0.0
|
||||
last_added: int = 0
|
||||
last_deleted: int = 0
|
||||
|
||||
def should_emit(self, added: int, deleted: int, now: float) -> bool:
|
||||
self.last_added = added
|
||||
self.last_deleted = deleted
|
||||
if not self.emitted_once:
|
||||
return True
|
||||
if added == self.last_emitted_added and deleted == self.last_emitted_deleted:
|
||||
return False
|
||||
if max(
|
||||
abs(added - self.last_emitted_added),
|
||||
abs(deleted - self.last_emitted_deleted),
|
||||
) >= _LIVE_EMIT_LINE_STEP:
|
||||
return True
|
||||
return now - self.last_emit_at >= _LIVE_EMIT_INTERVAL_S
|
||||
|
||||
def mark_emitted(self, added: int, deleted: int, now: float) -> None:
|
||||
self.emitted_once = True
|
||||
self.last_added = added
|
||||
self.last_deleted = deleted
|
||||
self.last_emitted_added = added
|
||||
self.last_emitted_deleted = deleted
|
||||
self.last_emit_at = now
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _StreamingFileEditState:
|
||||
key: str
|
||||
@ -600,6 +692,7 @@ class _StreamingFileEditState:
|
||||
new_text: _StreamingJsonStringField = field(
|
||||
default_factory=lambda: _StreamingJsonStringField("new_text")
|
||||
)
|
||||
patch_files: dict[str, _StreamingPatchFileState] = field(default_factory=dict)
|
||||
emitted_once: bool = False
|
||||
last_emitted_added: int = -1
|
||||
last_emitted_deleted: int = -1
|
||||
@ -622,6 +715,7 @@ class _StreamingFileEditState:
|
||||
self.content.reset()
|
||||
self.old_text.reset()
|
||||
self.new_text.reset()
|
||||
self.patch_files.clear()
|
||||
return
|
||||
delta = payload.get("arguments_delta")
|
||||
if isinstance(delta, str) and delta:
|
||||
@ -681,6 +775,13 @@ class _StreamingFileEditState:
|
||||
name = getattr(tool_call, "name", None)
|
||||
if name != self.name:
|
||||
return False
|
||||
if self.name == "apply_patch":
|
||||
arguments = getattr(tool_call, "arguments", None)
|
||||
if not isinstance(arguments, dict):
|
||||
return False
|
||||
patch = arguments.get("patch")
|
||||
streamed_patch = _extract_complete_json_string(self.arguments, "patch")
|
||||
return isinstance(patch, str) and streamed_patch == patch
|
||||
arguments = getattr(tool_call, "arguments", None)
|
||||
if not isinstance(arguments, dict):
|
||||
return False
|
||||
@ -703,6 +804,110 @@ def _stream_key(payload: dict[str, Any]) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _json_bool_true(source: str, key: str) -> bool:
|
||||
return re.search(rf'"{re.escape(key)}"\s*:\s*true\b', source) is not None
|
||||
|
||||
|
||||
def _extract_json_string_prefix(source: str, key: str) -> str | None:
|
||||
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
|
||||
if match is None:
|
||||
return None
|
||||
out: list[str] = []
|
||||
i = match.end()
|
||||
escape = False
|
||||
while i < len(source):
|
||||
ch = source[i]
|
||||
if escape:
|
||||
escape = False
|
||||
if ch == "n":
|
||||
out.append("\n")
|
||||
elif ch == "r":
|
||||
out.append("\r")
|
||||
elif ch == "t":
|
||||
out.append("\t")
|
||||
elif ch == "u":
|
||||
digits = source[i + 1:i + 5]
|
||||
if len(digits) < 4:
|
||||
break
|
||||
try:
|
||||
out.append(chr(int(digits, 16)))
|
||||
except ValueError:
|
||||
break
|
||||
i += 4
|
||||
else:
|
||||
out.append(ch)
|
||||
i += 1
|
||||
continue
|
||||
if ch == "\\":
|
||||
escape = True
|
||||
i += 1
|
||||
continue
|
||||
if ch == '"':
|
||||
return "".join(out)
|
||||
out.append(ch)
|
||||
i += 1
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _streaming_apply_patch_stats(patch: str) -> list[tuple[str, int, int, bool]]:
|
||||
stats: dict[str, list[Any]] = {}
|
||||
order: list[str] = []
|
||||
current: str | None = None
|
||||
|
||||
def ensure(path: str, *, delete_file: bool = False) -> list[Any]:
|
||||
if path not in stats:
|
||||
stats[path] = [0, 0, False]
|
||||
order.append(path)
|
||||
if delete_file:
|
||||
stats[path][2] = True
|
||||
return stats[path]
|
||||
|
||||
lines = patch.splitlines()
|
||||
tail = ""
|
||||
if patch and not patch.endswith(("\n", "\r")) and lines:
|
||||
tail = lines.pop()
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("*** Add File: "):
|
||||
current = line[len("*** Add File: "):].strip()
|
||||
if current:
|
||||
ensure(current)
|
||||
continue
|
||||
if line.startswith("*** Update File: "):
|
||||
current = line[len("*** Update File: "):].strip()
|
||||
if current:
|
||||
ensure(current)
|
||||
continue
|
||||
if line.startswith("*** Delete File: "):
|
||||
current = line[len("*** Delete File: "):].strip()
|
||||
if current:
|
||||
ensure(current, delete_file=True)
|
||||
continue
|
||||
if line.startswith("*** Move to: "):
|
||||
moved = line[len("*** Move to: "):].strip()
|
||||
if moved:
|
||||
current = moved
|
||||
ensure(current)
|
||||
continue
|
||||
if line.startswith("*** "):
|
||||
current = None
|
||||
continue
|
||||
if not current:
|
||||
continue
|
||||
if line.startswith("+") and not line.startswith("+++"):
|
||||
ensure(current)[0] += 1
|
||||
elif line.startswith("-") and not line.startswith("---"):
|
||||
ensure(current)[1] += 1
|
||||
|
||||
if current and tail:
|
||||
if tail.startswith("+") and not tail.startswith("+++"):
|
||||
ensure(current)[0] += 1
|
||||
elif tail.startswith("-") and not tail.startswith("---"):
|
||||
ensure(current)[1] += 1
|
||||
|
||||
return [(path, int(stats[path][0]), int(stats[path][1]), bool(stats[path][2])) for path in order]
|
||||
|
||||
|
||||
def _extract_complete_json_string(source: str, key: str) -> str | None:
|
||||
match = re.search(rf'"{re.escape(key)}"\s*:\s*"', source)
|
||||
if match is None:
|
||||
|
||||
@ -206,6 +206,66 @@ def test_streaming_write_file_tracker_emits_live_line_counts(tmp_path: Path) ->
|
||||
assert events[-1]["deleted"] == 0
|
||||
|
||||
|
||||
def test_streaming_apply_patch_tracker_emits_live_counts_per_file(tmp_path: Path) -> None:
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "src" / "existing.py").write_text("old\nkeep\n", encoding="utf-8")
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-patch",
|
||||
"name": "apply_patch",
|
||||
"arguments_delta": (
|
||||
'{"patch":"*** Begin Patch\\n'
|
||||
'*** Update File: src/existing.py\\n'
|
||||
'@@\\n'
|
||||
'-old\\n'
|
||||
'+new\\n'
|
||||
' keep\\n'
|
||||
'*** Add File: src/new.py\\n'
|
||||
'+fresh\\n'
|
||||
),
|
||||
})
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
by_path = {event["path"]: event for event in events}
|
||||
assert by_path["src/existing.py"]["tool"] == "apply_patch"
|
||||
assert by_path["src/existing.py"]["status"] == "editing"
|
||||
assert by_path["src/existing.py"]["approximate"] is True
|
||||
assert (by_path["src/existing.py"]["added"], by_path["src/existing.py"]["deleted"]) == (1, 1)
|
||||
assert (by_path["src/new.py"]["added"], by_path["src/new.py"]["deleted"]) == (1, 0)
|
||||
|
||||
|
||||
def test_streaming_apply_patch_tracker_skips_dry_run(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
async def emit(batch: list[dict]) -> None:
|
||||
events.extend(batch)
|
||||
|
||||
async def run() -> None:
|
||||
tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit)
|
||||
await tracker.update({
|
||||
"index": 0,
|
||||
"call_id": "call-patch",
|
||||
"name": "apply_patch",
|
||||
"arguments_delta": (
|
||||
'{"dry_run":true,"patch":"*** Begin Patch\\n'
|
||||
'*** Add File: dry.md\\n'
|
||||
'+preview\\n'
|
||||
),
|
||||
})
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
assert events == []
|
||||
|
||||
|
||||
def test_streaming_write_file_tracker_emits_pending_before_path(tmp_path: Path) -> None:
|
||||
events: list[dict] = []
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user